In [13]:
import torch.nn
import torch

class TransformerEncoder(torch.nn.Module):
    def __init__(self, num_layers: int, num_heads: int, dim: int, time: int, mz: int):
        super().__init__()
        self.pos = torch.nn.Embedding(time+1, dim)
        self.proj = torch.nn.Linear(mz, dim)
        self.cls = torch.nn.Parameter(torch.randn(1, 1, dim))
        self.encoder = torch.nn.TransformerEncoder(torch.nn.TransformerEncoderLayer(d_model=dim, nhead=num_heads, batch_first=True), num_layers)

    def forward(self, x: torch.Tensor):
        b, _, _ = x.shape
        x = self.proj(x)
        cls = self.cls.expand(b, -1, -1)
        x = torch.cat((cls, x), dim=1)
        x += self.pos(torch.arange(x.shape[1], device=x.device))
        x = self.encoder(x)
        return x[:,0,:]
encoder = TransformerEncoder(8, 8, 1024, 1024, 1024)
assert encoder(torch.randn(1, 1024, 1024)).shape == (1, 1024)
    

In [19]:
class TransformerDecoder(torch.nn.Module):
    def __init__(self, num_layers: int, num_heads: int, dim: int, time: int, mz: int):
        super().__init__()
        self.time = time
        self.mz = mz

        self.pos = torch.nn.Embedding(time+1, dim)
        self.out_put_proj = torch.nn.Linear(dim, mz)
        self.decoder = torch.nn.TransformerEncoder(torch.nn.TransformerEncoderLayer(d_model=dim, nhead=num_heads, batch_first=True), num_layers)

    def forward(self, latten: torch.Tensor):
        pos = self.pos(torch.arange(self.time+1, device=latten.device))
        pos = pos.unsqueeze(0).expand(latten.shape[0], -1, -1)
        pos[:,0,:] += latten
        x = self.decoder(pos)
        return self.out_put_proj(x)[:,1:,:]
decoder = TransformerDecoder(8, 8, 1024, 1024, 1024)
assert decoder(torch.randn(1, 1024)).shape == (1, 1024, 1024)

In [20]:
class Model(torch.nn.Module):
    def __init__(self, num_layers: int, num_heads: int, dim: int, time: int, mz: int):
        super().__init__()
        self.encoder = TransformerEncoder(num_layers, num_heads, dim, time, mz)
        self.decoder = TransformerDecoder(num_layers, num_heads, dim, time, mz)

    def forward(self, x: torch.Tensor):
        return self.decoder(self.encoder(x))
model = Model(8, 8, 1024, 1024, 1024)
assert model(torch.randn(1, 1024, 1024)).shape == (1, 1024, 1024)