In [None]:
from transformers import AutoTokenizer, CanineForMultipleChoice
import torch
import random
import numpy as np
from google.colab import drive

#drive.mount('/content/drive')

tokenizer = AutoTokenizer.from_pretrained("google/canine-s")
model = CanineForMultipleChoice.from_pretrained("google/canine-s")

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


Some weights of CanineForMultipleChoice were not initialized from the model checkpoint at google/canine-s and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [None]:
def val():
    model.eval()
    SP_count = 0
    WP_count = 0
    wp = False
    sp = False
    WP_true = []
    WP_pred = []
    SP_true = []
    SP_pred = []
    with torch.no_grad():
        total_loss_val = 0
        question_count = 0
        for data in np.concatenate((SP_val, WP_val)):
            wp = False
            sp = False

            id = data['id']
            if '_' in id: # exclude reconstructed questions during eval
                continue
            if "WP" in id:
                wp = True
            if "SP" in id:
                sp = True

            question_count += 1
            if wp:
              WP_count += 1
            if sp:
              SP_count += 1

            question = data['question']
            choice0 = data['choice_list'][0]
            choice1 = data['choice_list'][1]
            choice2 = data['choice_list'][2]
            choice3 = data['choice_list'][3]
            #print("processing data id: " + id + " length: " + str(len(question)))
            labels = torch.tensor(0).unsqueeze(0).to(device)  # choice i is correct, batch size 1

            encoding = tokenizer([question, question, question, question], [choice0, choice1, choice2, choice3], return_tensors="pt", padding=True)
            outputs = model(**{k: v.to(device).unsqueeze(0) for k, v in encoding.items()}, labels=labels)  # batch size is 1

            # sanitize dataset
            if data['answer'] not in data['choice_list']:
                question_count -= 1
                if wp:
                    WP_count -= 1
                if sp:
                    SP_count -= 1
                continue

            correct_index = data['choice_list'].index(data['answer'])

            # the linear classifier still needs to be trained
            loss = outputs.loss
            logits = outputs.logits
            total_loss_val += loss

            # Find the index of the correct answer in the choice order
            predicted_answer = np.argmax(F.softmax(logits).cpu().detach().numpy())
            if wp:
                WP_true.append(correct_index)
                WP_pred.append(predicted_answer)
            if sp:
                SP_true.append(correct_index)
                SP_pred.append(predicted_answer)

            #print(f'''\nevaluated data id: {id}, question: {question},\n predicted: {predicted_answer} {data['choice_list'][predicted_answer]}
        #correct: {correct_index} {data['choice_list'][correct_index]}\n''')


        metrics = {}
        for label, true, pred in (("WP", WP_true, WP_pred), ("SP", SP_true, SP_pred)):
            accuracy = accuracy_score(true, pred)
            precision = precision_score(true, pred, average='macro', zero_division=0)
            recall = recall_score(true, pred, average='macro', zero_division=0)
            f1 = f1_score(true, pred, average='macro', zero_division=0)
            metrics[label] = {'accuracy': accuracy, 'precision': precision, 'recall': recall, 'f1': f1}

        total_true = WP_true + SP_true
        total_pred = WP_pred + SP_pred
        metrics['total'] = {
            'accuracy': accuracy_score(total_true, total_pred),
            'precision': precision_score(total_true, total_pred, average='macro', zero_division=0),
            'recall': recall_score(total_true, total_pred, average='macro', zero_division=0),
            'f1': f1_score(total_true, total_pred, average='macro', zero_division=0)
        }
        print("----------------------Validation----------------------")
        print(metrics)

In [None]:
def test():
    model.eval()

    SP_count = 0
    WP_count = 0
    wp = False
    sp = False
    WP_true = []
    WP_pred = []
    SP_true = []
    SP_pred = []

    #with torch.no_grad():
    for data in np.concatenate((SP_test, WP_test)):
        wp = False
        sp = False

        id = data['id']
        if '_' in id: # exclude reconstructed questions during eval
            continue
        if "WP" in id:
            wp = True
        if "SP" in id:
            sp = True

        if wp:
          WP_count += 1
        if sp:
          SP_count += 1

        question = data['question']
        choice0 = data['choice_list'][0]
        choice1 = data['choice_list'][1]
        choice2 = data['choice_list'][2]
        choice3 = data['choice_list'][3]
        #print("processing data id: " + id + " length: " + str(len(question)))
        labels = torch.tensor(0).unsqueeze(0).to(device)  # choice i is correct, batch size 1

        encoding = tokenizer([question, question, question, question], [choice0, choice1, choice2, choice3], return_tensors="pt", padding=True)
        outputs = model(**{k: v.to(device).unsqueeze(0) for k, v in encoding.items()}, labels=labels)  # batch size is 1

        # sanitize dataset
        if data['answer'] not in data['choice_list']:
            if wp:
                WP_count -= 1
            if sp:
                SP_count -= 1
            continue

        correct_index = data['choice_list'].index(data['answer'])

        # the linear classifier still needs to be trained
        loss = outputs.loss
        logits = outputs.logits

        # Find the index of the correct answer in the choice order
        predicted_answer = np.argmax(F.softmax(logits).cpu().detach().numpy())
        if wp:
            WP_true.append(correct_index)
            WP_pred.append(predicted_answer)
        if sp:
            SP_true.append(correct_index)
            SP_pred.append(predicted_answer)

        #print(f'''\nevaluated data id: {id}, question: {question},\n predicted: {predicted_answer} {data['choice_list'][predicted_answer]}
    #correct: {correct_index} {data['choice_list'][correct_index]}\n''')
    metrics = {}
    for label, true, pred in (("WP", WP_true, WP_pred), ("SP", SP_true, SP_pred)):
        accuracy = accuracy_score(true, pred)
        precision = precision_score(true, pred, average='macro', zero_division=0)
        recall = recall_score(true, pred, average='macro', zero_division=0)
        f1 = f1_score(true, pred, average='macro', zero_division=0)
        metrics[label] = {'accuracy': accuracy, 'precision': precision, 'recall': recall, 'f1': f1}

    total_true = WP_true + SP_true
    total_pred = WP_pred + SP_pred
    metrics['total'] = {
        'accuracy': accuracy_score(total_true, total_pred),
        'precision': precision_score(total_true, total_pred, average='macro', zero_division=0),
        'recall': recall_score(total_true, total_pred, average='macro', zero_division=0),
        'f1': f1_score(total_true, total_pred, average='macro', zero_division=0)
    }
    print("----------------------Test----------------------")
    print(metrics)

