<a href="https://colab.research.google.com/github/sycho2003/20252R0136COSE36203/blob/main/%EA%B8%B0%EA%B3%84%ED%95%99%EC%8A%B5_Term_Project_%EC%99%9C%EA%B3%A1%EB%B3%84_%EC%9D%B4%EC%A7%84_%EB%B6%84%EB%A5%98.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# 1. All-or-nothing Thinking

## 1. 모델 정의

### 1-1. 기본 준비

In [None]:
import warnings
warnings.filterwarnings('ignore')

In [None]:
from tqdm import tqdm
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split, ConcatDataset

device = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
data1 = pd.read_csv('d01_preprocessed_revised.csv')
data2 = pd.read_csv('d02_preprocessed.csv')
data3 = pd.read_csv('d03_preprocessed.csv')

### 1-2. 데이터 증강 및 구조화

In [None]:
for idx, i in enumerate(data2['thought']):
    if type(i) != str:
        data2['thought'][idx] = ''

In [None]:
# has_distortion == 1 필터링

data1 = data1[data1['has_distortion'] == 1].reset_index(drop=True)
data2 = data2[data2['has_distortion'] == 1].reset_index(drop=True)
data3 = data3[data3['has_distortion'] == 1].reset_index(drop=True)

In [None]:
data1_1 = data1['situation']+' '+data1['thought']
data2_1 = data2['situation']
data3_1 = data3['situation']+' '+data3['thought']

In [None]:
data1_1.drop_duplicates(inplace = True)
data2_1.drop_duplicates(inplace = True)
data3_1.drop_duplicates(inplace = True)

In [None]:
def normalize_text(s):
    # Removing articles and punctuation, and standardizing whitespace
    import string, re

    def remove_articles(text):
        regex = re.compile(r"\b(a|an|the)\b", re.UNICODE)
        return re.sub(regex, " ", text)

    def white_space_fix(text):
        return " ".join(text.split())

    def remove_punc(text):
        exclude = set(string.punctuation)
        return "".join(ch for ch in text if ch not in exclude)

    def lower(text):
        return text.lower()

    return white_space_fix(remove_articles(remove_punc(lower(s))))


In [None]:
from transformers import BertTokenizer, BertModel, BertConfig

tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
bert_model = BertModel.from_pretrained('bert-base-uncased')
bert_model.to(device)
bert_config = BertConfig.from_pretrained('bert-base-uncased')

tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/440M [00:00<?, ?B/s]

In [None]:
# Embedding

def tokenize_and_pad(data, tokenizer, max_len=512):
    tokenized_data = []
    for text in data:
        encoded = tokenizer(normalize_text(text), return_tensors="pt", padding='max_length', truncation=True, max_length=max_len)
        tokenized_data.append(encoded)
    return tokenized_data

data1_1_encoded = tokenize_and_pad(data1_1, tokenizer)
data2_1_encoded = tokenize_and_pad(data2_1, tokenizer)
data3_1_encoded = tokenize_and_pad(data3_1, tokenizer)

In [None]:
data1.columns

Index(['situation', 'thought', 'reframe', 'has_distortion',
       'all-or-nothing thinking', 'comparing and despairing',
       'disqualifying the positive', 'emotional reasoning', 'fortune telling',
       'labeling', 'magnification', 'mind reading', 'overgeneralizing',
       'should statements', 'mental filter', 'personalization and blaming'],
      dtype='object')

In [None]:
# Add labels
data1_1_labels = list(data1['all-or-nothing thinking'][data1_1.index])
data2_1_labels = list(data2['all-or-nothing thinking'][data2_1.index])
data3_1_labels = list(data3['all-or-nothing thinking'][data3_1.index])

In [None]:
# Merging Data
data_encoded = data1_1_encoded + data2_1_encoded + data3_1_encoded
data_labels = data1_1_labels + data2_1_labels + data3_1_labels

In [None]:
class CustomDatasetWithLabels(Dataset):
    def __init__(self, data, labels):
        self.data = data
        self.labels = labels

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return {"input_ids": self.data[idx]['input_ids'].squeeze(),
                "attention_mask": self.data[idx]['attention_mask'].squeeze(),
                "y": self.labels[idx]}

In [None]:
dataset_with_labels = CustomDatasetWithLabels(data_encoded, data_labels)

# Define proportions for splitting
train_size = int(0.8 * len(dataset_with_labels))
val_size = int(0.1 * len(dataset_with_labels))
test_size = len(dataset_with_labels) - train_size - val_size

# Split the dataset
train_dataset, val_dataset, test_dataset = random_split(dataset_with_labels, [train_size, val_size, test_size])

# Create DataLoaders for each set
train_dataloader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=64)
test_dataloader = DataLoader(test_dataset, batch_size=64)

input_dim = bert_config.hidden_size
num_labels = len(set(data_labels))
label_emb = torch.randn(num_labels, input_dim)

### 1-3. 모델 평가 함수 정의

In [None]:
from sklearn.metrics import accuracy_score, f1_score

def evaluate(model, dataloader, device="cpu"):
    model.eval()
    all_preds, all_labels = [], []

    with torch.no_grad():
        for batch in dataloader:
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            y = batch["y"].to(device)

            # Get embeddings from the BERT model
            outputs = bert_model(input_ids=input_ids, attention_mask=attention_mask)
            embeddings = outputs.last_hidden_state[:, 0, :] # Get the [CLS] token embedding

            logits = model(embeddings)
            preds = torch.argmax(logits, dim=1)
            all_preds.extend(preds.cpu().tolist())
            all_labels.extend(y.cpu().tolist())

    acc = accuracy_score(all_labels, all_preds)
    f1_macro = f1_score(all_labels, all_preds, average="macro", zero_division=0)

    return {"accuracy": acc, "f1_macro": f1_macro}

In [None]:
# Classifier that uses label embeddings to make predictions
class InnerProductClassifier(nn.Module):
    def __init__(self, input_dim, label_embeddings, trainable_label_emb=True):
        super().__init__()
        # Project input features into the same dimension as label embeddings
        self.proj = nn.Linear(input_dim, label_embeddings.size(1))

        if trainable_label_emb:
            # Label embeddings are trainable parameters
            self.label_emb = nn.Parameter(label_embeddings.clone())
        else:
            # Label embeddings are fixed (not updated during training)
            self.register_buffer("label_emb", label_embeddings.clone())

    def forward(self, x):
        # Project input feature vectors
        x_proj = self.proj(x)
        # Compute logits as similarity with each label embedding
        logits = torch.matmul(x_proj, self.label_emb.T)
        return logits

### 1-4. 모델 생성

In [None]:
# Instantiate the InnerProductClassifier model
model = InnerProductClassifier(input_dim, label_emb).to(device)

# Define loss function
criterion = nn.CrossEntropyLoss()

# Define optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=2e-4)

## 2. 모델 학습

In [None]:
EPOCHS = 10

# Move the BERT model to the device
bert_model.to(device)

for epoch in range(EPOCHS):
    model.train()
    for batch in tqdm(train_dataloader, desc=f"Epoch {epoch+1}"):

        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        y = batch["y"].to(device)

        # Get embeddings from the BERT model
        with torch.no_grad():
            outputs = bert_model(input_ids=input_ids, attention_mask=attention_mask)
            embeddings = outputs.last_hidden_state[:, 0, :] # Get the [CLS] token embedding

        # Pass embeddings to the InnerProductClassifier
        logits = model(embeddings)
        loss = criterion(logits, y)

        # Backpropagate and update weights
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        print(f"Loss: {loss.item()}")

    # Evaluate on the validation set after each epoch
    val_metrics = evaluate(model, val_dataloader, device)
    print(f"Epoch {epoch+1} Validation Accuracy: {val_metrics['accuracy']}, F1-macro: {val_metrics['f1_macro']}")

Epoch 1:   4%|▍         | 1/26 [00:02<01:00,  2.41s/it]

Loss: 1.4566470384597778


Epoch 1:   8%|▊         | 2/26 [00:04<00:46,  1.94s/it]

Loss: 18.050392150878906


Epoch 1:  12%|█▏        | 3/26 [00:05<00:41,  1.81s/it]

Loss: 0.5756654739379883


Epoch 1:  15%|█▌        | 4/26 [00:07<00:38,  1.74s/it]

Loss: 1.2390334606170654


Epoch 1:  19%|█▉        | 5/26 [00:08<00:35,  1.70s/it]

Loss: 5.171160697937012


Epoch 1:  23%|██▎       | 6/26 [00:10<00:33,  1.68s/it]

Loss: 4.379364490509033


Epoch 1:  27%|██▋       | 7/26 [00:12<00:31,  1.67s/it]

Loss: 3.4084506034851074


Epoch 1:  31%|███       | 8/26 [00:13<00:29,  1.66s/it]

Loss: 3.2323710918426514


Epoch 1:  35%|███▍      | 9/26 [00:15<00:28,  1.65s/it]

Loss: 8.642477035522461


Epoch 1:  38%|███▊      | 10/26 [00:17<00:26,  1.65s/it]

Loss: 6.635566711425781


Epoch 1:  42%|████▏     | 11/26 [00:18<00:24,  1.65s/it]

Loss: 11.360759735107422


Epoch 1:  46%|████▌     | 12/26 [00:20<00:23,  1.66s/it]

Loss: 3.51615834236145


Epoch 1:  50%|█████     | 13/26 [00:22<00:21,  1.66s/it]

Loss: 5.15452766418457


Epoch 1:  54%|█████▍    | 14/26 [00:23<00:19,  1.66s/it]

Loss: 2.8383162021636963


Epoch 1:  58%|█████▊    | 15/26 [00:25<00:18,  1.66s/it]

Loss: 4.4014573097229


Epoch 1:  62%|██████▏   | 16/26 [00:27<00:16,  1.66s/it]

Loss: 3.294191360473633


Epoch 1:  65%|██████▌   | 17/26 [00:28<00:14,  1.67s/it]

Loss: 1.2320595979690552


Epoch 1:  69%|██████▉   | 18/26 [00:30<00:13,  1.67s/it]

Loss: 2.0341086387634277


Epoch 1:  73%|███████▎  | 19/26 [00:32<00:11,  1.67s/it]

Loss: 2.1707301139831543


Epoch 1:  77%|███████▋  | 20/26 [00:33<00:10,  1.68s/it]

Loss: 2.995382308959961


Epoch 1:  81%|████████  | 21/26 [00:35<00:08,  1.68s/it]

Loss: 2.2258424758911133


Epoch 1:  85%|████████▍ | 22/26 [00:37<00:06,  1.68s/it]

Loss: 1.8356106281280518


Epoch 1:  88%|████████▊ | 23/26 [00:38<00:05,  1.68s/it]

Loss: 1.015738606452942


Epoch 1:  92%|█████████▏| 24/26 [00:40<00:03,  1.68s/it]

Loss: 0.7638213634490967


Epoch 1:  96%|█████████▌| 25/26 [00:42<00:01,  1.68s/it]

Loss: 2.02744460105896


Epoch 1: 100%|██████████| 26/26 [00:43<00:00,  1.65s/it]

Loss: 3.162057876586914





Epoch 1 Validation Accuracy: 0.916256157635468, F1-macro: 0.4781491002570694


Epoch 2:   4%|▍         | 1/26 [00:01<00:42,  1.70s/it]

Loss: 2.117368459701538


Epoch 2:   8%|▊         | 2/26 [00:03<00:40,  1.70s/it]

Loss: 1.081818699836731


Epoch 2:  12%|█▏        | 3/26 [00:05<00:39,  1.70s/it]

Loss: 1.7912166118621826


Epoch 2:  15%|█▌        | 4/26 [00:06<00:37,  1.70s/it]

Loss: 0.5940866470336914


Epoch 2:  19%|█▉        | 5/26 [00:08<00:35,  1.71s/it]

Loss: 0.7043902277946472


Epoch 2:  23%|██▎       | 6/26 [00:10<00:34,  1.71s/it]

Loss: 1.5568548440933228


Epoch 2:  27%|██▋       | 7/26 [00:11<00:32,  1.71s/it]

Loss: 1.0843801498413086


Epoch 2:  31%|███       | 8/26 [00:13<00:30,  1.71s/it]

Loss: 1.1250886917114258


Epoch 2:  35%|███▍      | 9/26 [00:15<00:28,  1.71s/it]

Loss: 0.06791543960571289


Epoch 2:  38%|███▊      | 10/26 [00:17<00:27,  1.70s/it]

Loss: 2.091064691543579


Epoch 2:  42%|████▏     | 11/26 [00:18<00:25,  1.70s/it]

Loss: 1.2587509155273438


Epoch 2:  46%|████▌     | 12/26 [00:20<00:23,  1.70s/it]

Loss: 0.9479106664657593


Epoch 2:  50%|█████     | 13/26 [00:22<00:22,  1.70s/it]

Loss: 1.9226386547088623


Epoch 2:  54%|█████▍    | 14/26 [00:23<00:20,  1.70s/it]

Loss: 0.5986491441726685


Epoch 2:  58%|█████▊    | 15/26 [00:25<00:18,  1.71s/it]

Loss: 0.942541241645813


Epoch 2:  62%|██████▏   | 16/26 [00:27<00:17,  1.71s/it]

Loss: 0.9678117036819458


Epoch 2:  65%|██████▌   | 17/26 [00:29<00:15,  1.72s/it]

Loss: 1.2922148704528809


Epoch 2:  69%|██████▉   | 18/26 [00:30<00:13,  1.72s/it]

Loss: 0.8507363200187683


Epoch 2:  73%|███████▎  | 19/26 [00:32<00:12,  1.72s/it]

Loss: 0.9680067896842957


Epoch 2:  77%|███████▋  | 20/26 [00:34<00:10,  1.72s/it]

Loss: 0.49939167499542236


Epoch 2:  81%|████████  | 21/26 [00:35<00:08,  1.73s/it]

Loss: 0.9658344388008118


Epoch 2:  85%|████████▍ | 22/26 [00:37<00:06,  1.73s/it]

Loss: 0.9325984716415405


Epoch 2:  88%|████████▊ | 23/26 [00:39<00:05,  1.73s/it]

Loss: 1.3955531120300293


Epoch 2:  92%|█████████▏| 24/26 [00:41<00:03,  1.74s/it]

Loss: 0.876897931098938


Epoch 2:  96%|█████████▌| 25/26 [00:42<00:01,  1.74s/it]

Loss: 0.4714553654193878


Epoch 2: 100%|██████████| 26/26 [00:43<00:00,  1.68s/it]

Loss: 1.1144391298294067





Epoch 2 Validation Accuracy: 0.729064039408867, F1-macro: 0.49627791563275436


Epoch 3:   4%|▍         | 1/26 [00:01<00:43,  1.75s/it]

Loss: 0.5708145499229431


Epoch 3:   8%|▊         | 2/26 [00:03<00:42,  1.75s/it]

Loss: 0.44430994987487793


Epoch 3:  12%|█▏        | 3/26 [00:05<00:40,  1.76s/it]

Loss: 0.7175259590148926


Epoch 3:  15%|█▌        | 4/26 [00:07<00:38,  1.76s/it]

Loss: 0.7165786623954773


Epoch 3:  19%|█▉        | 5/26 [00:08<00:37,  1.76s/it]

Loss: 0.6616666913032532


Epoch 3:  23%|██▎       | 6/26 [00:10<00:35,  1.76s/it]

Loss: 0.6908207535743713


Epoch 3:  27%|██▋       | 7/26 [00:12<00:33,  1.76s/it]

Loss: 0.4260439872741699


Epoch 3:  31%|███       | 8/26 [00:14<00:31,  1.77s/it]

Loss: 0.9975721836090088


Epoch 3:  35%|███▍      | 9/26 [00:15<00:30,  1.77s/it]

Loss: 0.7019128799438477


Epoch 3:  38%|███▊      | 10/26 [00:17<00:28,  1.77s/it]

Loss: 0.8821214437484741


Epoch 3:  42%|████▏     | 11/26 [00:19<00:26,  1.77s/it]

Loss: 0.9956616759300232


Epoch 3:  46%|████▌     | 12/26 [00:21<00:24,  1.77s/it]

Loss: 0.5907788276672363


Epoch 3:  50%|█████     | 13/26 [00:22<00:23,  1.78s/it]

Loss: 0.6979962587356567


Epoch 3:  54%|█████▍    | 14/26 [00:24<00:21,  1.78s/it]

Loss: 0.6076666116714478


Epoch 3:  58%|█████▊    | 15/26 [00:26<00:19,  1.78s/it]

Loss: 0.8571183681488037


Epoch 3:  62%|██████▏   | 16/26 [00:28<00:17,  1.78s/it]

Loss: 0.729981541633606


Epoch 3:  65%|██████▌   | 17/26 [00:30<00:16,  1.79s/it]

Loss: 0.15611177682876587


Epoch 3:  69%|██████▉   | 18/26 [00:31<00:14,  1.79s/it]

Loss: 0.3649470806121826


Epoch 3:  73%|███████▎  | 19/26 [00:33<00:12,  1.79s/it]

Loss: 1.0063241720199585


Epoch 3:  77%|███████▋  | 20/26 [00:35<00:10,  1.79s/it]

Loss: 0.7471842169761658


Epoch 3:  81%|████████  | 21/26 [00:37<00:08,  1.80s/it]

Loss: 0.6766709089279175


Epoch 3:  85%|████████▍ | 22/26 [00:39<00:07,  1.80s/it]

Loss: 0.5673350691795349


Epoch 3:  88%|████████▊ | 23/26 [00:40<00:05,  1.80s/it]

Loss: 0.5439006090164185


Epoch 3:  92%|█████████▏| 24/26 [00:42<00:03,  1.81s/it]

Loss: 0.26104727387428284


Epoch 3:  96%|█████████▌| 25/26 [00:44<00:01,  1.81s/it]

Loss: 0.5476090908050537


Epoch 3: 100%|██████████| 26/26 [00:45<00:00,  1.75s/it]

Loss: 1.2635418176651





Epoch 3 Validation Accuracy: 0.8275862068965517, F1-macro: 0.544813889422769


Epoch 4:   4%|▍         | 1/26 [00:01<00:45,  1.82s/it]

Loss: 0.4374149441719055


Epoch 4:   8%|▊         | 2/26 [00:03<00:43,  1.81s/it]

Loss: 0.8650455474853516


Epoch 4:  12%|█▏        | 3/26 [00:05<00:41,  1.80s/it]

Loss: 0.6366218328475952


Epoch 4:  15%|█▌        | 4/26 [00:07<00:39,  1.80s/it]

Loss: 0.7003613710403442


Epoch 4:  19%|█▉        | 5/26 [00:09<00:37,  1.80s/it]

Loss: 0.9340442419052124


Epoch 4:  23%|██▎       | 6/26 [00:10<00:36,  1.80s/it]

Loss: 0.5092313289642334


Epoch 4:  27%|██▋       | 7/26 [00:12<00:34,  1.81s/it]

Loss: 0.9278044700622559


Epoch 4:  31%|███       | 8/26 [00:14<00:32,  1.81s/it]

Loss: 0.4796421229839325


Epoch 4:  35%|███▍      | 9/26 [00:16<00:30,  1.81s/it]

Loss: 0.7456859946250916


Epoch 4:  38%|███▊      | 10/26 [00:18<00:28,  1.81s/it]

Loss: 0.14380186796188354


Epoch 4:  42%|████▏     | 11/26 [00:19<00:27,  1.82s/it]

Loss: 0.7190815806388855


Epoch 4:  46%|████▌     | 12/26 [00:21<00:25,  1.82s/it]

Loss: 0.37774500250816345


Epoch 4:  50%|█████     | 13/26 [00:23<00:23,  1.82s/it]

Loss: 0.5012857913970947


Epoch 4:  54%|█████▍    | 14/26 [00:25<00:21,  1.82s/it]

Loss: 0.16564899682998657


Epoch 4:  58%|█████▊    | 15/26 [00:27<00:20,  1.82s/it]

Loss: 0.352311909198761


Epoch 4:  62%|██████▏   | 16/26 [00:29<00:18,  1.83s/it]

Loss: 0.7378534078598022


Epoch 4:  65%|██████▌   | 17/26 [00:30<00:16,  1.83s/it]

Loss: 0.5085819363594055


Epoch 4:  69%|██████▉   | 18/26 [00:32<00:14,  1.83s/it]

Loss: 0.5680617094039917


Epoch 4:  73%|███████▎  | 19/26 [00:34<00:12,  1.84s/it]

Loss: 1.0692944526672363


Epoch 4:  77%|███████▋  | 20/26 [00:36<00:11,  1.84s/it]

Loss: 0.35725316405296326


Epoch 4:  81%|████████  | 21/26 [00:38<00:09,  1.84s/it]

Loss: 0.8200703859329224


Epoch 4:  85%|████████▍ | 22/26 [00:40<00:07,  1.84s/it]

Loss: 0.7183581590652466


Epoch 4:  88%|████████▊ | 23/26 [00:41<00:05,  1.84s/it]

Loss: 1.0636662244796753


Epoch 4:  92%|█████████▏| 24/26 [00:43<00:03,  1.85s/it]

Loss: 0.5678508281707764


Epoch 4:  96%|█████████▌| 25/26 [00:45<00:01,  1.85s/it]

Loss: 0.1375110149383545


Epoch 4: 100%|██████████| 26/26 [00:46<00:00,  1.79s/it]

Loss: 0.48492711782455444





Epoch 4 Validation Accuracy: 0.8522167487684729, F1-macro: 0.5645022883295194


Epoch 5:   4%|▍         | 1/26 [00:01<00:46,  1.86s/it]

Loss: 0.34405791759490967


Epoch 5:   8%|▊         | 2/26 [00:03<00:44,  1.86s/it]

Loss: 0.7378703951835632


Epoch 5:  12%|█▏        | 3/26 [00:05<00:42,  1.86s/it]

Loss: 0.5406850576400757


Epoch 5:  15%|█▌        | 4/26 [00:07<00:41,  1.86s/it]

Loss: 0.3559906482696533


Epoch 5:  19%|█▉        | 5/26 [00:09<00:39,  1.87s/it]

Loss: 0.6351155638694763


Epoch 5:  23%|██▎       | 6/26 [00:11<00:37,  1.87s/it]

Loss: 0.6637234687805176


Epoch 5:  27%|██▋       | 7/26 [00:13<00:35,  1.87s/it]

Loss: 0.5028127431869507


Epoch 5:  31%|███       | 8/26 [00:14<00:33,  1.87s/it]

Loss: 0.6274831891059875


Epoch 5:  35%|███▍      | 9/26 [00:16<00:31,  1.87s/it]

Loss: 0.2643142342567444


Epoch 5:  38%|███▊      | 10/26 [00:18<00:29,  1.87s/it]

Loss: 0.7349002361297607


Epoch 5:  42%|████▏     | 11/26 [00:20<00:28,  1.87s/it]

Loss: 0.46140819787979126


Epoch 5:  46%|████▌     | 12/26 [00:22<00:26,  1.87s/it]

Loss: 0.6611006855964661


Epoch 5:  50%|█████     | 13/26 [00:24<00:24,  1.88s/it]

Loss: 0.7681806087493896


Epoch 5:  54%|█████▍    | 14/26 [00:26<00:22,  1.88s/it]

Loss: 0.43547558784484863


Epoch 5:  58%|█████▊    | 15/26 [00:28<00:20,  1.88s/it]

Loss: 0.7748555541038513


Epoch 5:  62%|██████▏   | 16/26 [00:29<00:18,  1.88s/it]

Loss: 0.42661917209625244


Epoch 5:  65%|██████▌   | 17/26 [00:31<00:16,  1.88s/it]

Loss: 0.5960219502449036


Epoch 5:  69%|██████▉   | 18/26 [00:33<00:15,  1.88s/it]

Loss: 0.22553083300590515


Epoch 5:  73%|███████▎  | 19/26 [00:35<00:13,  1.88s/it]

Loss: 0.6395750045776367


Epoch 5:  77%|███████▋  | 20/26 [00:37<00:11,  1.89s/it]

Loss: 0.8974602818489075


Epoch 5:  81%|████████  | 21/26 [00:39<00:09,  1.89s/it]

Loss: 0.2851725220680237


Epoch 5:  85%|████████▍ | 22/26 [00:41<00:07,  1.89s/it]

Loss: 0.4740316867828369


Epoch 5:  88%|████████▊ | 23/26 [00:43<00:05,  1.89s/it]

Loss: 0.27961331605911255


Epoch 5:  92%|█████████▏| 24/26 [00:45<00:03,  1.89s/it]

Loss: 0.6432197690010071


Epoch 5:  96%|█████████▌| 25/26 [00:46<00:01,  1.89s/it]

Loss: 0.5701252222061157


Epoch 5: 100%|██████████| 26/26 [00:47<00:00,  1.84s/it]

Loss: 0.4351448714733124





Epoch 5 Validation Accuracy: 0.8817733990147784, F1-macro: 0.5068825910931174


Epoch 6:   4%|▍         | 1/26 [00:01<00:47,  1.90s/it]

Loss: 0.3551412522792816


Epoch 6:   8%|▊         | 2/26 [00:03<00:45,  1.90s/it]

Loss: 0.4441754221916199


Epoch 6:  12%|█▏        | 3/26 [00:05<00:43,  1.90s/it]

Loss: 0.201445534825325


Epoch 6:  15%|█▌        | 4/26 [00:07<00:41,  1.90s/it]

Loss: 0.46208587288856506


Epoch 6:  19%|█▉        | 5/26 [00:09<00:39,  1.90s/it]

Loss: 0.2815588116645813


Epoch 6:  23%|██▎       | 6/26 [00:11<00:38,  1.90s/it]

Loss: 0.6007078289985657


Epoch 6:  27%|██▋       | 7/26 [00:13<00:36,  1.90s/it]

Loss: 0.30220329761505127


Epoch 6:  31%|███       | 8/26 [00:15<00:34,  1.90s/it]

Loss: 0.930885910987854


Epoch 6:  35%|███▍      | 9/26 [00:17<00:32,  1.90s/it]

Loss: 0.38886240124702454


Epoch 6:  38%|███▊      | 10/26 [00:19<00:30,  1.90s/it]

Loss: 0.35551610589027405


Epoch 6:  42%|████▏     | 11/26 [00:20<00:28,  1.91s/it]

Loss: 1.1330773830413818


Epoch 6:  46%|████▌     | 12/26 [00:22<00:26,  1.91s/it]

Loss: 0.3805089592933655


Epoch 6:  50%|█████     | 13/26 [00:24<00:24,  1.91s/it]

Loss: 0.8559824228286743


Epoch 6:  54%|█████▍    | 14/26 [00:26<00:22,  1.91s/it]

Loss: 0.5104929208755493


Epoch 6:  58%|█████▊    | 15/26 [00:28<00:21,  1.91s/it]

Loss: 0.9030967950820923


Epoch 6:  62%|██████▏   | 16/26 [00:30<00:19,  1.91s/it]

Loss: 0.48010456562042236


Epoch 6:  65%|██████▌   | 17/26 [00:32<00:17,  1.91s/it]

Loss: 0.5038645267486572


Epoch 6:  69%|██████▉   | 18/26 [00:34<00:15,  1.91s/it]

Loss: 0.3209690451622009


Epoch 6:  73%|███████▎  | 19/26 [00:36<00:13,  1.91s/it]

Loss: 0.6017467975616455


Epoch 6:  77%|███████▋  | 20/26 [00:38<00:11,  1.92s/it]

Loss: 0.3170323967933655


Epoch 6:  81%|████████  | 21/26 [00:40<00:09,  1.92s/it]

Loss: 0.4249286353588104


Epoch 6:  85%|████████▍ | 22/26 [00:41<00:07,  1.92s/it]

Loss: 0.557874321937561


Epoch 6:  88%|████████▊ | 23/26 [00:43<00:05,  1.92s/it]

Loss: 0.21578624844551086


Epoch 6:  92%|█████████▏| 24/26 [00:45<00:03,  1.92s/it]

Loss: 0.4072372317314148


Epoch 6:  96%|█████████▌| 25/26 [00:47<00:01,  1.92s/it]

Loss: 0.3063594698905945


Epoch 6: 100%|██████████| 26/26 [00:48<00:00,  1.87s/it]

Loss: 0.7481141686439514





Epoch 6 Validation Accuracy: 0.8226600985221675, F1-macro: 0.558695652173913


Epoch 7:   4%|▍         | 1/26 [00:01<00:47,  1.91s/it]

Loss: 0.43795764446258545


Epoch 7:   8%|▊         | 2/26 [00:03<00:45,  1.91s/it]

Loss: 0.29084932804107666


Epoch 7:  12%|█▏        | 3/26 [00:05<00:44,  1.91s/it]

Loss: 0.657396137714386


Epoch 7:  15%|█▌        | 4/26 [00:07<00:42,  1.91s/it]

Loss: 0.20893456041812897


Epoch 7:  19%|█▉        | 5/26 [00:09<00:40,  1.91s/it]

Loss: 0.31033384799957275


Epoch 7:  23%|██▎       | 6/26 [00:11<00:38,  1.91s/it]

Loss: 0.6897841691970825


Epoch 7:  27%|██▋       | 7/26 [00:13<00:36,  1.91s/it]

Loss: 0.42968010902404785


Epoch 7:  31%|███       | 8/26 [00:15<00:34,  1.91s/it]

Loss: 0.3707422912120819


Epoch 7:  35%|███▍      | 9/26 [00:17<00:32,  1.91s/it]

Loss: 0.8122372627258301


Epoch 7:  38%|███▊      | 10/26 [00:19<00:30,  1.91s/it]

Loss: 0.20137013494968414


Epoch 7:  42%|████▏     | 11/26 [00:21<00:28,  1.91s/it]

Loss: 0.13410070538520813


Epoch 7:  46%|████▌     | 12/26 [00:22<00:26,  1.91s/it]

Loss: 0.5096104145050049


Epoch 7:  50%|█████     | 13/26 [00:24<00:24,  1.91s/it]

Loss: 0.4981188476085663


Epoch 7:  54%|█████▍    | 14/26 [00:26<00:22,  1.91s/it]

Loss: 0.47817763686180115


Epoch 7:  58%|█████▊    | 15/26 [00:28<00:20,  1.90s/it]

Loss: 0.26383811235427856


Epoch 7:  62%|██████▏   | 16/26 [00:30<00:19,  1.91s/it]

Loss: 0.39109688997268677


Epoch 7:  65%|██████▌   | 17/26 [00:32<00:17,  1.90s/it]

Loss: 0.32746419310569763


Epoch 7:  69%|██████▉   | 18/26 [00:34<00:15,  1.90s/it]

Loss: 0.22442637383937836


Epoch 7:  73%|███████▎  | 19/26 [00:36<00:13,  1.90s/it]

Loss: 0.46799957752227783


Epoch 7:  77%|███████▋  | 20/26 [00:38<00:11,  1.90s/it]

Loss: 0.3882253170013428


Epoch 7:  81%|████████  | 21/26 [00:40<00:09,  1.90s/it]

Loss: 0.70601487159729


Epoch 7:  85%|████████▍ | 22/26 [00:41<00:07,  1.90s/it]

Loss: 0.130735844373703


Epoch 7:  88%|████████▊ | 23/26 [00:43<00:05,  1.90s/it]

Loss: 0.3352527618408203


Epoch 7:  92%|█████████▏| 24/26 [00:45<00:03,  1.90s/it]

Loss: 0.9743393659591675


Epoch 7:  96%|█████████▌| 25/26 [00:47<00:01,  1.90s/it]

Loss: 0.33469751477241516


Epoch 7: 100%|██████████| 26/26 [00:48<00:00,  1.87s/it]

Loss: 0.5312036871910095





Epoch 7 Validation Accuracy: 0.4729064039408867, F1-macro: 0.406286729533962


Epoch 8:   4%|▍         | 1/26 [00:01<00:47,  1.90s/it]

Loss: 1.120894432067871


Epoch 8:   8%|▊         | 2/26 [00:03<00:45,  1.90s/it]

Loss: 0.25376856327056885


Epoch 8:  12%|█▏        | 3/26 [00:05<00:43,  1.90s/it]

Loss: 0.4108927249908447


Epoch 8:  15%|█▌        | 4/26 [00:07<00:41,  1.90s/it]

Loss: 1.2272050380706787


Epoch 8:  19%|█▉        | 5/26 [00:09<00:39,  1.90s/it]

Loss: 0.6458331346511841


Epoch 8:  23%|██▎       | 6/26 [00:11<00:38,  1.90s/it]

Loss: 0.29389405250549316


Epoch 8:  27%|██▋       | 7/26 [00:13<00:36,  1.90s/it]

Loss: 0.516508936882019


Epoch 8:  31%|███       | 8/26 [00:15<00:34,  1.90s/it]

Loss: 0.7138612270355225


Epoch 8:  35%|███▍      | 9/26 [00:17<00:32,  1.90s/it]

Loss: 0.5036657452583313


Epoch 8:  38%|███▊      | 10/26 [00:19<00:30,  1.90s/it]

Loss: 0.4133458733558655


Epoch 8:  42%|████▏     | 11/26 [00:20<00:28,  1.90s/it]

Loss: 0.4732494652271271


Epoch 8:  46%|████▌     | 12/26 [00:22<00:26,  1.90s/it]

Loss: 0.7018386125564575


Epoch 8:  50%|█████     | 13/26 [00:24<00:24,  1.90s/it]

Loss: 0.40534794330596924


Epoch 8:  54%|█████▍    | 14/26 [00:26<00:22,  1.90s/it]

Loss: 0.7043266296386719


Epoch 8:  58%|█████▊    | 15/26 [00:28<00:20,  1.90s/it]

Loss: 0.2592070400714874


Epoch 8:  62%|██████▏   | 16/26 [00:30<00:19,  1.90s/it]

Loss: 0.7203128337860107


Epoch 8:  65%|██████▌   | 17/26 [00:32<00:17,  1.91s/it]

Loss: 0.45185256004333496


Epoch 8:  69%|██████▉   | 18/26 [00:34<00:15,  1.91s/it]

Loss: 0.21177202463150024


Epoch 8:  73%|███████▎  | 19/26 [00:36<00:13,  1.91s/it]

Loss: 0.21142104268074036


Epoch 8:  77%|███████▋  | 20/26 [00:38<00:11,  1.91s/it]

Loss: 0.5588212013244629


Epoch 8:  81%|████████  | 21/26 [00:39<00:09,  1.91s/it]

Loss: 0.6006197333335876


Epoch 8:  85%|████████▍ | 22/26 [00:41<00:07,  1.91s/it]

Loss: 0.36599671840667725


Epoch 8:  88%|████████▊ | 23/26 [00:43<00:05,  1.91s/it]

Loss: 0.3316906690597534


Epoch 8:  92%|█████████▏| 24/26 [00:45<00:03,  1.91s/it]

Loss: 0.28812268376350403


Epoch 8:  96%|█████████▌| 25/26 [00:47<00:01,  1.91s/it]

Loss: 0.5380284786224365


Epoch 8: 100%|██████████| 26/26 [00:48<00:00,  1.86s/it]

Loss: 0.3749788701534271





Epoch 8 Validation Accuracy: 0.7881773399014779, F1-macro: 0.5334331070607728


Epoch 9:   4%|▍         | 1/26 [00:01<00:47,  1.90s/it]

Loss: 0.5437136292457581


Epoch 9:   8%|▊         | 2/26 [00:03<00:45,  1.91s/it]

Loss: 0.3010217547416687


Epoch 9:  12%|█▏        | 3/26 [00:05<00:43,  1.91s/it]

Loss: 0.4818519353866577


Epoch 9:  15%|█▌        | 4/26 [00:07<00:41,  1.91s/it]

Loss: 0.09246662259101868


Epoch 9:  19%|█▉        | 5/26 [00:09<00:40,  1.91s/it]

Loss: 0.21429820358753204


Epoch 9:  23%|██▎       | 6/26 [00:11<00:38,  1.91s/it]

Loss: 0.1382281482219696


Epoch 9:  27%|██▋       | 7/26 [00:13<00:36,  1.91s/it]

Loss: 0.2703782320022583


Epoch 9:  31%|███       | 8/26 [00:15<00:34,  1.91s/it]

Loss: 0.37745973467826843


Epoch 9:  35%|███▍      | 9/26 [00:17<00:32,  1.91s/it]

Loss: 0.15281942486763


Epoch 9:  38%|███▊      | 10/26 [00:19<00:30,  1.91s/it]

Loss: 0.37182772159576416


Epoch 9:  42%|████▏     | 11/26 [00:20<00:28,  1.91s/it]

Loss: 0.28077399730682373


Epoch 9:  46%|████▌     | 12/26 [00:22<00:26,  1.91s/it]

Loss: 0.4674490690231323


Epoch 9:  50%|█████     | 13/26 [00:24<00:24,  1.91s/it]

Loss: 0.3513967990875244


Epoch 9:  54%|█████▍    | 14/26 [00:26<00:22,  1.91s/it]

Loss: 0.22057290375232697


Epoch 9:  58%|█████▊    | 15/26 [00:28<00:21,  1.91s/it]

Loss: 0.12082123756408691


Epoch 9:  62%|██████▏   | 16/26 [00:30<00:19,  1.91s/it]

Loss: 0.3232276737689972


Epoch 9:  65%|██████▌   | 17/26 [00:32<00:17,  1.91s/it]

Loss: 0.3382676839828491


Epoch 9:  69%|██████▉   | 18/26 [00:34<00:15,  1.91s/it]

Loss: 0.36049723625183105


Epoch 9:  73%|███████▎  | 19/26 [00:36<00:13,  1.91s/it]

Loss: 0.2743668556213379


Epoch 9:  77%|███████▋  | 20/26 [00:38<00:11,  1.91s/it]

Loss: 0.37113919854164124


Epoch 9:  81%|████████  | 21/26 [00:40<00:09,  1.91s/it]

Loss: 0.4134613871574402


Epoch 9:  85%|████████▍ | 22/26 [00:41<00:07,  1.91s/it]

Loss: 0.4072524905204773


Epoch 9:  88%|████████▊ | 23/26 [00:43<00:05,  1.91s/it]

Loss: 0.2404804527759552


Epoch 9:  92%|█████████▏| 24/26 [00:45<00:03,  1.91s/it]

Loss: 0.29137927293777466


Epoch 9:  96%|█████████▌| 25/26 [00:47<00:01,  1.91s/it]

Loss: 0.447298139333725


Epoch 9: 100%|██████████| 26/26 [00:48<00:00,  1.87s/it]

Loss: 0.5801486968994141





Epoch 9 Validation Accuracy: 0.9113300492610837, F1-macro: 0.6014397905759162


Epoch 10:   4%|▍         | 1/26 [00:01<00:47,  1.90s/it]

Loss: 0.3233177065849304


Epoch 10:   8%|▊         | 2/26 [00:03<00:45,  1.91s/it]

Loss: 0.32391637563705444


Epoch 10:  12%|█▏        | 3/26 [00:05<00:43,  1.91s/it]

Loss: 0.4229435920715332


Epoch 10:  15%|█▌        | 4/26 [00:07<00:42,  1.91s/it]

Loss: 0.25755152106285095


Epoch 10:  19%|█▉        | 5/26 [00:09<00:40,  1.91s/it]

Loss: 0.273669958114624


Epoch 10:  23%|██▎       | 6/26 [00:11<00:38,  1.91s/it]

Loss: 0.1631828397512436


Epoch 10:  27%|██▋       | 7/26 [00:13<00:36,  1.91s/it]

Loss: 0.4245586097240448


Epoch 10:  31%|███       | 8/26 [00:15<00:34,  1.91s/it]

Loss: 0.2030135840177536


Epoch 10:  35%|███▍      | 9/26 [00:17<00:32,  1.91s/it]

Loss: 0.1422646939754486


Epoch 10:  38%|███▊      | 10/26 [00:19<00:30,  1.91s/it]

Loss: 0.28177034854888916


Epoch 10:  42%|████▏     | 11/26 [00:21<00:28,  1.91s/it]

Loss: 0.30519765615463257


Epoch 10:  46%|████▌     | 12/26 [00:22<00:26,  1.91s/it]

Loss: 0.3007891774177551


Epoch 10:  50%|█████     | 13/26 [00:24<00:24,  1.91s/it]

Loss: 0.13957755267620087


Epoch 10:  54%|█████▍    | 14/26 [00:26<00:22,  1.91s/it]

Loss: 0.431402325630188


Epoch 10:  58%|█████▊    | 15/26 [00:28<00:21,  1.91s/it]

Loss: 0.01681768149137497


Epoch 10:  62%|██████▏   | 16/26 [00:30<00:19,  1.91s/it]

Loss: 0.47128286957740784


Epoch 10:  65%|██████▌   | 17/26 [00:32<00:17,  1.91s/it]

Loss: 1.3647668361663818


Epoch 10:  69%|██████▉   | 18/26 [00:34<00:15,  1.91s/it]

Loss: 0.5395845174789429


Epoch 10:  73%|███████▎  | 19/26 [00:36<00:13,  1.91s/it]

Loss: 0.39217036962509155


Epoch 10:  77%|███████▋  | 20/26 [00:38<00:11,  1.91s/it]

Loss: 0.28103089332580566


Epoch 10:  81%|████████  | 21/26 [00:40<00:09,  1.91s/it]

Loss: 0.48092949390411377


Epoch 10:  85%|████████▍ | 22/26 [00:42<00:07,  1.91s/it]

Loss: 0.5730733871459961


Epoch 10:  88%|████████▊ | 23/26 [00:43<00:05,  1.91s/it]

Loss: 0.3315696716308594


Epoch 10:  92%|█████████▏| 24/26 [00:45<00:03,  1.91s/it]

Loss: 0.22182932496070862


Epoch 10:  96%|█████████▌| 25/26 [00:47<00:01,  1.91s/it]

Loss: 0.5097233653068542


Epoch 10: 100%|██████████| 26/26 [00:48<00:00,  1.87s/it]

Loss: 0.29963621497154236





Epoch 10 Validation Accuracy: 0.9064039408866995, F1-macro: 0.5229437229437229


## 3. 모델 테스트

In [None]:
# Evaluate on the test set
test_metrics = evaluate(model, test_dataloader, device)
print(f"Test Accuracy: {test_metrics['accuracy']}, F1-macro: {test_metrics['f1_macro']}")

Test Accuracy: 0.8480392156862745, F1-macro: 0.5229437229437229


In [None]:
distortion_types = [
    'all-or-nothing thinking', 'comparing and despairing',
    'disqualifying the positive', 'emotional reasoning', 'fortune telling',
    'labeling', 'magnification', 'mind reading', 'overgeneralizing',
    'should statements', 'mental filter', 'personalization and blaming'
]

In [None]:
results_df = pd.DataFrame({
    "distortion_type": distortion_types,
    "test_accuracy": [np.nan] * 12,
    "f1_macro": [np.nan] * 12
})

In [None]:
current_type = 'all-or-nothing thinking'

results_df.loc[
    results_df["distortion_type"] == current_type,
    ["test_accuracy", "f1_macro"]
] = [test_metrics['accuracy'], test_metrics['f1_macro']]

# 2. Comparing and Despairing

In [None]:
# Add labels
data1_1_labels = list(data1['comparing and despairing'][data1_1.index])
# comparing and despairing은 data1에만 있음.

# Merging Data
data_encoded = data1_1_encoded
data_labels = data1_1_labels

dataset_with_labels = CustomDatasetWithLabels(data_encoded, data_labels)

# Define proportions for splitting
train_size = int(0.8 * len(dataset_with_labels))
val_size = int(0.1 * len(dataset_with_labels))
test_size = len(dataset_with_labels) - train_size - val_size

# Split the dataset
train_dataset, val_dataset, test_dataset = random_split(dataset_with_labels, [train_size, val_size, test_size])

# Create DataLoaders for each set
train_dataloader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=64)
test_dataloader = DataLoader(test_dataset, batch_size=64)

input_dim = bert_config.hidden_size
num_labels = len(set(data_labels))
label_emb = torch.randn(num_labels, input_dim)

In [None]:
# Instantiate the InnerProductClassifier model
model = InnerProductClassifier(input_dim, label_emb).to(device)

# Define loss function
criterion = nn.CrossEntropyLoss()

# Define optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=2e-4)

In [None]:
EPOCHS = 10

# Move the BERT model to the device
bert_model.to(device)

for epoch in range(EPOCHS):
    model.train()
    for batch in tqdm(train_dataloader, desc=f"Epoch {epoch+1}"):

        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        y = batch["y"].to(device)

        # Get embeddings from the BERT model
        with torch.no_grad():
            outputs = bert_model(input_ids=input_ids, attention_mask=attention_mask)
            embeddings = outputs.last_hidden_state[:, 0, :] # Get the [CLS] token embedding

        # Pass embeddings to the InnerProductClassifier
        logits = model(embeddings)
        loss = criterion(logits, y)

        # Backpropagate and update weights
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        print(f"Loss: {loss.item()}")

    # Evaluate on the validation set after each epoch
    val_metrics = evaluate(model, val_dataloader, device)
    print(f"Epoch {epoch+1} Validation Accuracy: {val_metrics['accuracy']}, F1-macro: {val_metrics['f1_macro']}")

Epoch 1:  33%|███▎      | 1/3 [00:01<00:03,  1.89s/it]

Loss: 10.964028358459473


Epoch 1:  67%|██████▋   | 2/3 [00:03<00:01,  1.90s/it]

Loss: 1.0615003108978271


Epoch 1: 100%|██████████| 3/3 [00:05<00:00,  1.85s/it]

Loss: 1.4241770505905151





Epoch 1 Validation Accuracy: 1.0, F1-macro: 1.0


Epoch 2:  33%|███▎      | 1/3 [00:01<00:03,  1.90s/it]

Loss: 0.8279706835746765


Epoch 2:  67%|██████▋   | 2/3 [00:03<00:01,  1.91s/it]

Loss: 4.615095138549805


Epoch 2: 100%|██████████| 3/3 [00:05<00:00,  1.85s/it]

Loss: 2.3962879180908203





Epoch 2 Validation Accuracy: 1.0, F1-macro: 1.0


Epoch 3:  33%|███▎      | 1/3 [00:01<00:03,  1.90s/it]

Loss: 2.862060785293579


Epoch 3:  67%|██████▋   | 2/3 [00:03<00:01,  1.90s/it]

Loss: 2.5891737937927246


Epoch 3: 100%|██████████| 3/3 [00:05<00:00,  1.85s/it]

Loss: 4.666739463806152





Epoch 3 Validation Accuracy: 1.0, F1-macro: 1.0


Epoch 4:  33%|███▎      | 1/3 [00:01<00:03,  1.90s/it]

Loss: 3.254875421524048


Epoch 4:  67%|██████▋   | 2/3 [00:03<00:01,  1.90s/it]

Loss: 2.879023551940918


Epoch 4: 100%|██████████| 3/3 [00:05<00:00,  1.85s/it]

Loss: 4.7457380294799805





Epoch 4 Validation Accuracy: 1.0, F1-macro: 1.0


Epoch 5:  33%|███▎      | 1/3 [00:01<00:03,  1.90s/it]

Loss: 3.6033143997192383


Epoch 5:  67%|██████▋   | 2/3 [00:03<00:01,  1.90s/it]

Loss: 4.2465410232543945


Epoch 5: 100%|██████████| 3/3 [00:05<00:00,  1.85s/it]

Loss: 2.7322750091552734





Epoch 5 Validation Accuracy: 1.0, F1-macro: 1.0


Epoch 6:  33%|███▎      | 1/3 [00:01<00:03,  1.90s/it]

Loss: 2.4997787475585938


Epoch 6:  67%|██████▋   | 2/3 [00:03<00:01,  1.91s/it]

Loss: 5.943368911743164


Epoch 6: 100%|██████████| 3/3 [00:05<00:00,  1.85s/it]

Loss: 1.284012794494629





Epoch 6 Validation Accuracy: 1.0, F1-macro: 1.0


Epoch 7:  33%|███▎      | 1/3 [00:01<00:03,  1.90s/it]

Loss: 2.599189281463623


Epoch 7:  67%|██████▋   | 2/3 [00:03<00:01,  1.90s/it]

Loss: 2.5339674949645996


Epoch 7: 100%|██████████| 3/3 [00:05<00:00,  1.85s/it]

Loss: 3.090691566467285





Epoch 7 Validation Accuracy: 1.0, F1-macro: 1.0


Epoch 8:  33%|███▎      | 1/3 [00:01<00:03,  1.90s/it]

Loss: 1.6891505718231201


Epoch 8:  67%|██████▋   | 2/3 [00:03<00:01,  1.91s/it]

Loss: 1.747880458831787


Epoch 8: 100%|██████████| 3/3 [00:05<00:00,  1.85s/it]

Loss: 2.9969167709350586





Epoch 8 Validation Accuracy: 1.0, F1-macro: 1.0


Epoch 9:  33%|███▎      | 1/3 [00:01<00:03,  1.90s/it]

Loss: 1.6957786083221436


Epoch 9:  67%|██████▋   | 2/3 [00:03<00:01,  1.91s/it]

Loss: 0.0


Epoch 9: 100%|██████████| 3/3 [00:05<00:00,  1.85s/it]

Loss: 2.6097278594970703





Epoch 9 Validation Accuracy: 1.0, F1-macro: 1.0


Epoch 10:  33%|███▎      | 1/3 [00:01<00:03,  1.90s/it]

Loss: 2.1008002758026123


Epoch 10:  67%|██████▋   | 2/3 [00:03<00:01,  1.90s/it]

Loss: 0.39907124638557434


Epoch 10: 100%|██████████| 3/3 [00:05<00:00,  1.85s/it]

Loss: 0.29249656200408936





Epoch 10 Validation Accuracy: 0.8260869565217391, F1-macro: 0.4523809523809524


In [None]:
# Evaluate on the test set
test_metrics = evaluate(model, test_dataloader, device)
print(f"Test Accuracy: {test_metrics['accuracy']}, F1-macro: {test_metrics['f1_macro']}")

Test Accuracy: 0.84, F1-macro: 0.4523809523809524


In [None]:
current_type = 'comparing and despairing'

results_df.loc[
    results_df["distortion_type"] == current_type,
    ["test_accuracy", "f1_macro"]
] = [test_metrics['accuracy'], test_metrics['f1_macro']]

# 3. Disqualifying the Positive

In [None]:
# Add labels
data1_1_labels = list(data1['disqualifying the positive'][data1_1.index])
data2_1_labels = list(data2['disqualifiying the positive'][data2_1.index]) # 자료에서 오타 있었음!
data3_1_labels = list(data3['disqualifying the positive'][data3_1.index])

# Merging Data
data_encoded = data1_1_encoded + data2_1_encoded + data3_1_encoded
data_labels = data1_1_labels + data2_1_labels + data3_1_labels

dataset_with_labels = CustomDatasetWithLabels(data_encoded, data_labels)

# Define proportions for splitting
train_size = int(0.8 * len(dataset_with_labels))
val_size = int(0.1 * len(dataset_with_labels))
test_size = len(dataset_with_labels) - train_size - val_size

# Split the dataset
train_dataset, val_dataset, test_dataset = random_split(dataset_with_labels, [train_size, val_size, test_size])

# Create DataLoaders for each set
train_dataloader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=64)
test_dataloader = DataLoader(test_dataset, batch_size=64)

input_dim = bert_config.hidden_size
num_labels = len(set(data_labels))
label_emb = torch.randn(num_labels, input_dim)

In [None]:
# Instantiate the InnerProductClassifier model
model = InnerProductClassifier(input_dim, label_emb).to(device)

# Define loss function
criterion = nn.CrossEntropyLoss()

# Define optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=2e-4)

In [None]:
EPOCHS = 10

# Move the BERT model to the device
bert_model.to(device)

for epoch in range(EPOCHS):
    model.train()
    for batch in tqdm(train_dataloader, desc=f"Epoch {epoch+1}"):

        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        y = batch["y"].to(device)

        # Get embeddings from the BERT model
        with torch.no_grad():
            outputs = bert_model(input_ids=input_ids, attention_mask=attention_mask)
            embeddings = outputs.last_hidden_state[:, 0, :] # Get the [CLS] token embedding

        # Pass embeddings to the InnerProductClassifier
        logits = model(embeddings)
        loss = criterion(logits, y)

        # Backpropagate and update weights
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        print(f"Loss: {loss.item()}")

    # Evaluate on the validation set after each epoch
    val_metrics = evaluate(model, val_dataloader, device)
    print(f"Epoch {epoch+1} Validation Accuracy: {val_metrics['accuracy']}, F1-macro: {val_metrics['f1_macro']}")

Epoch 1:   4%|▍         | 1/26 [00:01<00:47,  1.89s/it]

Loss: 0.07945582270622253


Epoch 1:   8%|▊         | 2/26 [00:03<00:45,  1.90s/it]

Loss: 14.826980590820312


Epoch 1:  12%|█▏        | 3/26 [00:05<00:43,  1.91s/it]

Loss: 0.6776041984558105


Epoch 1:  15%|█▌        | 4/26 [00:07<00:42,  1.91s/it]

Loss: 8.754407332389746e-08


Epoch 1:  19%|█▉        | 5/26 [00:09<00:40,  1.91s/it]

Loss: 2.2280735969543457


Epoch 1:  23%|██▎       | 6/26 [00:11<00:38,  1.91s/it]

Loss: 0.0


Epoch 1:  27%|██▋       | 7/26 [00:13<00:36,  1.91s/it]

Loss: 0.0


Epoch 1:  31%|███       | 8/26 [00:15<00:34,  1.91s/it]

Loss: 1.0510597229003906


Epoch 1:  35%|███▍      | 9/26 [00:17<00:32,  1.91s/it]

Loss: 0.0


Epoch 1:  38%|███▊      | 10/26 [00:19<00:30,  1.91s/it]

Loss: 1.5051530599594116


Epoch 1:  42%|████▏     | 11/26 [00:21<00:28,  1.91s/it]

Loss: 2.6690292358398438


Epoch 1:  46%|████▌     | 12/26 [00:22<00:26,  1.91s/it]

Loss: 1.3516380786895752


Epoch 1:  50%|█████     | 13/26 [00:24<00:24,  1.91s/it]

Loss: 2.7947168350219727


Epoch 1:  54%|█████▍    | 14/26 [00:26<00:22,  1.91s/it]

Loss: 1.1552159786224365


Epoch 1:  58%|█████▊    | 15/26 [00:28<00:21,  1.91s/it]

Loss: 2.123791217803955


Epoch 1:  62%|██████▏   | 16/26 [00:30<00:19,  1.91s/it]

Loss: 1.1418296098709106


Epoch 1:  65%|██████▌   | 17/26 [00:32<00:17,  1.91s/it]

Loss: 2.8537535667419434


Epoch 1:  69%|██████▉   | 18/26 [00:34<00:15,  1.91s/it]

Loss: 0.0


Epoch 1:  73%|███████▎  | 19/26 [00:36<00:13,  1.91s/it]

Loss: 1.5366392135620117


Epoch 1:  77%|███████▋  | 20/26 [00:38<00:11,  1.91s/it]

Loss: 0.0


Epoch 1:  81%|████████  | 21/26 [00:40<00:09,  1.91s/it]

Loss: 1.8424761295318604


Epoch 1:  85%|████████▍ | 22/26 [00:42<00:07,  1.91s/it]

Loss: 1.3759100437164307


Epoch 1:  88%|████████▊ | 23/26 [00:43<00:05,  1.91s/it]

Loss: 0.0


Epoch 1:  92%|█████████▏| 24/26 [00:45<00:03,  1.91s/it]

Loss: 1.2397066354751587


Epoch 1:  96%|█████████▌| 25/26 [00:47<00:01,  1.91s/it]

Loss: 0.0


Epoch 1: 100%|██████████| 26/26 [00:48<00:00,  1.87s/it]

Loss: 0.0





Epoch 1 Validation Accuracy: 0.9901477832512315, F1-macro: 0.4975247524752475


Epoch 2:   4%|▍         | 1/26 [00:01<00:47,  1.90s/it]

Loss: 1.1772477626800537


Epoch 2:   8%|▊         | 2/26 [00:03<00:45,  1.91s/it]

Loss: 3.5680291652679443


Epoch 2:  12%|█▏        | 3/26 [00:05<00:43,  1.91s/it]

Loss: 1.4555320739746094


Epoch 2:  15%|█▌        | 4/26 [00:07<00:42,  1.91s/it]

Loss: 3.9531567096710205


Epoch 2:  19%|█▉        | 5/26 [00:09<00:40,  1.91s/it]

Loss: 2.5658323764801025


Epoch 2:  23%|██▎       | 6/26 [00:11<00:38,  1.91s/it]

Loss: 2.7251224517822266


Epoch 2:  27%|██▋       | 7/26 [00:13<00:36,  1.91s/it]

Loss: 0.0


Epoch 2:  31%|███       | 8/26 [00:15<00:34,  1.91s/it]

Loss: 0.0


Epoch 2:  35%|███▍      | 9/26 [00:17<00:32,  1.91s/it]

Loss: 1.3038462043368781e-07


Epoch 2:  38%|███▊      | 10/26 [00:19<00:30,  1.91s/it]

Loss: 4.187013109913096e-05


Epoch 2:  42%|████▏     | 11/26 [00:21<00:28,  1.91s/it]

Loss: 0.4876694083213806


Epoch 2:  46%|████▌     | 12/26 [00:22<00:26,  1.91s/it]

Loss: 1.8626450382086546e-09


Epoch 2:  50%|█████     | 13/26 [00:24<00:24,  1.91s/it]

Loss: 0.43056520819664


Epoch 2:  54%|█████▍    | 14/26 [00:26<00:22,  1.91s/it]

Loss: 0.0


Epoch 2:  58%|█████▊    | 15/26 [00:28<00:21,  1.91s/it]

Loss: 0.1773476004600525


Epoch 2:  62%|██████▏   | 16/26 [00:30<00:19,  1.91s/it]

Loss: 0.05676108971238136


Epoch 2:  65%|██████▌   | 17/26 [00:32<00:17,  1.91s/it]

Loss: 0.029375109821558


Epoch 2:  69%|██████▉   | 18/26 [00:34<00:15,  1.91s/it]

Loss: 0.1733091026544571


Epoch 2:  73%|███████▎  | 19/26 [00:36<00:13,  1.91s/it]

Loss: 0.951551079750061


Epoch 2:  77%|███████▋  | 20/26 [00:38<00:11,  1.91s/it]

Loss: 1.471587061882019


Epoch 2:  81%|████████  | 21/26 [00:40<00:09,  1.91s/it]

Loss: 0.6756206750869751


Epoch 2:  85%|████████▍ | 22/26 [00:42<00:07,  1.91s/it]

Loss: 0.8908208608627319


Epoch 2:  88%|████████▊ | 23/26 [00:43<00:05,  1.91s/it]

Loss: 0.009201440960168839


Epoch 2:  92%|█████████▏| 24/26 [00:45<00:03,  1.91s/it]

Loss: 1.1484332084655762


Epoch 2:  96%|█████████▌| 25/26 [00:47<00:01,  1.91s/it]

Loss: 0.022556478157639503


Epoch 2: 100%|██████████| 26/26 [00:48<00:00,  1.87s/it]

Loss: 2.5544839132862762e-08





Epoch 2 Validation Accuracy: 0.9901477832512315, F1-macro: 0.7475124378109452


Epoch 3:   4%|▍         | 1/26 [00:01<00:47,  1.91s/it]

Loss: 0.0


Epoch 3:   8%|▊         | 2/26 [00:03<00:45,  1.91s/it]

Loss: 1.3302273750305176


Epoch 3:  12%|█▏        | 3/26 [00:05<00:43,  1.91s/it]

Loss: 0.5696174502372742


Epoch 3:  15%|█▌        | 4/26 [00:07<00:42,  1.91s/it]

Loss: 3.7252860352054995e-08


Epoch 3:  19%|█▉        | 5/26 [00:09<00:40,  1.91s/it]

Loss: 0.426430881023407


Epoch 3:  23%|██▎       | 6/26 [00:11<00:38,  1.91s/it]

Loss: 1.1158463954925537


Epoch 3:  27%|██▋       | 7/26 [00:13<00:36,  1.91s/it]

Loss: 0.0012780196266248822


Epoch 3:  31%|███       | 8/26 [00:15<00:34,  1.91s/it]

Loss: 0.861177384853363


Epoch 3:  35%|███▍      | 9/26 [00:17<00:32,  1.91s/it]

Loss: 5.066327730673947e-07


Epoch 3:  38%|███▊      | 10/26 [00:19<00:30,  1.91s/it]

Loss: 0.09812236577272415


Epoch 3:  42%|████▏     | 11/26 [00:21<00:28,  1.91s/it]

Loss: 0.24937498569488525


Epoch 3:  46%|████▌     | 12/26 [00:22<00:26,  1.91s/it]

Loss: 0.0010552958119660616


Epoch 3:  50%|█████     | 13/26 [00:24<00:24,  1.91s/it]

Loss: 0.698161780834198


Epoch 3:  54%|█████▍    | 14/26 [00:26<00:22,  1.91s/it]

Loss: 1.7665977478027344


Epoch 3:  58%|█████▊    | 15/26 [00:28<00:21,  1.91s/it]

Loss: 1.8626450382086546e-09


Epoch 3:  62%|██████▏   | 16/26 [00:30<00:19,  1.91s/it]

Loss: 0.31740567088127136


Epoch 3:  65%|██████▌   | 17/26 [00:32<00:17,  1.91s/it]

Loss: 0.235647052526474


Epoch 3:  69%|██████▉   | 18/26 [00:34<00:15,  1.91s/it]

Loss: 0.24933616816997528


Epoch 3:  73%|███████▎  | 19/26 [00:36<00:13,  1.91s/it]

Loss: 0.010145306587219238


Epoch 3:  77%|███████▋  | 20/26 [00:38<00:11,  1.91s/it]

Loss: 0.010789868421852589


Epoch 3:  81%|████████  | 21/26 [00:40<00:09,  1.91s/it]

Loss: 0.27643054723739624


Epoch 3:  85%|████████▍ | 22/26 [00:42<00:07,  1.91s/it]

Loss: 0.023240990936756134


Epoch 3:  88%|████████▊ | 23/26 [00:43<00:05,  1.91s/it]

Loss: 0.023599956184625626


Epoch 3:  92%|█████████▏| 24/26 [00:45<00:03,  1.91s/it]

Loss: 0.508349597454071


Epoch 3:  96%|█████████▌| 25/26 [00:47<00:01,  1.91s/it]

Loss: 0.1556280553340912


Epoch 3: 100%|██████████| 26/26 [00:48<00:00,  1.87s/it]

Loss: 1.1526473760604858





Epoch 3 Validation Accuracy: 0.9802955665024631, F1-macro: 0.6616666666666666


Epoch 4:   4%|▍         | 1/26 [00:01<00:47,  1.90s/it]

Loss: 0.051931533962488174


Epoch 4:   8%|▊         | 2/26 [00:03<00:45,  1.91s/it]

Loss: 0.6730899214744568


Epoch 4:  12%|█▏        | 3/26 [00:05<00:43,  1.91s/it]

Loss: 0.11793676763772964


Epoch 4:  15%|█▌        | 4/26 [00:07<00:41,  1.91s/it]

Loss: 0.039488036185503006


Epoch 4:  19%|█▉        | 5/26 [00:09<00:40,  1.91s/it]

Loss: 0.3727342188358307


Epoch 4:  23%|██▎       | 6/26 [00:11<00:38,  1.91s/it]

Loss: 0.1869286596775055


Epoch 4:  27%|██▋       | 7/26 [00:13<00:36,  1.91s/it]

Loss: 0.6905163526535034


Epoch 4:  31%|███       | 8/26 [00:15<00:34,  1.91s/it]

Loss: 0.015440838411450386


Epoch 4:  35%|███▍      | 9/26 [00:17<00:32,  1.91s/it]

Loss: 0.19039441645145416


Epoch 4:  38%|███▊      | 10/26 [00:19<00:30,  1.91s/it]

Loss: 0.005896697286516428


Epoch 4:  42%|████▏     | 11/26 [00:21<00:28,  1.91s/it]

Loss: 2.7955845780525124e-06


Epoch 4:  46%|████▌     | 12/26 [00:22<00:26,  1.91s/it]

Loss: 0.8581803441047668


Epoch 4:  50%|█████     | 13/26 [00:24<00:24,  1.91s/it]

Loss: 0.10359571874141693


Epoch 4:  54%|█████▍    | 14/26 [00:26<00:22,  1.91s/it]

Loss: 0.14656196534633636


Epoch 4:  58%|█████▊    | 15/26 [00:28<00:21,  1.91s/it]

Loss: 9.113355190493166e-05


Epoch 4:  62%|██████▏   | 16/26 [00:30<00:19,  1.91s/it]

Loss: 0.38506123423576355


Epoch 4:  65%|██████▌   | 17/26 [00:32<00:17,  1.91s/it]

Loss: 4.320939478930086e-05


Epoch 4:  69%|██████▉   | 18/26 [00:34<00:15,  1.91s/it]

Loss: 0.0004424040380399674


Epoch 4:  73%|███████▎  | 19/26 [00:36<00:13,  1.91s/it]

Loss: 0.11765557527542114


Epoch 4:  77%|███████▋  | 20/26 [00:38<00:11,  1.91s/it]

Loss: 0.24049007892608643


Epoch 4:  81%|████████  | 21/26 [00:40<00:09,  1.91s/it]

Loss: 0.51878422498703


Epoch 4:  85%|████████▍ | 22/26 [00:42<00:07,  1.91s/it]

Loss: 0.2480827122926712


Epoch 4:  88%|████████▊ | 23/26 [00:43<00:05,  1.91s/it]

Loss: 0.011285640299320221


Epoch 4:  92%|█████████▏| 24/26 [00:45<00:03,  1.91s/it]

Loss: 0.03713955357670784


Epoch 4:  96%|█████████▌| 25/26 [00:47<00:01,  1.91s/it]

Loss: 0.10141334682703018


Epoch 4: 100%|██████████| 26/26 [00:48<00:00,  1.87s/it]

Loss: 0.16813962161540985





Epoch 4 Validation Accuracy: 0.9852216748768473, F1-macro: 0.6962593516209477


Epoch 5:   4%|▍         | 1/26 [00:01<00:47,  1.91s/it]

Loss: 0.3843376636505127


Epoch 5:   8%|▊         | 2/26 [00:03<00:45,  1.91s/it]

Loss: 0.0919877365231514


Epoch 5:  12%|█▏        | 3/26 [00:05<00:43,  1.91s/it]

Loss: 0.12887656688690186


Epoch 5:  15%|█▌        | 4/26 [00:07<00:42,  1.91s/it]

Loss: 0.32703423500061035


Epoch 5:  19%|█▉        | 5/26 [00:09<00:40,  1.91s/it]

Loss: 6.997015589149669e-05


Epoch 5:  23%|██▎       | 6/26 [00:11<00:38,  1.91s/it]

Loss: 0.160146564245224


Epoch 5:  27%|██▋       | 7/26 [00:13<00:36,  1.91s/it]

Loss: 0.004308711271733046


Epoch 5:  31%|███       | 8/26 [00:15<00:34,  1.91s/it]

Loss: 0.05053717643022537


Epoch 5:  35%|███▍      | 9/26 [00:17<00:32,  1.91s/it]

Loss: 0.23047584295272827


Epoch 5:  38%|███▊      | 10/26 [00:19<00:30,  1.91s/it]

Loss: 0.18269583582878113


Epoch 5:  42%|████▏     | 11/26 [00:20<00:28,  1.91s/it]

Loss: 3.7625108006977825e-07


Epoch 5:  46%|████▌     | 12/26 [00:22<00:26,  1.91s/it]

Loss: 1.3597289694189385e-07


Epoch 5:  50%|█████     | 13/26 [00:24<00:24,  1.91s/it]

Loss: 0.1173376813530922


Epoch 5:  54%|█████▍    | 14/26 [00:26<00:22,  1.91s/it]

Loss: 2.4214378058218244e-08


Epoch 5:  58%|█████▊    | 15/26 [00:28<00:21,  1.91s/it]

Loss: 0.6120040416717529


Epoch 5:  62%|██████▏   | 16/26 [00:30<00:19,  1.91s/it]

Loss: 0.0002623690234031528


Epoch 5:  65%|██████▌   | 17/26 [00:32<00:17,  1.91s/it]

Loss: 0.466744065284729


Epoch 5:  69%|██████▉   | 18/26 [00:34<00:15,  1.91s/it]

Loss: 0.007480932865291834


Epoch 5:  73%|███████▎  | 19/26 [00:36<00:13,  1.91s/it]

Loss: 0.5229176878929138


Epoch 5:  77%|███████▋  | 20/26 [00:38<00:11,  1.91s/it]

Loss: 0.027269674465060234


Epoch 5:  81%|████████  | 21/26 [00:40<00:09,  1.91s/it]

Loss: 0.49663031101226807


Epoch 5:  85%|████████▍ | 22/26 [00:42<00:07,  1.91s/it]

Loss: 0.0745440125465393


Epoch 5:  88%|████████▊ | 23/26 [00:43<00:05,  1.91s/it]

Loss: 0.3796159625053406


Epoch 5:  92%|█████████▏| 24/26 [00:45<00:03,  1.91s/it]

Loss: 0.16096539795398712


Epoch 5:  96%|█████████▌| 25/26 [00:47<00:01,  1.91s/it]

Loss: 0.002567967865616083


Epoch 5: 100%|██████████| 26/26 [00:48<00:00,  1.87s/it]

Loss: 5.185471763979876e-06





Epoch 5 Validation Accuracy: 0.9901477832512315, F1-macro: 0.4975247524752475


Epoch 6:   4%|▍         | 1/26 [00:01<00:47,  1.90s/it]

Loss: 0.1297270655632019


Epoch 6:   8%|▊         | 2/26 [00:03<00:45,  1.91s/it]

Loss: 0.0014708193484693766


Epoch 6:  12%|█▏        | 3/26 [00:05<00:43,  1.91s/it]

Loss: 2.713734829740133e-06


Epoch 6:  15%|█▌        | 4/26 [00:07<00:41,  1.91s/it]

Loss: 7.450570649325527e-08


Epoch 6:  19%|█▉        | 5/26 [00:09<00:40,  1.91s/it]

Loss: 1.2689387798309326


Epoch 6:  23%|██▎       | 6/26 [00:11<00:38,  1.91s/it]

Loss: 0.0692441388964653


Epoch 6:  27%|██▋       | 7/26 [00:13<00:36,  1.91s/it]

Loss: 0.14659647643566132


Epoch 6:  31%|███       | 8/26 [00:15<00:34,  1.91s/it]

Loss: 0.03962670639157295


Epoch 6:  35%|███▍      | 9/26 [00:17<00:32,  1.91s/it]

Loss: 0.10539596527814865


Epoch 6:  38%|███▊      | 10/26 [00:19<00:30,  1.91s/it]

Loss: 0.11249706149101257


Epoch 6:  42%|████▏     | 11/26 [00:21<00:28,  1.91s/it]

Loss: 0.35410168766975403


Epoch 6:  46%|████▌     | 12/26 [00:22<00:26,  1.91s/it]

Loss: 0.27836766839027405


Epoch 6:  50%|█████     | 13/26 [00:24<00:24,  1.91s/it]

Loss: 0.000721554271876812


Epoch 6:  54%|█████▍    | 14/26 [00:26<00:22,  1.91s/it]

Loss: 0.04084276035428047


Epoch 6:  58%|█████▊    | 15/26 [00:28<00:20,  1.91s/it]

Loss: 8.049267489695922e-05


Epoch 6:  62%|██████▏   | 16/26 [00:30<00:19,  1.91s/it]

Loss: 0.10430654138326645


Epoch 6:  65%|██████▌   | 17/26 [00:32<00:17,  1.91s/it]

Loss: 1.2228348168719094e-05


Epoch 6:  69%|██████▉   | 18/26 [00:34<00:15,  1.91s/it]

Loss: 0.078932024538517


Epoch 6:  73%|███████▎  | 19/26 [00:36<00:13,  1.91s/it]

Loss: 0.004234504420310259


Epoch 6:  77%|███████▋  | 20/26 [00:38<00:11,  1.91s/it]

Loss: 0.00012982386397197843


Epoch 6:  81%|████████  | 21/26 [00:40<00:09,  1.91s/it]

Loss: 0.00019567611161619425


Epoch 6:  85%|████████▍ | 22/26 [00:42<00:07,  1.91s/it]

Loss: 0.0010097608901560307


Epoch 6:  88%|████████▊ | 23/26 [00:43<00:05,  1.91s/it]

Loss: 0.03316805884242058


Epoch 6:  92%|█████████▏| 24/26 [00:45<00:03,  1.91s/it]

Loss: 0.25356605648994446


Epoch 6:  96%|█████████▌| 25/26 [00:47<00:01,  1.91s/it]

Loss: 0.0003080075839534402


Epoch 6: 100%|██████████| 26/26 [00:48<00:00,  1.87s/it]

Loss: 0.37523674964904785





Epoch 6 Validation Accuracy: 0.9852216748768473, F1-macro: 0.49627791563275436


Epoch 7:   4%|▍         | 1/26 [00:01<00:47,  1.90s/it]

Loss: 0.031218327581882477


Epoch 7:   8%|▊         | 2/26 [00:03<00:45,  1.91s/it]

Loss: 0.20373505353927612


Epoch 7:  12%|█▏        | 3/26 [00:05<00:43,  1.91s/it]

Loss: 0.09899625927209854


Epoch 7:  15%|█▌        | 4/26 [00:07<00:41,  1.91s/it]

Loss: 0.1375001072883606


Epoch 7:  19%|█▉        | 5/26 [00:09<00:40,  1.91s/it]

Loss: 0.005795453675091267


Epoch 7:  23%|██▎       | 6/26 [00:11<00:38,  1.91s/it]

Loss: 0.000252744706813246


Epoch 7:  27%|██▋       | 7/26 [00:13<00:36,  1.91s/it]

Loss: 0.3634048104286194


Epoch 7:  31%|███       | 8/26 [00:15<00:34,  1.91s/it]

Loss: 0.007156712003052235


Epoch 7:  35%|███▍      | 9/26 [00:17<00:32,  1.91s/it]

Loss: 0.008538060821592808


Epoch 7:  38%|███▊      | 10/26 [00:19<00:30,  1.91s/it]

Loss: 0.00035570524050854146


Epoch 7:  42%|████▏     | 11/26 [00:20<00:28,  1.91s/it]

Loss: 0.015849126502871513


Epoch 7:  46%|████▌     | 12/26 [00:22<00:26,  1.91s/it]

Loss: 0.003290103515610099


Epoch 7:  50%|█████     | 13/26 [00:24<00:24,  1.91s/it]

Loss: 0.0225453469902277


Epoch 7:  54%|█████▍    | 14/26 [00:26<00:22,  1.91s/it]

Loss: 0.09042493253946304


Epoch 7:  58%|█████▊    | 15/26 [00:28<00:20,  1.91s/it]

Loss: 0.09928659349679947


Epoch 7:  62%|██████▏   | 16/26 [00:30<00:19,  1.91s/it]

Loss: 0.002122360747307539


Epoch 7:  65%|██████▌   | 17/26 [00:32<00:17,  1.91s/it]

Loss: 0.051046617329120636


Epoch 7:  69%|██████▉   | 18/26 [00:34<00:15,  1.91s/it]

Loss: 0.002400253666564822


Epoch 7:  73%|███████▎  | 19/26 [00:36<00:13,  1.91s/it]

Loss: 0.0016162157990038395


Epoch 7:  77%|███████▋  | 20/26 [00:38<00:11,  1.91s/it]

Loss: 0.0029531570617109537


Epoch 7:  81%|████████  | 21/26 [00:40<00:09,  1.91s/it]

Loss: 0.016573289409279823


Epoch 7:  85%|████████▍ | 22/26 [00:41<00:07,  1.91s/it]

Loss: 0.23641757667064667


Epoch 7:  88%|████████▊ | 23/26 [00:43<00:05,  1.91s/it]

Loss: 0.0012335615465417504


Epoch 7:  92%|█████████▏| 24/26 [00:45<00:03,  1.91s/it]

Loss: 0.024397291243076324


Epoch 7:  96%|█████████▌| 25/26 [00:47<00:01,  1.91s/it]

Loss: 0.1384788304567337


Epoch 7: 100%|██████████| 26/26 [00:48<00:00,  1.87s/it]

Loss: 0.002217769157141447





Epoch 7 Validation Accuracy: 0.9753694581280788, F1-macro: 0.4937655860349127


Epoch 8:   4%|▍         | 1/26 [00:01<00:47,  1.90s/it]

Loss: 0.04807324707508087


Epoch 8:   8%|▊         | 2/26 [00:03<00:45,  1.91s/it]

Loss: 0.0005730831762775779


Epoch 8:  12%|█▏        | 3/26 [00:05<00:43,  1.91s/it]

Loss: 0.02709042653441429


Epoch 8:  15%|█▌        | 4/26 [00:07<00:42,  1.91s/it]

Loss: 0.025617776438593864


Epoch 8:  19%|█▉        | 5/26 [00:09<00:40,  1.91s/it]

Loss: 0.0184495747089386


Epoch 8:  23%|██▎       | 6/26 [00:11<00:38,  1.91s/it]

Loss: 0.2469974011182785


Epoch 8:  27%|██▋       | 7/26 [00:13<00:36,  1.91s/it]

Loss: 0.023323416709899902


Epoch 8:  31%|███       | 8/26 [00:15<00:34,  1.91s/it]

Loss: 0.0009317841031588614


Epoch 8:  35%|███▍      | 9/26 [00:17<00:32,  1.91s/it]

Loss: 0.04133794829249382


Epoch 8:  38%|███▊      | 10/26 [00:19<00:30,  1.91s/it]

Loss: 0.005632970482110977


Epoch 8:  42%|████▏     | 11/26 [00:20<00:28,  1.91s/it]

Loss: 0.00035058805951848626


Epoch 8:  46%|████▌     | 12/26 [00:22<00:26,  1.91s/it]

Loss: 0.001007581246085465


Epoch 8:  50%|█████     | 13/26 [00:24<00:24,  1.91s/it]

Loss: 0.00032380642369389534


Epoch 8:  54%|█████▍    | 14/26 [00:26<00:22,  1.91s/it]

Loss: 0.006743151228874922


Epoch 8:  58%|█████▊    | 15/26 [00:28<00:20,  1.91s/it]

Loss: 0.0010589339071884751


Epoch 8:  62%|██████▏   | 16/26 [00:30<00:19,  1.91s/it]

Loss: 0.18689493834972382


Epoch 8:  65%|██████▌   | 17/26 [00:32<00:17,  1.91s/it]

Loss: 0.08754444122314453


Epoch 8:  69%|██████▉   | 18/26 [00:34<00:15,  1.91s/it]

Loss: 0.10646030306816101


Epoch 8:  73%|███████▎  | 19/26 [00:36<00:13,  1.91s/it]

Loss: 0.033423520624637604


Epoch 8:  77%|███████▋  | 20/26 [00:38<00:11,  1.91s/it]

Loss: 0.01676911674439907


Epoch 8:  81%|████████  | 21/26 [00:40<00:09,  1.91s/it]

Loss: 0.0005958633264526725


Epoch 8:  85%|████████▍ | 22/26 [00:41<00:07,  1.91s/it]

Loss: 0.015565616078674793


Epoch 8:  88%|████████▊ | 23/26 [00:43<00:05,  1.91s/it]

Loss: 0.022831685841083527


Epoch 8:  92%|█████████▏| 24/26 [00:45<00:03,  1.91s/it]

Loss: 0.0685371682047844


Epoch 8:  96%|█████████▌| 25/26 [00:47<00:01,  1.91s/it]

Loss: 0.015903381630778313


Epoch 8: 100%|██████████| 26/26 [00:48<00:00,  1.87s/it]

Loss: 0.29278138279914856





Epoch 8 Validation Accuracy: 0.9901477832512315, F1-macro: 0.4975247524752475


Epoch 9:   4%|▍         | 1/26 [00:01<00:47,  1.91s/it]

Loss: 0.00018964536138810217


Epoch 9:   8%|▊         | 2/26 [00:03<00:45,  1.91s/it]

Loss: 0.08319388329982758


Epoch 9:  12%|█▏        | 3/26 [00:05<00:43,  1.91s/it]

Loss: 0.0011586982291191816


Epoch 9:  15%|█▌        | 4/26 [00:07<00:42,  1.91s/it]

Loss: 0.003164962399750948


Epoch 9:  19%|█▉        | 5/26 [00:09<00:40,  1.91s/it]

Loss: 0.005241198930889368


Epoch 9:  23%|██▎       | 6/26 [00:11<00:38,  1.91s/it]

Loss: 0.012929619289934635


Epoch 9:  27%|██▋       | 7/26 [00:13<00:36,  1.91s/it]

Loss: 0.027671964839100838


Epoch 9:  31%|███       | 8/26 [00:15<00:34,  1.91s/it]

Loss: 0.02241915836930275


Epoch 9:  35%|███▍      | 9/26 [00:17<00:32,  1.91s/it]

Loss: 0.02816883474588394


Epoch 9:  38%|███▊      | 10/26 [00:19<00:30,  1.91s/it]

Loss: 0.02277732454240322


Epoch 9:  42%|████▏     | 11/26 [00:21<00:28,  1.91s/it]

Loss: 0.00019154259643983096


Epoch 9:  46%|████▌     | 12/26 [00:22<00:26,  1.91s/it]

Loss: 0.00035224476596340537


Epoch 9:  50%|█████     | 13/26 [00:24<00:24,  1.91s/it]

Loss: 0.11697974801063538


Epoch 9:  54%|█████▍    | 14/26 [00:26<00:22,  1.91s/it]

Loss: 0.08726748079061508


Epoch 9:  58%|█████▊    | 15/26 [00:28<00:21,  1.91s/it]

Loss: 5.451714514492778e-06


Epoch 9:  62%|██████▏   | 16/26 [00:30<00:19,  1.91s/it]

Loss: 9.378066351928283e-06


Epoch 9:  65%|██████▌   | 17/26 [00:32<00:17,  1.91s/it]

Loss: 0.00029307117802090943


Epoch 9:  69%|██████▉   | 18/26 [00:34<00:15,  1.91s/it]

Loss: 0.0002694327267818153


Epoch 9:  73%|███████▎  | 19/26 [00:36<00:13,  1.91s/it]

Loss: 0.060020726174116135


Epoch 9:  77%|███████▋  | 20/26 [00:38<00:11,  1.91s/it]

Loss: 0.005859867669641972


Epoch 9:  81%|████████  | 21/26 [00:40<00:09,  1.91s/it]

Loss: 0.06541220098733902


Epoch 9:  85%|████████▍ | 22/26 [00:42<00:07,  1.91s/it]

Loss: 0.004503071308135986


Epoch 9:  88%|████████▊ | 23/26 [00:43<00:05,  1.91s/it]

Loss: 0.11195389926433563


Epoch 9:  92%|█████████▏| 24/26 [00:45<00:03,  1.91s/it]

Loss: 0.07311663031578064


Epoch 9:  96%|█████████▌| 25/26 [00:47<00:01,  1.91s/it]

Loss: 0.21330846846103668


Epoch 9: 100%|██████████| 26/26 [00:48<00:00,  1.87s/it]

Loss: 6.59669385640882e-05





Epoch 9 Validation Accuracy: 0.9753694581280788, F1-macro: 0.6365914786967418


Epoch 10:   4%|▍         | 1/26 [00:01<00:47,  1.90s/it]

Loss: 0.13435226678848267


Epoch 10:   8%|▊         | 2/26 [00:03<00:45,  1.91s/it]

Loss: 0.0008437223150394857


Epoch 10:  12%|█▏        | 3/26 [00:05<00:43,  1.91s/it]

Loss: 0.17318354547023773


Epoch 10:  15%|█▌        | 4/26 [00:07<00:42,  1.91s/it]

Loss: 0.014898445457220078


Epoch 10:  19%|█▉        | 5/26 [00:09<00:40,  1.91s/it]

Loss: 0.009830060414969921


Epoch 10:  23%|██▎       | 6/26 [00:11<00:38,  1.91s/it]

Loss: 0.0004288622585590929


Epoch 10:  27%|██▋       | 7/26 [00:13<00:36,  1.91s/it]

Loss: 0.00021985523926559836


Epoch 10:  31%|███       | 8/26 [00:15<00:34,  1.91s/it]

Loss: 2.2026548322173767e-05


Epoch 10:  35%|███▍      | 9/26 [00:17<00:32,  1.91s/it]

Loss: 3.801040293183178e-05


Epoch 10:  38%|███▊      | 10/26 [00:19<00:30,  1.91s/it]

Loss: 0.00014683068729937077


Epoch 10:  42%|████▏     | 11/26 [00:20<00:28,  1.91s/it]

Loss: 4.62836123915622e-06


Epoch 10:  46%|████▌     | 12/26 [00:22<00:26,  1.91s/it]

Loss: 0.17574404180049896


Epoch 10:  50%|█████     | 13/26 [00:24<00:24,  1.91s/it]

Loss: 0.12523026764392853


Epoch 10:  54%|█████▍    | 14/26 [00:26<00:22,  1.91s/it]

Loss: 5.0291404818381125e-08


Epoch 10:  58%|█████▊    | 15/26 [00:28<00:21,  1.91s/it]

Loss: 2.2716687453794293e-05


Epoch 10:  62%|██████▏   | 16/26 [00:30<00:19,  1.91s/it]

Loss: 0.00030827970476821065


Epoch 10:  65%|██████▌   | 17/26 [00:32<00:17,  1.91s/it]

Loss: 0.04025647044181824


Epoch 10:  69%|██████▉   | 18/26 [00:34<00:15,  1.91s/it]

Loss: 0.000811938545666635


Epoch 10:  73%|███████▎  | 19/26 [00:36<00:13,  1.91s/it]

Loss: 0.03995256498456001


Epoch 10:  77%|███████▋  | 20/26 [00:38<00:11,  1.91s/it]

Loss: 0.060013867914676666


Epoch 10:  81%|████████  | 21/26 [00:40<00:09,  1.91s/it]

Loss: 0.13133780658245087


Epoch 10:  85%|████████▍ | 22/26 [00:42<00:07,  1.91s/it]

Loss: 0.004928476642817259


Epoch 10:  88%|████████▊ | 23/26 [00:43<00:05,  1.91s/it]

Loss: 0.017068352550268173


Epoch 10:  92%|█████████▏| 24/26 [00:45<00:03,  1.91s/it]

Loss: 0.004286451730877161


Epoch 10:  96%|█████████▌| 25/26 [00:47<00:01,  1.91s/it]

Loss: 0.0840369313955307


Epoch 10: 100%|██████████| 26/26 [00:48<00:00,  1.87s/it]

Loss: 2.171305339970786e-07





Epoch 10 Validation Accuracy: 0.9901477832512315, F1-macro: 0.4975247524752475


In [None]:
# Evaluate on the test set
test_metrics = evaluate(model, test_dataloader, device)
print(f"Test Accuracy: {test_metrics['accuracy']}, F1-macro: {test_metrics['f1_macro']}")

Test Accuracy: 0.9901960784313726, F1-macro: 0.4975247524752475


In [None]:
current_type = 'disqualifying the positive'

results_df.loc[
    results_df["distortion_type"] == current_type,
    ["test_accuracy", "f1_macro"]
] = [test_metrics['accuracy'], test_metrics['f1_macro']]

# 4. Emotional Reasoning

In [None]:
# Add labels
data1_1_labels = list(data1['emotional reasoning'][data1_1.index])
data2_1_labels = list(data2['emotional reasoning'][data2_1.index])
data3_1_labels = list(data3['emotional reasoning'][data3_1.index])

# Merging Data
data_encoded = data1_1_encoded + data2_1_encoded + data3_1_encoded
data_labels = data1_1_labels + data2_1_labels + data3_1_labels

dataset_with_labels = CustomDatasetWithLabels(data_encoded, data_labels)

# Define proportions for splitting
train_size = int(0.8 * len(dataset_with_labels))
val_size = int(0.1 * len(dataset_with_labels))
test_size = len(dataset_with_labels) - train_size - val_size

# Split the dataset
train_dataset, val_dataset, test_dataset = random_split(dataset_with_labels, [train_size, val_size, test_size])

# Create DataLoaders for each set
train_dataloader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=64)
test_dataloader = DataLoader(test_dataset, batch_size=64)

input_dim = bert_config.hidden_size
num_labels = len(set(data_labels))
label_emb = torch.randn(num_labels, input_dim)

In [None]:
# Instantiate the InnerProductClassifier model
model = InnerProductClassifier(input_dim, label_emb).to(device)

# Define loss function
criterion = nn.CrossEntropyLoss()

# Define optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=2e-4)

In [None]:
EPOCHS = 10

# Move the BERT model to the device
bert_model.to(device)

for epoch in range(EPOCHS):
    model.train()
    for batch in tqdm(train_dataloader, desc=f"Epoch {epoch+1}"):

        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        y = batch["y"].to(device)

        # Get embeddings from the BERT model
        with torch.no_grad():
            outputs = bert_model(input_ids=input_ids, attention_mask=attention_mask)
            embeddings = outputs.last_hidden_state[:, 0, :] # Get the [CLS] token embedding

        # Pass embeddings to the InnerProductClassifier
        logits = model(embeddings)
        loss = criterion(logits, y)

        # Backpropagate and update weights
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        print(f"Loss: {loss.item()}")

    # Evaluate on the validation set after each epoch
    val_metrics = evaluate(model, val_dataloader, device)
    print(f"Epoch {epoch+1} Validation Accuracy: {val_metrics['accuracy']}, F1-macro: {val_metrics['f1_macro']}")

Epoch 1:   4%|▍         | 1/26 [00:01<00:47,  1.89s/it]

Loss: 17.062763214111328


Epoch 1:   8%|▊         | 2/26 [00:03<00:45,  1.90s/it]

Loss: 3.0802483558654785


Epoch 1:  12%|█▏        | 3/26 [00:05<00:43,  1.90s/it]

Loss: 4.951032638549805


Epoch 1:  15%|█▌        | 4/26 [00:07<00:41,  1.91s/it]

Loss: 4.744932174682617


Epoch 1:  19%|█▉        | 5/26 [00:09<00:40,  1.91s/it]

Loss: 3.8652615547180176


Epoch 1:  23%|██▎       | 6/26 [00:11<00:38,  1.91s/it]

Loss: 11.824575424194336


Epoch 1:  27%|██▋       | 7/26 [00:13<00:36,  1.91s/it]

Loss: 3.863915205001831


Epoch 1:  31%|███       | 8/26 [00:15<00:34,  1.91s/it]

Loss: 4.9907989501953125


Epoch 1:  35%|███▍      | 9/26 [00:17<00:32,  1.91s/it]

Loss: 6.035246849060059


Epoch 1:  38%|███▊      | 10/26 [00:19<00:30,  1.91s/it]

Loss: 2.575094699859619


Epoch 1:  42%|████▏     | 11/26 [00:20<00:28,  1.91s/it]

Loss: 7.864439010620117


Epoch 1:  46%|████▌     | 12/26 [00:22<00:26,  1.91s/it]

Loss: 4.9446868896484375


Epoch 1:  50%|█████     | 13/26 [00:24<00:24,  1.91s/it]

Loss: 5.907814979553223


Epoch 1:  54%|█████▍    | 14/26 [00:26<00:22,  1.91s/it]

Loss: 2.0671987533569336


Epoch 1:  58%|█████▊    | 15/26 [00:28<00:20,  1.91s/it]

Loss: 1.125411033630371


Epoch 1:  62%|██████▏   | 16/26 [00:30<00:19,  1.91s/it]

Loss: 2.3939080238342285


Epoch 1:  65%|██████▌   | 17/26 [00:32<00:17,  1.91s/it]

Loss: 6.466042518615723


Epoch 1:  69%|██████▉   | 18/26 [00:34<00:15,  1.91s/it]

Loss: 2.7649824619293213


Epoch 1:  73%|███████▎  | 19/26 [00:36<00:13,  1.91s/it]

Loss: 1.7610833644866943


Epoch 1:  77%|███████▋  | 20/26 [00:38<00:11,  1.91s/it]

Loss: 1.90384840965271


Epoch 1:  81%|████████  | 21/26 [00:40<00:09,  1.91s/it]

Loss: 2.660557270050049


Epoch 1:  85%|████████▍ | 22/26 [00:41<00:07,  1.91s/it]

Loss: 3.4666147232055664


Epoch 1:  88%|████████▊ | 23/26 [00:43<00:05,  1.91s/it]

Loss: 1.8097872734069824


Epoch 1:  92%|█████████▏| 24/26 [00:45<00:03,  1.91s/it]

Loss: 1.3461271524429321


Epoch 1:  96%|█████████▌| 25/26 [00:47<00:01,  1.91s/it]

Loss: 3.6636388301849365


Epoch 1: 100%|██████████| 26/26 [00:48<00:00,  1.87s/it]

Loss: 1.5539804697036743





Epoch 1 Validation Accuracy: 0.9064039408866995, F1-macro: 0.4754521963824289


Epoch 2:   4%|▍         | 1/26 [00:01<00:47,  1.90s/it]

Loss: 1.960782766342163


Epoch 2:   8%|▊         | 2/26 [00:03<00:45,  1.91s/it]

Loss: 1.080037236213684


Epoch 2:  12%|█▏        | 3/26 [00:05<00:43,  1.91s/it]

Loss: 2.8051886558532715


Epoch 2:  15%|█▌        | 4/26 [00:07<00:41,  1.91s/it]

Loss: 2.54613995552063


Epoch 2:  19%|█▉        | 5/26 [00:09<00:40,  1.91s/it]

Loss: 0.8120659589767456


Epoch 2:  23%|██▎       | 6/26 [00:11<00:38,  1.91s/it]

Loss: 0.39914876222610474


Epoch 2:  27%|██▋       | 7/26 [00:13<00:36,  1.91s/it]

Loss: 2.270496129989624


Epoch 2:  31%|███       | 8/26 [00:15<00:34,  1.91s/it]

Loss: 2.781620502471924


Epoch 2:  35%|███▍      | 9/26 [00:17<00:32,  1.91s/it]

Loss: 2.1706764698028564


Epoch 2:  38%|███▊      | 10/26 [00:19<00:30,  1.91s/it]

Loss: 1.3214812278747559


Epoch 2:  42%|████▏     | 11/26 [00:21<00:28,  1.91s/it]

Loss: 0.3762989938259125


Epoch 2:  46%|████▌     | 12/26 [00:22<00:26,  1.91s/it]

Loss: 1.965266466140747


Epoch 2:  50%|█████     | 13/26 [00:24<00:24,  1.91s/it]

Loss: 0.8506954908370972


Epoch 2:  54%|█████▍    | 14/26 [00:26<00:22,  1.91s/it]

Loss: 0.6548368334770203


Epoch 2:  58%|█████▊    | 15/26 [00:28<00:21,  1.91s/it]

Loss: 1.4922051429748535


Epoch 2:  62%|██████▏   | 16/26 [00:30<00:19,  1.91s/it]

Loss: 1.8784430027008057


Epoch 2:  65%|██████▌   | 17/26 [00:32<00:17,  1.91s/it]

Loss: 1.0606184005737305


Epoch 2:  69%|██████▉   | 18/26 [00:34<00:15,  1.91s/it]

Loss: 0.7686144113540649


Epoch 2:  73%|███████▎  | 19/26 [00:36<00:13,  1.91s/it]

Loss: 0.810791015625


Epoch 2:  77%|███████▋  | 20/26 [00:38<00:11,  1.91s/it]

Loss: 1.1890192031860352


Epoch 2:  81%|████████  | 21/26 [00:40<00:09,  1.91s/it]

Loss: 2.714202880859375


Epoch 2:  85%|████████▍ | 22/26 [00:42<00:07,  1.91s/it]

Loss: 2.652357578277588


Epoch 2:  88%|████████▊ | 23/26 [00:43<00:05,  1.91s/it]

Loss: 1.7375973463058472


Epoch 2:  92%|█████████▏| 24/26 [00:45<00:03,  1.91s/it]

Loss: 0.5047447681427002


Epoch 2:  96%|█████████▌| 25/26 [00:47<00:01,  1.91s/it]

Loss: 0.41642290353775024


Epoch 2: 100%|██████████| 26/26 [00:48<00:00,  1.87s/it]

Loss: 1.2874211072921753





Epoch 2 Validation Accuracy: 0.7931034482758621, F1-macro: 0.503840782122905


Epoch 3:   4%|▍         | 1/26 [00:01<00:47,  1.90s/it]

Loss: 0.4921785593032837


Epoch 3:   8%|▊         | 2/26 [00:03<00:45,  1.91s/it]

Loss: 1.09241783618927


Epoch 3:  12%|█▏        | 3/26 [00:05<00:43,  1.91s/it]

Loss: 0.45672109723091125


Epoch 3:  15%|█▌        | 4/26 [00:07<00:41,  1.91s/it]

Loss: 1.7717275619506836


Epoch 3:  19%|█▉        | 5/26 [00:09<00:40,  1.91s/it]

Loss: 0.8910747170448303


Epoch 3:  23%|██▎       | 6/26 [00:11<00:38,  1.91s/it]

Loss: 0.7949361205101013


Epoch 3:  27%|██▋       | 7/26 [00:13<00:36,  1.91s/it]

Loss: 1.0459275245666504


Epoch 3:  31%|███       | 8/26 [00:15<00:34,  1.91s/it]

Loss: 1.001887321472168


Epoch 3:  35%|███▍      | 9/26 [00:17<00:32,  1.91s/it]

Loss: 0.9709329009056091


Epoch 3:  38%|███▊      | 10/26 [00:19<00:30,  1.91s/it]

Loss: 0.7401731610298157


Epoch 3:  42%|████▏     | 11/26 [00:21<00:28,  1.91s/it]

Loss: 0.6575189232826233


Epoch 3:  46%|████▌     | 12/26 [00:22<00:26,  1.91s/it]

Loss: 1.1465989351272583


Epoch 3:  50%|█████     | 13/26 [00:24<00:24,  1.91s/it]

Loss: 1.1084822416305542


Epoch 3:  54%|█████▍    | 14/26 [00:26<00:22,  1.91s/it]

Loss: 1.1887059211730957


Epoch 3:  58%|█████▊    | 15/26 [00:28<00:21,  1.91s/it]

Loss: 1.0915640592575073


Epoch 3:  62%|██████▏   | 16/26 [00:30<00:19,  1.91s/it]

Loss: 0.501024603843689


Epoch 3:  65%|██████▌   | 17/26 [00:32<00:17,  1.91s/it]

Loss: 1.6387027502059937


Epoch 3:  69%|██████▉   | 18/26 [00:34<00:15,  1.91s/it]

Loss: 0.6150586009025574


Epoch 3:  73%|███████▎  | 19/26 [00:36<00:13,  1.91s/it]

Loss: 0.2121947705745697


Epoch 3:  77%|███████▋  | 20/26 [00:38<00:11,  1.91s/it]

Loss: 1.1744509935379028


Epoch 3:  81%|████████  | 21/26 [00:40<00:09,  1.91s/it]

Loss: 1.728515625


Epoch 3:  85%|████████▍ | 22/26 [00:42<00:07,  1.91s/it]

Loss: 0.9610674381256104


Epoch 3:  88%|████████▊ | 23/26 [00:43<00:05,  1.91s/it]

Loss: 0.5363936424255371


Epoch 3:  92%|█████████▏| 24/26 [00:45<00:03,  1.91s/it]

Loss: 0.31534644961357117


Epoch 3:  96%|█████████▌| 25/26 [00:47<00:01,  1.91s/it]

Loss: 0.7500960826873779


Epoch 3: 100%|██████████| 26/26 [00:48<00:00,  1.87s/it]

Loss: 1.0110142230987549





Epoch 3 Validation Accuracy: 0.8029556650246306, F1-macro: 0.5728114478114478


Epoch 4:   4%|▍         | 1/26 [00:01<00:47,  1.91s/it]

Loss: 0.5557078123092651


Epoch 4:   8%|▊         | 2/26 [00:03<00:45,  1.91s/it]

Loss: 0.33008915185928345


Epoch 4:  12%|█▏        | 3/26 [00:05<00:43,  1.91s/it]

Loss: 0.7074391841888428


Epoch 4:  15%|█▌        | 4/26 [00:07<00:42,  1.91s/it]

Loss: 0.6652920246124268


Epoch 4:  19%|█▉        | 5/26 [00:09<00:40,  1.91s/it]

Loss: 0.4297005534172058


Epoch 4:  23%|██▎       | 6/26 [00:11<00:38,  1.91s/it]

Loss: 0.38317015767097473


Epoch 4:  27%|██▋       | 7/26 [00:13<00:36,  1.91s/it]

Loss: 0.37838947772979736


Epoch 4:  31%|███       | 8/26 [00:15<00:34,  1.91s/it]

Loss: 0.31839025020599365


Epoch 4:  35%|███▍      | 9/26 [00:17<00:32,  1.91s/it]

Loss: 0.39348480105400085


Epoch 4:  38%|███▊      | 10/26 [00:19<00:30,  1.91s/it]

Loss: 0.5853657722473145


Epoch 4:  42%|████▏     | 11/26 [00:21<00:28,  1.91s/it]

Loss: 0.8287099003791809


Epoch 4:  46%|████▌     | 12/26 [00:22<00:26,  1.91s/it]

Loss: 0.7224483489990234


Epoch 4:  50%|█████     | 13/26 [00:24<00:24,  1.91s/it]

Loss: 1.0683941841125488


Epoch 4:  54%|█████▍    | 14/26 [00:26<00:22,  1.91s/it]

Loss: 0.2485075294971466


Epoch 4:  58%|█████▊    | 15/26 [00:28<00:21,  1.91s/it]

Loss: 0.5502775311470032


Epoch 4:  62%|██████▏   | 16/26 [00:30<00:19,  1.91s/it]

Loss: 1.0919208526611328


Epoch 4:  65%|██████▌   | 17/26 [00:32<00:17,  1.91s/it]

Loss: 0.32206523418426514


Epoch 4:  69%|██████▉   | 18/26 [00:34<00:15,  1.91s/it]

Loss: 0.7019980549812317


Epoch 4:  73%|███████▎  | 19/26 [00:36<00:13,  1.91s/it]

Loss: 0.4795326888561249


Epoch 4:  77%|███████▋  | 20/26 [00:38<00:11,  1.91s/it]

Loss: 0.6056844592094421


Epoch 4:  81%|████████  | 21/26 [00:40<00:09,  1.91s/it]

Loss: 0.3474440574645996


Epoch 4:  85%|████████▍ | 22/26 [00:42<00:07,  1.91s/it]

Loss: 0.27640005946159363


Epoch 4:  88%|████████▊ | 23/26 [00:43<00:05,  1.91s/it]

Loss: 0.4507630467414856


Epoch 4:  92%|█████████▏| 24/26 [00:45<00:03,  1.91s/it]

Loss: 0.5249307751655579


Epoch 4:  96%|█████████▌| 25/26 [00:47<00:01,  1.91s/it]

Loss: 0.4571946859359741


Epoch 4: 100%|██████████| 26/26 [00:48<00:00,  1.87s/it]

Loss: 0.5177074670791626





Epoch 4 Validation Accuracy: 0.896551724137931, F1-macro: 0.5160631172664321


Epoch 5:   4%|▍         | 1/26 [00:01<00:47,  1.91s/it]

Loss: 0.5830777287483215


Epoch 5:   8%|▊         | 2/26 [00:03<00:45,  1.91s/it]

Loss: 0.41123032569885254


Epoch 5:  12%|█▏        | 3/26 [00:05<00:43,  1.91s/it]

Loss: 0.5604559183120728


Epoch 5:  15%|█▌        | 4/26 [00:07<00:42,  1.91s/it]

Loss: 0.6438726782798767


Epoch 5:  19%|█▉        | 5/26 [00:09<00:40,  1.91s/it]

Loss: 0.4235330820083618


Epoch 5:  23%|██▎       | 6/26 [00:11<00:38,  1.91s/it]

Loss: 0.3786512017250061


Epoch 5:  27%|██▋       | 7/26 [00:13<00:36,  1.91s/it]

Loss: 0.6098175048828125


Epoch 5:  31%|███       | 8/26 [00:15<00:34,  1.91s/it]

Loss: 0.37934064865112305


Epoch 5:  35%|███▍      | 9/26 [00:17<00:32,  1.91s/it]

Loss: 0.48246657848358154


Epoch 5:  38%|███▊      | 10/26 [00:19<00:30,  1.91s/it]

Loss: 0.6525806188583374


Epoch 5:  42%|████▏     | 11/26 [00:21<00:28,  1.91s/it]

Loss: 0.4505666196346283


Epoch 5:  46%|████▌     | 12/26 [00:22<00:26,  1.91s/it]

Loss: 0.9263461828231812


Epoch 5:  50%|█████     | 13/26 [00:24<00:24,  1.91s/it]

Loss: 1.0358147621154785


Epoch 5:  54%|█████▍    | 14/26 [00:26<00:22,  1.91s/it]

Loss: 0.3083999752998352


Epoch 5:  58%|█████▊    | 15/26 [00:28<00:21,  1.91s/it]

Loss: 0.9605759978294373


Epoch 5:  62%|██████▏   | 16/26 [00:30<00:19,  1.91s/it]

Loss: 0.5035786032676697


Epoch 5:  65%|██████▌   | 17/26 [00:32<00:17,  1.91s/it]

Loss: 0.7305954694747925


Epoch 5:  69%|██████▉   | 18/26 [00:34<00:15,  1.91s/it]

Loss: 1.099663496017456


Epoch 5:  73%|███████▎  | 19/26 [00:36<00:13,  1.91s/it]

Loss: 0.18954700231552124


Epoch 5:  77%|███████▋  | 20/26 [00:38<00:11,  1.91s/it]

Loss: 0.46775883436203003


Epoch 5:  81%|████████  | 21/26 [00:40<00:09,  1.91s/it]

Loss: 0.6862531304359436


Epoch 5:  85%|████████▍ | 22/26 [00:42<00:07,  1.91s/it]

Loss: 0.2514146566390991


Epoch 5:  88%|████████▊ | 23/26 [00:43<00:05,  1.91s/it]

Loss: 0.19692106544971466


Epoch 5:  92%|█████████▏| 24/26 [00:45<00:03,  1.91s/it]

Loss: 1.0977421998977661


Epoch 5:  96%|█████████▌| 25/26 [00:47<00:01,  1.91s/it]

Loss: 0.7146292924880981


Epoch 5: 100%|██████████| 26/26 [00:48<00:00,  1.87s/it]

Loss: 0.25469422340393066





Epoch 5 Validation Accuracy: 0.7931034482758621, F1-macro: 0.5514520202020202


Epoch 6:   4%|▍         | 1/26 [00:01<00:47,  1.90s/it]

Loss: 0.539968729019165


Epoch 6:   8%|▊         | 2/26 [00:03<00:45,  1.91s/it]

Loss: 0.3077024817466736


Epoch 6:  12%|█▏        | 3/26 [00:05<00:43,  1.91s/it]

Loss: 0.6117141246795654


Epoch 6:  15%|█▌        | 4/26 [00:07<00:42,  1.91s/it]

Loss: 0.4067673981189728


Epoch 6:  19%|█▉        | 5/26 [00:09<00:40,  1.91s/it]

Loss: 0.6354615092277527


Epoch 6:  23%|██▎       | 6/26 [00:11<00:38,  1.91s/it]

Loss: 0.34268656373023987


Epoch 6:  27%|██▋       | 7/26 [00:13<00:36,  1.91s/it]

Loss: 0.6339073777198792


Epoch 6:  31%|███       | 8/26 [00:15<00:34,  1.91s/it]

Loss: 0.36137449741363525


Epoch 6:  35%|███▍      | 9/26 [00:17<00:32,  1.91s/it]

Loss: 0.3032327890396118


Epoch 6:  38%|███▊      | 10/26 [00:19<00:30,  1.91s/it]

Loss: 0.6691765189170837


Epoch 6:  42%|████▏     | 11/26 [00:21<00:28,  1.91s/it]

Loss: 0.12800469994544983


Epoch 6:  46%|████▌     | 12/26 [00:22<00:26,  1.91s/it]

Loss: 0.3054576814174652


Epoch 6:  50%|█████     | 13/26 [00:24<00:24,  1.91s/it]

Loss: 0.30366066098213196


Epoch 6:  54%|█████▍    | 14/26 [00:26<00:22,  1.91s/it]

Loss: 0.31020718812942505


Epoch 6:  58%|█████▊    | 15/26 [00:28<00:21,  1.91s/it]

Loss: 0.4196222126483917


Epoch 6:  62%|██████▏   | 16/26 [00:30<00:19,  1.91s/it]

Loss: 0.5081971287727356


Epoch 6:  65%|██████▌   | 17/26 [00:32<00:17,  1.91s/it]

Loss: 0.38796401023864746


Epoch 6:  69%|██████▉   | 18/26 [00:34<00:15,  1.91s/it]

Loss: 0.4540928602218628


Epoch 6:  73%|███████▎  | 19/26 [00:36<00:13,  1.91s/it]

Loss: 0.3965096175670624


Epoch 6:  77%|███████▋  | 20/26 [00:38<00:11,  1.91s/it]

Loss: 0.19372060894966125


Epoch 6:  81%|████████  | 21/26 [00:40<00:09,  1.91s/it]

Loss: 0.5683157444000244


Epoch 6:  85%|████████▍ | 22/26 [00:42<00:07,  1.91s/it]

Loss: 0.43595993518829346


Epoch 6:  88%|████████▊ | 23/26 [00:43<00:05,  1.91s/it]

Loss: 0.3974130153656006


Epoch 6:  92%|█████████▏| 24/26 [00:45<00:03,  1.91s/it]

Loss: 0.22595569491386414


Epoch 6:  96%|█████████▌| 25/26 [00:47<00:01,  1.91s/it]

Loss: 0.4990074634552002


Epoch 6: 100%|██████████| 26/26 [00:48<00:00,  1.87s/it]

Loss: 0.6510501503944397





Epoch 6 Validation Accuracy: 0.8472906403940886, F1-macro: 0.5153638814016173


Epoch 7:   4%|▍         | 1/26 [00:01<00:47,  1.90s/it]

Loss: 0.3328458368778229


Epoch 7:   8%|▊         | 2/26 [00:03<00:45,  1.91s/it]

Loss: 0.38352829217910767


Epoch 7:  12%|█▏        | 3/26 [00:05<00:43,  1.91s/it]

Loss: 0.3820558786392212


Epoch 7:  15%|█▌        | 4/26 [00:07<00:42,  1.91s/it]

Loss: 0.44525381922721863


Epoch 7:  19%|█▉        | 5/26 [00:09<00:40,  1.91s/it]

Loss: 0.3741346001625061


Epoch 7:  23%|██▎       | 6/26 [00:11<00:38,  1.91s/it]

Loss: 0.5194562673568726


Epoch 7:  27%|██▋       | 7/26 [00:13<00:36,  1.91s/it]

Loss: 0.9288638234138489


Epoch 7:  31%|███       | 8/26 [00:15<00:34,  1.91s/it]

Loss: 0.45430564880371094


Epoch 7:  35%|███▍      | 9/26 [00:17<00:32,  1.91s/it]

Loss: 1.0004360675811768


Epoch 7:  38%|███▊      | 10/26 [00:19<00:30,  1.91s/it]

Loss: 0.491166353225708


Epoch 7:  42%|████▏     | 11/26 [00:21<00:28,  1.91s/it]

Loss: 0.4290170669555664


Epoch 7:  46%|████▌     | 12/26 [00:22<00:26,  1.91s/it]

Loss: 0.8318270444869995


Epoch 7:  50%|█████     | 13/26 [00:24<00:24,  1.91s/it]

Loss: 0.2448873370885849


Epoch 7:  54%|█████▍    | 14/26 [00:26<00:22,  1.91s/it]

Loss: 0.25383609533309937


Epoch 7:  58%|█████▊    | 15/26 [00:28<00:21,  1.91s/it]

Loss: 0.547400712966919


Epoch 7:  62%|██████▏   | 16/26 [00:30<00:19,  1.91s/it]

Loss: 0.22402384877204895


Epoch 7:  65%|██████▌   | 17/26 [00:32<00:17,  1.91s/it]

Loss: 0.8507946133613586


Epoch 7:  69%|██████▉   | 18/26 [00:34<00:15,  1.91s/it]

Loss: 1.0702073574066162


Epoch 7:  73%|███████▎  | 19/26 [00:36<00:13,  1.91s/it]

Loss: 0.22924906015396118


Epoch 7:  77%|███████▋  | 20/26 [00:38<00:11,  1.91s/it]

Loss: 0.6994640231132507


Epoch 7:  81%|████████  | 21/26 [00:40<00:09,  1.91s/it]

Loss: 0.2588196396827698


Epoch 7:  85%|████████▍ | 22/26 [00:42<00:07,  1.91s/it]

Loss: 0.2767413258552551


Epoch 7:  88%|████████▊ | 23/26 [00:43<00:05,  1.91s/it]

Loss: 0.747413158416748


Epoch 7:  92%|█████████▏| 24/26 [00:45<00:03,  1.91s/it]

Loss: 0.24191540479660034


Epoch 7:  96%|█████████▌| 25/26 [00:47<00:01,  1.91s/it]

Loss: 0.15599334239959717


Epoch 7: 100%|██████████| 26/26 [00:48<00:00,  1.87s/it]

Loss: 0.5573471784591675





Epoch 7 Validation Accuracy: 0.7635467980295566, F1-macro: 0.5657754010695186


Epoch 8:   4%|▍         | 1/26 [00:01<00:47,  1.90s/it]

Loss: 0.5222079157829285


Epoch 8:   8%|▊         | 2/26 [00:03<00:45,  1.91s/it]

Loss: 0.6142699122428894


Epoch 8:  12%|█▏        | 3/26 [00:05<00:43,  1.91s/it]

Loss: 0.5882073640823364


Epoch 8:  15%|█▌        | 4/26 [00:07<00:42,  1.91s/it]

Loss: 0.3562822639942169


Epoch 8:  19%|█▉        | 5/26 [00:09<00:40,  1.91s/it]

Loss: 0.749689519405365


Epoch 8:  23%|██▎       | 6/26 [00:11<00:38,  1.91s/it]

Loss: 0.40569403767585754


Epoch 8:  27%|██▋       | 7/26 [00:13<00:36,  1.91s/it]

Loss: 0.7511558532714844


Epoch 8:  31%|███       | 8/26 [00:15<00:34,  1.91s/it]

Loss: 0.2626662850379944


Epoch 8:  35%|███▍      | 9/26 [00:17<00:32,  1.91s/it]

Loss: 0.38341614603996277


Epoch 8:  38%|███▊      | 10/26 [00:19<00:30,  1.91s/it]

Loss: 0.4988805949687958


Epoch 8:  42%|████▏     | 11/26 [00:21<00:28,  1.91s/it]

Loss: 0.6041668653488159


Epoch 8:  46%|████▌     | 12/26 [00:22<00:26,  1.91s/it]

Loss: 0.2861948311328888


Epoch 8:  50%|█████     | 13/26 [00:24<00:24,  1.91s/it]

Loss: 0.5969439744949341


Epoch 8:  54%|█████▍    | 14/26 [00:26<00:22,  1.91s/it]

Loss: 0.2104962319135666


Epoch 8:  58%|█████▊    | 15/26 [00:28<00:21,  1.91s/it]

Loss: 0.41119444370269775


Epoch 8:  62%|██████▏   | 16/26 [00:30<00:19,  1.91s/it]

Loss: 0.6532660722732544


Epoch 8:  65%|██████▌   | 17/26 [00:32<00:17,  1.91s/it]

Loss: 0.17977291345596313


Epoch 8:  69%|██████▉   | 18/26 [00:34<00:15,  1.91s/it]

Loss: 0.3799440264701843


Epoch 8:  73%|███████▎  | 19/26 [00:36<00:13,  1.91s/it]

Loss: 0.5174612402915955


Epoch 8:  77%|███████▋  | 20/26 [00:38<00:11,  1.91s/it]

Loss: 0.28108927607536316


Epoch 8:  81%|████████  | 21/26 [00:40<00:09,  1.91s/it]

Loss: 0.26239013671875


Epoch 8:  85%|████████▍ | 22/26 [00:42<00:07,  1.91s/it]

Loss: 0.832993745803833


Epoch 8:  88%|████████▊ | 23/26 [00:43<00:05,  1.91s/it]

Loss: 0.20320098102092743


Epoch 8:  92%|█████████▏| 24/26 [00:45<00:03,  1.91s/it]

Loss: 0.616267740726471


Epoch 8:  96%|█████████▌| 25/26 [00:47<00:01,  1.91s/it]

Loss: 0.20353877544403076


Epoch 8: 100%|██████████| 26/26 [00:48<00:00,  1.87s/it]

Loss: 0.058540284633636475





Epoch 8 Validation Accuracy: 0.8866995073891626, F1-macro: 0.4699738903394256


Epoch 9:   4%|▍         | 1/26 [00:01<00:47,  1.90s/it]

Loss: 0.6841535568237305


Epoch 9:   8%|▊         | 2/26 [00:03<00:45,  1.91s/it]

Loss: 0.4609583616256714


Epoch 9:  12%|█▏        | 3/26 [00:05<00:43,  1.91s/it]

Loss: 0.1800079345703125


Epoch 9:  15%|█▌        | 4/26 [00:07<00:42,  1.91s/it]

Loss: 0.3729211688041687


Epoch 9:  19%|█▉        | 5/26 [00:09<00:40,  1.91s/it]

Loss: 0.1797170639038086


Epoch 9:  23%|██▎       | 6/26 [00:11<00:38,  1.91s/it]

Loss: 0.37504178285598755


Epoch 9:  27%|██▋       | 7/26 [00:13<00:36,  1.91s/it]

Loss: 0.42261967062950134


Epoch 9:  31%|███       | 8/26 [00:15<00:34,  1.91s/it]

Loss: 0.5065535306930542


Epoch 9:  35%|███▍      | 9/26 [00:17<00:32,  1.91s/it]

Loss: 0.18637600541114807


Epoch 9:  38%|███▊      | 10/26 [00:19<00:30,  1.91s/it]

Loss: 0.32946717739105225


Epoch 9:  42%|████▏     | 11/26 [00:20<00:28,  1.91s/it]

Loss: 0.4264908730983734


Epoch 9:  46%|████▌     | 12/26 [00:22<00:26,  1.91s/it]

Loss: 0.34838730096817017


Epoch 9:  50%|█████     | 13/26 [00:24<00:24,  1.91s/it]

Loss: 0.19166424870491028


Epoch 9:  54%|█████▍    | 14/26 [00:26<00:22,  1.91s/it]

Loss: 0.3224145770072937


Epoch 9:  58%|█████▊    | 15/26 [00:28<00:21,  1.91s/it]

Loss: 0.4583233892917633


Epoch 9:  62%|██████▏   | 16/26 [00:30<00:19,  1.91s/it]

Loss: 0.3160442113876343


Epoch 9:  65%|██████▌   | 17/26 [00:32<00:17,  1.91s/it]

Loss: 0.4785574674606323


Epoch 9:  69%|██████▉   | 18/26 [00:34<00:15,  1.91s/it]

Loss: 0.3712550699710846


Epoch 9:  73%|███████▎  | 19/26 [00:36<00:13,  1.91s/it]

Loss: 0.31438902020454407


Epoch 9:  77%|███████▋  | 20/26 [00:38<00:11,  1.91s/it]

Loss: 0.19687457382678986


Epoch 9:  81%|████████  | 21/26 [00:40<00:09,  1.91s/it]

Loss: 0.27591729164123535


Epoch 9:  85%|████████▍ | 22/26 [00:42<00:07,  1.91s/it]

Loss: 0.5418747067451477


Epoch 9:  88%|████████▊ | 23/26 [00:43<00:05,  1.91s/it]

Loss: 0.3118785619735718


Epoch 9:  92%|█████████▏| 24/26 [00:45<00:03,  1.91s/it]

Loss: 0.47199922800064087


Epoch 9:  96%|█████████▌| 25/26 [00:47<00:01,  1.91s/it]

Loss: 0.4358230531215668


Epoch 9: 100%|██████████| 26/26 [00:48<00:00,  1.87s/it]

Loss: 0.4058949947357178





Epoch 9 Validation Accuracy: 0.8522167487684729, F1-macro: 0.49114304812834225


Epoch 10:   4%|▍         | 1/26 [00:01<00:47,  1.90s/it]

Loss: 0.19198378920555115


Epoch 10:   8%|▊         | 2/26 [00:03<00:45,  1.91s/it]

Loss: 0.7128294706344604


Epoch 10:  12%|█▏        | 3/26 [00:05<00:43,  1.91s/it]

Loss: 0.253839910030365


Epoch 10:  15%|█▌        | 4/26 [00:07<00:42,  1.91s/it]

Loss: 0.14406761527061462


Epoch 10:  19%|█▉        | 5/26 [00:09<00:40,  1.91s/it]

Loss: 0.5633878707885742


Epoch 10:  23%|██▎       | 6/26 [00:11<00:38,  1.91s/it]

Loss: 0.2934267520904541


Epoch 10:  27%|██▋       | 7/26 [00:13<00:36,  1.91s/it]

Loss: 0.4796562194824219


Epoch 10:  31%|███       | 8/26 [00:15<00:34,  1.91s/it]

Loss: 0.2612587809562683


Epoch 10:  35%|███▍      | 9/26 [00:17<00:32,  1.91s/it]

Loss: 0.4434029459953308


Epoch 10:  38%|███▊      | 10/26 [00:19<00:30,  1.91s/it]

Loss: 0.3867691159248352


Epoch 10:  42%|████▏     | 11/26 [00:21<00:28,  1.91s/it]

Loss: 0.43034839630126953


Epoch 10:  46%|████▌     | 12/26 [00:22<00:26,  1.91s/it]

Loss: 0.23486211895942688


Epoch 10:  50%|█████     | 13/26 [00:24<00:24,  1.91s/it]

Loss: 0.2107633501291275


Epoch 10:  54%|█████▍    | 14/26 [00:26<00:22,  1.91s/it]

Loss: 0.2640320062637329


Epoch 10:  58%|█████▊    | 15/26 [00:28<00:21,  1.91s/it]

Loss: 0.2875014841556549


Epoch 10:  62%|██████▏   | 16/26 [00:30<00:19,  1.91s/it]

Loss: 0.47085073590278625


Epoch 10:  65%|██████▌   | 17/26 [00:32<00:17,  1.91s/it]

Loss: 0.284310519695282


Epoch 10:  69%|██████▉   | 18/26 [00:34<00:15,  1.91s/it]

Loss: 0.34333088994026184


Epoch 10:  73%|███████▎  | 19/26 [00:36<00:13,  1.91s/it]

Loss: 0.5904964208602905


Epoch 10:  77%|███████▋  | 20/26 [00:38<00:11,  1.91s/it]

Loss: 0.40953630208969116


Epoch 10:  81%|████████  | 21/26 [00:40<00:09,  1.91s/it]

Loss: 0.21360328793525696


Epoch 10:  85%|████████▍ | 22/26 [00:42<00:07,  1.91s/it]

Loss: 0.5170606374740601


Epoch 10:  88%|████████▊ | 23/26 [00:43<00:05,  1.91s/it]

Loss: 0.2406170815229416


Epoch 10:  92%|█████████▏| 24/26 [00:45<00:03,  1.91s/it]

Loss: 0.3787812292575836


Epoch 10:  96%|█████████▌| 25/26 [00:47<00:01,  1.91s/it]

Loss: 0.8008947372436523


Epoch 10: 100%|██████████| 26/26 [00:48<00:00,  1.87s/it]

Loss: 0.6787471771240234





Epoch 10 Validation Accuracy: 0.8325123152709359, F1-macro: 0.4543010752688172


In [None]:
# Evaluate on the test set
test_metrics = evaluate(model, test_dataloader, device)
print(f"Test Accuracy: {test_metrics['accuracy']}, F1-macro: {test_metrics['f1_macro']}")

Test Accuracy: 0.8137254901960784, F1-macro: 0.4543010752688172


In [None]:
current_type = 'emotional reasoning'

results_df.loc[
    results_df["distortion_type"] == current_type,
    ["test_accuracy", "f1_macro"]
] = [test_metrics['accuracy'], test_metrics['f1_macro']]

# 5. Fortune Telling

In [None]:
# Add labels
data1_1_labels = list(data1['fortune telling'][data1_1.index])
data2_1_labels = list(data2['fortune telling'][data2_1.index])
data3_1_labels = list(data3['fortune telling'][data3_1.index])

# Merging Data
data_encoded = data1_1_encoded + data2_1_encoded + data3_1_encoded
data_labels = data1_1_labels + data2_1_labels + data3_1_labels

dataset_with_labels = CustomDatasetWithLabels(data_encoded, data_labels)

# Define proportions for splitting
train_size = int(0.8 * len(dataset_with_labels))
val_size = int(0.1 * len(dataset_with_labels))
test_size = len(dataset_with_labels) - train_size - val_size

# Split the dataset
train_dataset, val_dataset, test_dataset = random_split(dataset_with_labels, [train_size, val_size, test_size])

# Create DataLoaders for each set
train_dataloader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=64)
test_dataloader = DataLoader(test_dataset, batch_size=64)

input_dim = bert_config.hidden_size
num_labels = len(set(data_labels))
label_emb = torch.randn(num_labels, input_dim)

In [None]:
# Instantiate the InnerProductClassifier model
model = InnerProductClassifier(input_dim, label_emb).to(device)

# Define loss function
criterion = nn.CrossEntropyLoss()

# Define optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=2e-4)

In [None]:
EPOCHS = 10

# Move the BERT model to the device
bert_model.to(device)

for epoch in range(EPOCHS):
    model.train()
    for batch in tqdm(train_dataloader, desc=f"Epoch {epoch+1}"):

        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        y = batch["y"].to(device)

        # Get embeddings from the BERT model
        with torch.no_grad():
            outputs = bert_model(input_ids=input_ids, attention_mask=attention_mask)
            embeddings = outputs.last_hidden_state[:, 0, :] # Get the [CLS] token embedding

        # Pass embeddings to the InnerProductClassifier
        logits = model(embeddings)
        loss = criterion(logits, y)

        # Backpropagate and update weights
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        print(f"Loss: {loss.item()}")

    # Evaluate on the validation set after each epoch
    val_metrics = evaluate(model, val_dataloader, device)
    print(f"Epoch {epoch+1} Validation Accuracy: {val_metrics['accuracy']}, F1-macro: {val_metrics['f1_macro']}")

Epoch 1:   4%|▍         | 1/26 [00:01<00:47,  1.88s/it]

Loss: 1.2807241678237915


Epoch 1:   8%|▊         | 2/26 [00:03<00:45,  1.90s/it]

Loss: 4.526332378387451


Epoch 1:  12%|█▏        | 3/26 [00:05<00:43,  1.90s/it]

Loss: 0.7799018621444702


Epoch 1:  15%|█▌        | 4/26 [00:07<00:41,  1.91s/it]

Loss: 2.8697123527526855


Epoch 1:  19%|█▉        | 5/26 [00:09<00:40,  1.91s/it]

Loss: 2.070244789123535


Epoch 1:  23%|██▎       | 6/26 [00:11<00:38,  1.91s/it]

Loss: 2.456763982772827


Epoch 1:  27%|██▋       | 7/26 [00:13<00:36,  1.91s/it]

Loss: 2.1762025356292725


Epoch 1:  31%|███       | 8/26 [00:15<00:34,  1.91s/it]

Loss: 0.8471379280090332


Epoch 1:  35%|███▍      | 9/26 [00:17<00:32,  1.91s/it]

Loss: 1.3571405410766602


Epoch 1:  38%|███▊      | 10/26 [00:19<00:30,  1.91s/it]

Loss: 1.3345723152160645


Epoch 1:  42%|████▏     | 11/26 [00:21<00:28,  1.91s/it]

Loss: 1.1968262195587158


Epoch 1:  46%|████▌     | 12/26 [00:22<00:26,  1.91s/it]

Loss: 0.8071523904800415


Epoch 1:  50%|█████     | 13/26 [00:24<00:24,  1.91s/it]

Loss: 1.0488237142562866


Epoch 1:  54%|█████▍    | 14/26 [00:26<00:22,  1.91s/it]

Loss: 0.6367789506912231


Epoch 1:  58%|█████▊    | 15/26 [00:28<00:21,  1.91s/it]

Loss: 1.0892564058303833


Epoch 1:  62%|██████▏   | 16/26 [00:30<00:19,  1.91s/it]

Loss: 1.1028259992599487


Epoch 1:  65%|██████▌   | 17/26 [00:32<00:17,  1.91s/it]

Loss: 0.7103679776191711


Epoch 1:  69%|██████▉   | 18/26 [00:34<00:15,  1.91s/it]

Loss: 0.7898262739181519


Epoch 1:  73%|███████▎  | 19/26 [00:36<00:13,  1.91s/it]

Loss: 0.7684493660926819


Epoch 1:  77%|███████▋  | 20/26 [00:38<00:11,  1.91s/it]

Loss: 0.6245958209037781


Epoch 1:  81%|████████  | 21/26 [00:40<00:09,  1.91s/it]

Loss: 0.9309926629066467


Epoch 1:  85%|████████▍ | 22/26 [00:42<00:07,  1.91s/it]

Loss: 1.2852487564086914


Epoch 1:  88%|████████▊ | 23/26 [00:43<00:05,  1.91s/it]

Loss: 1.0812945365905762


Epoch 1:  92%|█████████▏| 24/26 [00:45<00:03,  1.91s/it]

Loss: 0.983379602432251


Epoch 1:  96%|█████████▌| 25/26 [00:47<00:01,  1.91s/it]

Loss: 1.0021429061889648


Epoch 1: 100%|██████████| 26/26 [00:48<00:00,  1.87s/it]

Loss: 1.2449065446853638





Epoch 1 Validation Accuracy: 0.8669950738916257, F1-macro: 0.49867374005305043


Epoch 2:   4%|▍         | 1/26 [00:01<00:47,  1.90s/it]

Loss: 0.667216420173645


Epoch 2:   8%|▊         | 2/26 [00:03<00:45,  1.91s/it]

Loss: 0.7285743355751038


Epoch 2:  12%|█▏        | 3/26 [00:05<00:43,  1.91s/it]

Loss: 0.6696280241012573


Epoch 2:  15%|█▌        | 4/26 [00:07<00:42,  1.91s/it]

Loss: 0.82417231798172


Epoch 2:  19%|█▉        | 5/26 [00:09<00:40,  1.91s/it]

Loss: 0.49421820044517517


Epoch 2:  23%|██▎       | 6/26 [00:11<00:38,  1.91s/it]

Loss: 0.8864606618881226


Epoch 2:  27%|██▋       | 7/26 [00:13<00:36,  1.91s/it]

Loss: 0.9885851144790649


Epoch 2:  31%|███       | 8/26 [00:15<00:34,  1.91s/it]

Loss: 0.375326931476593


Epoch 2:  35%|███▍      | 9/26 [00:17<00:32,  1.91s/it]

Loss: 0.4698525667190552


Epoch 2:  38%|███▊      | 10/26 [00:19<00:30,  1.91s/it]

Loss: 0.9578812122344971


Epoch 2:  42%|████▏     | 11/26 [00:21<00:28,  1.91s/it]

Loss: 0.48072317242622375


Epoch 2:  46%|████▌     | 12/26 [00:22<00:26,  1.91s/it]

Loss: 0.6788203716278076


Epoch 2:  50%|█████     | 13/26 [00:24<00:24,  1.91s/it]

Loss: 0.8274105787277222


Epoch 2:  54%|█████▍    | 14/26 [00:26<00:22,  1.91s/it]

Loss: 1.272873878479004


Epoch 2:  58%|█████▊    | 15/26 [00:28<00:21,  1.91s/it]

Loss: 0.7936132550239563


Epoch 2:  62%|██████▏   | 16/26 [00:30<00:19,  1.91s/it]

Loss: 0.5901023149490356


Epoch 2:  65%|██████▌   | 17/26 [00:32<00:17,  1.91s/it]

Loss: 1.1359443664550781


Epoch 2:  69%|██████▉   | 18/26 [00:34<00:15,  1.91s/it]

Loss: 0.27199918031692505


Epoch 2:  73%|███████▎  | 19/26 [00:36<00:13,  1.91s/it]

Loss: 0.7165916562080383


Epoch 2:  77%|███████▋  | 20/26 [00:38<00:11,  1.91s/it]

Loss: 0.7456344366073608


Epoch 2:  81%|████████  | 21/26 [00:40<00:09,  1.91s/it]

Loss: 0.49616700410842896


Epoch 2:  85%|████████▍ | 22/26 [00:42<00:07,  1.91s/it]

Loss: 0.8256616592407227


Epoch 2:  88%|████████▊ | 23/26 [00:43<00:05,  1.91s/it]

Loss: 0.3736814260482788


Epoch 2:  92%|█████████▏| 24/26 [00:45<00:03,  1.91s/it]

Loss: 1.1508287191390991


Epoch 2:  96%|█████████▌| 25/26 [00:47<00:01,  1.91s/it]

Loss: 0.5161435604095459


Epoch 2: 100%|██████████| 26/26 [00:48<00:00,  1.87s/it]

Loss: 0.5615646839141846





Epoch 2 Validation Accuracy: 0.7536945812807881, F1-macro: 0.5583884441350505


Epoch 3:   4%|▍         | 1/26 [00:01<00:47,  1.90s/it]

Loss: 0.4882917106151581


Epoch 3:   8%|▊         | 2/26 [00:03<00:45,  1.91s/it]

Loss: 0.8962527513504028


Epoch 3:  12%|█▏        | 3/26 [00:05<00:43,  1.91s/it]

Loss: 0.42936909198760986


Epoch 3:  15%|█▌        | 4/26 [00:07<00:42,  1.91s/it]

Loss: 1.0140944719314575


Epoch 3:  19%|█▉        | 5/26 [00:09<00:40,  1.91s/it]

Loss: 0.46439680457115173


Epoch 3:  23%|██▎       | 6/26 [00:11<00:38,  1.91s/it]

Loss: 0.2539103627204895


Epoch 3:  27%|██▋       | 7/26 [00:13<00:36,  1.91s/it]

Loss: 0.8554956912994385


Epoch 3:  31%|███       | 8/26 [00:15<00:34,  1.91s/it]

Loss: 0.6088250279426575


Epoch 3:  35%|███▍      | 9/26 [00:17<00:32,  1.91s/it]

Loss: 0.5869991779327393


Epoch 3:  38%|███▊      | 10/26 [00:19<00:30,  1.91s/it]

Loss: 0.547999382019043


Epoch 3:  42%|████▏     | 11/26 [00:21<00:28,  1.91s/it]

Loss: 0.43042582273483276


Epoch 3:  46%|████▌     | 12/26 [00:22<00:26,  1.91s/it]

Loss: 0.7336955070495605


Epoch 3:  50%|█████     | 13/26 [00:24<00:24,  1.91s/it]

Loss: 0.5099453330039978


Epoch 3:  54%|█████▍    | 14/26 [00:26<00:22,  1.91s/it]

Loss: 0.2571116089820862


Epoch 3:  58%|█████▊    | 15/26 [00:28<00:21,  1.91s/it]

Loss: 0.42324742674827576


Epoch 3:  62%|██████▏   | 16/26 [00:30<00:19,  1.91s/it]

Loss: 0.4000440835952759


Epoch 3:  65%|██████▌   | 17/26 [00:32<00:17,  1.91s/it]

Loss: 0.4697427451610565


Epoch 3:  69%|██████▉   | 18/26 [00:34<00:15,  1.91s/it]

Loss: 0.6514499187469482


Epoch 3:  73%|███████▎  | 19/26 [00:36<00:13,  1.91s/it]

Loss: 0.6539767384529114


Epoch 3:  77%|███████▋  | 20/26 [00:38<00:11,  1.91s/it]

Loss: 1.0049257278442383


Epoch 3:  81%|████████  | 21/26 [00:40<00:09,  1.91s/it]

Loss: 0.5248520374298096


Epoch 3:  85%|████████▍ | 22/26 [00:42<00:07,  1.91s/it]

Loss: 0.45965662598609924


Epoch 3:  88%|████████▊ | 23/26 [00:43<00:05,  1.91s/it]

Loss: 0.4401121735572815


Epoch 3:  92%|█████████▏| 24/26 [00:45<00:03,  1.91s/it]

Loss: 0.7984366416931152


Epoch 3:  96%|█████████▌| 25/26 [00:47<00:01,  1.91s/it]

Loss: 0.4720015823841095


Epoch 3: 100%|██████████| 26/26 [00:48<00:00,  1.87s/it]

Loss: 0.19925734400749207





Epoch 3 Validation Accuracy: 0.8522167487684729, F1-macro: 0.5427927927927928


Epoch 4:   4%|▍         | 1/26 [00:01<00:47,  1.90s/it]

Loss: 0.4883788526058197


Epoch 4:   8%|▊         | 2/26 [00:03<00:45,  1.91s/it]

Loss: 0.4429478049278259


Epoch 4:  12%|█▏        | 3/26 [00:05<00:43,  1.91s/it]

Loss: 0.2470058798789978


Epoch 4:  15%|█▌        | 4/26 [00:07<00:42,  1.91s/it]

Loss: 0.25119906663894653


Epoch 4:  19%|█▉        | 5/26 [00:09<00:40,  1.91s/it]

Loss: 1.0433753728866577


Epoch 4:  23%|██▎       | 6/26 [00:11<00:38,  1.91s/it]

Loss: 1.1810189485549927


Epoch 4:  27%|██▋       | 7/26 [00:13<00:36,  1.91s/it]

Loss: 1.206894874572754


Epoch 4:  31%|███       | 8/26 [00:15<00:34,  1.91s/it]

Loss: 0.3576919436454773


Epoch 4:  35%|███▍      | 9/26 [00:17<00:32,  1.91s/it]

Loss: 0.8292580842971802


Epoch 4:  38%|███▊      | 10/26 [00:19<00:30,  1.91s/it]

Loss: 1.0397694110870361


Epoch 4:  42%|████▏     | 11/26 [00:21<00:28,  1.91s/it]

Loss: 0.6360683441162109


Epoch 4:  46%|████▌     | 12/26 [00:22<00:26,  1.91s/it]

Loss: 1.9515962600708008


Epoch 4:  50%|█████     | 13/26 [00:24<00:24,  1.91s/it]

Loss: 0.6140482425689697


Epoch 4:  54%|█████▍    | 14/26 [00:26<00:22,  1.91s/it]

Loss: 1.082627534866333


Epoch 4:  58%|█████▊    | 15/26 [00:28<00:21,  1.91s/it]

Loss: 0.4798769950866699


Epoch 4:  62%|██████▏   | 16/26 [00:30<00:19,  1.91s/it]

Loss: 0.5419294834136963


Epoch 4:  65%|██████▌   | 17/26 [00:32<00:17,  1.91s/it]

Loss: 0.6748543977737427


Epoch 4:  69%|██████▉   | 18/26 [00:34<00:15,  1.91s/it]

Loss: 0.6303184032440186


Epoch 4:  73%|███████▎  | 19/26 [00:36<00:13,  1.91s/it]

Loss: 1.047004222869873


Epoch 4:  77%|███████▋  | 20/26 [00:38<00:11,  1.91s/it]

Loss: 0.5188179016113281


Epoch 4:  81%|████████  | 21/26 [00:40<00:09,  1.91s/it]

Loss: 0.7552145719528198


Epoch 4:  85%|████████▍ | 22/26 [00:42<00:07,  1.91s/it]

Loss: 0.9385752081871033


Epoch 4:  88%|████████▊ | 23/26 [00:43<00:05,  1.91s/it]

Loss: 0.8637961149215698


Epoch 4:  92%|█████████▏| 24/26 [00:45<00:03,  1.91s/it]

Loss: 0.8766013383865356


Epoch 4:  96%|█████████▌| 25/26 [00:47<00:01,  1.91s/it]

Loss: 0.5639714002609253


Epoch 4: 100%|██████████| 26/26 [00:48<00:00,  1.87s/it]

Loss: 1.0648350715637207





Epoch 4 Validation Accuracy: 0.8472906403940886, F1-macro: 0.5968351592030239


Epoch 5:   4%|▍         | 1/26 [00:01<00:47,  1.90s/it]

Loss: 0.5972508192062378


Epoch 5:   8%|▊         | 2/26 [00:03<00:45,  1.91s/it]

Loss: 0.5634515285491943


Epoch 5:  12%|█▏        | 3/26 [00:05<00:43,  1.91s/it]

Loss: 0.38410502672195435


Epoch 5:  15%|█▌        | 4/26 [00:07<00:42,  1.91s/it]

Loss: 0.4670642614364624


Epoch 5:  19%|█▉        | 5/26 [00:09<00:40,  1.91s/it]

Loss: 0.5453723073005676


Epoch 5:  23%|██▎       | 6/26 [00:11<00:38,  1.91s/it]

Loss: 0.4674665033817291


Epoch 5:  27%|██▋       | 7/26 [00:13<00:36,  1.91s/it]

Loss: 0.28947675228118896


Epoch 5:  31%|███       | 8/26 [00:15<00:34,  1.91s/it]

Loss: 0.4313507676124573


Epoch 5:  35%|███▍      | 9/26 [00:17<00:32,  1.91s/it]

Loss: 0.976961076259613


Epoch 5:  38%|███▊      | 10/26 [00:19<00:30,  1.91s/it]

Loss: 0.8426361083984375


Epoch 5:  42%|████▏     | 11/26 [00:21<00:28,  1.91s/it]

Loss: 0.6260578632354736


Epoch 5:  46%|████▌     | 12/26 [00:22<00:26,  1.91s/it]

Loss: 0.7139112949371338


Epoch 5:  50%|█████     | 13/26 [00:24<00:24,  1.91s/it]

Loss: 0.34506163001060486


Epoch 5:  54%|█████▍    | 14/26 [00:26<00:22,  1.91s/it]

Loss: 0.6324039697647095


Epoch 5:  58%|█████▊    | 15/26 [00:28<00:21,  1.91s/it]

Loss: 0.618362545967102


Epoch 5:  62%|██████▏   | 16/26 [00:30<00:19,  1.91s/it]

Loss: 0.4815063178539276


Epoch 5:  65%|██████▌   | 17/26 [00:32<00:17,  1.91s/it]

Loss: 1.204353928565979


Epoch 5:  69%|██████▉   | 18/26 [00:34<00:15,  1.91s/it]

Loss: 0.5232605934143066


Epoch 5:  73%|███████▎  | 19/26 [00:36<00:13,  1.91s/it]

Loss: 0.3806118965148926


Epoch 5:  77%|███████▋  | 20/26 [00:38<00:11,  1.91s/it]

Loss: 0.2737439274787903


Epoch 5:  81%|████████  | 21/26 [00:40<00:09,  1.91s/it]

Loss: 0.6964415311813354


Epoch 5:  85%|████████▍ | 22/26 [00:42<00:07,  1.91s/it]

Loss: 0.4195079207420349


Epoch 5:  88%|████████▊ | 23/26 [00:43<00:05,  1.91s/it]

Loss: 0.5416730046272278


Epoch 5:  92%|█████████▏| 24/26 [00:45<00:03,  1.91s/it]

Loss: 0.46944892406463623


Epoch 5:  96%|█████████▌| 25/26 [00:47<00:01,  1.91s/it]

Loss: 0.336631715297699


Epoch 5: 100%|██████████| 26/26 [00:48<00:00,  1.87s/it]

Loss: 0.4333355128765106





Epoch 5 Validation Accuracy: 0.8669950738916257, F1-macro: 0.5285161290322581


Epoch 6:   4%|▍         | 1/26 [00:01<00:47,  1.90s/it]

Loss: 0.920135498046875


Epoch 6:   8%|▊         | 2/26 [00:03<00:45,  1.91s/it]

Loss: 0.5081340074539185


Epoch 6:  12%|█▏        | 3/26 [00:05<00:43,  1.91s/it]

Loss: 1.7887699604034424


Epoch 6:  15%|█▌        | 4/26 [00:07<00:42,  1.91s/it]

Loss: 0.4227217435836792


Epoch 6:  19%|█▉        | 5/26 [00:09<00:40,  1.91s/it]

Loss: 2.4941344261169434


Epoch 6:  23%|██▎       | 6/26 [00:11<00:38,  1.91s/it]

Loss: 2.141066074371338


Epoch 6:  27%|██▋       | 7/26 [00:13<00:36,  1.91s/it]

Loss: 1.8466477394104004


Epoch 6:  31%|███       | 8/26 [00:15<00:34,  1.91s/it]

Loss: 0.526729166507721


Epoch 6:  35%|███▍      | 9/26 [00:17<00:32,  1.91s/it]

Loss: 0.5548786520957947


Epoch 6:  38%|███▊      | 10/26 [00:19<00:30,  1.91s/it]

Loss: 0.6674926280975342


Epoch 6:  42%|████▏     | 11/26 [00:21<00:28,  1.91s/it]

Loss: 0.28923332691192627


Epoch 6:  46%|████▌     | 12/26 [00:22<00:26,  1.91s/it]

Loss: 0.57621169090271


Epoch 6:  50%|█████     | 13/26 [00:24<00:24,  1.91s/it]

Loss: 0.7886568903923035


Epoch 6:  54%|█████▍    | 14/26 [00:26<00:22,  1.91s/it]

Loss: 0.3769795894622803


Epoch 6:  58%|█████▊    | 15/26 [00:28<00:21,  1.91s/it]

Loss: 1.0936858654022217


Epoch 6:  62%|██████▏   | 16/26 [00:30<00:19,  1.91s/it]

Loss: 0.5063850283622742


Epoch 6:  65%|██████▌   | 17/26 [00:32<00:17,  1.91s/it]

Loss: 0.6893032193183899


Epoch 6:  69%|██████▉   | 18/26 [00:34<00:15,  1.91s/it]

Loss: 0.7296783924102783


Epoch 6:  73%|███████▎  | 19/26 [00:36<00:13,  1.91s/it]

Loss: 0.7005469799041748


Epoch 6:  77%|███████▋  | 20/26 [00:38<00:11,  1.91s/it]

Loss: 1.0292199850082397


Epoch 6:  81%|████████  | 21/26 [00:40<00:09,  1.91s/it]

Loss: 0.5948774814605713


Epoch 6:  85%|████████▍ | 22/26 [00:42<00:07,  1.91s/it]

Loss: 0.7191634178161621


Epoch 6:  88%|████████▊ | 23/26 [00:43<00:05,  1.91s/it]

Loss: 1.2217713594436646


Epoch 6:  92%|█████████▏| 24/26 [00:45<00:03,  1.91s/it]

Loss: 0.9320101737976074


Epoch 6:  96%|█████████▌| 25/26 [00:47<00:01,  1.91s/it]

Loss: 0.8096888661384583


Epoch 6: 100%|██████████| 26/26 [00:48<00:00,  1.87s/it]

Loss: 0.43815892934799194





Epoch 6 Validation Accuracy: 0.4630541871921182, F1-macro: 0.42216593111012457


Epoch 7:   4%|▍         | 1/26 [00:01<00:47,  1.90s/it]

Loss: 2.0525364875793457


Epoch 7:   8%|▊         | 2/26 [00:03<00:45,  1.91s/it]

Loss: 0.7372158765792847


Epoch 7:  12%|█▏        | 3/26 [00:05<00:43,  1.91s/it]

Loss: 0.8322734236717224


Epoch 7:  15%|█▌        | 4/26 [00:07<00:42,  1.91s/it]

Loss: 2.0310800075531006


Epoch 7:  19%|█▉        | 5/26 [00:09<00:40,  1.91s/it]

Loss: 1.705681324005127


Epoch 7:  23%|██▎       | 6/26 [00:11<00:38,  1.91s/it]

Loss: 0.897314190864563


Epoch 7:  27%|██▋       | 7/26 [00:13<00:36,  1.91s/it]

Loss: 2.3128371238708496


Epoch 7:  31%|███       | 8/26 [00:15<00:34,  1.91s/it]

Loss: 1.143056035041809


Epoch 7:  35%|███▍      | 9/26 [00:17<00:32,  1.91s/it]

Loss: 0.508658766746521


Epoch 7:  38%|███▊      | 10/26 [00:19<00:30,  1.91s/it]

Loss: 1.6059733629226685


Epoch 7:  42%|████▏     | 11/26 [00:21<00:28,  1.91s/it]

Loss: 2.4285812377929688


Epoch 7:  46%|████▌     | 12/26 [00:22<00:26,  1.91s/it]

Loss: 1.5173909664154053


Epoch 7:  50%|█████     | 13/26 [00:24<00:24,  1.91s/it]

Loss: 2.0073118209838867


Epoch 7:  54%|█████▍    | 14/26 [00:26<00:22,  1.91s/it]

Loss: 1.2625566720962524


Epoch 7:  58%|█████▊    | 15/26 [00:28<00:21,  1.91s/it]

Loss: 0.6660068035125732


Epoch 7:  62%|██████▏   | 16/26 [00:30<00:19,  1.91s/it]

Loss: 1.6506164073944092


Epoch 7:  65%|██████▌   | 17/26 [00:32<00:17,  1.91s/it]

Loss: 0.9380568265914917


Epoch 7:  69%|██████▉   | 18/26 [00:34<00:15,  1.91s/it]

Loss: 0.597430944442749


Epoch 7:  73%|███████▎  | 19/26 [00:36<00:13,  1.91s/it]

Loss: 1.7053718566894531


Epoch 7:  77%|███████▋  | 20/26 [00:38<00:11,  1.91s/it]

Loss: 0.9128577709197998


Epoch 7:  81%|████████  | 21/26 [00:40<00:09,  1.91s/it]

Loss: 0.18479059636592865


Epoch 7:  85%|████████▍ | 22/26 [00:42<00:07,  1.91s/it]

Loss: 0.5575481653213501


Epoch 7:  88%|████████▊ | 23/26 [00:43<00:05,  1.91s/it]

Loss: 0.2589039206504822


Epoch 7:  92%|█████████▏| 24/26 [00:45<00:03,  1.91s/it]

Loss: 0.7929493188858032


Epoch 7:  96%|█████████▌| 25/26 [00:47<00:01,  1.91s/it]

Loss: 0.5091845989227295


Epoch 7: 100%|██████████| 26/26 [00:48<00:00,  1.87s/it]

Loss: 0.817652702331543





Epoch 7 Validation Accuracy: 0.8423645320197044, F1-macro: 0.5123123123123123


Epoch 8:   4%|▍         | 1/26 [00:01<00:47,  1.90s/it]

Loss: 0.46044185757637024


Epoch 8:   8%|▊         | 2/26 [00:03<00:45,  1.91s/it]

Loss: 0.15287408232688904


Epoch 8:  12%|█▏        | 3/26 [00:05<00:43,  1.91s/it]

Loss: 0.39126265048980713


Epoch 8:  15%|█▌        | 4/26 [00:07<00:42,  1.91s/it]

Loss: 0.8517872095108032


Epoch 8:  19%|█▉        | 5/26 [00:09<00:40,  1.91s/it]

Loss: 0.6509354710578918


Epoch 8:  23%|██▎       | 6/26 [00:11<00:38,  1.91s/it]

Loss: 1.247891902923584


Epoch 8:  27%|██▋       | 7/26 [00:13<00:36,  1.91s/it]

Loss: 0.4816514551639557


Epoch 8:  31%|███       | 8/26 [00:15<00:34,  1.91s/it]

Loss: 0.8781223297119141


Epoch 8:  35%|███▍      | 9/26 [00:17<00:32,  1.91s/it]

Loss: 0.9574077129364014


Epoch 8:  38%|███▊      | 10/26 [00:19<00:30,  1.91s/it]

Loss: 0.36751508712768555


Epoch 8:  42%|████▏     | 11/26 [00:21<00:28,  1.91s/it]

Loss: 0.7081555724143982


Epoch 8:  46%|████▌     | 12/26 [00:22<00:26,  1.91s/it]

Loss: 0.6708479523658752


Epoch 8:  50%|█████     | 13/26 [00:24<00:24,  1.91s/it]

Loss: 0.5233442783355713


Epoch 8:  54%|█████▍    | 14/26 [00:26<00:22,  1.91s/it]

Loss: 0.9908314943313599


Epoch 8:  58%|█████▊    | 15/26 [00:28<00:21,  1.91s/it]

Loss: 0.6086813807487488


Epoch 8:  62%|██████▏   | 16/26 [00:30<00:19,  1.91s/it]

Loss: 0.5160071849822998


Epoch 8:  65%|██████▌   | 17/26 [00:32<00:17,  1.91s/it]

Loss: 0.6534985303878784


Epoch 8:  69%|██████▉   | 18/26 [00:34<00:15,  1.91s/it]

Loss: 0.48445841670036316


Epoch 8:  73%|███████▎  | 19/26 [00:36<00:13,  1.91s/it]

Loss: 0.6958845257759094


Epoch 8:  77%|███████▋  | 20/26 [00:38<00:11,  1.91s/it]

Loss: 0.34311267733573914


Epoch 8:  81%|████████  | 21/26 [00:40<00:09,  1.91s/it]

Loss: 0.5158443450927734


Epoch 8:  85%|████████▍ | 22/26 [00:42<00:07,  1.91s/it]

Loss: 0.6345652937889099


Epoch 8:  88%|████████▊ | 23/26 [00:43<00:05,  1.91s/it]

Loss: 0.3439824879169464


Epoch 8:  92%|█████████▏| 24/26 [00:45<00:03,  1.91s/it]

Loss: 0.6168147325515747


Epoch 8:  96%|█████████▌| 25/26 [00:47<00:01,  1.91s/it]

Loss: 0.5919317007064819


Epoch 8: 100%|██████████| 26/26 [00:48<00:00,  1.87s/it]

Loss: 0.06216945871710777





Epoch 8 Validation Accuracy: 0.8719211822660099, F1-macro: 0.46578947368421053


Epoch 9:   4%|▍         | 1/26 [00:01<00:47,  1.91s/it]

Loss: 1.1575374603271484


Epoch 9:   8%|▊         | 2/26 [00:03<00:45,  1.91s/it]

Loss: 0.8788806200027466


Epoch 9:  12%|█▏        | 3/26 [00:05<00:43,  1.91s/it]

Loss: 0.5323529243469238


Epoch 9:  15%|█▌        | 4/26 [00:07<00:42,  1.91s/it]

Loss: 0.4264901578426361


Epoch 9:  19%|█▉        | 5/26 [00:09<00:40,  1.91s/it]

Loss: 0.5493401885032654


Epoch 9:  23%|██▎       | 6/26 [00:11<00:38,  1.91s/it]

Loss: 0.4871227741241455


Epoch 9:  27%|██▋       | 7/26 [00:13<00:36,  1.91s/it]

Loss: 0.902215838432312


Epoch 9:  31%|███       | 8/26 [00:15<00:34,  1.91s/it]

Loss: 0.39486733078956604


Epoch 9:  35%|███▍      | 9/26 [00:17<00:32,  1.91s/it]

Loss: 0.4792396128177643


Epoch 9:  38%|███▊      | 10/26 [00:19<00:30,  1.91s/it]

Loss: 0.19955545663833618


Epoch 9:  42%|████▏     | 11/26 [00:21<00:28,  1.91s/it]

Loss: 0.3390641212463379


Epoch 9:  46%|████▌     | 12/26 [00:22<00:26,  1.91s/it]

Loss: 0.3913707137107849


Epoch 9:  50%|█████     | 13/26 [00:24<00:24,  1.91s/it]

Loss: 0.23526698350906372


Epoch 9:  54%|█████▍    | 14/26 [00:26<00:22,  1.91s/it]

Loss: 0.31682726740837097


Epoch 9:  58%|█████▊    | 15/26 [00:28<00:21,  1.91s/it]

Loss: 0.4821076989173889


Epoch 9:  62%|██████▏   | 16/26 [00:30<00:19,  1.91s/it]

Loss: 0.21332582831382751


Epoch 9:  65%|██████▌   | 17/26 [00:32<00:17,  1.91s/it]

Loss: 0.8109254837036133


Epoch 9:  69%|██████▉   | 18/26 [00:34<00:15,  1.91s/it]

Loss: 0.3063158392906189


Epoch 9:  73%|███████▎  | 19/26 [00:36<00:13,  1.91s/it]

Loss: 0.4788304269313812


Epoch 9:  77%|███████▋  | 20/26 [00:38<00:11,  1.91s/it]

Loss: 0.3254682719707489


Epoch 9:  81%|████████  | 21/26 [00:40<00:09,  1.91s/it]

Loss: 0.6173779964447021


Epoch 9:  85%|████████▍ | 22/26 [00:42<00:07,  1.91s/it]

Loss: 0.27854180335998535


Epoch 9:  88%|████████▊ | 23/26 [00:43<00:05,  1.91s/it]

Loss: 0.454328328371048


Epoch 9:  92%|█████████▏| 24/26 [00:45<00:03,  1.91s/it]

Loss: 0.21562571823596954


Epoch 9:  96%|█████████▌| 25/26 [00:47<00:01,  1.91s/it]

Loss: 0.5601451992988586


Epoch 9: 100%|██████████| 26/26 [00:48<00:00,  1.87s/it]

Loss: 0.48488086462020874





Epoch 9 Validation Accuracy: 0.8325123152709359, F1-macro: 0.6122471910112359


Epoch 10:   4%|▍         | 1/26 [00:01<00:47,  1.91s/it]

Loss: 0.10540631413459778


Epoch 10:   8%|▊         | 2/26 [00:03<00:45,  1.91s/it]

Loss: 0.4080628752708435


Epoch 10:  12%|█▏        | 3/26 [00:05<00:43,  1.91s/it]

Loss: 0.5471776723861694


Epoch 10:  15%|█▌        | 4/26 [00:07<00:42,  1.91s/it]

Loss: 0.19872449338436127


Epoch 10:  19%|█▉        | 5/26 [00:09<00:40,  1.91s/it]

Loss: 0.6035897135734558


Epoch 10:  23%|██▎       | 6/26 [00:11<00:38,  1.91s/it]

Loss: 0.03629158437252045


Epoch 10:  27%|██▋       | 7/26 [00:13<00:36,  1.91s/it]

Loss: 0.5108382701873779


Epoch 10:  31%|███       | 8/26 [00:15<00:34,  1.91s/it]

Loss: 0.21714358031749725


Epoch 10:  35%|███▍      | 9/26 [00:17<00:32,  1.91s/it]

Loss: 0.38499873876571655


Epoch 10:  38%|███▊      | 10/26 [00:19<00:30,  1.91s/it]

Loss: 0.29926633834838867


Epoch 10:  42%|████▏     | 11/26 [00:21<00:28,  1.91s/it]

Loss: 0.5511878132820129


Epoch 10:  46%|████▌     | 12/26 [00:22<00:26,  1.91s/it]

Loss: 0.968876838684082


Epoch 10:  50%|█████     | 13/26 [00:24<00:24,  1.91s/it]

Loss: 0.7476524710655212


Epoch 10:  54%|█████▍    | 14/26 [00:26<00:22,  1.91s/it]

Loss: 0.4060019254684448


Epoch 10:  58%|█████▊    | 15/26 [00:28<00:21,  1.91s/it]

Loss: 0.745347261428833


Epoch 10:  62%|██████▏   | 16/26 [00:30<00:19,  1.91s/it]

Loss: 0.464887410402298


Epoch 10:  65%|██████▌   | 17/26 [00:32<00:17,  1.91s/it]

Loss: 1.1407722234725952


Epoch 10:  69%|██████▉   | 18/26 [00:34<00:15,  1.91s/it]

Loss: 0.680192768573761


Epoch 10:  73%|███████▎  | 19/26 [00:36<00:13,  1.91s/it]

Loss: 0.8970171809196472


Epoch 10:  77%|███████▋  | 20/26 [00:38<00:11,  1.91s/it]

Loss: 0.5600894093513489


Epoch 10:  81%|████████  | 21/26 [00:40<00:09,  1.91s/it]

Loss: 0.6391329765319824


Epoch 10:  85%|████████▍ | 22/26 [00:42<00:07,  1.91s/it]

Loss: 0.673553466796875


Epoch 10:  88%|████████▊ | 23/26 [00:43<00:05,  1.91s/it]

Loss: 0.4479787349700928


Epoch 10:  92%|█████████▏| 24/26 [00:45<00:03,  1.91s/it]

Loss: 0.3985259532928467


Epoch 10:  96%|█████████▌| 25/26 [00:47<00:01,  1.91s/it]

Loss: 0.798467218875885


Epoch 10: 100%|██████████| 26/26 [00:48<00:00,  1.87s/it]

Loss: 0.37465932965278625





Epoch 10 Validation Accuracy: 0.8620689655172413, F1-macro: 0.4960992907801418


In [None]:
# Evaluate on the test set
test_metrics = evaluate(model, test_dataloader, device)
print(f"Test Accuracy: {test_metrics['accuracy']}, F1-macro: {test_metrics['f1_macro']}")

Test Accuracy: 0.8823529411764706, F1-macro: 0.4960992907801418


In [None]:
current_type = 'fortune telling'

results_df.loc[
    results_df["distortion_type"] == current_type,
    ["test_accuracy", "f1_macro"]
] = [test_metrics['accuracy'], test_metrics['f1_macro']]

# 6. Labeling

In [None]:
# Add labels
data1_1_labels = list(data1['labeling'][data1_1.index])
data2_1_labels = list(data2['labeling'][data2_1.index])
data3_1_labels = list(data3['labeling'][data3_1.index])

# Merging Data
data_encoded = data1_1_encoded + data2_1_encoded + data3_1_encoded
data_labels = data1_1_labels + data2_1_labels + data3_1_labels

dataset_with_labels = CustomDatasetWithLabels(data_encoded, data_labels)

# Define proportions for splitting
train_size = int(0.8 * len(dataset_with_labels))
val_size = int(0.1 * len(dataset_with_labels))
test_size = len(dataset_with_labels) - train_size - val_size

# Split the dataset
train_dataset, val_dataset, test_dataset = random_split(dataset_with_labels, [train_size, val_size, test_size])

# Create DataLoaders for each set
train_dataloader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=64)
test_dataloader = DataLoader(test_dataset, batch_size=64)

input_dim = bert_config.hidden_size
num_labels = len(set(data_labels))
label_emb = torch.randn(num_labels, input_dim)

In [None]:
# Instantiate the InnerProductClassifier model
model = InnerProductClassifier(input_dim, label_emb).to(device)

# Define loss function
criterion = nn.CrossEntropyLoss()

# Define optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=2e-4)

In [None]:
EPOCHS = 10

# Move the BERT model to the device
bert_model.to(device)

for epoch in range(EPOCHS):
    model.train()
    for batch in tqdm(train_dataloader, desc=f"Epoch {epoch+1}"):

        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        y = batch["y"].to(device)

        # Get embeddings from the BERT model
        with torch.no_grad():
            outputs = bert_model(input_ids=input_ids, attention_mask=attention_mask)
            embeddings = outputs.last_hidden_state[:, 0, :] # Get the [CLS] token embedding

        # Pass embeddings to the InnerProductClassifier
        logits = model(embeddings)
        loss = criterion(logits, y)

        # Backpropagate and update weights
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        print(f"Loss: {loss.item()}")

    # Evaluate on the validation set after each epoch
    val_metrics = evaluate(model, val_dataloader, device)
    print(f"Epoch {epoch+1} Validation Accuracy: {val_metrics['accuracy']}, F1-macro: {val_metrics['f1_macro']}")

Epoch 1:   4%|▍         | 1/26 [00:01<00:47,  1.88s/it]

Loss: 1.6382334232330322


Epoch 1:   8%|▊         | 2/26 [00:03<00:45,  1.90s/it]

Loss: 5.039063930511475


Epoch 1:  12%|█▏        | 3/26 [00:05<00:43,  1.91s/it]

Loss: 5.623708724975586


Epoch 1:  15%|█▌        | 4/26 [00:07<00:41,  1.91s/it]

Loss: 3.0775136947631836


Epoch 1:  19%|█▉        | 5/26 [00:09<00:40,  1.91s/it]

Loss: 4.422934532165527


Epoch 1:  23%|██▎       | 6/26 [00:11<00:38,  1.91s/it]

Loss: 4.249041557312012


Epoch 1:  27%|██▋       | 7/26 [00:13<00:36,  1.91s/it]

Loss: 1.9358594417572021


Epoch 1:  31%|███       | 8/26 [00:15<00:34,  1.91s/it]

Loss: 3.4917309284210205


Epoch 1:  35%|███▍      | 9/26 [00:17<00:32,  1.91s/it]

Loss: 4.352371692657471


Epoch 1:  38%|███▊      | 10/26 [00:19<00:30,  1.91s/it]

Loss: 4.139320373535156


Epoch 1:  42%|████▏     | 11/26 [00:21<00:28,  1.91s/it]

Loss: 3.9211621284484863


Epoch 1:  46%|████▌     | 12/26 [00:22<00:26,  1.91s/it]

Loss: 2.0920138359069824


Epoch 1:  50%|█████     | 13/26 [00:24<00:24,  1.91s/it]

Loss: 1.3633627891540527


Epoch 1:  54%|█████▍    | 14/26 [00:26<00:22,  1.91s/it]

Loss: 4.841593265533447


Epoch 1:  58%|█████▊    | 15/26 [00:28<00:21,  1.91s/it]

Loss: 1.4038745164871216


Epoch 1:  62%|██████▏   | 16/26 [00:30<00:19,  1.91s/it]

Loss: 1.157052755355835


Epoch 1:  65%|██████▌   | 17/26 [00:32<00:17,  1.91s/it]

Loss: 0.964346170425415


Epoch 1:  69%|██████▉   | 18/26 [00:34<00:15,  1.91s/it]

Loss: 2.189875841140747


Epoch 1:  73%|███████▎  | 19/26 [00:36<00:13,  1.91s/it]

Loss: 3.8981220722198486


Epoch 1:  77%|███████▋  | 20/26 [00:38<00:11,  1.91s/it]

Loss: 1.3678028583526611


Epoch 1:  81%|████████  | 21/26 [00:40<00:09,  1.91s/it]

Loss: 5.0475568771362305


Epoch 1:  85%|████████▍ | 22/26 [00:42<00:07,  1.91s/it]

Loss: 3.8327293395996094


Epoch 1:  88%|████████▊ | 23/26 [00:43<00:05,  1.91s/it]

Loss: 1.8042316436767578


Epoch 1:  92%|█████████▏| 24/26 [00:45<00:03,  1.91s/it]

Loss: 1.1470633745193481


Epoch 1:  96%|█████████▌| 25/26 [00:47<00:01,  1.91s/it]

Loss: 2.848949432373047


Epoch 1: 100%|██████████| 26/26 [00:48<00:00,  1.87s/it]

Loss: 2.4998838901519775





Epoch 1 Validation Accuracy: 0.7487684729064039, F1-macro: 0.47956567636857184


Epoch 2:   4%|▍         | 1/26 [00:01<00:47,  1.90s/it]

Loss: 0.8282188773155212


Epoch 2:   8%|▊         | 2/26 [00:03<00:45,  1.91s/it]

Loss: 1.6656678915023804


Epoch 2:  12%|█▏        | 3/26 [00:05<00:43,  1.91s/it]

Loss: 1.55320405960083


Epoch 2:  15%|█▌        | 4/26 [00:07<00:42,  1.91s/it]

Loss: 1.3789865970611572


Epoch 2:  19%|█▉        | 5/26 [00:09<00:40,  1.91s/it]

Loss: 2.635657548904419


Epoch 2:  23%|██▎       | 6/26 [00:11<00:38,  1.91s/it]

Loss: 3.591520309448242


Epoch 2:  27%|██▋       | 7/26 [00:13<00:36,  1.91s/it]

Loss: 1.671290636062622


Epoch 2:  31%|███       | 8/26 [00:15<00:34,  1.91s/it]

Loss: 0.9510765671730042


Epoch 2:  35%|███▍      | 9/26 [00:17<00:32,  1.91s/it]

Loss: 2.1740493774414062


Epoch 2:  38%|███▊      | 10/26 [00:19<00:30,  1.91s/it]

Loss: 3.4401960372924805


Epoch 2:  42%|████▏     | 11/26 [00:21<00:28,  1.91s/it]

Loss: 0.982653021812439


Epoch 2:  46%|████▌     | 12/26 [00:22<00:26,  1.91s/it]

Loss: 0.38773077726364136


Epoch 2:  50%|█████     | 13/26 [00:24<00:24,  1.91s/it]

Loss: 1.7044062614440918


Epoch 2:  54%|█████▍    | 14/26 [00:26<00:22,  1.91s/it]

Loss: 1.4061861038208008


Epoch 2:  58%|█████▊    | 15/26 [00:28<00:20,  1.91s/it]

Loss: 0.9974192380905151


Epoch 2:  62%|██████▏   | 16/26 [00:30<00:19,  1.91s/it]

Loss: 1.0071102380752563


Epoch 2:  65%|██████▌   | 17/26 [00:32<00:17,  1.91s/it]

Loss: 0.884520411491394


Epoch 2:  69%|██████▉   | 18/26 [00:34<00:15,  1.91s/it]

Loss: 0.8542007207870483


Epoch 2:  73%|███████▎  | 19/26 [00:36<00:13,  1.91s/it]

Loss: 1.0697518587112427


Epoch 2:  77%|███████▋  | 20/26 [00:38<00:11,  1.91s/it]

Loss: 1.0143126249313354


Epoch 2:  81%|████████  | 21/26 [00:40<00:09,  1.91s/it]

Loss: 1.3291382789611816


Epoch 2:  85%|████████▍ | 22/26 [00:41<00:07,  1.91s/it]

Loss: 0.7087782621383667


Epoch 2:  88%|████████▊ | 23/26 [00:43<00:05,  1.91s/it]

Loss: 0.8972347378730774


Epoch 2:  92%|█████████▏| 24/26 [00:45<00:03,  1.91s/it]

Loss: 1.3168408870697021


Epoch 2:  96%|█████████▌| 25/26 [00:47<00:01,  1.91s/it]

Loss: 1.0642136335372925


Epoch 2: 100%|██████████| 26/26 [00:48<00:00,  1.87s/it]

Loss: 0.895236611366272





Epoch 2 Validation Accuracy: 0.7980295566502463, F1-macro: 0.4876577408433364


Epoch 3:   4%|▍         | 1/26 [00:01<00:47,  1.90s/it]

Loss: 0.3960949182510376


Epoch 3:   8%|▊         | 2/26 [00:03<00:45,  1.91s/it]

Loss: 0.5757136344909668


Epoch 3:  12%|█▏        | 3/26 [00:05<00:43,  1.91s/it]

Loss: 0.6716372966766357


Epoch 3:  15%|█▌        | 4/26 [00:07<00:41,  1.91s/it]

Loss: 0.6910014152526855


Epoch 3:  19%|█▉        | 5/26 [00:09<00:40,  1.91s/it]

Loss: 0.9834904670715332


Epoch 3:  23%|██▎       | 6/26 [00:11<00:38,  1.91s/it]

Loss: 0.507138192653656


Epoch 3:  27%|██▋       | 7/26 [00:13<00:36,  1.91s/it]

Loss: 0.6879472732543945


Epoch 3:  31%|███       | 8/26 [00:15<00:34,  1.91s/it]

Loss: 0.4083550274372101


Epoch 3:  35%|███▍      | 9/26 [00:17<00:32,  1.91s/it]

Loss: 0.9977719783782959


Epoch 3:  38%|███▊      | 10/26 [00:19<00:30,  1.91s/it]

Loss: 0.6849976778030396


Epoch 3:  42%|████▏     | 11/26 [00:20<00:28,  1.91s/it]

Loss: 0.5074118375778198


Epoch 3:  46%|████▌     | 12/26 [00:22<00:26,  1.91s/it]

Loss: 0.7611240148544312


Epoch 3:  50%|█████     | 13/26 [00:24<00:24,  1.91s/it]

Loss: 0.6188144683837891


Epoch 3:  54%|█████▍    | 14/26 [00:26<00:22,  1.91s/it]

Loss: 0.6831462383270264


Epoch 3:  58%|█████▊    | 15/26 [00:28<00:20,  1.91s/it]

Loss: 0.3637416362762451


Epoch 3:  62%|██████▏   | 16/26 [00:30<00:19,  1.91s/it]

Loss: 0.598733127117157


Epoch 3:  65%|██████▌   | 17/26 [00:32<00:17,  1.91s/it]

Loss: 0.3866499066352844


Epoch 3:  69%|██████▉   | 18/26 [00:34<00:15,  1.91s/it]

Loss: 0.5651666522026062


Epoch 3:  73%|███████▎  | 19/26 [00:36<00:13,  1.91s/it]

Loss: 1.022214651107788


Epoch 3:  77%|███████▋  | 20/26 [00:38<00:11,  1.91s/it]

Loss: 0.762056291103363


Epoch 3:  81%|████████  | 21/26 [00:40<00:09,  1.91s/it]

Loss: 1.0600082874298096


Epoch 3:  85%|████████▍ | 22/26 [00:41<00:07,  1.91s/it]

Loss: 0.9335411787033081


Epoch 3:  88%|████████▊ | 23/26 [00:43<00:05,  1.91s/it]

Loss: 0.8590319156646729


Epoch 3:  92%|█████████▏| 24/26 [00:45<00:03,  1.91s/it]

Loss: 1.0440937280654907


Epoch 3:  96%|█████████▌| 25/26 [00:47<00:01,  1.91s/it]

Loss: 0.2943814694881439


Epoch 3: 100%|██████████| 26/26 [00:48<00:00,  1.87s/it]

Loss: 0.6933866143226624





Epoch 3 Validation Accuracy: 0.8177339901477833, F1-macro: 0.44986449864498645


Epoch 4:   4%|▍         | 1/26 [00:01<00:47,  1.90s/it]

Loss: 1.2553099393844604


Epoch 4:   8%|▊         | 2/26 [00:03<00:45,  1.91s/it]

Loss: 1.0400387048721313


Epoch 4:  12%|█▏        | 3/26 [00:05<00:43,  1.91s/it]

Loss: 0.7724860310554504


Epoch 4:  15%|█▌        | 4/26 [00:07<00:41,  1.91s/it]

Loss: 0.8635246753692627


Epoch 4:  19%|█▉        | 5/26 [00:09<00:40,  1.91s/it]

Loss: 0.3459228575229645


Epoch 4:  23%|██▎       | 6/26 [00:11<00:38,  1.91s/it]

Loss: 0.6652419567108154


Epoch 4:  27%|██▋       | 7/26 [00:13<00:36,  1.91s/it]

Loss: 0.6428332328796387


Epoch 4:  31%|███       | 8/26 [00:15<00:34,  1.91s/it]

Loss: 0.5000776052474976


Epoch 4:  35%|███▍      | 9/26 [00:17<00:32,  1.91s/it]

Loss: 0.7123528122901917


Epoch 4:  38%|███▊      | 10/26 [00:19<00:30,  1.91s/it]

Loss: 0.6306058764457703


Epoch 4:  42%|████▏     | 11/26 [00:20<00:28,  1.91s/it]

Loss: 0.5900453329086304


Epoch 4:  46%|████▌     | 12/26 [00:22<00:26,  1.91s/it]

Loss: 0.48676353693008423


Epoch 4:  50%|█████     | 13/26 [00:24<00:24,  1.91s/it]

Loss: 0.505403995513916


Epoch 4:  54%|█████▍    | 14/26 [00:26<00:22,  1.91s/it]

Loss: 0.333551287651062


Epoch 4:  58%|█████▊    | 15/26 [00:28<00:21,  1.91s/it]

Loss: 0.5356637835502625


Epoch 4:  62%|██████▏   | 16/26 [00:30<00:19,  1.91s/it]

Loss: 0.6372305154800415


Epoch 4:  65%|██████▌   | 17/26 [00:32<00:17,  1.91s/it]

Loss: 0.4511745274066925


Epoch 4:  69%|██████▉   | 18/26 [00:34<00:15,  1.91s/it]

Loss: 1.052695631980896


Epoch 4:  73%|███████▎  | 19/26 [00:36<00:13,  1.91s/it]

Loss: 0.47023677825927734


Epoch 4:  77%|███████▋  | 20/26 [00:38<00:11,  1.91s/it]

Loss: 0.459020733833313


Epoch 4:  81%|████████  | 21/26 [00:40<00:09,  1.91s/it]

Loss: 0.8952484130859375


Epoch 4:  85%|████████▍ | 22/26 [00:41<00:07,  1.91s/it]

Loss: 0.5240585803985596


Epoch 4:  88%|████████▊ | 23/26 [00:43<00:05,  1.91s/it]

Loss: 0.4668745696544647


Epoch 4:  92%|█████████▏| 24/26 [00:45<00:03,  1.91s/it]

Loss: 0.5187841653823853


Epoch 4:  96%|█████████▌| 25/26 [00:47<00:01,  1.91s/it]

Loss: 0.38393402099609375


Epoch 4: 100%|██████████| 26/26 [00:48<00:00,  1.87s/it]

Loss: 0.5981417894363403





Epoch 4 Validation Accuracy: 0.6009852216748769, F1-macro: 0.4920767306088407


Epoch 5:   4%|▍         | 1/26 [00:01<00:47,  1.90s/it]

Loss: 0.8398082256317139


Epoch 5:   8%|▊         | 2/26 [00:03<00:45,  1.90s/it]

Loss: 0.4788071811199188


Epoch 5:  12%|█▏        | 3/26 [00:05<00:43,  1.90s/it]

Loss: 0.5250952243804932


Epoch 5:  15%|█▌        | 4/26 [00:07<00:41,  1.91s/it]

Loss: 0.705818772315979


Epoch 5:  19%|█▉        | 5/26 [00:09<00:40,  1.91s/it]

Loss: 0.2770495116710663


Epoch 5:  23%|██▎       | 6/26 [00:11<00:38,  1.91s/it]

Loss: 0.44229674339294434


Epoch 5:  27%|██▋       | 7/26 [00:13<00:36,  1.91s/it]

Loss: 0.5994628667831421


Epoch 5:  31%|███       | 8/26 [00:15<00:34,  1.90s/it]

Loss: 0.336531400680542


Epoch 5:  35%|███▍      | 9/26 [00:17<00:32,  1.91s/it]

Loss: 0.8770533204078674


Epoch 5:  38%|███▊      | 10/26 [00:19<00:30,  1.91s/it]

Loss: 0.6154147386550903


Epoch 5:  42%|████▏     | 11/26 [00:20<00:28,  1.91s/it]

Loss: 0.604093074798584


Epoch 5:  46%|████▌     | 12/26 [00:22<00:26,  1.91s/it]

Loss: 0.2646562457084656


Epoch 5:  50%|█████     | 13/26 [00:24<00:24,  1.90s/it]

Loss: 0.4710747003555298


Epoch 5:  54%|█████▍    | 14/26 [00:26<00:22,  1.91s/it]

Loss: 0.3680568039417267


Epoch 5:  58%|█████▊    | 15/26 [00:28<00:20,  1.91s/it]

Loss: 0.6205010414123535


Epoch 5:  62%|██████▏   | 16/26 [00:30<00:19,  1.91s/it]

Loss: 0.5733782649040222


Epoch 5:  65%|██████▌   | 17/26 [00:32<00:17,  1.91s/it]

Loss: 0.191003680229187


Epoch 5:  69%|██████▉   | 18/26 [00:34<00:15,  1.91s/it]

Loss: 0.4457499384880066


Epoch 5:  73%|███████▎  | 19/26 [00:36<00:13,  1.91s/it]

Loss: 0.6130616068840027


Epoch 5:  77%|███████▋  | 20/26 [00:38<00:11,  1.91s/it]

Loss: 0.7749935388565063


Epoch 5:  81%|████████  | 21/26 [00:40<00:09,  1.91s/it]

Loss: 0.9246630668640137


Epoch 5:  85%|████████▍ | 22/26 [00:41<00:07,  1.91s/it]

Loss: 0.5689545273780823


Epoch 5:  88%|████████▊ | 23/26 [00:43<00:05,  1.91s/it]

Loss: 0.5293471813201904


Epoch 5:  92%|█████████▏| 24/26 [00:45<00:03,  1.91s/it]

Loss: 1.3193178176879883


Epoch 5:  96%|█████████▌| 25/26 [00:47<00:01,  1.91s/it]

Loss: 0.92081618309021


Epoch 5: 100%|██████████| 26/26 [00:48<00:00,  1.87s/it]

Loss: 0.839052140712738





Epoch 5 Validation Accuracy: 0.5812807881773399, F1-macro: 0.5042948662702175


Epoch 6:   4%|▍         | 1/26 [00:01<00:47,  1.90s/it]

Loss: 0.7951438426971436


Epoch 6:   8%|▊         | 2/26 [00:03<00:45,  1.90s/it]

Loss: 0.41420167684555054


Epoch 6:  12%|█▏        | 3/26 [00:05<00:43,  1.91s/it]

Loss: 0.7663511037826538


Epoch 6:  15%|█▌        | 4/26 [00:07<00:41,  1.91s/it]

Loss: 0.7204210162162781


Epoch 6:  19%|█▉        | 5/26 [00:09<00:40,  1.91s/it]

Loss: 0.4084344506263733


Epoch 6:  23%|██▎       | 6/26 [00:11<00:38,  1.91s/it]

Loss: 0.9016289710998535


Epoch 6:  27%|██▋       | 7/26 [00:13<00:36,  1.91s/it]

Loss: 0.3550567030906677


Epoch 6:  31%|███       | 8/26 [00:15<00:34,  1.91s/it]

Loss: 0.9250055551528931


Epoch 6:  35%|███▍      | 9/26 [00:17<00:32,  1.91s/it]

Loss: 1.1785141229629517


Epoch 6:  38%|███▊      | 10/26 [00:19<00:30,  1.91s/it]

Loss: 0.6962246894836426


Epoch 6:  42%|████▏     | 11/26 [00:20<00:28,  1.91s/it]

Loss: 0.2904900312423706


Epoch 6:  46%|████▌     | 12/26 [00:22<00:26,  1.91s/it]

Loss: 1.2098808288574219


Epoch 6:  50%|█████     | 13/26 [00:24<00:24,  1.91s/it]

Loss: 0.2804780602455139


Epoch 6:  54%|█████▍    | 14/26 [00:26<00:22,  1.91s/it]

Loss: 0.7790943384170532


Epoch 6:  58%|█████▊    | 15/26 [00:28<00:20,  1.91s/it]

Loss: 1.4157824516296387


Epoch 6:  62%|██████▏   | 16/26 [00:30<00:19,  1.91s/it]

Loss: 1.023959755897522


Epoch 6:  65%|██████▌   | 17/26 [00:32<00:17,  1.91s/it]

Loss: 0.35056784749031067


Epoch 6:  69%|██████▉   | 18/26 [00:34<00:15,  1.91s/it]

Loss: 0.9838721752166748


Epoch 6:  73%|███████▎  | 19/26 [00:36<00:13,  1.91s/it]

Loss: 1.077620506286621


Epoch 6:  77%|███████▋  | 20/26 [00:38<00:11,  1.91s/it]

Loss: 0.38306522369384766


Epoch 6:  81%|████████  | 21/26 [00:40<00:09,  1.91s/it]

Loss: 0.5362464189529419


Epoch 6:  85%|████████▍ | 22/26 [00:41<00:07,  1.91s/it]

Loss: 0.7093912363052368


Epoch 6:  88%|████████▊ | 23/26 [00:43<00:05,  1.91s/it]

Loss: 0.24409520626068115


Epoch 6:  92%|█████████▏| 24/26 [00:45<00:03,  1.91s/it]

Loss: 0.4334834814071655


Epoch 6:  96%|█████████▌| 25/26 [00:47<00:01,  1.91s/it]

Loss: 0.44551604986190796


Epoch 6: 100%|██████████| 26/26 [00:48<00:00,  1.87s/it]

Loss: 0.7362231016159058





Epoch 6 Validation Accuracy: 0.7832512315270936, F1-macro: 0.49820224719101125


Epoch 7:   4%|▍         | 1/26 [00:01<00:47,  1.90s/it]

Loss: 0.27175572514533997


Epoch 7:   8%|▊         | 2/26 [00:03<00:45,  1.90s/it]

Loss: 0.4098781943321228


Epoch 7:  12%|█▏        | 3/26 [00:05<00:43,  1.91s/it]

Loss: 0.21964086592197418


Epoch 7:  15%|█▌        | 4/26 [00:07<00:41,  1.91s/it]

Loss: 0.23167067766189575


Epoch 7:  19%|█▉        | 5/26 [00:09<00:40,  1.91s/it]

Loss: 0.39014819264411926


Epoch 7:  23%|██▎       | 6/26 [00:11<00:38,  1.91s/it]

Loss: 0.5460226535797119


Epoch 7:  27%|██▋       | 7/26 [00:13<00:36,  1.91s/it]

Loss: 0.5003018379211426


Epoch 7:  31%|███       | 8/26 [00:15<00:34,  1.91s/it]

Loss: 0.15940366685390472


Epoch 7:  35%|███▍      | 9/26 [00:17<00:32,  1.91s/it]

Loss: 0.3383346199989319


Epoch 7:  38%|███▊      | 10/26 [00:19<00:30,  1.91s/it]

Loss: 0.4513513147830963


Epoch 7:  42%|████▏     | 11/26 [00:20<00:28,  1.91s/it]

Loss: 0.23188848793506622


Epoch 7:  46%|████▌     | 12/26 [00:22<00:26,  1.91s/it]

Loss: 0.36855509877204895


Epoch 7:  50%|█████     | 13/26 [00:24<00:24,  1.91s/it]

Loss: 0.5546261072158813


Epoch 7:  54%|█████▍    | 14/26 [00:26<00:22,  1.91s/it]

Loss: 0.33897238969802856


Epoch 7:  58%|█████▊    | 15/26 [00:28<00:20,  1.91s/it]

Loss: 0.690523624420166


Epoch 7:  62%|██████▏   | 16/26 [00:30<00:19,  1.91s/it]

Loss: 0.45196467638015747


Epoch 7:  65%|██████▌   | 17/26 [00:32<00:17,  1.91s/it]

Loss: 0.5311195254325867


Epoch 7:  69%|██████▉   | 18/26 [00:34<00:15,  1.91s/it]

Loss: 0.4877840280532837


Epoch 7:  73%|███████▎  | 19/26 [00:36<00:13,  1.91s/it]

Loss: 0.9856774210929871


Epoch 7:  77%|███████▋  | 20/26 [00:38<00:11,  1.91s/it]

Loss: 0.3816531002521515


Epoch 7:  81%|████████  | 21/26 [00:40<00:09,  1.91s/it]

Loss: 0.5227304100990295


Epoch 7:  85%|████████▍ | 22/26 [00:41<00:07,  1.91s/it]

Loss: 1.2529504299163818


Epoch 7:  88%|████████▊ | 23/26 [00:43<00:05,  1.91s/it]

Loss: 1.7763700485229492


Epoch 7:  92%|█████████▏| 24/26 [00:45<00:03,  1.91s/it]

Loss: 0.36620432138442993


Epoch 7:  96%|█████████▌| 25/26 [00:47<00:01,  1.91s/it]

Loss: 1.9301812648773193


Epoch 7: 100%|██████████| 26/26 [00:48<00:00,  1.87s/it]

Loss: 0.5590689778327942





Epoch 7 Validation Accuracy: 0.8078817733990148, F1-macro: 0.44686648501362397


Epoch 8:   4%|▍         | 1/26 [00:01<00:47,  1.90s/it]

Loss: 0.4786396622657776


Epoch 8:   8%|▊         | 2/26 [00:03<00:45,  1.91s/it]

Loss: 0.6172557473182678


Epoch 8:  12%|█▏        | 3/26 [00:05<00:43,  1.91s/it]

Loss: 0.9037629961967468


Epoch 8:  15%|█▌        | 4/26 [00:07<00:41,  1.91s/it]

Loss: 0.6840828061103821


Epoch 8:  19%|█▉        | 5/26 [00:09<00:40,  1.91s/it]

Loss: 1.427910566329956


Epoch 8:  23%|██▎       | 6/26 [00:11<00:38,  1.91s/it]

Loss: 0.34272444248199463


Epoch 8:  27%|██▋       | 7/26 [00:13<00:36,  1.90s/it]

Loss: 0.35375961661338806


Epoch 8:  31%|███       | 8/26 [00:15<00:34,  1.90s/it]

Loss: 1.0843310356140137


Epoch 8:  35%|███▍      | 9/26 [00:17<00:32,  1.90s/it]

Loss: 0.3489935100078583


Epoch 8:  38%|███▊      | 10/26 [00:19<00:30,  1.90s/it]

Loss: 0.4178984463214874


Epoch 8:  42%|████▏     | 11/26 [00:20<00:28,  1.90s/it]

Loss: 0.9632951021194458


Epoch 8:  46%|████▌     | 12/26 [00:22<00:26,  1.90s/it]

Loss: 0.5374242067337036


Epoch 8:  50%|█████     | 13/26 [00:24<00:24,  1.91s/it]

Loss: 0.5991805791854858


Epoch 8:  54%|█████▍    | 14/26 [00:26<00:22,  1.91s/it]

Loss: 1.125866413116455


Epoch 8:  58%|█████▊    | 15/26 [00:28<00:20,  1.91s/it]

Loss: 0.5226573944091797


Epoch 8:  62%|██████▏   | 16/26 [00:30<00:19,  1.91s/it]

Loss: 1.5590195655822754


Epoch 8:  65%|██████▌   | 17/26 [00:32<00:17,  1.91s/it]

Loss: 0.3974115252494812


Epoch 8:  69%|██████▉   | 18/26 [00:34<00:15,  1.91s/it]

Loss: 0.9445418119430542


Epoch 8:  73%|███████▎  | 19/26 [00:36<00:13,  1.91s/it]

Loss: 1.3332417011260986


Epoch 8:  77%|███████▋  | 20/26 [00:38<00:11,  1.91s/it]

Loss: 0.5850015878677368


Epoch 8:  81%|████████  | 21/26 [00:40<00:09,  1.91s/it]

Loss: 0.4875231683254242


Epoch 8:  85%|████████▍ | 22/26 [00:41<00:07,  1.91s/it]

Loss: 0.730846643447876


Epoch 8:  88%|████████▊ | 23/26 [00:43<00:05,  1.91s/it]

Loss: 0.3544831871986389


Epoch 8:  92%|█████████▏| 24/26 [00:45<00:03,  1.91s/it]

Loss: 0.4859081506729126


Epoch 8:  96%|█████████▌| 25/26 [00:47<00:01,  1.91s/it]

Loss: 0.41427046060562134


Epoch 8: 100%|██████████| 26/26 [00:48<00:00,  1.87s/it]

Loss: 0.32647600769996643





Epoch 8 Validation Accuracy: 0.7536945812807881, F1-macro: 0.5110789980732178


Epoch 9:   4%|▍         | 1/26 [00:01<00:47,  1.90s/it]

Loss: 0.47117841243743896


Epoch 9:   8%|▊         | 2/26 [00:03<00:45,  1.91s/it]

Loss: 0.41806137561798096


Epoch 9:  12%|█▏        | 3/26 [00:05<00:43,  1.91s/it]

Loss: 0.1358678787946701


Epoch 9:  15%|█▌        | 4/26 [00:07<00:42,  1.91s/it]

Loss: 0.6463010311126709


Epoch 9:  19%|█▉        | 5/26 [00:09<00:40,  1.91s/it]

Loss: 0.29155290126800537


Epoch 9:  23%|██▎       | 6/26 [00:11<00:38,  1.91s/it]

Loss: 0.7520166635513306


Epoch 9:  27%|██▋       | 7/26 [00:13<00:36,  1.91s/it]

Loss: 0.8150612711906433


Epoch 9:  31%|███       | 8/26 [00:15<00:34,  1.91s/it]

Loss: 0.5834530591964722


Epoch 9:  35%|███▍      | 9/26 [00:17<00:32,  1.91s/it]

Loss: 0.26194465160369873


Epoch 9:  38%|███▊      | 10/26 [00:19<00:30,  1.91s/it]

Loss: 0.18136771023273468


Epoch 9:  42%|████▏     | 11/26 [00:20<00:28,  1.91s/it]

Loss: 0.5619713068008423


Epoch 9:  46%|████▌     | 12/26 [00:22<00:26,  1.91s/it]

Loss: 0.725823700428009


Epoch 9:  50%|█████     | 13/26 [00:24<00:24,  1.91s/it]

Loss: 0.8296211957931519


Epoch 9:  54%|█████▍    | 14/26 [00:26<00:22,  1.91s/it]

Loss: 0.6006363034248352


Epoch 9:  58%|█████▊    | 15/26 [00:28<00:20,  1.91s/it]

Loss: 0.5029217004776001


Epoch 9:  62%|██████▏   | 16/26 [00:30<00:19,  1.91s/it]

Loss: 0.3745891749858856


Epoch 9:  65%|██████▌   | 17/26 [00:32<00:17,  1.91s/it]

Loss: 0.5467751622200012


Epoch 9:  69%|██████▉   | 18/26 [00:34<00:15,  1.91s/it]

Loss: 0.5174494385719299


Epoch 9:  73%|███████▎  | 19/26 [00:36<00:13,  1.91s/it]

Loss: 1.7763944864273071


Epoch 9:  77%|███████▋  | 20/26 [00:38<00:11,  1.91s/it]

Loss: 0.4017944037914276


Epoch 9:  81%|████████  | 21/26 [00:40<00:09,  1.91s/it]

Loss: 1.8160278797149658


Epoch 9:  85%|████████▍ | 22/26 [00:41<00:07,  1.91s/it]

Loss: 1.4836056232452393


Epoch 9:  88%|████████▊ | 23/26 [00:43<00:05,  1.91s/it]

Loss: 1.4876468181610107


Epoch 9:  92%|█████████▏| 24/26 [00:45<00:03,  1.91s/it]

Loss: 0.9534850120544434


Epoch 9:  96%|█████████▌| 25/26 [00:47<00:01,  1.91s/it]

Loss: 1.4007861614227295


Epoch 9: 100%|██████████| 26/26 [00:48<00:00,  1.87s/it]

Loss: 0.8230497241020203





Epoch 9 Validation Accuracy: 0.8029556650246306, F1-macro: 0.46886446886446886


Epoch 10:   4%|▍         | 1/26 [00:01<00:47,  1.90s/it]

Loss: 0.16321012377738953


Epoch 10:   8%|▊         | 2/26 [00:03<00:45,  1.90s/it]

Loss: 1.1995822191238403


Epoch 10:  12%|█▏        | 3/26 [00:05<00:43,  1.90s/it]

Loss: 1.070167064666748


Epoch 10:  15%|█▌        | 4/26 [00:07<00:41,  1.91s/it]

Loss: 1.9102294445037842


Epoch 10:  19%|█▉        | 5/26 [00:09<00:40,  1.91s/it]

Loss: 0.7514371871948242


Epoch 10:  23%|██▎       | 6/26 [00:11<00:38,  1.91s/it]

Loss: 1.5423610210418701


Epoch 10:  27%|██▋       | 7/26 [00:13<00:36,  1.91s/it]

Loss: 0.5112309455871582


Epoch 10:  31%|███       | 8/26 [00:15<00:34,  1.91s/it]

Loss: 0.6021196842193604


Epoch 10:  35%|███▍      | 9/26 [00:17<00:32,  1.91s/it]

Loss: 1.013796329498291


Epoch 10:  38%|███▊      | 10/26 [00:19<00:30,  1.91s/it]

Loss: 0.5575847625732422


Epoch 10:  42%|████▏     | 11/26 [00:20<00:28,  1.91s/it]

Loss: 0.4254000186920166


Epoch 10:  46%|████▌     | 12/26 [00:22<00:26,  1.91s/it]

Loss: 0.5726253986358643


Epoch 10:  50%|█████     | 13/26 [00:24<00:24,  1.91s/it]

Loss: 0.37432748079299927


Epoch 10:  54%|█████▍    | 14/26 [00:26<00:22,  1.91s/it]

Loss: 0.31038427352905273


Epoch 10:  58%|█████▊    | 15/26 [00:28<00:20,  1.91s/it]

Loss: 0.36866921186447144


Epoch 10:  62%|██████▏   | 16/26 [00:30<00:19,  1.91s/it]

Loss: 0.5755787491798401


Epoch 10:  65%|██████▌   | 17/26 [00:32<00:17,  1.91s/it]

Loss: 0.5213171243667603


Epoch 10:  69%|██████▉   | 18/26 [00:34<00:15,  1.91s/it]

Loss: 0.3657069206237793


Epoch 10:  73%|███████▎  | 19/26 [00:36<00:13,  1.91s/it]

Loss: 0.6236549615859985


Epoch 10:  77%|███████▋  | 20/26 [00:38<00:11,  1.91s/it]

Loss: 0.28816869854927063


Epoch 10:  81%|████████  | 21/26 [00:40<00:09,  1.91s/it]

Loss: 0.7458536624908447


Epoch 10:  85%|████████▍ | 22/26 [00:41<00:07,  1.91s/it]

Loss: 0.15634045004844666


Epoch 10:  88%|████████▊ | 23/26 [00:43<00:05,  1.91s/it]

Loss: 0.2182798534631729


Epoch 10:  92%|█████████▏| 24/26 [00:45<00:03,  1.91s/it]

Loss: 0.4728226959705353


Epoch 10:  96%|█████████▌| 25/26 [00:47<00:01,  1.91s/it]

Loss: 0.3551477789878845


Epoch 10: 100%|██████████| 26/26 [00:48<00:00,  1.87s/it]

Loss: 0.09348882734775543





Epoch 10 Validation Accuracy: 0.8029556650246306, F1-macro: 0.46886446886446886


In [None]:
# Evaluate on the test set
test_metrics = evaluate(model, test_dataloader, device)
print(f"Test Accuracy: {test_metrics['accuracy']}, F1-macro: {test_metrics['f1_macro']}")

Test Accuracy: 0.8480392156862745, F1-macro: 0.46886446886446886


In [None]:
current_type = 'labeling'

results_df.loc[
    results_df["distortion_type"] == current_type,
    ["test_accuracy", "f1_macro"]
] = [test_metrics['accuracy'], test_metrics['f1_macro']]

# 7. Magnification

In [None]:
# Add labels
data1_1_labels = list(data1['labeling'][data1_1.index])
data2_1_labels = list(data2['labeling'][data2_1.index])
data3_1_labels = list(data3['labeling'][data3_1.index])

# Merging Data
data_encoded = data1_1_encoded + data2_1_encoded + data3_1_encoded
data_labels = data1_1_labels + data2_1_labels + data3_1_labels

dataset_with_labels = CustomDatasetWithLabels(data_encoded, data_labels)

# Define proportions for splitting
train_size = int(0.8 * len(dataset_with_labels))
val_size = int(0.1 * len(dataset_with_labels))
test_size = len(dataset_with_labels) - train_size - val_size

# Split the dataset
train_dataset, val_dataset, test_dataset = random_split(dataset_with_labels, [train_size, val_size, test_size])

# Create DataLoaders for each set
train_dataloader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=64)
test_dataloader = DataLoader(test_dataset, batch_size=64)

input_dim = bert_config.hidden_size
num_labels = len(set(data_labels))
label_emb = torch.randn(num_labels, input_dim)

In [None]:
# Instantiate the InnerProductClassifier model
model = InnerProductClassifier(input_dim, label_emb).to(device)

# Define loss function
criterion = nn.CrossEntropyLoss()

# Define optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=2e-4)

In [None]:
EPOCHS = 10

# Move the BERT model to the device
bert_model.to(device)

for epoch in range(EPOCHS):
    model.train()
    for batch in tqdm(train_dataloader, desc=f"Epoch {epoch+1}"):

        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        y = batch["y"].to(device)

        # Get embeddings from the BERT model
        with torch.no_grad():
            outputs = bert_model(input_ids=input_ids, attention_mask=attention_mask)
            embeddings = outputs.last_hidden_state[:, 0, :] # Get the [CLS] token embedding

        # Pass embeddings to the InnerProductClassifier
        logits = model(embeddings)
        loss = criterion(logits, y)

        # Backpropagate and update weights
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        print(f"Loss: {loss.item()}")

    # Evaluate on the validation set after each epoch
    val_metrics = evaluate(model, val_dataloader, device)
    print(f"Epoch {epoch+1} Validation Accuracy: {val_metrics['accuracy']}, F1-macro: {val_metrics['f1_macro']}")

Epoch 1:   4%|▍         | 1/26 [00:01<00:47,  1.89s/it]

Loss: 2.5669641494750977


Epoch 1:   8%|▊         | 2/26 [00:03<00:45,  1.90s/it]

Loss: 4.225327968597412


Epoch 1:  12%|█▏        | 3/26 [00:05<00:43,  1.90s/it]

Loss: 6.898487091064453


Epoch 1:  15%|█▌        | 4/26 [00:07<00:41,  1.91s/it]

Loss: 8.876835823059082


Epoch 1:  19%|█▉        | 5/26 [00:09<00:40,  1.91s/it]

Loss: 7.660277366638184


Epoch 1:  23%|██▎       | 6/26 [00:11<00:38,  1.91s/it]

Loss: 3.9702022075653076


Epoch 1:  27%|██▋       | 7/26 [00:13<00:36,  1.91s/it]

Loss: 0.7789842486381531


Epoch 1:  31%|███       | 8/26 [00:15<00:34,  1.91s/it]

Loss: 1.9379198551177979


Epoch 1:  35%|███▍      | 9/26 [00:17<00:32,  1.91s/it]

Loss: 5.788169860839844


Epoch 1:  38%|███▊      | 10/26 [00:19<00:30,  1.91s/it]

Loss: 3.495021343231201


Epoch 1:  42%|████▏     | 11/26 [00:20<00:28,  1.91s/it]

Loss: 1.4611440896987915


Epoch 1:  46%|████▌     | 12/26 [00:22<00:26,  1.91s/it]

Loss: 3.07470965385437


Epoch 1:  50%|█████     | 13/26 [00:24<00:24,  1.91s/it]

Loss: 4.24186897277832


Epoch 1:  54%|█████▍    | 14/26 [00:26<00:22,  1.91s/it]

Loss: 4.323698043823242


Epoch 1:  58%|█████▊    | 15/26 [00:28<00:20,  1.91s/it]

Loss: 5.06662654876709


Epoch 1:  62%|██████▏   | 16/26 [00:30<00:19,  1.91s/it]

Loss: 6.734378337860107


Epoch 1:  65%|██████▌   | 17/26 [00:32<00:17,  1.91s/it]

Loss: 4.315465927124023


Epoch 1:  69%|██████▉   | 18/26 [00:34<00:15,  1.91s/it]

Loss: 1.993066430091858


Epoch 1:  73%|███████▎  | 19/26 [00:36<00:13,  1.91s/it]

Loss: 0.9578880071640015


Epoch 1:  77%|███████▋  | 20/26 [00:38<00:11,  1.91s/it]

Loss: 2.1891584396362305


Epoch 1:  81%|████████  | 21/26 [00:40<00:09,  1.91s/it]

Loss: 2.3523130416870117


Epoch 1:  85%|████████▍ | 22/26 [00:41<00:07,  1.91s/it]

Loss: 0.7751592397689819


Epoch 1:  88%|████████▊ | 23/26 [00:43<00:05,  1.91s/it]

Loss: 1.1120014190673828


Epoch 1:  92%|█████████▏| 24/26 [00:45<00:03,  1.91s/it]

Loss: 0.7360212802886963


Epoch 1:  96%|█████████▌| 25/26 [00:47<00:01,  1.91s/it]

Loss: 3.8071985244750977


Epoch 1: 100%|██████████| 26/26 [00:48<00:00,  1.87s/it]

Loss: 1.1130977869033813





Epoch 1 Validation Accuracy: 0.8374384236453202, F1-macro: 0.45576407506702415


Epoch 2:   4%|▍         | 1/26 [00:01<00:47,  1.90s/it]

Loss: 1.4566636085510254


Epoch 2:   8%|▊         | 2/26 [00:03<00:45,  1.91s/it]

Loss: 1.3839530944824219


Epoch 2:  12%|█▏        | 3/26 [00:05<00:43,  1.91s/it]

Loss: 0.9931828379631042


Epoch 2:  15%|█▌        | 4/26 [00:07<00:41,  1.91s/it]

Loss: 2.3389298915863037


Epoch 2:  19%|█▉        | 5/26 [00:09<00:40,  1.91s/it]

Loss: 1.3643428087234497


Epoch 2:  23%|██▎       | 6/26 [00:11<00:38,  1.91s/it]

Loss: 1.378415822982788


Epoch 2:  27%|██▋       | 7/26 [00:13<00:36,  1.91s/it]

Loss: 1.6541674137115479


Epoch 2:  31%|███       | 8/26 [00:15<00:34,  1.91s/it]

Loss: 1.8443273305892944


Epoch 2:  35%|███▍      | 9/26 [00:17<00:32,  1.91s/it]

Loss: 2.215365409851074


Epoch 2:  38%|███▊      | 10/26 [00:19<00:30,  1.91s/it]

Loss: 2.47981858253479


Epoch 2:  42%|████▏     | 11/26 [00:20<00:28,  1.91s/it]

Loss: 0.7936493754386902


Epoch 2:  46%|████▌     | 12/26 [00:22<00:26,  1.91s/it]

Loss: 2.219316005706787


Epoch 2:  50%|█████     | 13/26 [00:24<00:24,  1.91s/it]

Loss: 0.6649593114852905


Epoch 2:  54%|█████▍    | 14/26 [00:26<00:22,  1.91s/it]

Loss: 0.9125398397445679


Epoch 2:  58%|█████▊    | 15/26 [00:28<00:21,  1.91s/it]

Loss: 1.4157476425170898


Epoch 2:  62%|██████▏   | 16/26 [00:30<00:19,  1.91s/it]

Loss: 1.6462068557739258


Epoch 2:  65%|██████▌   | 17/26 [00:32<00:17,  1.91s/it]

Loss: 1.2169827222824097


Epoch 2:  69%|██████▉   | 18/26 [00:34<00:15,  1.91s/it]

Loss: 1.036221981048584


Epoch 2:  73%|███████▎  | 19/26 [00:36<00:13,  1.91s/it]

Loss: 0.9982481598854065


Epoch 2:  77%|███████▋  | 20/26 [00:38<00:11,  1.91s/it]

Loss: 1.0050266981124878


Epoch 2:  81%|████████  | 21/26 [00:40<00:09,  1.91s/it]

Loss: 1.7519268989562988


Epoch 2:  85%|████████▍ | 22/26 [00:41<00:07,  1.91s/it]

Loss: 0.8347549438476562


Epoch 2:  88%|████████▊ | 23/26 [00:43<00:05,  1.91s/it]

Loss: 1.0700805187225342


Epoch 2:  92%|█████████▏| 24/26 [00:45<00:03,  1.91s/it]

Loss: 1.1044790744781494


Epoch 2:  96%|█████████▌| 25/26 [00:47<00:01,  1.91s/it]

Loss: 1.9085710048675537


Epoch 2: 100%|██████████| 26/26 [00:48<00:00,  1.87s/it]

Loss: 2.0197653770446777





Epoch 2 Validation Accuracy: 0.8423645320197044, F1-macro: 0.4864010120177103


Epoch 3:   4%|▍         | 1/26 [00:01<00:47,  1.90s/it]

Loss: 1.5939844846725464


Epoch 3:   8%|▊         | 2/26 [00:03<00:45,  1.91s/it]

Loss: 1.0973992347717285


Epoch 3:  12%|█▏        | 3/26 [00:05<00:43,  1.91s/it]

Loss: 1.6525840759277344


Epoch 3:  15%|█▌        | 4/26 [00:07<00:41,  1.91s/it]

Loss: 1.1012805700302124


Epoch 3:  19%|█▉        | 5/26 [00:09<00:40,  1.91s/it]

Loss: 0.3336179852485657


Epoch 3:  23%|██▎       | 6/26 [00:11<00:38,  1.91s/it]

Loss: 1.0020802021026611


Epoch 3:  27%|██▋       | 7/26 [00:13<00:36,  1.91s/it]

Loss: 1.8693126440048218


Epoch 3:  31%|███       | 8/26 [00:15<00:34,  1.91s/it]

Loss: 0.8458381295204163


Epoch 3:  35%|███▍      | 9/26 [00:17<00:32,  1.91s/it]

Loss: 0.357948899269104


Epoch 3:  38%|███▊      | 10/26 [00:19<00:30,  1.91s/it]

Loss: 1.0173524618148804


Epoch 3:  42%|████▏     | 11/26 [00:20<00:28,  1.91s/it]

Loss: 1.0529664754867554


Epoch 3:  46%|████▌     | 12/26 [00:22<00:26,  1.91s/it]

Loss: 0.6246858239173889


Epoch 3:  50%|█████     | 13/26 [00:24<00:24,  1.91s/it]

Loss: 0.3351520597934723


Epoch 3:  54%|█████▍    | 14/26 [00:26<00:22,  1.91s/it]

Loss: 0.5564868450164795


Epoch 3:  58%|█████▊    | 15/26 [00:28<00:21,  1.91s/it]

Loss: 0.6867684721946716


Epoch 3:  62%|██████▏   | 16/26 [00:30<00:19,  1.91s/it]

Loss: 0.5847443342208862


Epoch 3:  65%|██████▌   | 17/26 [00:32<00:17,  1.91s/it]

Loss: 0.9689459204673767


Epoch 3:  69%|██████▉   | 18/26 [00:34<00:15,  1.91s/it]

Loss: 0.6997653245925903


Epoch 3:  73%|███████▎  | 19/26 [00:36<00:13,  1.91s/it]

Loss: 1.0143625736236572


Epoch 3:  77%|███████▋  | 20/26 [00:38<00:11,  1.91s/it]

Loss: 0.8930218815803528


Epoch 3:  81%|████████  | 21/26 [00:40<00:09,  1.91s/it]

Loss: 0.734319806098938


Epoch 3:  85%|████████▍ | 22/26 [00:41<00:07,  1.91s/it]

Loss: 0.8520646095275879


Epoch 3:  88%|████████▊ | 23/26 [00:43<00:05,  1.91s/it]

Loss: 0.5169901251792908


Epoch 3:  92%|█████████▏| 24/26 [00:45<00:03,  1.91s/it]

Loss: 0.5570031404495239


Epoch 3:  96%|█████████▌| 25/26 [00:47<00:01,  1.91s/it]

Loss: 0.5378599762916565


Epoch 3: 100%|██████████| 26/26 [00:48<00:00,  1.87s/it]

Loss: 0.5211999416351318





Epoch 3 Validation Accuracy: 0.8078817733990148, F1-macro: 0.5126500461680518


Epoch 4:   4%|▍         | 1/26 [00:01<00:47,  1.90s/it]

Loss: 0.8851429224014282


Epoch 4:   8%|▊         | 2/26 [00:03<00:45,  1.90s/it]

Loss: 0.834342360496521


Epoch 4:  12%|█▏        | 3/26 [00:05<00:43,  1.91s/it]

Loss: 0.39974355697631836


Epoch 4:  15%|█▌        | 4/26 [00:07<00:41,  1.91s/it]

Loss: 0.6473882794380188


Epoch 4:  19%|█▉        | 5/26 [00:09<00:40,  1.91s/it]

Loss: 0.504846453666687


Epoch 4:  23%|██▎       | 6/26 [00:11<00:38,  1.91s/it]

Loss: 0.4870608448982239


Epoch 4:  27%|██▋       | 7/26 [00:13<00:36,  1.91s/it]

Loss: 0.5647699236869812


Epoch 4:  31%|███       | 8/26 [00:15<00:34,  1.91s/it]

Loss: 0.4811464250087738


Epoch 4:  35%|███▍      | 9/26 [00:17<00:32,  1.91s/it]

Loss: 0.5440330505371094


Epoch 4:  38%|███▊      | 10/26 [00:19<00:30,  1.91s/it]

Loss: 0.5962293148040771


Epoch 4:  42%|████▏     | 11/26 [00:20<00:28,  1.91s/it]

Loss: 0.6958621144294739


Epoch 4:  46%|████▌     | 12/26 [00:22<00:26,  1.91s/it]

Loss: 0.6201059818267822


Epoch 4:  50%|█████     | 13/26 [00:24<00:24,  1.91s/it]

Loss: 0.8979142904281616


Epoch 4:  54%|█████▍    | 14/26 [00:26<00:22,  1.91s/it]

Loss: 0.7263116240501404


Epoch 4:  58%|█████▊    | 15/26 [00:28<00:21,  1.91s/it]

Loss: 0.5943115949630737


Epoch 4:  62%|██████▏   | 16/26 [00:30<00:19,  1.91s/it]

Loss: 0.28247547149658203


Epoch 4:  65%|██████▌   | 17/26 [00:32<00:17,  1.91s/it]

Loss: 0.3447014093399048


Epoch 4:  69%|██████▉   | 18/26 [00:34<00:15,  1.91s/it]

Loss: 0.5869297981262207


Epoch 4:  73%|███████▎  | 19/26 [00:36<00:13,  1.91s/it]

Loss: 0.5042721033096313


Epoch 4:  77%|███████▋  | 20/26 [00:38<00:11,  1.91s/it]

Loss: 0.5000345706939697


Epoch 4:  81%|████████  | 21/26 [00:40<00:09,  1.91s/it]

Loss: 0.5475504398345947


Epoch 4:  85%|████████▍ | 22/26 [00:41<00:07,  1.91s/it]

Loss: 0.8006916046142578


Epoch 4:  88%|████████▊ | 23/26 [00:43<00:05,  1.91s/it]

Loss: 0.47164177894592285


Epoch 4:  92%|█████████▏| 24/26 [00:45<00:03,  1.91s/it]

Loss: 0.5373767018318176


Epoch 4:  96%|█████████▌| 25/26 [00:47<00:01,  1.91s/it]

Loss: 0.6346796751022339


Epoch 4: 100%|██████████| 26/26 [00:48<00:00,  1.87s/it]

Loss: 0.32176968455314636





Epoch 4 Validation Accuracy: 0.812807881773399, F1-macro: 0.5941708754208754


Epoch 5:   4%|▍         | 1/26 [00:01<00:47,  1.90s/it]

Loss: 0.38332635164260864


Epoch 5:   8%|▊         | 2/26 [00:03<00:45,  1.90s/it]

Loss: 0.45213979482650757


Epoch 5:  12%|█▏        | 3/26 [00:05<00:43,  1.91s/it]

Loss: 0.5511105060577393


Epoch 5:  15%|█▌        | 4/26 [00:07<00:41,  1.91s/it]

Loss: 0.2767074406147003


Epoch 5:  19%|█▉        | 5/26 [00:09<00:40,  1.91s/it]

Loss: 0.6235456466674805


Epoch 5:  23%|██▎       | 6/26 [00:11<00:38,  1.91s/it]

Loss: 0.6556408405303955


Epoch 5:  27%|██▋       | 7/26 [00:13<00:36,  1.91s/it]

Loss: 1.5936591625213623


Epoch 5:  31%|███       | 8/26 [00:15<00:34,  1.91s/it]

Loss: 0.9017602801322937


Epoch 5:  35%|███▍      | 9/26 [00:17<00:32,  1.91s/it]

Loss: 0.6691803932189941


Epoch 5:  38%|███▊      | 10/26 [00:19<00:30,  1.91s/it]

Loss: 0.7370567321777344


Epoch 5:  42%|████▏     | 11/26 [00:20<00:28,  1.91s/it]

Loss: 0.41213512420654297


Epoch 5:  46%|████▌     | 12/26 [00:22<00:26,  1.91s/it]

Loss: 0.3122277557849884


Epoch 5:  50%|█████     | 13/26 [00:24<00:24,  1.91s/it]

Loss: 0.435955286026001


Epoch 5:  54%|█████▍    | 14/26 [00:26<00:22,  1.91s/it]

Loss: 0.5992876291275024


Epoch 5:  58%|█████▊    | 15/26 [00:28<00:20,  1.91s/it]

Loss: 0.6153934597969055


Epoch 5:  62%|██████▏   | 16/26 [00:30<00:19,  1.91s/it]

Loss: 0.24269843101501465


Epoch 5:  65%|██████▌   | 17/26 [00:32<00:17,  1.91s/it]

Loss: 0.3811137080192566


Epoch 5:  69%|██████▉   | 18/26 [00:34<00:15,  1.91s/it]

Loss: 0.663088321685791


Epoch 5:  73%|███████▎  | 19/26 [00:36<00:13,  1.91s/it]

Loss: 0.4731389880180359


Epoch 5:  77%|███████▋  | 20/26 [00:38<00:11,  1.91s/it]

Loss: 0.5918651819229126


Epoch 5:  81%|████████  | 21/26 [00:40<00:09,  1.91s/it]

Loss: 0.7219990491867065


Epoch 5:  85%|████████▍ | 22/26 [00:41<00:07,  1.91s/it]

Loss: 0.39729613065719604


Epoch 5:  88%|████████▊ | 23/26 [00:43<00:05,  1.91s/it]

Loss: 0.9258453845977783


Epoch 5:  92%|█████████▏| 24/26 [00:45<00:03,  1.91s/it]

Loss: 0.5415016412734985


Epoch 5:  96%|█████████▌| 25/26 [00:47<00:01,  1.91s/it]

Loss: 0.5051262378692627


Epoch 5: 100%|██████████| 26/26 [00:48<00:00,  1.87s/it]

Loss: 0.5044880509376526





Epoch 5 Validation Accuracy: 0.8177339901477833, F1-macro: 0.5188032545326414


Epoch 6:   4%|▍         | 1/26 [00:01<00:47,  1.90s/it]

Loss: 0.6426412463188171


Epoch 6:   8%|▊         | 2/26 [00:03<00:45,  1.91s/it]

Loss: 0.4397096037864685


Epoch 6:  12%|█▏        | 3/26 [00:05<00:43,  1.91s/it]

Loss: 0.5632728338241577


Epoch 6:  15%|█▌        | 4/26 [00:07<00:42,  1.91s/it]

Loss: 0.35401713848114014


Epoch 6:  19%|█▉        | 5/26 [00:09<00:40,  1.91s/it]

Loss: 0.39727115631103516


Epoch 6:  23%|██▎       | 6/26 [00:11<00:38,  1.91s/it]

Loss: 0.43935590982437134


Epoch 6:  27%|██▋       | 7/26 [00:13<00:36,  1.91s/it]

Loss: 0.45046621561050415


Epoch 6:  31%|███       | 8/26 [00:15<00:34,  1.91s/it]

Loss: 0.24378401041030884


Epoch 6:  35%|███▍      | 9/26 [00:17<00:32,  1.91s/it]

Loss: 0.38370782136917114


Epoch 6:  38%|███▊      | 10/26 [00:19<00:30,  1.91s/it]

Loss: 0.3102806806564331


Epoch 6:  42%|████▏     | 11/26 [00:20<00:28,  1.91s/it]

Loss: 0.39397522807121277


Epoch 6:  46%|████▌     | 12/26 [00:22<00:26,  1.91s/it]

Loss: 0.3595794141292572


Epoch 6:  50%|█████     | 13/26 [00:24<00:24,  1.91s/it]

Loss: 0.4453008770942688


Epoch 6:  54%|█████▍    | 14/26 [00:26<00:22,  1.91s/it]

Loss: 0.5212079882621765


Epoch 6:  58%|█████▊    | 15/26 [00:28<00:20,  1.91s/it]

Loss: 0.6071828603744507


Epoch 6:  62%|██████▏   | 16/26 [00:30<00:19,  1.91s/it]

Loss: 0.5400530099868774


Epoch 6:  65%|██████▌   | 17/26 [00:32<00:17,  1.91s/it]

Loss: 0.25630491971969604


Epoch 6:  69%|██████▉   | 18/26 [00:34<00:15,  1.91s/it]

Loss: 0.37072041630744934


Epoch 6:  73%|███████▎  | 19/26 [00:36<00:13,  1.91s/it]

Loss: 0.43587160110473633


Epoch 6:  77%|███████▋  | 20/26 [00:38<00:11,  1.91s/it]

Loss: 0.24183981120586395


Epoch 6:  81%|████████  | 21/26 [00:40<00:09,  1.91s/it]

Loss: 0.3566038906574249


Epoch 6:  85%|████████▍ | 22/26 [00:41<00:07,  1.91s/it]

Loss: 0.9427809119224548


Epoch 6:  88%|████████▊ | 23/26 [00:43<00:05,  1.91s/it]

Loss: 0.5872857570648193


Epoch 6:  92%|█████████▏| 24/26 [00:45<00:03,  1.91s/it]

Loss: 0.9319219589233398


Epoch 6:  96%|█████████▌| 25/26 [00:47<00:01,  1.91s/it]

Loss: 0.43632665276527405


Epoch 6: 100%|██████████| 26/26 [00:48<00:00,  1.87s/it]

Loss: 1.5119352340698242





Epoch 6 Validation Accuracy: 0.7783251231527094, F1-macro: 0.5538025692375323


Epoch 7:   4%|▍         | 1/26 [00:01<00:47,  1.90s/it]

Loss: 0.2014576941728592


Epoch 7:   8%|▊         | 2/26 [00:03<00:45,  1.91s/it]

Loss: 0.5307000875473022


Epoch 7:  12%|█▏        | 3/26 [00:05<00:43,  1.91s/it]

Loss: 0.2943628430366516


Epoch 7:  15%|█▌        | 4/26 [00:07<00:41,  1.91s/it]

Loss: 0.22634565830230713


Epoch 7:  19%|█▉        | 5/26 [00:09<00:40,  1.91s/it]

Loss: 0.7839301824569702


Epoch 7:  23%|██▎       | 6/26 [00:11<00:38,  1.91s/it]

Loss: 0.36825841665267944


Epoch 7:  27%|██▋       | 7/26 [00:13<00:36,  1.91s/it]

Loss: 0.5699021220207214


Epoch 7:  31%|███       | 8/26 [00:15<00:34,  1.91s/it]

Loss: 0.402920126914978


Epoch 7:  35%|███▍      | 9/26 [00:17<00:32,  1.91s/it]

Loss: 0.5451096892356873


Epoch 7:  38%|███▊      | 10/26 [00:19<00:30,  1.91s/it]

Loss: 0.3654346168041229


Epoch 7:  42%|████▏     | 11/26 [00:20<00:28,  1.91s/it]

Loss: 0.39479440450668335


Epoch 7:  46%|████▌     | 12/26 [00:22<00:26,  1.91s/it]

Loss: 0.4050471782684326


Epoch 7:  50%|█████     | 13/26 [00:24<00:24,  1.91s/it]

Loss: 0.30045002698898315


Epoch 7:  54%|█████▍    | 14/26 [00:26<00:22,  1.91s/it]

Loss: 0.18682971596717834


Epoch 7:  58%|█████▊    | 15/26 [00:28<00:20,  1.91s/it]

Loss: 0.34116852283477783


Epoch 7:  62%|██████▏   | 16/26 [00:30<00:19,  1.91s/it]

Loss: 0.6506538987159729


Epoch 7:  65%|██████▌   | 17/26 [00:32<00:17,  1.91s/it]

Loss: 0.5223534107208252


Epoch 7:  69%|██████▉   | 18/26 [00:34<00:15,  1.91s/it]

Loss: 0.527518093585968


Epoch 7:  73%|███████▎  | 19/26 [00:36<00:13,  1.91s/it]

Loss: 0.3874967098236084


Epoch 7:  77%|███████▋  | 20/26 [00:38<00:11,  1.91s/it]

Loss: 0.7309751510620117


Epoch 7:  81%|████████  | 21/26 [00:40<00:09,  1.91s/it]

Loss: 0.8638402223587036


Epoch 7:  85%|████████▍ | 22/26 [00:41<00:07,  1.91s/it]

Loss: 0.3181455433368683


Epoch 7:  88%|████████▊ | 23/26 [00:43<00:05,  1.91s/it]

Loss: 0.6988265514373779


Epoch 7:  92%|█████████▏| 24/26 [00:45<00:03,  1.91s/it]

Loss: 0.5203967690467834


Epoch 7:  96%|█████████▌| 25/26 [00:47<00:01,  1.91s/it]

Loss: 0.5421240329742432


Epoch 7: 100%|██████████| 26/26 [00:48<00:00,  1.87s/it]

Loss: 0.3403645157814026





Epoch 7 Validation Accuracy: 0.7093596059113301, F1-macro: 0.5532470439031668


Epoch 8:   4%|▍         | 1/26 [00:01<00:47,  1.90s/it]

Loss: 0.4409165680408478


Epoch 8:   8%|▊         | 2/26 [00:03<00:45,  1.91s/it]

Loss: 0.30700328946113586


Epoch 8:  12%|█▏        | 3/26 [00:05<00:43,  1.91s/it]

Loss: 0.4301944673061371


Epoch 8:  15%|█▌        | 4/26 [00:07<00:42,  1.91s/it]

Loss: 0.5637531280517578


Epoch 8:  19%|█▉        | 5/26 [00:09<00:40,  1.91s/it]

Loss: 0.5476216673851013


Epoch 8:  23%|██▎       | 6/26 [00:11<00:38,  1.91s/it]

Loss: 0.4022454619407654


Epoch 8:  27%|██▋       | 7/26 [00:13<00:36,  1.91s/it]

Loss: 0.40063679218292236


Epoch 8:  31%|███       | 8/26 [00:15<00:34,  1.91s/it]

Loss: 0.2128826081752777


Epoch 8:  35%|███▍      | 9/26 [00:17<00:32,  1.91s/it]

Loss: 0.38890540599823


Epoch 8:  38%|███▊      | 10/26 [00:19<00:30,  1.91s/it]

Loss: 0.49260836839675903


Epoch 8:  42%|████▏     | 11/26 [00:20<00:28,  1.91s/it]

Loss: 0.47585779428482056


Epoch 8:  46%|████▌     | 12/26 [00:22<00:26,  1.91s/it]

Loss: 0.5015017986297607


Epoch 8:  50%|█████     | 13/26 [00:24<00:24,  1.91s/it]

Loss: 0.4223363399505615


Epoch 8:  54%|█████▍    | 14/26 [00:26<00:22,  1.91s/it]

Loss: 0.5286614894866943


Epoch 8:  58%|█████▊    | 15/26 [00:28<00:20,  1.91s/it]

Loss: 0.32808399200439453


Epoch 8:  62%|██████▏   | 16/26 [00:30<00:19,  1.91s/it]

Loss: 0.5945337414741516


Epoch 8:  65%|██████▌   | 17/26 [00:32<00:17,  1.91s/it]

Loss: 0.492262065410614


Epoch 8:  69%|██████▉   | 18/26 [00:34<00:15,  1.91s/it]

Loss: 0.4668973684310913


Epoch 8:  73%|███████▎  | 19/26 [00:36<00:13,  1.91s/it]

Loss: 0.4268971085548401


Epoch 8:  77%|███████▋  | 20/26 [00:38<00:11,  1.91s/it]

Loss: 0.27425557374954224


Epoch 8:  81%|████████  | 21/26 [00:40<00:09,  1.91s/it]

Loss: 0.5284339189529419


Epoch 8:  85%|████████▍ | 22/26 [00:41<00:07,  1.91s/it]

Loss: 0.5672687292098999


Epoch 8:  88%|████████▊ | 23/26 [00:43<00:05,  1.91s/it]

Loss: 0.4153665602207184


Epoch 8:  92%|█████████▏| 24/26 [00:45<00:03,  1.91s/it]

Loss: 0.6502389907836914


Epoch 8:  96%|█████████▌| 25/26 [00:47<00:01,  1.91s/it]

Loss: 0.4520733654499054


Epoch 8: 100%|██████████| 26/26 [00:48<00:00,  1.87s/it]

Loss: 0.5509653687477112





Epoch 8 Validation Accuracy: 0.7389162561576355, F1-macro: 0.5021056041464205


Epoch 9:   4%|▍         | 1/26 [00:01<00:47,  1.90s/it]

Loss: 0.46223655343055725


Epoch 9:   8%|▊         | 2/26 [00:03<00:45,  1.91s/it]

Loss: 0.27706027030944824


Epoch 9:  12%|█▏        | 3/26 [00:05<00:43,  1.91s/it]

Loss: 0.3083742558956146


Epoch 9:  15%|█▌        | 4/26 [00:07<00:41,  1.91s/it]

Loss: 0.2604454755783081


Epoch 9:  19%|█▉        | 5/26 [00:09<00:40,  1.91s/it]

Loss: 0.6302690505981445


Epoch 9:  23%|██▎       | 6/26 [00:11<00:38,  1.91s/it]

Loss: 0.19695034623146057


Epoch 9:  27%|██▋       | 7/26 [00:13<00:36,  1.91s/it]

Loss: 0.46913760900497437


Epoch 9:  31%|███       | 8/26 [00:15<00:34,  1.91s/it]

Loss: 0.2742736041545868


Epoch 9:  35%|███▍      | 9/26 [00:17<00:32,  1.91s/it]

Loss: 0.27555525302886963


Epoch 9:  38%|███▊      | 10/26 [00:19<00:30,  1.91s/it]

Loss: 0.3563418984413147


Epoch 9:  42%|████▏     | 11/26 [00:20<00:28,  1.91s/it]

Loss: 0.24351109564304352


Epoch 9:  46%|████▌     | 12/26 [00:22<00:26,  1.91s/it]

Loss: 0.6635037660598755


Epoch 9:  50%|█████     | 13/26 [00:24<00:24,  1.91s/it]

Loss: 0.47108688950538635


Epoch 9:  54%|█████▍    | 14/26 [00:26<00:22,  1.91s/it]

Loss: 0.3842412233352661


Epoch 9:  58%|█████▊    | 15/26 [00:28<00:20,  1.91s/it]

Loss: 0.4301033616065979


Epoch 9:  62%|██████▏   | 16/26 [00:30<00:19,  1.91s/it]

Loss: 0.2809159755706787


Epoch 9:  65%|██████▌   | 17/26 [00:32<00:17,  1.91s/it]

Loss: 0.4936921000480652


Epoch 9:  69%|██████▉   | 18/26 [00:34<00:15,  1.91s/it]

Loss: 0.3113526701927185


Epoch 9:  73%|███████▎  | 19/26 [00:36<00:13,  1.91s/it]

Loss: 0.5379087924957275


Epoch 9:  77%|███████▋  | 20/26 [00:38<00:11,  1.91s/it]

Loss: 0.6892321109771729


Epoch 9:  81%|████████  | 21/26 [00:40<00:09,  1.91s/it]

Loss: 0.6230624914169312


Epoch 9:  85%|████████▍ | 22/26 [00:41<00:07,  1.91s/it]

Loss: 0.5850590467453003


Epoch 9:  88%|████████▊ | 23/26 [00:43<00:05,  1.91s/it]

Loss: 0.4526224434375763


Epoch 9:  92%|█████████▏| 24/26 [00:45<00:03,  1.91s/it]

Loss: 0.3287851810455322


Epoch 9:  96%|█████████▌| 25/26 [00:47<00:01,  1.91s/it]

Loss: 0.4597957730293274


Epoch 9: 100%|██████████| 26/26 [00:48<00:00,  1.87s/it]

Loss: 0.17066967487335205





Epoch 9 Validation Accuracy: 0.7980295566502463, F1-macro: 0.5688681688681689


Epoch 10:   4%|▍         | 1/26 [00:01<00:47,  1.90s/it]

Loss: 0.25088661909103394


Epoch 10:   8%|▊         | 2/26 [00:03<00:45,  1.91s/it]

Loss: 0.20655596256256104


Epoch 10:  12%|█▏        | 3/26 [00:05<00:43,  1.91s/it]

Loss: 0.44540339708328247


Epoch 10:  15%|█▌        | 4/26 [00:07<00:41,  1.91s/it]

Loss: 0.295808881521225


Epoch 10:  19%|█▉        | 5/26 [00:09<00:40,  1.91s/it]

Loss: 0.2800682783126831


Epoch 10:  23%|██▎       | 6/26 [00:11<00:38,  1.91s/it]

Loss: 0.36258482933044434


Epoch 10:  27%|██▋       | 7/26 [00:13<00:36,  1.91s/it]

Loss: 0.3144022822380066


Epoch 10:  31%|███       | 8/26 [00:15<00:34,  1.91s/it]

Loss: 0.4331991672515869


Epoch 10:  35%|███▍      | 9/26 [00:17<00:32,  1.91s/it]

Loss: 0.3215753734111786


Epoch 10:  38%|███▊      | 10/26 [00:19<00:30,  1.91s/it]

Loss: 0.48301154375076294


Epoch 10:  42%|████▏     | 11/26 [00:20<00:28,  1.91s/it]

Loss: 0.42014363408088684


Epoch 10:  46%|████▌     | 12/26 [00:22<00:26,  1.91s/it]

Loss: 0.38693904876708984


Epoch 10:  50%|█████     | 13/26 [00:24<00:24,  1.91s/it]

Loss: 0.5074863433837891


Epoch 10:  54%|█████▍    | 14/26 [00:26<00:22,  1.91s/it]

Loss: 0.42862668633461


Epoch 10:  58%|█████▊    | 15/26 [00:28<00:20,  1.91s/it]

Loss: 0.2168656587600708


Epoch 10:  62%|██████▏   | 16/26 [00:30<00:19,  1.91s/it]

Loss: 0.48615169525146484


Epoch 10:  65%|██████▌   | 17/26 [00:32<00:17,  1.91s/it]

Loss: 0.41336768865585327


Epoch 10:  69%|██████▉   | 18/26 [00:34<00:15,  1.91s/it]

Loss: 0.41691574454307556


Epoch 10:  73%|███████▎  | 19/26 [00:36<00:13,  1.91s/it]

Loss: 0.428270548582077


Epoch 10:  77%|███████▋  | 20/26 [00:38<00:11,  1.91s/it]

Loss: 0.4129948914051056


Epoch 10:  81%|████████  | 21/26 [00:40<00:09,  1.91s/it]

Loss: 0.33300545811653137


Epoch 10:  85%|████████▍ | 22/26 [00:41<00:07,  1.91s/it]

Loss: 0.4207206964492798


Epoch 10:  88%|████████▊ | 23/26 [00:43<00:05,  1.91s/it]

Loss: 0.6454613208770752


Epoch 10:  92%|█████████▏| 24/26 [00:45<00:03,  1.91s/it]

Loss: 0.36608216166496277


Epoch 10:  96%|█████████▌| 25/26 [00:47<00:01,  1.91s/it]

Loss: 1.4486496448516846


Epoch 10: 100%|██████████| 26/26 [00:48<00:00,  1.87s/it]

Loss: 1.1349002122879028





Epoch 10 Validation Accuracy: 0.8325123152709359, F1-macro: 0.5064359267734554


In [None]:
# Evaluate on the test set
test_metrics = evaluate(model, test_dataloader, device)
print(f"Test Accuracy: {test_metrics['accuracy']}, F1-macro: {test_metrics['f1_macro']}")

Test Accuracy: 0.8676470588235294, F1-macro: 0.5064359267734554


In [None]:
current_type = 'magnification'

results_df.loc[
    results_df["distortion_type"] == current_type,
    ["test_accuracy", "f1_macro"]
] = [test_metrics['accuracy'], test_metrics['f1_macro']]

# 8. Mind Reading

In [None]:
# Add labels
data1_1_labels = list(data1['mind reading'][data1_1.index])
data2_1_labels = list(data2['mind reading'][data2_1.index])
data3_1_labels = list(data3['mind reading'][data3_1.index])

# Merging Data
data_encoded = data1_1_encoded + data2_1_encoded + data3_1_encoded
data_labels = data1_1_labels + data2_1_labels + data3_1_labels

dataset_with_labels = CustomDatasetWithLabels(data_encoded, data_labels)

# Define proportions for splitting
train_size = int(0.8 * len(dataset_with_labels))
val_size = int(0.1 * len(dataset_with_labels))
test_size = len(dataset_with_labels) - train_size - val_size

# Split the dataset
train_dataset, val_dataset, test_dataset = random_split(dataset_with_labels, [train_size, val_size, test_size])

# Create DataLoaders for each set
train_dataloader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=64)
test_dataloader = DataLoader(test_dataset, batch_size=64)

input_dim = bert_config.hidden_size
num_labels = len(set(data_labels))
label_emb = torch.randn(num_labels, input_dim)

In [None]:
# Instantiate the InnerProductClassifier model
model = InnerProductClassifier(input_dim, label_emb).to(device)

# Define loss function
criterion = nn.CrossEntropyLoss()

# Define optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=2e-4)

In [None]:
EPOCHS = 10

# Move the BERT model to the device
bert_model.to(device)

for epoch in range(EPOCHS):
    model.train()
    for batch in tqdm(train_dataloader, desc=f"Epoch {epoch+1}"):

        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        y = batch["y"].to(device)

        # Get embeddings from the BERT model
        with torch.no_grad():
            outputs = bert_model(input_ids=input_ids, attention_mask=attention_mask)
            embeddings = outputs.last_hidden_state[:, 0, :] # Get the [CLS] token embedding

        # Pass embeddings to the InnerProductClassifier
        logits = model(embeddings)
        loss = criterion(logits, y)

        # Backpropagate and update weights
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        print(f"Loss: {loss.item()}")

    # Evaluate on the validation set after each epoch
    val_metrics = evaluate(model, val_dataloader, device)
    print(f"Epoch {epoch+1} Validation Accuracy: {val_metrics['accuracy']}, F1-macro: {val_metrics['f1_macro']}")

Epoch 1:   4%|▍         | 1/26 [00:01<00:47,  1.88s/it]

Loss: 2.9873878955841064


Epoch 1:   8%|▊         | 2/26 [00:03<00:45,  1.90s/it]

Loss: 16.783336639404297


Epoch 1:  12%|█▏        | 3/26 [00:05<00:43,  1.90s/it]

Loss: 2.747978687286377


Epoch 1:  15%|█▌        | 4/26 [00:07<00:41,  1.90s/it]

Loss: 3.711137533187866


Epoch 1:  19%|█▉        | 5/26 [00:09<00:39,  1.90s/it]

Loss: 6.734628200531006


Epoch 1:  23%|██▎       | 6/26 [00:11<00:38,  1.90s/it]

Loss: 11.721153259277344


Epoch 1:  27%|██▋       | 7/26 [00:13<00:36,  1.91s/it]

Loss: 9.924911499023438


Epoch 1:  31%|███       | 8/26 [00:15<00:34,  1.91s/it]

Loss: 6.926122665405273


Epoch 1:  35%|███▍      | 9/26 [00:17<00:32,  1.91s/it]

Loss: 9.229948043823242


Epoch 1:  38%|███▊      | 10/26 [00:19<00:30,  1.91s/it]

Loss: 11.236526489257812


Epoch 1:  42%|████▏     | 11/26 [00:20<00:28,  1.91s/it]

Loss: 5.6506452560424805


Epoch 1:  46%|████▌     | 12/26 [00:22<00:26,  1.91s/it]

Loss: 4.670531272888184


Epoch 1:  50%|█████     | 13/26 [00:24<00:24,  1.91s/it]

Loss: 1.9895316362380981


Epoch 1:  54%|█████▍    | 14/26 [00:26<00:22,  1.91s/it]

Loss: 4.978365898132324


Epoch 1:  58%|█████▊    | 15/26 [00:28<00:20,  1.91s/it]

Loss: 6.11867094039917


Epoch 1:  62%|██████▏   | 16/26 [00:30<00:19,  1.91s/it]

Loss: 4.3163838386535645


Epoch 1:  65%|██████▌   | 17/26 [00:32<00:17,  1.91s/it]

Loss: 1.8663723468780518


Epoch 1:  69%|██████▉   | 18/26 [00:34<00:15,  1.91s/it]

Loss: 3.11868953704834


Epoch 1:  73%|███████▎  | 19/26 [00:36<00:13,  1.91s/it]

Loss: 4.095675468444824


Epoch 1:  77%|███████▋  | 20/26 [00:38<00:11,  1.91s/it]

Loss: 5.248601913452148


Epoch 1:  81%|████████  | 21/26 [00:40<00:09,  1.91s/it]

Loss: 4.492910385131836


Epoch 1:  85%|████████▍ | 22/26 [00:41<00:07,  1.91s/it]

Loss: 4.211960792541504


Epoch 1:  88%|████████▊ | 23/26 [00:43<00:05,  1.91s/it]

Loss: 4.713515281677246


Epoch 1:  92%|█████████▏| 24/26 [00:45<00:03,  1.91s/it]

Loss: 2.067028045654297


Epoch 1:  96%|█████████▌| 25/26 [00:47<00:01,  1.91s/it]

Loss: 1.6189348697662354


Epoch 1: 100%|██████████| 26/26 [00:48<00:00,  1.87s/it]

Loss: 2.447408676147461





Epoch 1 Validation Accuracy: 0.541871921182266, F1-macro: 0.5147410358565737


Epoch 2:   4%|▍         | 1/26 [00:01<00:47,  1.90s/it]

Loss: 1.6424283981323242


Epoch 2:   8%|▊         | 2/26 [00:03<00:45,  1.91s/it]

Loss: 2.3374688625335693


Epoch 2:  12%|█▏        | 3/26 [00:05<00:43,  1.91s/it]

Loss: 1.8695341348648071


Epoch 2:  15%|█▌        | 4/26 [00:07<00:41,  1.91s/it]

Loss: 1.1615052223205566


Epoch 2:  19%|█▉        | 5/26 [00:09<00:40,  1.91s/it]

Loss: 1.2907143831253052


Epoch 2:  23%|██▎       | 6/26 [00:11<00:38,  1.91s/it]

Loss: 1.8024840354919434


Epoch 2:  27%|██▋       | 7/26 [00:13<00:36,  1.91s/it]

Loss: 1.7748382091522217


Epoch 2:  31%|███       | 8/26 [00:15<00:34,  1.91s/it]

Loss: 2.474282741546631


Epoch 2:  35%|███▍      | 9/26 [00:17<00:32,  1.91s/it]

Loss: 0.9761592149734497


Epoch 2:  38%|███▊      | 10/26 [00:19<00:30,  1.91s/it]

Loss: 1.6321622133255005


Epoch 2:  42%|████▏     | 11/26 [00:20<00:28,  1.91s/it]

Loss: 1.9265379905700684


Epoch 2:  46%|████▌     | 12/26 [00:22<00:26,  1.91s/it]

Loss: 1.1452038288116455


Epoch 2:  50%|█████     | 13/26 [00:24<00:24,  1.91s/it]

Loss: 1.239227533340454


Epoch 2:  54%|█████▍    | 14/26 [00:26<00:22,  1.91s/it]

Loss: 0.6796367764472961


Epoch 2:  58%|█████▊    | 15/26 [00:28<00:20,  1.91s/it]

Loss: 1.9221755266189575


Epoch 2:  62%|██████▏   | 16/26 [00:30<00:19,  1.91s/it]

Loss: 1.1648447513580322


Epoch 2:  65%|██████▌   | 17/26 [00:32<00:17,  1.91s/it]

Loss: 0.5960458517074585


Epoch 2:  69%|██████▉   | 18/26 [00:34<00:15,  1.91s/it]

Loss: 1.2745249271392822


Epoch 2:  73%|███████▎  | 19/26 [00:36<00:13,  1.91s/it]

Loss: 1.3891892433166504


Epoch 2:  77%|███████▋  | 20/26 [00:38<00:11,  1.91s/it]

Loss: 0.6490066647529602


Epoch 2:  81%|████████  | 21/26 [00:40<00:09,  1.91s/it]

Loss: 0.8240312337875366


Epoch 2:  85%|████████▍ | 22/26 [00:41<00:07,  1.91s/it]

Loss: 1.2871251106262207


Epoch 2:  88%|████████▊ | 23/26 [00:43<00:05,  1.91s/it]

Loss: 1.5271344184875488


Epoch 2:  92%|█████████▏| 24/26 [00:45<00:03,  1.91s/it]

Loss: 0.49826279282569885


Epoch 2:  96%|█████████▌| 25/26 [00:47<00:01,  1.91s/it]

Loss: 1.1654213666915894


Epoch 2: 100%|██████████| 26/26 [00:48<00:00,  1.87s/it]

Loss: 0.5716024041175842





Epoch 2 Validation Accuracy: 0.6354679802955665, F1-macro: 0.5276100628930818


Epoch 3:   4%|▍         | 1/26 [00:01<00:47,  1.90s/it]

Loss: 0.7717134952545166


Epoch 3:   8%|▊         | 2/26 [00:03<00:45,  1.91s/it]

Loss: 0.7396944761276245


Epoch 3:  12%|█▏        | 3/26 [00:05<00:43,  1.91s/it]

Loss: 0.619899332523346


Epoch 3:  15%|█▌        | 4/26 [00:07<00:41,  1.91s/it]

Loss: 0.870506227016449


Epoch 3:  19%|█▉        | 5/26 [00:09<00:40,  1.91s/it]

Loss: 0.4872632920742035


Epoch 3:  23%|██▎       | 6/26 [00:11<00:38,  1.91s/it]

Loss: 0.5307095050811768


Epoch 3:  27%|██▋       | 7/26 [00:13<00:36,  1.91s/it]

Loss: 1.0183898210525513


Epoch 3:  31%|███       | 8/26 [00:15<00:34,  1.91s/it]

Loss: 1.209704875946045


Epoch 3:  35%|███▍      | 9/26 [00:17<00:32,  1.91s/it]

Loss: 0.6784113645553589


Epoch 3:  38%|███▊      | 10/26 [00:19<00:30,  1.91s/it]

Loss: 0.9417917728424072


Epoch 3:  42%|████▏     | 11/26 [00:20<00:28,  1.91s/it]

Loss: 1.170961856842041


Epoch 3:  46%|████▌     | 12/26 [00:22<00:26,  1.91s/it]

Loss: 0.6635226011276245


Epoch 3:  50%|█████     | 13/26 [00:24<00:24,  1.91s/it]

Loss: 1.6151567697525024


Epoch 3:  54%|█████▍    | 14/26 [00:26<00:22,  1.91s/it]

Loss: 1.104972004890442


Epoch 3:  58%|█████▊    | 15/26 [00:28<00:20,  1.91s/it]

Loss: 1.4865633249282837


Epoch 3:  62%|██████▏   | 16/26 [00:30<00:19,  1.91s/it]

Loss: 1.6572579145431519


Epoch 3:  65%|██████▌   | 17/26 [00:32<00:17,  1.91s/it]

Loss: 0.9051185250282288


Epoch 3:  69%|██████▉   | 18/26 [00:34<00:15,  1.91s/it]

Loss: 1.0304714441299438


Epoch 3:  73%|███████▎  | 19/26 [00:36<00:13,  1.91s/it]

Loss: 1.1059677600860596


Epoch 3:  77%|███████▋  | 20/26 [00:38<00:11,  1.91s/it]

Loss: 0.6211480498313904


Epoch 3:  81%|████████  | 21/26 [00:40<00:09,  1.91s/it]

Loss: 0.3168390989303589


Epoch 3:  85%|████████▍ | 22/26 [00:41<00:07,  1.91s/it]

Loss: 0.7758535146713257


Epoch 3:  88%|████████▊ | 23/26 [00:43<00:05,  1.91s/it]

Loss: 1.4403218030929565


Epoch 3:  92%|█████████▏| 24/26 [00:45<00:03,  1.91s/it]

Loss: 0.9746054410934448


Epoch 3:  96%|█████████▌| 25/26 [00:47<00:01,  1.91s/it]

Loss: 0.9637441039085388


Epoch 3: 100%|██████████| 26/26 [00:48<00:00,  1.87s/it]

Loss: 0.7609134316444397





Epoch 3 Validation Accuracy: 0.6995073891625616, F1-macro: 0.5610888597455074


Epoch 4:   4%|▍         | 1/26 [00:01<00:47,  1.90s/it]

Loss: 0.5730301141738892


Epoch 4:   8%|▊         | 2/26 [00:03<00:45,  1.91s/it]

Loss: 0.7210158109664917


Epoch 4:  12%|█▏        | 3/26 [00:05<00:43,  1.91s/it]

Loss: 0.688377857208252


Epoch 4:  15%|█▌        | 4/26 [00:07<00:41,  1.91s/it]

Loss: 1.0112473964691162


Epoch 4:  19%|█▉        | 5/26 [00:09<00:40,  1.91s/it]

Loss: 0.6063519716262817


Epoch 4:  23%|██▎       | 6/26 [00:11<00:38,  1.91s/it]

Loss: 0.9410876035690308


Epoch 4:  27%|██▋       | 7/26 [00:13<00:36,  1.91s/it]

Loss: 1.2129557132720947


Epoch 4:  31%|███       | 8/26 [00:15<00:34,  1.91s/it]

Loss: 0.4456067681312561


Epoch 4:  35%|███▍      | 9/26 [00:17<00:32,  1.91s/it]

Loss: 1.3269429206848145


Epoch 4:  38%|███▊      | 10/26 [00:19<00:30,  1.91s/it]

Loss: 0.8889816403388977


Epoch 4:  42%|████▏     | 11/26 [00:20<00:28,  1.91s/it]

Loss: 0.8330955505371094


Epoch 4:  46%|████▌     | 12/26 [00:22<00:26,  1.91s/it]

Loss: 0.5427824258804321


Epoch 4:  50%|█████     | 13/26 [00:24<00:24,  1.91s/it]

Loss: 0.9340602159500122


Epoch 4:  54%|█████▍    | 14/26 [00:26<00:22,  1.91s/it]

Loss: 0.6165969371795654


Epoch 4:  58%|█████▊    | 15/26 [00:28<00:20,  1.91s/it]

Loss: 1.7035799026489258


Epoch 4:  62%|██████▏   | 16/26 [00:30<00:19,  1.91s/it]

Loss: 0.9340308904647827


Epoch 4:  65%|██████▌   | 17/26 [00:32<00:17,  1.91s/it]

Loss: 1.1405763626098633


Epoch 4:  69%|██████▉   | 18/26 [00:34<00:15,  1.91s/it]

Loss: 0.5450117588043213


Epoch 4:  73%|███████▎  | 19/26 [00:36<00:13,  1.91s/it]

Loss: 0.5254533886909485


Epoch 4:  77%|███████▋  | 20/26 [00:38<00:11,  1.91s/it]

Loss: 0.555091142654419


Epoch 4:  81%|████████  | 21/26 [00:40<00:09,  1.91s/it]

Loss: 0.4802018105983734


Epoch 4:  85%|████████▍ | 22/26 [00:41<00:07,  1.91s/it]

Loss: 1.2197949886322021


Epoch 4:  88%|████████▊ | 23/26 [00:43<00:05,  1.91s/it]

Loss: 0.5611087679862976


Epoch 4:  92%|█████████▏| 24/26 [00:45<00:03,  1.91s/it]

Loss: 0.8941527605056763


Epoch 4:  96%|█████████▌| 25/26 [00:47<00:01,  1.91s/it]

Loss: 1.010183572769165


Epoch 4: 100%|██████████| 26/26 [00:48<00:00,  1.87s/it]

Loss: 1.2053061723709106





Epoch 4 Validation Accuracy: 0.7438423645320197, F1-macro: 0.5050637659414854


Epoch 5:   4%|▍         | 1/26 [00:01<00:47,  1.90s/it]

Loss: 0.6006544232368469


Epoch 5:   8%|▊         | 2/26 [00:03<00:45,  1.91s/it]

Loss: 0.49733954668045044


Epoch 5:  12%|█▏        | 3/26 [00:05<00:43,  1.91s/it]

Loss: 0.586387574672699


Epoch 5:  15%|█▌        | 4/26 [00:07<00:41,  1.91s/it]

Loss: 0.6460105180740356


Epoch 5:  19%|█▉        | 5/26 [00:09<00:40,  1.91s/it]

Loss: 0.5240355730056763


Epoch 5:  23%|██▎       | 6/26 [00:11<00:38,  1.91s/it]

Loss: 0.5986462831497192


Epoch 5:  27%|██▋       | 7/26 [00:13<00:36,  1.91s/it]

Loss: 0.7233669757843018


Epoch 5:  31%|███       | 8/26 [00:15<00:34,  1.91s/it]

Loss: 0.6122134327888489


Epoch 5:  35%|███▍      | 9/26 [00:17<00:32,  1.91s/it]

Loss: 0.3909326195716858


Epoch 5:  38%|███▊      | 10/26 [00:19<00:30,  1.91s/it]

Loss: 0.6465482711791992


Epoch 5:  42%|████▏     | 11/26 [00:20<00:28,  1.91s/it]

Loss: 0.40270382165908813


Epoch 5:  46%|████▌     | 12/26 [00:22<00:26,  1.91s/it]

Loss: 0.6410016417503357


Epoch 5:  50%|█████     | 13/26 [00:24<00:24,  1.91s/it]

Loss: 0.8833451271057129


Epoch 5:  54%|█████▍    | 14/26 [00:26<00:22,  1.91s/it]

Loss: 0.6242934465408325


Epoch 5:  58%|█████▊    | 15/26 [00:28<00:20,  1.91s/it]

Loss: 0.6666685342788696


Epoch 5:  62%|██████▏   | 16/26 [00:30<00:19,  1.91s/it]

Loss: 0.7923409938812256


Epoch 5:  65%|██████▌   | 17/26 [00:32<00:17,  1.91s/it]

Loss: 0.6121875643730164


Epoch 5:  69%|██████▉   | 18/26 [00:34<00:15,  1.91s/it]

Loss: 0.8444907069206238


Epoch 5:  73%|███████▎  | 19/26 [00:36<00:13,  1.91s/it]

Loss: 0.9580085277557373


Epoch 5:  77%|███████▋  | 20/26 [00:38<00:11,  1.91s/it]

Loss: 0.3244782090187073


Epoch 5:  81%|████████  | 21/26 [00:40<00:09,  1.91s/it]

Loss: 0.902641773223877


Epoch 5:  85%|████████▍ | 22/26 [00:41<00:07,  1.91s/it]

Loss: 0.906571090221405


Epoch 5:  88%|████████▊ | 23/26 [00:43<00:05,  1.91s/it]

Loss: 0.7736325860023499


Epoch 5:  92%|█████████▏| 24/26 [00:45<00:03,  1.91s/it]

Loss: 0.957485556602478


Epoch 5:  96%|█████████▌| 25/26 [00:47<00:01,  1.91s/it]

Loss: 0.6666905879974365


Epoch 5: 100%|██████████| 26/26 [00:48<00:00,  1.87s/it]

Loss: 0.3738188147544861





Epoch 5 Validation Accuracy: 0.6354679802955665, F1-macro: 0.5664396213345648


Epoch 6:   4%|▍         | 1/26 [00:01<00:47,  1.90s/it]

Loss: 0.5529443621635437


Epoch 6:   8%|▊         | 2/26 [00:03<00:45,  1.91s/it]

Loss: 0.7309290170669556


Epoch 6:  12%|█▏        | 3/26 [00:05<00:43,  1.91s/it]

Loss: 0.3818672299385071


Epoch 6:  15%|█▌        | 4/26 [00:07<00:41,  1.91s/it]

Loss: 0.4718599319458008


Epoch 6:  19%|█▉        | 5/26 [00:09<00:40,  1.91s/it]

Loss: 0.2900271415710449


Epoch 6:  23%|██▎       | 6/26 [00:11<00:38,  1.91s/it]

Loss: 0.6080721616744995


Epoch 6:  27%|██▋       | 7/26 [00:13<00:36,  1.91s/it]

Loss: 0.6323574781417847


Epoch 6:  31%|███       | 8/26 [00:15<00:34,  1.91s/it]

Loss: 0.3297685980796814


Epoch 6:  35%|███▍      | 9/26 [00:17<00:32,  1.91s/it]

Loss: 0.4540680944919586


Epoch 6:  38%|███▊      | 10/26 [00:19<00:30,  1.91s/it]

Loss: 0.1927192360162735


Epoch 6:  42%|████▏     | 11/26 [00:20<00:28,  1.91s/it]

Loss: 0.6141976714134216


Epoch 6:  46%|████▌     | 12/26 [00:22<00:26,  1.91s/it]

Loss: 0.48362332582473755


Epoch 6:  50%|█████     | 13/26 [00:24<00:24,  1.91s/it]

Loss: 0.3133885860443115


Epoch 6:  54%|█████▍    | 14/26 [00:26<00:22,  1.91s/it]

Loss: 0.4794142246246338


Epoch 6:  58%|█████▊    | 15/26 [00:28<00:21,  1.91s/it]

Loss: 0.37001317739486694


Epoch 6:  62%|██████▏   | 16/26 [00:30<00:19,  1.91s/it]

Loss: 0.6743125319480896


Epoch 6:  65%|██████▌   | 17/26 [00:32<00:17,  1.91s/it]

Loss: 0.5463582277297974


Epoch 6:  69%|██████▉   | 18/26 [00:34<00:15,  1.91s/it]

Loss: 0.29004570841789246


Epoch 6:  73%|███████▎  | 19/26 [00:36<00:13,  1.91s/it]

Loss: 0.3587627708911896


Epoch 6:  77%|███████▋  | 20/26 [00:38<00:11,  1.91s/it]

Loss: 0.36627936363220215


Epoch 6:  81%|████████  | 21/26 [00:40<00:09,  1.91s/it]

Loss: 0.6811375617980957


Epoch 6:  85%|████████▍ | 22/26 [00:42<00:07,  1.91s/it]

Loss: 0.8036065101623535


Epoch 6:  88%|████████▊ | 23/26 [00:43<00:05,  1.91s/it]

Loss: 0.6391710042953491


Epoch 6:  92%|█████████▏| 24/26 [00:45<00:03,  1.91s/it]

Loss: 0.6616564989089966


Epoch 6:  96%|█████████▌| 25/26 [00:47<00:01,  1.91s/it]

Loss: 0.5937381386756897


Epoch 6: 100%|██████████| 26/26 [00:48<00:00,  1.87s/it]

Loss: 1.2245490550994873





Epoch 6 Validation Accuracy: 0.7142857142857143, F1-macro: 0.5207587105177467


Epoch 7:   4%|▍         | 1/26 [00:01<00:47,  1.90s/it]

Loss: 0.37713637948036194


Epoch 7:   8%|▊         | 2/26 [00:03<00:45,  1.91s/it]

Loss: 0.6025168895721436


Epoch 7:  12%|█▏        | 3/26 [00:05<00:43,  1.91s/it]

Loss: 0.2788858711719513


Epoch 7:  15%|█▌        | 4/26 [00:07<00:41,  1.91s/it]

Loss: 0.43532538414001465


Epoch 7:  19%|█▉        | 5/26 [00:09<00:40,  1.91s/it]

Loss: 0.4893539845943451


Epoch 7:  23%|██▎       | 6/26 [00:11<00:38,  1.91s/it]

Loss: 0.6924228072166443


Epoch 7:  27%|██▋       | 7/26 [00:13<00:36,  1.91s/it]

Loss: 0.785942018032074


Epoch 7:  31%|███       | 8/26 [00:15<00:34,  1.91s/it]

Loss: 0.5688502192497253


Epoch 7:  35%|███▍      | 9/26 [00:17<00:32,  1.91s/it]

Loss: 0.4841673672199249


Epoch 7:  38%|███▊      | 10/26 [00:19<00:30,  1.91s/it]

Loss: 0.21865998208522797


Epoch 7:  42%|████▏     | 11/26 [00:20<00:28,  1.91s/it]

Loss: 0.3854963183403015


Epoch 7:  46%|████▌     | 12/26 [00:22<00:26,  1.91s/it]

Loss: 0.2641430199146271


Epoch 7:  50%|█████     | 13/26 [00:24<00:24,  1.91s/it]

Loss: 0.34073564410209656


Epoch 7:  54%|█████▍    | 14/26 [00:26<00:22,  1.91s/it]

Loss: 0.3347904682159424


Epoch 7:  58%|█████▊    | 15/26 [00:28<00:21,  1.91s/it]

Loss: 0.40540140867233276


Epoch 7:  62%|██████▏   | 16/26 [00:30<00:19,  1.91s/it]

Loss: 0.7883373498916626


Epoch 7:  65%|██████▌   | 17/26 [00:32<00:17,  1.91s/it]

Loss: 0.6954253911972046


Epoch 7:  69%|██████▉   | 18/26 [00:34<00:15,  1.91s/it]

Loss: 0.2773456871509552


Epoch 7:  73%|███████▎  | 19/26 [00:36<00:13,  1.91s/it]

Loss: 0.8594381809234619


Epoch 7:  77%|███████▋  | 20/26 [00:38<00:11,  1.91s/it]

Loss: 0.5279874205589294


Epoch 7:  81%|████████  | 21/26 [00:40<00:09,  1.91s/it]

Loss: 0.6019710302352905


Epoch 7:  85%|████████▍ | 22/26 [00:42<00:07,  1.91s/it]

Loss: 0.406634658575058


Epoch 7:  88%|████████▊ | 23/26 [00:43<00:05,  1.91s/it]

Loss: 0.57032310962677


Epoch 7:  92%|█████████▏| 24/26 [00:45<00:03,  1.91s/it]

Loss: 0.4962749183177948


Epoch 7:  96%|█████████▌| 25/26 [00:47<00:01,  1.91s/it]

Loss: 0.7574830055236816


Epoch 7: 100%|██████████| 26/26 [00:48<00:00,  1.87s/it]

Loss: 0.8983039855957031





Epoch 7 Validation Accuracy: 0.7339901477832512, F1-macro: 0.5230595196658545


Epoch 8:   4%|▍         | 1/26 [00:01<00:47,  1.90s/it]

Loss: 0.2458477020263672


Epoch 8:   8%|▊         | 2/26 [00:03<00:45,  1.91s/it]

Loss: 0.9640829563140869


Epoch 8:  12%|█▏        | 3/26 [00:05<00:43,  1.91s/it]

Loss: 0.30712494254112244


Epoch 8:  15%|█▌        | 4/26 [00:07<00:41,  1.91s/it]

Loss: 0.5055091381072998


Epoch 8:  19%|█▉        | 5/26 [00:09<00:40,  1.91s/it]

Loss: 0.5650674104690552


Epoch 8:  23%|██▎       | 6/26 [00:11<00:38,  1.91s/it]

Loss: 0.514018177986145


Epoch 8:  27%|██▋       | 7/26 [00:13<00:36,  1.91s/it]

Loss: 0.727940559387207


Epoch 8:  31%|███       | 8/26 [00:15<00:34,  1.91s/it]

Loss: 0.4238896369934082


Epoch 8:  35%|███▍      | 9/26 [00:17<00:32,  1.91s/it]

Loss: 0.5058252811431885


Epoch 8:  38%|███▊      | 10/26 [00:19<00:30,  1.91s/it]

Loss: 0.39664414525032043


Epoch 8:  42%|████▏     | 11/26 [00:20<00:28,  1.91s/it]

Loss: 0.3648260235786438


Epoch 8:  46%|████▌     | 12/26 [00:22<00:26,  1.91s/it]

Loss: 0.4353429079055786


Epoch 8:  50%|█████     | 13/26 [00:24<00:24,  1.91s/it]

Loss: 0.6780911684036255


Epoch 8:  54%|█████▍    | 14/26 [00:26<00:22,  1.91s/it]

Loss: 0.3689180910587311


Epoch 8:  58%|█████▊    | 15/26 [00:28<00:20,  1.91s/it]

Loss: 0.4550643861293793


Epoch 8:  62%|██████▏   | 16/26 [00:30<00:19,  1.91s/it]

Loss: 0.3432241976261139


Epoch 8:  65%|██████▌   | 17/26 [00:32<00:17,  1.91s/it]

Loss: 0.6266213655471802


Epoch 8:  69%|██████▉   | 18/26 [00:34<00:15,  1.91s/it]

Loss: 0.8318151235580444


Epoch 8:  73%|███████▎  | 19/26 [00:36<00:13,  1.91s/it]

Loss: 0.5346708297729492


Epoch 8:  77%|███████▋  | 20/26 [00:38<00:11,  1.91s/it]

Loss: 0.5421384572982788


Epoch 8:  81%|████████  | 21/26 [00:40<00:09,  1.91s/it]

Loss: 0.8645792007446289


Epoch 8:  85%|████████▍ | 22/26 [00:41<00:07,  1.91s/it]

Loss: 0.7743618488311768


Epoch 8:  88%|████████▊ | 23/26 [00:43<00:05,  1.91s/it]

Loss: 0.5086440443992615


Epoch 8:  92%|█████████▏| 24/26 [00:45<00:03,  1.91s/it]

Loss: 0.6640837788581848


Epoch 8:  96%|█████████▌| 25/26 [00:47<00:01,  1.91s/it]

Loss: 0.3053835332393646


Epoch 8: 100%|██████████| 26/26 [00:48<00:00,  1.87s/it]

Loss: 0.9598106741905212





Epoch 8 Validation Accuracy: 0.7487684729064039, F1-macro: 0.4943095784692033


Epoch 9:   4%|▍         | 1/26 [00:01<00:47,  1.90s/it]

Loss: 0.22055678069591522


Epoch 9:   8%|▊         | 2/26 [00:03<00:45,  1.91s/it]

Loss: 0.7579830884933472


Epoch 9:  12%|█▏        | 3/26 [00:05<00:43,  1.91s/it]

Loss: 1.2360858917236328


Epoch 9:  15%|█▌        | 4/26 [00:07<00:41,  1.91s/it]

Loss: 0.7515139579772949


Epoch 9:  19%|█▉        | 5/26 [00:09<00:40,  1.91s/it]

Loss: 0.6505343317985535


Epoch 9:  23%|██▎       | 6/26 [00:11<00:38,  1.91s/it]

Loss: 1.2447710037231445


Epoch 9:  27%|██▋       | 7/26 [00:13<00:36,  1.91s/it]

Loss: 0.5079065561294556


Epoch 9:  31%|███       | 8/26 [00:15<00:34,  1.91s/it]

Loss: 1.188506841659546


Epoch 9:  35%|███▍      | 9/26 [00:17<00:32,  1.91s/it]

Loss: 0.4424954652786255


Epoch 9:  38%|███▊      | 10/26 [00:19<00:30,  1.91s/it]

Loss: 1.047867774963379


Epoch 9:  42%|████▏     | 11/26 [00:20<00:28,  1.91s/it]

Loss: 1.454294204711914


Epoch 9:  46%|████▌     | 12/26 [00:22<00:26,  1.91s/it]

Loss: 1.2384790182113647


Epoch 9:  50%|█████     | 13/26 [00:24<00:24,  1.91s/it]

Loss: 0.21579711139202118


Epoch 9:  54%|█████▍    | 14/26 [00:26<00:22,  1.91s/it]

Loss: 1.1998891830444336


Epoch 9:  58%|█████▊    | 15/26 [00:28<00:20,  1.91s/it]

Loss: 0.8994961380958557


Epoch 9:  62%|██████▏   | 16/26 [00:30<00:19,  1.91s/it]

Loss: 0.5460013151168823


Epoch 9:  65%|██████▌   | 17/26 [00:32<00:17,  1.91s/it]

Loss: 1.1864542961120605


Epoch 9:  69%|██████▉   | 18/26 [00:34<00:15,  1.91s/it]

Loss: 0.563625156879425


Epoch 9:  73%|███████▎  | 19/26 [00:36<00:13,  1.91s/it]

Loss: 0.35797011852264404


Epoch 9:  77%|███████▋  | 20/26 [00:38<00:11,  1.91s/it]

Loss: 0.46460744738578796


Epoch 9:  81%|████████  | 21/26 [00:40<00:09,  1.91s/it]

Loss: 1.2744526863098145


Epoch 9:  85%|████████▍ | 22/26 [00:41<00:07,  1.91s/it]

Loss: 1.3881473541259766


Epoch 9:  88%|████████▊ | 23/26 [00:43<00:05,  1.91s/it]

Loss: 0.9422550201416016


Epoch 9:  92%|█████████▏| 24/26 [00:45<00:03,  1.91s/it]

Loss: 0.6196693778038025


Epoch 9:  96%|█████████▌| 25/26 [00:47<00:01,  1.91s/it]

Loss: 0.5626071691513062


Epoch 9: 100%|██████████| 26/26 [00:48<00:00,  1.87s/it]

Loss: 2.0247576236724854





Epoch 9 Validation Accuracy: 0.7389162561576355, F1-macro: 0.5145950823370178


Epoch 10:   4%|▍         | 1/26 [00:01<00:47,  1.90s/it]

Loss: 0.7720204591751099


Epoch 10:   8%|▊         | 2/26 [00:03<00:45,  1.91s/it]

Loss: 0.9555809497833252


Epoch 10:  12%|█▏        | 3/26 [00:05<00:43,  1.91s/it]

Loss: 0.6498369574546814


Epoch 10:  15%|█▌        | 4/26 [00:07<00:41,  1.91s/it]

Loss: 0.6342319250106812


Epoch 10:  19%|█▉        | 5/26 [00:09<00:40,  1.91s/it]

Loss: 0.4535999894142151


Epoch 10:  23%|██▎       | 6/26 [00:11<00:38,  1.91s/it]

Loss: 1.270136833190918


Epoch 10:  27%|██▋       | 7/26 [00:13<00:36,  1.91s/it]

Loss: 0.5723767280578613


Epoch 10:  31%|███       | 8/26 [00:15<00:34,  1.91s/it]

Loss: 0.5417965650558472


Epoch 10:  35%|███▍      | 9/26 [00:17<00:32,  1.91s/it]

Loss: 0.9951167106628418


Epoch 10:  38%|███▊      | 10/26 [00:19<00:30,  1.91s/it]

Loss: 0.789286732673645


Epoch 10:  42%|████▏     | 11/26 [00:20<00:28,  1.91s/it]

Loss: 0.4029916226863861


Epoch 10:  46%|████▌     | 12/26 [00:22<00:26,  1.91s/it]

Loss: 1.0910913944244385


Epoch 10:  50%|█████     | 13/26 [00:24<00:24,  1.91s/it]

Loss: 0.40966734290122986


Epoch 10:  54%|█████▍    | 14/26 [00:26<00:22,  1.91s/it]

Loss: 0.5005422830581665


Epoch 10:  58%|█████▊    | 15/26 [00:28<00:20,  1.91s/it]

Loss: 0.7619900107383728


Epoch 10:  62%|██████▏   | 16/26 [00:30<00:19,  1.91s/it]

Loss: 0.6705589294433594


Epoch 10:  65%|██████▌   | 17/26 [00:32<00:17,  1.91s/it]

Loss: 0.5491143465042114


Epoch 10:  69%|██████▉   | 18/26 [00:34<00:15,  1.91s/it]

Loss: 1.356773853302002


Epoch 10:  73%|███████▎  | 19/26 [00:36<00:13,  1.91s/it]

Loss: 0.6592299938201904


Epoch 10:  77%|███████▋  | 20/26 [00:38<00:11,  1.91s/it]

Loss: 0.43403762578964233


Epoch 10:  81%|████████  | 21/26 [00:40<00:09,  1.91s/it]

Loss: 0.6319849491119385


Epoch 10:  85%|████████▍ | 22/26 [00:41<00:07,  1.91s/it]

Loss: 0.7907341122627258


Epoch 10:  88%|████████▊ | 23/26 [00:43<00:05,  1.91s/it]

Loss: 0.40795063972473145


Epoch 10:  92%|█████████▏| 24/26 [00:45<00:03,  1.91s/it]

Loss: 0.6909492015838623


Epoch 10:  96%|█████████▌| 25/26 [00:47<00:01,  1.91s/it]

Loss: 0.7020993232727051


Epoch 10: 100%|██████████| 26/26 [00:48<00:00,  1.87s/it]

Loss: 0.673779308795929





Epoch 10 Validation Accuracy: 0.7586206896551724, F1-macro: 0.5512294157455448


In [None]:
# Evaluate on the test set
test_metrics = evaluate(model, test_dataloader, device)
print(f"Test Accuracy: {test_metrics['accuracy']}, F1-macro: {test_metrics['f1_macro']}")

Test Accuracy: 0.803921568627451, F1-macro: 0.5512294157455448


In [None]:
current_type = 'mind reading'

results_df.loc[
    results_df["distortion_type"] == current_type,
    ["test_accuracy", "f1_macro"]
] = [test_metrics['accuracy'], test_metrics['f1_macro']]

# 9. Overgeneralizing

In [None]:
# Add labels
data1_1_labels = list(data1['overgeneralizing'][data1_1.index]) # data1에서는 명칭이 다름.
data2_1_labels = list(data2['overgeneralization'][data2_1.index])
data3_1_labels = list(data3['overgeneralization'][data3_1.index])

# Merging Data
data_encoded = data1_1_encoded + data2_1_encoded + data3_1_encoded
data_labels = data1_1_labels + data2_1_labels + data3_1_labels

dataset_with_labels = CustomDatasetWithLabels(data_encoded, data_labels)

# Define proportions for splitting
train_size = int(0.8 * len(dataset_with_labels))
val_size = int(0.1 * len(dataset_with_labels))
test_size = len(dataset_with_labels) - train_size - val_size

# Split the dataset
train_dataset, val_dataset, test_dataset = random_split(dataset_with_labels, [train_size, val_size, test_size])

# Create DataLoaders for each set
train_dataloader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=64)
test_dataloader = DataLoader(test_dataset, batch_size=64)

input_dim = bert_config.hidden_size
num_labels = len(set(data_labels))
label_emb = torch.randn(num_labels, input_dim)

In [None]:
# Instantiate the InnerProductClassifier model
model = InnerProductClassifier(input_dim, label_emb).to(device)

# Define loss function
criterion = nn.CrossEntropyLoss()

# Define optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=2e-4)

In [None]:
EPOCHS = 10

# Move the BERT model to the device
bert_model.to(device)

for epoch in range(EPOCHS):
    model.train()
    for batch in tqdm(train_dataloader, desc=f"Epoch {epoch+1}"):

        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        y = batch["y"].to(device)

        # Get embeddings from the BERT model
        with torch.no_grad():
            outputs = bert_model(input_ids=input_ids, attention_mask=attention_mask)
            embeddings = outputs.last_hidden_state[:, 0, :] # Get the [CLS] token embedding

        # Pass embeddings to the InnerProductClassifier
        logits = model(embeddings)
        loss = criterion(logits, y)

        # Backpropagate and update weights
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        print(f"Loss: {loss.item()}")

    # Evaluate on the validation set after each epoch
    val_metrics = evaluate(model, val_dataloader, device)
    print(f"Epoch {epoch+1} Validation Accuracy: {val_metrics['accuracy']}, F1-macro: {val_metrics['f1_macro']}")

Epoch 1:   4%|▍         | 1/26 [00:01<00:42,  1.71s/it]

Loss: 1.657505750656128


Epoch 1:   8%|▊         | 2/26 [00:03<00:40,  1.68s/it]

Loss: 6.536026954650879


Epoch 1:  12%|█▏        | 3/26 [00:05<00:38,  1.67s/it]

Loss: 4.204126834869385


Epoch 1:  15%|█▌        | 4/26 [00:06<00:36,  1.67s/it]

Loss: 1.8255808353424072


Epoch 1:  19%|█▉        | 5/26 [00:08<00:35,  1.67s/it]

Loss: 12.133673667907715


Epoch 1:  23%|██▎       | 6/26 [00:10<00:33,  1.67s/it]

Loss: 3.252427101135254


Epoch 1:  27%|██▋       | 7/26 [00:11<00:31,  1.68s/it]

Loss: 2.714059829711914


Epoch 1:  31%|███       | 8/26 [00:13<00:30,  1.68s/it]

Loss: 2.823896884918213


Epoch 1:  35%|███▍      | 9/26 [00:15<00:28,  1.68s/it]

Loss: 5.253157615661621


Epoch 1:  38%|███▊      | 10/26 [00:16<00:26,  1.69s/it]

Loss: 4.7388715744018555


Epoch 1:  42%|████▏     | 11/26 [00:18<00:25,  1.68s/it]

Loss: 6.660444259643555


Epoch 1:  46%|████▌     | 12/26 [00:20<00:23,  1.68s/it]

Loss: 4.923271179199219


Epoch 1:  50%|█████     | 13/26 [00:21<00:21,  1.69s/it]

Loss: 5.979859828948975


Epoch 1:  54%|█████▍    | 14/26 [00:23<00:20,  1.69s/it]

Loss: 6.840394973754883


Epoch 1:  58%|█████▊    | 15/26 [00:25<00:18,  1.69s/it]

Loss: 5.415790557861328


Epoch 1:  62%|██████▏   | 16/26 [00:26<00:16,  1.70s/it]

Loss: 2.182864189147949


Epoch 1:  65%|██████▌   | 17/26 [00:28<00:15,  1.70s/it]

Loss: 1.4134105443954468


Epoch 1:  69%|██████▉   | 18/26 [00:30<00:13,  1.70s/it]

Loss: 4.534759044647217


Epoch 1:  73%|███████▎  | 19/26 [00:32<00:11,  1.70s/it]

Loss: 3.447801113128662


Epoch 1:  77%|███████▋  | 20/26 [00:33<00:10,  1.71s/it]

Loss: 1.504724383354187


Epoch 1:  81%|████████  | 21/26 [00:35<00:08,  1.71s/it]

Loss: 1.9479584693908691


Epoch 1:  85%|████████▍ | 22/26 [00:37<00:06,  1.71s/it]

Loss: 3.9133224487304688


Epoch 1:  88%|████████▊ | 23/26 [00:38<00:05,  1.71s/it]

Loss: 3.099050998687744


Epoch 1:  92%|█████████▏| 24/26 [00:40<00:03,  1.71s/it]

Loss: 1.5850259065628052


Epoch 1:  96%|█████████▌| 25/26 [00:42<00:01,  1.71s/it]

Loss: 1.6374300718307495


Epoch 1: 100%|██████████| 26/26 [00:43<00:00,  1.66s/it]

Loss: 1.3731634616851807





Epoch 1 Validation Accuracy: 0.7044334975369458, F1-macro: 0.5907258064516129


Epoch 2:   4%|▍         | 1/26 [00:01<00:42,  1.71s/it]

Loss: 1.3131194114685059


Epoch 2:   8%|▊         | 2/26 [00:03<00:41,  1.72s/it]

Loss: 2.0387206077575684


Epoch 2:  12%|█▏        | 3/26 [00:05<00:39,  1.72s/it]

Loss: 1.460680365562439


Epoch 2:  15%|█▌        | 4/26 [00:06<00:37,  1.72s/it]

Loss: 0.6026325821876526


Epoch 2:  19%|█▉        | 5/26 [00:08<00:36,  1.72s/it]

Loss: 1.8378205299377441


Epoch 2:  23%|██▎       | 6/26 [00:10<00:34,  1.72s/it]

Loss: 1.056325912475586


Epoch 2:  27%|██▋       | 7/26 [00:12<00:32,  1.72s/it]

Loss: 1.480748176574707


Epoch 2:  31%|███       | 8/26 [00:13<00:31,  1.73s/it]

Loss: 0.9045896530151367


Epoch 2:  35%|███▍      | 9/26 [00:15<00:29,  1.73s/it]

Loss: 1.15848708152771


Epoch 2:  38%|███▊      | 10/26 [00:17<00:27,  1.73s/it]

Loss: 1.2970315217971802


Epoch 2:  42%|████▏     | 11/26 [00:19<00:26,  1.74s/it]

Loss: 0.9758695363998413


Epoch 2:  46%|████▌     | 12/26 [00:20<00:24,  1.74s/it]

Loss: 1.147385835647583


Epoch 2:  50%|█████     | 13/26 [00:22<00:22,  1.74s/it]

Loss: 1.207533836364746


Epoch 2:  54%|█████▍    | 14/26 [00:24<00:20,  1.75s/it]

Loss: 0.8360916376113892


Epoch 2:  58%|█████▊    | 15/26 [00:26<00:19,  1.75s/it]

Loss: 1.2368900775909424


Epoch 2:  62%|██████▏   | 16/26 [00:27<00:17,  1.75s/it]

Loss: 1.2721903324127197


Epoch 2:  65%|██████▌   | 17/26 [00:29<00:15,  1.75s/it]

Loss: 1.257763385772705


Epoch 2:  69%|██████▉   | 18/26 [00:31<00:14,  1.76s/it]

Loss: 1.365992784500122


Epoch 2:  73%|███████▎  | 19/26 [00:33<00:12,  1.76s/it]

Loss: 0.6055692434310913


Epoch 2:  77%|███████▋  | 20/26 [00:34<00:10,  1.76s/it]

Loss: 0.8397058248519897


Epoch 2:  81%|████████  | 21/26 [00:36<00:08,  1.76s/it]

Loss: 1.4946489334106445


Epoch 2:  85%|████████▍ | 22/26 [00:38<00:07,  1.77s/it]

Loss: 0.6358153820037842


Epoch 2:  88%|████████▊ | 23/26 [00:40<00:05,  1.77s/it]

Loss: 1.0872011184692383


Epoch 2:  92%|█████████▏| 24/26 [00:41<00:03,  1.77s/it]

Loss: 0.9156894087791443


Epoch 2:  96%|█████████▌| 25/26 [00:43<00:01,  1.77s/it]

Loss: 1.0651261806488037


Epoch 2: 100%|██████████| 26/26 [00:44<00:00,  1.71s/it]

Loss: 0.8272861242294312





Epoch 2 Validation Accuracy: 0.7931034482758621, F1-macro: 0.4647162230035158


Epoch 3:   4%|▍         | 1/26 [00:01<00:44,  1.79s/it]

Loss: 0.8477904796600342


Epoch 3:   8%|▊         | 2/26 [00:03<00:43,  1.80s/it]

Loss: 0.6361202597618103


Epoch 3:  12%|█▏        | 3/26 [00:05<00:41,  1.79s/it]

Loss: 1.0991325378417969


Epoch 3:  15%|█▌        | 4/26 [00:07<00:39,  1.80s/it]

Loss: 0.9991301894187927


Epoch 3:  19%|█▉        | 5/26 [00:08<00:37,  1.80s/it]

Loss: 0.5888310074806213


Epoch 3:  23%|██▎       | 6/26 [00:10<00:36,  1.80s/it]

Loss: 0.7081252932548523


Epoch 3:  27%|██▋       | 7/26 [00:12<00:34,  1.81s/it]

Loss: 1.4088537693023682


Epoch 3:  31%|███       | 8/26 [00:14<00:32,  1.81s/it]

Loss: 0.7027519941329956


Epoch 3:  35%|███▍      | 9/26 [00:16<00:30,  1.81s/it]

Loss: 0.5796459913253784


Epoch 3:  38%|███▊      | 10/26 [00:18<00:29,  1.81s/it]

Loss: 0.8643797039985657


Epoch 3:  42%|████▏     | 11/26 [00:19<00:27,  1.82s/it]

Loss: 1.2425899505615234


Epoch 3:  46%|████▌     | 12/26 [00:21<00:25,  1.81s/it]

Loss: 0.49866148829460144


Epoch 3:  50%|█████     | 13/26 [00:23<00:23,  1.81s/it]

Loss: 0.7189621925354004


Epoch 3:  54%|█████▍    | 14/26 [00:25<00:21,  1.80s/it]

Loss: 0.8566776514053345


Epoch 3:  58%|█████▊    | 15/26 [00:27<00:19,  1.80s/it]

Loss: 0.863457441329956


Epoch 3:  62%|██████▏   | 16/26 [00:28<00:18,  1.80s/it]

Loss: 0.506223738193512


Epoch 3:  65%|██████▌   | 17/26 [00:30<00:16,  1.81s/it]

Loss: 0.5796521902084351


Epoch 3:  69%|██████▉   | 18/26 [00:32<00:14,  1.81s/it]

Loss: 0.5478360056877136


Epoch 3:  73%|███████▎  | 19/26 [00:34<00:12,  1.81s/it]

Loss: 0.7544068694114685


Epoch 3:  77%|███████▋  | 20/26 [00:36<00:10,  1.81s/it]

Loss: 0.6816093325614929


Epoch 3:  81%|████████  | 21/26 [00:37<00:09,  1.81s/it]

Loss: 0.7674964666366577


Epoch 3:  85%|████████▍ | 22/26 [00:39<00:07,  1.81s/it]

Loss: 1.0914424657821655


Epoch 3:  88%|████████▊ | 23/26 [00:41<00:05,  1.82s/it]

Loss: 0.697940468788147


Epoch 3:  92%|█████████▏| 24/26 [00:43<00:03,  1.82s/it]

Loss: 0.641213059425354


Epoch 3:  96%|█████████▌| 25/26 [00:45<00:01,  1.82s/it]

Loss: 0.5481401681900024


Epoch 3: 100%|██████████| 26/26 [00:46<00:00,  1.77s/it]

Loss: 0.34813934564590454





Epoch 3 Validation Accuracy: 0.7881773399014779, F1-macro: 0.561202433016639


Epoch 4:   4%|▍         | 1/26 [00:01<00:45,  1.83s/it]

Loss: 0.6547352075576782


Epoch 4:   8%|▊         | 2/26 [00:03<00:44,  1.83s/it]

Loss: 0.699188232421875


Epoch 4:  12%|█▏        | 3/26 [00:05<00:42,  1.84s/it]

Loss: 1.0426771640777588


Epoch 4:  15%|█▌        | 4/26 [00:07<00:40,  1.84s/it]

Loss: 0.7493646144866943


Epoch 4:  19%|█▉        | 5/26 [00:09<00:38,  1.84s/it]

Loss: 0.5382634401321411


Epoch 4:  23%|██▎       | 6/26 [00:11<00:36,  1.84s/it]

Loss: 0.9880825281143188


Epoch 4:  27%|██▋       | 7/26 [00:12<00:35,  1.84s/it]

Loss: 0.3957729637622833


Epoch 4:  31%|███       | 8/26 [00:14<00:33,  1.85s/it]

Loss: 0.8406367301940918


Epoch 4:  35%|███▍      | 9/26 [00:16<00:31,  1.85s/it]

Loss: 0.7308945059776306


Epoch 4:  38%|███▊      | 10/26 [00:18<00:29,  1.85s/it]

Loss: 0.7162966728210449


Epoch 4:  42%|████▏     | 11/26 [00:20<00:27,  1.85s/it]

Loss: 0.5991370677947998


Epoch 4:  46%|████▌     | 12/26 [00:22<00:25,  1.85s/it]

Loss: 0.6280054450035095


Epoch 4:  50%|█████     | 13/26 [00:24<00:24,  1.85s/it]

Loss: 0.7625551223754883


Epoch 4:  54%|█████▍    | 14/26 [00:25<00:22,  1.86s/it]

Loss: 1.357206106185913


Epoch 4:  58%|█████▊    | 15/26 [00:27<00:20,  1.86s/it]

Loss: 1.1538307666778564


Epoch 4:  62%|██████▏   | 16/26 [00:29<00:18,  1.86s/it]

Loss: 0.7618488669395447


Epoch 4:  65%|██████▌   | 17/26 [00:31<00:16,  1.86s/it]

Loss: 0.5715628862380981


Epoch 4:  69%|██████▉   | 18/26 [00:33<00:14,  1.86s/it]

Loss: 0.824925422668457


Epoch 4:  73%|███████▎  | 19/26 [00:35<00:13,  1.86s/it]

Loss: 1.1007297039031982


Epoch 4:  77%|███████▋  | 20/26 [00:37<00:11,  1.87s/it]

Loss: 0.8139176368713379


Epoch 4:  81%|████████  | 21/26 [00:38<00:09,  1.87s/it]

Loss: 1.2785825729370117


Epoch 4:  85%|████████▍ | 22/26 [00:40<00:07,  1.87s/it]

Loss: 0.7034227848052979


Epoch 4:  88%|████████▊ | 23/26 [00:42<00:05,  1.87s/it]

Loss: 0.38635343313217163


Epoch 4:  92%|█████████▏| 24/26 [00:44<00:03,  1.87s/it]

Loss: 0.7506933808326721


Epoch 4:  96%|█████████▌| 25/26 [00:46<00:01,  1.87s/it]

Loss: 0.7624847888946533


Epoch 4: 100%|██████████| 26/26 [00:47<00:00,  1.82s/it]

Loss: 0.6269463300704956





Epoch 4 Validation Accuracy: 0.6551724137931034, F1-macro: 0.5635749385749386


Epoch 5:   4%|▍         | 1/26 [00:01<00:46,  1.88s/it]

Loss: 0.7192707061767578


Epoch 5:   8%|▊         | 2/26 [00:03<00:45,  1.88s/it]

Loss: 0.5618783831596375


Epoch 5:  12%|█▏        | 3/26 [00:05<00:43,  1.88s/it]

Loss: 0.7496674060821533


Epoch 5:  15%|█▌        | 4/26 [00:07<00:41,  1.88s/it]

Loss: 0.4479106366634369


Epoch 5:  19%|█▉        | 5/26 [00:09<00:39,  1.89s/it]

Loss: 0.7516856789588928


Epoch 5:  23%|██▎       | 6/26 [00:11<00:37,  1.89s/it]

Loss: 0.6023619174957275


Epoch 5:  27%|██▋       | 7/26 [00:13<00:35,  1.89s/it]

Loss: 0.5700479745864868


Epoch 5:  31%|███       | 8/26 [00:15<00:34,  1.89s/it]

Loss: 0.43210750818252563


Epoch 5:  35%|███▍      | 9/26 [00:16<00:32,  1.89s/it]

Loss: 0.49802422523498535


Epoch 5:  38%|███▊      | 10/26 [00:18<00:30,  1.89s/it]

Loss: 0.3846628665924072


Epoch 5:  42%|████▏     | 11/26 [00:20<00:28,  1.90s/it]

Loss: 0.6336958408355713


Epoch 5:  46%|████▌     | 12/26 [00:22<00:26,  1.90s/it]

Loss: 0.5761820673942566


Epoch 5:  50%|█████     | 13/26 [00:24<00:24,  1.90s/it]

Loss: 0.7877520322799683


Epoch 5:  54%|█████▍    | 14/26 [00:26<00:22,  1.90s/it]

Loss: 0.4022252559661865


Epoch 5:  58%|█████▊    | 15/26 [00:28<00:20,  1.90s/it]

Loss: 1.636815071105957


Epoch 5:  62%|██████▏   | 16/26 [00:30<00:19,  1.90s/it]

Loss: 0.7970861196517944


Epoch 5:  65%|██████▌   | 17/26 [00:32<00:17,  1.91s/it]

Loss: 0.6600724458694458


Epoch 5:  69%|██████▉   | 18/26 [00:34<00:15,  1.91s/it]

Loss: 0.7655812501907349


Epoch 5:  73%|███████▎  | 19/26 [00:36<00:13,  1.91s/it]

Loss: 0.8441262245178223


Epoch 5:  77%|███████▋  | 20/26 [00:37<00:11,  1.91s/it]

Loss: 0.8635995984077454


Epoch 5:  81%|████████  | 21/26 [00:39<00:09,  1.91s/it]

Loss: 0.5659265518188477


Epoch 5:  85%|████████▍ | 22/26 [00:41<00:07,  1.92s/it]

Loss: 0.5224575400352478


Epoch 5:  88%|████████▊ | 23/26 [00:43<00:05,  1.92s/it]

Loss: 0.6483087539672852


Epoch 5:  92%|█████████▏| 24/26 [00:45<00:03,  1.92s/it]

Loss: 0.465619832277298


Epoch 5:  96%|█████████▌| 25/26 [00:47<00:01,  1.92s/it]

Loss: 0.5769513845443726


Epoch 5: 100%|██████████| 26/26 [00:48<00:00,  1.86s/it]

Loss: 0.6743192672729492





Epoch 5 Validation Accuracy: 0.5024630541871922, F1-macro: 0.47032989743987186


Epoch 6:   4%|▍         | 1/26 [00:01<00:47,  1.90s/it]

Loss: 1.136136531829834


Epoch 6:   8%|▊         | 2/26 [00:03<00:45,  1.91s/it]

Loss: 0.5486981868743896


Epoch 6:  12%|█▏        | 3/26 [00:05<00:44,  1.91s/it]

Loss: 0.7041075825691223


Epoch 6:  15%|█▌        | 4/26 [00:07<00:42,  1.91s/it]

Loss: 1.087554931640625


Epoch 6:  19%|█▉        | 5/26 [00:09<00:40,  1.91s/it]

Loss: 0.30776727199554443


Epoch 6:  23%|██▎       | 6/26 [00:11<00:38,  1.91s/it]

Loss: 0.4838438034057617


Epoch 6:  27%|██▋       | 7/26 [00:13<00:36,  1.91s/it]

Loss: 0.4184315800666809


Epoch 6:  31%|███       | 8/26 [00:15<00:34,  1.91s/it]

Loss: 0.6641153693199158


Epoch 6:  35%|███▍      | 9/26 [00:17<00:32,  1.91s/it]

Loss: 0.40797585248947144


Epoch 6:  38%|███▊      | 10/26 [00:19<00:30,  1.91s/it]

Loss: 0.2584243416786194


Epoch 6:  42%|████▏     | 11/26 [00:21<00:28,  1.91s/it]

Loss: 0.5114433169364929


Epoch 6:  46%|████▌     | 12/26 [00:22<00:26,  1.91s/it]

Loss: 0.5582347512245178


Epoch 6:  50%|█████     | 13/26 [00:24<00:24,  1.91s/it]

Loss: 0.8644964098930359


Epoch 6:  54%|█████▍    | 14/26 [00:26<00:22,  1.91s/it]

Loss: 0.5016258955001831


Epoch 6:  58%|█████▊    | 15/26 [00:28<00:20,  1.91s/it]

Loss: 0.49922922253608704


Epoch 6:  62%|██████▏   | 16/26 [00:30<00:19,  1.91s/it]

Loss: 0.7614097595214844


Epoch 6:  65%|██████▌   | 17/26 [00:32<00:17,  1.91s/it]

Loss: 0.35938704013824463


Epoch 6:  69%|██████▉   | 18/26 [00:34<00:15,  1.90s/it]

Loss: 0.9060907363891602


Epoch 6:  73%|███████▎  | 19/26 [00:36<00:13,  1.90s/it]

Loss: 0.3278552293777466


Epoch 6:  77%|███████▋  | 20/26 [00:38<00:11,  1.90s/it]

Loss: 0.882207989692688


Epoch 6:  81%|████████  | 21/26 [00:40<00:09,  1.90s/it]

Loss: 0.403812438249588


Epoch 6:  85%|████████▍ | 22/26 [00:41<00:07,  1.90s/it]

Loss: 0.43207207322120667


Epoch 6:  88%|████████▊ | 23/26 [00:43<00:05,  1.90s/it]

Loss: 0.7015318870544434


Epoch 6:  92%|█████████▏| 24/26 [00:45<00:03,  1.90s/it]

Loss: 0.6986759305000305


Epoch 6:  96%|█████████▌| 25/26 [00:47<00:01,  1.90s/it]

Loss: 0.25372573733329773


Epoch 6: 100%|██████████| 26/26 [00:48<00:00,  1.87s/it]

Loss: 0.7195907831192017





Epoch 6 Validation Accuracy: 0.8177339901477833, F1-macro: 0.5188032545326414


Epoch 7:   4%|▍         | 1/26 [00:01<00:47,  1.90s/it]

Loss: 0.43289047479629517


Epoch 7:   8%|▊         | 2/26 [00:03<00:45,  1.90s/it]

Loss: 0.5551241636276245


Epoch 7:  12%|█▏        | 3/26 [00:05<00:43,  1.90s/it]

Loss: 0.6584416031837463


Epoch 7:  15%|█▌        | 4/26 [00:07<00:41,  1.90s/it]

Loss: 0.3886095881462097


Epoch 7:  19%|█▉        | 5/26 [00:09<00:39,  1.90s/it]

Loss: 0.4510555565357208


Epoch 7:  23%|██▎       | 6/26 [00:11<00:38,  1.90s/it]

Loss: 0.5441249012947083


Epoch 7:  27%|██▋       | 7/26 [00:13<00:36,  1.90s/it]

Loss: 0.48643651604652405


Epoch 7:  31%|███       | 8/26 [00:15<00:34,  1.90s/it]

Loss: 0.47106772661209106


Epoch 7:  35%|███▍      | 9/26 [00:17<00:32,  1.90s/it]

Loss: 0.7073904871940613


Epoch 7:  38%|███▊      | 10/26 [00:19<00:30,  1.90s/it]

Loss: 0.4833637773990631


Epoch 7:  42%|████▏     | 11/26 [00:20<00:28,  1.90s/it]

Loss: 0.36467301845550537


Epoch 7:  46%|████▌     | 12/26 [00:22<00:26,  1.90s/it]

Loss: 0.2812909781932831


Epoch 7:  50%|█████     | 13/26 [00:24<00:24,  1.90s/it]

Loss: 0.4027464985847473


Epoch 7:  54%|█████▍    | 14/26 [00:26<00:22,  1.90s/it]

Loss: 0.5741031169891357


Epoch 7:  58%|█████▊    | 15/26 [00:28<00:20,  1.90s/it]

Loss: 0.46104347705841064


Epoch 7:  62%|██████▏   | 16/26 [00:30<00:19,  1.91s/it]

Loss: 0.4811353087425232


Epoch 7:  65%|██████▌   | 17/26 [00:32<00:17,  1.90s/it]

Loss: 0.3318479061126709


Epoch 7:  69%|██████▉   | 18/26 [00:34<00:15,  1.90s/it]

Loss: 0.29488405585289


Epoch 7:  73%|███████▎  | 19/26 [00:36<00:13,  1.90s/it]

Loss: 0.28527143597602844


Epoch 7:  77%|███████▋  | 20/26 [00:38<00:11,  1.91s/it]

Loss: 0.48194408416748047


Epoch 7:  81%|████████  | 21/26 [00:39<00:09,  1.90s/it]

Loss: 0.5475896596908569


Epoch 7:  85%|████████▍ | 22/26 [00:41<00:07,  1.90s/it]

Loss: 0.5979914665222168


Epoch 7:  88%|████████▊ | 23/26 [00:43<00:05,  1.91s/it]

Loss: 0.8220226764678955


Epoch 7:  92%|█████████▏| 24/26 [00:45<00:03,  1.91s/it]

Loss: 0.41318103671073914


Epoch 7:  96%|█████████▌| 25/26 [00:47<00:01,  1.91s/it]

Loss: 0.565233051776886


Epoch 7: 100%|██████████| 26/26 [00:48<00:00,  1.86s/it]

Loss: 0.6147609353065491





Epoch 7 Validation Accuracy: 0.8078817733990148, F1-macro: 0.5307888342322052


Epoch 8:   4%|▍         | 1/26 [00:01<00:47,  1.90s/it]

Loss: 0.6738172769546509


Epoch 8:   8%|▊         | 2/26 [00:03<00:45,  1.91s/it]

Loss: 0.3121791183948517


Epoch 8:  12%|█▏        | 3/26 [00:05<00:43,  1.91s/it]

Loss: 0.481465607881546


Epoch 8:  15%|█▌        | 4/26 [00:07<00:41,  1.91s/it]

Loss: 0.5356380939483643


Epoch 8:  19%|█▉        | 5/26 [00:09<00:40,  1.91s/it]

Loss: 0.36285609006881714


Epoch 8:  23%|██▎       | 6/26 [00:11<00:38,  1.91s/it]

Loss: 0.7633683085441589


Epoch 8:  27%|██▋       | 7/26 [00:13<00:36,  1.91s/it]

Loss: 0.3494718074798584


Epoch 8:  31%|███       | 8/26 [00:15<00:34,  1.91s/it]

Loss: 0.4851865768432617


Epoch 8:  35%|███▍      | 9/26 [00:17<00:32,  1.91s/it]

Loss: 0.4120796322822571


Epoch 8:  38%|███▊      | 10/26 [00:19<00:30,  1.91s/it]

Loss: 0.18066152930259705


Epoch 8:  42%|████▏     | 11/26 [00:20<00:28,  1.91s/it]

Loss: 0.4298693537712097


Epoch 8:  46%|████▌     | 12/26 [00:22<00:26,  1.91s/it]

Loss: 0.5634798407554626


Epoch 8:  50%|█████     | 13/26 [00:24<00:24,  1.91s/it]

Loss: 0.28279566764831543


Epoch 8:  54%|█████▍    | 14/26 [00:26<00:22,  1.91s/it]

Loss: 0.7117403149604797


Epoch 8:  58%|█████▊    | 15/26 [00:28<00:20,  1.91s/it]

Loss: 0.6211825013160706


Epoch 8:  62%|██████▏   | 16/26 [00:30<00:19,  1.91s/it]

Loss: 0.3143458664417267


Epoch 8:  65%|██████▌   | 17/26 [00:32<00:17,  1.91s/it]

Loss: 0.7533478736877441


Epoch 8:  69%|██████▉   | 18/26 [00:34<00:15,  1.91s/it]

Loss: 0.5434372425079346


Epoch 8:  73%|███████▎  | 19/26 [00:36<00:13,  1.91s/it]

Loss: 0.634558379650116


Epoch 8:  77%|███████▋  | 20/26 [00:38<00:11,  1.91s/it]

Loss: 0.6076721549034119


Epoch 8:  81%|████████  | 21/26 [00:40<00:09,  1.91s/it]

Loss: 0.702229380607605


Epoch 8:  85%|████████▍ | 22/26 [00:42<00:07,  1.91s/it]

Loss: 0.4299502968788147


Epoch 8:  88%|████████▊ | 23/26 [00:43<00:05,  1.91s/it]

Loss: 0.6674308180809021


Epoch 8:  92%|█████████▏| 24/26 [00:45<00:03,  1.91s/it]

Loss: 0.6047710180282593


Epoch 8:  96%|█████████▌| 25/26 [00:47<00:01,  1.91s/it]

Loss: 0.32951676845550537


Epoch 8: 100%|██████████| 26/26 [00:48<00:00,  1.87s/it]

Loss: 0.7474647760391235





Epoch 8 Validation Accuracy: 0.7881773399014779, F1-macro: 0.501000400160064


Epoch 9:   4%|▍         | 1/26 [00:01<00:47,  1.91s/it]

Loss: 0.7504377365112305


Epoch 9:   8%|▊         | 2/26 [00:03<00:45,  1.91s/it]

Loss: 0.36065351963043213


Epoch 9:  12%|█▏        | 3/26 [00:05<00:43,  1.91s/it]

Loss: 0.3609553277492523


Epoch 9:  15%|█▌        | 4/26 [00:07<00:42,  1.91s/it]

Loss: 0.25527751445770264


Epoch 9:  19%|█▉        | 5/26 [00:09<00:40,  1.91s/it]

Loss: 0.4854903817176819


Epoch 9:  23%|██▎       | 6/26 [00:11<00:38,  1.91s/it]

Loss: 0.5000091791152954


Epoch 9:  27%|██▋       | 7/26 [00:13<00:36,  1.91s/it]

Loss: 0.3375281095504761


Epoch 9:  31%|███       | 8/26 [00:15<00:34,  1.91s/it]

Loss: 0.3843074440956116


Epoch 9:  35%|███▍      | 9/26 [00:17<00:32,  1.91s/it]

Loss: 0.3107960522174835


Epoch 9:  38%|███▊      | 10/26 [00:19<00:30,  1.91s/it]

Loss: 0.629743218421936


Epoch 9:  42%|████▏     | 11/26 [00:21<00:28,  1.91s/it]

Loss: 0.8353604078292847


Epoch 9:  46%|████▌     | 12/26 [00:22<00:26,  1.91s/it]

Loss: 0.27736330032348633


Epoch 9:  50%|█████     | 13/26 [00:24<00:24,  1.91s/it]

Loss: 0.6564840078353882


Epoch 9:  54%|█████▍    | 14/26 [00:26<00:22,  1.91s/it]

Loss: 0.6202782392501831


Epoch 9:  58%|█████▊    | 15/26 [00:28<00:21,  1.91s/it]

Loss: 0.33220306038856506


Epoch 9:  62%|██████▏   | 16/26 [00:30<00:19,  1.91s/it]

Loss: 0.2735491991043091


Epoch 9:  65%|██████▌   | 17/26 [00:32<00:17,  1.91s/it]

Loss: 0.5924687385559082


Epoch 9:  69%|██████▉   | 18/26 [00:34<00:15,  1.91s/it]

Loss: 0.4947638511657715


Epoch 9:  73%|███████▎  | 19/26 [00:36<00:13,  1.91s/it]

Loss: 0.5222413539886475


Epoch 9:  77%|███████▋  | 20/26 [00:38<00:11,  1.91s/it]

Loss: 0.7913207411766052


Epoch 9:  81%|████████  | 21/26 [00:40<00:09,  1.91s/it]

Loss: 0.6411073803901672


Epoch 9:  85%|████████▍ | 22/26 [00:42<00:07,  1.91s/it]

Loss: 0.4567495286464691


Epoch 9:  88%|████████▊ | 23/26 [00:43<00:05,  1.91s/it]

Loss: 0.6212248802185059


Epoch 9:  92%|█████████▏| 24/26 [00:45<00:03,  1.92s/it]

Loss: 1.0161947011947632


Epoch 9:  96%|█████████▌| 25/26 [00:47<00:01,  1.91s/it]

Loss: 0.7069445252418518


Epoch 9: 100%|██████████| 26/26 [00:48<00:00,  1.87s/it]

Loss: 0.4424806237220764





Epoch 9 Validation Accuracy: 0.8078817733990148, F1-macro: 0.44686648501362397


Epoch 10:   4%|▍         | 1/26 [00:01<00:47,  1.91s/it]

Loss: 1.2610855102539062


Epoch 10:   8%|▊         | 2/26 [00:03<00:45,  1.91s/it]

Loss: 0.8830419778823853


Epoch 10:  12%|█▏        | 3/26 [00:05<00:44,  1.91s/it]

Loss: 1.494279384613037


Epoch 10:  15%|█▌        | 4/26 [00:07<00:42,  1.91s/it]

Loss: 0.28375405073165894


Epoch 10:  19%|█▉        | 5/26 [00:09<00:40,  1.91s/it]

Loss: 1.226658821105957


Epoch 10:  23%|██▎       | 6/26 [00:11<00:38,  1.92s/it]

Loss: 0.6919376254081726


Epoch 10:  27%|██▋       | 7/26 [00:13<00:36,  1.92s/it]

Loss: 0.6332236528396606


Epoch 10:  31%|███       | 8/26 [00:15<00:34,  1.91s/it]

Loss: 0.4507962465286255


Epoch 10:  35%|███▍      | 9/26 [00:17<00:32,  1.92s/it]

Loss: 1.3549563884735107


Epoch 10:  38%|███▊      | 10/26 [00:19<00:30,  1.92s/it]

Loss: 0.5269737243652344


Epoch 10:  42%|████▏     | 11/26 [00:21<00:28,  1.92s/it]

Loss: 0.7415227890014648


Epoch 10:  46%|████▌     | 12/26 [00:22<00:26,  1.92s/it]

Loss: 0.6940248012542725


Epoch 10:  50%|█████     | 13/26 [00:24<00:24,  1.91s/it]

Loss: 1.0554434061050415


Epoch 10:  54%|█████▍    | 14/26 [00:26<00:22,  1.92s/it]

Loss: 0.6270861625671387


Epoch 10:  58%|█████▊    | 15/26 [00:28<00:21,  1.91s/it]

Loss: 0.4281952381134033


Epoch 10:  62%|██████▏   | 16/26 [00:30<00:19,  1.91s/it]

Loss: 0.322698712348938


Epoch 10:  65%|██████▌   | 17/26 [00:32<00:17,  1.91s/it]

Loss: 0.30196020007133484


Epoch 10:  69%|██████▉   | 18/26 [00:34<00:15,  1.91s/it]

Loss: 0.8937848806381226


Epoch 10:  73%|███████▎  | 19/26 [00:36<00:13,  1.92s/it]

Loss: 0.5609683394432068


Epoch 10:  77%|███████▋  | 20/26 [00:38<00:11,  1.91s/it]

Loss: 0.5151781439781189


Epoch 10:  81%|████████  | 21/26 [00:40<00:09,  1.91s/it]

Loss: 0.2514919638633728


Epoch 10:  85%|████████▍ | 22/26 [00:42<00:07,  1.92s/it]

Loss: 0.7358967065811157


Epoch 10:  88%|████████▊ | 23/26 [00:44<00:05,  1.92s/it]

Loss: 1.2874693870544434


Epoch 10:  92%|█████████▏| 24/26 [00:45<00:03,  1.92s/it]

Loss: 0.6850442886352539


Epoch 10:  96%|█████████▌| 25/26 [00:47<00:01,  1.92s/it]

Loss: 1.2493101358413696


Epoch 10: 100%|██████████| 26/26 [00:48<00:00,  1.87s/it]

Loss: 0.5622022151947021





Epoch 10 Validation Accuracy: 0.8177339901477833, F1-macro: 0.5548509452972205


In [None]:
# Evaluate on the test set
test_metrics = evaluate(model, test_dataloader, device)
print(f"Test Accuracy: {test_metrics['accuracy']}, F1-macro: {test_metrics['f1_macro']}")

Test Accuracy: 0.8186274509803921, F1-macro: 0.5548509452972205


In [None]:
current_type = 'overgeneralizing'

results_df.loc[
    results_df["distortion_type"] == current_type,
    ["test_accuracy", "f1_macro"]
] = [test_metrics['accuracy'], test_metrics['f1_macro']]

# 10. Should Statements

In [None]:
# Add labels
data1_1_labels = list(data1['should statements'][data1_1.index])
data2_1_labels = list(data2['should statements'][data2_1.index])
data3_1_labels = list(data3['should statements'][data3_1.index])

# Merging Data
data_encoded = data1_1_encoded + data2_1_encoded + data3_1_encoded
data_labels = data1_1_labels + data2_1_labels + data3_1_labels

dataset_with_labels = CustomDatasetWithLabels(data_encoded, data_labels)

# Define proportions for splitting
train_size = int(0.8 * len(dataset_with_labels))
val_size = int(0.1 * len(dataset_with_labels))
test_size = len(dataset_with_labels) - train_size - val_size

# Split the dataset
train_dataset, val_dataset, test_dataset = random_split(dataset_with_labels, [train_size, val_size, test_size])

# Create DataLoaders for each set
train_dataloader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=64)
test_dataloader = DataLoader(test_dataset, batch_size=64)

input_dim = bert_config.hidden_size
num_labels = len(set(data_labels))
label_emb = torch.randn(num_labels, input_dim)

In [None]:
# Instantiate the InnerProductClassifier model
model = InnerProductClassifier(input_dim, label_emb).to(device)

# Define loss function
criterion = nn.CrossEntropyLoss()

# Define optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=2e-4)

In [None]:
EPOCHS = 10

# Move the BERT model to the device
bert_model.to(device)

for epoch in range(EPOCHS):
    model.train()
    for batch in tqdm(train_dataloader, desc=f"Epoch {epoch+1}"):

        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        y = batch["y"].to(device)

        # Get embeddings from the BERT model
        with torch.no_grad():
            outputs = bert_model(input_ids=input_ids, attention_mask=attention_mask)
            embeddings = outputs.last_hidden_state[:, 0, :] # Get the [CLS] token embedding

        # Pass embeddings to the InnerProductClassifier
        logits = model(embeddings)
        loss = criterion(logits, y)

        # Backpropagate and update weights
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        print(f"Loss: {loss.item()}")

    # Evaluate on the validation set after each epoch
    val_metrics = evaluate(model, val_dataloader, device)
    print(f"Epoch {epoch+1} Validation Accuracy: {val_metrics['accuracy']}, F1-macro: {val_metrics['f1_macro']}")

Epoch 1:   4%|▍         | 1/26 [00:01<00:47,  1.89s/it]

Loss: 1.4640066623687744


Epoch 1:   8%|▊         | 2/26 [00:03<00:45,  1.91s/it]

Loss: 2.801388740539551


Epoch 1:  12%|█▏        | 3/26 [00:05<00:43,  1.91s/it]

Loss: 0.6426094770431519


Epoch 1:  15%|█▌        | 4/26 [00:07<00:42,  1.91s/it]

Loss: 1.2571823596954346


Epoch 1:  19%|█▉        | 5/26 [00:09<00:40,  1.91s/it]

Loss: 4.027986526489258


Epoch 1:  23%|██▎       | 6/26 [00:11<00:38,  1.91s/it]

Loss: 1.256540298461914


Epoch 1:  27%|██▋       | 7/26 [00:13<00:36,  1.92s/it]

Loss: 2.2885921001434326


Epoch 1:  31%|███       | 8/26 [00:15<00:34,  1.92s/it]

Loss: 8.412786483764648


Epoch 1:  35%|███▍      | 9/26 [00:17<00:32,  1.92s/it]

Loss: 2.1133832931518555


Epoch 1:  38%|███▊      | 10/26 [00:19<00:30,  1.92s/it]

Loss: 0.08105773478746414


Epoch 1:  42%|████▏     | 11/26 [00:21<00:28,  1.92s/it]

Loss: 2.600219488143921


Epoch 1:  46%|████▌     | 12/26 [00:22<00:26,  1.92s/it]

Loss: 3.3731493949890137


Epoch 1:  50%|█████     | 13/26 [00:24<00:24,  1.92s/it]

Loss: 2.8659114837646484


Epoch 1:  54%|█████▍    | 14/26 [00:26<00:22,  1.92s/it]

Loss: 6.064884662628174


Epoch 1:  58%|█████▊    | 15/26 [00:28<00:21,  1.92s/it]

Loss: 1.419960856437683


Epoch 1:  62%|██████▏   | 16/26 [00:30<00:19,  1.92s/it]

Loss: 2.1114349365234375


Epoch 1:  65%|██████▌   | 17/26 [00:32<00:17,  1.92s/it]

Loss: 2.611464738845825


Epoch 1:  69%|██████▉   | 18/26 [00:34<00:15,  1.92s/it]

Loss: 3.632119655609131


Epoch 1:  73%|███████▎  | 19/26 [00:36<00:13,  1.92s/it]

Loss: 3.7102813720703125


Epoch 1:  77%|███████▋  | 20/26 [00:38<00:11,  1.92s/it]

Loss: 2.825608491897583


Epoch 1:  81%|████████  | 21/26 [00:40<00:09,  1.92s/it]

Loss: 1.465686321258545


Epoch 1:  85%|████████▍ | 22/26 [00:42<00:07,  1.92s/it]

Loss: 0.1625308245420456


Epoch 1:  88%|████████▊ | 23/26 [00:44<00:05,  1.92s/it]

Loss: 3.7154481410980225


Epoch 1:  92%|█████████▏| 24/26 [00:45<00:03,  1.92s/it]

Loss: 1.242293119430542


Epoch 1:  96%|█████████▌| 25/26 [00:47<00:01,  1.91s/it]

Loss: 2.639899730682373


Epoch 1: 100%|██████████| 26/26 [00:48<00:00,  1.88s/it]

Loss: 0.9421493411064148





Epoch 1 Validation Accuracy: 0.8768472906403941, F1-macro: 0.5040555066940291


Epoch 2:   4%|▍         | 1/26 [00:01<00:47,  1.91s/it]

Loss: 1.1546390056610107


Epoch 2:   8%|▊         | 2/26 [00:03<00:45,  1.92s/it]

Loss: 0.3933939039707184


Epoch 2:  12%|█▏        | 3/26 [00:05<00:44,  1.92s/it]

Loss: 1.253158450126648


Epoch 2:  15%|█▌        | 4/26 [00:07<00:42,  1.92s/it]

Loss: 1.5637502670288086


Epoch 2:  19%|█▉        | 5/26 [00:09<00:40,  1.92s/it]

Loss: 1.502450704574585


Epoch 2:  23%|██▎       | 6/26 [00:11<00:38,  1.92s/it]

Loss: 1.2492833137512207


Epoch 2:  27%|██▋       | 7/26 [00:13<00:36,  1.91s/it]

Loss: 0.4401182532310486


Epoch 2:  31%|███       | 8/26 [00:15<00:34,  1.91s/it]

Loss: 0.34316056966781616


Epoch 2:  35%|███▍      | 9/26 [00:17<00:32,  1.91s/it]

Loss: 0.596600353717804


Epoch 2:  38%|███▊      | 10/26 [00:19<00:30,  1.91s/it]

Loss: 0.5552201867103577


Epoch 2:  42%|████▏     | 11/26 [00:21<00:28,  1.91s/it]

Loss: 0.6896437406539917


Epoch 2:  46%|████▌     | 12/26 [00:22<00:26,  1.92s/it]

Loss: 0.3636954128742218


Epoch 2:  50%|█████     | 13/26 [00:24<00:24,  1.91s/it]

Loss: 0.7169504761695862


Epoch 2:  54%|█████▍    | 14/26 [00:26<00:22,  1.91s/it]

Loss: 0.4918428063392639


Epoch 2:  58%|█████▊    | 15/26 [00:28<00:21,  1.92s/it]

Loss: 1.0768766403198242


Epoch 2:  62%|██████▏   | 16/26 [00:30<00:19,  1.92s/it]

Loss: 1.1791753768920898


Epoch 2:  65%|██████▌   | 17/26 [00:32<00:17,  1.91s/it]

Loss: 0.29128527641296387


Epoch 2:  69%|██████▉   | 18/26 [00:34<00:15,  1.91s/it]

Loss: 0.8273646235466003


Epoch 2:  73%|███████▎  | 19/26 [00:36<00:13,  1.91s/it]

Loss: 1.2798956632614136


Epoch 2:  77%|███████▋  | 20/26 [00:38<00:11,  1.91s/it]

Loss: 0.39034661650657654


Epoch 2:  81%|████████  | 21/26 [00:40<00:09,  1.91s/it]

Loss: 0.763490617275238


Epoch 2:  85%|████████▍ | 22/26 [00:42<00:07,  1.91s/it]

Loss: 0.7336219549179077


Epoch 2:  88%|████████▊ | 23/26 [00:44<00:05,  1.91s/it]

Loss: 0.417108416557312


Epoch 2:  92%|█████████▏| 24/26 [00:45<00:03,  1.91s/it]

Loss: 0.9690066576004028


Epoch 2:  96%|█████████▌| 25/26 [00:47<00:01,  1.91s/it]

Loss: 0.5675293207168579


Epoch 2: 100%|██████████| 26/26 [00:48<00:00,  1.87s/it]

Loss: 1.1467899084091187





Epoch 2 Validation Accuracy: 0.6798029556650246, F1-macro: 0.5533931291250634


Epoch 3:   4%|▍         | 1/26 [00:01<00:47,  1.91s/it]

Loss: 0.794577956199646


Epoch 3:   8%|▊         | 2/26 [00:03<00:45,  1.91s/it]

Loss: 0.8976892828941345


Epoch 3:  12%|█▏        | 3/26 [00:05<00:43,  1.91s/it]

Loss: 0.5334936380386353


Epoch 3:  15%|█▌        | 4/26 [00:07<00:42,  1.91s/it]

Loss: 0.4547036290168762


Epoch 3:  19%|█▉        | 5/26 [00:09<00:40,  1.91s/it]

Loss: 1.0917959213256836


Epoch 3:  23%|██▎       | 6/26 [00:11<00:38,  1.91s/it]

Loss: 1.0158127546310425


Epoch 3:  27%|██▋       | 7/26 [00:13<00:36,  1.91s/it]

Loss: 0.3757621645927429


Epoch 3:  31%|███       | 8/26 [00:15<00:34,  1.91s/it]

Loss: 0.3858676254749298


Epoch 3:  35%|███▍      | 9/26 [00:17<00:32,  1.91s/it]

Loss: 0.4034235179424286


Epoch 3:  38%|███▊      | 10/26 [00:19<00:30,  1.91s/it]

Loss: 0.6188843250274658


Epoch 3:  42%|████▏     | 11/26 [00:21<00:28,  1.91s/it]

Loss: 0.6001810431480408


Epoch 3:  46%|████▌     | 12/26 [00:22<00:26,  1.91s/it]

Loss: 0.27863454818725586


Epoch 3:  50%|█████     | 13/26 [00:24<00:24,  1.91s/it]

Loss: 0.9376555681228638


Epoch 3:  54%|█████▍    | 14/26 [00:26<00:22,  1.91s/it]

Loss: 0.102530837059021


Epoch 3:  58%|█████▊    | 15/26 [00:28<00:21,  1.91s/it]

Loss: 0.8409407734870911


Epoch 3:  62%|██████▏   | 16/26 [00:30<00:19,  1.91s/it]

Loss: 0.5763545036315918


Epoch 3:  65%|██████▌   | 17/26 [00:32<00:17,  1.91s/it]

Loss: 0.45221492648124695


Epoch 3:  69%|██████▉   | 18/26 [00:34<00:15,  1.91s/it]

Loss: 0.47077834606170654


Epoch 3:  73%|███████▎  | 19/26 [00:36<00:13,  1.91s/it]

Loss: 0.23945698142051697


Epoch 3:  77%|███████▋  | 20/26 [00:38<00:11,  1.91s/it]

Loss: 0.4972090423107147


Epoch 3:  81%|████████  | 21/26 [00:40<00:09,  1.91s/it]

Loss: 0.7112190127372742


Epoch 3:  85%|████████▍ | 22/26 [00:42<00:07,  1.91s/it]

Loss: 0.20392858982086182


Epoch 3:  88%|████████▊ | 23/26 [00:43<00:05,  1.91s/it]

Loss: 0.27021506428718567


Epoch 3:  92%|█████████▏| 24/26 [00:45<00:03,  1.91s/it]

Loss: 0.6958402395248413


Epoch 3:  96%|█████████▌| 25/26 [00:47<00:01,  1.91s/it]

Loss: 0.4648463726043701


Epoch 3: 100%|██████████| 26/26 [00:48<00:00,  1.87s/it]

Loss: 0.10221409797668457





Epoch 3 Validation Accuracy: 0.8916256157635468, F1-macro: 0.6040780141843972


Epoch 4:   4%|▍         | 1/26 [00:01<00:47,  1.90s/it]

Loss: 0.32923898100852966


Epoch 4:   8%|▊         | 2/26 [00:03<00:45,  1.90s/it]

Loss: 0.36515071988105774


Epoch 4:  12%|█▏        | 3/26 [00:05<00:43,  1.91s/it]

Loss: 0.09483855962753296


Epoch 4:  15%|█▌        | 4/26 [00:07<00:41,  1.91s/it]

Loss: 0.4786708354949951


Epoch 4:  19%|█▉        | 5/26 [00:09<00:40,  1.91s/it]

Loss: 0.1679668128490448


Epoch 4:  23%|██▎       | 6/26 [00:11<00:38,  1.91s/it]

Loss: 0.5477207899093628


Epoch 4:  27%|██▋       | 7/26 [00:13<00:36,  1.91s/it]

Loss: 0.23662245273590088


Epoch 4:  31%|███       | 8/26 [00:15<00:34,  1.91s/it]

Loss: 0.45867419242858887


Epoch 4:  35%|███▍      | 9/26 [00:17<00:32,  1.91s/it]

Loss: 0.6814966201782227


Epoch 4:  38%|███▊      | 10/26 [00:19<00:30,  1.91s/it]

Loss: 0.08940320461988449


Epoch 4:  42%|████▏     | 11/26 [00:20<00:28,  1.91s/it]

Loss: 0.21453773975372314


Epoch 4:  46%|████▌     | 12/26 [00:22<00:26,  1.91s/it]

Loss: 0.33782267570495605


Epoch 4:  50%|█████     | 13/26 [00:24<00:24,  1.91s/it]

Loss: 0.3498687148094177


Epoch 4:  54%|█████▍    | 14/26 [00:26<00:22,  1.91s/it]

Loss: 0.42578360438346863


Epoch 4:  58%|█████▊    | 15/26 [00:28<00:21,  1.91s/it]

Loss: 0.47810113430023193


Epoch 4:  62%|██████▏   | 16/26 [00:30<00:19,  1.91s/it]

Loss: 0.13099031150341034


Epoch 4:  65%|██████▌   | 17/26 [00:32<00:17,  1.91s/it]

Loss: 0.18760091066360474


Epoch 4:  69%|██████▉   | 18/26 [00:34<00:15,  1.91s/it]

Loss: 0.47527390718460083


Epoch 4:  73%|███████▎  | 19/26 [00:36<00:13,  1.91s/it]

Loss: 0.3597647249698639


Epoch 4:  77%|███████▋  | 20/26 [00:38<00:11,  1.91s/it]

Loss: 0.3248029947280884


Epoch 4:  81%|████████  | 21/26 [00:40<00:09,  1.91s/it]

Loss: 0.5177206993103027


Epoch 4:  85%|████████▍ | 22/26 [00:42<00:07,  1.91s/it]

Loss: 0.5624377131462097


Epoch 4:  88%|████████▊ | 23/26 [00:43<00:05,  1.91s/it]

Loss: 0.491388738155365


Epoch 4:  92%|█████████▏| 24/26 [00:45<00:03,  1.91s/it]

Loss: 0.7370895147323608


Epoch 4:  96%|█████████▌| 25/26 [00:47<00:01,  1.91s/it]

Loss: 0.28769418597221375


Epoch 4: 100%|██████████| 26/26 [00:48<00:00,  1.87s/it]

Loss: 0.27398520708084106





Epoch 4 Validation Accuracy: 0.896551724137931, F1-macro: 0.5834066256229844


Epoch 5:   4%|▍         | 1/26 [00:01<00:47,  1.90s/it]

Loss: 0.4056178629398346


Epoch 5:   8%|▊         | 2/26 [00:03<00:45,  1.91s/it]

Loss: 0.1990782618522644


Epoch 5:  12%|█▏        | 3/26 [00:05<00:43,  1.91s/it]

Loss: 0.24290074408054352


Epoch 5:  15%|█▌        | 4/26 [00:07<00:41,  1.91s/it]

Loss: 0.30109459161758423


Epoch 5:  19%|█▉        | 5/26 [00:09<00:40,  1.91s/it]

Loss: 0.34790924191474915


Epoch 5:  23%|██▎       | 6/26 [00:11<00:38,  1.91s/it]

Loss: 0.28672856092453003


Epoch 5:  27%|██▋       | 7/26 [00:13<00:36,  1.91s/it]

Loss: 0.1503579318523407


Epoch 5:  31%|███       | 8/26 [00:15<00:34,  1.91s/it]

Loss: 0.3191007375717163


Epoch 5:  35%|███▍      | 9/26 [00:17<00:32,  1.91s/it]

Loss: 0.9092426896095276


Epoch 5:  38%|███▊      | 10/26 [00:19<00:30,  1.91s/it]

Loss: 0.1958034634590149


Epoch 5:  42%|████▏     | 11/26 [00:20<00:28,  1.91s/it]

Loss: 0.39114612340927124


Epoch 5:  46%|████▌     | 12/26 [00:22<00:26,  1.91s/it]

Loss: 0.3477707803249359


Epoch 5:  50%|█████     | 13/26 [00:24<00:24,  1.91s/it]

Loss: 0.18458899855613708


Epoch 5:  54%|█████▍    | 14/26 [00:26<00:22,  1.91s/it]

Loss: 0.28589221835136414


Epoch 5:  58%|█████▊    | 15/26 [00:28<00:20,  1.91s/it]

Loss: 0.41642701625823975


Epoch 5:  62%|██████▏   | 16/26 [00:30<00:19,  1.91s/it]

Loss: 0.8087517023086548


Epoch 5:  65%|██████▌   | 17/26 [00:32<00:17,  1.91s/it]

Loss: 0.11745864152908325


Epoch 5:  69%|██████▉   | 18/26 [00:34<00:15,  1.91s/it]

Loss: 0.12643544375896454


Epoch 5:  73%|███████▎  | 19/26 [00:36<00:13,  1.91s/it]

Loss: 0.24839213490486145


Epoch 5:  77%|███████▋  | 20/26 [00:38<00:11,  1.91s/it]

Loss: 0.45586782693862915


Epoch 5:  81%|████████  | 21/26 [00:40<00:09,  1.91s/it]

Loss: 0.34226179122924805


Epoch 5:  85%|████████▍ | 22/26 [00:41<00:07,  1.91s/it]

Loss: 0.3844078779220581


Epoch 5:  88%|████████▊ | 23/26 [00:43<00:05,  1.91s/it]

Loss: 0.3819909691810608


Epoch 5:  92%|█████████▏| 24/26 [00:45<00:03,  1.91s/it]

Loss: 0.23517194390296936


Epoch 5:  96%|█████████▌| 25/26 [00:47<00:01,  1.91s/it]

Loss: 0.4162902235984802


Epoch 5: 100%|██████████| 26/26 [00:48<00:00,  1.87s/it]

Loss: 0.1915917694568634





Epoch 5 Validation Accuracy: 0.8916256157635468, F1-macro: 0.5128708551483421


Epoch 6:   4%|▍         | 1/26 [00:01<00:47,  1.90s/it]

Loss: 0.3094562888145447


Epoch 6:   8%|▊         | 2/26 [00:03<00:45,  1.90s/it]

Loss: 0.22451096773147583


Epoch 6:  12%|█▏        | 3/26 [00:05<00:43,  1.91s/it]

Loss: 0.2998427450656891


Epoch 6:  15%|█▌        | 4/26 [00:07<00:41,  1.91s/it]

Loss: 0.49054282903671265


Epoch 6:  19%|█▉        | 5/26 [00:09<00:40,  1.91s/it]

Loss: 0.2833305299282074


Epoch 6:  23%|██▎       | 6/26 [00:11<00:38,  1.91s/it]

Loss: 0.08693281561136246


Epoch 6:  27%|██▋       | 7/26 [00:13<00:36,  1.91s/it]

Loss: 0.6447916030883789


Epoch 6:  31%|███       | 8/26 [00:15<00:34,  1.91s/it]

Loss: 0.18008597195148468


Epoch 6:  35%|███▍      | 9/26 [00:17<00:32,  1.91s/it]

Loss: 0.24345368146896362


Epoch 6:  38%|███▊      | 10/26 [00:19<00:30,  1.91s/it]

Loss: 0.2687043249607086


Epoch 6:  42%|████▏     | 11/26 [00:20<00:28,  1.91s/it]

Loss: 0.23108097910881042


Epoch 6:  46%|████▌     | 12/26 [00:22<00:26,  1.91s/it]

Loss: 0.6989853978157043


Epoch 6:  50%|█████     | 13/26 [00:24<00:24,  1.91s/it]

Loss: 0.282792866230011


Epoch 6:  54%|█████▍    | 14/26 [00:26<00:22,  1.91s/it]

Loss: 0.41914790868759155


Epoch 6:  58%|█████▊    | 15/26 [00:28<00:20,  1.91s/it]

Loss: 0.5097165107727051


Epoch 6:  62%|██████▏   | 16/26 [00:30<00:19,  1.91s/it]

Loss: 0.2512156367301941


Epoch 6:  65%|██████▌   | 17/26 [00:32<00:17,  1.91s/it]

Loss: 0.5262503623962402


Epoch 6:  69%|██████▉   | 18/26 [00:34<00:15,  1.91s/it]

Loss: 0.14649337530136108


Epoch 6:  73%|███████▎  | 19/26 [00:36<00:13,  1.91s/it]

Loss: 0.26084819436073303


Epoch 6:  77%|███████▋  | 20/26 [00:38<00:11,  1.91s/it]

Loss: 0.22150015830993652


Epoch 6:  81%|████████  | 21/26 [00:40<00:09,  1.91s/it]

Loss: 0.1647220402956009


Epoch 6:  85%|████████▍ | 22/26 [00:41<00:07,  1.91s/it]

Loss: 0.42294153571128845


Epoch 6:  88%|████████▊ | 23/26 [00:43<00:05,  1.91s/it]

Loss: 0.10497473925352097


Epoch 6:  92%|█████████▏| 24/26 [00:45<00:03,  1.91s/it]

Loss: 0.2877498269081116


Epoch 6:  96%|█████████▌| 25/26 [00:47<00:01,  1.91s/it]

Loss: 0.1787145733833313


Epoch 6: 100%|██████████| 26/26 [00:48<00:00,  1.87s/it]

Loss: 0.2511516213417053





Epoch 6 Validation Accuracy: 0.896551724137931, F1-macro: 0.5160631172664321


Epoch 7:   4%|▍         | 1/26 [00:01<00:47,  1.90s/it]

Loss: 0.32734501361846924


Epoch 7:   8%|▊         | 2/26 [00:03<00:45,  1.91s/it]

Loss: 0.2105432152748108


Epoch 7:  12%|█▏        | 3/26 [00:05<00:43,  1.91s/it]

Loss: 0.25952816009521484


Epoch 7:  15%|█▌        | 4/26 [00:07<00:41,  1.91s/it]

Loss: 0.17023740708827972


Epoch 7:  19%|█▉        | 5/26 [00:09<00:40,  1.91s/it]

Loss: 0.29634585976600647


Epoch 7:  23%|██▎       | 6/26 [00:11<00:38,  1.91s/it]

Loss: 0.42989689111709595


Epoch 7:  27%|██▋       | 7/26 [00:13<00:36,  1.91s/it]

Loss: 0.13603763282299042


Epoch 7:  31%|███       | 8/26 [00:15<00:34,  1.91s/it]

Loss: 0.33766162395477295


Epoch 7:  35%|███▍      | 9/26 [00:17<00:32,  1.91s/it]

Loss: 0.5323776602745056


Epoch 7:  38%|███▊      | 10/26 [00:19<00:30,  1.91s/it]

Loss: 0.2509669065475464


Epoch 7:  42%|████▏     | 11/26 [00:20<00:28,  1.91s/it]

Loss: 0.5097602605819702


Epoch 7:  46%|████▌     | 12/26 [00:22<00:26,  1.91s/it]

Loss: 0.20783588290214539


Epoch 7:  50%|█████     | 13/26 [00:24<00:24,  1.91s/it]

Loss: 0.30295315384864807


Epoch 7:  54%|█████▍    | 14/26 [00:26<00:22,  1.91s/it]

Loss: 0.43201398849487305


Epoch 7:  58%|█████▊    | 15/26 [00:28<00:21,  1.91s/it]

Loss: 0.27345287799835205


Epoch 7:  62%|██████▏   | 16/26 [00:30<00:19,  1.91s/it]

Loss: 0.37328651547431946


Epoch 7:  65%|██████▌   | 17/26 [00:32<00:17,  1.91s/it]

Loss: 0.408597469329834


Epoch 7:  69%|██████▉   | 18/26 [00:34<00:15,  1.91s/it]

Loss: 0.6332524418830872


Epoch 7:  73%|███████▎  | 19/26 [00:36<00:13,  1.91s/it]

Loss: 0.2255052626132965


Epoch 7:  77%|███████▋  | 20/26 [00:38<00:11,  1.91s/it]

Loss: 0.24435701966285706


Epoch 7:  81%|████████  | 21/26 [00:40<00:09,  1.91s/it]

Loss: 0.670853316783905


Epoch 7:  85%|████████▍ | 22/26 [00:42<00:07,  1.91s/it]

Loss: 0.2671384811401367


Epoch 7:  88%|████████▊ | 23/26 [00:43<00:05,  1.91s/it]

Loss: 0.20576857030391693


Epoch 7:  92%|█████████▏| 24/26 [00:45<00:03,  1.91s/it]

Loss: 0.7128610610961914


Epoch 7:  96%|█████████▌| 25/26 [00:47<00:01,  1.91s/it]

Loss: 1.0466439723968506


Epoch 7: 100%|██████████| 26/26 [00:48<00:00,  1.87s/it]

Loss: 0.4006160795688629





Epoch 7 Validation Accuracy: 0.7684729064039408, F1-macro: 0.5466381563316702


Epoch 8:   4%|▍         | 1/26 [00:01<00:47,  1.90s/it]

Loss: 0.27417778968811035


Epoch 8:   8%|▊         | 2/26 [00:03<00:45,  1.90s/it]

Loss: 0.45991963148117065


Epoch 8:  12%|█▏        | 3/26 [00:05<00:43,  1.91s/it]

Loss: 0.4934041500091553


Epoch 8:  15%|█▌        | 4/26 [00:07<00:41,  1.91s/it]

Loss: 0.2765956521034241


Epoch 8:  19%|█▉        | 5/26 [00:09<00:40,  1.91s/it]

Loss: 0.08458105474710464


Epoch 8:  23%|██▎       | 6/26 [00:11<00:38,  1.91s/it]

Loss: 1.0359580516815186


Epoch 8:  27%|██▋       | 7/26 [00:13<00:36,  1.91s/it]

Loss: 0.21387916803359985


Epoch 8:  31%|███       | 8/26 [00:15<00:34,  1.91s/it]

Loss: 0.5301046371459961


Epoch 8:  35%|███▍      | 9/26 [00:17<00:32,  1.91s/it]

Loss: 0.5370197892189026


Epoch 8:  38%|███▊      | 10/26 [00:19<00:30,  1.91s/it]

Loss: 0.31328684091567993


Epoch 8:  42%|████▏     | 11/26 [00:20<00:28,  1.91s/it]

Loss: 0.15685777366161346


Epoch 8:  46%|████▌     | 12/26 [00:22<00:26,  1.91s/it]

Loss: 0.1265714317560196


Epoch 8:  50%|█████     | 13/26 [00:24<00:24,  1.91s/it]

Loss: 0.3659555912017822


Epoch 8:  54%|█████▍    | 14/26 [00:26<00:22,  1.91s/it]

Loss: 0.8444483280181885


Epoch 8:  58%|█████▊    | 15/26 [00:28<00:20,  1.91s/it]

Loss: 0.2115732729434967


Epoch 8:  62%|██████▏   | 16/26 [00:30<00:19,  1.91s/it]

Loss: 0.271187424659729


Epoch 8:  65%|██████▌   | 17/26 [00:32<00:17,  1.91s/it]

Loss: 0.24975243210792542


Epoch 8:  69%|██████▉   | 18/26 [00:34<00:15,  1.91s/it]

Loss: 0.3356407880783081


Epoch 8:  73%|███████▎  | 19/26 [00:36<00:13,  1.91s/it]

Loss: 0.24977809190750122


Epoch 8:  77%|███████▋  | 20/26 [00:38<00:11,  1.91s/it]

Loss: 0.25163358449935913


Epoch 8:  81%|████████  | 21/26 [00:40<00:09,  1.91s/it]

Loss: 0.190819650888443


Epoch 8:  85%|████████▍ | 22/26 [00:41<00:07,  1.91s/it]

Loss: 0.17575284838676453


Epoch 8:  88%|████████▊ | 23/26 [00:43<00:05,  1.91s/it]

Loss: 0.13265429437160492


Epoch 8:  92%|█████████▏| 24/26 [00:45<00:03,  1.91s/it]

Loss: 0.14943884313106537


Epoch 8:  96%|█████████▌| 25/26 [00:47<00:01,  1.91s/it]

Loss: 0.5682193040847778


Epoch 8: 100%|██████████| 26/26 [00:48<00:00,  1.87s/it]

Loss: 0.09381765872240067





Epoch 8 Validation Accuracy: 0.7980295566502463, F1-macro: 0.5816116221786558


Epoch 9:   4%|▍         | 1/26 [00:01<00:47,  1.90s/it]

Loss: 0.27286046743392944


Epoch 9:   8%|▊         | 2/26 [00:03<00:45,  1.90s/it]

Loss: 0.2799336314201355


Epoch 9:  12%|█▏        | 3/26 [00:05<00:43,  1.90s/it]

Loss: 0.04623296111822128


Epoch 9:  15%|█▌        | 4/26 [00:07<00:41,  1.91s/it]

Loss: 0.17669382691383362


Epoch 9:  19%|█▉        | 5/26 [00:09<00:40,  1.91s/it]

Loss: 0.6192940473556519


Epoch 9:  23%|██▎       | 6/26 [00:11<00:38,  1.91s/it]

Loss: 0.5340499877929688


Epoch 9:  27%|██▋       | 7/26 [00:13<00:36,  1.91s/it]

Loss: 0.3927459716796875


Epoch 9:  31%|███       | 8/26 [00:15<00:34,  1.91s/it]

Loss: 0.24990957975387573


Epoch 9:  35%|███▍      | 9/26 [00:17<00:32,  1.91s/it]

Loss: 0.11467143148183823


Epoch 9:  38%|███▊      | 10/26 [00:19<00:30,  1.91s/it]

Loss: 0.5792794227600098


Epoch 9:  42%|████▏     | 11/26 [00:20<00:28,  1.91s/it]

Loss: 0.1867508888244629


Epoch 9:  46%|████▌     | 12/26 [00:22<00:26,  1.91s/it]

Loss: 0.18922647833824158


Epoch 9:  50%|█████     | 13/26 [00:24<00:24,  1.91s/it]

Loss: 0.13043838739395142


Epoch 9:  54%|█████▍    | 14/26 [00:26<00:22,  1.91s/it]

Loss: 0.3894086480140686


Epoch 9:  58%|█████▊    | 15/26 [00:28<00:20,  1.91s/it]

Loss: 0.08055976778268814


Epoch 9:  62%|██████▏   | 16/26 [00:30<00:19,  1.91s/it]

Loss: 0.8561818599700928


Epoch 9:  65%|██████▌   | 17/26 [00:32<00:17,  1.91s/it]

Loss: 0.6419830322265625


Epoch 9:  69%|██████▉   | 18/26 [00:34<00:15,  1.91s/it]

Loss: 0.308008074760437


Epoch 9:  73%|███████▎  | 19/26 [00:36<00:13,  1.91s/it]

Loss: 0.5223416090011597


Epoch 9:  77%|███████▋  | 20/26 [00:38<00:11,  1.91s/it]

Loss: 0.09605167806148529


Epoch 9:  81%|████████  | 21/26 [00:40<00:09,  1.91s/it]

Loss: 0.022701343521475792


Epoch 9:  85%|████████▍ | 22/26 [00:41<00:07,  1.91s/it]

Loss: 0.8243069648742676


Epoch 9:  88%|████████▊ | 23/26 [00:43<00:05,  1.91s/it]

Loss: 0.35872405767440796


Epoch 9:  92%|█████████▏| 24/26 [00:45<00:03,  1.91s/it]

Loss: 0.2916269600391388


Epoch 9:  96%|█████████▌| 25/26 [00:47<00:01,  1.91s/it]

Loss: 0.72774338722229


Epoch 9: 100%|██████████| 26/26 [00:48<00:00,  1.87s/it]

Loss: 0.5531620383262634





Epoch 9 Validation Accuracy: 0.7241379310344828, F1-macro: 0.5556597873671044


Epoch 10:   4%|▍         | 1/26 [00:01<00:47,  1.90s/it]

Loss: 0.5181573629379272


Epoch 10:   8%|▊         | 2/26 [00:03<00:45,  1.91s/it]

Loss: 0.06308099627494812


Epoch 10:  12%|█▏        | 3/26 [00:05<00:43,  1.91s/it]

Loss: 0.1017027199268341


Epoch 10:  15%|█▌        | 4/26 [00:07<00:41,  1.91s/it]

Loss: 0.5977165699005127


Epoch 10:  19%|█▉        | 5/26 [00:09<00:40,  1.91s/it]

Loss: 0.51601243019104


Epoch 10:  23%|██▎       | 6/26 [00:11<00:38,  1.91s/it]

Loss: 0.7394294142723083


Epoch 10:  27%|██▋       | 7/26 [00:13<00:36,  1.91s/it]

Loss: 0.16131168603897095


Epoch 10:  31%|███       | 8/26 [00:15<00:34,  1.91s/it]

Loss: 0.3864450454711914


Epoch 10:  35%|███▍      | 9/26 [00:17<00:32,  1.91s/it]

Loss: 0.22480344772338867


Epoch 10:  38%|███▊      | 10/26 [00:19<00:30,  1.91s/it]

Loss: 0.03732297196984291


Epoch 10:  42%|████▏     | 11/26 [00:20<00:28,  1.91s/it]

Loss: 0.5599596500396729


Epoch 10:  46%|████▌     | 12/26 [00:22<00:26,  1.91s/it]

Loss: 0.544436514377594


Epoch 10:  50%|█████     | 13/26 [00:24<00:24,  1.91s/it]

Loss: 0.14931732416152954


Epoch 10:  54%|█████▍    | 14/26 [00:26<00:22,  1.91s/it]

Loss: 0.2315296232700348


Epoch 10:  58%|█████▊    | 15/26 [00:28<00:20,  1.91s/it]

Loss: 1.029800295829773


Epoch 10:  62%|██████▏   | 16/26 [00:30<00:19,  1.91s/it]

Loss: 0.11295247822999954


Epoch 10:  65%|██████▌   | 17/26 [00:32<00:17,  1.91s/it]

Loss: 0.8202337026596069


Epoch 10:  69%|██████▉   | 18/26 [00:34<00:15,  1.91s/it]

Loss: 0.7045934200286865


Epoch 10:  73%|███████▎  | 19/26 [00:36<00:13,  1.91s/it]

Loss: 0.40150123834609985


Epoch 10:  77%|███████▋  | 20/26 [00:38<00:11,  1.91s/it]

Loss: 0.26661935448646545


Epoch 10:  81%|████████  | 21/26 [00:40<00:09,  1.91s/it]

Loss: 0.32999294996261597


Epoch 10:  85%|████████▍ | 22/26 [00:41<00:07,  1.91s/it]

Loss: 0.6206018924713135


Epoch 10:  88%|████████▊ | 23/26 [00:43<00:05,  1.91s/it]

Loss: 0.13915851712226868


Epoch 10:  92%|█████████▏| 24/26 [00:45<00:03,  1.91s/it]

Loss: 0.6631505489349365


Epoch 10:  96%|█████████▌| 25/26 [00:47<00:01,  1.91s/it]

Loss: 0.42217063903808594


Epoch 10: 100%|██████████| 26/26 [00:48<00:00,  1.87s/it]

Loss: 0.5678803324699402





Epoch 10 Validation Accuracy: 0.8866995073891626, F1-macro: 0.5437310661585069


In [None]:
# Evaluate on the test set
test_metrics = evaluate(model, test_dataloader, device)
print(f"Test Accuracy: {test_metrics['accuracy']}, F1-macro: {test_metrics['f1_macro']}")

Test Accuracy: 0.9117647058823529, F1-macro: 0.5437310661585069


In [None]:
current_type = 'should statements'

results_df.loc[
    results_df["distortion_type"] == current_type,
    ["test_accuracy", "f1_macro"]
] = [test_metrics['accuracy'], test_metrics['f1_macro']]

# 11. Mental Filter

In [None]:
# Add labels
data1_1_labels = list(data1['mental filter'][data1_1.index])
data2_1_labels = list(data2['mental filter'][data2_1.index])
data3_1_labels = list(data3['mental filter'][data3_1.index])

# Merging Data
data_encoded = data1_1_encoded + data2_1_encoded + data3_1_encoded
data_labels = data1_1_labels + data2_1_labels + data3_1_labels

dataset_with_labels = CustomDatasetWithLabels(data_encoded, data_labels)

# Define proportions for splitting
train_size = int(0.8 * len(dataset_with_labels))
val_size = int(0.1 * len(dataset_with_labels))
test_size = len(dataset_with_labels) - train_size - val_size

# Split the dataset
train_dataset, val_dataset, test_dataset = random_split(dataset_with_labels, [train_size, val_size, test_size])

# Create DataLoaders for each set
train_dataloader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=64)
test_dataloader = DataLoader(test_dataset, batch_size=64)

input_dim = bert_config.hidden_size
num_labels = len(set(data_labels))
label_emb = torch.randn(num_labels, input_dim)

In [None]:
# Instantiate the InnerProductClassifier model
model = InnerProductClassifier(input_dim, label_emb).to(device)

# Define loss function
criterion = nn.CrossEntropyLoss()

# Define optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=2e-4)

In [None]:
EPOCHS = 10

# Move the BERT model to the device
bert_model.to(device)

for epoch in range(EPOCHS):
    model.train()
    for batch in tqdm(train_dataloader, desc=f"Epoch {epoch+1}"):

        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        y = batch["y"].to(device)

        # Get embeddings from the BERT model
        with torch.no_grad():
            outputs = bert_model(input_ids=input_ids, attention_mask=attention_mask)
            embeddings = outputs.last_hidden_state[:, 0, :] # Get the [CLS] token embedding

        # Pass embeddings to the InnerProductClassifier
        logits = model(embeddings)
        loss = criterion(logits, y)

        # Backpropagate and update weights
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        print(f"Loss: {loss.item()}")

    # Evaluate on the validation set after each epoch
    val_metrics = evaluate(model, val_dataloader, device)
    print(f"Epoch {epoch+1} Validation Accuracy: {val_metrics['accuracy']}, F1-macro: {val_metrics['f1_macro']}")

Epoch 1:   4%|▍         | 1/26 [00:01<00:47,  1.89s/it]

Loss: 0.6224747896194458


Epoch 1:   8%|▊         | 2/26 [00:03<00:45,  1.90s/it]

Loss: 3.9320268630981445


Epoch 1:  12%|█▏        | 3/26 [00:05<00:43,  1.90s/it]

Loss: 2.0346739292144775


Epoch 1:  15%|█▌        | 4/26 [00:07<00:41,  1.90s/it]

Loss: 1.7619720697402954


Epoch 1:  19%|█▉        | 5/26 [00:09<00:39,  1.90s/it]

Loss: 0.7500898838043213


Epoch 1:  23%|██▎       | 6/26 [00:11<00:38,  1.91s/it]

Loss: 1.231073021888733


Epoch 1:  27%|██▋       | 7/26 [00:13<00:36,  1.91s/it]

Loss: 1.734997034072876


Epoch 1:  31%|███       | 8/26 [00:15<00:34,  1.91s/it]

Loss: 0.28590601682662964


Epoch 1:  35%|███▍      | 9/26 [00:17<00:32,  1.91s/it]

Loss: 0.8128634691238403


Epoch 1:  38%|███▊      | 10/26 [00:19<00:30,  1.91s/it]

Loss: 2.326658010482788


Epoch 1:  42%|████▏     | 11/26 [00:20<00:28,  1.91s/it]

Loss: 0.47759848833084106


Epoch 1:  46%|████▌     | 12/26 [00:22<00:26,  1.91s/it]

Loss: 1.4709866046905518


Epoch 1:  50%|█████     | 13/26 [00:24<00:24,  1.91s/it]

Loss: 1.5071121454238892


Epoch 1:  54%|█████▍    | 14/26 [00:26<00:22,  1.91s/it]

Loss: 1.7551562786102295


Epoch 1:  58%|█████▊    | 15/26 [00:28<00:20,  1.91s/it]

Loss: 2.034745693206787


Epoch 1:  62%|██████▏   | 16/26 [00:30<00:19,  1.91s/it]

Loss: 0.5716239213943481


Epoch 1:  65%|██████▌   | 17/26 [00:32<00:17,  1.91s/it]

Loss: 1.1034671068191528


Epoch 1:  69%|██████▉   | 18/26 [00:34<00:15,  1.91s/it]

Loss: 1.6121766567230225


Epoch 1:  73%|███████▎  | 19/26 [00:36<00:13,  1.91s/it]

Loss: 1.0982016324996948


Epoch 1:  77%|███████▋  | 20/26 [00:38<00:11,  1.91s/it]

Loss: 0.7170232534408569


Epoch 1:  81%|████████  | 21/26 [00:40<00:09,  1.91s/it]

Loss: 0.9544140100479126


Epoch 1:  85%|████████▍ | 22/26 [00:41<00:07,  1.91s/it]

Loss: 1.047792911529541


Epoch 1:  88%|████████▊ | 23/26 [00:43<00:05,  1.91s/it]

Loss: 2.1369779109954834


Epoch 1:  92%|█████████▏| 24/26 [00:45<00:03,  1.91s/it]

Loss: 1.7167640924453735


Epoch 1:  96%|█████████▌| 25/26 [00:47<00:01,  1.91s/it]

Loss: 0.5382542610168457


Epoch 1: 100%|██████████| 26/26 [00:48<00:00,  1.87s/it]

Loss: 0.5437771677970886





Epoch 1 Validation Accuracy: 0.5960591133004927, F1-macro: 0.4631707946336429


Epoch 2:   4%|▍         | 1/26 [00:01<00:47,  1.90s/it]

Loss: 1.0991666316986084


Epoch 2:   8%|▊         | 2/26 [00:03<00:45,  1.90s/it]

Loss: 0.8598349690437317


Epoch 2:  12%|█▏        | 3/26 [00:05<00:43,  1.91s/it]

Loss: 0.7449178695678711


Epoch 2:  15%|█▌        | 4/26 [00:07<00:41,  1.91s/it]

Loss: 0.3091592490673065


Epoch 2:  19%|█▉        | 5/26 [00:09<00:39,  1.90s/it]

Loss: 0.6694560647010803


Epoch 2:  23%|██▎       | 6/26 [00:11<00:38,  1.91s/it]

Loss: 0.3921857476234436


Epoch 2:  27%|██▋       | 7/26 [00:13<00:36,  1.91s/it]

Loss: 0.26072609424591064


Epoch 2:  31%|███       | 8/26 [00:15<00:34,  1.91s/it]

Loss: 1.0025373697280884


Epoch 2:  35%|███▍      | 9/26 [00:17<00:32,  1.91s/it]

Loss: 0.5542253851890564


Epoch 2:  38%|███▊      | 10/26 [00:19<00:30,  1.91s/it]

Loss: 0.24836619198322296


Epoch 2:  42%|████▏     | 11/26 [00:20<00:28,  1.91s/it]

Loss: 0.24852073192596436


Epoch 2:  46%|████▌     | 12/26 [00:22<00:26,  1.91s/it]

Loss: 0.3177831768989563


Epoch 2:  50%|█████     | 13/26 [00:24<00:24,  1.91s/it]

Loss: 0.38930636644363403


Epoch 2:  54%|█████▍    | 14/26 [00:26<00:22,  1.91s/it]

Loss: 0.8256179094314575


Epoch 2:  58%|█████▊    | 15/26 [00:28<00:20,  1.91s/it]

Loss: 0.6709551811218262


Epoch 2:  62%|██████▏   | 16/26 [00:30<00:19,  1.91s/it]

Loss: 0.674775242805481


Epoch 2:  65%|██████▌   | 17/26 [00:32<00:17,  1.91s/it]

Loss: 0.22896482050418854


Epoch 2:  69%|██████▉   | 18/26 [00:34<00:15,  1.91s/it]

Loss: 0.26101022958755493


Epoch 2:  73%|███████▎  | 19/26 [00:36<00:13,  1.91s/it]

Loss: 0.3709447383880615


Epoch 2:  77%|███████▋  | 20/26 [00:38<00:11,  1.91s/it]

Loss: 0.34235626459121704


Epoch 2:  81%|████████  | 21/26 [00:40<00:09,  1.91s/it]

Loss: 0.8158877491950989


Epoch 2:  85%|████████▍ | 22/26 [00:41<00:07,  1.91s/it]

Loss: 0.4426514208316803


Epoch 2:  88%|████████▊ | 23/26 [00:43<00:05,  1.91s/it]

Loss: 0.7035520076751709


Epoch 2:  92%|█████████▏| 24/26 [00:45<00:03,  1.91s/it]

Loss: 0.33820822834968567


Epoch 2:  96%|█████████▌| 25/26 [00:47<00:01,  1.91s/it]

Loss: 0.1606869101524353


Epoch 2: 100%|██████████| 26/26 [00:48<00:00,  1.87s/it]

Loss: 0.23761340975761414





Epoch 2 Validation Accuracy: 0.9014778325123153, F1-macro: 0.6164021164021164


Epoch 3:   4%|▍         | 1/26 [00:01<00:47,  1.90s/it]

Loss: 0.18437406420707703


Epoch 3:   8%|▊         | 2/26 [00:03<00:45,  1.90s/it]

Loss: 0.22768637537956238


Epoch 3:  12%|█▏        | 3/26 [00:05<00:43,  1.91s/it]

Loss: 0.5318028330802917


Epoch 3:  15%|█▌        | 4/26 [00:07<00:41,  1.91s/it]

Loss: 0.5716545581817627


Epoch 3:  19%|█▉        | 5/26 [00:09<00:40,  1.91s/it]

Loss: 1.2434412240982056


Epoch 3:  23%|██▎       | 6/26 [00:11<00:38,  1.91s/it]

Loss: 0.24564777314662933


Epoch 3:  27%|██▋       | 7/26 [00:13<00:36,  1.91s/it]

Loss: 0.7546949982643127


Epoch 3:  31%|███       | 8/26 [00:15<00:34,  1.91s/it]

Loss: 0.02612007036805153


Epoch 3:  35%|███▍      | 9/26 [00:17<00:32,  1.91s/it]

Loss: 0.44589605927467346


Epoch 3:  38%|███▊      | 10/26 [00:19<00:30,  1.91s/it]

Loss: 1.806538462638855


Epoch 3:  42%|████▏     | 11/26 [00:20<00:28,  1.91s/it]

Loss: 0.6153855323791504


Epoch 3:  46%|████▌     | 12/26 [00:22<00:26,  1.91s/it]

Loss: 1.1218217611312866


Epoch 3:  50%|█████     | 13/26 [00:24<00:24,  1.91s/it]

Loss: 1.234412431716919


Epoch 3:  54%|█████▍    | 14/26 [00:26<00:22,  1.91s/it]

Loss: 1.056235432624817


Epoch 3:  58%|█████▊    | 15/26 [00:28<00:20,  1.91s/it]

Loss: 0.680928111076355


Epoch 3:  62%|██████▏   | 16/26 [00:30<00:19,  1.91s/it]

Loss: 0.3899860084056854


Epoch 3:  65%|██████▌   | 17/26 [00:32<00:17,  1.91s/it]

Loss: 1.0653388500213623


Epoch 3:  69%|██████▉   | 18/26 [00:34<00:15,  1.91s/it]

Loss: 0.585912823677063


Epoch 3:  73%|███████▎  | 19/26 [00:36<00:13,  1.91s/it]

Loss: 0.7606202960014343


Epoch 3:  77%|███████▋  | 20/26 [00:38<00:11,  1.91s/it]

Loss: 1.222647786140442


Epoch 3:  81%|████████  | 21/26 [00:40<00:09,  1.91s/it]

Loss: 0.2540780007839203


Epoch 3:  85%|████████▍ | 22/26 [00:41<00:07,  1.91s/it]

Loss: 0.9743219614028931


Epoch 3:  88%|████████▊ | 23/26 [00:43<00:05,  1.91s/it]

Loss: 0.15910539031028748


Epoch 3:  92%|█████████▏| 24/26 [00:45<00:03,  1.91s/it]

Loss: 1.2159850597381592


Epoch 3:  96%|█████████▌| 25/26 [00:47<00:01,  1.91s/it]

Loss: 0.6979985237121582


Epoch 3: 100%|██████████| 26/26 [00:48<00:00,  1.87s/it]

Loss: 0.2925736904144287





Epoch 3 Validation Accuracy: 0.9113300492610837, F1-macro: 0.47680412371134023


Epoch 4:   4%|▍         | 1/26 [00:01<00:47,  1.90s/it]

Loss: 1.2159837484359741


Epoch 4:   8%|▊         | 2/26 [00:03<00:45,  1.90s/it]

Loss: 0.47064921259880066


Epoch 4:  12%|█▏        | 3/26 [00:05<00:43,  1.90s/it]

Loss: 0.3913637101650238


Epoch 4:  15%|█▌        | 4/26 [00:07<00:41,  1.91s/it]

Loss: 0.5530568361282349


Epoch 4:  19%|█▉        | 5/26 [00:09<00:40,  1.91s/it]

Loss: 0.26074856519699097


Epoch 4:  23%|██▎       | 6/26 [00:11<00:38,  1.91s/it]

Loss: 0.7613478899002075


Epoch 4:  27%|██▋       | 7/26 [00:13<00:36,  1.91s/it]

Loss: 1.0444865226745605


Epoch 4:  31%|███       | 8/26 [00:15<00:34,  1.91s/it]

Loss: 1.1036566495895386


Epoch 4:  35%|███▍      | 9/26 [00:17<00:32,  1.91s/it]

Loss: 0.6844237446784973


Epoch 4:  38%|███▊      | 10/26 [00:19<00:30,  1.91s/it]

Loss: 0.37733954191207886


Epoch 4:  42%|████▏     | 11/26 [00:20<00:28,  1.91s/it]

Loss: 1.4208096265792847


Epoch 4:  46%|████▌     | 12/26 [00:22<00:26,  1.91s/it]

Loss: 0.5475776195526123


Epoch 4:  50%|█████     | 13/26 [00:24<00:24,  1.91s/it]

Loss: 1.252363681793213


Epoch 4:  54%|█████▍    | 14/26 [00:26<00:22,  1.91s/it]

Loss: 0.5440936088562012


Epoch 4:  58%|█████▊    | 15/26 [00:28<00:20,  1.91s/it]

Loss: 0.20809902250766754


Epoch 4:  62%|██████▏   | 16/26 [00:30<00:19,  1.91s/it]

Loss: 0.1765545904636383


Epoch 4:  65%|██████▌   | 17/26 [00:32<00:17,  1.91s/it]

Loss: 0.4211960434913635


Epoch 4:  69%|██████▉   | 18/26 [00:34<00:15,  1.91s/it]

Loss: 0.2717626988887787


Epoch 4:  73%|███████▎  | 19/26 [00:36<00:13,  1.91s/it]

Loss: 0.6489545106887817


Epoch 4:  77%|███████▋  | 20/26 [00:38<00:11,  1.91s/it]

Loss: 0.4167841672897339


Epoch 4:  81%|████████  | 21/26 [00:40<00:09,  1.91s/it]

Loss: 0.12467782199382782


Epoch 4:  85%|████████▍ | 22/26 [00:41<00:07,  1.91s/it]

Loss: 0.5234827399253845


Epoch 4:  88%|████████▊ | 23/26 [00:43<00:05,  1.91s/it]

Loss: 0.3886946141719818


Epoch 4:  92%|█████████▏| 24/26 [00:45<00:03,  1.91s/it]

Loss: 0.34985536336898804


Epoch 4:  96%|█████████▌| 25/26 [00:47<00:01,  1.91s/it]

Loss: 0.795791745185852


Epoch 4: 100%|██████████| 26/26 [00:48<00:00,  1.87s/it]

Loss: 0.5989579558372498





Epoch 4 Validation Accuracy: 0.4482758620689655, F1-macro: 0.38082788671023965


Epoch 5:   4%|▍         | 1/26 [00:01<00:47,  1.90s/it]

Loss: 1.8504915237426758


Epoch 5:   8%|▊         | 2/26 [00:03<00:45,  1.91s/it]

Loss: 0.18339315056800842


Epoch 5:  12%|█▏        | 3/26 [00:05<00:43,  1.91s/it]

Loss: 1.3284051418304443


Epoch 5:  15%|█▌        | 4/26 [00:07<00:41,  1.91s/it]

Loss: 0.6041406393051147


Epoch 5:  19%|█▉        | 5/26 [00:09<00:40,  1.91s/it]

Loss: 1.670546531677246


Epoch 5:  23%|██▎       | 6/26 [00:11<00:38,  1.91s/it]

Loss: 2.8641197681427


Epoch 5:  27%|██▋       | 7/26 [00:13<00:36,  1.91s/it]

Loss: 1.1671706438064575


Epoch 5:  31%|███       | 8/26 [00:15<00:34,  1.91s/it]

Loss: 1.2924761772155762


Epoch 5:  35%|███▍      | 9/26 [00:17<00:32,  1.91s/it]

Loss: 0.7594970464706421


Epoch 5:  38%|███▊      | 10/26 [00:19<00:30,  1.91s/it]

Loss: 1.4134525060653687


Epoch 5:  42%|████▏     | 11/26 [00:20<00:28,  1.91s/it]

Loss: 1.0508238077163696


Epoch 5:  46%|████▌     | 12/26 [00:22<00:26,  1.91s/it]

Loss: 0.6201173663139343


Epoch 5:  50%|█████     | 13/26 [00:24<00:24,  1.91s/it]

Loss: 1.7875266075134277


Epoch 5:  54%|█████▍    | 14/26 [00:26<00:22,  1.91s/it]

Loss: 0.7786862254142761


Epoch 5:  58%|█████▊    | 15/26 [00:28<00:20,  1.91s/it]

Loss: 1.0366851091384888


Epoch 5:  62%|██████▏   | 16/26 [00:30<00:19,  1.91s/it]

Loss: 0.5978765487670898


Epoch 5:  65%|██████▌   | 17/26 [00:32<00:17,  1.91s/it]

Loss: 0.8163793683052063


Epoch 5:  69%|██████▉   | 18/26 [00:34<00:15,  1.91s/it]

Loss: 0.4916089177131653


Epoch 5:  73%|███████▎  | 19/26 [00:36<00:13,  1.91s/it]

Loss: 0.49385756254196167


Epoch 5:  77%|███████▋  | 20/26 [00:38<00:11,  1.91s/it]

Loss: 0.7704395651817322


Epoch 5:  81%|████████  | 21/26 [00:40<00:09,  1.91s/it]

Loss: 1.0684839487075806


Epoch 5:  85%|████████▍ | 22/26 [00:41<00:07,  1.91s/it]

Loss: 1.3555338382720947


Epoch 5:  88%|████████▊ | 23/26 [00:43<00:05,  1.91s/it]

Loss: 0.4612368047237396


Epoch 5:  92%|█████████▏| 24/26 [00:45<00:03,  1.91s/it]

Loss: 0.46477702260017395


Epoch 5:  96%|█████████▌| 25/26 [00:47<00:01,  1.91s/it]

Loss: 0.7421839833259583


Epoch 5: 100%|██████████| 26/26 [00:48<00:00,  1.87s/it]

Loss: 0.2901889979839325





Epoch 5 Validation Accuracy: 0.8719211822660099, F1-macro: 0.55899064171123


Epoch 6:   4%|▍         | 1/26 [00:01<00:47,  1.90s/it]

Loss: 0.6076936721801758


Epoch 6:   8%|▊         | 2/26 [00:03<00:45,  1.90s/it]

Loss: 0.09836888313293457


Epoch 6:  12%|█▏        | 3/26 [00:05<00:43,  1.91s/it]

Loss: 1.1042134761810303


Epoch 6:  15%|█▌        | 4/26 [00:07<00:41,  1.91s/it]

Loss: 0.14197564125061035


Epoch 6:  19%|█▉        | 5/26 [00:09<00:40,  1.91s/it]

Loss: 0.1990666687488556


Epoch 6:  23%|██▎       | 6/26 [00:11<00:38,  1.91s/it]

Loss: 0.4696594476699829


Epoch 6:  27%|██▋       | 7/26 [00:13<00:36,  1.91s/it]

Loss: 0.631839394569397


Epoch 6:  31%|███       | 8/26 [00:15<00:34,  1.91s/it]

Loss: 0.10596782714128494


Epoch 6:  35%|███▍      | 9/26 [00:17<00:32,  1.91s/it]

Loss: 0.24667666852474213


Epoch 6:  38%|███▊      | 10/26 [00:19<00:30,  1.91s/it]

Loss: 0.41118744015693665


Epoch 6:  42%|████▏     | 11/26 [00:20<00:28,  1.91s/it]

Loss: 1.311995267868042


Epoch 6:  46%|████▌     | 12/26 [00:22<00:26,  1.91s/it]

Loss: 0.3159938156604767


Epoch 6:  50%|█████     | 13/26 [00:24<00:24,  1.91s/it]

Loss: 1.0181703567504883


Epoch 6:  54%|█████▍    | 14/26 [00:26<00:22,  1.91s/it]

Loss: 0.48171424865722656


Epoch 6:  58%|█████▊    | 15/26 [00:28<00:20,  1.91s/it]

Loss: 0.45562267303466797


Epoch 6:  62%|██████▏   | 16/26 [00:30<00:19,  1.91s/it]

Loss: 0.971582293510437


Epoch 6:  65%|██████▌   | 17/26 [00:32<00:17,  1.91s/it]

Loss: 0.7338030338287354


Epoch 6:  69%|██████▉   | 18/26 [00:34<00:15,  1.91s/it]

Loss: 0.6424034833908081


Epoch 6:  73%|███████▎  | 19/26 [00:36<00:13,  1.91s/it]

Loss: 1.3401528596878052


Epoch 6:  77%|███████▋  | 20/26 [00:38<00:11,  1.90s/it]

Loss: 0.4640291929244995


Epoch 6:  81%|████████  | 21/26 [00:40<00:09,  1.90s/it]

Loss: 0.9228047728538513


Epoch 6:  85%|████████▍ | 22/26 [00:41<00:07,  1.91s/it]

Loss: 0.4459168314933777


Epoch 6:  88%|████████▊ | 23/26 [00:43<00:05,  1.91s/it]

Loss: 0.37987184524536133


Epoch 6:  92%|█████████▏| 24/26 [00:45<00:03,  1.91s/it]

Loss: 0.7878046035766602


Epoch 6:  96%|█████████▌| 25/26 [00:47<00:01,  1.91s/it]

Loss: 0.7925603985786438


Epoch 6: 100%|██████████| 26/26 [00:48<00:00,  1.87s/it]

Loss: 0.13776226341724396





Epoch 6 Validation Accuracy: 0.8472906403940886, F1-macro: 0.5794854660875376


Epoch 7:   4%|▍         | 1/26 [00:01<00:47,  1.90s/it]

Loss: 0.33298826217651367


Epoch 7:   8%|▊         | 2/26 [00:03<00:45,  1.90s/it]

Loss: 0.31238698959350586


Epoch 7:  12%|█▏        | 3/26 [00:05<00:43,  1.90s/it]

Loss: 0.15235093235969543


Epoch 7:  15%|█▌        | 4/26 [00:07<00:41,  1.91s/it]

Loss: 0.2461533099412918


Epoch 7:  19%|█▉        | 5/26 [00:09<00:39,  1.90s/it]

Loss: 0.41813960671424866


Epoch 7:  23%|██▎       | 6/26 [00:11<00:38,  1.91s/it]

Loss: 0.7750775814056396


Epoch 7:  27%|██▋       | 7/26 [00:13<00:36,  1.91s/it]

Loss: 0.07305624336004257


Epoch 7:  31%|███       | 8/26 [00:15<00:34,  1.91s/it]

Loss: 0.21052774786949158


Epoch 7:  35%|███▍      | 9/26 [00:17<00:32,  1.91s/it]

Loss: 0.7036210298538208


Epoch 7:  38%|███▊      | 10/26 [00:19<00:30,  1.91s/it]

Loss: 0.08583390712738037


Epoch 7:  42%|████▏     | 11/26 [00:20<00:28,  1.91s/it]

Loss: 0.4650647044181824


Epoch 7:  46%|████▌     | 12/26 [00:22<00:26,  1.91s/it]

Loss: 0.6091010570526123


Epoch 7:  50%|█████     | 13/26 [00:24<00:24,  1.91s/it]

Loss: 0.550089418888092


Epoch 7:  54%|█████▍    | 14/26 [00:26<00:22,  1.91s/it]

Loss: 0.2514142096042633


Epoch 7:  58%|█████▊    | 15/26 [00:28<00:20,  1.90s/it]

Loss: 0.32156845927238464


Epoch 7:  62%|██████▏   | 16/26 [00:30<00:19,  1.90s/it]

Loss: 0.24591678380966187


Epoch 7:  65%|██████▌   | 17/26 [00:32<00:17,  1.90s/it]

Loss: 0.5615354776382446


Epoch 7:  69%|██████▉   | 18/26 [00:34<00:15,  1.90s/it]

Loss: 0.29752466082572937


Epoch 7:  73%|███████▎  | 19/26 [00:36<00:13,  1.91s/it]

Loss: 0.20287267863750458


Epoch 7:  77%|███████▋  | 20/26 [00:38<00:11,  1.91s/it]

Loss: 0.2906401753425598


Epoch 7:  81%|████████  | 21/26 [00:40<00:09,  1.91s/it]

Loss: 0.2188268005847931


Epoch 7:  85%|████████▍ | 22/26 [00:41<00:07,  1.91s/it]

Loss: 0.4555782675743103


Epoch 7:  88%|████████▊ | 23/26 [00:43<00:05,  1.91s/it]

Loss: 0.2543376386165619


Epoch 7:  92%|█████████▏| 24/26 [00:45<00:03,  1.91s/it]

Loss: 0.5454756021499634


Epoch 7:  96%|█████████▌| 25/26 [00:47<00:01,  1.91s/it]

Loss: 0.7600970268249512


Epoch 7: 100%|██████████| 26/26 [00:48<00:00,  1.87s/it]

Loss: 1.0372098684310913





Epoch 7 Validation Accuracy: 0.896551724137931, F1-macro: 0.4727272727272727


Epoch 8:   4%|▍         | 1/26 [00:01<00:47,  1.90s/it]

Loss: 0.299360990524292


Epoch 8:   8%|▊         | 2/26 [00:03<00:45,  1.90s/it]

Loss: 0.3265790641307831


Epoch 8:  12%|█▏        | 3/26 [00:05<00:43,  1.90s/it]

Loss: 0.5957916975021362


Epoch 8:  15%|█▌        | 4/26 [00:07<00:41,  1.90s/it]

Loss: 0.1745591163635254


Epoch 8:  19%|█▉        | 5/26 [00:09<00:39,  1.90s/it]

Loss: 0.5733409523963928


Epoch 8:  23%|██▎       | 6/26 [00:11<00:38,  1.90s/it]

Loss: 0.7556017637252808


Epoch 8:  27%|██▋       | 7/26 [00:13<00:36,  1.91s/it]

Loss: 0.36498555541038513


Epoch 8:  31%|███       | 8/26 [00:15<00:34,  1.91s/it]

Loss: 0.4318695664405823


Epoch 8:  35%|███▍      | 9/26 [00:17<00:32,  1.91s/it]

Loss: 0.3519980311393738


Epoch 8:  38%|███▊      | 10/26 [00:19<00:30,  1.91s/it]

Loss: 0.45055538415908813


Epoch 8:  42%|████▏     | 11/26 [00:20<00:28,  1.91s/it]

Loss: 0.20624513924121857


Epoch 8:  46%|████▌     | 12/26 [00:22<00:26,  1.91s/it]

Loss: 0.5369654893875122


Epoch 8:  50%|█████     | 13/26 [00:24<00:24,  1.91s/it]

Loss: 0.960931658744812


Epoch 8:  54%|█████▍    | 14/26 [00:26<00:22,  1.91s/it]

Loss: 0.3927844762802124


Epoch 8:  58%|█████▊    | 15/26 [00:28<00:20,  1.91s/it]

Loss: 0.530972957611084


Epoch 8:  62%|██████▏   | 16/26 [00:30<00:19,  1.91s/it]

Loss: 0.3136925995349884


Epoch 8:  65%|██████▌   | 17/26 [00:32<00:17,  1.91s/it]

Loss: 0.5581479072570801


Epoch 8:  69%|██████▉   | 18/26 [00:34<00:15,  1.91s/it]

Loss: 0.7472131848335266


Epoch 8:  73%|███████▎  | 19/26 [00:36<00:13,  1.91s/it]

Loss: 0.366788387298584


Epoch 8:  77%|███████▋  | 20/26 [00:38<00:11,  1.91s/it]

Loss: 0.15683580935001373


Epoch 8:  81%|████████  | 21/26 [00:40<00:09,  1.91s/it]

Loss: 0.3650130033493042


Epoch 8:  85%|████████▍ | 22/26 [00:41<00:07,  1.91s/it]

Loss: 0.17940416932106018


Epoch 8:  88%|████████▊ | 23/26 [00:43<00:05,  1.91s/it]

Loss: 0.1852908879518509


Epoch 8:  92%|█████████▏| 24/26 [00:45<00:03,  1.91s/it]

Loss: 0.30229073762893677


Epoch 8:  96%|█████████▌| 25/26 [00:47<00:01,  1.91s/it]

Loss: 0.3194109797477722


Epoch 8: 100%|██████████| 26/26 [00:48<00:00,  1.87s/it]

Loss: 0.6124065518379211





Epoch 8 Validation Accuracy: 0.7487684729064039, F1-macro: 0.47956567636857184


Epoch 9:   4%|▍         | 1/26 [00:01<00:47,  1.91s/it]

Loss: 0.30208471417427063


Epoch 9:   8%|▊         | 2/26 [00:03<00:45,  1.91s/it]

Loss: 0.2182798683643341


Epoch 9:  12%|█▏        | 3/26 [00:05<00:43,  1.91s/it]

Loss: 0.3752429783344269


Epoch 9:  15%|█▌        | 4/26 [00:07<00:42,  1.91s/it]

Loss: 0.3561452627182007


Epoch 9:  19%|█▉        | 5/26 [00:09<00:40,  1.91s/it]

Loss: 0.25255805253982544


Epoch 9:  23%|██▎       | 6/26 [00:11<00:38,  1.91s/it]

Loss: 0.4313199818134308


Epoch 9:  27%|██▋       | 7/26 [00:13<00:36,  1.91s/it]

Loss: 0.6790273785591125


Epoch 9:  31%|███       | 8/26 [00:15<00:34,  1.91s/it]

Loss: 0.16669586300849915


Epoch 9:  35%|███▍      | 9/26 [00:17<00:32,  1.91s/it]

Loss: 0.698843777179718


Epoch 9:  38%|███▊      | 10/26 [00:19<00:30,  1.91s/it]

Loss: 0.48856788873672485


Epoch 9:  42%|████▏     | 11/26 [00:21<00:28,  1.91s/it]

Loss: 0.15922428667545319


Epoch 9:  46%|████▌     | 12/26 [00:22<00:26,  1.91s/it]

Loss: 0.38512474298477173


Epoch 9:  50%|█████     | 13/26 [00:24<00:24,  1.91s/it]

Loss: 0.5016423463821411


Epoch 9:  54%|█████▍    | 14/26 [00:26<00:22,  1.91s/it]

Loss: 0.4359598159790039


Epoch 9:  58%|█████▊    | 15/26 [00:28<00:21,  1.91s/it]

Loss: 0.13598406314849854


Epoch 9:  62%|██████▏   | 16/26 [00:30<00:19,  1.91s/it]

Loss: 0.35667848587036133


Epoch 9:  65%|██████▌   | 17/26 [00:32<00:17,  1.91s/it]

Loss: 0.2152576446533203


Epoch 9:  69%|██████▉   | 18/26 [00:34<00:15,  1.91s/it]

Loss: 0.2527918517589569


Epoch 9:  73%|███████▎  | 19/26 [00:36<00:13,  1.91s/it]

Loss: 0.17467299103736877


Epoch 9:  77%|███████▋  | 20/26 [00:38<00:11,  1.91s/it]

Loss: 0.19256320595741272


Epoch 9:  81%|████████  | 21/26 [00:40<00:09,  1.91s/it]

Loss: 0.6810723543167114


Epoch 9:  85%|████████▍ | 22/26 [00:42<00:07,  1.91s/it]

Loss: 0.17385338246822357


Epoch 9:  88%|████████▊ | 23/26 [00:43<00:05,  1.91s/it]

Loss: 0.3129641115665436


Epoch 9:  92%|█████████▏| 24/26 [00:45<00:03,  1.91s/it]

Loss: 0.7366100549697876


Epoch 9:  96%|█████████▌| 25/26 [00:47<00:01,  1.91s/it]

Loss: 0.1794881671667099


Epoch 9: 100%|██████████| 26/26 [00:48<00:00,  1.87s/it]

Loss: 0.16150443255901337





Epoch 9 Validation Accuracy: 0.8325123152709359, F1-macro: 0.5832125603864734


Epoch 10:   4%|▍         | 1/26 [00:01<00:47,  1.90s/it]

Loss: 0.12169528007507324


Epoch 10:   8%|▊         | 2/26 [00:03<00:45,  1.91s/it]

Loss: 0.26977235078811646


Epoch 10:  12%|█▏        | 3/26 [00:05<00:43,  1.91s/it]

Loss: 0.4523339569568634


Epoch 10:  15%|█▌        | 4/26 [00:07<00:41,  1.91s/it]

Loss: 0.5280041694641113


Epoch 10:  19%|█▉        | 5/26 [00:09<00:40,  1.91s/it]

Loss: 0.3283437192440033


Epoch 10:  23%|██▎       | 6/26 [00:11<00:38,  1.91s/it]

Loss: 0.3328050971031189


Epoch 10:  27%|██▋       | 7/26 [00:13<00:36,  1.91s/it]

Loss: 0.2745865285396576


Epoch 10:  31%|███       | 8/26 [00:15<00:34,  1.91s/it]

Loss: 0.14264193177223206


Epoch 10:  35%|███▍      | 9/26 [00:17<00:32,  1.91s/it]

Loss: 0.2260616570711136


Epoch 10:  38%|███▊      | 10/26 [00:19<00:30,  1.91s/it]

Loss: 0.35342034697532654


Epoch 10:  42%|████▏     | 11/26 [00:20<00:28,  1.91s/it]

Loss: 0.4298384189605713


Epoch 10:  46%|████▌     | 12/26 [00:22<00:26,  1.91s/it]

Loss: 0.12643715739250183


Epoch 10:  50%|█████     | 13/26 [00:24<00:24,  1.91s/it]

Loss: 0.26731210947036743


Epoch 10:  54%|█████▍    | 14/26 [00:26<00:22,  1.91s/it]

Loss: 0.1771252453327179


Epoch 10:  58%|█████▊    | 15/26 [00:28<00:21,  1.91s/it]

Loss: 0.11364297568798065


Epoch 10:  62%|██████▏   | 16/26 [00:30<00:19,  1.91s/it]

Loss: 0.12439712882041931


Epoch 10:  65%|██████▌   | 17/26 [00:32<00:17,  1.91s/it]

Loss: 0.3382686376571655


Epoch 10:  69%|██████▉   | 18/26 [00:34<00:15,  1.91s/it]

Loss: 0.230375736951828


Epoch 10:  73%|███████▎  | 19/26 [00:36<00:13,  1.91s/it]

Loss: 0.11535342037677765


Epoch 10:  77%|███████▋  | 20/26 [00:38<00:11,  1.91s/it]

Loss: 0.38979512453079224


Epoch 10:  81%|████████  | 21/26 [00:40<00:09,  1.91s/it]

Loss: 0.29024243354797363


Epoch 10:  85%|████████▍ | 22/26 [00:42<00:07,  1.91s/it]

Loss: 0.23255857825279236


Epoch 10:  88%|████████▊ | 23/26 [00:43<00:05,  1.91s/it]

Loss: 0.21023324131965637


Epoch 10:  92%|█████████▏| 24/26 [00:45<00:03,  1.91s/it]

Loss: 0.267554372549057


Epoch 10:  96%|█████████▌| 25/26 [00:47<00:01,  1.91s/it]

Loss: 0.19576497375965118


Epoch 10: 100%|██████████| 26/26 [00:48<00:00,  1.87s/it]

Loss: 0.46493667364120483





Epoch 10 Validation Accuracy: 0.9014778325123153, F1-macro: 0.4740932642487047


In [None]:
# Evaluate on the test set
test_metrics = evaluate(model, test_dataloader, device)
print(f"Test Accuracy: {test_metrics['accuracy']}, F1-macro: {test_metrics['f1_macro']}")

Test Accuracy: 0.9117647058823529, F1-macro: 0.4740932642487047


In [None]:
current_type = 'mental filter'

results_df.loc[
    results_df["distortion_type"] == current_type,
    ["test_accuracy", "f1_macro"]
] = [test_metrics['accuracy'], test_metrics['f1_macro']]

# 12. Personalization and Blaming

In [None]:
# Add labels
data1_1_labels = list(data1['personalization and blaming'][data1_1.index])
data3_1_labels = list(data3['personalization and blaming'][data3_1.index]) # data2에는 없음.

# Merging Data
data_encoded = data1_1_encoded + data3_1_encoded
data_labels = data1_1_labels + data3_1_labels

dataset_with_labels = CustomDatasetWithLabels(data_encoded, data_labels)

# Define proportions for splitting
train_size = int(0.8 * len(dataset_with_labels))
val_size = int(0.1 * len(dataset_with_labels))
test_size = len(dataset_with_labels) - train_size - val_size

# Split the dataset
train_dataset, val_dataset, test_dataset = random_split(dataset_with_labels, [train_size, val_size, test_size])

# Create DataLoaders for each set
train_dataloader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=64)
test_dataloader = DataLoader(test_dataset, batch_size=64)

input_dim = bert_config.hidden_size
num_labels = len(set(data_labels))
label_emb = torch.randn(num_labels, input_dim)

In [None]:
# Instantiate the InnerProductClassifier model
model = InnerProductClassifier(input_dim, label_emb).to(device)

# Define loss function
criterion = nn.CrossEntropyLoss()

# Define optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=2e-4)

In [None]:
EPOCHS = 10

# Move the BERT model to the device
bert_model.to(device)

for epoch in range(EPOCHS):
    model.train()
    for batch in tqdm(train_dataloader, desc=f"Epoch {epoch+1}"):

        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        y = batch["y"].to(device)

        # Get embeddings from the BERT model
        with torch.no_grad():
            outputs = bert_model(input_ids=input_ids, attention_mask=attention_mask)
            embeddings = outputs.last_hidden_state[:, 0, :] # Get the [CLS] token embedding

        # Pass embeddings to the InnerProductClassifier
        logits = model(embeddings)
        loss = criterion(logits, y)

        # Backpropagate and update weights
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        print(f"Loss: {loss.item()}")

    # Evaluate on the validation set after each epoch
    val_metrics = evaluate(model, val_dataloader, device)
    print(f"Epoch {epoch+1} Validation Accuracy: {val_metrics['accuracy']}, F1-macro: {val_metrics['f1_macro']}")

Epoch 1:  17%|█▋        | 1/6 [00:01<00:08,  1.77s/it]

Loss: 3.502189874649048


Epoch 1:  33%|███▎      | 2/6 [00:03<00:07,  1.75s/it]

Loss: 18.072551727294922


Epoch 1:  50%|█████     | 3/6 [00:05<00:05,  1.75s/it]

Loss: 3.4851465225219727


Epoch 1:  67%|██████▋   | 4/6 [00:06<00:03,  1.75s/it]

Loss: 5.072531223297119


Epoch 1:  83%|████████▎ | 5/6 [00:08<00:01,  1.75s/it]

Loss: 7.018558502197266


Epoch 1: 100%|██████████| 6/6 [00:09<00:00,  1.59s/it]

Loss: 8.504460334777832





Epoch 1 Validation Accuracy: 0.8372093023255814, F1-macro: 0.45569620253164556


Epoch 2:  17%|█▋        | 1/6 [00:01<00:08,  1.76s/it]

Loss: 9.02737808227539


Epoch 2:  33%|███▎      | 2/6 [00:03<00:07,  1.76s/it]

Loss: 12.456899642944336


Epoch 2:  50%|█████     | 3/6 [00:05<00:05,  1.77s/it]

Loss: 9.4509916305542


Epoch 2:  67%|██████▋   | 4/6 [00:07<00:03,  1.77s/it]

Loss: 11.700393676757812


Epoch 2:  83%|████████▎ | 5/6 [00:08<00:01,  1.77s/it]

Loss: 4.161364555358887


Epoch 2: 100%|██████████| 6/6 [00:09<00:00,  1.60s/it]

Loss: 3.9110167026519775





Epoch 2 Validation Accuracy: 0.8372093023255814, F1-macro: 0.45569620253164556


Epoch 3:  17%|█▋        | 1/6 [00:01<00:08,  1.78s/it]

Loss: 2.445507049560547


Epoch 3:  33%|███▎      | 2/6 [00:03<00:07,  1.79s/it]

Loss: 2.001100540161133


Epoch 3:  50%|█████     | 3/6 [00:05<00:05,  1.79s/it]

Loss: 4.417099952697754


Epoch 3:  67%|██████▋   | 4/6 [00:07<00:03,  1.80s/it]

Loss: 4.5123701095581055


Epoch 3:  83%|████████▎ | 5/6 [00:08<00:01,  1.80s/it]

Loss: 1.7941169738769531


Epoch 3: 100%|██████████| 6/6 [00:09<00:00,  1.63s/it]

Loss: 1.955578327178955





Epoch 3 Validation Accuracy: 0.8372093023255814, F1-macro: 0.45569620253164556


Epoch 4:  17%|█▋        | 1/6 [00:01<00:09,  1.82s/it]

Loss: 2.185885190963745


Epoch 4:  33%|███▎      | 2/6 [00:03<00:07,  1.80s/it]

Loss: 3.353130340576172


Epoch 4:  50%|█████     | 3/6 [00:05<00:05,  1.80s/it]

Loss: 3.666688919067383


Epoch 4:  67%|██████▋   | 4/6 [00:07<00:03,  1.80s/it]

Loss: 2.107501268386841


Epoch 4:  83%|████████▎ | 5/6 [00:09<00:01,  1.81s/it]

Loss: 2.2607791423797607


Epoch 4: 100%|██████████| 6/6 [00:09<00:00,  1.64s/it]

Loss: 0.7273802161216736





Epoch 4 Validation Accuracy: 0.7906976744186046, F1-macro: 0.5922023182297155


Epoch 5:  17%|█▋        | 1/6 [00:01<00:09,  1.81s/it]

Loss: 0.9331235289573669


Epoch 5:  33%|███▎      | 2/6 [00:03<00:07,  1.82s/it]

Loss: 1.7340811491012573


Epoch 5:  50%|█████     | 3/6 [00:05<00:05,  1.82s/it]

Loss: 2.0309371948242188


Epoch 5:  67%|██████▋   | 4/6 [00:07<00:03,  1.82s/it]

Loss: 0.927438497543335


Epoch 5:  83%|████████▎ | 5/6 [00:09<00:01,  1.83s/it]

Loss: 0.7587804794311523


Epoch 5: 100%|██████████| 6/6 [00:09<00:00,  1.66s/it]

Loss: 0.8422700762748718





Epoch 5 Validation Accuracy: 0.8372093023255814, F1-macro: 0.45569620253164556


Epoch 6:  17%|█▋        | 1/6 [00:01<00:09,  1.83s/it]

Loss: 1.6165452003479004


Epoch 6:  33%|███▎      | 2/6 [00:03<00:07,  1.84s/it]

Loss: 1.1484975814819336


Epoch 6:  50%|█████     | 3/6 [00:05<00:05,  1.84s/it]

Loss: 2.271476984024048


Epoch 6:  67%|██████▋   | 4/6 [00:07<00:03,  1.84s/it]

Loss: 0.8744885921478271


Epoch 6:  83%|████████▎ | 5/6 [00:09<00:01,  1.84s/it]

Loss: 0.740361750125885


Epoch 6: 100%|██████████| 6/6 [00:10<00:00,  1.67s/it]

Loss: 1.8269211053848267





Epoch 6 Validation Accuracy: 0.4418604651162791, F1-macro: 0.4342105263157895


Epoch 7:  17%|█▋        | 1/6 [00:01<00:09,  1.84s/it]

Loss: 1.6722252368927002


Epoch 7:  33%|███▎      | 2/6 [00:03<00:07,  1.85s/it]

Loss: 0.7328757643699646


Epoch 7:  50%|█████     | 3/6 [00:05<00:05,  1.85s/it]

Loss: 0.40334558486938477


Epoch 7:  67%|██████▋   | 4/6 [00:07<00:03,  1.85s/it]

Loss: 0.7605124115943909


Epoch 7:  83%|████████▎ | 5/6 [00:09<00:01,  1.85s/it]

Loss: 1.6656959056854248


Epoch 7: 100%|██████████| 6/6 [00:10<00:00,  1.68s/it]

Loss: 1.134954571723938





Epoch 7 Validation Accuracy: 0.813953488372093, F1-macro: 0.44871794871794873


Epoch 8:  17%|█▋        | 1/6 [00:01<00:09,  1.84s/it]

Loss: 1.6099941730499268


Epoch 8:  33%|███▎      | 2/6 [00:03<00:07,  1.85s/it]

Loss: 0.10987427085638046


Epoch 8:  50%|█████     | 3/6 [00:05<00:05,  1.85s/it]

Loss: 0.7248846292495728


Epoch 8:  67%|██████▋   | 4/6 [00:07<00:03,  1.86s/it]

Loss: 1.9300512075424194


Epoch 8:  83%|████████▎ | 5/6 [00:09<00:01,  1.86s/it]

Loss: 0.6614447832107544


Epoch 8: 100%|██████████| 6/6 [00:10<00:00,  1.68s/it]

Loss: 0.36889106035232544





Epoch 8 Validation Accuracy: 0.8372093023255814, F1-macro: 0.5656565656565656


Epoch 9:  17%|█▋        | 1/6 [00:01<00:09,  1.85s/it]

Loss: 0.5228935480117798


Epoch 9:  33%|███▎      | 2/6 [00:03<00:07,  1.86s/it]

Loss: 1.5527148246765137


Epoch 9:  50%|█████     | 3/6 [00:05<00:05,  1.86s/it]

Loss: 0.9450944662094116


Epoch 9:  67%|██████▋   | 4/6 [00:07<00:03,  1.86s/it]

Loss: 0.9083354473114014


Epoch 9:  83%|████████▎ | 5/6 [00:09<00:01,  1.86s/it]

Loss: 0.5601050853729248


Epoch 9: 100%|██████████| 6/6 [00:10<00:00,  1.69s/it]

Loss: 0.8299040198326111





Epoch 9 Validation Accuracy: 0.5348837209302325, F1-macro: 0.5023148148148149


Epoch 10:  17%|█▋        | 1/6 [00:01<00:09,  1.86s/it]

Loss: 1.87758207321167


Epoch 10:  33%|███▎      | 2/6 [00:03<00:07,  1.87s/it]

Loss: 0.626692533493042


Epoch 10:  50%|█████     | 3/6 [00:05<00:05,  1.87s/it]

Loss: 0.4340912699699402


Epoch 10:  67%|██████▋   | 4/6 [00:07<00:03,  1.87s/it]

Loss: 0.7579348683357239


Epoch 10:  83%|████████▎ | 5/6 [00:09<00:01,  1.88s/it]

Loss: 1.450488567352295


Epoch 10: 100%|██████████| 6/6 [00:10<00:00,  1.70s/it]

Loss: 1.0246431827545166





Epoch 10 Validation Accuracy: 0.813953488372093, F1-macro: 0.6126126126126126


In [None]:
# Evaluate on the test set
test_metrics = evaluate(model, test_dataloader, device)
print(f"Test Accuracy: {test_metrics['accuracy']}, F1-macro: {test_metrics['f1_macro']}")

Test Accuracy: 0.8, F1-macro: 0.6126126126126126


In [None]:
current_type = 'personalization and blaming'

results_df.loc[
    results_df["distortion_type"] == current_type,
    ["test_accuracy", "f1_macro"]
] = [test_metrics['accuracy'], test_metrics['f1_macro']]

# Accuracy

In [None]:
results_df

Unnamed: 0,distortion_type,test_accuracy,f1_macro
0,all-or-nothing thinking,0.848039,0.48897
1,comparing and despairing,0.84,0.456522
2,disqualifying the positive,0.990196,0.497537
3,emotional reasoning,0.813725,0.53447
4,fortune telling,0.882353,0.46875
5,labeling,0.848039,0.612856
6,magnification,0.867647,0.498863
7,mind reading,0.803922,0.59803
8,overgeneralizing,0.818627,0.498372
9,should statements,0.911765,0.567593
