In [1]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

/kaggle/input/json-data/assignment_3_ai_tutors_dataset.json


In [2]:
import torch
import torch.nn as nn
from transformers import BertForSequenceClassification, BertTokenizer
from sklearn.metrics import accuracy_score, f1_score, classification_report
from torch.utils.data import DataLoader
from tqdm import tqdm
import pandas as pd
import json
from torch.utils.data import Dataset
from sklearn.model_selection import train_test_split

2025-04-25 04:32:11.837934: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1745555532.091726      19 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1745555532.159899      19 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [4]:
class FocalLossWithWeights(nn.Module):
    def __init__(self, weight=None, gamma=2.0):
        super(FocalLossWithWeights, self).__init__()
        self.weight = weight
        self.gamma = gamma
        self.ce = nn.CrossEntropyLoss(reduction='none', weight=weight)

    def forward(self, inputs, targets):
        logp = self.ce(inputs, targets)
        p = torch.exp(-logp)
        loss = (1 - p) ** self.gamma * logp
        return loss.mean()


In [5]:
# ---------- Dataset Definition ----------
class TutorEvalSingleTaskDataset(Dataset):
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels

    def __getitem__(self, idx):
        return {
            'input_ids': self.encodings['input_ids'][idx],
            'attention_mask': self.encodings['attention_mask'][idx],
            'labels': self.labels[idx]
        }

    def __len__(self):
        return len(self.labels)

In [6]:
# ---------- Preprocessing ----------
def load_and_flatten(json_path):
    with open(json_path) as f:
        data = json.load(f)

    rows = []
    for instance in data:
        convo_id = instance["conversation_id"]
        history = instance["conversation_history"]
        for tutor_id, tutor_data in instance["tutor_responses"].items():
            row = {
                "conversation_id": convo_id,
                "tutor_id": tutor_id,
                "conversation_history": history,
                "tutor_response": tutor_data["response"],
                "Mistake_Identification": tutor_data["annotation"]["Mistake_Identification"],
                "Mistake_Location": tutor_data["annotation"]["Mistake_Location"],
                "Pedagogical_Guidance": tutor_data["annotation"]["Providing_Guidance"],
                "Actionability": tutor_data["annotation"]["Actionability"]
            }
            rows.append(row)
    return pd.DataFrame(rows)

In [7]:
def build_input_text(row):
    return f"Context:\n{row['conversation_history']}\n\nTutor Response:\n{row['tutor_response']}"

LABEL_MAP = {"Yes": 0, "To some extent": 1, "No": 2}
MERGED_LABEL_MAP = {"Yes": 1, "To some extent": 1, "No": 0}

def encode_labels(df):
    for task in ["Mistake_Identification", "Mistake_Location", "Pedagogical_Guidance", "Actionability"]:
        df[f"{task}_label"] = df[task].map(LABEL_MAP)
        df[f"{task}_binary"] = df[task].map(MERGED_LABEL_MAP)
    return df

def tokenize_inputs(tokenizer, texts, max_length=256):
    return tokenizer(
        texts,
        add_special_tokens=True,
        truncation=True,
        padding=True,
        max_length=max_length,
        return_tensors="pt"
    )

In [8]:
def preprocess_dataset(json_path, task_label):
    df = load_and_flatten(json_path)
    df["input_text"] = df.apply(build_input_text, axis=1)
    df = encode_labels(df)

    train_df, val_df = train_test_split(df, test_size=0.1, stratify=df[task_label], random_state=42)

    tokenizer = BertTokenizer.from_pretrained("bert-base-uncased", do_lower_case=True)
    train_enc = tokenize_inputs(tokenizer, train_df["input_text"].tolist())
    val_enc = tokenize_inputs(tokenizer, val_df["input_text"].tolist())

    train_labels = torch.tensor(train_df[task_label].tolist())
    val_labels = torch.tensor(val_df[task_label].tolist())

    train_dataset = TutorEvalSingleTaskDataset(train_enc, train_labels)
    val_dataset = TutorEvalSingleTaskDataset(val_enc, val_labels)

    return train_dataset, val_dataset, tokenizer, df


In [9]:
from sklearn.utils.class_weight import compute_class_weight
import numpy as np

def get_class_weights(labels, num_classes):
    class_weights = compute_class_weight(class_weight='balanced', classes=np.arange(num_classes), y=labels)
    return torch.tensor(class_weights, dtype=torch.float).to(device)


In [10]:
# ---------- Model ----------
class SingleTaskBertClassifier(nn.Module):
    def __init__(self, num_labels):
        super().__init__()
        self.bert = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=num_labels)

    def forward(self, input_ids, attention_mask):
        return self.bert(input_ids=input_ids, attention_mask=attention_mask).logits