In [None]:
import torch.nn.functional as F
from sklearn.model_selection import train_test_split
from sklearn.utils import shuffle
import torch.optim as optim
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)
model.to(device)

# Load the data
#SP_all = np.load('/content/drive/My Drive/C5470prj/BrainTeaser/data/SP-train.npy', allow_pickle=True)
#WP_all = np.load('/content/drive/My Drive/C5470prj/BrainTeaser/data/WP-train.npy', allow_pickle=True)
SP_all = np.load('SP-train.npy', allow_pickle=True)
WP_all = np.load('WP-train.npy', allow_pickle=True)

SP_train, SP_test = train_test_split(SP_all, test_size = 0.2, random_state=42)
SP_train, SP_val  = train_test_split(SP_train, test_size = 0.25, random_state=42)
WP_train, WP_test = train_test_split(WP_all, test_size = 0.2, random_state=42)
WP_train, WP_val  = train_test_split(WP_train, test_size = 0.25, random_state=42)

# train
num_epochs = 10

optimizer = optim.Adam(model.parameters(), lr=8e-5)
#optimizer = optim.RMSprop(model.parameters(), lr=0.01)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=4, gamma=0.1)
model.to(device)
model.train()

for epoch in range(num_epochs):
    shuffle(SP_train, random_state=epoch)
    shuffle(WP_train, random_state=epoch)
    total_loss = 0

    bad_qs_train = 0
    correct_answers_train = 0
    model.train()
    for data in np.concatenate((SP_test, WP_test)):
        optimizer.zero_grad()

        id = data['id']
        question = data['question']
        choice0 = data['choice_list'][0]
        choice1 = data['choice_list'][1]
        choice2 = data['choice_list'][2]
        choice3 = data['choice_list'][3]
        labels = torch.tensor(0).unsqueeze(0).to(device)  # choice i is correct, batch size 1
        encoding = tokenizer([question, question, question, question], [choice0, choice1, choice2, choice3], return_tensors="pt", padding=True)
        outputs = model(**{k: v.to(device).unsqueeze(0) for k, v in encoding.items()}, labels=labels)  # batch size is 1

        loss = outputs.loss
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

        if data['answer'] not in data['choice_list']:
            bad_qs_train += 1
        else:
            correct_index = data['choice_list'].index(data['answer'])
            logits = outputs.logits
            predicted_answer = np.argmax(F.softmax(logits).cpu().detach().numpy())
            if predicted_answer == correct_index:
                correct_answers_train += 1
    scheduler.step()
    num_qs_train = len(np.concatenate((SP_train, WP_train)))
    print("----------------------Train----------------------")
    print(f"Epoch: {epoch + 1}, Train Loss: {total_loss / num_qs_train}, Train Acc: {correct_answers_train / (num_qs_train - bad_qs_train)}")

    val()

In [None]:
test()

  predicted_answer = np.argmax(F.softmax(logits).cpu().detach().numpy())


----------------------Test----------------------
{'WP': {'accuracy': 0.4, 'precision': 0.38888888888888884, 'recall': 0.38619528619528615, 'f1': 0.3747252747252747}, 'SP': {'accuracy': 0.23076923076923078, 'precision': 0.15865384615384615, 'recall': 0.19444444444444442, 'f1': 0.16928571428571426}, 'total': {'accuracy': 0.30434782608695654, 'precision': 0.22838509316770186, 'recall': 0.24287878787878786, 'f1': 0.23472449536279322}}


##REsults
####

SP acc: 0.26666666666666666 correct: 44  count: 165
WP acc: 0.25 correct: 32  count: 128
Total Accuracy: 0.26

10 epoch, 5e5
SP acc: 0.25 correct: 5  count: 20
WP acc: 0.3333333333333333 correct: 4  count: 12
Total Accuracy: 0.28

5 epoch, 8e6, scheduler
SP acc: 0.3076923076923077 correct: 12  count: 39
WP acc: 0.26666666666666666 correct: 8  count: 30
Total Accuracy: 0.29

rms 0.01
SP acc: 0.28205128205128205 correct: 11  count: 39
WP acc: 0.36666666666666664 correct: 11  count: 30
Total Accuracy: 0.32

----------------------Validation----------------------
Validation Loss: 1.3862944841384888
SP acc: 0.3125 precision: 0.3125  recall: 1.0 f1: 0.47619047619047616
WP acc: 0.24 precision: 0.24  recall: 1.0 f1: 0.3870967741935484
Total Accuracy: 0.28
Total Precision: 0.28
Total Recall: 1.00
Total F1: 0.44


----------------------Test----------------------
SP acc: 0.3076923076923077 precision: 0.3076923076923077  recall: 1.0 f1: 0.47058823529411764
WP acc: 0.4 precision: 0.4  recall: 1.0 f1: 0.5714285714285715
Total Accuracy: 0.35
Total Precision: 0.35
Total Recall: 1.00
Total F1: 0.52