In [None]:
pip install -U transformers

In [None]:
import json
import gc
import os
import time
import itertools
from pathlib import Path

import optuna
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split

import torch
from torch import nn
from transformers import (
    BertConfig,
    BertModel,
    BertTokenizer,
    BertForPreTraining,
    BertForMaskedLM,
    DataCollatorForLanguageModeling
)
from torch.utils.data.dataset import Dataset
from torch.utils.data.dataloader import DataLoader

In [None]:
device = torch.device(
    "cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
!ls ../input/augmented-data-for-stanford-covid-vaccine/48k_augment.csv

In [None]:
aug_df = pd.read_csv("../input/augmented-data-for-stanford-covid-vaccine/48k_augment.csv")

In [None]:
all_df = aug_df[aug_df.score > 0].reset_index(drop=True)

In [None]:
all_df["seq_length"] = all_df["sequence"].map(len)

In [None]:
def make_all_sequence(row):
    length = row["seq_length"]
    a, b, c = row["sequence"], row["structure"], row["predicted_loop_type"]
    return [a[i] + b[i] + c[i] for i in range(length)]

In [None]:
all_df["text"] = all_df[["sequence", "structure", "predicted_loop_type", "seq_length"]].apply(make_all_sequence, axis=1)

In [None]:
ALL_TOKENS = "().ACGUBEHIMSX"

In [None]:
tokens1 = ["A", "C", "G", "U"]
tokens2 = ["(", ")", "."]
tokens3 = ["B", "E", "H", "I", "M", "S", "X"]

In [None]:
with open("vocab.txt", "w") as f:
    f.write("[PAD]\n")
    f.write("[UNK]\n")
    f.write("[CLS]\n")
    f.write("[SEP]\n")
    f.write("[MASK]\n")

    vocab_list = []
    ix = 0
    for a in tokens1:
        for b in tokens2:
            for c in tokens3:
                vocab_list.append(a+b+c)
                ix += 1
                f.write(a+b+c + "\n")

## Split train and test

In [None]:
train_df, valid_df = train_test_split(all_df, test_size=0.2, shuffle=True, random_state=2020)

## Set Bert Env

In [None]:
class EarlyStopping:
    """
    ref: https://github.com/Bjarten/early-stopping-pytorch
    """
    def __init__(self, patience=2, verbose=False):
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = np.Inf
        self.best_model_savepath = None

    def __call__(self, val_loss, model, save_name):
        score = -val_loss
        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_loss, model, save_name)
            self.best_model_savepath = save_name

        elif score < self.best_score:
            self.counter += 1
            # print(f'EarlyStopping counter: {self.counter} '
            #      'out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(val_loss, model, save_name)
            os.remove(self.best_model_savepath)
            self.best_model_savepath = save_name
            self.counter = 0

    def save_checkpoint(self, val_loss, model, save_name):
        if self.verbose:
            print(f'Validation loss decreased ('
                  '{self.val_loss_min:.5f} --> {val_loss:.5f}'
                  ').  Saving model ...')
            print("Save model: {}".format(save_name))
        torch.save(model.state_dict(), save_name)
        self.val_loss_min = val_loss
    
    def get_best_filepath(self):
        return self.best_model_savepath

In [None]:
class NullScheduler():
    def __init__(self, lr=0.01):
        super(NullScheduler, self).__init__()
        self.lr = lr
        self.cycle = 0

    def __call__(self, time):
        return self.lr

    def __str__(self):
        string = "NullScheduler\n" \
            + "lr={0:0.5f}".format(self.lr)
        return string

def adjust_learning_rate(optimizer, lr):
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr


def get_learning_rate(optimizer):
    lr = []
    for param_group in optimizer.param_groups:
        lr += [param_group['lr']]
    assert(len(lr) == 1)
    lr = lr[0]
    return lr

In [None]:
class RnaDataset(Dataset):

    def __init__(self, tokenizer, df, block_size=256):
        self.examples = [tokenizer.convert_tokens_to_ids(x) for x in df["text"].tolist()]

    def __len__(self):
        return len(self.examples)

    def __getitem__(self, i) -> torch.Tensor:
        return torch.tensor(self.examples[i], dtype=torch.long)

