In [1]:
import numpy as np
import pickle
import os

seed = 2023

In [2]:
import torch

# set device
device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cuda'

In [3]:
import transformers

# set to only report critical errors to avoid excessing logging
transformers.utils.logging.set_verbosity(50)

In [4]:
from nlpsig_networks.scripts.fine_tune_bert_classification import (
    fine_tune_transformer_average_seed,
)

In [5]:
output_dir = "rumours_output"
if not os.path.isdir(output_dir):
    os.makedirs(output_dir)

## Rumours

In [6]:
%run load_sbert-embeddings.py

In [7]:
df_rumours.head()

Unnamed: 0,id,label,datetime,text,timeline_id,set
0,5.249902e+17,0,2014-10-22 18:26:23,Police have clarified that there were two shoo...,0,train
1,5.249906e+17,0,2014-10-22 18:27:58,"@CTVNews you guys ""confirmed"" there were 3 sho...",0,train
2,5.249908e+17,1,2014-10-22 18:28:46,@CTVNews get it right. http://t.co/GHYxMuzPG9,0,train
3,5.249927e+17,1,2014-10-22 18:36:29,RT @CTVNews Police have clarified that there w...,0,train
4,5.250038e+17,1,2014-10-22 19:20:41,@CTVNews @ctvsaskatoon so what happened at Rid...,0,train


## Baseline: Fine-tine BERT for classification

In [8]:
num_epochs = 5
seeds = [1, 12, 123]
validation_metric = "f1"

In [9]:
kwargs = {
    "num_epochs": num_epochs,
    "pretrained_model_name": "bert-base-uncased",
    "df": df_rumours,
    "feature_name": "text",
    "label_column": "label",
    "seeds": seeds,
    "split_ids": split_ids,
    "k_fold": True,
    "validation_metric": validation_metric,
    "device": device,
    "verbose": False,
}

## Focal Loss

In [10]:
loss = "focal"
gamma = 2

In [11]:
bert_classifier = fine_tune_transformer_average_seed(
    loss=loss,
    gamma=gamma,
    results_output=f"{output_dir}/bert_classifier_focal.csv",
    **kwargs,
)

Map:   0%|          | 0/5568 [00:00<?, ? examples/s]

Map:   0%|          | 0/5568 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Accuracy,F1
1,No log,0.201621,0.710611,0.706456
2,No log,0.168554,0.773633,0.769138
3,No log,0.179551,0.787781,0.778088
4,No log,0.22282,0.783923,0.771469
5,No log,0.308203,0.798714,0.784298


  0%|          | 0/856 [00:00<?, ?it/s]

Map:   0%|          | 0/5568 [00:00<?, ? examples/s]

Map:   0%|          | 0/5568 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Accuracy,F1
1,No log,0.221911,0.681289,0.663439
2,No log,0.174878,0.772447,0.766064
3,No log,0.169233,0.782042,0.769604
4,No log,0.208359,0.784099,0.7696
5,No log,0.276813,0.782042,0.770366


  0%|          | 0/1148 [00:00<?, ?it/s]

Map:   0%|          | 0/5568 [00:00<?, ? examples/s]

Map:   0%|          | 0/5568 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Accuracy,F1
1,No log,0.206696,0.748026,0.707228
2,No log,0.200577,0.766691,0.736719
3,No log,0.248392,0.748026,0.687757
4,No log,0.261207,0.765973,0.737747
5,No log,0.357613,0.770998,0.741011


  0%|          | 0/1349 [00:00<?, ?it/s]

Map:   0%|          | 0/5568 [00:00<?, ? examples/s]

Map:   0%|          | 0/5568 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Accuracy,F1
1,No log,0.205713,0.682274,0.678593
2,No log,0.164696,0.802007,0.781795
3,No log,0.182934,0.795987,0.780925
4,No log,0.235859,0.78796,0.76704
5,No log,0.288463,0.799331,0.786162


  0%|          | 0/1039 [00:00<?, ?it/s]

