In [1]:
import numpy as np
import pickle
import os

seed = 2023

In [2]:
import torch

# set device
device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cuda'

In [3]:
import transformers

# set to only report critical errors to avoid excessing logging
transformers.utils.logging.set_verbosity(50)

In [4]:
from nlpsig_networks.scripts.fine_tune_bert_classification import (
    fine_tune_transformer_average_seed,
)

In [5]:
output_dir = "rumours_output"
if not os.path.isdir(output_dir):
    os.makedirs(output_dir)

## Rumours

In [6]:
%run load_sbert-embeddings.py

In [7]:
df_rumours.head()

Unnamed: 0,id,label,datetime,text,timeline_id,set
0,5.249902e+17,0,2014-10-22 18:26:23,Police have clarified that there were two shoo...,0,train
1,5.249906e+17,0,2014-10-22 18:27:58,"@CTVNews you guys ""confirmed"" there were 3 sho...",0,train
2,5.249908e+17,1,2014-10-22 18:28:46,@CTVNews get it right. http://t.co/GHYxMuzPG9,0,train
3,5.249927e+17,1,2014-10-22 18:36:29,RT @CTVNews Police have clarified that there w...,0,train
4,5.250038e+17,1,2014-10-22 19:20:41,@CTVNews @ctvsaskatoon so what happened at Rid...,0,train


## Baseline: Fine-tine BERT for classification

In [8]:
num_epochs = 10
seeds = [1, 12, 123]
loss = "focal"
gamma = 2
validation_metric = "f1"
split_ids = torch.tensor(df_rumours['timeline_id'].astype(int))

In [9]:
bert_classifier = fine_tune_transformer_average_seed(
    num_epochs=num_epochs,
    pretrained_model_name="bert-base-uncased",
    df=df_rumours,
    feature_name="text",
    label_column="label",
    seeds=seeds,
    loss=loss,
    gamma=gamma,
    split_ids=split_ids,
    k_fold=True,
    validation_metric=validation_metric,
    results_output=f"{output_dir}/bert_classifier.csv",
    device=device,
    verbose=False,
)

Map:   0%|          | 0/5568 [00:00<?, ? examples/s]

Map:   0%|          | 0/5568 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Accuracy,F1
1,No log,0.204061,0.688746,0.687374
2,No log,0.168122,0.767846,0.76497
3,No log,0.171746,0.78135,0.774397
4,No log,0.238052,0.788424,0.77767
5,No log,0.401542,0.781994,0.75796
6,No log,0.420593,0.784566,0.769301
7,No log,0.498974,0.786495,0.765528
8,No log,0.567132,0.785852,0.765315
9,No log,0.551505,0.788424,0.772946
10,0.064900,0.56,0.785852,0.770515


  0%|          | 0/856 [00:00<?, ?it/s]

Map:   0%|          | 0/5568 [00:00<?, ? examples/s]

Map:   0%|          | 0/5568 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Accuracy,F1
1,No log,0.222291,0.694311,0.643889
2,No log,0.170878,0.790267,0.777069
3,No log,0.174684,0.782042,0.775824
4,No log,0.234802,0.784784,0.777442
5,No log,0.306338,0.764907,0.744321
6,No log,0.386139,0.781357,0.771434
7,No log,0.429505,0.783413,0.773104
8,No log,0.517221,0.782728,0.773791
9,No log,0.584737,0.793009,0.7782
10,No log,0.574154,0.790953,0.778342


  0%|          | 0/1148 [00:00<?, ?it/s]

Map:   0%|          | 0/5568 [00:00<?, ? examples/s]

