In [1]:
import torch
import pytorch_lightning as pl
import torch.nn.functional as F
from torch import nn

In [2]:
from pytorch_lightning.loggers import WandbLogger
from lightning.pytorch.callbacks import ModelCheckpoint

In [26]:
class Transformer(pl.LightningModule):
    def __init__(self, vocab_size, embed_size, latent_dim, num_heads, hidden_dim, num_layers, max_seq_length, vocab):
        super().__init__()
        self.latent_dim = latent_dim
        self.max_seq_length = max_seq_length
        
        # Embedding layer
        self.embedding = nn.Embedding(vocab_size, embed_size)
        
        # Positional encoding
        self.positional_encoding = nn.Parameter(torch.zeros(1, max_seq_length, embed_size))
        
        # Transformer Encoder
        encoder_layer = nn.TransformerEncoderLayer(embed_size, num_heads, hidden_dim, batch_first=True)
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers)
        
        # Latent space
        self.fc_mu = nn.Linear(embed_size, latent_dim)
        self.fc_logvar = nn.Linear(embed_size, latent_dim)
        self.fc_latent_to_hidden = nn.Linear(latent_dim, embed_size)
        
        # Transformer Decoder
        decoder_layer = nn.TransformerDecoderLayer(embed_size, num_heads, hidden_dim, batch_first=True)
        self.decoder = nn.TransformerDecoder(decoder_layer, num_layers)

        # Output projection
        self.output_fc = nn.Linear(embed_size, vocab_size)
        self.vocab = vocab 
        self.criterion = torch.nn.CrossEntropyLoss(ignore_index=vocab.char2idx[vocab.pad_token], reduction='sum')
        self.kl_weight = 0.6
        

    def compute_loss(self, outputs, targets, logvar, mu):
        r_loss = self.criterion(outputs.view(-1, outputs.size(-1)), targets.view(-1))
        kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
        kl_loss = kl_loss * self.kl_weight
        return r_loss, kl_loss
    
    def cyclical_annealing(self,T,M,step,R=0.4, max_kl_weight=1):
        """
        Implementing: <https://arxiv.org/abs/1903.10145>
        T = Total steps 
        M = Number of cycles 
        R = Proportion used to increase beta
        t = Global step 
        """
        period = (T/M) # N_iters/N_cycles 
        internal_period = (step) % (period)  # Itteration_number/(Global Period)
        tau = internal_period/period
        if tau > R:
            tau = max_kl_weight
        else:
            tau = min(max_kl_weight, tau/R) # Linear function 
        return tau
    
    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std
    
    def generate_square_subsequent_mask(self, sz):
        mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
        mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
        return mask
    
    def encode(self, x, mask=None):
        emb = self.embedding(x) + self.positional_encoding[:, :x.size(1), :]
        encoded = self.encoder(emb, src_key_padding_mask=mask)
        last_hidden_state = encoded[:, -1, :] # z this hidden state is the last hidden state.
        mu_vector = self.fc_mu(last_hidden_state)
        logvar_vector = self.fc_logvar(last_hidden_state)
        return last_hidden_state, emb, logvar_vector, mu_vector
    
    def decode(self, last_hidden_state, x, mask=None):
        resized_latent = self.fc_latent_to_hidden(last_hidden_state)
        resized_latent = resized_latent.unsqueeze(1).repeat(1, self.max_seq_length, 1)
        # add positional encoding to reshaped latents 
        
        resized_latent = resized_latent + self.positional_encoding[:, :resized_latent.size(1), :] 
        # Add positional encoding
        hidden = resized_latent + self.positional_encoding[:, :self.max_seq_length, :]
        tgt_mask = self.generate_square_subsequent_mask(x.size(1)).to(x.device)
        # Pass through decoder
        # decoded = self.decoder(x, hidden, tgt_key_padding_mask=mask, memory_key_padding_mask=mask)
        decoded = self.decoder(x, hidden, tgt_mask=tgt_mask,tgt_is_causal=True,tgt_key_padding_mask=mask, memory_key_padding_mask=mask)

        outputs_v = self.output_fc(decoded)
        
        return outputs_v

    def forward(self,batch):
        with_bos, with_eos, masks = batch
        last_hidden_states, memory, logvar, mu = m.encode(with_bos, mask=masks)
        z = self.reparameterize(mu, logvar)
        output_v = m.decode(z, memory, mask=masks)
        return output_v, with_eos, logvar, mu

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.001) 

    def training_step(self, batch, batch_idx):
        outputs_v, with_eos, logvar, mu = self.forward(batch)
        r_loss, kl_loss = self.compute_loss(outputs_v, with_eos, logvar, mu)
        # recon_loss = self.criterion(outputs_v.view(-1, outputs_v.size(-1)), with_eos.view(-1))
        self.kl_weight = self.cyclical_annealing(100000, 100, step=self.global_step, max_kl_weight=0.6)
        # self.kl_weight = 0.8
        r = {}
        r['r_loss'] = r_loss 
        r['kl_loss'] = kl_loss
        r['loss'] = r_loss + kl_loss
        r['kl_weight'] = self.kl_weight
        # self.log(r)
        for key in r:
            self.log(key, r[key])

        return r