Map:   0%|          | 0/5568 [00:00<?, ? examples/s]

Map:   0%|          | 0/5568 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Accuracy,F1
1,No log,0.212702,0.703362,0.697281
2,No log,0.178537,0.766645,0.758475
3,No log,0.225645,0.779169,0.764219
4,No log,0.341089,0.749506,0.721507
5,No log,0.348777,0.773896,0.761937


  0%|          | 0/973 [00:00<?, ?it/s]

Map:   0%|          | 0/5568 [00:00<?, ? examples/s]

Map:   0%|          | 0/5568 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Accuracy,F1
1,No log,0.226092,0.659807,0.578783
2,No log,0.173752,0.776849,0.768922
3,No log,0.18936,0.793569,0.778785
4,No log,0.251466,0.787138,0.773315
5,No log,0.294368,0.785852,0.77002


  0%|          | 0/856 [00:00<?, ?it/s]

Map:   0%|          | 0/5568 [00:00<?, ? examples/s]

Map:   0%|          | 0/5568 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Accuracy,F1
1,No log,0.209956,0.710761,0.700893
2,No log,0.179201,0.786155,0.760802
3,No log,0.194814,0.790267,0.773679
4,No log,0.237652,0.790953,0.778647
5,No log,0.272853,0.798492,0.78641


  0%|          | 0/1148 [00:00<?, ?it/s]

Map:   0%|          | 0/5568 [00:00<?, ? examples/s]

Map:   0%|          | 0/5568 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Accuracy,F1
1,No log,0.227259,0.695621,0.593967
2,No log,0.222563,0.745872,0.680788
3,No log,0.208049,0.758076,0.729607
4,No log,0.262142,0.754487,0.732061
5,No log,0.347008,0.768844,0.736958


  0%|          | 0/1349 [00:00<?, ?it/s]

Map:   0%|          | 0/5568 [00:00<?, ? examples/s]

Map:   0%|          | 0/5568 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Accuracy,F1
1,No log,0.212769,0.713043,0.672149
2,No log,0.166624,0.797324,0.787465
3,No log,0.177124,0.785953,0.782013
4,No log,0.243856,0.794649,0.779981
5,No log,0.252143,0.795987,0.782675


  0%|          | 0/1039 [00:00<?, ?it/s]

Map:   0%|          | 0/5568 [00:00<?, ? examples/s]

Map:   0%|          | 0/5568 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Accuracy,F1
1,No log,0.227748,0.640079,0.638721
2,No log,0.195814,0.733026,0.729946
3,No log,0.204088,0.75412,0.751821
4,No log,0.242569,0.767304,0.763103
5,No log,0.349559,0.764008,0.753232


  0%|          | 0/973 [00:00<?, ?it/s]

Map:   0%|          | 0/5568 [00:00<?, ? examples/s]

Map:   0%|          | 0/5568 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Accuracy,F1
1,No log,0.227744,0.708682,0.614695
2,No log,0.175406,0.767846,0.763191
3,No log,0.163093,0.786495,0.778857
4,No log,0.211596,0.789068,0.774518
5,No log,0.271482,0.789711,0.774163


  0%|          | 0/856 [00:00<?, ?it/s]

Map:   0%|          | 0/5568 [00:00<?, ? examples/s]

Map:   0%|          | 0/5568 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Accuracy,F1
1,No log,0.230945,0.610692,0.610612
2,No log,0.171977,0.801919,0.788305
3,No log,0.183311,0.802605,0.790624
4,No log,0.240617,0.788897,0.767458
5,No log,0.243567,0.804661,0.791389


  0%|          | 0/1148 [00:00<?, ?it/s]

Map:   0%|          | 0/5568 [00:00<?, ? examples/s]

