In [71]:
%load_ext autoreload
%autoreload 2

In [77]:
import numpy as np
import pandas as pd
import random
import math

import os
import torch
import torch.nn as nn
import torch.nn.functional as F 
from torch.optim import Adam 

# import lightning as L
from torch.utils.data import Dataset, DataLoader # these are needed for the training data
import pytorch_lightning as pl

In [2]:
def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True

In [3]:
Lmax=407
Lmax=206

bs = 32
num_workers = 4
nfolds = 4
device = 'cuda' if torch.cuda.is_available() else 'cpu'

device

'cuda'

In [4]:
seed = 42

seed_everything(seed)

In [5]:
vocab = {'A':0,'C':1,'G':2,'U':3,'M':4}

In [6]:
input_dir = '../input/stanford-ribonanza-rna-folding'
train_dir = f'{input_dir}/train'
test_dir = f'{input_dir}/test'
train_csv = f'{input_dir}/train_data.csv'

train_file = 'train_data.parquet'
sequence_file = 'train_sequences.parquet'

In [7]:
# train_df = pd.read_csv(train_csv)

train_df = pd.read_parquet(train_file)
sequences_df = pd.read_parquet(sequence_file)

In [8]:
train_df.shape, sequences_df.shape

((1643680, 419), (806573, 2))

In [9]:
train_df.head(2)

Unnamed: 0,sequence_id,sequence,experiment_type,dataset_name,reads,signal_to_noise,SN_filter,reactivity_0001,reactivity_0002,reactivity_0003,...,reactivity_error_0197,reactivity_error_0198,reactivity_error_0199,reactivity_error_0200,reactivity_error_0201,reactivity_error_0202,reactivity_error_0203,reactivity_error_0204,reactivity_error_0205,reactivity_error_0206
0,8cdfeef009ea,GGGAACGACUCGAGUAGAGUCGAAAAACGUUGAUAUGGAUUUACUC...,2A3_MaP,15k_2A3,2343,0.944,0,,,,...,,,,,,,,,,
1,51e61fbde94d,GGGAACGACUCGAGUAGAGUCGAAAAACAUUGAUAUGGAUUUACUC...,2A3_MaP,15k_2A3,5326,1.933,1,,,,...,,,,,,,,,,


In [10]:
sequences_df.head(2)

Unnamed: 0,sequence_id,sequence
0,8cdfeef009ea,GGGAACGACUCGAGUAGAGUCGAAAAACGUUGAUAUGGAUUUACUC...
1,51e61fbde94d,GGGAACGACUCGAGUAGAGUCGAAAAACAUUGAUAUGGAUUUACUC...


In [11]:
sequences_df['L'] = sequences_df.sequence.apply(len)

In [12]:
sequences_df.head(2)

Unnamed: 0,sequence_id,sequence,L
0,8cdfeef009ea,GGGAACGACUCGAGUAGAGUCGAAAAACGUUGAUAUGGAUUUACUC...,170
1,51e61fbde94d,GGGAACGACUCGAGUAGAGUCGAAAAACAUUGAUAUGGAUUUACUC...,170


### Dataset

In [13]:
foo = sequences_df['sequence'].values

In [14]:
len(foo)

806573

In [15]:
foo.shape

(806573,)

In [74]:
class RNA_Dataset(torch.utils.data.Dataset):
    def __init__(self, df, mode='train', seed=seed, v=vocab,
                 mask_only=False, Lmax=Lmax, **kwargs):
        
        self.seq_map = v
        self.Lmax = Lmax
        
        self.seq = df['sequence'].values
        self.L = df['L'].values
        
        self.mask_only = mask_only
        
    def __len__(self):
        return len(self.seq)  
    
    def __getitem__(self, idx):
        seq = self.seq[idx]
        
        mask = torch.zeros(self.Lmax, dtype=torch.bool)
        mask[:len(seq)] = True
        
        if self.mask_only:
            return {'mask':mask}
        
        seq = [self.seq_map[s] for s in seq]
        seq = np.array(seq)        
        seq = np.pad(seq, (0, self.Lmax - len(seq)))
        seq = torch.from_numpy(seq)     
        
        rand = torch.rand(mask.nonzero().shape)
        mask_arr = rand < 0.15
        
        selection = torch.flatten((mask_arr).nonzero()).tolist()
        
        mlm = seq.detach().clone()
        mlm[selection] = self.seq_map['M']
        
        # true when token is masked
        mask_idx = mlm == self.seq_map['M']
        
        return {'seq': mlm, 'att_mask': mask}, {'labels': seq, 'mask_idx': mask_idx}

### Test dataset

In [56]:
ds_train = RNA_Dataset(sequences_df, mode='train')
dl_train = torch.utils.data.DataLoader(ds_train, batch_size=2, num_workers=0, persistent_workers=False, drop_last=True)

In [57]:
inputs, targets = ds_train.__getitem__(0)