In [11]:
# ---------- Training ----------
def evaluate_model(model, val_loader):
    model.eval()
    all_preds, all_labels = [], []
    with torch.no_grad():
        for batch in val_loader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].cpu().numpy()

            outputs = model(input_ids, attention_mask)
            preds = torch.argmax(outputs, dim=1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels)

    acc = accuracy_score(all_labels, all_preds)
    macro_f1 = f1_score(all_labels, all_preds, average='macro')
    print(f"Validation Accuracy: {acc:.4f}  Validation Macro F1: {macro_f1:.4f}")
    print(classification_report(all_labels, all_preds, target_names=["Yes/To some extent", "No"], zero_division=0))

def train_model(loss_type, train_loader, val_loader, num_labels, epochs=15):
    print(f"\n🔁 Training with: {loss_type.upper()} Loss\n")

    model = SingleTaskBertClassifier(num_labels).to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5)

    train_labels_list = [label.item() for batch in train_loader for label in batch['labels']]
    class_weights = get_class_weights(train_labels_list, num_labels)

    if loss_type == "focal":
        criterion = FocalLossWithWeights(class_weights)
    elif loss_type == "smoothing":
        criterion = nn.CrossEntropyLoss(weight=class_weights, label_smoothing=0.1)
    else:
        criterion = nn.CrossEntropyLoss(weight=class_weights)


    for epoch in range(epochs):
        model.train()
        total_loss = 0
        for batch in tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}"):
            optimizer.zero_grad()
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)

            outputs = model(input_ids, attention_mask)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        avg_loss = total_loss / len(train_loader)
        print(f"Epoch {epoch+1} Loss: {avg_loss:.4f}")
        evaluate_model(model, val_loader)



In [12]:
# ---------- Run ----------
json_path = "/kaggle/input/json-data/assignment_3_ai_tutors_dataset.json"
train_dataset, val_dataset, tokenizer, df = preprocess_dataset(json_path, "Mistake_Identification_binary")
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=4)


tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