Map:   0%|          | 0/5568 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Accuracy,F1
1,No log,0.21386,0.73295,0.665101
2,No log,0.213874,0.728643,0.719687
3,No log,0.236275,0.758794,0.729702
4,No log,0.244598,0.745872,0.72947
5,No log,0.337191,0.75664,0.724773


  0%|          | 0/1349 [00:00<?, ?it/s]

Map:   0%|          | 0/5568 [00:00<?, ? examples/s]

Map:   0%|          | 0/5568 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Accuracy,F1
1,No log,0.210711,0.67893,0.673826
2,No log,0.172196,0.783278,0.751173
3,No log,0.159413,0.803344,0.788586
4,No log,0.197641,0.8,0.784909
5,No log,0.256258,0.792642,0.778877


  0%|          | 0/1039 [00:00<?, ?it/s]

Map:   0%|          | 0/5568 [00:00<?, ? examples/s]

Map:   0%|          | 0/5568 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Accuracy,F1
1,No log,0.218569,0.644034,0.64338
2,No log,0.180713,0.785761,0.779174
3,No log,0.210476,0.725775,0.725593
4,No log,0.306362,0.760712,0.745402
5,No log,0.349985,0.774555,0.764261


  0%|          | 0/973 [00:00<?, ?it/s]

saving the results dataframe to CSV in rumours_output/bert_classifier_focal.csv


In [12]:
bert_classifier

Unnamed: 0,accuracy,f1,f1_scores,precision,precision_scores,recall,recall_scores,seed,loss,gamma,k_fold
0,0.628705,0.589027,"[0.7167235494880547, 0.46133044889129254]",0.589742,"[0.7110609480812641, 0.4684239428885228]",0.588463,"[0.7224770642201835, 0.45444858817261585]",1,focal,2,True
0,0.624231,0.579203,"[0.7168539325842697, 0.44155124653739614]",0.58127,"[0.7026431718061674, 0.4598961338718984]",0.578133,"[0.731651376146789, 0.4246137453383058]",12,focal,2,True
0,0.609506,0.559226,"[0.7080953044447542, 0.4103574444131719]",0.561884,"[0.6888045540796964, 0.434964200477327]",0.558442,"[0.7284977064220184, 0.38838572189664355]",123,focal,2,True


In [13]:
bert_classifier["f1"].mean()

0.5758186543931565

In [14]:
bert_classifier["precision"].mean()

0.5776321585341461

In [15]:
bert_classifier["recall"].mean()

0.5750123670327594

In [16]:
np.stack(bert_classifier["f1_scores"]).mean(axis=0)

array([0.71389093, 0.43774638])

In [17]:
np.stack(bert_classifier["precision_scores"]).mean(axis=0)

array([0.70083622, 0.45442809])

In [18]:
np.stack(bert_classifier["recall_scores"]).mean(axis=0)

array([0.72754205, 0.42248269])

## Using Cross-Entropy loss

In [19]:
loss = "cross_entropy"
gamma = None

In [20]:
bert_classifier_ce = fine_tune_transformer_average_seed(
    loss=loss,
    gamma=gamma,
    results_output=f"{output_dir}/bert_classifier_ce.csv",
    **kwargs,
)

Map:   0%|          | 0/5568 [00:00<?, ? examples/s]

Map:   0%|          | 0/5568 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Accuracy,F1
1,No log,0.52742,0.757556,0.74267
2,No log,0.460328,0.796141,0.781541
3,No log,0.502343,0.799357,0.774638
4,No log,0.55555,0.801286,0.786902
5,No log,0.584539,0.801929,0.789279


  0%|          | 0/856 [00:00<?, ?it/s]

Map:   0%|          | 0/5568 [00:00<?, ? examples/s]

Map:   0%|          | 0/5568 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Accuracy,F1
1,No log,0.564482,0.703221,0.695925
2,No log,0.471149,0.779986,0.77344
3,No log,0.469479,0.793009,0.781343
4,No log,0.571571,0.782728,0.762892
5,No log,0.622543,0.788211,0.771184


  0%|          | 0/1148 [00:00<?, ?it/s]

