In [1]:
import numpy as np
import pandas as pd 
import matplotlib.pyplot as plt

import torch 
from torch.utils.data import TensorDataset, DataLoader, SequentialSampler, RandomSampler

if torch.cuda.is_available():       
    device = torch.device("cuda")

# Deep learning
from transformers import AutoTokenizer
from loop_train_berts import (
    set_seed,
    BertClassifier,
    preprocessing_for_bert,
    initialize_model, 
    train,
    bert_predict
)


from sklearn.manifold import TSNE

import random
import time

In [2]:
model_path = 'allenai/biomed_roberta_base'

In [3]:
# For fine-tuning BERT, the authors recommend a batch size of 16-32, but our RTX could hold only 8. 
batch_size = 4

positive = pd.read_csv('data/positive.tsv', sep='\t', index_col=0)
positive['target'] = 1
negative = pd.read_csv('data/negative.tsv', sep='\t', index_col=0)
negative['target'] = 0
data = positive.append(negative)
data['concat'] = data.Title.map(str) + " " + data.Abstract.fillna(' ').map(str)
data['bert'] = data['concat'].apply(lambda x: x.lower())

In [4]:
tokenizer = AutoTokenizer.from_pretrained(model_path, do_lower_case=True)

In [None]:
train_inputs, train_masks, offset = preprocessing_for_bert(tokenizer, data.bert, return_offset=True)
train_labels = torch.tensor(data.target.values)

# Create the DataLoader for our training set
train_data = TensorDataset(train_inputs, train_masks, train_labels)
train_sampler = RandomSampler(train_data)
train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=batch_size, num_workers=1)

In [6]:
# Loss function
loss_fn = nn.CrossEntropyLoss()

set_seed(42)
bert_classifier, optimizer, scheduler = initialize_model(model_path, device, train_dataloader, epochs=4)

Some weights of the model checkpoint at allenai/biomed_roberta_base were not used when initializing RobertaModel: ['lm_head.dense.weight', 'lm_head.layer_norm.bias', 'lm_head.bias', 'lm_head.decoder.weight', 'lm_head.dense.bias', 'lm_head.layer_norm.weight']
- This IS expected if you are initializing RobertaModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [7]:
train(bert_classifier, device, train_dataloader, optimizer, scheduler, epochs=4, save_emb=True)

Start training...

 Epoch  |  Batch  |  Train Loss  |  Val Loss  |  Val Acc  |  Elapsed 
----------------------------------------------------------------------
   1    |   100   |   0.448089   |     -      |     -     |   23.82  
   1    |   200   |   0.416967   |     -      |     -     |   23.77  
   1    |   300   |   0.308454   |     -      |     -     |   24.39  
   1    |   400   |   0.388859   |     -      |     -     |   25.01  
   1    |   500   |   0.314160   |     -      |     -     |   25.27  
   1    |   600   |   0.342056   |     -      |     -     |   26.08  
   1    |   700   |   0.191861   |     -      |     -     |   26.53  
   1    |   800   |   0.192471   |     -      |     -     |   26.81  
   1    |   900   |   0.157770   |     -      |     -     |   27.45  
   1    |  1000   |   0.197449   |     -      |     -     |   27.97  
   1    |  1100   |   0.218265   |     -      |     -     |   28.43  
   1    |  1200   |   0.121481   |     -      |     -     |   28.98  




 Epoch  |  Batch  |  Train Loss  |  Val Loss  |  Val Acc  |  Elapsed 
----------------------------------------------------------------------
   4    |   100   |   0.137687   |     -      |     -     |   23.47  
   4    |   200   |   0.185541   |     -      |     -     |   23.71  
   4    |   300   |   0.155648   |     -      |     -     |   24.50  
   4    |   400   |   0.096470   |     -      |     -     |   25.00  
   4    |   500   |   0.123750   |     -      |     -     |   25.51  
   4    |   600   |   0.147560   |     -      |     -     |   25.97  
   4    |   700   |   0.082644   |     -      |     -     |   26.42  
   4    |   800   |   0.122023   |     -      |     -     |   26.76  
   4    |   900   |   0.084847   |     -      |     -     |   26.98  
   4    |  1000   |   0.215947   |     -      |     -     |   27.81  
   4    |  1100   |   0.148923   |     -      |     -     |   28.23  
   4    |  1200   |   0.135644   |     -      |     -     |   28.67  
   4    |  1300  