In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import numpy as np
import pandas as pd
import random
import math

import os
import torch
import torch.nn as nn
import torch.nn.functional as F 
from torch.optim import Adam 

# import lightning as L
from torch.utils.data import Dataset, DataLoader # these are needed for the training data
import pytorch_lightning as pl



In [3]:
seed = 42

def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True

seed_everything(seed)

In [4]:
Lmax=407
Lmax=206

bs = 32
num_workers = 4
nfolds = 4
device = 'cuda' if torch.cuda.is_available() else 'cpu'

device

'cuda'

In [5]:
vocab = {'PAD': 0, 'A':1,'C':2,'G':3,'U':4,'M':5}
len(vocab)

6

In [6]:
input_dir = '../input/stanford-ribonanza-rna-folding'
train_dir = f'{input_dir}/train'
test_dir = f'{input_dir}/test'
train_csv = f'{input_dir}/train_data.csv'

train_file = 'train_data.parquet'
sequence_file = 'train_sequences.parquet'

In [7]:
# train_df = pd.read_csv(train_csv)

train_df = pd.read_parquet(train_file)
sequences_df = pd.read_parquet(sequence_file)

In [8]:
train_df.shape, sequences_df.shape

((1643680, 419), (806573, 2))

In [9]:
train_df.head(2)

Unnamed: 0,sequence_id,sequence,experiment_type,dataset_name,reads,signal_to_noise,SN_filter,reactivity_0001,reactivity_0002,reactivity_0003,...,reactivity_error_0197,reactivity_error_0198,reactivity_error_0199,reactivity_error_0200,reactivity_error_0201,reactivity_error_0202,reactivity_error_0203,reactivity_error_0204,reactivity_error_0205,reactivity_error_0206
0,8cdfeef009ea,GGGAACGACUCGAGUAGAGUCGAAAAACGUUGAUAUGGAUUUACUC...,2A3_MaP,15k_2A3,2343,0.944,0,,,,...,,,,,,,,,,
1,51e61fbde94d,GGGAACGACUCGAGUAGAGUCGAAAAACAUUGAUAUGGAUUUACUC...,2A3_MaP,15k_2A3,5326,1.933,1,,,,...,,,,,,,,,,


In [10]:
sequences_df.head(2)

Unnamed: 0,sequence_id,sequence
0,8cdfeef009ea,GGGAACGACUCGAGUAGAGUCGAAAAACGUUGAUAUGGAUUUACUC...
1,51e61fbde94d,GGGAACGACUCGAGUAGAGUCGAAAAACAUUGAUAUGGAUUUACUC...


In [11]:
sequences_df['L'] = sequences_df.sequence.apply(len)

In [12]:
sequences_df.head(2)

Unnamed: 0,sequence_id,sequence,L
0,8cdfeef009ea,GGGAACGACUCGAGUAGAGUCGAAAAACGUUGAUAUGGAUUUACUC...,170
1,51e61fbde94d,GGGAACGACUCGAGUAGAGUCGAAAAACAUUGAUAUGGAUUUACUC...,170


### Dataset

In [13]:
foo = sequences_df['sequence'].values

In [14]:
len(foo)

806573

In [15]:
foo.shape

(806573,)