In [27]:
class TransformerVocab:
    def __init__(self, sequences=None, target_col=None, char2idx={}, idx2char={}):
        self.sos_token = '<sos>' 
        self.eos_token = '<eos>' 
        self.pad_token = '<pad>'
        self.unk_token = '<unk>'
        self.special_tokens = ["<pad>", "<sos>", "<eos>", "<unk>"]
        self.char2idx = char2idx 
        self.idx2char = idx2char
        if sequences != None:
            self.extract_charset(sequences)
        else:
            self.build_vocab_aa()

    def handle_special(self, smi):
        smi = smi.replace('Cl', 'Q')
        smi = smi.replace('Br', 'W')
        smi = smi.replace('[nH]', 'X')
        smi = smi.replace('[H]', 'Y')
        return smi
    def reverse_special(self, smi):
        smi = smi.replace('Q', 'Cl')
        smi = smi.replace('W', 'Br')
        smi = smi.replace( 'X','[nH]')
        smi = smi.replace('Y', '[H]') 
        return smi 
    def extract_charset(self, sequences):
        """
        Extract charset from SMILES strings

        Parameters
        ----------
        df : pd.DataFrame
            DataFrame containing SMILES strings

        """
        from tqdm import tqdm

        print('extracting charset..')
        i = 0
        for c in self.special_tokens:
            if c not in self.char2idx:
                self.char2idx[c] = i
                self.idx2char[i] = c
                i += 1
        all_smi = sequences
        for _, smi in enumerate(tqdm(all_smi)):
            smi = self.handle_special(smi)
            for c in smi:
                if c not in self.char2idx:
                    self.char2idx[c] = i
                    self.idx2char[i] = c
                    i += 1
    
    def build_vocab_aa(self):
        aa_alphabet = list('ACDEFGHIKLMNPQRSTVWYXUZBO')  # 20 standard amino acids
        all_chars = self.special_tokens + aa_alphabet
        self.char2idx = {token: idx for idx, token in enumerate(all_chars)}
        self.idx2char = {idx: token for token, idx in self.char2idx.items()}

class TransformerDataset:
    def __init__(self, sequences, vocab, max_len=100):
        self.sequences = sequences
        self.vocab = vocab 
        self.max_len = max_len 
    
    def __len__(self):
        return len(self.sequences)

    def __getitem__(self, idx):
        seq = self.sequences[idx]
        seq = self.vocab.handle_special(seq)
        tokens = [self.vocab.char2idx[i] for i in seq]
        
        with_bos = [self.vocab.char2idx[self.vocab.sos_token]] + tokens
        with_eos = tokens + [self.vocab.char2idx[self.vocab.eos_token]]
        with_bos += [self.vocab.char2idx[self.vocab.pad_token]] * (self.max_len - len(with_bos))
        with_eos += [self.vocab.char2idx[self.vocab.pad_token]] * (self.max_len - len(with_eos))
        attention_mask = [1 if t != self.vocab.char2idx[self.vocab.pad_token] else 0 for t in with_bos]
        return torch.tensor(with_bos), torch.tensor(with_eos), torch.tensor(attention_mask).float()

    def collate(self, batch):
        with_bos, with_eos, masks = zip(*batch)
        with_bos = torch.stack(with_bos) 
        with_eos = torch.stack(with_eos)
        masks = torch.stack(masks)
        return with_bos, with_eos, masks

In [28]:
import pandas as pd

In [29]:
df = pd.read_csv('/workspace/uniref50_small.csv')

In [30]:
df = df.sample(n=500000)

In [31]:
df['str_len'] = df.Sequence.str.len()

In [47]:
max_len = 200

In [48]:
df = df[df['str_len'] <= 200]

In [49]:
sequences = df.Sequence.values.tolist()

In [50]:
vocab = TransformerVocab()

ds = TransformerDataset(sequences, vocab=vocab, max_len=max_len + 2)


In [51]:
# wandb.finish()
import wandb
wandb.finish()
wandb_logger = WandbLogger(project="run_pod_prot", log_model=True)


