# Task 2: RoBERTa based ADR Classification

## 1. Install necessary libraries

In [None]:
!pip install -q contractions transformers sent2vec imbalanced-learn seqeval[gpu] ekphrasis
!pip install -q tf-estimator-nightly==2.8.0.dev2021122109
# !python -m pip uninstall -q -y spacy
# !python -m pip install -q -U spacy

[K     |████████████████████████████████| 4.0 MB 9.8 MB/s 
[K     |████████████████████████████████| 43 kB 1.0 MB/s 
[K     |████████████████████████████████| 80 kB 2.9 MB/s 
[K     |████████████████████████████████| 106 kB 38.6 MB/s 
[K     |████████████████████████████████| 287 kB 18.2 MB/s 
[K     |████████████████████████████████| 77 kB 3.3 MB/s 
[K     |████████████████████████████████| 895 kB 39.2 MB/s 
[K     |████████████████████████████████| 6.6 MB 9.0 MB/s 
[K     |████████████████████████████████| 596 kB 28.1 MB/s 
[K     |████████████████████████████████| 45 kB 1.1 MB/s 
[K     |████████████████████████████████| 53 kB 1.2 MB/s 
[?25h  Building wheel for ekphrasis (setup.py) ... [?25l[?25hdone
  Building wheel for seqeval (setup.py) ... [?25l[?25hdone
[K     |████████████████████████████████| 462 kB 7.8 MB/s 
[K     |████████████████████████████████| 6.0 MB 11.1 MB/s 
[K     |████████████████████████████████| 42 kB 1.4 MB/s 
[K     |██████████████████████

In [None]:
# !python -m spacy download en_core_web_sm

Collecting en-core-web-sm==3.2.0
  Downloading https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.2.0/en_core_web_sm-3.2.0-py3-none-any.whl (13.9 MB)
[K     |████████████████████████████████| 13.9 MB 409 kB/s 
Installing collected packages: en-core-web-sm
  Attempting uninstall: en-core-web-sm
    Found existing installation: en-core-web-sm 2.2.5
    Uninstalling en-core-web-sm-2.2.5:
      Successfully uninstalled en-core-web-sm-2.2.5
Successfully installed en-core-web-sm-3.2.0
[38;5;2m✔ Download and installation successful[0m
You can now load the package via spacy.load('en_core_web_sm')


## 2. Load all the libraries

In [None]:
import numpy as np
import pandas as pd
import torch
import warnings
import torch.nn as nn

from transformers import RobertaTokenizerFast, RobertaModel
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
from sklearn.metrics import accuracy_score, classification_report, f1_score
from ekphrasis.classes.preprocessor import TextPreProcessor
from ekphrasis.classes.tokenizer import SocialTokenizer
from ekphrasis.dicts.emoticons import emoticons
from tqdm import tqdm

pd.options.display.max_rows = None
pd.options.display.max_columns = None
pd.options.display.max_colwidth=None
warnings.filterwarnings("ignore")

## 3. Load the dataset

In [None]:
# Load data
training = pd.read_csv('task3_training.tsv',sep='\t', usecols=['tweet_id', 'begin', 'end', 'type', 'extraction', 'drug', 'tweet', 'meddra_code', 'meddra_term'])
validation = pd.read_csv('task3_validation.tsv', sep="\t", skipinitialspace=True)

In [None]:
training.head()

Unnamed: 0,tweet_id,begin,end,type,extraction,drug,tweet,meddra_code,meddra_term
0,331187619096588288,,,,,ofloxacin,@seefisch:oral drugs for pyelonephritis:ciprofloxacin levofloxacin tmp/smz do not use nitrofurantoin for pyelo(only cystitis)@david_medinaf,,
1,332227554956161024,,,,,trazodone,happy for wellbutrin; has similar effects as adderall.. trazodone is super promising for sleep.. but abilify can cause weight gain -_-,,
2,332448217490944000,,,,,lamotrigine,"@stilgarg i'm ok ty have an official diagnosis of bipolar now, feeling ok at the moment lamotrigine has been increased having monotherapy:/",,
3,332977955754110976,,,,,cymbalta,i'm soo depressed cymbalta couldn't help me .,,
4,333674203331051520,,,,,seroquel,"time for my daily afternoon relaxation ritual of smoking weed, taking 2 mgs of clonazepam, and 400 mg of seroquel xr.",,


In [None]:
print(f"Shape of Training data: {training.shape}")
print(f"Shape of Validation data: {validation.shape}")

Shape of Training data: (2246, 9)
Shape of Validation data: (560, 9)


## 4. Prepare the data

In [None]:
training.head()

Unnamed: 0,tweet_id,begin,end,type,extraction,drug,tweet,meddra_code,meddra_term
0,331187619096588288,,,,,ofloxacin,@seefisch:oral drugs for pyelonephritis:ciprofloxacin levofloxacin tmp/smz do not use nitrofurantoin for pyelo(only cystitis)@david_medinaf,,
1,332227554956161024,,,,,trazodone,happy for wellbutrin; has similar effects as adderall.. trazodone is super promising for sleep.. but abilify can cause weight gain -_-,,
2,332448217490944000,,,,,lamotrigine,"@stilgarg i'm ok ty have an official diagnosis of bipolar now, feeling ok at the moment lamotrigine has been increased having monotherapy:/",,
3,332977955754110976,,,,,cymbalta,i'm soo depressed cymbalta couldn't help me .,,
4,333674203331051520,,,,,seroquel,"time for my daily afternoon relaxation ritual of smoking weed, taking 2 mgs of clonazepam, and 400 mg of seroquel xr.",,


In [None]:
# Remove the rows which don't have ADR
# training_data = training[training.begin.notnull()]
# validation_data = validation[validation.begin.notnull()]

# Drop duplicate rows, only keep first
training.drop_duplicates(subset=["extraction", "tweet"], inplace=True, keep='first')
print(f"Shape after removing duplicates Training data: {training.shape}")

# Reset Index
training.reset_index(inplace=True, drop=True)

Shape after removing duplicates Training data: (2172, 9)


In [None]:
# Pre-processing
# Referred from: https://github.com/cbaziotis/ekphrasis

text_processor = TextPreProcessor(
    # terms that will be normalized
    normalize=['url', 'email', 'percent', 'money', 'phone', 'user',
        'time', 'url', 'date', 'number'],
    
    # terms that will be annotated
    annotate={"hashtag", "allcaps", "elongated", "repeated",
        'emphasis', 'censored'},
    fix_html=True,  # fix HTML tokens
    
    # corpus from which the word statistics are going to be used 
    # for word segmentation 
    segmenter="twitter", 
    
    # corpus from which the word statistics are going to be used 
    # for spell correction
    corrector="twitter", 
    
    unpack_hashtags=True,  # perform word segmentation on hashtags
    unpack_contractions=True,  # Unpack contractions (can't -> can not)
    spell_correct_elong=False,  # spell correction for elongated words
    
    # select a tokenizer. You can use SocialTokenizer, or pass your own
    # the tokenizer, should take as input a string and return a list of tokens
    tokenizer=SocialTokenizer(lowercase=True).tokenize,
    
    # list of dictionaries, for replacing tokens extracted from the text,
    # with other expressions. You can pass more than one dictionaries.
    dicts=[emoticons]
)

Word statistics files not found!
Downloading... done!
Unpacking... done!
Reading twitter - 1grams ...
generating cache file for faster loading...
reading ngrams /root/.ekphrasis/stats/twitter/counts_1grams.txt
Reading twitter - 2grams ...
generating cache file for faster loading...
reading ngrams /root/.ekphrasis/stats/twitter/counts_2grams.txt
Reading twitter - 1grams ...


In [None]:
# Process the tweets
training['clean_tweets'] = [" ".join(text_processor.pre_process_doc(tweet)) for tweet in training.tweet]
validation['clean_tweets'] = [" ".join(text_processor.pre_process_doc(tweet)) for tweet in validation.tweet]

In [None]:
# Create the label class
training['label'] = [1 if begin else 0 for begin in training.begin.notnull()]
validation['label'] = [1 if begin else 0 for begin in validation.begin.notnull()]

In [None]:
# After pre-processing
training[['clean_tweets', 'label', 'extraction']].head()

Unnamed: 0,clean_tweets,label,extraction
0,<user> : oral drugs for pyelonephritis : ciprofloxacin levofloxacin tmp / smz do not use nitrofurantoin for pyelo ( only cystitis ) <user>,0,
1,happy for wellbutrin ; has similar effects as adderall . <repeated> trazodone is super promising for sleep . <repeated> but abilify can cause weight gain -_-,0,
2,"<user> i am ok ty have an official diagnosis of bipolar now , feeling ok at the moment lamotrigine has been increased having monotherapy <annoyed>",0,
3,i am soo depressed cymbalta could not help me .,0,
4,"time for my daily afternoon relaxation ritual of smoking weed , taking <number> mgs of clonazepam , and <number> mg of seroquel xr .",0,


In [None]:
validation[['clean_tweets', 'label', 'extraction']].head()

Unnamed: 0,clean_tweets,label,extraction
0,"do you have any medication allergies ? "" asthma ! <repeated> "" me : "" . <repeated> "" pt : "" no wait . avelox , that ' s it ! "" "" so no other allergies ? "" "" right ! "" * cont",1,allergies
1,"<user> if <hashtag> a velox </hashtag> has hurt your liver , avoid tylenol always , as it further damages liver , eat grapefruit unless taking cardiac drugs",1,HURT YOUR Liver
2,"apparently , baclofen greatly exacerbates the "" ad "" part of my adhd . average length of focus today : about <number> seconds .",1,AD
3,"apparently , baclofen greatly exacerbates the "" ad "" part of my adhd . average length of focus today : about <number> seconds .",1,focus
4,pt of mine died from cipro rt <user> : <user> if only more doctors thought like you ! i lost my entire life to <number> cipro pills,1,died


### Prepare tweets for RoBERTa

In [None]:
BATCH_SIZE = 32
N_EPOCHS = 5
MAX_LENGTH = 128
LEARNING_RATE = 2e-5

In [None]:
# Define tokenizer
tokenizer = RobertaTokenizerFast.from_pretrained("roberta-base")

Downloading:   0%|          | 0.00/878k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/446k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.29M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/481 [00:00<?, ?B/s]

In [None]:
# Tokenize train and validation data
training_enc = tokenizer.batch_encode_plus(training.clean_tweets.to_list(), padding="longest", truncation=True, max_length=MAX_LENGTH, return_tensors="pt")
validation_enc = tokenizer.batch_encode_plus(validation.clean_tweets.to_list(), padding="longest", truncation=True, max_length=MAX_LENGTH, return_tensors="pt")

In [None]:
training_enc.keys()

dict_keys(['input_ids', 'attention_mask'])

In [None]:
# Define dataloader
def get_dataloader(encoding, target, set='train'):
    if set == 'train':
        data = (TensorDataset(encoding.input_ids, encoding.attention_mask, target, torch.tensor(training.index.values.tolist())))
    else:
        data = (TensorDataset(encoding.input_ids, encoding.attention_mask, target, torch.tensor(validation.index.values.tolist())))
    sampler = RandomSampler(data)
    dataloader = DataLoader(data, sampler=sampler, batch_size=BATCH_SIZE)
    return dataloader

In [None]:
# Get train and validation dataloaders
training_dataloader = get_dataloader(training_enc, torch.tensor(training.label.to_list()), 'train')
validation_dataloader = get_dataloader(validation_enc, torch.tensor(validation.label.to_list()), 'valid')

In [None]:
# Sanity check that the tensors returned by the dataloader are correct
for batch in training_dataloader:
    input_ids, attn_mask, target, index = batch
    print(input_ids.shape, attn_mask.shape, target.shape, index.shape)
    break

torch.Size([32, 124]) torch.Size([32, 124]) torch.Size([32]) torch.Size([32])


## 5. Model Building

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

'cuda'

In [None]:
class RobertaClassifier(nn.Module):
    def __init__(self, transformer):
        super(RobertaClassifier, self).__init__()
        self.transformer = transformer
        self.linear_layer = nn.Linear(768, 2)

    def forward(self, ip_ids, attn_mask):
        op = self.transformer(input_ids=ip_ids,
                              attention_mask=attn_mask)
        return self.linear_layer(op["pooler_output"])

In [None]:
def count_parameter(model):
    return sum(para.numel() for para in model.parameters() if para.requires_grad)

In [None]:
transformer = RobertaModel.from_pretrained('roberta-base')
model = RobertaClassifier(transformer).to(device)

In [None]:
print(f"The model has {count_parameter(model):,} trainable parameters.")

The model has 124,647,170 trainable parameters.


In [None]:
# Define optimizer 
criterion = torch.nn.CrossEntropyLoss()
optim = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)

In [None]:
# Define the train and validation functions

def train(model, dataloader, clip=1.0):
    model.train()

    epoch_loss = 0
    batch_num = 0
    y_pred, y_true = [], []

    for index, batch in tqdm(enumerate(dataloader)):
        batch = tuple(row.to(device) for row in batch)
        input_ids, attn_mask, target, indexes = batch

        optim.zero_grad()
        output = model(input_ids, attn_mask)
        loss = criterion(output, target)
        loss.backward()

        nn.utils.clip_grad_norm_(model.parameters(), clip)
        optim.step()

        epoch_loss += loss.item()
        batch_num += 1
        y_pred.extend(torch.argmax(output, -1).tolist())
        y_true.extend(target.tolist())

    return epoch_loss/batch_num, f1_score(y_true, y_pred)

def evaluate(model, dataloader):
    model.eval()

    epoch_loss = 0
    batch_num = 0
    y_pred, y_true = [], []

    for index, batch in enumerate(dataloader):
        batch = tuple(row.to(device) for row in batch)
        input_ids, attn_mask, target, indexes = batch
        
        with torch.no_grad():
            output = model(input_ids, attn_mask)
            loss = criterion(output, target)
            
            epoch_loss += loss.item()
            batch_num += 1
            y_pred.extend(torch.argmax(output, -1).tolist())
            y_true.extend(target.tolist())
    
    return epoch_loss/batch_num, f1_score(y_true, y_pred), y_pred, y_true

In [None]:
best_valid_loss = float('inf')
total_train_loss, total_valid_loss = list(), list()

In [None]:
for epoch in tqdm(range(N_EPOCHS)):
    train_loss, train_f1_score = train(model, training_dataloader)
    total_train_loss.append(train_loss)

    valid_loss, valid_f1_score, pred, target = evaluate(model, validation_dataloader)
    total_valid_loss.append(valid_loss)

    if valid_loss < best_valid_loss:
        best_valid_loss_rob = valid_loss
        best_pred, best_target = pred, target
        torch.save(model.state_dict(), "model_least_loss_rob.pt")
        print("\nBest Model Saved!!\n")
    
    torch.save(model.state_dict(), "model_checkpoint_rob" + str(epoch+1) + ".pt")
    print("Checkpoint Model Saved!\n")

    print(f"Epoch: {epoch+1:02}")
    print(f"Train Total Loss: {train_loss:.3f} | Train F1 Score: {train_f1_score:.3f}")
    print(f"Valid Total Loss: {valid_loss:.3f} | Valid F1 Score: {valid_f1_score:.3f}")
    print("-"*20)

  0%|          | 0/5 [00:00<?, ?it/s]
0it [00:00, ?it/s][A
1it [00:01,  1.36s/it][A
2it [00:02,  1.28s/it][A
3it [00:03,  1.25s/it][A
4it [00:05,  1.23s/it][A
5it [00:06,  1.22s/it][A
6it [00:07,  1.22s/it][A
7it [00:08,  1.21s/it][A
8it [00:09,  1.20s/it][A
9it [00:11,  1.21s/it][A
10it [00:12,  1.20s/it][A
11it [00:13,  1.20s/it][A
12it [00:14,  1.20s/it][A
13it [00:15,  1.20s/it][A
14it [00:17,  1.20s/it][A
15it [00:18,  1.20s/it][A
16it [00:19,  1.20s/it][A
17it [00:20,  1.20s/it][A
18it [00:21,  1.20s/it][A
19it [00:23,  1.21s/it][A
20it [00:24,  1.21s/it][A
21it [00:25,  1.20s/it][A
22it [00:26,  1.20s/it][A
23it [00:27,  1.21s/it][A
24it [00:29,  1.21s/it][A
25it [00:30,  1.20s/it][A
26it [00:31,  1.20s/it][A
27it [00:32,  1.20s/it][A
28it [00:33,  1.20s/it][A
29it [00:35,  1.20s/it][A
30it [00:36,  1.21s/it][A
31it [00:37,  1.21s/it][A
32it [00:38,  1.20s/it][A
33it [00:39,  1.21s/it][A
34it [00:41,  1.21s/it][A
35it [00:42,  1.21s/it][A
36it


Best Model Saved!!



 20%|██        | 1/5 [01:31<06:05, 91.26s/it]

Checkpoint Model Saved!

Epoch: 01
Train Total Loss: 0.106 | Train F1 Score: 0.975
Valid Total Loss: 0.767 | Valid F1 Score: 0.899
--------------------



0it [00:00, ?it/s][A
1it [00:01,  1.34s/it][A
2it [00:02,  1.29s/it][A
3it [00:03,  1.25s/it][A
4it [00:05,  1.24s/it][A
5it [00:06,  1.26s/it][A
6it [00:07,  1.25s/it][A
7it [00:08,  1.24s/it][A
8it [00:09,  1.23s/it][A
9it [00:11,  1.22s/it][A
10it [00:12,  1.22s/it][A
11it [00:13,  1.22s/it][A
12it [00:14,  1.22s/it][A
13it [00:16,  1.21s/it][A
14it [00:17,  1.21s/it][A
15it [00:18,  1.22s/it][A
16it [00:19,  1.22s/it][A
17it [00:20,  1.22s/it][A
18it [00:22,  1.22s/it][A
19it [00:23,  1.21s/it][A
20it [00:24,  1.21s/it][A
21it [00:25,  1.21s/it][A
22it [00:26,  1.22s/it][A
23it [00:28,  1.21s/it][A
24it [00:29,  1.21s/it][A
25it [00:30,  1.21s/it][A
26it [00:31,  1.21s/it][A
27it [00:33,  1.22s/it][A
28it [00:34,  1.21s/it][A
29it [00:35,  1.21s/it][A
30it [00:36,  1.21s/it][A
31it [00:37,  1.21s/it][A
32it [00:39,  1.21s/it][A
33it [00:40,  1.22s/it][A
34it [00:41,  1.22s/it][A
35it [00:42,  1.22s/it][A
36it [00:43,  1.22s/it][A
37it [00:45,  


Best Model Saved!!



 40%|████      | 2/5 [03:02<04:34, 91.42s/it]

Checkpoint Model Saved!

Epoch: 02
Train Total Loss: 0.080 | Train F1 Score: 0.985
Valid Total Loss: 0.773 | Valid F1 Score: 0.902
--------------------



0it [00:00, ?it/s][A
1it [00:01,  1.35s/it][A
2it [00:02,  1.30s/it][A
3it [00:03,  1.26s/it][A
4it [00:05,  1.24s/it][A
5it [00:06,  1.23s/it][A
6it [00:07,  1.22s/it][A
7it [00:08,  1.22s/it][A
8it [00:09,  1.22s/it][A
9it [00:11,  1.22s/it][A
10it [00:12,  1.21s/it][A
11it [00:13,  1.22s/it][A
12it [00:14,  1.22s/it][A
13it [00:15,  1.21s/it][A
14it [00:17,  1.21s/it][A
15it [00:18,  1.21s/it][A
16it [00:19,  1.22s/it][A
17it [00:20,  1.22s/it][A
18it [00:22,  1.22s/it][A
19it [00:23,  1.21s/it][A
20it [00:24,  1.21s/it][A
21it [00:25,  1.22s/it][A
22it [00:26,  1.21s/it][A
23it [00:28,  1.21s/it][A
24it [00:29,  1.21s/it][A
25it [00:30,  1.21s/it][A
26it [00:31,  1.21s/it][A
27it [00:32,  1.22s/it][A
28it [00:34,  1.22s/it][A
29it [00:35,  1.22s/it][A
30it [00:36,  1.22s/it][A
31it [00:37,  1.22s/it][A
32it [00:39,  1.22s/it][A
33it [00:40,  1.22s/it][A
34it [00:41,  1.21s/it][A
35it [00:42,  1.22s/it][A
36it [00:43,  1.22s/it][A
37it [00:45,  


Best Model Saved!!



 60%|██████    | 3/5 [04:34<03:03, 91.52s/it]

Checkpoint Model Saved!

Epoch: 03
Train Total Loss: 0.056 | Train F1 Score: 0.989
Valid Total Loss: 0.702 | Valid F1 Score: 0.898
--------------------



0it [00:00, ?it/s][A
1it [00:01,  1.33s/it][A
2it [00:02,  1.28s/it][A
3it [00:03,  1.25s/it][A
4it [00:05,  1.24s/it][A
5it [00:06,  1.22s/it][A
6it [00:07,  1.22s/it][A
7it [00:08,  1.22s/it][A
8it [00:09,  1.22s/it][A
9it [00:11,  1.21s/it][A
10it [00:12,  1.21s/it][A
11it [00:13,  1.21s/it][A
12it [00:14,  1.21s/it][A
13it [00:15,  1.21s/it][A
14it [00:17,  1.22s/it][A
15it [00:18,  1.22s/it][A
16it [00:19,  1.22s/it][A
17it [00:20,  1.22s/it][A
18it [00:21,  1.22s/it][A
19it [00:23,  1.22s/it][A
20it [00:24,  1.21s/it][A
21it [00:25,  1.22s/it][A
22it [00:26,  1.22s/it][A
23it [00:28,  1.21s/it][A
24it [00:29,  1.21s/it][A
25it [00:30,  1.21s/it][A
26it [00:31,  1.21s/it][A
27it [00:32,  1.21s/it][A
28it [00:34,  1.21s/it][A
29it [00:35,  1.21s/it][A
30it [00:36,  1.21s/it][A
31it [00:37,  1.22s/it][A
32it [00:39,  1.22s/it][A
33it [00:40,  1.22s/it][A
34it [00:41,  1.22s/it][A
35it [00:42,  1.21s/it][A
36it [00:43,  1.21s/it][A
37it [00:45,  


Best Model Saved!!



 80%|████████  | 4/5 [06:06<01:31, 91.54s/it]

Checkpoint Model Saved!

Epoch: 04
Train Total Loss: 0.026 | Train F1 Score: 0.995
Valid Total Loss: 0.862 | Valid F1 Score: 0.909
--------------------



0it [00:00, ?it/s][A
1it [00:01,  1.33s/it][A
2it [00:02,  1.28s/it][A
3it [00:03,  1.25s/it][A
4it [00:04,  1.23s/it][A
5it [00:06,  1.22s/it][A
6it [00:07,  1.22s/it][A
7it [00:08,  1.22s/it][A
8it [00:09,  1.22s/it][A
9it [00:11,  1.22s/it][A
10it [00:12,  1.22s/it][A
11it [00:13,  1.21s/it][A
12it [00:14,  1.21s/it][A
13it [00:15,  1.21s/it][A
14it [00:17,  1.21s/it][A
15it [00:18,  1.22s/it][A
16it [00:19,  1.21s/it][A
17it [00:20,  1.21s/it][A
18it [00:21,  1.21s/it][A
19it [00:23,  1.21s/it][A
20it [00:24,  1.22s/it][A
21it [00:25,  1.21s/it][A
22it [00:26,  1.22s/it][A
23it [00:28,  1.22s/it][A
24it [00:29,  1.22s/it][A
25it [00:30,  1.23s/it][A
26it [00:31,  1.23s/it][A
27it [00:33,  1.23s/it][A
28it [00:34,  1.23s/it][A
29it [00:35,  1.22s/it][A
30it [00:36,  1.22s/it][A
31it [00:37,  1.22s/it][A
32it [00:39,  1.22s/it][A
33it [00:40,  1.22s/it][A
34it [00:41,  1.22s/it][A
35it [00:42,  1.22s/it][A
36it [00:43,  1.21s/it][A
37it [00:45,  


Best Model Saved!!



100%|██████████| 5/5 [07:37<00:00, 91.53s/it]

Checkpoint Model Saved!

Epoch: 05
Train Total Loss: 0.028 | Train F1 Score: 0.996
Valid Total Loss: 0.918 | Valid F1 Score: 0.905
--------------------





In [None]:
print(classification_report(best_target, best_pred))

              precision    recall  f1-score   support

           0       0.86      0.75      0.80       195
           1       0.87      0.94      0.90       365

    accuracy                           0.87       560
   macro avg       0.87      0.84      0.85       560
weighted avg       0.87      0.87      0.87       560



In [None]:
from google.colab import drive
drive.mount('/content/gdrive')

Mounted at /content/gdrive


In [None]:
# Save the weights to drive
from glob import glob
import time
for filepath in glob("*.pt"):
    !cp -r $filepath /content/gdrive/My\ Drive/Colab\ Notebooks/NLP\ Final\ Project/
    time.sleep(10)

In [None]:
# Loading the saved model
output_model = 'model_least_loss_rob.pt'

model_test = RobertaClassifier(transformer).to(device)
model_test.load_state_dict(torch.load(output_model, map_location=device))

<All keys matched successfully>

In [None]:
# Training set

# Set model to evaluation
model_test.eval()

y_pred_train, y_true_train = [], []
train_indexes_list = []

for index, batch in enumerate(training_dataloader):
    batch = tuple(row.to(device) for row in batch)
    input_ids, attn_mask, target, indexes = batch
    
    with torch.no_grad():
        output = model_test(input_ids, attn_mask)
        
        y_pred_train.extend(torch.argmax(output, -1).tolist())
        y_true_train.extend(target.tolist())
        train_indexes_list.extend(indexes.tolist())

In [None]:
print(f"F1-score: {f1_score(y_true_train, y_pred_train)}\n", f"Classification report: \n{classification_report(y_true_train, y_pred_train)}", sep='\n')

F1-score: 0.9996512033484479

Classification report: 
              precision    recall  f1-score   support

           0       1.00      1.00      1.00       739
           1       1.00      1.00      1.00      1433

    accuracy                           1.00      2172
   macro avg       1.00      1.00      1.00      2172
weighted avg       1.00      1.00      1.00      2172



In [None]:
# Validation set
y_pred_valid, y_true_valid = [], []
valid_indexes_list = []

for index, batch in enumerate(validation_dataloader):
    batch = tuple(row.to(device) for row in batch)
    input_ids, attn_mask, target, indexes = batch
    
    with torch.no_grad():
        output = model_test(input_ids, attn_mask)
        
        y_pred_valid.extend(torch.argmax(output, -1).tolist())
        y_true_valid.extend(target.tolist())
        valid_indexes_list.extend(indexes.tolist())

In [None]:
print(f"F1-score: {f1_score(y_true_valid, y_pred_valid)}\n", f"Classification report: \n{classification_report(y_true_valid, y_pred_valid)}", sep='\n')

F1-score: 0.9047619047619047

Classification report: 
              precision    recall  f1-score   support

           0       0.86      0.75      0.80       195
           1       0.87      0.94      0.90       365

    accuracy                           0.87       560
   macro avg       0.87      0.84      0.85       560
weighted avg       0.87      0.87      0.87       560



In [None]:
training.loc[train_indexes_list, 'ADR'] = y_pred_train
validation.loc[valid_indexes_list, 'ADR'] = y_pred_valid

In [None]:
training.to_csv("training_data_with_ADR.csv")
validation.to_csv("validation_data_with_ADR.csv")

In [None]:
# Cross-check that it is correctly mapped
print(f"F1-score: {f1_score(training.label, training.ADR)}\n", f"Classification report: \n{classification_report(training.label, training.ADR)}", sep='\n')

F1-score: 0.9996512033484479

Classification report: 
              precision    recall  f1-score   support

           0       1.00      1.00      1.00       739
           1       1.00      1.00      1.00      1433

    accuracy                           1.00      2172
   macro avg       1.00      1.00      1.00      2172
weighted avg       1.00      1.00      1.00      2172

