In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
# %reload_ext tensorboard
# %tensorboard --logdir=lightning_logs/

In [3]:
import numpy as np
import pandas as pd
import random
import math

import os
import torch
import torch.nn as nn
import torch.nn.functional as F 
from torch.optim import Adam 

# import lightning as L
from torch.utils.data import Dataset, DataLoader # these are needed for the training data
import pytorch_lightning as pl

In [4]:
seed = 42

def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True

seed_everything(seed)

In [5]:
Lmax=407
Lmax=50

bs = 256
num_workers = 4
learning_rate = 5e-4
nfolds = 4
device = 'cuda' if torch.cuda.is_available() else 'cpu'

device

'cuda'

In [6]:
input_dir = '../input/stanford-ribonanza-rna-folding'

tok_file='tokenizer.json'

train_file = 'train_data.parquet'
sequence_file = 'train_sequences.parquet'

In [7]:
# train_df = pd.read_csv(train_csv)

train_df = pd.read_parquet(train_file)
sequences_df = pd.read_parquet(sequence_file)

In [8]:
train_df.shape, sequences_df.shape

((1643680, 419), (806573, 2))

In [9]:
train_df.head(2)

Unnamed: 0,sequence_id,sequence,experiment_type,dataset_name,reads,signal_to_noise,SN_filter,reactivity_0001,reactivity_0002,reactivity_0003,...,reactivity_error_0197,reactivity_error_0198,reactivity_error_0199,reactivity_error_0200,reactivity_error_0201,reactivity_error_0202,reactivity_error_0203,reactivity_error_0204,reactivity_error_0205,reactivity_error_0206
0,8cdfeef009ea,GGGAACGACUCGAGUAGAGUCGAAAAACGUUGAUAUGGAUUUACUC...,2A3_MaP,15k_2A3,2343,0.944,0,,,,...,,,,,,,,,,
1,51e61fbde94d,GGGAACGACUCGAGUAGAGUCGAAAAACAUUGAUAUGGAUUUACUC...,2A3_MaP,15k_2A3,5326,1.933,1,,,,...,,,,,,,,,,


In [10]:
sequences_df.head(2)

Unnamed: 0,sequence_id,sequence
0,8cdfeef009ea,GGGAACGACUCGAGUAGAGUCGAAAAACGUUGAUAUGGAUUUACUC...
1,51e61fbde94d,GGGAACGACUCGAGUAGAGUCGAAAAACAUUGAUAUGGAUUUACUC...


In [11]:
sequences_df['L'] = sequences_df.sequence.apply(len)

In [12]:
sequences_df.head(2)

Unnamed: 0,sequence_id,sequence,L
0,8cdfeef009ea,GGGAACGACUCGAGUAGAGUCGAAAAACGUUGAUAUGGAUUUACUC...,170
1,51e61fbde94d,GGGAACGACUCGAGUAGAGUCGAAAAACAUUGAUAUGGAUUUACUC...,170


### Dataset

In [13]:
from tokenizers import (
    decoders,
    models,
    normalizers,
    pre_tokenizers,
    processors,
    trainers,
    Tokenizer,
)

In [14]:
from dataloader import RNA_Token_Dataset

In [15]:
tokenizer = Tokenizer.from_file(tok_file)

In [16]:
# ds_train = RNA_Token_Dataset(sequences_df, Lmax, tokenizer)
ds_train = RNA_Token_Dataset(sequences_df, Lmax)
dl_train = torch.utils.data.DataLoader(ds_train, batch_size=2, num_workers=0, persistent_workers=False, drop_last=True)

In [17]:
inputs, targets = ds_train.__getitem__(0)

# print(ds_train.__len__())

inputs['seq'].shape, inputs['att_mask'].shape, inputs['seq'].dtype, inputs['att_mask'].dtype

(torch.Size([51]), torch.Size([51]), torch.int32, torch.bool)

In [18]:
inputs['seq']

tensor([ 40,  45,  29,  87,  42,  43, 129, 206, 218,   3, 100,   8,  41,  33,
        203,  43,  49,  26, 232, 129,  78,  45,  10,  56,   3,  86,  20,  45,
        103,   3, 121,   3,  49,  58,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0,   0,   0,   0,   0,   0,   0,   0], dtype=torch.int32)

In [19]:
targets['labels']

tensor([ 40,  45,  29,  87,  42,  43, 129, 206, 218,  98, 100,   8,  41,  33,
        203,  43,  49,  26, 232, 129,  78,  45,  10,  56, 199,  86,  20,  45,
        103,  69, 121,  14,  49,  58,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0,   0,   0,   0,   0,   0,   0,   0], dtype=torch.int32)