In [52]:
m = Transformer(vocab_size=len(vocab.char2idx), embed_size=256, latent_dim=256, num_heads=8, hidden_dim=64, num_layers=8, max_seq_length=max_len + 2, vocab=vocab)

dl = torch.utils.data.DataLoader(ds, batch_size=128, collate_fn=ds.collate)

trainer = pl.Trainer(devices=1, logger=[wandb_logger])

trainer.fit(m, dl)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/usr/local/lib/python3.10/dist-packages/pytorch_lightning/loops/utilities.py:72: `max_epochs` was not set. Setting it to 1000 epochs. To train without an epoch limit, set `max_epochs=-1`.


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name                | Type               | Params | Mode 
-------------------------------------------------------------------
0 | embedding           | Embedding          | 7.4 K  | train
1 | encoder             | TransformerEncoder | 2.4 M  | train
2 | fc_mu               | Linear             | 65.8 K | train
3 | fc_logvar           | Linear             | 65.8 K | train
4 | fc_latent_to_hidden | Linear             | 65.8 K | train
5 | decoder             | TransformerDecoder | 4.5 M  | train
6 | output_fc           | Linear             | 7.5 K  | train
7 | criterion           | CrossEntropyLoss   | 0      | train
  | other params        | n/a                | 51.7 K | n/a  
-------------------------------------------------------------------
7.1 M     Trainable params
0         Non-trainable params
7.1 M     Total params
28.520    Total estimated model params size (MB)
202       Modules in train mode
0         Modules in eval mode
/usr/loc

Training: |          | 0/? [00:00<?, ?it/s]


Detected KeyboardInterrupt, attempting graceful shutdown ...

KeyboardInterrupt



In [92]:
def generate(m):
    current_tokens = [m.vocab.char2idx[m.vocab.sos_token]]
    tgt_tokens = torch.tensor([current_tokens]).long().to(m.device)
    s = ""
    last_hidden_state = torch.normal(0, 1.0, size=(1, 128)).to(m.device)
    resized_latent = m.fc_latent_to_hidden(last_hidden_state) 
    resized_latent = resized_latent.unsqueeze(1).repeat(1, 1, 1) 
    emb = m.embedding(tgt_tokens) + m.positional_encoding[:, :tgt_tokens.size(1), :]
    for i in range(128):
        tgt_mask = m.generate_square_subsequent_mask(emb.size(1)).to(emb.device)
        decoded = m.decoder(emb, resized_latent, tgt_mask=tgt_mask)
        outputs_v = m.output_fc(decoded)
        outputs_v = outputs_v[:,-1,:] 
        top_char = torch.argmax(outputs_v)
        print(top_char)
        print(outputs_v.shape)
        if top_char == vocab.char2idx[vocab.eos_token]:
            break
        current_tokens.append(top_char.item())
        tgt_tokens = torch.tensor([current_tokens]).long().to(m.device)
        emb = m.embedding(tgt_tokens) + m.positional_encoding[:, :tgt_tokens.size(1), :]
        s += vocab.idx2char[top_char.item()]
    s = vocab.reverse_special(s)
    print(s)

In [93]:
generate(m)

tensor(14)
torch.Size([1, 29])
tensor(19)
torch.Size([1, 29])
tensor(19)
torch.Size([1, 29])
tensor(19)
torch.Size([1, 29])
tensor(19)
torch.Size([1, 29])
tensor(19)
torch.Size([1, 29])
tensor(19)
torch.Size([1, 29])
tensor(19)
torch.Size([1, 29])
tensor(19)
torch.Size([1, 29])
tensor(19)
torch.Size([1, 29])
tensor(19)
torch.Size([1, 29])
tensor(19)
torch.Size([1, 29])
tensor(19)
torch.Size([1, 29])
tensor(19)
torch.Size([1, 29])
tensor(19)
torch.Size([1, 29])
tensor(19)
torch.Size([1, 29])
tensor(19)
torch.Size([1, 29])
tensor(19)
torch.Size([1, 29])
tensor(19)
torch.Size([1, 29])
tensor(19)
torch.Size([1, 29])
tensor(19)
torch.Size([1, 29])
tensor(19)
torch.Size([1, 29])
tensor(19)
torch.Size([1, 29])
tensor(19)
torch.Size([1, 29])
tensor(19)
torch.Size([1, 29])
tensor(19)
torch.Size([1, 29])
tensor(19)
torch.Size([1, 29])
tensor(19)
torch.Size([1, 29])
tensor(19)
torch.Size([1, 29])
tensor(19)
torch.Size([1, 29])
tensor(19)
torch.Size([1, 29])
tensor(19)
torch.Size([1, 29])
tensor(1


KeyboardInterrupt