Map:   0%|          | 0/5568 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Accuracy,F1
1,No log,0.206155,0.750897,0.704495
2,No log,0.192453,0.771716,0.744985
3,No log,0.251449,0.725054,0.637026
4,No log,0.23675,0.774587,0.744645
5,No log,0.345618,0.781048,0.75425
6,No log,0.481782,0.774587,0.740168
7,No log,0.518147,0.763819,0.745472
8,No log,0.566889,0.773869,0.749311
9,No log,0.577623,0.774587,0.750724
10,No log,0.595421,0.777459,0.74928


  0%|          | 0/1349 [00:00<?, ?it/s]

Map:   0%|          | 0/5568 [00:00<?, ? examples/s]

Map:   0%|          | 0/5568 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Accuracy,F1
1,No log,0.206432,0.670903,0.669512
2,No log,0.161797,0.80602,0.791933
3,No log,0.191499,0.785953,0.763245
4,No log,0.317277,0.766555,0.736082
5,No log,0.27909,0.786622,0.776905
6,No log,0.426307,0.795318,0.778759
7,No log,0.51367,0.793311,0.778051
8,No log,0.594165,0.796656,0.778594
9,No log,0.618347,0.798662,0.782631
10,No log,0.580221,0.797993,0.785778


  0%|          | 0/1039 [00:00<?, ?it/s]

Map:   0%|          | 0/5568 [00:00<?, ? examples/s]

Map:   0%|          | 0/5568 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Accuracy,F1
1,No log,0.208448,0.742914,0.728372
2,No log,0.1859,0.760053,0.749945
3,No log,0.197525,0.784443,0.774922
4,No log,0.302332,0.765985,0.744914
5,No log,0.51138,0.736981,0.700413
6,No log,0.622812,0.76203,0.74124
7,No log,0.474753,0.764008,0.750979
8,No log,0.54667,0.771259,0.760326
9,No log,0.586,0.769941,0.760447
10,No log,0.604227,0.7706,0.760804


  0%|          | 0/973 [00:00<?, ?it/s]

Map:   0%|          | 0/5568 [00:00<?, ? examples/s]

Map:   0%|          | 0/5568 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Accuracy,F1
1,No log,0.222978,0.641801,0.640191
2,No log,0.16753,0.78328,0.761491
3,No log,0.169268,0.799357,0.788061
4,No log,0.258595,0.792283,0.771981
5,No log,0.300341,0.784566,0.770884
6,No log,0.426137,0.790354,0.774109
7,No log,0.482873,0.792283,0.781725
8,No log,0.544243,0.782637,0.758998
9,No log,0.589594,0.785852,0.764523
10,0.070500,0.553925,0.793569,0.778467


  0%|          | 0/856 [00:00<?, ?it/s]

Map:   0%|          | 0/5568 [00:00<?, ? examples/s]

Map:   0%|          | 0/5568 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Accuracy,F1
1,No log,0.210638,0.670322,0.668225
2,No log,0.178897,0.793009,0.774596
3,No log,0.182372,0.800548,0.786521
4,No log,0.320088,0.792324,0.774328
5,No log,0.265847,0.793694,0.78853
6,No log,0.329755,0.802605,0.791893
7,No log,0.437721,0.787526,0.764858
8,No log,0.491499,0.792324,0.771753
9,No log,0.495515,0.798492,0.785351
10,No log,0.531499,0.799178,0.781776


  0%|          | 0/1148 [00:00<?, ?it/s]

Map:   0%|          | 0/5568 [00:00<?, ? examples/s]

Map:   0%|          | 0/5568 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Accuracy,F1
1,No log,0.221935,0.69921,0.596816
2,No log,0.203531,0.752333,0.706986
3,No log,0.218407,0.757358,0.730872
4,No log,0.26264,0.754487,0.729025
5,No log,0.358285,0.757358,0.721309
6,No log,0.46562,0.759512,0.726182
7,No log,0.570216,0.74659,0.694763
8,No log,0.518708,0.760948,0.732538
9,No log,0.608649,0.766691,0.733746
10,No log,0.603648,0.763819,0.733478


  0%|          | 0/1349 [00:00<?, ?it/s]

