In [1]:
# Importing the libraries needed
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
import torch
import transformers
import json
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
from transformers import RobertaModel, RobertaTokenizer
import logging
logging.basicConfig(level=logging.ERROR)
import os
import pickle

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import time

In [3]:
from transformers import LongformerConfig, LongformerModel, LongformerTokenizer
tokenizer = LongformerTokenizer.from_pretrained("allenai/longformer-base-4096")

In [4]:
# device = 'mps' if torch.backends.mps.is_available() else 'cpu'
device = 0
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"

In [5]:
# Defining some key variables that will be used later on in the training
MAX_LEN = 512

LEARNING_RATE = 1e-05
tokenizer = LongformerTokenizer.from_pretrained("allenai/longformer-base-4096")

In [6]:
with open('../data/twibot20/train.json') as f:
    train_data = json.loads(f.read())

with open('../data/twibot20/val.json') as f:
    val_data = json.loads(f.read())
    
with open('../data/twibot20/test.json') as f:
    test_data = json.loads(f.read())


In [7]:
model = LongformerModel.from_pretrained("allenai/longformer-base-4096")
model.to(device)
model.requires_grad = False
model.eval()

Some weights of the model checkpoint at allenai/longformer-base-4096 were not used when initializing LongformerModel: ['lm_head.layer_norm.weight', 'lm_head.bias', 'lm_head.dense.weight', 'lm_head.layer_norm.bias', 'lm_head.decoder.weight', 'lm_head.dense.bias']
- This IS expected if you are initializing LongformerModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing LongformerModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