In [20]:
targets['mlm_target']

tensor([  0,   0,   0,   0,   0,   0,   0,   0,   0,  98,   0,   0,   0,   0,
          0,   0,   0,   0,   0,   0,   0,   0,   0,   0, 199,   0,   0,   0,
          0,  69,   0,  14,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0,   0,   0,   0,   0,   0,   0,   0])

In [21]:
targets['token_mask']

tensor([False, False, False, False, False, False, False, False, False,  True,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False,  True, False, False, False, False,  True,
        False,  True, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False])

In [22]:
((inputs['seq'] == 3) == targets['token_mask']).sum()

tensor(51)

In [23]:
(inputs['seq'] != 3).sum(), (~targets['token_mask']).sum()

(tensor(47), tensor(47))

In [24]:
# check mask indeces mark masked tokens
((inputs['seq'] == 3) == targets['token_mask']).sum() == len(targets['mlm_target'])

tensor(True)

In [25]:
inputs['att_mask'].sum()

tensor(34)

In [26]:
# inputs['seq'] == ds_train.seq_map['M']

In [27]:
# targets['labels']

In [28]:
# inputs['att_mask']

In [29]:
# inputs['att_mask']

In [30]:
sum(targets['labels'] != inputs['seq'])/sum(targets['labels'] == inputs['seq'])

tensor(0.0851)

In [31]:
inputs, targets = next(iter(dl_train))
inputs['seq'].shape, targets['labels'].shape

(torch.Size([2, 51]), torch.Size([2, 51]))

In [32]:
# targets['mlm_target'][0]

#### DataModule

In [33]:
from dataloader import create_dataloader, DataloaderWrapper

In [34]:
class rna_datamodule(pl.LightningDataModule):
    def __init__(self, train_df, val_df, bs=bs, num_workers=0):
        super().__init__()
        
        self.train_df = train_df
        self.val_df = val_df
        self.bs = bs
        self.num_workers = num_workers
        self.pw = True if num_workers>0 else False
        
    def train_dataloader(self):
        train_ds = RNA_Token_Dataset(self.train_df, Lmax)
        
        return torch.utils.data.DataLoader(
            train_ds, 
            batch_size=self.bs,
            shuffle=True,
            num_workers=self.num_workers, 
            persistent_workers=self.pw, 
            drop_last=True
        )

    def val_dataloader(self):
        val_ds = RNA_Token_Dataset(self.val_df, Lmax)
        
        return torch.utils.data.DataLoader(
            val_ds, 
            batch_size=self.bs, 
            num_workers=self.num_workers, 
            persistent_workers=self.pw, 
            drop_last=True
        )

### Model

In [35]:
emb = nn.Embedding(tokenizer.get_vocab_size(), 192)

In [36]:
# tokenizer.get_vocab_size()

In [37]:
emb(targets['labels']).shape

torch.Size([2, 51, 192])

In [38]:
inputs['att_mask'].sum(-1).max()

tensor(34)

In [39]:
class SinusoidalPosEmb(nn.Module):
    def __init__(self, dim=16, M=10000):
        super().__init__()
        self.dim = dim
        self.M = M

    def forward(self, x):
        # device = x.device
        
        half_dim = self.dim // 2
        emb = math.log(self.M) / half_dim
        emb = torch.exp(torch.arange(half_dim) * (-emb))
        emb = x[...,None] * emb[None,...]
        emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
        
        # print(x.shape)
        # print(emb.shape)
        
        return x + emb

In [40]:
class PositionalEncoding(nn.Module):

    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:x.size(0), :]
        return self.dropout(x)

In [41]:
dim = 256
head_size=32
num_embeddings = tokenizer.get_vocab_size()

class bert_config:
    num_embeddings=num_embeddings
    dim=dim
    head_size=head_size
    nhead=dim//head_size
    dim_feedforward=4*dim
    num_layers=2
    dropout=0.1

In [42]:
for k, v in bert_config.__dict__.items():
    if '__' not in k:
        print(k, v)

num_embeddings 256
dim 256
head_size 32
nhead 8
dim_feedforward 1024
num_layers 2
dropout 0.1