Map:   0%|          | 0/5568 [00:00<?, ? examples/s]

Map:   0%|          | 0/5568 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Accuracy,F1
1,No log,0.212102,0.714381,0.678733
2,No log,0.162224,0.802676,0.793569
3,No log,0.168937,0.794649,0.788741
4,No log,0.24622,0.79398,0.781669
5,No log,0.299663,0.796656,0.781725
6,No log,0.324305,0.802007,0.789013
7,No log,0.52931,0.775251,0.744453
8,No log,0.536587,0.79398,0.77983
9,No log,0.581259,0.8,0.786341
10,No log,0.606179,0.79398,0.77568


  0%|          | 0/1039 [00:00<?, ?it/s]

Map:   0%|          | 0/5568 [00:00<?, ? examples/s]

Map:   0%|          | 0/5568 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Accuracy,F1
1,No log,0.228128,0.628873,0.626765
2,No log,0.195674,0.748846,0.742937
3,No log,0.252662,0.740936,0.711033
4,No log,0.372129,0.758734,0.746053
5,No log,0.540215,0.74621,0.724261
6,No log,0.490903,0.750824,0.742357
7,No log,0.623468,0.752802,0.736986
8,No log,0.644102,0.760053,0.749241
9,No log,0.698331,0.76203,0.748489
10,No log,0.696165,0.76269,0.750827


  0%|          | 0/973 [00:00<?, ?it/s]

Map:   0%|          | 0/5568 [00:00<?, ? examples/s]

Map:   0%|          | 0/5568 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Accuracy,F1
1,No log,0.242515,0.719614,0.636662
2,No log,0.167489,0.784566,0.776692
3,No log,0.194638,0.805145,0.784811
4,No log,0.236957,0.79164,0.782
5,No log,0.313702,0.797428,0.784126
6,No log,0.354854,0.785852,0.777342
7,No log,0.478213,0.788424,0.770372
8,No log,0.508315,0.795498,0.778821
9,No log,0.569051,0.796785,0.776639
10,0.068500,0.547891,0.801286,0.784998


  0%|          | 0/856 [00:00<?, ?it/s]

Map:   0%|          | 0/5568 [00:00<?, ? examples/s]

Map:   0%|          | 0/5568 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Accuracy,F1
1,No log,0.235803,0.558602,0.551856
2,No log,0.179092,0.793694,0.78155
3,No log,0.193209,0.797121,0.786113
4,No log,0.208034,0.790953,0.781069
5,No log,0.344423,0.790953,0.77226
6,No log,0.471714,0.782728,0.76099
7,No log,0.409258,0.782042,0.775285
8,No log,0.536905,0.790267,0.774918
9,No log,0.571124,0.78684,0.768174
10,No log,0.577936,0.791638,0.774612


  0%|          | 0/1148 [00:00<?, ?it/s]

Map:   0%|          | 0/5568 [00:00<?, ? examples/s]

Map:   0%|          | 0/5568 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Accuracy,F1
1,No log,0.216534,0.735822,0.665417
2,No log,0.213675,0.721464,0.71578
3,No log,0.236296,0.738693,0.727609
4,No log,0.419826,0.735104,0.668606
5,No log,0.423972,0.739411,0.685667
6,No log,0.470541,0.745872,0.711794
7,No log,0.587386,0.758794,0.729128
8,No log,0.666577,0.754487,0.718011
9,No log,0.682721,0.759512,0.725869
10,No log,0.686716,0.755922,0.723192


  0%|          | 0/1349 [00:00<?, ?it/s]

Map:   0%|          | 0/5568 [00:00<?, ? examples/s]