Map:   0%|          | 0/5568 [00:00<?, ? examples/s]

Map:   0%|          | 0/5568 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Accuracy,F1
1,No log,0.547329,0.733668,0.658365
2,No log,0.517336,0.778894,0.755246
3,No log,0.562546,0.765973,0.724041
4,No log,0.593968,0.78033,0.758249
5,No log,0.619812,0.782484,0.760733


  0%|          | 0/1349 [00:00<?, ?it/s]

Map:   0%|          | 0/5568 [00:00<?, ? examples/s]

Map:   0%|          | 0/5568 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Accuracy,F1
1,No log,0.682851,0.597324,0.594669
2,No log,0.457079,0.801338,0.777862
3,No log,0.447292,0.812709,0.800276
4,No log,0.480813,0.808027,0.798068
5,No log,0.478057,0.809365,0.797905


  0%|          | 0/1039 [00:00<?, ?it/s]

Map:   0%|          | 0/5568 [00:00<?, ? examples/s]

Map:   0%|          | 0/5568 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Accuracy,F1
1,No log,0.537948,0.733685,0.728829
2,No log,0.480495,0.780488,0.773416
3,No log,0.628094,0.767963,0.745466
4,No log,0.624019,0.764667,0.749264
5,No log,0.661223,0.769281,0.757453


  0%|          | 0/973 [00:00<?, ?it/s]

Map:   0%|          | 0/5568 [00:00<?, ? examples/s]

Map:   0%|          | 0/5568 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Accuracy,F1
1,No log,0.595622,0.66881,0.539504
2,No log,0.449915,0.798714,0.783517
3,No log,0.487879,0.803859,0.792222
4,No log,0.553569,0.797428,0.781487
5,No log,0.603184,0.803859,0.789811


  0%|          | 0/856 [00:00<?, ?it/s]

Map:   0%|          | 0/5568 [00:00<?, ? examples/s]

Map:   0%|          | 0/5568 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Accuracy,F1
1,No log,0.526986,0.751199,0.728485
2,No log,0.506189,0.780672,0.750005
3,No log,0.548216,0.778615,0.753224
4,No log,0.545199,0.790267,0.773497
5,No log,0.583109,0.784099,0.767668


  0%|          | 0/1148 [00:00<?, ?it/s]

Map:   0%|          | 0/5568 [00:00<?, ? examples/s]

Map:   0%|          | 0/5568 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Accuracy,F1
1,No log,0.633636,0.678392,0.624673
2,No log,0.56477,0.75664,0.715449
3,No log,0.539606,0.760948,0.720127
4,No log,0.560154,0.758794,0.737935
5,No log,0.624117,0.761665,0.738205


  0%|          | 0/1349 [00:00<?, ?it/s]

Map:   0%|          | 0/5568 [00:00<?, ? examples/s]

Map:   0%|          | 0/5568 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Accuracy,F1
1,No log,0.658642,0.624749,0.389525
2,No log,0.480413,0.79398,0.771491
3,No log,0.475983,0.79398,0.783088
4,No log,0.525761,0.79398,0.782673
5,No log,0.565612,0.808027,0.796349


  0%|          | 0/1039 [00:00<?, ?it/s]

Map:   0%|          | 0/5568 [00:00<?, ? examples/s]

Map:   0%|          | 0/5568 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Accuracy,F1
1,No log,0.604106,0.669743,0.662909
2,No log,0.521429,0.759394,0.743823
3,No log,0.543982,0.759394,0.752896
4,No log,0.58588,0.760712,0.75347
5,No log,0.69291,0.765985,0.754655


  0%|          | 0/973 [00:00<?, ?it/s]

Map:   0%|          | 0/5568 [00:00<?, ? examples/s]