print(ds_train.__len__())

inputs['seq'].shape, inputs['att_mask'].shape, inputs['seq'].dtype, inputs['att_mask'].dtype

806573


(torch.Size([206]), torch.Size([206]), torch.int32, torch.bool)

In [58]:
# inputs['seq'] == ds_train.seq_map['M']

In [60]:
# targets['mask_idx']

In [49]:
# inputs['att_mask']

In [37]:
# ~inputs['att_mask']

In [61]:
sum(targets['labels'] != inputs['seq'])/sum(targets['labels'] == inputs['seq'])

tensor(0.1977)

In [62]:
inputs, targets = next(iter(dl_train))
inputs['seq'].shape, targets['labels'].shape

(torch.Size([2, 206]), torch.Size([2, 206]))

In [65]:
# targets['mask_idx'][0]

#### DataModule

In [68]:
class rna_datamodule(pl.LightningDataModule):
    def __init__(self, train_df, val_df=None, train_bs=bs, val_bs=bs):
        super().__init__()
        
        self.train_df = train_df
        self.val_df = val_df
        self.train_bs = train_bs
        self.val_bs = val_bs
        
    def train_dataloader(self):
        train_ds = RNA_Dataset(self.train_df, mode='train')
        train_dl = torch.utils.data.DataLoader(ds_train, batch_size=train_bs, num_workers=0, persistent_workers=False, drop_last=True)
        
        return train_dl

### Model

In [22]:
emb = nn.Embedding(len(vocab),192)

In [23]:
emb(targets['labels']).shape

torch.Size([2, 206, 192])

In [24]:
inputs['att_mask'].sum(-1).max()

tensor(170)

In [91]:
class SinusoidalPosEmb(nn.Module):
    def __init__(self, dim=16, M=10000):
        super().__init__()
        self.dim = dim
        self.M = M

    def forward(self, x):
        device = x.device
        
        half_dim = self.dim // 2
        emb = math.log(self.M) / half_dim
        emb = torch.exp(torch.arange(half_dim, device=device) * (-emb))
        emb = x[...,None] * emb[None,...]
        emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
        
        print(x.shape)
        print(emb.shape)
        
        return x + emb

In [92]:
class PositionalEncoding(nn.Module):

    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:x.size(0), :]
        return self.dropout(x)

In [103]:
class bert_config:
    num_embeddings=len(vocab)
    dim=192
    num_layers=12
    head_size=32
    dropout=0.1

In [98]:
class RNA_Model(nn.Module):
    def __init__(self, config=bert_config, **kwargs):
        super().__init__()
        
        self.config = config
        
        dim = config.dim
        
        self.emb = nn.Embedding(config.num_embeddings, dim)
        
        self.pos_enc = SinusoidalPosEmb(dim)
        self.pos_enc = PositionalEncoding(dim)
        
        self.transformer = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(d_model=dim, nhead=dim//head_size, dim_feedforward=4*dim,
                dropout=config.dropout, activation=nn.GELU(), batch_first=True, norm_first=True), num_layers)
        
#         self.decoder = nn.Linear(dim, num_embeddings)
        
#         self.softmax = nn.LogSoftmax(dim=-1) 
    
    def forward(self, x0):
        # slice out extra padding starting from max length in batch
        Lmax = x0['att_mask'].sum(-1).max()
        
        mask = x0['att_mask'][:,:Lmax]
        x = x0['seq'][:,:Lmax]

        x = self.emb(x)
        x = self.pos_enc(x)
        
        output = self.transformer(x, src_key_padding_mask=~mask)
        
#         token_predictions = self.decoder(output)
        
#         return self.softmax(token_predictions)

        return output

In [99]:
class mlm_model(pl.LightningModule):
    def __init__(self, model):
        super().__init__()
        
        self.bert = model
        self.decoder = nn.Linear(dim, num_embeddings)
        self.softmax = nn.LogSoftmax(dim=-1) 
        
        self.ml_criterion = nn.NLLLoss(ignore_index=0).to(device)
        
    def forward(self, x):
        embeds = self.bert(x)
    
        token_predictions = self.decoder(embeds)
        
        return self.softmax(token_predictions)
    
    def training_step(self, batch, batch_idx):
        inputs, target = batch
        
        tokens = self(inputs)
        
        loss_token = self.ml_criterion(tokens.transpose(1, 2), token_target)
        
        return loss_token

### Train

In [100]:
model = RNA_Model()

In [101]:
foo = model(inputs)

In [102]:
foo.shape

torch.Size([2, 170, 5])

In [69]:
dm = rna_datamodule(sequences_df)

In [72]:
trainer = pl.Trainer(limit_train_batches=100, max_epochs=1)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [None]:
# trainer.fit(model=model, train_dataloaders=train_loader)

In [86]:
inputs['att_mask'].sum(-1).max()

tensor(170)

In [85]:
inputs['seq'].shape

torch.Size([2, 206])