In [16]:
class RNA_Dataset(torch.utils.data.Dataset):
    def __init__(self, df, mode='train', seed=seed, v=vocab,
                 mask_only=False, Lmax=Lmax, **kwargs):
        
        self.seq_map = v
        self.Lmax = Lmax
        
        self.seq = df['sequence'].values
        self.L = df['L'].values
        
        self.mask_only = mask_only
        
    def __len__(self):
        return len(self.seq)  
    
    def __getitem__(self, idx):
        seq = self.seq[idx]
        
        mask = torch.zeros(self.Lmax, dtype=torch.bool)
        mask[:len(seq)] = True
        
        if self.mask_only:
            return {'mask':mask}
        
        seq = [self.seq_map[s] for s in seq]
        seq = np.array(seq)        
        seq = np.pad(seq, (0, self.Lmax - len(seq)))
        seq = torch.from_numpy(seq)
        
        rand = torch.rand(mask.sum())
        mask_arr = rand < 0.15
        
        print(rand.shape)
        
        selection = torch.flatten((mask_arr).nonzero()).tolist()
        
        mlm = seq.detach().clone()
        mlm[selection] = self.seq_map['M']
        
        token_mask = mlm == self.seq_map['M']
        
        # # true when token is masked
        # selection = torch.flatten((~mask_arr).nonzero()).tolist()
        # mlm_target = mlm.detach().clone()
        # mlm_target[selection] = 0
        
        mlm_target = mlm.masked_fill(~token_mask, 0)
        
        
        return {'seq': mlm, 'att_mask': mask}, {'labels': seq, 'token_mask': token_mask, 'mlm_target': mlm_target.long()}

### Test dataset

In [17]:
ds_train = RNA_Dataset(sequences_df, mode='train')
dl_train = torch.utils.data.DataLoader(ds_train, batch_size=2, num_workers=0, persistent_workers=False, drop_last=True)

In [18]:
inputs, targets = ds_train.__getitem__(0)

# print(ds_train.__len__())

inputs['seq'].shape, inputs['att_mask'].shape, inputs['seq'].dtype, inputs['att_mask'].dtype

torch.Size([170])


(torch.Size([206]), torch.Size([206]), torch.int32, torch.bool)

In [19]:
targets['mlm_target']

tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        5, 0, 0, 0, 0, 5, 0, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 5, 0, 5, 0, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 5, 5, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 5, 0,
        0, 0, 0, 5, 0, 0, 0, 5, 0, 0, 5, 0, 5, 0, 0, 0, 0, 5, 0, 0, 0, 5, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 5, 0, 0, 0, 5, 5, 0, 0, 0, 5, 0, 5,
        0, 0, 0, 0, 0, 5, 0, 0, 0, 5, 0, 5, 0, 0, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])

In [20]:
# check mask indeces mark masked tokens
((targets['mlm_target'] == 5) == targets['token_mask']).sum() == len(targets['mlm_target'])

tensor(True)

In [21]:
inputs['att_mask'].sum()

tensor(170)

In [22]:
# inputs['seq'] == ds_train.seq_map['M']

In [23]:
targets['labels']

tensor([3, 3, 3, 1, 1, 2, 3, 1, 2, 4, 2, 3, 1, 3, 4, 1, 3, 1, 3, 4, 2, 3, 1, 1,
        1, 1, 1, 2, 3, 4, 4, 3, 1, 4, 1, 4, 3, 3, 1, 4, 4, 4, 1, 2, 4, 2, 2, 3,
        1, 3, 3, 1, 3, 1, 2, 3, 1, 1, 2, 4, 1, 2, 2, 1, 2, 3, 1, 1, 2, 1, 3, 3,
        3, 3, 1, 1, 1, 2, 4, 2, 4, 1, 2, 2, 2, 3, 4, 3, 3, 2, 3, 4, 2, 4, 2, 2,
        3, 4, 4, 4, 3, 1, 2, 3, 1, 3, 4, 1, 1, 3, 4, 2, 2, 4, 1, 1, 3, 4, 2, 1,
        1, 2, 1, 4, 3, 2, 2, 1, 2, 3, 2, 3, 3, 3, 4, 2, 2, 4, 4, 2, 3, 3, 3, 1,
        2, 2, 2, 3, 2, 1, 1, 1, 1, 3, 1, 1, 1, 2, 1, 1, 2, 1, 1, 2, 1, 1, 2, 1,
        1, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], dtype=torch.int32)

In [24]:
inputs['seq']

