In [None]:
batch_size = 2
sequence_length = 4
d_embedding_in = 6
d_embedding_out = 7

In [2]:
import torch

In [3]:
x = torch.randn(batch_size, sequence_length, d_embedding_in)
x.shape

torch.Size([2, 4, 6])

In [10]:
from torch.nn import Parameter
import torch.nn as nn


# _NLinear is a simplified copy of delu.nn.NLinear:
# https://yura52.github.io/delu/stable/api/generated/delu.nn.NLinear.html
class _NLinear(nn.Module):
    """N *separate* linear layers for N feature embeddings.

    In other words,
    each feature embedding is transformed by its own dedicated linear layer.
    """

    def __init__(self, n: int, in_features: int, out_features: int, bias: bool = True) -> None:
        super().__init__()
        self.weight = Parameter(torch.empty(n, in_features, out_features))
        self.bias = Parameter(torch.empty(n, out_features)) if bias else None
        self.reset_parameters()

    def reset_parameters(self):
        """Reset the parameters."""
        d_in_rsqrt = self.weight.shape[-2] ** -0.5
        nn.init.uniform_(self.weight, -d_in_rsqrt, d_in_rsqrt)
        if self.bias is not None:
            nn.init.uniform_(self.bias, -d_in_rsqrt, d_in_rsqrt)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Do the forward pass."""
        if x.ndim != 3:
            raise ValueError(
                '_NLinear supports only inputs with exactly one batch dimension,'
                ' so `x` must have a shape like (BATCH_SIZE, N_FEATURES, D_EMBEDDING).'
            )
        assert x.shape[-(self.weight.ndim - 1) :] == self.weight.shape[:-1]
        x = x.transpose(0, 1)
        x = x @ self.weight
        x = x.transpose(0, 1)
        if self.bias is not None:
            x = x + self.bias
        return x

In [11]:
m = _NLinear(sequence_length, d_embedding_in, d_embedding_out)
m(x).shape

torch.Size([2, 4, 7])

(CV) Training a separate linear layer (i.e. Conv1x1) for each of the spatial embeddings (patch embeddings, pixel embeddings, etc.) of an image (total count: width x height):

In [13]:
batch_size = 2
width = 4
height = 5
in_channels = 6
out_channels = 7
x = torch.randn(batch_size, width, height, in_channels)
x.shape

torch.Size([2, 4, 5, 6])

In [14]:
m = _NLinear((width, height), in_channels, out_channels)
m(x).shape

TypeError: empty() received an invalid combination of arguments - got (tuple, int, int), but expected one of:
 * (tuple of ints size, *, tuple of names names, torch.memory_format memory_format = None, torch.dtype dtype = None, torch.layout layout = None, torch.device device = None, bool pin_memory = False, bool requires_grad = False)
 * (tuple of ints size, *, torch.memory_format memory_format = None, Tensor out = None, torch.dtype dtype = None, torch.layout layout = None, torch.device device = None, bool pin_memory = False, bool requires_grad = False)