Map:   0%|          | 0/5568 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Accuracy,F1
1,No log,0.613685,0.70418,0.607085
2,No log,0.465573,0.798071,0.781435
3,No log,0.473249,0.805145,0.792907
4,No log,0.556424,0.800643,0.781615
5,No log,0.603633,0.801286,0.785162


  0%|          | 0/856 [00:00<?, ?it/s]

Map:   0%|          | 0/5568 [00:00<?, ? examples/s]

Map:   0%|          | 0/5568 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Accuracy,F1
1,No log,0.580482,0.697053,0.692423
2,No log,0.453034,0.799178,0.786464
3,No log,0.454246,0.799178,0.790188
4,No log,0.555573,0.790267,0.768755
5,No log,0.566697,0.800548,0.78684


  0%|          | 0/1148 [00:00<?, ?it/s]

Map:   0%|          | 0/5568 [00:00<?, ? examples/s]

Map:   0%|          | 0/5568 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Accuracy,F1
1,No log,0.606319,0.702082,0.568029
2,No log,0.531708,0.778177,0.751814
3,No log,0.52933,0.771716,0.749245
4,No log,0.567354,0.769562,0.742445
5,No log,0.608636,0.769562,0.744536


  0%|          | 0/1349 [00:00<?, ?it/s]

Map:   0%|          | 0/5568 [00:00<?, ? examples/s]

Map:   0%|          | 0/5568 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Accuracy,F1
1,No log,0.489326,0.780602,0.766038
2,No log,0.461967,0.785284,0.768549
3,No log,0.509898,0.793311,0.776678
4,No log,0.58211,0.789298,0.770867
5,No log,0.635305,0.788629,0.771709


  0%|          | 0/1039 [00:00<?, ?it/s]

Map:   0%|          | 0/5568 [00:00<?, ? examples/s]

Map:   0%|          | 0/5568 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Accuracy,F1
1,No log,0.638944,0.666447,0.573224
2,No log,0.51865,0.757416,0.748669
3,No log,0.567206,0.743573,0.740688
4,No log,0.61875,0.766645,0.754072
5,No log,0.695375,0.771259,0.756791


  0%|          | 0/973 [00:00<?, ?it/s]

saving the results dataframe to CSV in rumours_output/bert_classifier_ce.csv


In [21]:
bert_classifier_ce

Unnamed: 0,accuracy,f1,f1_scores,precision,precision_scores,recall,recall_scores,seed,loss,gamma,k_fold
0,0.638397,0.601872,"[0.7224606580829757, 0.48128342245989303]",0.602054,"[0.7210165619645916, 0.4830917874396135]",0.6017,"[0.7239105504587156, 0.47948854555141185]",1,cross_entropy,,True
0,0.638397,0.562956,"[0.7445351593363181, 0.3813775510204081]",0.581742,"[0.6885046273745737, 0.4749801429706116]",0.564543,"[0.810493119266055, 0.3185935002663825]",12,cross_entropy,,True
0,0.635042,0.598378,"[0.7197251646149442, 0.4770299145299145]",0.598502,"[0.7186963979416809, 0.47830744509908946]",0.598258,"[0.720756880733945, 0.4757591901971231]",123,cross_entropy,,True


In [22]:
bert_classifier_ce["f1"].mean()

0.5877353116740756

In [23]:
bert_classifier_ce["precision"].mean()

0.59409949379836

In [24]:
bert_classifier_ce["recall"].mean()

0.5881669644122721

In [25]:
np.stack(bert_classifier_ce["f1_scores"]).mean(axis=0)

array([0.72890699, 0.44656363])

In [26]:
np.stack(bert_classifier_ce["precision_scores"]).mean(axis=0)

array([0.70940586, 0.47879313])

In [27]:
np.stack(bert_classifier_ce["recall_scores"]).mean(axis=0)

array([0.75172018, 0.42461375])