In [13]:
train_model("ce", train_loader, val_loader, num_labels=2)

Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`



🔁 Training with: CE Loss



model.safetensors:   0%|          | 0.00/440M [00:00<?, ?B/s]

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Epoch 1/15: 100%|██████████| 557/557 [01:53<00:00,  4.91it/s]


Epoch 1 Loss: 0.6352
Validation Accuracy: 0.8508  Validation Macro F1: 0.4597
                    precision    recall  f1-score   support

Yes/To some extent       0.00      0.00      0.00        37
                No       0.85      1.00      0.92       211

          accuracy                           0.85       248
         macro avg       0.43      0.50      0.46       248
      weighted avg       0.72      0.85      0.78       248



Epoch 2/15: 100%|██████████| 557/557 [02:03<00:00,  4.51it/s]


Epoch 2 Loss: 0.5997
Validation Accuracy: 0.8508  Validation Macro F1: 0.4852
                    precision    recall  f1-score   support

Yes/To some extent       0.50      0.03      0.05        37
                No       0.85      1.00      0.92       211

          accuracy                           0.85       248
         macro avg       0.68      0.51      0.49       248
      weighted avg       0.80      0.85      0.79       248



Epoch 3/15: 100%|██████████| 557/557 [02:03<00:00,  4.52it/s]


Epoch 3 Loss: 0.5699
Validation Accuracy: 0.8387  Validation Macro F1: 0.5387
                    precision    recall  f1-score   support

Yes/To some extent       0.36      0.11      0.17        37
                No       0.86      0.97      0.91       211

          accuracy                           0.84       248
         macro avg       0.61      0.54      0.54       248
      weighted avg       0.79      0.84      0.80       248



Epoch 4/15: 100%|██████████| 557/557 [02:03<00:00,  4.51it/s]


Epoch 4 Loss: 0.5318
Validation Accuracy: 0.8306  Validation Macro F1: 0.5636
                    precision    recall  f1-score   support

Yes/To some extent       0.35      0.16      0.22        37
                No       0.87      0.95      0.90       211

          accuracy                           0.83       248
         macro avg       0.61      0.56      0.56       248
      weighted avg       0.79      0.83      0.80       248



Epoch 5/15: 100%|██████████| 557/557 [02:03<00:00,  4.51it/s]


Epoch 5 Loss: 0.5126
Validation Accuracy: 0.8347  Validation Macro F1: 0.5937
                    precision    recall  f1-score   support

Yes/To some extent       0.40      0.22      0.28        37
                No       0.87      0.94      0.91       211

          accuracy                           0.83       248
         macro avg       0.64      0.58      0.59       248
      weighted avg       0.80      0.83      0.81       248



Epoch 6/15: 100%|██████████| 557/557 [02:03<00:00,  4.52it/s]


Epoch 6 Loss: 0.4894
Validation Accuracy: 0.8347  Validation Macro F1: 0.5808
                    precision    recall  f1-score   support

Yes/To some extent       0.39      0.19      0.25        37
                No       0.87      0.95      0.91       211

          accuracy                           0.83       248
         macro avg       0.63      0.57      0.58       248
      weighted avg       0.80      0.83      0.81       248



Epoch 7/15: 100%|██████████| 557/557 [02:03<00:00,  4.51it/s]


Epoch 7 Loss: 0.4780
Validation Accuracy: 0.8508  Validation Macro F1: 0.5811
                    precision    recall  f1-score   support

Yes/To some extent       0.50      0.16      0.24        37
                No       0.87      0.97      0.92       211

          accuracy                           0.85       248
         macro avg       0.68      0.57      0.58       248
      weighted avg       0.81      0.85      0.82       248



Epoch 8/15: 100%|██████████| 557/557 [02:03<00:00,  4.52it/s]


Epoch 8 Loss: 0.4589
Validation Accuracy: 0.8226  Validation Macro F1: 0.5705
                    precision    recall  f1-score   support

Yes/To some extent       0.33      0.19      0.24        37
                No       0.87      0.93      0.90       211

          accuracy                           0.82       248
         macro avg       0.60      0.56      0.57       248
      weighted avg       0.79      0.82      0.80       248



Epoch 9/15: 100%|██████████| 557/557 [02:03<00:00,  4.51it/s]


Epoch 9 Loss: 0.4598
Validation Accuracy: 0.8508  Validation Macro F1: 0.5652
                    precision    recall  f1-score   support

Yes/To some extent       0.50      0.14      0.21        37
                No       0.87      0.98      0.92       211

          accuracy                           0.85       248
         macro avg       0.68      0.56      0.57       248
      weighted avg       0.81      0.85      0.81       248



Epoch 10/15: 100%|██████████| 557/557 [02:03<00:00,  4.52it/s]


Epoch 10 Loss: 0.4487
Validation Accuracy: 0.8266  Validation Macro F1: 0.5603
                    precision    recall  f1-score   support

Yes/To some extent       0.33      0.16      0.22        37
                No       0.87      0.94      0.90       211

          accuracy                           0.83       248
         macro avg       0.60      0.55      0.56       248
      weighted avg       0.79      0.83      0.80       248



Epoch 11/15: 100%|██████████| 557/557 [02:03<00:00,  4.52it/s]


Epoch 11 Loss: 0.4369
Validation Accuracy: 0.8306  Validation Macro F1: 0.6018
                    precision    recall  f1-score   support

Yes/To some extent       0.39      0.24      0.30        37
                No       0.88      0.93      0.90       211

          accuracy                           0.83       248
         macro avg       0.63      0.59      0.60       248
      weighted avg       0.80      0.83      0.81       248



Epoch 12/15: 100%|██████████| 557/557 [02:03<00:00,  4.51it/s]


Epoch 12 Loss: 0.4209
Validation Accuracy: 0.8306  Validation Macro F1: 0.5900
                    precision    recall  f1-score   support

Yes/To some extent       0.38      0.22      0.28        37
                No       0.87      0.94      0.90       211

          accuracy                           0.83       248
         macro avg       0.63      0.58      0.59       248
      weighted avg       0.80      0.83      0.81       248



Epoch 13/15: 100%|██████████| 557/557 [02:03<00:00,  4.51it/s]


Epoch 13 Loss: 0.4145
Validation Accuracy: 0.8266  Validation Macro F1: 0.5864
                    precision    recall  f1-score   support

Yes/To some extent       0.36      0.22      0.27        37
                No       0.87      0.93      0.90       211

          accuracy                           0.83       248
         macro avg       0.62      0.57      0.59       248
      weighted avg       0.80      0.83      0.81       248



Epoch 14/15: 100%|██████████| 557/557 [02:03<00:00,  4.53it/s]


Epoch 14 Loss: 0.4211
Validation Accuracy: 0.8105  Validation Macro F1: 0.5479
                    precision    recall  f1-score   support

Yes/To some extent       0.27      0.16      0.20        37
                No       0.86      0.92      0.89       211

          accuracy                           0.81       248
         macro avg       0.57      0.54      0.55       248
      weighted avg       0.77      0.81      0.79       248



Epoch 15/15: 100%|██████████| 557/557 [02:03<00:00,  4.52it/s]


Epoch 15 Loss: 0.4041
Validation Accuracy: 0.8145  Validation Macro F1: 0.5760
                    precision    recall  f1-score   support

Yes/To some extent       0.32      0.22      0.26        37
                No       0.87      0.92      0.89       211

          accuracy                           0.81       248
         macro avg       0.59      0.57      0.58       248
      weighted avg       0.79      0.81      0.80       248



In [14]:
# ---------- Run ----------
json_path = "/kaggle/input/json-data/assignment_3_ai_tutors_dataset.json"
train_dataset, val_dataset, tokenizer, df = preprocess_dataset(json_path, "Mistake_Location_binary")
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=4)



In [15]:
train_model("ce", train_loader, val_loader, num_labels=2)

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.



🔁 Training with: CE Loss



Epoch 1/15: 100%|██████████| 557/557 [02:03<00:00,  4.52it/s]


Epoch 1 Loss: 0.6698
Validation Accuracy: 0.4718  Validation Macro F1: 0.4707
                    precision    recall  f1-score   support

Yes/To some extent       0.34      0.90      0.49        71
                No       0.88      0.30      0.45       177

          accuracy                           0.47       248
         macro avg       0.61      0.60      0.47       248
      weighted avg       0.73      0.47      0.46       248



Epoch 2/15: 100%|██████████| 557/557 [02:03<00:00,  4.52it/s]


Epoch 2 Loss: 0.6331
Validation Accuracy: 0.7298  Validation Macro F1: 0.4850
                    precision    recall  f1-score   support

Yes/To some extent       0.83      0.07      0.13        71
                No       0.73      0.99      0.84       177

          accuracy                           0.73       248
         macro avg       0.78      0.53      0.48       248
      weighted avg       0.76      0.73      0.64       248



Epoch 3/15: 100%|██████████| 557/557 [02:03<00:00,  4.52it/s]


Epoch 3 Loss: 0.6009
Validation Accuracy: 0.7298  Validation Macro F1: 0.4738
                    precision    recall  f1-score   support

Yes/To some extent       1.00      0.06      0.11        71
                No       0.73      1.00      0.84       177

          accuracy                           0.73       248
         macro avg       0.86      0.53      0.47       248
      weighted avg       0.80      0.73      0.63       248



Epoch 4/15: 100%|██████████| 557/557 [02:03<00:00,  4.53it/s]


Epoch 4 Loss: 0.5718
Validation Accuracy: 0.7298  Validation Macro F1: 0.4738
                    precision    recall  f1-score   support

Yes/To some extent       1.00      0.06      0.11        71
                No       0.73      1.00      0.84       177

          accuracy                           0.73       248
         macro avg       0.86      0.53      0.47       248
      weighted avg       0.80      0.73      0.63       248



Epoch 5/15: 100%|██████████| 557/557 [02:03<00:00,  4.52it/s]


Epoch 5 Loss: 0.5532
Validation Accuracy: 0.6452  Validation Macro F1: 0.6275
                    precision    recall  f1-score   support

Yes/To some extent       0.43      0.75      0.55        71
                No       0.86      0.60      0.71       177

          accuracy                           0.65       248
         macro avg       0.64      0.68      0.63       248
      weighted avg       0.73      0.65      0.66       248



Epoch 6/15: 100%|██████████| 557/557 [02:03<00:00,  4.52it/s]


Epoch 6 Loss: 0.5534
Validation Accuracy: 0.7298  Validation Macro F1: 0.4850
                    precision    recall  f1-score   support

Yes/To some extent       0.83      0.07      0.13        71
                No       0.73      0.99      0.84       177

          accuracy                           0.73       248
         macro avg       0.78      0.53      0.48       248
      weighted avg       0.76      0.73      0.64       248



Epoch 7/15: 100%|██████████| 557/557 [02:03<00:00,  4.51it/s]


Epoch 7 Loss: 0.5472
Validation Accuracy: 0.7258  Validation Macro F1: 0.4828
                    precision    recall  f1-score   support

Yes/To some extent       0.71      0.07      0.13        71
                No       0.73      0.99      0.84       177

          accuracy                           0.73       248
         macro avg       0.72      0.53      0.48       248
      weighted avg       0.72      0.73      0.63       248



Epoch 8/15: 100%|██████████| 557/557 [02:03<00:00,  4.52it/s]


Epoch 8 Loss: 0.5161
Validation Accuracy: 0.7339  Validation Macro F1: 0.5668
                    precision    recall  f1-score   support

Yes/To some extent       0.61      0.20      0.30        71
                No       0.75      0.95      0.84       177

          accuracy                           0.73       248
         macro avg       0.68      0.57      0.57       248
      weighted avg       0.71      0.73      0.68       248



Epoch 9/15: 100%|██████████| 557/557 [02:03<00:00,  4.52it/s]


Epoch 9 Loss: 0.5017
Validation Accuracy: 0.7298  Validation Macro F1: 0.5330
                    precision    recall  f1-score   support

Yes/To some extent       0.62      0.14      0.23        71
                No       0.74      0.97      0.84       177

          accuracy                           0.73       248
         macro avg       0.68      0.55      0.53       248
      weighted avg       0.70      0.73      0.66       248



Epoch 10/15: 100%|██████████| 557/557 [02:03<00:00,  4.51it/s]


Epoch 10 Loss: 0.4881
Validation Accuracy: 0.6129  Validation Macro F1: 0.5921
                    precision    recall  f1-score   support

Yes/To some extent       0.40      0.68      0.50        71
                No       0.82      0.59      0.68       177

          accuracy                           0.61       248
         macro avg       0.61      0.63      0.59       248
      weighted avg       0.70      0.61      0.63       248



Epoch 11/15: 100%|██████████| 557/557 [02:03<00:00,  4.52it/s]


Epoch 11 Loss: 0.4757
Validation Accuracy: 0.7379  Validation Macro F1: 0.5699
                    precision    recall  f1-score   support

Yes/To some extent       0.64      0.20      0.30        71
                No       0.75      0.95      0.84       177

          accuracy                           0.74       248
         macro avg       0.69      0.58      0.57       248
      weighted avg       0.72      0.74      0.68       248



Epoch 12/15: 100%|██████████| 557/557 [02:03<00:00,  4.52it/s]


Epoch 12 Loss: 0.4624
Validation Accuracy: 0.7056  Validation Macro F1: 0.6443
                    precision    recall  f1-score   support

Yes/To some extent       0.49      0.51      0.50        71
                No       0.80      0.79      0.79       177

          accuracy                           0.71       248
         macro avg       0.64      0.65      0.64       248
      weighted avg       0.71      0.71      0.71       248



Epoch 13/15: 100%|██████████| 557/557 [02:03<00:00,  4.53it/s]


Epoch 13 Loss: 0.4620
Validation Accuracy: 0.7298  Validation Macro F1: 0.6253
                    precision    recall  f1-score   support

Yes/To some extent       0.54      0.35      0.43        71
                No       0.77      0.88      0.82       177

          accuracy                           0.73       248
         macro avg       0.66      0.62      0.63       248
      weighted avg       0.71      0.73      0.71       248



Epoch 14/15: 100%|██████████| 557/557 [02:03<00:00,  4.52it/s]


Epoch 14 Loss: 0.4462
Validation Accuracy: 0.6895  Validation Macro F1: 0.6278
                    precision    recall  f1-score   support

Yes/To some extent       0.46      0.49      0.48        71
                No       0.79      0.77      0.78       177

          accuracy                           0.69       248
         macro avg       0.63      0.63      0.63       248
      weighted avg       0.70      0.69      0.69       248



Epoch 15/15: 100%|██████████| 557/557 [02:03<00:00,  4.51it/s]


Epoch 15 Loss: 0.4394
Validation Accuracy: 0.6774  Validation Macro F1: 0.6400
                    precision    recall  f1-score   support

Yes/To some extent       0.45      0.62      0.52        71
                No       0.82      0.70      0.76       177

          accuracy                           0.68       248
         macro avg       0.64      0.66      0.64       248
      weighted avg       0.72      0.68      0.69       248



In [16]:
# ---------- Run ----------
json_path = "/kaggle/input/json-data/assignment_3_ai_tutors_dataset.json"
train_dataset, val_dataset, tokenizer, df = preprocess_dataset(json_path, "Pedagogical_Guidance_binary")
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=4)

In [17]:
train_model("ce", train_loader, val_loader, num_labels=2)

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.



🔁 Training with: CE Loss



Epoch 1/15: 100%|██████████| 557/557 [02:03<00:00,  4.51it/s]


Epoch 1 Loss: 0.6544
Validation Accuracy: 0.7984  Validation Macro F1: 0.5842
                    precision    recall  f1-score   support

Yes/To some extent       0.77      0.18      0.29        57
                No       0.80      0.98      0.88       191

          accuracy                           0.80       248
         macro avg       0.78      0.58      0.58       248
      weighted avg       0.79      0.80      0.75       248



Epoch 2/15: 100%|██████████| 557/557 [02:03<00:00,  4.51it/s]


Epoch 2 Loss: 0.6322
Validation Accuracy: 0.7742  Validation Macro F1: 0.5987
                    precision    recall  f1-score   support

Yes/To some extent       0.52      0.25      0.33        57
                No       0.81      0.93      0.86       191

          accuracy                           0.77       248
         macro avg       0.66      0.59      0.60       248
      weighted avg       0.74      0.77      0.74       248



Epoch 3/15: 100%|██████████| 557/557 [02:03<00:00,  4.52it/s]


Epoch 3 Loss: 0.5941
Validation Accuracy: 0.8024  Validation Macro F1: 0.6233
                    precision    recall  f1-score   support

Yes/To some extent       0.70      0.25      0.36        57
                No       0.81      0.97      0.88       191

          accuracy                           0.80       248
         macro avg       0.76      0.61      0.62       248
      weighted avg       0.79      0.80      0.76       248



Epoch 4/15: 100%|██████████| 557/557 [02:03<00:00,  4.52it/s]


Epoch 4 Loss: 0.5458
Validation Accuracy: 0.7298  Validation Macro F1: 0.6061
                    precision    recall  f1-score   support

Yes/To some extent       0.40      0.37      0.39        57
                No       0.82      0.84      0.83       191

          accuracy                           0.73       248
         macro avg       0.61      0.60      0.61       248
      weighted avg       0.72      0.73      0.73       248



Epoch 5/15: 100%|██████████| 557/557 [02:03<00:00,  4.51it/s]


Epoch 5 Loss: 0.5137
Validation Accuracy: 0.8145  Validation Macro F1: 0.6572
                    precision    recall  f1-score   support

Yes/To some extent       0.74      0.30      0.42        57
                No       0.82      0.97      0.89       191

          accuracy                           0.81       248
         macro avg       0.78      0.63      0.66       248
      weighted avg       0.80      0.81      0.78       248



Epoch 6/15: 100%|██████████| 557/557 [02:02<00:00,  4.53it/s]


Epoch 6 Loss: 0.4679
Validation Accuracy: 0.7782  Validation Macro F1: 0.6361
                    precision    recall  f1-score   support

Yes/To some extent       0.53      0.33      0.41        57
                No       0.82      0.91      0.86       191

          accuracy                           0.78       248
         macro avg       0.67      0.62      0.64       248
      weighted avg       0.75      0.78      0.76       248



Epoch 7/15: 100%|██████████| 557/557 [02:03<00:00,  4.51it/s]


Epoch 7 Loss: 0.4433
Validation Accuracy: 0.7782  Validation Macro F1: 0.6234
                    precision    recall  f1-score   support

Yes/To some extent       0.53      0.30      0.38        57
                No       0.81      0.92      0.86       191

          accuracy                           0.78       248
         macro avg       0.67      0.61      0.62       248
      weighted avg       0.75      0.78      0.75       248



Epoch 8/15: 100%|██████████| 557/557 [02:03<00:00,  4.52it/s]


Epoch 8 Loss: 0.4391
Validation Accuracy: 0.7782  Validation Macro F1: 0.6361
                    precision    recall  f1-score   support

Yes/To some extent       0.53      0.33      0.41        57
                No       0.82      0.91      0.86       191

          accuracy                           0.78       248
         macro avg       0.67      0.62      0.64       248
      weighted avg       0.75      0.78      0.76       248



Epoch 9/15: 100%|██████████| 557/557 [02:03<00:00,  4.52it/s]


Epoch 9 Loss: 0.4538
Validation Accuracy: 0.7460  Validation Macro F1: 0.6477
                    precision    recall  f1-score   support

Yes/To some extent       0.45      0.47      0.46        57
                No       0.84      0.83      0.83       191

          accuracy                           0.75       248
         macro avg       0.65      0.65      0.65       248
      weighted avg       0.75      0.75      0.75       248



Epoch 10/15: 100%|██████████| 557/557 [02:03<00:00,  4.52it/s]


Epoch 10 Loss: 0.4157
Validation Accuracy: 0.7581  Validation Macro F1: 0.6242
                    precision    recall  f1-score   support

Yes/To some extent       0.47      0.35      0.40        57
                No       0.82      0.88      0.85       191

          accuracy                           0.76       248
         macro avg       0.64      0.62      0.62       248
      weighted avg       0.74      0.76      0.75       248



Epoch 11/15: 100%|██████████| 557/557 [02:03<00:00,  4.51it/s]


Epoch 11 Loss: 0.4134
Validation Accuracy: 0.8145  Validation Macro F1: 0.6704
                    precision    recall  f1-score   support

Yes/To some extent       0.70      0.33      0.45        57
                No       0.83      0.96      0.89       191

          accuracy                           0.81       248
         macro avg       0.77      0.65      0.67       248
      weighted avg       0.80      0.81      0.79       248



Epoch 12/15: 100%|██████████| 557/557 [02:03<00:00,  4.52it/s]


Epoch 12 Loss: 0.4236
Validation Accuracy: 0.7944  Validation Macro F1: 0.6310
                    precision    recall  f1-score   support

Yes/To some extent       0.62      0.28      0.39        57
                No       0.82      0.95      0.88       191

          accuracy                           0.79       248
         macro avg       0.72      0.61      0.63       248
      weighted avg       0.77      0.79      0.76       248



Epoch 13/15: 100%|██████████| 557/557 [02:03<00:00,  4.51it/s]


Epoch 13 Loss: 0.4052
Validation Accuracy: 0.7903  Validation Macro F1: 0.6471
                    precision    recall  f1-score   support

Yes/To some extent       0.58      0.33      0.42        57
                No       0.82      0.93      0.87       191

          accuracy                           0.79       248
         macro avg       0.70      0.63      0.65       248
      weighted avg       0.77      0.79      0.77       248



Epoch 14/15: 100%|██████████| 557/557 [02:03<00:00,  4.51it/s]


Epoch 14 Loss: 0.4200
Validation Accuracy: 0.7460  Validation Macro F1: 0.6764
                    precision    recall  f1-score   support

Yes/To some extent       0.46      0.61      0.53        57
                No       0.87      0.79      0.83       191

          accuracy                           0.75       248
         macro avg       0.67      0.70      0.68       248
      weighted avg       0.78      0.75      0.76       248



Epoch 15/15: 100%|██████████| 557/557 [02:03<00:00,  4.52it/s]


Epoch 15 Loss: 0.4184
Validation Accuracy: 0.7742  Validation Macro F1: 0.6493
                    precision    recall  f1-score   support

Yes/To some extent       0.51      0.39      0.44        57
                No       0.83      0.89      0.86       191

          accuracy                           0.77       248
         macro avg       0.67      0.64      0.65       248
      weighted avg       0.76      0.77      0.76       248



In [18]:
# ---------- Run ----------
json_path = "/kaggle/input/json-data/assignment_3_ai_tutors_dataset.json"
train_dataset, val_dataset, tokenizer, df = preprocess_dataset(json_path, "Actionability_binary")
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=4)

In [19]:
train_model("ce", train_loader, val_loader, num_labels=2)

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.



🔁 Training with: CE Loss



Epoch 1/15: 100%|██████████| 557/557 [02:03<00:00,  4.52it/s]


Epoch 1 Loss: 0.6946
Validation Accuracy: 0.3226  Validation Macro F1: 0.2439
                    precision    recall  f1-score   support

Yes/To some extent       0.32      1.00      0.49        80
                No       0.00      0.00      0.00       168

          accuracy                           0.32       248
         macro avg       0.16      0.50      0.24       248
      weighted avg       0.10      0.32      0.16       248



Epoch 2/15: 100%|██████████| 557/557 [02:03<00:00,  4.52it/s]


Epoch 2 Loss: 0.6655
Validation Accuracy: 0.7298  Validation Macro F1: 0.6525
                    precision    recall  f1-score   support

Yes/To some extent       0.63      0.40      0.49        80
                No       0.76      0.89      0.82       168

          accuracy                           0.73       248
         macro avg       0.69      0.64      0.65       248
      weighted avg       0.71      0.73      0.71       248



Epoch 3/15: 100%|██████████| 557/557 [02:02<00:00,  4.53it/s]


Epoch 3 Loss: 0.6263
Validation Accuracy: 0.6855  Validation Macro F1: 0.4302
                    precision    recall  f1-score   support

Yes/To some extent       1.00      0.03      0.05        80
                No       0.68      1.00      0.81       168

          accuracy                           0.69       248
         macro avg       0.84      0.51      0.43       248
      weighted avg       0.79      0.69      0.57       248



Epoch 4/15: 100%|██████████| 557/557 [02:03<00:00,  4.52it/s]


Epoch 4 Loss: 0.6175
Validation Accuracy: 0.7581  Validation Macro F1: 0.6449
                    precision    recall  f1-score   support

Yes/To some extent       0.86      0.30      0.44        80
                No       0.75      0.98      0.85       168

          accuracy                           0.76       248
         macro avg       0.80      0.64      0.64       248
      weighted avg       0.78      0.76      0.72       248



Epoch 5/15: 100%|██████████| 557/557 [02:03<00:00,  4.51it/s]


Epoch 5 Loss: 0.5700
Validation Accuracy: 0.7581  Validation Macro F1: 0.6583
                    precision    recall  f1-score   support

Yes/To some extent       0.79      0.34      0.47        80
                No       0.75      0.96      0.84       168

          accuracy                           0.76       248
         macro avg       0.77      0.65      0.66       248
      weighted avg       0.77      0.76      0.72       248



Epoch 6/15: 100%|██████████| 557/557 [02:03<00:00,  4.52it/s]


Epoch 6 Loss: 0.5502
Validation Accuracy: 0.7702  Validation Macro F1: 0.6734
                    precision    recall  f1-score   support

Yes/To some extent       0.85      0.35      0.50        80
                No       0.76      0.97      0.85       168

          accuracy                           0.77       248
         macro avg       0.80      0.66      0.67       248
      weighted avg       0.79      0.77      0.74       248



Epoch 7/15: 100%|██████████| 557/557 [02:03<00:00,  4.52it/s]


Epoch 7 Loss: 0.5517
Validation Accuracy: 0.7581  Validation Macro F1: 0.6583
                    precision    recall  f1-score   support

Yes/To some extent       0.79      0.34      0.47        80
                No       0.75      0.96      0.84       168

          accuracy                           0.76       248
         macro avg       0.77      0.65      0.66       248
      weighted avg       0.77      0.76      0.72       248



Epoch 8/15: 100%|██████████| 557/557 [02:03<00:00,  4.52it/s]


Epoch 8 Loss: 0.5271
Validation Accuracy: 0.7621  Validation Macro F1: 0.6619
                    precision    recall  f1-score   support

Yes/To some extent       0.82      0.34      0.48        80
                No       0.75      0.96      0.85       168

          accuracy                           0.76       248
         macro avg       0.79      0.65      0.66       248
      weighted avg       0.77      0.76      0.73       248



Epoch 9/15: 100%|██████████| 557/557 [02:03<00:00,  4.52it/s]


Epoch 9 Loss: 0.5157
Validation Accuracy: 0.7661  Validation Macro F1: 0.6612
                    precision    recall  f1-score   support

Yes/To some extent       0.87      0.33      0.47        80
                No       0.75      0.98      0.85       168

          accuracy                           0.77       248
         macro avg       0.81      0.65      0.66       248
      weighted avg       0.79      0.77      0.73       248



Epoch 10/15: 100%|██████████| 557/557 [02:03<00:00,  4.51it/s]


Epoch 10 Loss: 0.5066
Validation Accuracy: 0.7621  Validation Macro F1: 0.6700
                    precision    recall  f1-score   support

Yes/To some extent       0.78      0.36      0.50        80
                No       0.76      0.95      0.84       168

          accuracy                           0.76       248
         macro avg       0.77      0.66      0.67       248
      weighted avg       0.77      0.76      0.73       248



Epoch 11/15: 100%|██████████| 557/557 [02:03<00:00,  4.51it/s]


Epoch 11 Loss: 0.4771
Validation Accuracy: 0.7540  Validation Macro F1: 0.6628
                    precision    recall  f1-score   support

Yes/To some extent       0.74      0.36      0.49        80
                No       0.76      0.94      0.84       168

          accuracy                           0.75       248
         macro avg       0.75      0.65      0.66       248
      weighted avg       0.75      0.75      0.73       248



Epoch 12/15: 100%|██████████| 557/557 [02:03<00:00,  4.52it/s]


Epoch 12 Loss: 0.4729
Validation Accuracy: 0.7056  Validation Macro F1: 0.6414
                    precision    recall  f1-score   support

Yes/To some extent       0.56      0.44      0.49        80
                No       0.76      0.83      0.79       168

          accuracy                           0.71       248
         macro avg       0.66      0.64      0.64       248
      weighted avg       0.69      0.71      0.70       248



Epoch 13/15: 100%|██████████| 557/557 [02:03<00:00,  4.52it/s]


Epoch 13 Loss: 0.4572
Validation Accuracy: 0.7056  Validation Macro F1: 0.6383
                    precision    recall  f1-score   support

Yes/To some extent       0.56      0.42      0.48        80
                No       0.75      0.84      0.79       168

          accuracy                           0.71       248
         macro avg       0.66      0.63      0.64       248
      weighted avg       0.69      0.71      0.69       248



Epoch 14/15: 100%|██████████| 557/557 [02:02<00:00,  4.53it/s]


Epoch 14 Loss: 0.4440
Validation Accuracy: 0.7097  Validation Macro F1: 0.6610
                    precision    recall  f1-score   support

Yes/To some extent       0.55      0.51      0.53        80
                No       0.78      0.80      0.79       168

          accuracy                           0.71       248
         macro avg       0.66      0.66      0.66       248
      weighted avg       0.70      0.71      0.71       248



Epoch 15/15: 100%|██████████| 557/557 [02:03<00:00,  4.51it/s]


Epoch 15 Loss: 0.4423
Validation Accuracy: 0.7218  Validation Macro F1: 0.6806
                    precision    recall  f1-score   support

Yes/To some extent       0.57      0.56      0.57        80
                No       0.79      0.80      0.80       168

          accuracy                           0.72       248
         macro avg       0.68      0.68      0.68       248
      weighted avg       0.72      0.72      0.72       248

