In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
%reload_ext tensorboard
%tensorboard --logdir=lightning_logs/

Reusing TensorBoard on port 6006 (pid 10304), started 6 days, 1:31:48 ago. (Use '!kill 10304' to kill it.)

In [3]:
import numpy as np
import pandas as pd
import random
import math

import os
import torch
import torch.nn as nn
import torch.nn.functional as F 
from torch.optim import Adam 

# import lightning as L
from torch.utils.data import Dataset, DataLoader # these are needed for the training data
import pytorch_lightning as pl



In [4]:
seed = 42

def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True

seed_everything(seed)

In [5]:
Lmax=407
Lmax=206

bs = 64
num_workers = 4
learning_rate = 4e-5
nfolds = 4
device = 'cuda' if torch.cuda.is_available() else 'cpu'

device

'cuda'

In [6]:
vocab = {'PAD':0,'A':1,'C':2,'G':3,'U':4,'M':5,'S':6} # MASK:5, SEP:6
len(vocab)

7

In [7]:
input_dir = '../input/stanford-ribonanza-rna-folding'
train_dir = f'{input_dir}/train'
test_dir = f'{input_dir}/test'
train_csv = f'{input_dir}/train_data.csv'

train_file = 'train_data.parquet'
sequence_file = 'train_sequences.parquet'

In [8]:
# train_df = pd.read_csv(train_csv)

train_df = pd.read_parquet(train_file)
sequences_df = pd.read_parquet(sequence_file)

In [9]:
train_df.shape, sequences_df.shape

((1643680, 419), (806573, 2))

In [10]:
train_df.head(2)

Unnamed: 0,sequence_id,sequence,experiment_type,dataset_name,reads,signal_to_noise,SN_filter,reactivity_0001,reactivity_0002,reactivity_0003,...,reactivity_error_0197,reactivity_error_0198,reactivity_error_0199,reactivity_error_0200,reactivity_error_0201,reactivity_error_0202,reactivity_error_0203,reactivity_error_0204,reactivity_error_0205,reactivity_error_0206
0,8cdfeef009ea,GGGAACGACUCGAGUAGAGUCGAAAAACGUUGAUAUGGAUUUACUC...,2A3_MaP,15k_2A3,2343,0.944,0,,,,...,,,,,,,,,,
1,51e61fbde94d,GGGAACGACUCGAGUAGAGUCGAAAAACAUUGAUAUGGAUUUACUC...,2A3_MaP,15k_2A3,5326,1.933,1,,,,...,,,,,,,,,,


In [11]:
sequences_df.head(2)

Unnamed: 0,sequence_id,sequence
0,8cdfeef009ea,GGGAACGACUCGAGUAGAGUCGAAAAACGUUGAUAUGGAUUUACUC...
1,51e61fbde94d,GGGAACGACUCGAGUAGAGUCGAAAAACAUUGAUAUGGAUUUACUC...


In [12]:
sequences_df['L'] = sequences_df.sequence.apply(len)

In [13]:
sequences_df.head(2)

Unnamed: 0,sequence_id,sequence,L
0,8cdfeef009ea,GGGAACGACUCGAGUAGAGUCGAAAAACGUUGAUAUGGAUUUACUC...,170
1,51e61fbde94d,GGGAACGACUCGAGUAGAGUCGAAAAACAUUGAUAUGGAUUUACUC...,170


### Dataset

In [14]:
foo = sequences_df['sequence'].values

In [15]:
len(foo)

806573

In [16]:
foo.shape

(806573,)

In [17]:
from dataloader import RNA_Dataset

### Test dataset

In [18]:
ds_train = RNA_Dataset(sequences_df, Lmax, vocab, seed, mode='train', nsp=True)
dl_train = torch.utils.data.DataLoader(ds_train, batch_size=2, num_workers=0, persistent_workers=False, drop_last=True)

In [19]:
inputs, targets = ds_train.__getitem__(0)

# print(ds_train.__len__())

inputs['seq'].shape, inputs['att_mask'].shape, inputs['seq'].dtype, inputs['att_mask'].dtype

(torch.Size([207]), torch.Size([207]), torch.int32, torch.bool)

In [20]:
# targets