In [43]:
class bert_rna(nn.Module):
    def __init__(self, config=bert_config, **kwargs):
        super().__init__()
        
        self.config = config
        
        dim = config.dim
        
        self.emb = nn.Embedding(config.num_embeddings, dim)
        
        self.pos_enc = SinusoidalPosEmb(dim)
        self.pos_enc = PositionalEncoding(dim)
        
        self.transformer = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(
                d_model=dim, 
                nhead=dim//config.head_size, 
                dim_feedforward=4*dim,
                dropout=config.dropout, 
                activation=nn.GELU(), 
                batch_first=True, 
                norm_first=True), 
            config.num_layers)
    
    def forward(self, seq, mask):
        
        x = self.emb(seq)
        x = self.pos_enc(x)
        
        output = self.transformer(x, src_key_padding_mask=~mask)

        return output

In [44]:
class bert_mlm(nn.Module):
    def __init__(self, bert):
        super().__init__()
        
        self.bert = bert
        self.config = bert.config
        
        self.decoder = nn.Linear(self.config.dim, self.config.num_embeddings)
        self.softmax = nn.LogSoftmax(dim=-1)
        
    def forward(self, seq, mask):
        # [bs, seq_len, dim]
        embeds = self.bert(seq, mask)
    
        # [bs, seq_len, vocab_size]
        token_predictions = self.decoder(embeds)
        
        # [bs, seq_len, vocab_size]
        return self.softmax(token_predictions)