Map:   0%|          | 0/5568 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Accuracy,F1
1,No log,0.206504,0.698328,0.690751
2,No log,0.173134,0.76388,0.760167
3,No log,0.169957,0.8,0.781033
4,No log,0.190873,0.780602,0.774028
5,No log,0.443715,0.777258,0.755497
6,No log,0.405215,0.783278,0.767539
7,No log,0.462355,0.784615,0.767738
8,No log,0.502846,0.79398,0.780147
9,No log,0.522151,0.791304,0.777765
10,No log,0.537994,0.789967,0.776341


  0%|          | 0/1039 [00:00<?, ?it/s]

Map:   0%|          | 0/5568 [00:00<?, ? examples/s]

Map:   0%|          | 0/5568 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Accuracy,F1
1,No log,0.218157,0.649308,0.648622
2,No log,0.179774,0.765326,0.762161
3,No log,0.223702,0.73764,0.737295
4,No log,0.369279,0.769281,0.752501
5,No log,0.346867,0.765326,0.756865
6,No log,0.52678,0.779169,0.766452
7,No log,0.492198,0.777851,0.766956
8,No log,0.603412,0.769281,0.754435
9,No log,0.57624,0.77851,0.768662
10,No log,0.574608,0.777851,0.767774


  0%|          | 0/973 [00:00<?, ?it/s]

saving the results dataframe to CSV in rumours_output/bert_classifier.csv


In [10]:
bert_classifier

Unnamed: 0,accuracy,f1,f1_scores,precision,precision_scores,recall,recall_scores,seed,loss,gamma,k_fold
0,0.623299,0.591448,"[0.7055223663121084, 0.4773726402896302]",0.590576,"[0.7173333333333334, 0.4638190954773869]",0.592918,"[0.6940940366972477, 0.49174214171550346]",1,focal,2,True
0,0.614911,0.546403,"[0.7226845637583893, 0.37012195121951225]",0.55605,"[0.679454820797577, 0.43264433357091947]",0.547589,"[0.7717889908256881, 0.32338838572189665]",12,focal,2,True
0,0.603355,0.589922,"[0.6641414141414141, 0.515703231679563]",0.594452,"[0.7387640449438202, 0.45013905442987684]",0.603417,"[0.6032110091743119, 0.6036228023441662]",123,focal,2,True


In [11]:
bert_classifier["f1"].mean()

0.5759243612334363

In [12]:
bert_classifier["precision"].mean()

0.580359113758819

In [13]:
bert_classifier["recall"].mean()

0.5813078944131357

In [14]:
np.stack(bert_classifier["f1_scores"]).mean(axis=0)

array([0.69744945, 0.45439927])

In [15]:
np.stack(bert_classifier["precision_scores"]).mean(axis=0)

array([0.71185073, 0.44886749])

In [16]:
np.stack(bert_classifier["recall_scores"]).mean(axis=0)

array([0.68969801, 0.47291778])

## Using Cross-Entropy loss

In [19]:
loss = "cross_entropy"
gamma = None

In [20]:
bert_classifier_ce = fine_tune_transformer_average_seed(
    num_epochs=num_epochs,
    pretrained_model_name="bert-base-uncased",
    df=df_rumours,
    feature_name="text",
    label_column="label",
    seeds=seeds,
    loss=loss,
    gamma=gamma,
    split_ids=split_ids,
    k_fold=True,
    validation_metric=validation_metric,
    results_output=f"{output_dir}/bert_classifier.csv",
    device=device,
    verbose=False,
)

Map:   0%|          | 0/5568 [00:00<?, ? examples/s]

Map:   0%|          | 0/5568 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Accuracy,F1
1,No log,0.515088,0.771704,0.744668
2,No log,0.482425,0.765916,0.761267
3,No log,0.45359,0.798714,0.78718
4,No log,0.596315,0.789711,0.761596
5,No log,0.688636,0.801929,0.782108
6,No log,0.959194,0.769775,0.737264
7,No log,0.896423,0.8,0.781183
8,No log,0.920382,0.800643,0.788883
9,No log,0.98077,0.802572,0.789465
10,0.187000,1.047564,0.803859,0.786956


  0%|          | 0/856 [00:00<?, ?it/s]

