In [42]:
import numpy as np
import pandas as pd
import random

import os
import torch
import torch.nn as nn
import torch.nn.functional as F 
from torch.optim import Adam 

# import lightning as L
from torch.utils.data import Dataset, DataLoader # these are needed for the training data
import pytorch_lightning as pl

In [2]:
def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True

In [3]:
Lmax=407
Lmax=206

bs = 32
num_workers = 4
nfolds = 4
device = 'cuda' if torch.cuda.is_available() else 'cpu'

device

'cuda'

In [4]:
seed = 42

seed_everything(seed)

In [5]:
vocab = {'A':0,'C':1,'G':2,'U':3,'M':4}

In [6]:
input_dir = '../input/stanford-ribonanza-rna-folding'
train_dir = f'{input_dir}/train'
test_dir = f'{input_dir}/test'
train_csv = f'{input_dir}/train_data.csv'

train_file = 'train_data.parquet'
sequence_file = 'train_sequences.parquet'

In [7]:
# train_df = pd.read_csv(train_csv)

train_df = pd.read_parquet(train_file)
sequences_df = pd.read_parquet(sequence_file)

In [8]:
train_df.shape, sequences_df.shape

((1643680, 419), (806573, 2))

In [9]:
train_df.head(2)

Unnamed: 0,sequence_id,sequence,experiment_type,dataset_name,reads,signal_to_noise,SN_filter,reactivity_0001,reactivity_0002,reactivity_0003,...,reactivity_error_0197,reactivity_error_0198,reactivity_error_0199,reactivity_error_0200,reactivity_error_0201,reactivity_error_0202,reactivity_error_0203,reactivity_error_0204,reactivity_error_0205,reactivity_error_0206
0,8cdfeef009ea,GGGAACGACUCGAGUAGAGUCGAAAAACGUUGAUAUGGAUUUACUC...,2A3_MaP,15k_2A3,2343,0.944,0,,,,...,,,,,,,,,,
1,51e61fbde94d,GGGAACGACUCGAGUAGAGUCGAAAAACAUUGAUAUGGAUUUACUC...,2A3_MaP,15k_2A3,5326,1.933,1,,,,...,,,,,,,,,,


In [10]:
sequences_df.head(2)

Unnamed: 0,sequence_id,sequence
0,8cdfeef009ea,GGGAACGACUCGAGUAGAGUCGAAAAACGUUGAUAUGGAUUUACUC...
1,51e61fbde94d,GGGAACGACUCGAGUAGAGUCGAAAAACAUUGAUAUGGAUUUACUC...


In [11]:
sequences_df['L'] = sequences_df.sequence.apply(len)

In [12]:
sequences_df.head(2)

Unnamed: 0,sequence_id,sequence,L
0,8cdfeef009ea,GGGAACGACUCGAGUAGAGUCGAAAAACGUUGAUAUGGAUUUACUC...,170
1,51e61fbde94d,GGGAACGACUCGAGUAGAGUCGAAAAACAUUGAUAUGGAUUUACUC...,170


### Dataset

In [13]:
foo = sequences_df['sequence'].values

In [14]:
len(foo)

806573

In [15]:
foo.shape

(806573,)

In [16]:
class RNA_Dataset(torch.utils.data.Dataset):
    def __init__(self, df, mode='train', seed=seed, v=vocab,
                 mask_only=False, Lmax=Lmax, **kwargs):
        
        self.seq_map = v
        self.Lmax = Lmax
        
        self.seq = df['sequence'].values
        self.L = df['L'].values
        
        self.mask_only = mask_only
        
    def __len__(self):
        return len(self.seq)  
    
    def __getitem__(self, idx):
        seq = self.seq[idx]
        
        mask = torch.zeros(self.Lmax, dtype=torch.bool)
        mask[:len(seq)] = True
        
        if self.mask_only:
            return {'mask':mask}
        
        seq = [self.seq_map[s] for s in seq]
        seq = np.array(seq)        
        seq = np.pad(seq, (0, self.Lmax - len(seq)))
        seq = torch.from_numpy(seq)     
        
        rand = torch.rand(mask.nonzero().shape)
        mask_arr = rand < 0.15
        
        selection = torch.flatten((mask_arr).nonzero()).tolist()
        
        mlm = seq.detach().clone()
        mlm[selection] = self.seq_map['M']
        
        
        return {'seq': mlm, 'att_mask': mask}, {'labels': seq}

### Test dataset

In [22]:
ds_train = RNA_Dataset(sequences_df, mode='train')

# sampler_train = torch.utils.data.RandomSampler(ds_train)
# batch_sampler_train = torch.utils.data.BatchSampler(sampler_train, batch_size=2, drop_last=True)

# dl_train = torch.utils.data.DataLoader(ds_train, batch_sampler=batch_sampler_train, num_workers=num_workers, persistent_workers=False)

dl_train = torch.utils.data.DataLoader(ds_train, batch_size=2, num_workers=0, persistent_workers=False, drop_last=True)

In [23]:
inputs, targets = ds_train.__getitem__(0)

print(ds_train.__len__())

# foo
inputs['seq'].shape, inputs['att_mask'].shape, inputs['seq'].dtype, inputs['att_mask'].dtype

806573


(torch.Size([206]), torch.Size([206]), torch.int32, torch.bool)

In [53]:
inputs['att_mask'][0]

tensor([ True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True, 

In [25]:
sum(targets['labels'] != inputs['seq'])/sum(targets['labels'] == inputs['seq'])

tensor(0.1771)

In [40]:
inputs, targets = next(iter(dl_train))
inputs['seq'].shape, targets['labels'].shape

(torch.Size([2, 206]), torch.Size([2, 206]))

### Model

In [43]:
emb = nn.Embedding(len(vocab),192)

In [47]:
emb(targets['labels']).shape

torch.Size([2, 206, 192])

In [51]:
inputs['att_mask'].sum(-1).max()

tensor(170)

In [45]:
class SinusoidalPosEmb(nn.Module):
    def __init__(self, dim=16, M=10000):
        super().__init__()
        self.dim = dim
        self.M = M

    def forward(self, x):
        device = x.device
        half_dim = self.dim // 2
        emb = math.log(self.M) / half_dim
        emb = torch.exp(torch.arange(half_dim, device=device) * (-emb))
        emb = x[...,None] * emb[None,...]
        emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
        return emb

In [44]:
class RNA_Model(nn.Module):
    def __init__(self, num_embeddings=len(vocab), dim=192, num_layers=12, head_size=32, **kwargs):
        super().__init__()
        
        self.emb = nn.Embedding(num_embeddings,dim)
        
        self.pos_enc = SinusoidalPosEmb(dim)
        
        self.transformer = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(d_model=dim, nhead=dim//head_size, dim_feedforward=4*dim,
                dropout=0.1, activation=nn.GELU(), batch_first=True, norm_first=True), num_layers)
        
        self.decoder = nn.Linear(dim,num_embeddings)
    
    def forward(self, x0):
        mask = x0['mask']
        
        Lmax = mask.sum(-1).max()
        mask = mask[:,:Lmax]
        
        x = x0['seq'][:,:Lmax]
        
        pos = torch.arange(Lmax, device=x.device).unsqueeze(0)
        pos = self.pos_enc(pos)
        
        x = self.emb(x)
        x = x + pos
        
        x = self.transformer(x, src_key_padding_mask=~mask)
        x = self.decoder(x)
        
        return x