In [1]:
import numpy as np
import pickle
import os

seed = 2023

In [2]:
import torch

# set device
device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cuda'

In [3]:
import transformers

# set to only report critical errors to avoid excessing logging
transformers.utils.logging.set_verbosity(50)

In [4]:
from nlpsig_networks.scripts.fine_tune_bert_classification import (
    fine_tune_transformer_average_seed,
)

In [5]:
output_dir = "rumours_output"
if not os.path.isdir(output_dir):
    os.makedirs(output_dir)

## Rumours

In [6]:
%run load_sbert-embeddings.py

In [7]:
df_rumours.head()

Unnamed: 0,id,label,datetime,text,timeline_id,set
0,5.249902e+17,0,2014-10-22 18:26:23,Police have clarified that there were two shoo...,0,train
1,5.249906e+17,0,2014-10-22 18:27:58,"@CTVNews you guys ""confirmed"" there were 3 sho...",0,train
2,5.249908e+17,1,2014-10-22 18:28:46,@CTVNews get it right. http://t.co/GHYxMuzPG9,0,train
3,5.249927e+17,1,2014-10-22 18:36:29,RT @CTVNews Police have clarified that there w...,0,train
4,5.250038e+17,1,2014-10-22 19:20:41,@CTVNews @ctvsaskatoon so what happened at Rid...,0,train


## Baseline: Fine-tine BERT for classification

In [8]:
num_epochs = 5
seeds = [1, 12, 123]
validation_metric = "f1"

In [9]:
label_to_id

{'0': 0, '1': 1}

In [10]:
id_to_label

{0: '0', 1: '1'}

In [11]:
kwargs = {
    "num_epochs": num_epochs,
    "pretrained_model_name": "bert-base-uncased",
    "df": df_rumours,
    "feature_name": "text",
    "label_column": "label",
    "label_to_id": label_to_id,
    "id_to_label": id_to_label,
    "output_dim": output_dim,
    "seeds": seeds,
    "device": device,
    "batch_size": 8,
    "split_ids": split_ids,
    "k_fold": True,
    "validation_metric": validation_metric,
    "verbose": False,
}

## Focal Loss

In [12]:
loss = "focal"
gamma = 2

In [13]:
bert_classifier = fine_tune_transformer_average_seed(
    loss=loss,
    gamma=gamma,
    results_output=f"{output_dir}/bert_classifier_focal.csv",
    **kwargs,
)

Map:   0%|          | 0/5568 [00:00<?, ? examples/s]

Map:   0%|          | 0/5568 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Validation accuracy,Validation f1
1,No log,0.175128,0.78135,0.769612
2,0.215100,0.229847,0.801929,0.782294
3,0.130100,0.26366,0.787781,0.763419
4,0.062600,0.437864,0.789068,0.772723
5,0.062600,0.576602,0.786495,0.768369


Map:   0%|          | 0/5568 [00:00<?, ? examples/s]

Map:   0%|          | 0/5568 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Validation accuracy,Validation f1
1,No log,0.180301,0.790953,0.768101
2,0.211600,0.162821,0.768334,0.763853
3,0.132200,0.206907,0.806717,0.79541
4,0.132200,0.341469,0.795751,0.780634
5,0.070500,0.425662,0.809459,0.796885


Map:   0%|          | 0/5568 [00:00<?, ? examples/s]

Map:   0%|          | 0/5568 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Validation accuracy,Validation f1
1,No log,0.24009,0.648241,0.393293
2,0.214200,0.260987,0.777459,0.747333
3,0.152300,0.22076,0.781048,0.755537
4,0.152300,0.336756,0.767408,0.746393
5,0.088400,0.407268,0.769562,0.745787


Map:   0%|          | 0/5568 [00:00<?, ? examples/s]

Map:   0%|          | 0/5568 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Validation accuracy,Validation f1
1,No log,0.192495,0.772575,0.746799
2,0.212900,0.17527,0.808696,0.800158
3,0.130800,0.198096,0.799331,0.793606
4,0.072800,0.289047,0.805351,0.789521
5,0.072800,0.455159,0.8,0.786648


Map:   0%|          | 0/5568 [00:00<?, ? examples/s]

Map:   0%|          | 0/5568 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Validation accuracy,Validation f1
1,No log,0.208769,0.75412,0.725217
2,0.216200,0.186639,0.782465,0.776928
3,0.134100,0.335128,0.777851,0.757038
4,0.068000,0.37907,0.787739,0.777919
5,0.068000,0.526244,0.777851,0.764904


Map:   0%|          | 0/5568 [00:00<?, ? examples/s]