Map:   0%|          | 0/5568 [00:00<?, ? examples/s]

Map:   0%|          | 0/5568 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Accuracy,F1
1,No log,0.56952,0.716244,0.705633
2,No log,0.467208,0.789582,0.771349
3,No log,0.461566,0.805346,0.78892
4,No log,0.508753,0.776559,0.761466
5,No log,0.626719,0.784784,0.773548
6,No log,0.67826,0.784099,0.771233
7,No log,0.804088,0.793009,0.78075
8,No log,0.940678,0.792324,0.776867
9,No log,0.94953,0.797121,0.78583
10,No log,0.953349,0.799863,0.787117


  0%|          | 0/1148 [00:00<?, ?it/s]

Map:   0%|          | 0/5568 [00:00<?, ? examples/s]

Map:   0%|          | 0/5568 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Accuracy,F1
1,No log,0.539478,0.754487,0.697409
2,No log,0.515542,0.773151,0.748889
3,No log,0.609649,0.751615,0.693398
4,No log,0.696941,0.768844,0.744623
5,No log,0.840369,0.764537,0.722713
6,No log,0.940827,0.757358,0.728663
7,No log,0.919318,0.782484,0.762723
8,No log,1.052478,0.770998,0.74663
9,No log,1.113914,0.77028,0.749972
10,No log,1.152819,0.772434,0.749438


  0%|          | 0/1349 [00:00<?, ?it/s]

Map:   0%|          | 0/5568 [00:00<?, ? examples/s]

Map:   0%|          | 0/5568 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Accuracy,F1
1,No log,0.512784,0.755184,0.739118
2,No log,0.491624,0.790635,0.754436
3,No log,0.422619,0.826756,0.818218
4,No log,0.447599,0.807358,0.79988
5,No log,0.545221,0.768562,0.766302
6,No log,0.579822,0.810702,0.801613
7,No log,0.686068,0.804682,0.798542
8,No log,0.715641,0.814716,0.80371
9,No log,0.771628,0.810702,0.799989
10,No log,0.782926,0.812709,0.80243


  0%|          | 0/1039 [00:00<?, ?it/s]

Map:   0%|          | 0/5568 [00:00<?, ? examples/s]

Map:   0%|          | 0/5568 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Accuracy,F1
1,No log,0.536304,0.734344,0.729053
2,No log,0.478323,0.77851,0.770756
3,No log,0.71871,0.754779,0.725027
4,No log,0.611021,0.779829,0.771417
5,No log,0.819567,0.761371,0.745485
6,No log,0.91875,0.765985,0.752669
7,No log,1.157568,0.76269,0.739231
8,No log,1.043031,0.773896,0.762949
9,No log,1.147644,0.783125,0.769564
10,No log,1.158944,0.781147,0.767542


  0%|          | 0/973 [00:00<?, ?it/s]

Map:   0%|          | 0/5568 [00:00<?, ? examples/s]

Map:   0%|          | 0/5568 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Accuracy,F1
1,No log,0.5907,0.696463,0.610355
2,No log,0.457221,0.79164,0.768356
3,No log,0.475701,0.807074,0.797918
4,No log,0.65065,0.807717,0.786711
5,No log,0.704587,0.78135,0.774291
6,No log,0.807651,0.796785,0.786391
7,No log,1.074444,0.785852,0.760491
8,No log,1.030583,0.780064,0.774651
9,No log,1.100883,0.797428,0.783227
10,0.203100,1.085342,0.798071,0.78503


  0%|          | 0/856 [00:00<?, ?it/s]

Map:   0%|          | 0/5568 [00:00<?, ? examples/s]