tensor([3, 3, 3, 1, 1, 2, 3, 1, 2, 5, 2, 3, 1, 3, 4, 1, 3, 1, 3, 4, 2, 3, 1, 1,
        5, 1, 1, 2, 3, 5, 4, 5, 1, 4, 1, 4, 3, 3, 1, 4, 4, 4, 1, 2, 4, 2, 2, 3,
        1, 3, 5, 1, 5, 1, 2, 3, 1, 1, 2, 5, 1, 2, 2, 1, 2, 3, 1, 1, 2, 1, 3, 3,
        3, 5, 5, 5, 1, 2, 4, 2, 4, 1, 2, 2, 2, 3, 4, 5, 3, 2, 3, 4, 2, 4, 5, 2,
        3, 4, 4, 5, 3, 1, 2, 5, 1, 3, 5, 1, 5, 3, 4, 2, 2, 5, 1, 1, 3, 5, 2, 1,
        1, 2, 1, 4, 3, 2, 2, 5, 2, 3, 2, 3, 5, 3, 4, 2, 5, 5, 4, 2, 3, 5, 3, 5,
        2, 2, 2, 3, 2, 5, 1, 1, 1, 5, 1, 5, 1, 2, 1, 1, 2, 1, 1, 5, 1, 1, 2, 1,
        1, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], dtype=torch.int32)

In [25]:
targets['mlm_target']

tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        5, 0, 0, 0, 0, 5, 0, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 5, 0, 5, 0, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 5, 5, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 5, 0,
        0, 0, 0, 5, 0, 0, 0, 5, 0, 0, 5, 0, 5, 0, 0, 0, 0, 5, 0, 0, 0, 5, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 5, 0, 0, 0, 5, 5, 0, 0, 0, 5, 0, 5,
        0, 0, 0, 0, 0, 5, 0, 0, 0, 5, 0, 5, 0, 0, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])

In [26]:
# targets['labels']

In [27]:
# inputs['att_mask']

In [28]:
# inputs['att_mask']

In [29]:
sum(targets['labels'] != inputs['seq'])/sum(targets['labels'] == inputs['seq'])

tensor(0.1573)

In [30]:
inputs, targets = next(iter(dl_train))
inputs['seq'].shape, targets['labels'].shape

torch.Size([170])
torch.Size([170])


(torch.Size([2, 206]), torch.Size([2, 206]))

In [31]:
# targets['mlm_target'][0]

#### DataModule

In [32]:
class rna_datamodule(pl.LightningDataModule):
    def __init__(self, train_df, val_df, train_bs=bs, val_bs=bs):
        super().__init__()
        
        self.train_df = train_df
        self.val_df = val_df
        self.val_df = val_df
        self.val_bs = val_bs
        
    def train_dataloader(self):
        train_ds = RNA_Dataset(self.train_df, mode='train')
        train_dl = torch.utils.data.DataLoader(ds_train, batch_size=self.train_bs, num_workers=0, persistent_workers=False, drop_last=True)
        
        return train_dl
    
    def train_dataloader(self):
        val_ds = RNA_Dataset(self.val_df)
        val_dl = torch.utils.data.DataLoader(val_ds, batch_size=self.val_bs, num_workers=0, persistent_workers=False, drop_last=True)
        
        return val_dl

### Model

In [33]:
emb = nn.Embedding(len(vocab),192)

In [34]:
emb(targets['labels']).shape

torch.Size([2, 206, 192])

In [35]:
inputs['att_mask'].sum(-1).max()

tensor(170)

In [36]:
class SinusoidalPosEmb(nn.Module):
    def __init__(self, dim=16, M=10000):
        super().__init__()
        self.dim = dim
        self.M = M

    def forward(self, x):
        # device = x.device
        
        half_dim = self.dim // 2
        emb = math.log(self.M) / half_dim
        emb = torch.exp(torch.arange(half_dim) * (-emb))
        emb = x[...,None] * emb[None,...]
        emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
        
        print(x.shape)
        print(emb.shape)
        
        return x + emb

In [37]:
class PositionalEncoding(nn.Module):

    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:x.size(0), :]
        return self.dropout(x)

In [38]:
class bert_config:
    num_embeddings=len(vocab)
    dim=192
    num_layers=12
    head_size=32
    dropout=0.1