Map:   0%|          | 0/5568 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Validation accuracy,Validation f1
1,No log,0.173963,0.805145,0.789334
2,0.204900,0.172301,0.814148,0.804817
3,0.127600,0.246576,0.821222,0.803831
4,0.075500,0.311917,0.813505,0.80146
5,0.075500,0.427955,0.810289,0.796264


Map:   0%|          | 0/5568 [00:00<?, ? examples/s]

Map:   0%|          | 0/5568 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Validation accuracy,Validation f1
1,No log,0.276985,0.770391,0.731051
2,0.212500,0.203743,0.801234,0.784101
3,0.138400,0.246892,0.801919,0.789078
4,0.138400,0.473059,0.795751,0.778145
5,0.064600,0.562765,0.797121,0.77831


Map:   0%|          | 0/5568 [00:00<?, ? examples/s]

Map:   0%|          | 0/5568 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Validation accuracy,Validation f1
1,No log,0.194702,0.750179,0.703843
2,0.212900,0.212851,0.772434,0.752645
3,0.149500,0.255148,0.748026,0.740709
4,0.149500,0.252674,0.776023,0.761387
5,0.093000,0.331272,0.778177,0.757138


Map:   0%|          | 0/5568 [00:00<?, ? examples/s]

Map:   0%|          | 0/5568 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Validation accuracy,Validation f1
1,No log,0.19004,0.755853,0.708604
2,0.221300,0.202323,0.807358,0.798527
3,0.140800,0.190588,0.807358,0.795708
4,0.059900,0.416343,0.795987,0.781253
5,0.059900,0.509129,0.799331,0.783941


Map:   0%|          | 0/5568 [00:00<?, ? examples/s]

Map:   0%|          | 0/5568 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Validation accuracy,Validation f1
1,No log,0.201363,0.760053,0.742978
2,0.211700,0.190972,0.77851,0.775523
3,0.117600,0.310327,0.768622,0.756079
4,0.036600,0.547778,0.7706,0.756829
5,0.036600,0.714912,0.773896,0.76364


Map:   0%|          | 0/5568 [00:00<?, ? examples/s]

Map:   0%|          | 0/5568 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Validation accuracy,Validation f1
1,No log,0.228716,0.783923,0.76521
2,0.219800,0.171287,0.817363,0.802469
3,0.136100,0.186493,0.803215,0.793517
4,0.073000,0.388246,0.797428,0.780485
5,0.073000,0.565887,0.79164,0.772317


Map:   0%|          | 0/5568 [00:00<?, ? examples/s]

Map:   0%|          | 0/5568 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Validation accuracy,Validation f1
1,No log,0.232894,0.690199,0.593713
2,0.240000,0.172886,0.784099,0.778298
3,0.162300,0.26637,0.797121,0.773742
4,0.162300,0.296031,0.799178,0.786004
5,0.084000,0.427534,0.785469,0.769143


Map:   0%|          | 0/5568 [00:00<?, ? examples/s]

Map:   0%|          | 0/5568 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Validation accuracy,Validation f1
1,No log,0.205668,0.745872,0.672023
2,0.212600,0.210247,0.698492,0.696423
3,0.141400,0.281554,0.779612,0.765651
4,0.141400,0.331408,0.788227,0.766151
5,0.070000,0.445326,0.786073,0.762484


Map:   0%|          | 0/5568 [00:00<?, ? examples/s]

Map:   0%|          | 0/5568 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Validation accuracy,Validation f1
1,No log,0.205116,0.755853,0.753693
2,0.210200,0.174034,0.804682,0.795242
3,0.137800,0.29092,0.806689,0.793487
4,0.078600,0.281582,0.798662,0.790563
5,0.078600,0.402969,0.803344,0.790439


Map:   0%|          | 0/5568 [00:00<?, ? examples/s]

Map:   0%|          | 0/5568 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Validation accuracy,Validation f1
1,No log,0.193928,0.736322,0.735194
2,0.212600,0.174399,0.78708,0.781614
3,0.123100,0.330782,0.7706,0.763379
4,0.050800,0.557238,0.767304,0.754534
5,0.050800,0.615424,0.771259,0.759459


saving the results dataframe to CSV in rumours_output/bert_classifier_focal.csv


In [14]:
bert_classifier

Unnamed: 0,accuracy,f1,f1_scores,precision,precision_scores,recall,recall_scores,seed,loss,gamma,k_fold
0,0.601491,0.560027,"[0.6950941243582431, 0.4249596557288865]",0.560329,"[0.6915437003405221, 0.42911461162411735]",0.559783,"[0.6986811926605505, 0.42088438998401706]",1,focal,2,True
0,0.62013,0.543761,"[0.7304232804232805, 0.3570977917981073]",0.557893,"[0.6780451866404715, 0.4377416860015468]",0.546558,"[0.7915711009174312, 0.3015450186467768]",12,focal,2,True
0,0.600559,0.548095,"[0.7020714583622966, 0.3941193101498446]",0.550695,"[0.6815114709851552, 0.4198795180722892]",0.547624,"[0.7239105504587156, 0.37133724027703785]",123,focal,2,True


