In [None]:
!pip install transformers[torch] -U
!pip install wandb -qU

In [None]:
import pandas as pd
import numpy as np
import csv
from sklearn.metrics import accuracy_score, recall_score, precision_score, f1_score
import torch
from transformers import BertTokenizer, AutoModelForSequenceClassification

In [None]:
def change_labels(labels):
  converted_labels = []

  for label in labels:
    if label == 'OFF':
      converted_labels.append(0)
    if label == 'NOT':
      converted_labels.append(1)

  return converted_labels

In [None]:
def read_data_tsv(input):

  inputs = []
  labels = []

  with open(input,'r',encoding='utf-8') as f:
    for line in f:
      label = line.split()[-1]
      input = ' '.join(line.split()[:-1])

      inputs.append(input)
      labels.append(label)

    return inputs, labels

def read_data_csv(input):
    documents, labels = [], []
    with open(input, encoding="UTF-8") as f:
        reader = csv.DictReader(f)
        for row in reader:
            documents.append(row["document"])
            labels.append(row["label"])

    return documents, labels


In [None]:
X_train, Y_train = read_data_csv('data/name/train.csv')
X_dev, Y_dev = read_data_csv('data/name/dev.csv')
X_test, Y_test = read_data_tsv('data/OLID/test_preprocessed.tsv')

In [None]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')


In [None]:
X_train_tokenized = tokenizer(X_train, padding=True, truncation=True, max_length=512)
X_dev_tokenized = tokenizer(X_dev, padding=True, truncation=True, max_length=512)
X_test_tokenized = tokenizer(X_test, padding=True, truncation=True, max_length=512)

Y_train_converted = change_labels(Y_train)
Y_dev_converted = change_labels(Y_dev)
Y_test_converted = change_labels(Y_test)

In [None]:
class Dataset(torch.utils.data.Dataset):
    def __init__(self, encodings, labels=None):
        self.encodings = encodings
        self.labels = labels

    def __getitem__(self, idx):
        item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
        if self.labels:
            item["labels"] = torch.tensor(self.labels[idx])
        return item

    def __len__(self):
        return len(self.encodings["input_ids"])

In [None]:
train_dataset = Dataset(X_train_tokenized, Y_train_converted)
dev_dataset = Dataset(X_dev_tokenized, Y_dev_converted)
test_dataset = Dataset(X_test_tokenized)

In [None]:
def compute_metrics(p):
    print(type(p))
    pred, labels = p
    pred = np.argmax(pred, axis=1)

    accuracy = accuracy_score(y_true=labels, y_pred=pred)

    recall_OFF = recall_score(y_true=labels, y_pred=pred,average=None)[0]
    recall_NOT = recall_score(y_true=labels, y_pred=pred,average=None)[1]

    precision_OFF = precision_score(y_true=labels, y_pred=pred,average=None)[0]
    precision_NOT = precision_score(y_true=labels, y_pred=pred,average=None)[1]

    f1_OFF = f1_score(y_true=labels, y_pred=pred,average=None)[0]
    f1_NOT = f1_score(y_true=labels, y_pred=pred,average=None)[1]



    return {"accuracy": accuracy, 'recall_OFF': recall_OFF, 'recall_NOT': recall_NOT, 'precision_OFF': precision_OFF, 'precision_NOT': precision_NOT, 'f1_OFF': f1_OFF, 'f1_NOT': f1_NOT}

In [None]:
# Define Trainer
def model_init():
    model =  AutoModelForSequenceClassification.from_pretrained('bert-base-uncased',
                                                              num_labels=2)
    model = model.to('cuda')
    return model


from transformers import TrainingArguments, Trainer
args = TrainingArguments(
    output_dir="output_bert",
    num_train_epochs= 4,
    per_device_train_batch_size=16,
    seed = 123,
    evaluation_strategy = "epoch",
    warmup_ratio = 0.1,
    eval_steps = 20,
    learning_rate = 0.00002
)
trainer = Trainer(
    model_init=model_init,
    args=args,
    train_dataset=train_dataset,
    eval_dataset= dev_dataset,
    compute_metrics=compute_metrics
)


In [None]:
trainer.train()

In [None]:
trainer.save_model("models/name")

In [None]:
pred = trainer.predict(test_dataset)

In [None]:
preds = [np.argmax(arr) for arr in list(pred[0])]

In [None]:
preds