LongformerModel(
  (embeddings): LongformerEmbeddings(
    (word_embeddings): Embedding(50265, 768, padding_idx=1)
    (position_embeddings): Embedding(4098, 768, padding_idx=1)
    (token_type_embeddings): Embedding(1, 768)
    (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): LongformerEncoder(
    (layer): ModuleList(
      (0): LongformerLayer(
        (attention): LongformerAttention(
          (self): LongformerSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (query_global): Linear(in_features=768, out_features=768, bias=True)
            (key_global): Linear(in_features=768, out_features=768, bias=True)
            (value_global): Linear(in_features=768, out_features=768, bias=True)
          )
          (o

In [8]:
def process_tweet_data(data):
    if len(data['tweets']) == 0:
        text = ""
    else:
        tweet_list = np.random.choice(data['tweets'], 1000)
        text = " ".join(" ".join(tweet_list).split(' ')[:MAX_LEN])

    if data['label'] == 'bot':
        label = 1
    else:
        label = 0
    inputs = tokenizer.encode_plus(
        text,
        max_length=MAX_LEN,
        pad_to_max_length=True,
        return_token_type_ids=True,
        add_special_tokens=True
    )
        
    ids = inputs['input_ids']
    mask = inputs['attention_mask']
    token_type_ids = inputs["token_type_ids"]
    
    output = model(input_ids=torch.tensor(ids, dtype=torch.long).reshape(1, -1).cuda(), attention_mask=torch.tensor(mask, dtype=torch.long).reshape(1, -1).cuda(), token_type_ids=torch.tensor(token_type_ids, dtype=torch.long).reshape(1, -1).cuda())[0][0, 0, :].detach()
    tweet_embed = output
    
    created_dt = pd.to_datetime(data['created_at'])
    created_hms = np.zeros(24 + 60 + 60)
    created_hms[created_dt.hour] = 1
    created_hms[created_dt.minute + 24] = 1
    created_hms[created_dt.second + 84] = 1
    
    
    text = str(data['description'])
    text = " ".join(text.split())

    inputs = tokenizer.encode_plus(
        text,
        add_special_tokens=True,
        max_length=MAX_LEN,
        pad_to_max_length=True,
        return_token_type_ids=True
    )
    ids = inputs['input_ids']
    mask = inputs['attention_mask']
    token_type_ids = inputs["token_type_ids"]
    
    output = model(input_ids=torch.tensor(ids, dtype=torch.long).reshape(1, -1).cuda(), attention_mask=torch.tensor(mask, dtype=torch.long).reshape(1, -1).cuda(), token_type_ids=torch.tensor(token_type_ids, dtype=torch.long).reshape(1, -1).cuda())[0][0, 0, :].detach()
    desc_embed = output
    

    return {
        'tweet_embed': tweet_embed.detach().cpu(),
        'desc_embed': desc_embed.detach().cpu(),
        'hms': torch.tensor(created_hms, dtype = torch.float32),
        'metrics': torch.tensor(list(data['public_metrics'].values()), dtype = torch.float32),
        'verified': not 'false' in data['verified'],
        'targets': torch.tensor(label, dtype=torch.float32)
            }

In [9]:
class UserTweetDataset(Dataset):
    def __init__(self, dataset, tokenizer, max_len):
        self.dataset = dataset
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, index):
        return self.dataset[index]

In [10]:
train_set = pickle.load(open(f"../processed_data/train_set_{MAX_LEN}_long", "rb"))
val_set = pickle.load(open(f"../processed_data/val_set_{MAX_LEN}_long", "rb"))
test_set = pickle.load(open(f"../processed_data/test_set_{MAX_LEN}_long", "rb"))

train_adj = pickle.load(open(f"../processed_data/train_adj", "rb"))
val_adj = pickle.load(open(f"../processed_data/val_adj", "rb"))
test_adj = pickle.load(open(f"../processed_data/test_adj", "rb"))

In [29]:
for i in range(len(train_adj)):
    train_set[i]['adjacency1'] = train_adj[i]['adjacency1']
    train_set[i]['ind'] = train_adj[i]['ind']
    
for i in range(len(val_adj)):
    val_set[i]['adjacency1'] = val_adj[i]['adjacency1']
    val_set[i]['ind'] = val_adj[i]['ind']

for i in range(len(test_adj)):
    test_set[i]['adjacency1'] = test_adj[i]['adjacency1']
    test_set[i]['ind'] = test_adj[i]['ind']

In [10]:
try:
    train_set = pickle.load(open(f"../processed_data/train_adj", "rb"))
    val_set = pickle.load(open(f"../processed_data/val_adj", "rb"))
    test_set = pickle.load(open(f"../processed_data/test_adj", "rb"))
except:
    start = time.time()
    train_set = []
    for i in range(len(train_data)):
        if i % 100 == 0:
            print(i)
        train_set.append(process_tweet_data(train_data[i]))
    print(time.time() - start)

    start = time.time()
    val_set = []
    for i in range(len(val_data)):
        val_set.append(process_tweet_data(val_data[i]))
    print(time.time() - start)

    start = time.time()
    test_set = []
    for i in range(len(val_data)):
        test_set.append(process_tweet_data(test_data[i]))
    print(time.time() - start)
    user_data = {}

    user_ids = []
    neighbor_ids = []

    for data in train_set:
        user_ids.append(data['id'])
        user_data[data['id']] = data

    for data in val_set:
        user_ids.append(data['id'])
        user_data[data['id']] = data

    for data in test_set:
        user_ids.append(data['id'])
        user_data[data['id']] = data

    user_ids = set(user_ids)
    user_ids = list(user_ids)

    id_map = [{},{}]
    for i in range(len(user_ids)):
        id_map[0][user_ids[i]] = i
        id_map[1][i] = user_ids[i]

    d1_adjacency = torch.zeros((len(user_ids), len(user_ids)))
    d2_adjacency = torch.zeros((len(user_ids), len(user_ids)))

    i = 0
    for uid1 in user_data:
        if i %100 == 0:
            print(i)
        i+=1
        for uid2 in user_data:
            if uid1 == uid2:
                continue
            if uid2 in user_data[uid1]['friends'] or uid2 in user_data[uid1]['follows']:
                d1_adjacency[id_map[0][uid1], id_map[0][uid2]] = 1
                d1_adjacency[id_map[0][uid2], id_map[0][uid1]] = 1

    d2_adjacency = ((d1_adjacency @ d1_adjacency) > 0).float()
    for i in range(len(d2_adjacency)):
        d2_adjacency[i,i] = 0

    print(torch.mean(torch.sum(d1_adjacency, axis = 0)))
    print(torch.mean(torch.sum(d2_adjacency, axis = 0)))

    d2_adjacency.shape

    for data in train_set:
        del data['friends']
        del data['follows']
        data['adjacency1'] = d1_adjacency[id_map[0][data['id']]] / torch.sum(d1_adjacency[id_map[0][data['id']]] + .01)
        data['adjacency2'] = d2_adjacency[id_map[0][data['id']]] / torch.sum(d1_adjacency[id_map[0][data['id']]] + .01)
        data['ind'] = id_map[0][data['id']]

    for data in val_set:
        del data['friends']
        del data['follows']
        data['adjacency1'] = d1_adjacency[id_map[0][data['id']]] / torch.sum(d1_adjacency[id_map[0][data['id']]] + .01)
        data['adjacency2'] = d2_adjacency[id_map[0][data['id']]] / torch.sum(d1_adjacency[id_map[0][data['id']]] + .01)
        data['ind'] = id_map[0][data['id']]

    for data in test_set:
        del data['friends']
        del data['follows']
        data['adjacency1'] = d1_adjacency[id_map[0][data['id']]] / torch.sum(d1_adjacency[id_map[0][data['id']]] + .01)
        data['adjacency2'] = d2_adjacency[id_map[0][data['id']]] / torch.sum(d1_adjacency[id_map[0][data['id']]] + .01)
        data['ind'] = id_map[0][data['id']]

    pickle.dump(test_set, open("../processed_data/test_adj", "wb"))
    pickle.dump(train_set, open("../processed_data/train_adj", "wb"))
    pickle.dump(val_set, open("../processed_data/val_adj", "wb"))

In [40]:
TRAIN_BATCH_SIZE = 64
VALID_BATCH_SIZE = 256

train_params = {'batch_size': TRAIN_BATCH_SIZE,
                'shuffle': True,
                'num_workers': 0
                }

val_params = {'batch_size': VALID_BATCH_SIZE,
                'shuffle': False,
                'num_workers': 0
                }

train_loader = DataLoader(train_set, **train_params)
val_loader = DataLoader(val_set, **val_params)
test_loader = DataLoader(test_set, **val_params)

In [41]:
class TweetModel(torch.nn.Module):
    def __init__(self):
        super(TweetModel, self).__init__()
        
        self.l1 = torch.nn.Linear(768, 256)
        self.l2 = torch.nn.Linear(256, 64)
        self.l3 = torch.nn.Linear(64, 32)
        self.l4 = torch.nn.Linear(32, 1)
        self.dropout = torch.nn.Dropout(0.5)
        self.activation = torch.nn.ReLU()

    def forward(self, data):
        tweet_embed = data['tweet_embed'].to(device)
        x = self.activation(self.dropout(self.l1(self.dropout(tweet_embed))))
        x = self.activation(self.dropout(self.l2(x)))
        x = self.activation(self.dropout(self.l3(x)))
        x = self.l4(x)
        output = torch.squeeze(x)
        return output
    
class TweetDescModel(torch.nn.Module):
    def __init__(self):
        super(TweetDescModel, self).__init__()
        
        self.tweet_l1 = torch.nn.Linear(768, 128)
        
        self.desc_l1 = torch.nn.Linear(768, 128)
        
        self.fuse_l1 = torch.nn.Linear(self.tweet_l1.out_features + self.desc_l1.out_features, 128)
        self.fuse_l2 = torch.nn.Linear(128, 64)
        self.fuse_l3 = torch.nn.Linear(64, 32)
        self.fuse_l4 = torch.nn.Linear(32, 1)
        
        self.dropout = torch.nn.Dropout(0.5)
        self.activation = torch.nn.ReLU()

    def forward(self, data):
        tweet_embed = data['tweet_embed'].to(device)
        desc_embed = data['desc_embed'].to(device)
        tweet_x = self.activation(self.dropout(self.tweet_l1(self.dropout(tweet_embed))))
        desc_x = self.activation(self.dropout(self.desc_l1(self.dropout(desc_embed))))
        
        fuse_x = torch.cat([tweet_x, desc_x], axis = 1)
        fuse_x = self.activation(self.dropout(self.fuse_l1(fuse_x)))
        fuse_x = self.activation(self.dropout(self.fuse_l2(fuse_x)))
        fuse_x = self.activation(self.dropout(self.fuse_l3(fuse_x)))
        fuse_x = self.fuse_l4(fuse_x)
        output = torch.squeeze(fuse_x)
        return output

class FullModel(torch.nn.Module):
    def __init__(self):
        super(FullModel, self).__init__()
        
        self.tweet_l1 = torch.nn.Linear(768, 128)
        
        self.desc_l1 = torch.nn.Linear(768, 128)
        
        self.hms_l1 = torch.nn.Linear(24+60+60, 16)
        
        self.metrics_l1 = torch.nn.Linear(4, 16)
        
        self.fuse_l1 = torch.nn.Linear(self.tweet_l1.out_features + self.desc_l1.out_features + self.hms_l1.out_features + self.metrics_l1.out_features + 1, 128)
        self.fuse_l2 = torch.nn.Linear(128, 64)
        self.fuse_l3 = torch.nn.Linear(64, 32)
        self.fuse_l4 = torch.nn.Linear(32, 1)
        
        self.dropout = torch.nn.Dropout(0.5)
        self.activation = torch.nn.ReLU()

    def forward(self, data):
        tweet_embed = data['tweet_embed'].to(device)
        desc_embed = data['desc_embed'].to(device)
        hms = data['hms'].to(device)
        metrics = data['metrics'].to(device)
        verified = data['verified'].to(device).reshape(-1, 1)
        
        tweet_x = self.activation(self.dropout(self.tweet_l1(self.dropout(tweet_embed))))
        desc_x = self.activation(self.dropout(self.desc_l1(self.dropout(desc_embed))))
        
        hms_x = self.activation(self.dropout(self.hms_l1(hms)))
        metrics_x = self.activation(self.dropout(self.metrics_l1(metrics)))
        
        fuse_x = torch.cat([tweet_x, desc_x, hms_x, metrics_x, verified], axis = 1)
        fuse_x = self.activation(self.dropout(self.fuse_l1(fuse_x)))
        fuse_x = self.activation(self.dropout(self.fuse_l2(fuse_x)))
        fuse_x = self.activation(self.dropout(self.fuse_l3(fuse_x)))
        fuse_x = self.fuse_l4(fuse_x)
        output = torch.squeeze(fuse_x)
        return output
    
    def embed(self, data):
        tweet_embed = data['tweet_embed'].to(device)
        desc_embed = data['desc_embed'].to(device)
        hms = data['hms'].to(device)
        metrics = data['metrics'].to(device)
        verified = data['verified'].to(device).reshape(-1, 1)
        
        tweet_x = self.activation(self.dropout(self.tweet_l1(self.dropout(tweet_embed))))
        desc_x = self.activation(self.dropout(self.desc_l1(self.dropout(desc_embed))))
        
        hms_x = self.activation(self.dropout(self.hms_l1(hms)))
        metrics_x = self.activation(self.dropout(self.metrics_l1(metrics)))
        
        fuse_x = torch.cat([tweet_x, desc_x, hms_x, metrics_x, verified], axis = 1)
        fuse_x = self.activation(self.dropout(self.fuse_l1(fuse_x)))
        fuse_x = self.activation(self.dropout(self.fuse_l2(fuse_x)))
        return fuse_x
    
class EarlyFullModel(torch.nn.Module):
    def __init__(self):
        super(EarlyFullModel, self).__init__()

        self.fuse_l1 = torch.nn.Linear(768 + 768 + 24 + 60 + 60 + 4 + 1, 256)
        self.fuse_l2 = torch.nn.Linear(256, 128)
        self.fuse_l3 = torch.nn.Linear(128, 64)
        self.fuse_l4 = torch.nn.Linear(64, 32)
        self.fuse_l5 = torch.nn.Linear(32, 1)
        
        self.dropout = torch.nn.Dropout(0.5)
        self.activation = torch.nn.ReLU()

    def forward(self, data):
        tweet_embed = data['tweet_embed'].to(device)
        desc_embed = data['desc_embed'].to(device)
        hms = data['hms'].to(device)
        metrics = data['metrics'].to(device)
        verified = data['verified'].to(device).reshape(-1, 1)
        
        fuse_x = torch.cat([tweet_embed, desc_embed, hms, metrics, verified], axis = 1)
        fuse_x = self.activation(self.dropout(self.fuse_l1(fuse_x)))
        fuse_x = self.activation(self.dropout(self.fuse_l2(fuse_x)))
        fuse_x = self.activation(self.dropout(self.fuse_l3(fuse_x)))
        fuse_x = self.activation(self.dropout(self.fuse_l4(fuse_x)))
        fuse_x = self.fuse_l5(fuse_x)
        output = torch.squeeze(fuse_x)
        return output
    
    def embed(self, data):
        tweet_embed = data['tweet_embed'].to(device)
        desc_embed = data['desc_embed'].to(device)
        hms = data['hms'].to(device)
        metrics = data['metrics'].to(device)
        verified = data['verified'].to(device).reshape(-1, 1)
        
        fuse_x = torch.cat([tweet_embed, desc_embed, hms, metrics, verified], axis = 1)
        fuse_x = self.activation(self.dropout(self.fuse_l1(fuse_x)))
        fuse_x = self.activation(self.dropout(self.fuse_l2(fuse_x)))
        fuse_x = self.fuse_l3(fuse_x)
        return fuse_x
    
    
class LateFullModel(torch.nn.Module):
    def __init__(self):
        super(LateFullModel, self).__init__()
        
        self.tweet_l1 = torch.nn.Linear(768, 128)
        self.tweet_l2 = torch.nn.Linear(128, 64)
        self.tweet_l3 = torch.nn.Linear(64, 32)
        
        self.desc_l1 = torch.nn.Linear(768, 128)
        self.desc_l2 = torch.nn.Linear(128, 64)
        self.desc_l3 = torch.nn.Linear(64, 32)
        
        self.hms_l1 = torch.nn.Linear(24+60+60, 16)
        
        self.metrics_l1 = torch.nn.Linear(4, 16)
        
        self.fuse_l1 = torch.nn.Linear(self.tweet_l3.out_features + self.desc_l3.out_features + self.hms_l1.out_features + self.metrics_l1.out_features + 1, 64)
        self.fuse_l2 = torch.nn.Linear(64, 1)
        
        self.dropout = torch.nn.Dropout(0.5)
        self.activation = torch.nn.ReLU()

    def forward(self, data):
        tweet_embed = data['tweet_embed'].to(device)
        desc_embed = data['desc_embed'].to(device)
        hms = data['hms'].to(device)
        metrics = data['metrics'].to(device)
        verified = data['verified'].to(device).reshape(-1, 1)
        
        tweet_x = self.activation(self.dropout(self.tweet_l1(self.dropout(tweet_embed))))
        tweet_x = self.activation(self.dropout(self.tweet_l2(tweet_x)))
        tweet_x = self.activation(self.dropout(self.tweet_l3(tweet_x)))
        
        desc_x = self.activation(self.dropout(self.desc_l1(self.dropout(desc_embed))))
        desc_x = self.activation(self.dropout(self.desc_l2(desc_x)))
        desc_x = self.activation(self.dropout(self.desc_l3(desc_x)))
        
        hms_x = self.activation(self.dropout(self.hms_l1(self.dropout(hms))))
        
        metrics_x = self.activation(self.dropout(self.metrics_l1(self.dropout(metrics))))

        fuse_x = torch.cat([tweet_x, desc_x, hms_x, metrics_x, verified], axis = 1)
        fuse_x = self.activation(self.dropout(self.fuse_l1(fuse_x)))
        fuse_x = self.fuse_l2(fuse_x)
        output = torch.squeeze(fuse_x)
        return output
    
    def embed(self, data):
        tweet_embed = data['tweet_embed'].to(device)
        desc_embed = data['desc_embed'].to(device)
        hms = data['hms'].to(device)
        metrics = data['metrics'].to(device)
        verified = data['verified'].to(device).reshape(-1, 1)
        
        tweet_x = self.activation(self.dropout(self.tweet_l1(self.dropout(tweet_embed))))
        tweet_x = self.activation(self.dropout(self.tweet_l2(tweet_x)))
        tweet_x = self.activation(self.dropout(self.tweet_l3(tweet_x)))
        
        desc_x = self.activation(self.dropout(self.desc_l1(self.dropout(desc_embed))))
        desc_x = self.activation(self.dropout(self.desc_l2(desc_x)))
        desc_x = self.activation(self.dropout(self.desc_l3(desc_x)))
        
        hms_x = self.activation(self.dropout(self.hms_l1(self.dropout(hms))))

        metrics_x = self.activation(self.dropout(self.metrics_l1(self.dropout(metrics))))
        
        fuse_x = torch.cat([tweet_x, desc_x, hms_x, metrics_x, verified], axis = 1)
        fuse_x = self.fuse_l1(fuse_x)
        output = torch.squeeze(fuse_x)
        return output
    
class GraphModel(torch.nn.Module):
    def __init__(self):
        super(GraphModel, self).__init__()
        
        self.embed_memo = torch.zeros(len(train_set[0]['adjacency1']), 64).cuda()
        self.embed_memo.requires_grad = False
        self.full_model = FullModel()
        
        self.fuse_l3 = torch.nn.Linear(128, 32)
        self.fuse_l4 = torch.nn.Linear(32, 1)
        
        self.dropout = torch.nn.Dropout(0.5)
        self.activation = torch.nn.ReLU()

    def forward(self, data):
        data_embed = self.full_model.embed(data)
        #self.embed_memo[data['id']] = data_embed
        
        adjacency = data['adjacency1'].cuda() @ self.embed_memo
        
        #print(data_embed.shape, adjacency.shape)
        
        fuse_x = torch.cat([data_embed, adjacency], axis = 1)
        
        fuse_x = self.activation(self.dropout(self.fuse_l3(fuse_x)))
        fuse_x = self.fuse_l4(fuse_x)
        output = torch.squeeze(fuse_x)
        return output
    
    def update_embed(self, data):
        data_embed = self.full_model.embed(data)
        self.embed_memo[data['ind']] = data_embed.detach()
    
    def embed(self, data):
        return self.full_model.embed(data)
    
    

In [42]:
model = GraphModel()
model.to(device)

GraphModel(
  (full_model): FullModel(
    (tweet_l1): Linear(in_features=768, out_features=128, bias=True)
    (desc_l1): Linear(in_features=768, out_features=128, bias=True)
    (hms_l1): Linear(in_features=144, out_features=16, bias=True)
    (metrics_l1): Linear(in_features=4, out_features=16, bias=True)
    (fuse_l1): Linear(in_features=289, out_features=128, bias=True)
    (fuse_l2): Linear(in_features=128, out_features=64, bias=True)
    (fuse_l3): Linear(in_features=64, out_features=32, bias=True)
    (fuse_l4): Linear(in_features=32, out_features=1, bias=True)
    (dropout): Dropout(p=0.5, inplace=False)
    (activation): ReLU()
  )
  (fuse_l3): Linear(in_features=128, out_features=32, bias=True)
  (fuse_l4): Linear(in_features=32, out_features=1, bias=True)
  (dropout): Dropout(p=0.5, inplace=False)
  (activation): ReLU()
)

In [43]:
# Creating the loss function and optimizer
loss_function = torch.nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(params =  model.parameters(), lr=2e-4)

In [44]:
def calcuate_accuracy(preds, targets):
    n_correct = (preds==targets).sum().item()
    return n_correct

In [45]:
def eval_model(model, dataset):
    tp = 0
    fp = 0
    tn = 0
    fn = 0
    model.eval()
    for data in dataset:
        preds = model(data)
        targets = data['targets']
        for i in range(len(preds)):
            pred = preds[i]
            target = targets[i]
            if pred > 0:
                if target == 1:
                    tp += 1
                else:
                    fp += 1
            else:
                if target == 1:
                    fn += 1
                else:
                    tn += 1
    recall = tp / (tp + fn)
    precision = tp / (tp + fp)
    f1 = 2 * recall * precision / (recall + precision)
    accuracy = (tp + tn) / (tp + tn + fn + fp)
    return accuracy, f1, recall, precision

In [46]:
# Defining the training function on the 80% of the dataset for tuning the distilbert model

def train(epoch, verbose = False, update_e = False):
    start = time.time()
    tr_loss = 0
    n_correct = 0
    nb_tr_steps = 0
    nb_tr_examples = 0
    model.train()
    
    if update_e:
        for data in train_loader:
            model.update_embed(data)

        for data in val_loader:
            model.update_embed(data)

        for data in test_loader:
            model.update_embed(data)
    
    for data in train_loader:
        optimizer.zero_grad()
        targets = data['targets'].to(device, dtype = torch.float32)

        outputs = model(data)
        loss = loss_function(outputs, targets)
        tr_loss += loss.item() * targets.size(0)
        
        preds = (outputs.data).float() > 0
        n_correct += calcuate_accuracy(preds, targets)

        nb_tr_steps += 1
        nb_tr_examples+=targets.size(0)

        
        loss.backward()
        # # When using GPU
        optimizer.step()
        
    model.eval()
    
    
    if verbose:
        print()
        print(f"Training Stats Epoch: {eval_model(model, train_loader)}")
        print(f"Val Stats Epoch: {eval_model(model, val_loader)}")
        print(f"Test Stats Epoch: {eval_model(model, test_loader)}")
        print(f"t={time.time() - start}")

In [47]:
EPOCHS = 10000
for epoch in range(EPOCHS):
    train(epoch, epoch % 100 == 0, epoch % 5 == 0)


Training Stats Epoch: (0.6275670451799952, 0.6247108947048082, 0.5523030563925958, 0.7189688988512188)
Val Stats Epoch: (0.6342494714587738, 0.6289146289146289, 0.5625479662317728, 0.7130350194552529)
Test Stats Epoch: (0.6500422654268808, 0.6336283185840706, 0.559375, 0.7306122448979592)
t=16.59629464149475

Training Stats Epoch: (0.5612466779415318, 0.7189724543484989, 1.0, 0.5612466779415318)
Val Stats Epoch: (0.5509513742071882, 0.7104689203925845, 1.0, 0.5509513742071882)
Test Stats Epoch: (0.5409974640743872, 0.7021393307734504, 1.0, 0.5409974640743872)
t=15.17582893371582

Training Stats Epoch: (0.684223242329065, 0.7777210884353742, 0.9842875591907017, 0.6428169806016306)
Val Stats Epoch: (0.6866807610993657, 0.7767399819222657, 0.9892555640828856, 0.6393849206349206)
Test Stats Epoch: (0.687235841081995, 0.7738386308068459, 0.9890625, 0.6355421686746988)
t=13.953425645828247

Training Stats Epoch: (0.7666102923411452, 0.8155784650630011, 0.9195006457167456, 0.732761578044597)

Test Stats Epoch: (0.7996618765849535, 0.8216704288939052, 0.853125, 0.7924528301886793)
t=14.161805391311646

Training Stats Epoch: (0.9170089393573327, 0.9269537480063795, 0.938226431338786, 0.9159487287245219)
Val Stats Epoch: (0.7932346723044398, 0.8189559422436135, 0.8488104374520338, 0.7911301859799714)
Test Stats Epoch: (0.7920540997464074, 0.8127853881278538, 0.834375, 0.7922848664688428)
t=15.558864116668701

Training Stats Epoch: (0.9176129499879198, 0.9276776246023329, 0.941455015066724, 0.9142976588628763)
Val Stats Epoch: (0.7978858350951374, 0.8233555062823354, 0.8549501151189562, 0.7940128296507484)
Test Stats Epoch: (0.7928994082840237, 0.8139711465451784, 0.8375, 0.7917282127031019)
t=13.993781328201294

Training Stats Epoch: (0.9238946605460256, 0.9337121212121212, 0.9550150667240637, 0.9133388225607246)
Val Stats Epoch: (0.7966173361522199, 0.8238740388136213, 0.8633921719109747, 0.7878151260504201)
Test Stats Epoch: (0.7895181741335587, 0.8134831460674158, 0.8484375

Val Stats Epoch: (0.786046511627907, 0.8142437591776799, 0.8511128165771297, 0.7804363124560169)
Test Stats Epoch: (0.7751479289940828, 0.799396681749623, 0.828125, 0.7725947521865889)
t=14.029000759124756

Training Stats Epoch: (0.9695578642184103, 0.9727802981205443, 0.969220835126991, 0.9763660017346054)
Val Stats Epoch: (0.7906976744186046, 0.8171407462135204, 0.8488104374520338, 0.7877492877492878)
Test Stats Epoch: (0.7810650887573964, 0.8030418250950571, 0.825, 0.7822222222222223)
t=15.460454940795898

Training Stats Epoch: (0.9666586131915922, 0.9700130378096479, 0.9608265174343521, 0.9793769197016235)
Val Stats Epoch: (0.7966173361522199, 0.8215213358070501, 0.8495778971603991, 0.7952586206896551)
Test Stats Epoch: (0.7810650887573964, 0.8030418250950571, 0.825, 0.7822222222222223)
t=15.164224863052368

Training Stats Epoch: (0.971490698236289, 0.9745744451626804, 0.9735256134309083, 0.975625539257981)
Val Stats Epoch: (0.7873150105708245, 0.8158183815452216, 0.854950115118956


Training Stats Epoch: (0.9834501087219135, 0.9851811790156842, 0.9801980198019802, 0.9902152641878669)
Val Stats Epoch: (0.7915433403805496, 0.8188166115398752, 0.8549501151189562, 0.7856135401974612)
Test Stats Epoch: (0.7836010143702451, 0.8083832335329341, 0.84375, 0.7758620689655172)
t=10.220168590545654

Training Stats Epoch: (0.9845373278569702, 0.9861741196802765, 0.9825656478691347, 0.9898091934084996)
Val Stats Epoch: (0.7915433403805496, 0.8180140273163529, 0.8503453568687643, 0.7880512091038406)
Test Stats Epoch: (0.7692307692307693, 0.7936507936507936, 0.8203125, 0.7686676427525623)
t=13.809845685958862

Training Stats Epoch: (0.9817588789562697, 0.9836562398527979, 0.9780456306500215, 0.9893315915523623)
Val Stats Epoch: (0.7940803382663848, 0.821021683204704, 0.8572524942440521, 0.7877291960507757)
Test Stats Epoch: (0.7844463229078613, 0.8098434004474274, 0.8484375, 0.7746077032810271)
t=10.022267818450928

Training Stats Epoch: (0.9869533703793186, 0.9883495145631067, 

In [None]:
# longformer graph

In [19]:
eval_model(model, train_loader)

(0.8249577192558589,
 0.8502325581395349,
 0.8852776582006027,
 0.8178564326903957)

In [20]:
eval_model(model, val_loader)

(0.8038054968287527,
 0.8292862398822662,
 0.8649270913277053,
 0.7964664310954064)

In [21]:
eval_model(model, test_loader)

(0.8157227387996618, 0.835843373493976, 0.8671875, 0.8066860465116279)