{'labels': tensor([3, 3, 3, 1, 1, 2, 3, 1, 2, 4, 2, 3, 1, 3, 4, 1, 3, 1, 3, 4, 2, 3, 1, 1,
         1, 1, 1, 2, 3, 4, 4, 3, 1, 4, 1, 4, 3, 3, 1, 4, 4, 4, 1, 2, 4, 2, 2, 3,
         1, 3, 3, 1, 3, 1, 2, 3, 1, 1, 2, 4, 1, 2, 2, 1, 2, 3, 1, 1, 2, 1, 3, 3,
         3, 3, 1, 1, 1, 2, 4, 2, 4, 1, 2, 2, 2, 6, 3, 4, 3, 3, 2, 3, 4, 2, 4, 2,
         2, 3, 4, 4, 4, 3, 1, 2, 3, 1, 3, 4, 1, 1, 3, 4, 2, 2, 4, 1, 1, 3, 4, 2,
         1, 1, 2, 1, 4, 3, 2, 2, 1, 2, 3, 2, 3, 3, 3, 4, 2, 2, 4, 4, 2, 3, 3, 3,
         1, 2, 2, 2, 3, 2, 1, 1, 1, 1, 3, 1, 1, 1, 2, 1, 1, 2, 1, 1, 2, 1, 1, 2,
         1, 1, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], dtype=torch.int32),
 'token_mask': tensor([False, False, False, False, False, False, False, False,  True, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False,  True, False, False, False, False,  True, False,
          True, 

In [21]:
inputs['seq']

tensor([3, 3, 3, 1, 1, 2, 3, 1, 5, 4, 2, 3, 1, 3, 4, 1, 3, 1, 3, 4, 2, 3, 1, 5,
        1, 1, 1, 2, 5, 4, 5, 3, 1, 4, 1, 4, 3, 3, 1, 4, 4, 4, 1, 2, 4, 2, 2, 3,
        1, 5, 3, 5, 3, 1, 2, 3, 1, 1, 5, 4, 1, 2, 2, 1, 2, 3, 1, 1, 2, 1, 3, 3,
        5, 5, 5, 1, 1, 2, 4, 2, 4, 1, 2, 2, 2, 6, 5, 4, 3, 3, 2, 3, 4, 5, 4, 2,
        2, 3, 5, 4, 4, 3, 5, 2, 3, 5, 3, 5, 1, 1, 3, 4, 5, 2, 4, 1, 5, 3, 4, 2,
        1, 1, 2, 1, 4, 3, 5, 2, 1, 2, 3, 5, 3, 3, 3, 5, 5, 2, 4, 4, 5, 3, 5, 3,
        1, 2, 2, 2, 5, 2, 1, 1, 5, 1, 5, 1, 1, 1, 2, 1, 1, 2, 5, 1, 2, 1, 1, 2,
        1, 1, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], dtype=torch.int32)

In [22]:
targets['labels']

tensor([3, 3, 3, 1, 1, 2, 3, 1, 2, 4, 2, 3, 1, 3, 4, 1, 3, 1, 3, 4, 2, 3, 1, 1,
        1, 1, 1, 2, 3, 4, 4, 3, 1, 4, 1, 4, 3, 3, 1, 4, 4, 4, 1, 2, 4, 2, 2, 3,
        1, 3, 3, 1, 3, 1, 2, 3, 1, 1, 2, 4, 1, 2, 2, 1, 2, 3, 1, 1, 2, 1, 3, 3,
        3, 3, 1, 1, 1, 2, 4, 2, 4, 1, 2, 2, 2, 6, 3, 4, 3, 3, 2, 3, 4, 2, 4, 2,
        2, 3, 4, 4, 4, 3, 1, 2, 3, 1, 3, 4, 1, 1, 3, 4, 2, 2, 4, 1, 1, 3, 4, 2,
        1, 1, 2, 1, 4, 3, 2, 2, 1, 2, 3, 2, 3, 3, 3, 4, 2, 2, 4, 4, 2, 3, 3, 3,
        1, 2, 2, 2, 3, 2, 1, 1, 1, 1, 3, 1, 1, 1, 2, 1, 1, 2, 1, 1, 2, 1, 1, 2,
        1, 1, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], dtype=torch.int32)

In [23]:
targets['mlm_target']

tensor([0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1,
        0, 0, 0, 0, 3, 0, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 3, 0, 1, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        3, 3, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 0, 0, 0, 0, 0, 0, 2, 0, 0,
        0, 0, 4, 0, 0, 0, 1, 0, 0, 1, 0, 4, 0, 0, 0, 0, 2, 0, 0, 0, 1, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 2, 0, 0, 0, 4, 2, 0, 0, 0, 2, 0, 3, 0,
        0, 0, 0, 0, 3, 0, 0, 0, 1, 0, 3, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])

In [24]:
targets['token_mask']

tensor([False, False, False, False, False, False, False, False,  True, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False,  True, False, False, False, False,  True, False,
         True, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,  True,
        False,  True, False, False, False, False, False, False,  True, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False,  True,  True,  True, False, False, False, False, False,
        False, False, False, False, False, False,  True, False, False, False,
        False, False, False,  True, False, False, False, False,  True, False,
        False, False,  True, False, False,  True, False,  True, False, False,
        False, False,  True, False, False, False,  True, False, False, False,
        False, False, False, False, False, False,  True, False, 

In [25]:
((inputs['seq'] == 5) == targets['token_mask']).sum()

tensor(207)

In [26]:
# check mask indeces mark masked tokens
((inputs['seq'] == 5) == targets['token_mask']).sum() == len(targets['mlm_target'])

tensor(True)

In [27]:
inputs['att_mask'].sum()

tensor(171)

In [28]:
# inputs['seq'] == ds_train.seq_map['M']

In [29]:
# targets['labels']

In [30]:
# inputs['att_mask']

In [31]:
# inputs['att_mask']

In [32]:
sum(targets['labels'] != inputs['seq'])/sum(targets['labels'] == inputs['seq'])

tensor(0.1564)

In [33]:
inputs, targets = next(iter(dl_train))
inputs['seq'].shape, targets['labels'].shape

(torch.Size([2, 207]), torch.Size([2, 207]))

In [34]:
# targets['mlm_target'][0]

#### DataModule

In [35]:
from dataloader import create_dataloader, DataloaderWrapper

In [36]:
class rna_datamodule(pl.LightningDataModule):
    def __init__(self, train_df, val_df, bs=bs, num_workers=0):
        super().__init__()
        
        self.train_df = train_df
        self.val_df = val_df
        self.bs = bs
        self.num_workers = num_workers
        self.pw = True if num_workers>0 else False
        
    def train_dataloader(self):
        train_ds = RNA_Dataset(self.train_df, Lmax, vocab, seed, mode='train', nsp=True)
        
        return torch.utils.data.DataLoader(
            train_ds, 
            batch_size=self.bs, 
            num_workers=self.num_workers, 
            persistent_workers=self.pw, 
            drop_last=True
        )
        
        
    
    def val_dataloader(self):
        val_ds = RNA_Dataset(self.val_df, Lmax, vocab, seed, nsp=True)
        
        return torch.utils.data.DataLoader(
            val_ds, 
            batch_size=self.bs, 
            num_workers=self.num_workers, 
            persistent_workers=self.pw, 
            drop_last=True
        )
        
        

### Model

In [37]:
emb = nn.Embedding(len(vocab),192)

In [38]:
emb(targets['labels']).shape

torch.Size([2, 207, 192])

In [39]:
inputs['att_mask'].sum(-1).max()

tensor(178)

In [40]:
class SinusoidalPosEmb(nn.Module):
    def __init__(self, dim=16, M=10000):
        super().__init__()
        self.dim = dim
        self.M = M

    def forward(self, x):
        # device = x.device
        
        half_dim = self.dim // 2
        emb = math.log(self.M) / half_dim
        emb = torch.exp(torch.arange(half_dim) * (-emb))
        emb = x[...,None] * emb[None,...]
        emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
        
        # print(x.shape)
        # print(emb.shape)
        
        return x + emb

In [41]:
class PositionalEncoding(nn.Module):

    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:x.size(0), :]
        return self.dropout(x)

In [42]:
dim = 192
head_size=32

class bert_config:
    num_embeddings=len(vocab)
    dim=dim
    head_size=head_size
    nhead=dim//head_size
    dim_feedforward=4*dim
    num_layers=12
    dropout=0.1

In [43]:
for k, v in bert_config.__dict__.items():
    if '__' not in k:
        print(k, v)

num_embeddings 7
dim 192
head_size 32
nhead 6
dim_feedforward 768
num_layers 12
dropout 0.1


In [44]:
class bert_rna(nn.Module):
    def __init__(self, config=bert_config, **kwargs):
        super().__init__()
        
        self.config = config
        
        dim = config.dim
        
        self.emb = nn.Embedding(config.num_embeddings, dim)
        
        self.pos_enc = SinusoidalPosEmb(dim)
        self.pos_enc = PositionalEncoding(dim)
        
        self.transformer = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(
                d_model=dim, 
                nhead=dim//config.head_size, 
                dim_feedforward=4*dim,
                dropout=config.dropout, 
                activation=nn.GELU(), 
                batch_first=True, 
                norm_first=True), 
            config.num_layers)
    
    def forward(self, seq, mask):
        
        x = self.emb(seq)
        x = self.pos_enc(x)
        
        output = self.transformer(x, src_key_padding_mask=~mask)

        return output

In [45]:
class bert_mlm(nn.Module):
    def __init__(self, bert):
        super().__init__()
        
        self.bert = bert
        self.config = bert.config
        
        self.decoder = nn.Linear(self.config.dim, self.config.num_embeddings)
        self.softmax = nn.LogSoftmax(dim=-1)
        
        self.nsp = nn.Linear(self.config.dim, 2)
        
    def forward(self, seq, mask):
        # [bs, seq_len, dim]
        embeds = self.bert(seq, mask)
    
        # [bs, seq_len, vocab_size]
        token_predictions = self.decoder(embeds)
        
        # [bs, seq_len, vocab_size], 
        return self.softmax(token_predictions), self.nsp(embeds[:,0,:])
    # sortout why use only first word for nsp prediction

In [46]:
class pl_model(pl.LightningModule):
    def __init__(self, model):
        super().__init__()
        
        self.model = model
        
        # maybe use reduction='sum'?
        self.ml_criterion = nn.NLLLoss(ignore_index=0)
        # self.ml_criterion = nn.NLLLoss(ignore_index=0, reduction='sum')
        
        self.nsp_criterion = nn.BCEWithLogitsLoss() 
        
    def configure_optimizers(self):
        return torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=0.015)
        
    def forward(self, seq, mask):
        return self.model(seq, mask)
    
    def get_data(self, batch, batch_idx):
        inputs, targets = batch
        
        # slice out extra padding starting from max length in batch
        Lmax = inputs['att_mask'].sum(-1).max()
        
        mask = inputs['att_mask'][:,:Lmax]
        seq = inputs['seq'][:,:Lmax]
        token_targets = targets['mlm_target'][:,:Lmax]
        inv_token_mask = targets['token_mask'][:,:Lmax]
        inv_token_mask = ~inv_token_mask
        
        return seq, mask, token_targets, inv_token_mask, targets['nsp_target']
    
    def mlm_acc(self, result, target, inv_token_mask):
        r = result.argmax(-1).masked_select(~inv_token_mask)  
        t = target.masked_select(~inv_token_mask)  
        s = (r == t).sum()  
        
        return round(float(s / r.size(0)), 2)
    
    def nsp_acc(self, result: torch.Tensor, target: torch.Tensor):
        s = (result.argmax(1) == target.argmax(1)).sum()  
        return round(float(s / result.size(0)), 2)
    
    def training_step(self, batch, batch_idx):
        
        seq, mask, token_targets, inv_token_mask, nsp_target = self.get_data(batch, batch_idx)
        
        # [bs, seq_len, dim]
        tokens, nsp = self(seq, mask)
        
        tm = inv_token_mask.unsqueeze(-1).expand_as(tokens)
        tokens = tokens.masked_fill(tm, 0)
        
        mlm_loss = self.ml_criterion(tokens.transpose(1, 2), token_targets)
        nsp_loss = self.nsp_criterion(nsp, nsp_target)
        
        loss = mlm_loss + nsp_loss
        self.log("train_mlm_loss", mlm_loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        self.log("train_nsp_loss", nsp_loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        
        acc = self.mlm_acc(tokens, token_targets, inv_token_mask)
        self.log("train_mlm_acc", acc, on_step=True, on_epoch=True, prog_bar=True, logger=True)
                     
        acc = self.nsp_acc(nsp, nsp_target)
        self.log("train_nsp_acc", acc, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        
        return loss
    
    def validation_step(self, batch, batch_idx):
        seq, mask, token_targets, inv_token_mask, nsp_target = self.get_data(batch, batch_idx)

        # [bs, seq_len, dim]
        tokens, nsp = self(seq, mask)
        
        tm = inv_token_mask.unsqueeze(-1).expand_as(tokens)
        tokens = tokens.masked_fill(tm, 0)
        
        mlm_loss = self.ml_criterion(tokens.transpose(1, 2), token_targets)
        nsp_loss = self.nsp_criterion(nsp, nsp_target)
        
        loss = mlm_loss + nsp_loss
        self.log("val_mlm_loss", mlm_loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        self.log("val_nsp_loss", nsp_loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        self.log("val_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        
        acc = self.mlm_acc(tokens, token_targets, inv_token_mask)
        self.log("val_mlm_acc", acc, on_step=True, on_epoch=True, prog_bar=True, logger=True)
                     
        acc = self.nsp_acc(nsp, nsp_target)
        self.log("val_nsp_acc", acc, on_step=True, on_epoch=True, prog_bar=True, logger=True)

        return loss

### Test model

In [47]:
# model = bert_rna()
mlm_model = bert_mlm(bert_rna())
model = pl_model(mlm_model)

In [48]:
# foo = model(inputs['seq'], inputs['att_mask'])
# foo.shape

In [49]:
foo, bar = mlm_model(inputs['seq'], inputs['att_mask'])
foo.shape, bar.shape

(torch.Size([2, 207, 7]), torch.Size([2, 2]))

In [50]:
model.training_step((inputs, targets), 0)

C:\Users\rosul\anaconda3\envs\devenv\lib\site-packages\pytorch_lightning\core\module.py:420: You are trying to `self.log()` but the `self.trainer` reference is not registered on the model yet. This is most likely because the model hasn't been passed to the `Trainer`


tensor(2.9780, grad_fn=<AddBackward0>)

### Split

In [51]:
from sklearn.model_selection import ShuffleSplit, StratifiedShuffleSplit

In [52]:
rs = ShuffleSplit(n_splits=5, test_size=.1, random_state=seed)
train_idx, val_idx = next(rs.split(sequences_df))

In [53]:
len(train_idx), len(val_idx)

(725915, 80658)

### Train

In [54]:
t_df = sequences_df.iloc[train_idx]
v_df = sequences_df.iloc[val_idx]

In [55]:
dm = rna_datamodule(t_df, v_df, num_workers=8)

In [56]:
# model = bert_rna()
mlm_model = bert_mlm(bert_rna())
model = pl_model(mlm_model)

In [57]:
trainer = pl.Trainer(
    accelerator="gpu",
    max_epochs=5,
)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [58]:
# next(iter(dm.val_dataloader()))

In [59]:
trainer.fit(model, dm)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name          | Type              | Params
----------------------------------------------------
0 | model         | bert_mlm          | 5.3 M 
1 | ml_criterion  | NLLLoss           | 0     
2 | nsp_criterion | BCEWithLogitsLoss | 0     
----------------------------------------------------
5.3 M     Trainable params
0         Non-trainable params
5.3 M     Total params
21.366    Total estimated model params size (MB)


Sanity Checking: |                                                               | 0/? [00:00<?, ?it/s]

C:\Users\rosul\anaconda3\envs\devenv\lib\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=15` in the `DataLoader` to improve performance.
  return torch._transformer_encoder_layer_fwd(
C:\Users\rosul\anaconda3\envs\devenv\lib\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=15` in the `DataLoader` to improve performance.


Training: |                                                                      | 0/? [00:00<?, ?it/s]

C:\Users\rosul\anaconda3\envs\devenv\lib\site-packages\pytorch_lightning\trainer\call.py:54: Detected KeyboardInterrupt, attempting graceful shutdown...


### Test

In [None]:
inputs, targets = next(iter(dl_train))

In [82]:
# inputs

In [None]:
preds = model(inputs['seq'].to(device), inputs['att_mask'].to(device))
preds.shape

In [None]:
tm = targets['token_mask'].unsqueeze(-1).expand_as(preds).to(device)
preds = preds.masked_fill(~tm, 0)

preds.shape

In [None]:
preds[1].topk(1).indices.view(-1)

In [None]:
targets['labels'][0]

In [None]:
# (preds[0].topk(1).indices.view(-1) == targets['labels'][0].to(device))

In [None]:
acc = model.mlm_acc(preds[0], targets['labels'][0].to(device), targets['token_mask'].to(device))
acc

### Measure