Map:   0%|          | 0/5568 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Accuracy,F1
1,No log,0.535975,0.745716,0.70904
2,No log,0.491248,0.778615,0.747245
3,No log,0.497309,0.785469,0.769858
4,No log,0.588363,0.791638,0.773868
5,No log,0.630497,0.788897,0.782873
6,No log,0.688479,0.797121,0.786936
7,No log,0.9306,0.783413,0.761636
8,No log,0.937413,0.78684,0.773824
9,No log,1.068761,0.784784,0.766623
10,No log,1.047971,0.793009,0.776993


  0%|          | 0/1148 [00:00<?, ?it/s]

Map:   0%|          | 0/5568 [00:00<?, ? examples/s]

Map:   0%|          | 0/5568 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Accuracy,F1
1,No log,0.636339,0.671931,0.622064
2,No log,0.584792,0.731515,0.70564
3,No log,0.571704,0.752333,0.728591
4,No log,0.640533,0.738693,0.724611
5,No log,0.87693,0.744436,0.687405
6,No log,1.085868,0.742283,0.697495
7,No log,1.079451,0.748744,0.721317
8,No log,1.238371,0.751615,0.713342
9,No log,1.270031,0.751615,0.72532
10,No log,1.30882,0.753051,0.72442


  0%|          | 0/1349 [00:00<?, ?it/s]

Map:   0%|          | 0/5568 [00:00<?, ? examples/s]

Map:   0%|          | 0/5568 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Accuracy,F1
1,No log,0.632912,0.622742,0.383759
2,No log,0.464891,0.803344,0.790733
3,No log,0.47456,0.795987,0.782053
4,No log,0.509013,0.775251,0.77023
5,No log,0.564573,0.797993,0.783645
6,No log,0.666311,0.790635,0.779637
7,No log,0.92496,0.78796,0.761255
8,No log,0.861477,0.779933,0.770179
9,No log,0.938915,0.792642,0.7784
10,No log,0.99846,0.795987,0.776446


  0%|          | 0/1039 [00:00<?, ?it/s]

Map:   0%|          | 0/5568 [00:00<?, ? examples/s]

Map:   0%|          | 0/5568 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Accuracy,F1
1,No log,0.602192,0.68029,0.661878
2,No log,0.511382,0.76203,0.750803
3,No log,0.564571,0.760053,0.744079
4,No log,0.747071,0.758075,0.739211
5,No log,0.840359,0.751483,0.732713
6,No log,1.06323,0.769941,0.756211
7,No log,1.261024,0.754779,0.734735
8,No log,1.220452,0.766645,0.759185
9,No log,1.355545,0.763349,0.747683
10,No log,1.322488,0.766645,0.758596


  0%|          | 0/973 [00:00<?, ?it/s]

Map:   0%|          | 0/5568 [00:00<?, ? examples/s]

Map:   0%|          | 0/5568 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Accuracy,F1
1,No log,0.614951,0.708682,0.61662
2,No log,0.452472,0.810932,0.796433
3,No log,0.46865,0.801929,0.786897
4,No log,0.584174,0.800643,0.783718
5,No log,0.713329,0.789068,0.771337
6,No log,0.882563,0.793569,0.777157
7,No log,0.9467,0.794212,0.780922
8,No log,1.054776,0.796785,0.782919
9,No log,1.183761,0.782637,0.760474
10,0.186900,1.143012,0.794855,0.779367


  0%|          | 0/856 [00:00<?, ?it/s]

Map:   0%|          | 0/5568 [00:00<?, ? examples/s]

Map:   0%|          | 0/5568 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Accuracy,F1
1,No log,0.69173,0.618917,0.618787
2,No log,0.489086,0.78684,0.775639
3,No log,0.458621,0.804661,0.795046
4,No log,0.555386,0.793694,0.767466
5,No log,0.548048,0.795065,0.780979
6,No log,0.730583,0.791638,0.772119
7,No log,0.841145,0.786155,0.765501
8,No log,0.796413,0.804661,0.794263
9,No log,0.892912,0.799178,0.785691
10,No log,0.918145,0.797807,0.783423


  0%|          | 0/1148 [00:00<?, ?it/s]

