In [None]:
from transformers import BertTokenizer, BertForSequenceClassification
import torch
from tqdm import tqdm
from torchvision import datasets
from torch.utils.data import DataLoader
from torch import nn
from torch.nn import functional as F

from utils import *

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)

In [None]:
test_str = 'IM JUST A STOOPID LIL TEST'

In [None]:
epochs = 3
lr = 0.1
n_labels = ................?????????
manual_loss= False

In [None]:
include_dev=False
tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
batch_size_train = 8
batch_size_test = 8
batch_size_dev = 8

if include_dev:
    train_df, val_df, test_df = get_processed_data(dev=True)

    train_df= get_cols_for_bert(train_df, 'snip')
    val_df= get_cols_for_bert(val_df, 'snip')
    test_df= get_cols_for_bert(test_df, 'snip')

    train_input_embeddings_labelled = format_and_tokenise_from_df(train_df, tokenizer, task='snip')
    val_input_embeddings_labelled = format_and_tokenise_from_df(val_df, tokenizer, task='snip')
    test_input_embeddings_labelled = format_and_tokenise_from_df(test_df, tokenizer, task='snip')

    train_dataset = CustomPropagandaDataset(train_input_embeddings_labelled)
    test_dataset = CustomPropagandaDataset(test_input_embeddings_labelled)
    val_dataset = CustomPropagandaDataset(val_input_embeddings_labelled)



    train_dataloader = DataLoader(train_dataset, batch_size=batch_size_train, shuffle=True)
    test_dataloader = DataLoader(test_dataset, batch_size=batch_size_test, shuffle=True)
    val_dataloader = DataLoader(val_dataset, batch_size=batch_size_dev, shuffle=True)


else:
    train_df, val_df = get_processed_data(dev=False)

    train_df= get_cols_for_bert(train_df, 'snip')
    val_df= get_cols_for_bert(val_df, 'snip')

    train_input_embeddings_labelled = format_and_tokenise_from_df(train_df, tokenizer, task='snip')
    val_input_embeddings_labelled = format_and_tokenise_from_df(val_df, tokenizer, task='snip')

    train_dataset = CustomPropagandaDataset(train_input_embeddings_labelled)
    val_dataset = CustomPropagandaDataset(val_input_embeddings_labelled)

    train_dataloader = DataLoader(train_dataset, batch_size=batch_size_train, shuffle=True)
    val_dataloader = DataLoader(val_dataset, batch_size=batch_size_dev, shuffle=True)




In [None]:
model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=n_labels)
optimizer = torch.optim.Adam(params=model.parameters(), lr=lr)
criterion=nn.CrossEntropyLoss()

train_losses = []
train_accuracy = []
val_losses = []
val_accuracy = []

model.to(device)
for epoch in range(epochs):
  train_running_losses = []
  train_total = 0
  train_correct = 0
  model.train()
  for batch in tqdm(train_dataloader):

    batch = {k: v.to(device) for k, v in batch.items()}

    outputs = model(**batch)

    # IN BUILT LOSS

    if manual_loss:
      loss = criterion(outputs.logits, batch['labels'])
    else:
      loss = outputs[0]


    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

    _, indices = torch.max(outputs['logits'], dim=1)
    predicted_labels = indices.float()

    train_total += batch['labels'].size(0)
    train_correct += (predicted_labels == batch['labels']).sum().item()
    train_running_losses.append(loss.item())

  train_losses.append(sum(train_running_losses)/len(train_running_losses))
  train_accuracy.append(train_correct/train_total)

  print(f'TRAIN: Epoch [{epoch}/{epochs}] Loss: {sum(train_running_losses)/len(train_running_losses)} Acc: {train_correct/train_total}')


  model.eval()
  with torch.no_grad():
    val_running_losses = []
    val_total = 0
    val_correct = 0
    for batch in tqdm(val_dataloader):
      batch = {k: v.to(device) for k, v in batch.items()}
      outputs = model(**batch)
      loss = outputs[0]
      # Convert outputs to predicted labels (0 or 1 based on threshold 0.5)
      _, indices = torch.max(outputs['logits'], dim=1)
      predicted_labels = indices.float()

      # Calculate accuracy
      val_total += batch['labels'].size(0)
      val_correct += (predicted_labels == batch['labels']).sum().item()
      val_running_losses.append(loss.item())

  val_losses.append(sum(val_running_losses)/len(val_running_losses))
  val_accuracy.append(val_correct/val_total)

  print(f'VAL: Epoch [{epoch + 1}/{epochs}] Loss: {sum(val_running_losses)/len(val_running_losses)} Acc: {val_correct/val_total}')

if include_dev:
  print('TESTING...')
  test_losses = []
  test_accuracy = []
  model.eval()
  with torch.no_grad():
      test_running_losses = []
      test_total = 0
      test_correct = 0
      for batch in test_dataloader:
          batch = {k: v.to(device) for k, v in batch.items()}
          outputs = model(**batch)
          loss = outputs[0]
          # Convert outputs to predicted labels (0 or 1 based on threshold 0.5)
          _, indices = torch.max(outputs['logits'], dim=1)
          predicted_labels = indices.float()

          # Calculate accuracy
          test_total += batch['labels'].size(0)
          test_correct += (predicted_labels == batch['labels']).sum().item()
          test_running_losses.append(loss.item())

      test_losses.append(sum(test_running_losses)/len(test_running_losses))
      test_accuracy.append(test_correct/test_total)
  print(f'TEST: Epoch [{epoch + 1}/{epochs}] Loss: {sum(test_running_losses)/len(test_running_losses)} Acc: {test_correct/test_total}')