In [15]:
bert_classifier["f1"].mean()

0.5506276034701097

In [16]:
bert_classifier["precision"].mean()

0.556306028944017

In [17]:
bert_classifier["recall"].mean()

0.5513215821574214

In [18]:
np.stack(bert_classifier["f1_scores"]).mean(axis=0)

array([0.70919629, 0.39205892])

In [19]:
np.stack(bert_classifier["precision_scores"]).mean(axis=0)

array([0.68370012, 0.42891194])

In [20]:
np.stack(bert_classifier["recall_scores"]).mean(axis=0)

array([0.73805428, 0.36458888])

## Using Cross-Entropy loss

In [21]:
loss = "cross_entropy"
gamma = None

In [22]:
bert_classifier_ce = fine_tune_transformer_average_seed(
    loss=loss,
    gamma=gamma,
    results_output=f"{output_dir}/bert_classifier_ce.csv",
    **kwargs,
)

Map:   0%|          | 0/5568 [00:00<?, ? examples/s]

Map:   0%|          | 0/5568 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Validation accuracy,Validation f1
1,No log,0.534237,0.755627,0.74951
2,0.585600,0.523375,0.804502,0.784382
3,0.405400,0.532079,0.801286,0.782047
4,0.296200,0.760756,0.803215,0.791064
5,0.296200,0.91306,0.798071,0.784738


Map:   0%|          | 0/5568 [00:00<?, ? examples/s]

Map:   0%|          | 0/5568 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Validation accuracy,Validation f1
1,No log,0.55986,0.732008,0.67919
2,0.606000,0.450734,0.797121,0.780899
3,0.429900,0.553459,0.801234,0.79099
4,0.429900,0.822448,0.800548,0.787922
5,0.296700,0.938725,0.795751,0.780803


Map:   0%|          | 0/5568 [00:00<?, ? examples/s]

Map:   0%|          | 0/5568 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Validation accuracy,Validation f1
1,No log,0.515933,0.76023,0.722948
2,0.546500,0.53196,0.773869,0.739809
3,0.341900,0.75834,0.79397,0.7635
4,0.341900,1.058996,0.778177,0.763948
5,0.179000,1.216479,0.781048,0.758224


Map:   0%|          | 0/5568 [00:00<?, ? examples/s]

Map:   0%|          | 0/5568 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Validation accuracy,Validation f1
1,No log,0.559142,0.77592,0.745348
2,0.562800,0.552186,0.810033,0.788503
3,0.380600,0.555676,0.820736,0.804155
4,0.294500,0.823779,0.816722,0.802666
5,0.294500,0.857257,0.820736,0.808144


Map:   0%|          | 0/5568 [00:00<?, ? examples/s]

Map:   0%|          | 0/5568 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Validation accuracy,Validation f1
1,No log,0.552789,0.74555,0.712983
2,0.612300,0.534647,0.760053,0.757521
3,0.423300,0.552274,0.789057,0.779803
4,0.274700,0.889837,0.791035,0.780917
5,0.274700,1.050337,0.783125,0.772076


Map:   0%|          | 0/5568 [00:00<?, ? examples/s]

Map:   0%|          | 0/5568 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Validation accuracy,Validation f1
1,No log,0.547582,0.752412,0.738554
2,0.596000,0.480532,0.792283,0.780929
3,0.426900,0.595866,0.814148,0.797325
4,0.321100,0.833292,0.797428,0.781157
5,0.321100,0.875806,0.800643,0.787334


Map:   0%|          | 0/5568 [00:00<?, ? examples/s]

Map:   0%|          | 0/5568 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Validation accuracy,Validation f1
1,No log,0.616596,0.716929,0.713905
2,0.573900,0.587312,0.779986,0.775511
3,0.397300,0.750646,0.801919,0.786166
4,0.397300,0.816148,0.799178,0.781776
5,0.294700,0.955636,0.800548,0.788072


Map:   0%|          | 0/5568 [00:00<?, ? examples/s]

Map:   0%|          | 0/5568 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Validation accuracy,Validation f1
1,No log,0.54022,0.73654,0.699338
2,0.567300,0.589955,0.757358,0.728379
3,0.396400,0.595849,0.763101,0.744592
4,0.396400,0.797275,0.768844,0.74231
5,0.256400,0.990864,0.767408,0.745932