Map:   0%|          | 0/5568 [00:00<?, ? examples/s]

Map:   0%|          | 0/5568 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Accuracy,F1
1,No log,0.572642,0.737258,0.654742
2,No log,0.533342,0.776023,0.753743
3,No log,0.576932,0.77028,0.749297
4,No log,0.703938,0.776741,0.748879
5,No log,0.86419,0.767408,0.72352
6,No log,0.902257,0.767408,0.75006
7,No log,1.060084,0.776741,0.74666
8,No log,1.125634,0.773151,0.748137
9,No log,1.189279,0.776741,0.75124
10,No log,1.219058,0.779612,0.751841


  0%|          | 0/1349 [00:00<?, ?it/s]

Map:   0%|          | 0/5568 [00:00<?, ? examples/s]

Map:   0%|          | 0/5568 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Accuracy,F1
1,No log,0.487634,0.779264,0.759857
2,No log,0.464184,0.78194,0.768909
3,No log,0.624433,0.770569,0.732497
4,No log,0.606522,0.798662,0.789858
5,No log,0.737273,0.777258,0.759897
6,No log,0.887945,0.786622,0.774099
7,No log,1.061992,0.789967,0.776962
8,No log,1.094186,0.774582,0.754045
9,No log,1.201645,0.789298,0.771614
10,No log,1.234728,0.779933,0.760485


  0%|          | 0/1039 [00:00<?, ?it/s]

Map:   0%|          | 0/5568 [00:00<?, ? examples/s]

Map:   0%|          | 0/5568 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Accuracy,F1
1,No log,0.633545,0.667106,0.566072
2,No log,0.512595,0.771259,0.764978
3,No log,0.549044,0.764008,0.753374
4,No log,0.706583,0.772577,0.754085
5,No log,0.855388,0.75412,0.738208
6,No log,1.026764,0.758075,0.740954
7,No log,1.16564,0.766645,0.746565
8,No log,1.106261,0.766645,0.757473
9,No log,1.190833,0.765985,0.757236
10,No log,1.205723,0.7706,0.759564


  0%|          | 0/973 [00:00<?, ?it/s]

saving the results dataframe to CSV in rumours_output/bert_classifier.csv


In [21]:
bert_classifier_ce

Unnamed: 0,accuracy,f1,f1_scores,precision,precision_scores,recall,recall_scores,seed,loss,gamma,k_fold
0,0.630382,0.59125,"[0.7177224199288257, 0.4647773279352227]",0.591879,"[0.7127509188577891, 0.4710065645514223]",0.590737,"[0.7227637614678899, 0.45871070857751733]",1,cross_entropy,,True
0,0.63616,0.578723,"[0.7342771576368092, 0.4231678486997636]",0.587091,"[0.6990668740279938, 0.4751161247511613]",0.577341,"[0.7732224770642202, 0.3814597762386787]",12,cross_entropy,,True
0,0.635974,0.606165,"[0.7145154217219706, 0.4978143481614811]",0.605006,"[0.728899492991351, 0.48111332007952284]",0.608202,"[0.7006880733944955, 0.515716568993074]",123,cross_entropy,,True


In [22]:
bert_classifier_ce["f1"].mean()

0.5920457540140122

In [23]:
bert_classifier_ce["precision"].mean()

0.5946588825432068

In [24]:
bert_classifier_ce["recall"].mean()

0.5920935609559793

In [25]:
np.stack(bert_classifier_ce["f1_scores"]).mean(axis=0)

array([0.72217167, 0.46191984])

In [26]:
np.stack(bert_classifier_ce["precision_scores"]).mean(axis=0)

array([0.71357243, 0.47574534])

In [27]:
np.stack(bert_classifier_ce["recall_scores"]).mean(axis=0)

array([0.73222477, 0.45196235])