In [45]:
class pl_model(pl.LightningModule):
    def __init__(self, model):
        super().__init__()
        
        self.model = model
        
        # maybe use reduction='sum'?
        self.ml_criterion = nn.NLLLoss()
        # self.ml_criterion = nn.NLLLoss(ignore_index=0)
        # self.ml_criterion = nn.NLLLoss(ignore_index=0, reduction='sum')
        
    def configure_optimizers(self):
        return torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=0.015)
        
    def forward(self, seq, mask):
        return self.model(seq, mask)
    
    def get_data(self, batch, batch_idx):
        inputs, targets = batch
        
        # slice out extra padding starting from max length in batch
        Lmax = inputs['att_mask'].sum(-1).max()
        
        mask = inputs['att_mask'][:,:Lmax]
        seq = inputs['seq'][:,:Lmax]
        token_targets = targets['mlm_target'][:,:Lmax]
        token_mask = targets['token_mask'][:,:Lmax]
        
        return seq, mask, token_targets, token_mask
    
    def compute_acc(self, result, target, token_mask):
        r = result.argmax(-1).masked_select(token_mask)
        t = target.masked_select(token_mask)  
        s = (r == t).sum()  
        
        return round(float(s / token_mask.sum()), 2)
    
    def training_step(self, batch, batch_idx):
        
        seq, mask, token_targets, token_mask = self.get_data(batch, batch_idx)
        
        # [bs, seq_len, dim]
        tokens = self(seq, mask)
        
        tm = token_mask.unsqueeze(-1).expand_as(tokens)
        tokens = tokens.masked_fill(~tm, 0)
        
        loss = self.ml_criterion(tokens.transpose(1, 2), token_targets)
        
        self.log("train_mlm_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=False, logger=True)
        
        acc = self.compute_acc(tokens, token_targets, token_mask)
        self.log("train_acc", acc, on_step=True, on_epoch=True, prog_bar=False, logger=True)
        
        return loss
    
    def validation_step(self, batch, batch_idx):
        seq, mask, token_targets, token_mask = self.get_data(batch, batch_idx)

        # [bs, seq_len, dim]
        tokens = self(seq, mask)

        tm = token_mask.unsqueeze(-1).expand_as(tokens)
        tokens = tokens.masked_fill(~tm, 0)

        loss = self.ml_criterion(tokens.transpose(1, 2), token_targets)
        
        self.log("val_mlm_loss", loss, on_step=True, on_epoch=True, prog_bar=False, logger=True)
        self.log("val_loss", loss, on_step=True, on_epoch=True, prog_bar=False, logger=True)
        
        acc = self.compute_acc(tokens, token_targets, token_mask)
        self.log("val_acc", acc, on_step=True, on_epoch=True, prog_bar=False, logger=True)

        return loss

### Test model

In [46]:
# model = bert_rna()
mlm_model = bert_mlm(bert_rna())
model = pl_model(mlm_model)



In [47]:
# foo = model(inputs['seq'], inputs['att_mask'])
# foo.shape

In [48]:
foo = mlm_model(inputs['seq'], inputs['att_mask'])
foo.shape

torch.Size([2, 51, 256])

In [49]:
model.training_step((inputs, targets), 0)

  rank_zero_warn(


tensor(0.8220, grad_fn=<NllLoss2DBackward0>)

### Split

In [50]:
from sklearn.model_selection import ShuffleSplit, StratifiedShuffleSplit

In [51]:
rs = ShuffleSplit(n_splits=5, test_size=.1, random_state=seed)
train_idx, val_idx = next(rs.split(sequences_df))

In [52]:
len(train_idx), len(val_idx)

(725915, 80658)

### Train

In [53]:
t_df = sequences_df.iloc[train_idx]
v_df = sequences_df.iloc[val_idx]

In [54]:
dm = rna_datamodule(t_df, v_df, num_workers=12)

In [55]:
# model = bert_rna()
mlm_model = bert_mlm(bert_rna())
model = pl_model(mlm_model)



In [56]:
trainer = pl.Trainer(
    accelerator="gpu",
    max_epochs=20,
    accumulate_grad_batches=4,
)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [57]:
# next(iter(dm.val_dataloader()))

In [58]:
model.device

device(type='cpu')

In [59]:
trainer.fit(model, dm)

You are using a CUDA device ('NVIDIA GeForce RTX 4090 Laptop GPU') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name         | Type     | Params
------------------------------------------
0 | model        | bert_mlm | 1.7 M 
1 | ml_criterion | NLLLoss  | 0     
------------------------------------------
1.7 M     Trainable params
0         Non-trainable params
1.7 M     Total params
6.843     Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

`Trainer.fit` stopped: `max_epochs=20` reached.


### Test

In [60]:
inputs, targets = next(iter(dl_train))

In [61]:
model.device, device

(device(type='cpu'), 'cuda')

In [62]:
foo = model.to(device)

In [63]:
inputs['seq'].to(device).device, inputs['att_mask'].to(device).device

(device(type='cuda', index=0), device(type='cuda', index=0))

In [64]:
# model.to(device)

In [65]:
preds = model(inputs['seq'].to(device), inputs['att_mask'].to(device))
preds.shape

torch.Size([2, 51, 256])

In [66]:
tm = targets['token_mask'].unsqueeze(-1).expand_as(preds).to(device)
preds = preds.masked_fill(~tm, 0)

preds.shape

torch.Size([2, 51, 256])

In [67]:
tokenizer.decode([41])

'CAG'

In [68]:
preds[0].topk(1).indices.view(-1)

tensor([ 0, 12,  0,  0,  0,  0,  0,  0,  0,  0,  0, 12, 12,  0,  0,  0, 12,  0,
         0,  0,  0,  0,  0,  0,  0,  0, 12,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0],
       device='cuda:0')

In [69]:
preds.shape

torch.Size([2, 51, 256])

In [70]:
# preds[0,1]

In [71]:
torch.exp(preds[0,8])

tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 

In [72]:
# preds[1].topk(2).indices

In [73]:
targets['labels'][0]

tensor([ 40,  45,  29,  87,  42,  43, 129, 206, 218,  98, 100,   8,  41,  33,
        203,  43,  49,  26, 232, 129,  78,  45,  10,  56, 199,  86,  20,  45,
        103,  69, 121,  14,  49,  58,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0,   0,   0,   0,   0,   0,   0,   0], dtype=torch.int32)

In [74]:
targets['mlm_target'][0]

tensor([ 0, 45,  0,  0,  0,  0,  0,  0,  0,  0,  0,  8, 41,  0,  0,  0, 49,  0,
         0,  0,  0,  0,  0,  0,  0,  0, 20,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0])

In [75]:
# (preds[0].topk(1).indices.view(-1) == targets['labels'][0].to(device))

In [76]:
acc = model.compute_acc(preds[0], targets['mlm_target'][0].to(device), targets['token_mask'].to(device))
acc

0.62

In [77]:
preds[0].argmax(-1)

tensor([ 0, 12,  0,  0,  0,  0,  0,  0,  0,  0,  0, 12, 12,  0,  0,  0, 12,  0,
         0,  0,  0,  0,  0,  0,  0,  0, 12,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0],
       device='cuda:0')

In [78]:
i=1
r = preds[i].argmax(-1).masked_select(targets['token_mask'][i].to(device))
r

tensor([12, 12, 12, 12, 12, 12, 12, 12], device='cuda:0')

In [79]:
t = targets['mlm_target'][i].to(device).masked_select(targets['token_mask'][i].to(device))
t

tensor([ 99,  87,  43, 129, 100,  43, 217,  58], device='cuda:0')

In [80]:
targets['token_mask'].sum()

tensor(13)

In [81]:
(r == t).sum() / targets['token_mask'].sum()

tensor(0., device='cuda:0')

### Measure