Map:   0%|          | 0/5568 [00:00<?, ? examples/s]

Map:   0%|          | 0/5568 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Validation accuracy,Validation f1
1,No log,0.550009,0.735117,0.669779
2,0.623900,0.514651,0.799331,0.776374
3,0.398300,0.643668,0.812709,0.804575
4,0.264800,0.774677,0.812709,0.803055
5,0.264800,0.888565,0.810033,0.798132


Map:   0%|          | 0/5568 [00:00<?, ? examples/s]

Map:   0%|          | 0/5568 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Validation accuracy,Validation f1
1,No log,0.51241,0.758734,0.742125
2,0.562500,0.5406,0.789717,0.783846
3,0.384100,0.700819,0.777192,0.763501
4,0.240500,0.932919,0.782465,0.770161
5,0.240500,1.15946,0.785761,0.77512


Map:   0%|          | 0/5568 [00:00<?, ? examples/s]

Map:   0%|          | 0/5568 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Validation accuracy,Validation f1
1,No log,0.496857,0.793569,0.777656
2,0.545800,0.557792,0.810289,0.791216
3,0.356700,0.760594,0.805145,0.786458
4,0.219400,0.943866,0.811576,0.795338
5,0.219400,1.018637,0.809003,0.793348


Map:   0%|          | 0/5568 [00:00<?, ? examples/s]

Map:   0%|          | 0/5568 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Validation accuracy,Validation f1
1,No log,0.677707,0.664154,0.544838
2,0.628400,0.45981,0.804661,0.785889
3,0.449700,0.6078,0.799863,0.77523
4,0.449700,0.796474,0.780672,0.763701
5,0.269500,1.001797,0.793009,0.778534


Map:   0%|          | 0/5568 [00:00<?, ? examples/s]

Map:   0%|          | 0/5568 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Validation accuracy,Validation f1
1,No log,0.591444,0.715721,0.610108
2,0.552700,0.612183,0.750179,0.738449
3,0.393100,0.716231,0.772434,0.759418
4,0.393100,1.031243,0.769562,0.733243
5,0.232100,1.167606,0.765973,0.734887


Map:   0%|          | 0/5568 [00:00<?, ? examples/s]

Map:   0%|          | 0/5568 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Validation accuracy,Validation f1
1,No log,0.522766,0.754515,0.743486
2,0.602600,0.470483,0.797324,0.788218
3,0.408000,0.595292,0.808027,0.800734
4,0.303300,0.710465,0.81204,0.802041
5,0.303300,0.894472,0.813378,0.802025


Map:   0%|          | 0/5568 [00:00<?, ? examples/s]

Map:   0%|          | 0/5568 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Validation accuracy,Validation f1
1,No log,0.536587,0.748187,0.718452
2,0.603400,0.533451,0.784443,0.769355
3,0.381100,0.763288,0.76269,0.759635
4,0.263900,0.957753,0.780488,0.769722
5,0.263900,1.144476,0.7706,0.75778


saving the results dataframe to CSV in rumours_output/bert_classifier_ce.csv


In [23]:
bert_classifier_ce

Unnamed: 0,accuracy,f1,f1_scores,precision,precision_scores,recall,recall_scores,seed,loss,gamma,k_fold
0,0.617521,0.564434,"[0.716496269687759, 0.41237113402061853]",0.568644,"[0.6914666666666667, 0.4458204334365325]",0.563498,"[0.7434059633027523, 0.38359083644112946]",1,cross_entropy,,True
0,0.623672,0.588426,"[0.7088680605623648, 0.4679841897233201]",0.588033,"[0.7130838410211778, 0.4629822732012513]",0.588899,"[0.7047018348623854, 0.4730953649440597]",12,cross_entropy,,True
0,0.625909,0.540592,"[0.7385697538100822, 0.34261382246970196]",0.56075,"[0.6767724994031988, 0.4447278911564626]",0.545711,"[0.8127866972477065, 0.27863612147043154]",123,cross_entropy,,True


In [24]:
bert_classifier_ce["f1"].mean()

0.5644838717123077

In [25]:
bert_classifier_ce["precision"].mean()

0.572475600814215

In [26]:
bert_classifier_ce["recall"].mean()

0.5660361363780776

In [27]:
np.stack(bert_classifier_ce["f1_scores"]).mean(axis=0)

array([0.72131136, 0.40765638])

In [28]:
np.stack(bert_classifier_ce["precision_scores"]).mean(axis=0)

array([0.69377434, 0.45117687])

In [29]:
np.stack(bert_classifier_ce["recall_scores"]).mean(axis=0)

array([0.7536315 , 0.37844077])