In [39]:
# bert_config.dim

In [114]:
class bert_rna(nn.Module):
    def __init__(self, config=bert_config, **kwargs):
        super().__init__()
        
        self.config = config
        
        dim = config.dim
        
        self.emb = nn.Embedding(config.num_embeddings, dim)
        
        self.pos_enc = SinusoidalPosEmb(dim)
        self.pos_enc = PositionalEncoding(dim)
        
        self.transformer = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(
                d_model=dim, 
                nhead=dim//config.head_size, 
                dim_feedforward=4*dim,
                dropout=config.dropout, 
                activation=nn.GELU(), 
                batch_first=True, 
                norm_first=True), 
            config.num_layers)
    
    def forward(self, seq, mask):
        
        x = self.emb(seq)
        x = self.pos_enc(x)
        
        output = self.transformer(x, src_key_padding_mask=~mask)

        return output

In [115]:
class bert_mlm(pl.LightningModule):
    def __init__(self, model):
        super().__init__()
        
        self.bert = model
        self.config = model.config
        
        self.decoder = nn.Linear(self.config.dim, self.config.num_embeddings)
        self.softmax = nn.LogSoftmax(dim=-1) 
        
        self.ml_criterion = nn.NLLLoss(ignore_index=0)
        
    def forward(self, seq, mask):
        # [bs, seq_len, dim]
        embeds = self.bert(seq, mask)
    
        # [bs, seq_len, vocab_size]
        token_predictions = self.decoder(embeds)
        
        # [bs, seq_len, vocab_size]
        return self.softmax(token_predictions)
    
    def training_step(self, batch, batch_idx):
        inputs, targets = batch
        
        # slice out extra padding starting from max length in batch
        Lmax = inputs['att_mask'].sum(-1).max()
        
        mask = inputs['att_mask'][:,:Lmax]
        seq = inputs['seq'][:,:Lmax]
        inv_token_mask = targets['token_mask'][:,:Lmax]
        inv_token_mask = ~inv_token_mask
        
        # predict
        # [bs, seq_len, dim]
        tokens = self(seq, mask)
        token_targets = targets['mlm_target'][:,:tokens.shape[1]]
        
        # print(inv_token_mask)
        print(inv_token_mask.shape)
        print(inv_token_mask.unsqueeze(-1).expand_as(tokens).shape)
        
        tm = inv_token_mask.unsqueeze(-1).expand_as(tokens)
        tokens = tokens.masked_fill(tm, 0)
        
        loss = self.ml_criterion(tokens.transpose(1, 2), token_targets)
        
        self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        
        return loss

### Test model

In [110]:
model = bert_rna()
mlm_model = bert_mlm(model)

In [111]:
foo = model(inputs['seq'], inputs['att_mask'])
foo.shape

torch.Size([2, 206, 192])

In [112]:
foo = mlm_model(inputs['seq'], inputs['att_mask'])
foo.shape

torch.Size([2, 206, 6])

In [113]:
mlm_model.training_step((inputs, targets), 0)

torch.Size([2, 170])
torch.Size([2, 170, 6])


tensor(5.7783, grad_fn=<NllLoss2DBackward0>)

### Split

In [179]:
from sklearn.model_selection import ShuffleSplit, StratifiedShuffleSplit

In [180]:
rs = ShuffleSplit(n_splits=5, test_size=.1, random_state=seed)
train_idx, val_idx = next(rs.split(sequences_df))

In [181]:
len(train_idx), len(val_idx)

(725915, 80658)

### Train

In [182]:
t_df = sequences_df.iloc[train_idx]
v_df = sequences_df.iloc[val_idx]

In [183]:
dm = rna_datamodule(t_df, v_df)

In [184]:
model = bert_rna()
mlm_model = bert_mlm(model)

In [72]:
trainer = pl.Trainer(limit_train_batches=100, max_epochs=1)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [None]:
# trainer.fit(model=model, train_dataloaders=train_loader)