In [None]:
def bert_pretrain(train_df, valid_df, config):
    ###################################
    # Tokenizer
    ###################################    
    tokenizer = BertTokenizer(
        "vocab.txt",
        do_basic_tokenize=False,
        do_lower_case=False,
        strip_accents=False,
        never_split=vocab_list
    )

    ###################################
    # Bert Config
    ###################################
    bert_config = BertConfig(
        vocab_size=tokenizer.vocab_size,
        hidden_size=config["bert_hidden_size"],
        num_hidden_layers=config["bert_num_hidden_layers"],
        num_attention_heads=config["bert_num_attention_heads"],
        intermediate_size=config["bert_intermediate_size"]
    )

    ###################################
    # Model
    ###################################
    model = BertForMaskedLM(config=bert_config)
    model.to(device)

    ###################################
    # Dataset
    ###################################
    train_dataset = RnaDataset(tokenizer, train_df)
    valid_dataset = RnaDataset(tokenizer, valid_df)
    data_collator = DataCollatorForLanguageModeling(
        tokenizer=tokenizer
    )

    ###################################
    # Dataloader
    ###################################    
    train_dataloader = DataLoader(
        train_dataset,
        batch_size=config["batch_size"],
        collate_fn=data_collator,
    )

    valid_dataloader = DataLoader(
        valid_dataset,
        batch_size=config["batch_size"],
        collate_fn=data_collator,
    )
    ##################################
    # early stopping
    ##################################
    early_stopping = EarlyStopping(
        patience=config["n_early_stopping_patience"],
        verbose=False
    )
    
    
    ##################
    # lr scheduler
    ##################
    scheduler = NullScheduler(lr=config["learning_rate"])

    #scheduler = CosineAnnealingScheduler(
    #    eta_min=ca_eta_min,
    #    eta_max=ca_eta_max,
    #    cycle=ca_cycle,
    #    repeat=False
    #)

    ##################
    # Optimiizer
    ##################
    optimizer = torch.optim.Adam(
        filter(lambda p: p.requires_grad, model.parameters()),
        lr=scheduler(0)
    )

    # dataloaders
    dataloaders_dict = {
        "train": train_dataloader,
        "valid": valid_dataloader
    }


    ###############################
    # train epoch loop
    ###############################
    # iteration and loss count
    iteration = 1
    epoch_train_loss = 0.0
    epoch_val_loss = 0.0
    num_epochs = config["n_epoch"]
    valid_period = 1

    print(f"Optimizer\n  {optimizer}")
    print(f"Scheduler\n  {scheduler}")
    print("** start training here! **")
    print("                    |  val   |  train ")
    print("rate    iter  epoch |  loss  |  loss  | time  ")
    print("--------------------------------------------------------------------------------")

    for epoch in range(num_epochs+1):
        t_epoch_start = time.time()
        val_pred_list = []
        val_true_list = []
        for phase in ['train', 'valid']:
            if phase == 'train':
                lr = scheduler(epoch)
                if lr < 0:
                    break
                adjust_learning_rate(optimizer, lr)
                model.train()
            else:  # valid
                if((epoch+1) % valid_period == 0):
                    model.eval()
                else:
                    continue
            # get batch data loop
            for iter_i, model_input \
                    in enumerate(dataloaders_dict[phase]):

                for k, v in model_input.items():
                    if isinstance(v, torch.Tensor):
                        model_input[k] = v.to(device)
                # zero grad
                optimizer.zero_grad()
                # train
                with torch.set_grad_enabled(phase == 'train'):
                    loss, pred = model(**model_input)

                    if phase == 'train':
                        print(f"\r{iter_i*config['batch_size']} / {len(train_dataset)}", end='')
                        loss.backward()  
                        optimizer.step()
                        epoch_train_loss += loss.item()
                        iteration += 1
                    elif phase == "valid":
                        epoch_val_loss += loss.item()

        t_epoch_finish = time.time()
        elapsed_time = t_epoch_finish - t_epoch_start
        lr = get_learning_rate(optimizer)

        epoch_train_loss /= len(train_dataset)
        epoch_val_loss /= len(valid_dataset)

        print(f"\r", end="")
        print(
            "{0:1.5f}  {1:4d}  {2:3d}  | {3:4.4f} {4:4.4f}  {5:1.5f}"
            .format(
                lr,
                iteration,
                epoch,
                epoch_val_loss,
                epoch_train_loss,
                elapsed_time),
        )
        t_epoch_start = time.time()
        
        ######################
        # early stopping
        ######################
        model_save_path = f"./checkpoint_epoch{epoch}_val{epoch_val_loss:.4f}.pth"
        early_stopping(epoch_val_loss, model, model_save_path)
        if early_stopping.early_stop:
            print("******** Early stopping ********")
            best_score = early_stopping.best_score*(-1)
            print(f"Best Score: {best_score}")
            # load best model parameter
            best_model_save_path = early_stopping.get_best_filepath()
            model.load_state_dict(
                torch.load(
                    best_model_save_path,
                    map_location=lambda storage,loc: storage
                )
            )
            return model, bert_config, best_score
        epoch_train_loss = 0
        epoch_val_loss = 0
    
    best_score = early_stopping.best_score*(-1)
    return model, bert_config, best_score

## Training

In [None]:
gc.collect()

In [None]:
# model, bert_config, best_score = bert_pretrain(train_df, valid_df, config)

In [None]:
config = {
    "learning_rate": 0.001,
    "batch_size": 16,
    "n_epoch": 200,
    "n_early_stopping_patience": 20,
    "bert_hidden_size": 128,
    "bert_num_hidden_layers": 8,
    "bert_num_attention_heads": 4,
    "bert_intermediate_size": 256
}


model, bert_config, best_score = bert_pretrain(train_df, valid_df, config)

In [None]:
torch.save(model.state_dict(), f"./bert_mlm_{best_score}.model")

In [None]:
model.save_pretrained(f"./bert_mlm_{best_score}")

In [None]:
bert_config.to_json_file("./bert_config.json")

In [None]:
bert = model.bert

In [None]:
tokenizer = BertTokenizer(
    "vocab.txt",
    do_basic_tokenize=False,
    do_lower_case=False,
    strip_accents=False,
    never_split=vocab_list
)

test_dataset = RnaDataset(tokenizer, valid_df)

In [None]:
last_hidden, pool = bert(test_dataset.__getitem__(0).view(1,-1).to(device))

In [None]:
last_hidden.shape

In [None]:
pretrained_model = BertForMaskedLM.from_pretrained(f"./bert_mlm_{best_score}")

In [None]:
pretrained_model.to(device)

In [None]:
pretrained_model(test_dataset.__getitem__(0).view(1,-1).to(device))