In [4]:
import torch
import pytorch_lightning as pl
import torch.nn.functional as F
from torch import nn

In [5]:
class Transformer(pl.LightningModule):
    def __init__(self, vocab_size, embed_size, latent_dim, num_heads, hidden_dim, num_layers, max_seq_length):
        super().__init__()
        self.latent_dim = latent_dim
        self.max_seq_length = max_seq_length
        
        # Embedding layer
        self.embedding = nn.Embedding(vocab_size, embed_size)
        
        # Positional encoding
        self.positional_encoding = nn.Parameter(torch.zeros(1, max_seq_length, embed_size))
        
        # Transformer Encoder
        encoder_layer = nn.TransformerEncoderLayer(embed_size, num_heads, hidden_dim, batch_first=True)
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers)
        
        # Latent space
        self.fc_mu = nn.Linear(embed_size, latent_dim)
        self.fc_logvar = nn.Linear(embed_size, latent_dim)
        self.fc_latent_to_hidden = nn.Linear(latent_dim, embed_size)
        
        # Transformer Decoder
        decoder_layer = nn.TransformerDecoderLayer(embed_size, num_heads, hidden_dim, batch_first=True)
        self.decoder = nn.TransformerDecoder(decoder_layer, num_layers)
        
        # Output projection
        self.output_fc = nn.Linear(embed_size, vocab_size)

        

    def encode(self, x, mask=None):
        emb = self.embedding(x) + self.positional_encoding[:, :x.size(1), :]
        encoded = self.encoder(emb, src_key_padding_mask=mask)
        last_hidden_state = encoded[:, -1, :] # z this hidden state is the last hidden state.
        # print('last hidden state shape')
        # print(last_hidden_state.shape)
        # pooled_latent = encoded.mean(dim=1) # batch_size, embedding_size
        mu_vector = self.fc_mu(last_hidden_state)
        logvar_vector = self.fc_logvar(last_hidden_state)
        return last_hidden_state, emb
    
    def decode(self, last_hidden_state, x, mask=None):
        resized_latent = self.fc_latent_to_hidden(last_hidden_state)
        resized_latent = resized_latent.unsqueeze(1).repeat(1, self.max_seq_length, 1)
        # add positional encoding to reshaped latents 
        
        resized_latent = resized_latent + self.positional_encoding[:, :resized_latent.size(1), :] 
        
        # Add positional encoding
        hidden = resized_latent + self.positional_encoding[:, :self.max_seq_length, :]

        # Pass through decoder
        decoded = self.decoder(x, hidden, tgt_key_padding_mask=mask, memory_key_padding_mask=mask)

        outputs_v = self.output_fc(decoded)
        
        return outputs_v

    def forward(self,batch):
        with_bos, with_eos, masks = batch
        last_hidden_states, memory = m.encode(with_bos, mask=masks)
        output_v = m.decode(last_hidden_states, memory, mask=masks)
        return output_v, with_eos 

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.0001) 

    def training_step(self, batch, batch_idx):
        outputs_v, with_eos = self.forward(batch)
        recon_loss = F.cross_entropy(outputs_v.view(-1, outputs_v.size(-1)), with_eos.view(-1), ignore_index=0)
        print(recon_loss)
        return recon_loss
        

In [6]:
class TransformerVocab:
    def __init__(self, df=None, target_col=None, char2idx={}, idx2char={}):
        self.sos_token = '<sos>' 
        self.eos_token = '<eos>' 
        self.pad_token = '<pad>'
        self.unk_token = '<unk>'
        self.special_tokens = ["<pad>", "<sos>", "<eos>", "<unk>"]
        self.char2idx = char2idx 
        self.idx2char = idx2char
        self.build_vocab()
    
    def build_vocab(self):
        aa_alphabet = list("ACDEFGHIKLMNPQRSTVWY")  # 20 standard amino acids
        all_chars = self.special_tokens + aa_alphabet
        self.char2idx = {token: idx for idx, token in enumerate(all_chars)}
        self.idx2char = {idx: token for token, idx in self.char2idx.items()}

class TransformerDataset:
    def __init__(self, sequences, vocab, max_len=100):
        self.sequences = sequences
        self.vocab = vocab 
        self.max_len = max_len 
    
    def __len__(self):
        return len(self.sequences)

    def __getitem__(self, idx):
        seq = self.sequences[idx]
        tokens = [self.vocab.char2idx[i] for i in seq]
        
        with_bos = [self.vocab.char2idx[self.vocab.sos_token]] + tokens
        with_eos = tokens + [self.vocab.char2idx[self.vocab.eos_token]]
        with_bos += [self.vocab.char2idx[self.vocab.pad_token]] * (self.max_len - len(with_bos))
        with_eos += [self.vocab.char2idx[self.vocab.pad_token]] * (self.max_len - len(with_eos))
        attention_mask = [1 if t != self.vocab.char2idx[self.vocab.pad_token] else 0 for t in with_bos]
        return torch.tensor(with_bos), torch.tensor(with_eos), torch.tensor(attention_mask).float()

    def collate(self, batch):
        with_bos, with_eos, masks = zip(*batch)
        with_bos = torch.stack(with_bos) 
        with_eos = torch.stack(with_eos)
        masks = torch.stack(masks)
        return with_bos, with_eos, masks

In [7]:
sequences = [
    "ACDEFGHIKLMNPQRSTVWY",
    "ACD",
    "MNPQRSTVWY",
    "DEFGHI"
]

vocab = TransformerVocab()

ds = TransformerDataset(sequences, vocab, max_len=77)

m = Transformer(vocab_size=len(vocab.char2idx), embed_size=32, latent_dim=32, num_heads=4, hidden_dim=6, num_layers=2, max_seq_length=77)

dl = torch.utils.data.DataLoader(ds, batch_size=2, collate_fn=ds.collate)

trainer = pl.Trainer(devices=1)

trainer.fit(m, dl)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py:75: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `pytorch_lightning` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default
/usr/local/lib/python3.10/dist-packages/pytorch_lightning/loops/utilities.py:72: `max_epochs` was not set. Setting it to 1000 epochs. To train without an epoch limit, set `max_epochs=-1`.
You are using a CUDA device ('NVIDIA A40') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which wil

Training: |          | 0/? [00:00<?, ?it/s]

tensor(3.3949, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(3.4344, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(3.3319, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(3.4329, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(3.4154, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(3.4391, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(3.3630, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(3.4274, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(3.3475, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(3.3932, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(3.3302, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(3.3952, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(3.3218, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(3.3815, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(3.2665, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(3.3719, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(3.2597, device='cuda:0', grad_fn=

`Trainer.fit` stopped: `max_epochs=1000` reached.


tensor(0.1452, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(0.1267, device='cuda:0', grad_fn=<NllLossBackward0>)
