In [None]:
! pip3 install evaluate
import os
try:
    from google.colab import drive
    drive.mount('/content/drive')
    os.chdir(next((root for root, _, files in os.walk(".") if "dsait4090_project_location" in files), "."))
    print(f'Google Colab: {os.getcwd()}')
except ImportError:
    print(f'Local: {os.getcwd()}')

In [None]:
import logging
import torch

from src.classification_training import ClassificationTraining
from src.common import read_data, QTDataset, get_device, save_data
from src.models.bart_large_mnli import BartLargeMnliTokenizer, BartLargeMnliClassifier
from sklearn.preprocessing import LabelEncoder
from torch.utils.data import TensorDataset
from tqdm.auto import tqdm
from torch import nn

CATEGORIES = ["statistical", "temporal", "interval", "comparison"]

logging.basicConfig(level=logging.ERROR)

device = get_device()

label_endocer = LabelEncoder()
label_endocer.fit(CATEGORIES)
label_endocer.classes_

In [None]:
model = BartLargeMnliClassifier(labels_count=len(CATEGORIES)).to(device)
tokenizer = BartLargeMnliTokenizer()

In [None]:
train_claims = read_data('raw_data/train_claims.json')
val_claims = read_data('raw_data/val_claims.json')
test_claims = read_data('raw_data/test_claims.json')

def build_dataset(claims: QTDataset) -> TensorDataset:
    # features
    features = [claim['claim'] for claim in claims]
    input_tokens = []
    attention_masks = []

    for feature in tqdm(features):
        input_token, attention_mask = tokenizer(feature)
        input_tokens.append(input_token)
        attention_masks.append(attention_mask)

    input_tokens = torch.cat(input_tokens)
    attention_masks = torch.cat(attention_masks)

    # labels
    labels = [claim['taxonomy_label'].strip() for claim in claims]
    encoded_labels = torch.tensor(label_endocer.transform(labels))

    return TensorDataset(input_tokens, attention_masks, encoded_labels)

train_dataset = build_dataset(train_claims)
val_dataset = build_dataset(val_claims)
test_dataset = build_dataset(test_claims)

In [None]:
training = ClassificationTraining(
    model_name="routing_model",
    train_dataset=train_dataset,
    val_dataset=val_dataset,
    model=model,
    optimizer=torch.optim.AdamW(model.parameters(), lr=2e-5, eps=1e-8),
    loss_function=nn.CrossEntropyLoss(),
    batch_size=16,
    device=device,
    random_state=0
)


In [None]:
training.start_new_training()
training.train(epochs=20, patience=3)

In [None]:
training.load_best_model()
test_predictions = training.predict(test_dataset)

In [None]:
def accuracy(output: torch.Tensor, labels: torch.Tensor) -> float:
    if len(output.shape) == 2:
        output = torch.argmax(output, dim=1).flatten()

    labels_flat = labels.flatten()
    return torch.sum(output == labels_flat).item() / len(labels_flat)

test_accuracy = accuracy(test_dataset.tensors[2], torch.tensor(test_predictions))
print(f'Test accuracy: {test_accuracy:.4f}')

#### Inference

In [None]:
training.load_best_model()

val_predictions = training.predict(val_dataset)
test_predictions = training.predict(test_dataset)

In [None]:
decomposition = "flant5"
train_claims = read_data(f'{decomposition}/train_decomposed_{decomposition}.json')
val_claims = read_data(f'{decomposition}/val_decomposed_{decomposition}.json')
test_claims = read_data(f'{decomposition}/test_decomposed_{decomposition}.json')

for claim in train_claims:
    claim['predicted_taxonomy_label'] = claim['taxonomy_label']

counter = 0
for claim, prediction in zip(val_claims, val_predictions):
    claim['predicted_taxonomy_label'] = label_endocer.inverse_transform([prediction])[0]
    counter += claim['predicted_taxonomy_label'] == claim['taxonomy_label']
print(f'Val accuracy: {counter / len(val_claims):.4f}')

counter = 0
for claim, prediction in zip(test_claims, test_predictions):
    claim['predicted_taxonomy_label'] = label_endocer.inverse_transform([prediction])[0]
    counter += claim['predicted_taxonomy_label'] == claim['taxonomy_label']
print(f'Test accuracy: {counter / len(test_claims):.4f}')

save_data(f'custom_decomposition/NEW_train_decomposed_{decomposition}.json', train_claims)
save_data(f'custom_decomposition/NEW_val_decomposed_{decomposition}.json', val_claims)
save_data(f'custom_decomposition/NEW_test_decomposed_{decomposition}.json', test_claims)