## Dataset Download

In [None]:
!gdown 1doU9FE1tJ-0IL4tKZUOOp76JkXKuTkFU

Downloading...
From (original): https://drive.google.com/uc?id=1doU9FE1tJ-0IL4tKZUOOp76JkXKuTkFU
From (redirected): https://drive.google.com/uc?id=1doU9FE1tJ-0IL4tKZUOOp76JkXKuTkFU&confirm=t&uuid=52071f04-cc2f-414e-b37c-ce4972b9c905
To: /kaggle/working/anxiety_dataset_complete.zip
100%|███████████████████████████████████████| 1.41G/1.41G [00:04<00:00, 321MB/s]


In [None]:
!unzip anxiety_dataset_complete.zip > /dev/null

In [27]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModel
from sklearn.metrics import f1_score, classification_report
from sklearn.model_selection import train_test_split
from tqdm import tqdm
import json
import math

TRAIN_RATIO = 0.7  # 70% train
VAL_RATIO = 0.1    # 10% validation
TEST_RATIO = 0.2   # 20% test
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
BATCH_SIZE = 16
LR = 2e-5
MAX_LEN = 512
FUSION_DIM = 768
LABELS = ["Nervousness", "Lack of Worry Control", "Excessive Worry",
          "Difficulty Relaxing", "Restlessness", "Impending Doom"]
LABEL_MAP = {i: idx for idx, i in enumerate(LABELS)}

## Custom Dataset

In [28]:
class TextOnlyAnxietyDataset(Dataset):
    def __init__(self, data, tokenizer, max_len=512):
        self.data = data
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        sample = self.data[idx]

        ocr_text = sample["ocr_text"]
        figurative_reasoning = sample["figurative_reasoning"]
        combined_text = ocr_text + " [SEP] " + figurative_reasoning

        encoding = self.tokenizer(
            combined_text,
            padding="max_length",
            truncation=True,
            max_length=self.max_len,
            return_tensors="pt",
        )

        label = LABEL_MAP[sample["meme_anxiety_category"]]

        return {
            "input_ids": encoding["input_ids"].squeeze(0),
            "attention_mask": encoding["attention_mask"].squeeze(0),
            "label": torch.tensor(label, dtype=torch.long)
        }

def custom_collate_fn(batch):
    input_ids = torch.stack([item['input_ids'] for item in batch])
    attention_mask = torch.stack([item['attention_mask'] for item in batch])
    labels = torch.stack([item['label'] for item in batch])

    return {
        'input_ids': input_ids,
        'attention_mask': attention_mask,
        'label': labels,
    }

## Model Definition

In [29]:
class TextOnlyModel(nn.Module):
    def __init__(self, text_model_name="bert-base-uncased", num_classes=6, hidden_dim=768):
        super(TextOnlyModel, self).__init__()

        self.text_encoder = AutoModel.from_pretrained(text_model_name)
        self.text_dim = self.text_encoder.config.hidden_size

        self.text_projections = nn.ModuleList([
            nn.Linear(self.text_dim, hidden_dim) for _ in range(3)
        ])

        self.expert_nets = nn.ModuleList([
            nn.Sequential(
                nn.Linear(hidden_dim, hidden_dim),
                nn.LayerNorm(hidden_dim),
                nn.GELU(),
                nn.Dropout(0.2)
            ) for _ in range(4)
        ])

        self.moe_gate = nn.Linear(hidden_dim, 4)

        self.classifier = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.GELU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_dim, num_classes)
        )

    def forward(self, input_ids, attention_mask):
        text_outputs = self.text_encoder(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True)

        text_features = [
            text_outputs.hidden_states[-i][:, 0] for i in range(1, 4)
        ]
        text_features = [
            proj(feat) for proj, feat in zip(self.text_projections, text_features)
        ]

        combined_feature = sum(text_features) / len(text_features)

        expert_outputs = [expert(combined_feature) for expert in self.expert_nets]
        expert_gates = F.softmax(self.moe_gate(combined_feature), dim=1)

        moe_output = torch.zeros_like(expert_outputs[0])
        for i, expert_out in enumerate(expert_outputs):
            moe_output += expert_out * expert_gates[:, i].unsqueeze(1)

        logits = self.classifier(moe_output)

        return logits

## Training Functions

In [30]:
def train_text_model(model, train_data, val_data, epochs, model_save_name):
    tokenizer = AutoTokenizer.from_pretrained("mental/mental-roberta-base")

    train_dataset = TextOnlyAnxietyDataset(train_data, tokenizer, max_len=MAX_LEN)
    val_dataset = TextOnlyAnxietyDataset(val_data, tokenizer, max_len=MAX_LEN)

    print("Train Set Size:", len(train_dataset))
    print("Validation Set Size:", len(val_dataset))
    
    train_loader = DataLoader(
        train_dataset,
        batch_size=BATCH_SIZE,
        shuffle=True,
        collate_fn=custom_collate_fn
    )

    val_loader = DataLoader(
        val_dataset,
        batch_size=BATCH_SIZE,
        shuffle=False,
        collate_fn=custom_collate_fn
    )

    optimizer = optim.AdamW(model.parameters(), lr=LR)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='max', factor=0.5, patience=2, verbose=True
    )
    criterion = nn.CrossEntropyLoss()

    model = model.to(DEVICE)
    model = nn.DataParallel(model)

    best_f1 = 0
    for epoch in range(epochs):
        print(f"\nEpoch {epoch+1}/{epochs}")

        model.train()
        train_loss = 0
        train_preds, train_labels = [], []

        for batch in tqdm(train_loader, desc="Training"):
            input_ids = batch["input_ids"].to(DEVICE)
            attention_mask = batch["attention_mask"].to(DEVICE)
            labels = batch["label"].to(DEVICE)

            optimizer.zero_grad()
            logits = model(input_ids, attention_mask)
            loss = criterion(logits, labels)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()

            train_loss += loss.item()
            predictions = torch.argmax(logits, dim=1).cpu().numpy()
            train_preds.extend(predictions)
            train_labels.extend(labels.cpu().numpy())

        train_loss = train_loss / len(train_loader)
        train_macro_f1 = f1_score(train_labels, train_preds, average="macro")
        train_weighted_f1 = f1_score(train_labels, train_preds, average="weighted")

        print(f"Train Loss: {train_loss:.4f}")
        print(f"Train Macro-F1: {train_macro_f1:.4f}, Weighted-F1: {train_weighted_f1:.4f}")

        val_loss, val_macro_f1, val_weighted_f1 = evaluate_text_model(
            model, val_loader, criterion
        )

        print(f"Validation Loss: {val_loss:.4f}")
        print(f"Validation Macro-F1: {val_macro_f1:.4f}, Weighted-F1: {val_weighted_f1:.4f}")

        scheduler.step(val_macro_f1)

        f1_hm = 2 * val_macro_f1 * val_weighted_f1 / (val_macro_f1 + val_weighted_f1)
        if f1_hm > best_f1:
            best_f1 = f1_hm
            torch.save(model.state_dict(), f"{model_save_name}_anxiety.pth")
            print("Best model saved!")

    return model

def evaluate_text_model(model, loader, criterion):
    model.eval()
    val_loss = 0
    val_preds, val_labels = [], []

    with torch.no_grad():
        for batch in tqdm(loader, desc="Evaluating"):
            input_ids = batch["input_ids"].to(DEVICE)
            attention_mask = batch["attention_mask"].to(DEVICE)
            labels = batch["label"].to(DEVICE)

            logits = model(input_ids, attention_mask)
            loss = criterion(logits, labels)

            val_loss += loss.item()
            predictions = torch.argmax(logits, dim=1).cpu().numpy()
            val_preds.extend(predictions)
            val_labels.extend(labels.cpu().numpy())

    val_loss = val_loss / len(loader)
    val_macro_f1 = f1_score(val_labels, val_preds, average="macro")
    val_weighted_f1 = f1_score(val_labels, val_preds, average="weighted")

    target_names = [k for k, v in sorted(LABEL_MAP.items(), key=lambda x: x[1])]
    report = classification_report(val_labels, val_preds, target_names=target_names)
    print(report)

    return val_loss, val_macro_f1, val_weighted_f1

## Model Training

In [None]:
full_train_data = json.load(open("anxiety_train_llava_dataset.json", "r"))
test_data = json.load(open("anxiety_test_llava_dataset.json", "r"))

labels = [LABEL_MAP[item["meme_anxiety_category"]] for item in full_train_data]

train_size = math.ceil(len(full_train_data) * TRAIN_RATIO / (TRAIN_RATIO + VAL_RATIO))
train_data, val_data = train_test_split(
    full_train_data, train_size=train_size, stratify=labels, random_state=42
)

tokenizer = AutoTokenizer.from_pretrained("mental/mental-roberta-base")

test_dataset = TextOnlyAnxietyDataset(test_data, tokenizer, max_len=MAX_LEN)
test_loader = DataLoader(
    test_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    collate_fn=custom_collate_fn
)

In [None]:
model = TextOnlyModel(
    text_model_name="mental/mental-roberta-base",
    num_classes=len(LABEL_MAP),
    hidden_dim=FUSION_DIM
)

trained_model, tokenizer = train_text_model(
    model,
    train_data,
    val_data,
    epochs=30,
    model_save_name="only_text"
)

Some weights of RobertaModel were not initialized from the model checkpoint at mental/mental-roberta-base and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Train Set Size: 2153
Validation Set Size: 307

Epoch 1/30


Training: 100%|██████████| 135/135 [02:01<00:00,  1.11it/s]


Train Loss: 1.7209
Train Macro-F1: 0.2361, Weighted-F1: 0.2382


Evaluating: 100%|██████████| 20/20 [00:05<00:00,  3.70it/s]


                       precision    recall  f1-score   support

          Nervousness       0.41      0.79      0.54        53
Lack of Worry Control       0.51      0.51      0.51        47
      Excessive Worry       0.47      0.20      0.28        46
  Difficulty Relaxing       0.83      0.78      0.81        51
         Restlessness       0.35      0.45      0.39        58
       Impending Doom       0.69      0.21      0.32        52

             accuracy                           0.50       307
            macro avg       0.54      0.49      0.48       307
         weighted avg       0.54      0.50      0.48       307

Validation Loss: 1.3568
Validation Macro-F1: 0.4753, Weighted-F1: 0.4761
Best model saved!

Epoch 2/30


Training: 100%|██████████| 135/135 [02:01<00:00,  1.11it/s]


Train Loss: 1.1704
Train Macro-F1: 0.5547, Weighted-F1: 0.5545


Evaluating: 100%|██████████| 20/20 [00:05<00:00,  3.68it/s]


                       precision    recall  f1-score   support

          Nervousness       0.65      0.60      0.63        53
Lack of Worry Control       0.57      0.57      0.57        47
      Excessive Worry       0.63      0.37      0.47        46
  Difficulty Relaxing       0.82      0.92      0.87        51
         Restlessness       0.49      0.69      0.57        58
       Impending Doom       0.71      0.62      0.66        52

             accuracy                           0.64       307
            macro avg       0.65      0.63      0.63       307
         weighted avg       0.64      0.64      0.63       307

Validation Loss: 1.0349
Validation Macro-F1: 0.6282, Weighted-F1: 0.6304
Best model saved!

Epoch 3/30


Training: 100%|██████████| 135/135 [02:01<00:00,  1.11it/s]


Train Loss: 0.7756
Train Macro-F1: 0.7311, Weighted-F1: 0.7311


Evaluating: 100%|██████████| 20/20 [00:05<00:00,  3.69it/s]


                       precision    recall  f1-score   support

          Nervousness       0.67      0.58      0.63        53
Lack of Worry Control       0.74      0.62      0.67        47
      Excessive Worry       0.66      0.46      0.54        46
  Difficulty Relaxing       0.90      0.84      0.87        51
         Restlessness       0.45      0.74      0.56        58
       Impending Doom       0.67      0.60      0.63        52

             accuracy                           0.64       307
            macro avg       0.68      0.64      0.65       307
         weighted avg       0.68      0.64      0.65       307

Validation Loss: 1.0623
Validation Macro-F1: 0.6498, Weighted-F1: 0.6490
Best model saved!

Epoch 4/30


Training: 100%|██████████| 135/135 [02:01<00:00,  1.11it/s]


Train Loss: 0.4776
Train Macro-F1: 0.8468, Weighted-F1: 0.8472


Evaluating: 100%|██████████| 20/20 [00:05<00:00,  3.69it/s]


                       precision    recall  f1-score   support

          Nervousness       0.55      0.77      0.65        53
Lack of Worry Control       0.80      0.60      0.68        47
      Excessive Worry       0.58      0.41      0.48        46
  Difficulty Relaxing       0.91      0.76      0.83        51
         Restlessness       0.55      0.55      0.55        58
       Impending Doom       0.62      0.77      0.69        52

             accuracy                           0.65       307
            macro avg       0.67      0.64      0.65       307
         weighted avg       0.67      0.65      0.65       307

Validation Loss: 1.2091
Validation Macro-F1: 0.6468, Weighted-F1: 0.6470

Epoch 5/30


Training: 100%|██████████| 135/135 [02:01<00:00,  1.11it/s]


Train Loss: 0.3084
Train Macro-F1: 0.9072, Weighted-F1: 0.9072


Evaluating: 100%|██████████| 20/20 [00:05<00:00,  3.70it/s]


                       precision    recall  f1-score   support

          Nervousness       0.53      0.72      0.61        53
Lack of Worry Control       0.82      0.66      0.73        47
      Excessive Worry       0.59      0.48      0.53        46
  Difficulty Relaxing       0.82      0.88      0.85        51
         Restlessness       0.61      0.53      0.57        58
       Impending Doom       0.70      0.73      0.72        52

             accuracy                           0.67       307
            macro avg       0.68      0.67      0.67       307
         weighted avg       0.68      0.67      0.67       307

Validation Loss: 1.2773
Validation Macro-F1: 0.6671, Weighted-F1: 0.6660
Best model saved!

Epoch 6/30


Training: 100%|██████████| 135/135 [02:01<00:00,  1.11it/s]


Train Loss: 0.1928
Train Macro-F1: 0.9442, Weighted-F1: 0.9447


Evaluating: 100%|██████████| 20/20 [00:05<00:00,  3.69it/s]


                       precision    recall  f1-score   support

          Nervousness       0.54      0.77      0.64        53
Lack of Worry Control       0.72      0.70      0.71        47
      Excessive Worry       0.47      0.50      0.48        46
  Difficulty Relaxing       0.91      0.78      0.84        51
         Restlessness       0.68      0.47      0.55        58
       Impending Doom       0.73      0.73      0.73        52

             accuracy                           0.66       307
            macro avg       0.67      0.66      0.66       307
         weighted avg       0.68      0.66      0.66       307

Validation Loss: 1.5622
Validation Macro-F1: 0.6589, Weighted-F1: 0.6587

Epoch 7/30


Training: 100%|██████████| 135/135 [02:01<00:00,  1.11it/s]


Train Loss: 0.1604
Train Macro-F1: 0.9531, Weighted-F1: 0.9540


Evaluating: 100%|██████████| 20/20 [00:05<00:00,  3.47it/s]


                       precision    recall  f1-score   support

          Nervousness       0.58      0.68      0.63        53
Lack of Worry Control       0.62      0.70      0.66        47
      Excessive Worry       0.83      0.33      0.47        46
  Difficulty Relaxing       0.87      0.78      0.82        51
         Restlessness       0.54      0.64      0.59        58
       Impending Doom       0.65      0.75      0.70        52

             accuracy                           0.65       307
            macro avg       0.68      0.65      0.64       307
         weighted avg       0.68      0.65      0.65       307

Validation Loss: 1.6942
Validation Macro-F1: 0.6439, Weighted-F1: 0.6453

Epoch 8/30


Training: 100%|██████████| 135/135 [02:01<00:00,  1.11it/s]


Train Loss: 0.1207
Train Macro-F1: 0.9666, Weighted-F1: 0.9671


Evaluating: 100%|██████████| 20/20 [00:05<00:00,  3.68it/s]


                       precision    recall  f1-score   support

          Nervousness       0.71      0.60      0.65        53
Lack of Worry Control       0.76      0.62      0.68        47
      Excessive Worry       0.79      0.41      0.54        46
  Difficulty Relaxing       0.90      0.75      0.82        51
         Restlessness       0.50      0.72      0.59        58
       Impending Doom       0.57      0.81      0.67        52

             accuracy                           0.66       307
            macro avg       0.71      0.65      0.66       307
         weighted avg       0.70      0.66      0.66       307

Validation Loss: 1.9583
Validation Macro-F1: 0.6589, Weighted-F1: 0.6590

Epoch 9/30


Training: 100%|██████████| 135/135 [02:01<00:00,  1.11it/s]


Train Loss: 0.0642
Train Macro-F1: 0.9820, Weighted-F1: 0.9819


Evaluating: 100%|██████████| 20/20 [00:05<00:00,  3.70it/s]


                       precision    recall  f1-score   support

          Nervousness       0.67      0.64      0.65        53
Lack of Worry Control       0.70      0.74      0.72        47
      Excessive Worry       0.55      0.48      0.51        46
  Difficulty Relaxing       0.88      0.73      0.80        51
         Restlessness       0.56      0.71      0.63        58
       Impending Doom       0.75      0.73      0.74        52

             accuracy                           0.67       307
            macro avg       0.68      0.67      0.67       307
         weighted avg       0.68      0.67      0.68       307

Validation Loss: 1.8230
Validation Macro-F1: 0.6744, Weighted-F1: 0.6754
Best model saved!

Epoch 10/30


Training: 100%|██████████| 135/135 [02:02<00:00,  1.10it/s]


Train Loss: 0.0376
Train Macro-F1: 0.9897, Weighted-F1: 0.9898


Evaluating: 100%|██████████| 20/20 [00:05<00:00,  3.69it/s]


                       precision    recall  f1-score   support

          Nervousness       0.65      0.68      0.67        53
Lack of Worry Control       0.72      0.66      0.69        47
      Excessive Worry       0.62      0.46      0.52        46
  Difficulty Relaxing       0.90      0.75      0.82        51
         Restlessness       0.54      0.69      0.61        58
       Impending Doom       0.68      0.77      0.72        52

             accuracy                           0.67       307
            macro avg       0.69      0.67      0.67       307
         weighted avg       0.68      0.67      0.67       307

Validation Loss: 1.9831
Validation Macro-F1: 0.6708, Weighted-F1: 0.6716

Epoch 11/30


Training: 100%|██████████| 135/135 [02:02<00:00,  1.11it/s]


Train Loss: 0.0365
Train Macro-F1: 0.9906, Weighted-F1: 0.9907


Evaluating: 100%|██████████| 20/20 [00:05<00:00,  3.70it/s]


                       precision    recall  f1-score   support

          Nervousness       0.67      0.70      0.69        53
Lack of Worry Control       0.65      0.77      0.71        47
      Excessive Worry       0.62      0.46      0.52        46
  Difficulty Relaxing       0.89      0.80      0.85        51
         Restlessness       0.66      0.64      0.65        58
       Impending Doom       0.66      0.77      0.71        52

             accuracy                           0.69       307
            macro avg       0.69      0.69      0.69       307
         weighted avg       0.69      0.69      0.69       307

Validation Loss: 1.9283
Validation Macro-F1: 0.6864, Weighted-F1: 0.6880
Best model saved!

Epoch 12/30


Training: 100%|██████████| 135/135 [02:01<00:00,  1.11it/s]


Train Loss: 0.0275
Train Macro-F1: 0.9935, Weighted-F1: 0.9935


Evaluating: 100%|██████████| 20/20 [00:05<00:00,  3.69it/s]


                       precision    recall  f1-score   support

          Nervousness       0.65      0.68      0.67        53
Lack of Worry Control       0.69      0.74      0.71        47
      Excessive Worry       0.63      0.48      0.54        46
  Difficulty Relaxing       0.89      0.76      0.82        51
         Restlessness       0.59      0.64      0.61        58
       Impending Doom       0.66      0.75      0.70        52

             accuracy                           0.68       307
            macro avg       0.68      0.68      0.68       307
         weighted avg       0.68      0.68      0.68       307

Validation Loss: 2.0381
Validation Macro-F1: 0.6766, Weighted-F1: 0.6768

Epoch 13/30


Training: 100%|██████████| 135/135 [02:01<00:00,  1.11it/s]


Train Loss: 0.0243
Train Macro-F1: 0.9939, Weighted-F1: 0.9940


Evaluating: 100%|██████████| 20/20 [00:05<00:00,  3.69it/s]


                       precision    recall  f1-score   support

          Nervousness       0.58      0.74      0.65        53
Lack of Worry Control       0.72      0.62      0.67        47
      Excessive Worry       0.52      0.50      0.51        46
  Difficulty Relaxing       0.88      0.75      0.81        51
         Restlessness       0.63      0.62      0.63        58
       Impending Doom       0.68      0.73      0.70        52

             accuracy                           0.66       307
            macro avg       0.67      0.66      0.66       307
         weighted avg       0.67      0.66      0.66       307

Validation Loss: 2.1167
Validation Macro-F1: 0.6610, Weighted-F1: 0.6627

Epoch 14/30


Training: 100%|██████████| 135/135 [02:02<00:00,  1.11it/s]


Train Loss: 0.0267
Train Macro-F1: 0.9929, Weighted-F1: 0.9930


Evaluating: 100%|██████████| 20/20 [00:05<00:00,  3.68it/s]


                       precision    recall  f1-score   support

          Nervousness       0.70      0.72      0.71        53
Lack of Worry Control       0.68      0.68      0.68        47
      Excessive Worry       0.66      0.50      0.57        46
  Difficulty Relaxing       0.88      0.82      0.85        51
         Restlessness       0.66      0.67      0.67        58
       Impending Doom       0.62      0.77      0.69        52

             accuracy                           0.70       307
            macro avg       0.70      0.69      0.69       307
         weighted avg       0.70      0.70      0.70       307

Validation Loss: 2.1297
Validation Macro-F1: 0.6940, Weighted-F1: 0.6957
Best model saved!

Epoch 15/30


Training: 100%|██████████| 135/135 [02:01<00:00,  1.11it/s]


Train Loss: 0.0301
Train Macro-F1: 0.9916, Weighted-F1: 0.9916


Evaluating: 100%|██████████| 20/20 [00:05<00:00,  3.68it/s]


                       precision    recall  f1-score   support

          Nervousness       0.60      0.74      0.66        53
Lack of Worry Control       0.74      0.66      0.70        47
      Excessive Worry       0.57      0.46      0.51        46
  Difficulty Relaxing       0.88      0.71      0.78        51
         Restlessness       0.58      0.66      0.62        58
       Impending Doom       0.68      0.75      0.72        52

             accuracy                           0.66       307
            macro avg       0.68      0.66      0.66       307
         weighted avg       0.67      0.66      0.66       307

Validation Loss: 2.1898
Validation Macro-F1: 0.6633, Weighted-F1: 0.6645

Epoch 16/30


Training: 100%|██████████| 135/135 [02:01<00:00,  1.11it/s]


Train Loss: 0.0246
Train Macro-F1: 0.9939, Weighted-F1: 0.9940


Evaluating: 100%|██████████| 20/20 [00:05<00:00,  3.69it/s]


                       precision    recall  f1-score   support

          Nervousness       0.71      0.64      0.67        53
Lack of Worry Control       0.71      0.74      0.73        47
      Excessive Worry       0.59      0.59      0.59        46
  Difficulty Relaxing       0.88      0.75      0.81        51
         Restlessness       0.60      0.66      0.63        58
       Impending Doom       0.67      0.75      0.71        52

             accuracy                           0.69       307
            macro avg       0.69      0.69      0.69       307
         weighted avg       0.69      0.69      0.69       307

Validation Loss: 2.0588
Validation Macro-F1: 0.6892, Weighted-F1: 0.6889

Epoch 17/30


Training: 100%|██████████| 135/135 [02:02<00:00,  1.10it/s]


Train Loss: 0.0223
Train Macro-F1: 0.9962, Weighted-F1: 0.9963


Evaluating: 100%|██████████| 20/20 [00:05<00:00,  3.69it/s]


                       precision    recall  f1-score   support

          Nervousness       0.66      0.70      0.68        53
Lack of Worry Control       0.69      0.74      0.71        47
      Excessive Worry       0.63      0.48      0.54        46
  Difficulty Relaxing       0.89      0.78      0.83        51
         Restlessness       0.62      0.67      0.64        58
       Impending Doom       0.68      0.75      0.72        52

             accuracy                           0.69       307
            macro avg       0.69      0.69      0.69       307
         weighted avg       0.69      0.69      0.69       307

Validation Loss: 2.1529
Validation Macro-F1: 0.6883, Weighted-F1: 0.6894

Epoch 18/30


Training: 100%|██████████| 135/135 [02:01<00:00,  1.11it/s]


Train Loss: 0.0181
Train Macro-F1: 0.9953, Weighted-F1: 0.9954


Evaluating: 100%|██████████| 20/20 [00:05<00:00,  3.48it/s]


                       precision    recall  f1-score   support

          Nervousness       0.65      0.70      0.67        53
Lack of Worry Control       0.73      0.74      0.74        47
      Excessive Worry       0.65      0.48      0.55        46
  Difficulty Relaxing       0.89      0.80      0.85        51
         Restlessness       0.65      0.69      0.67        58
       Impending Doom       0.67      0.77      0.71        52

             accuracy                           0.70       307
            macro avg       0.70      0.70      0.70       307
         weighted avg       0.70      0.70      0.70       307

Validation Loss: 2.1537
Validation Macro-F1: 0.6976, Weighted-F1: 0.6987
Best model saved!

Epoch 19/30


Training: 100%|██████████| 135/135 [02:01<00:00,  1.11it/s]


Train Loss: 0.0137
Train Macro-F1: 0.9966, Weighted-F1: 0.9968


Evaluating: 100%|██████████| 20/20 [00:05<00:00,  3.68it/s]


                       precision    recall  f1-score   support

          Nervousness       0.62      0.72      0.67        53
Lack of Worry Control       0.73      0.68      0.70        47
      Excessive Worry       0.67      0.43      0.53        46
  Difficulty Relaxing       0.88      0.86      0.87        51
         Restlessness       0.62      0.69      0.66        58
       Impending Doom       0.67      0.75      0.71        52

             accuracy                           0.69       307
            macro avg       0.70      0.69      0.69       307
         weighted avg       0.70      0.69      0.69       307

Validation Loss: 2.2576
Validation Macro-F1: 0.6887, Weighted-F1: 0.6904

Epoch 20/30


Training: 100%|██████████| 135/135 [02:01<00:00,  1.11it/s]


Train Loss: 0.0109
Train Macro-F1: 0.9958, Weighted-F1: 0.9958


Evaluating: 100%|██████████| 20/20 [00:05<00:00,  3.70it/s]


                       precision    recall  f1-score   support

          Nervousness       0.65      0.68      0.67        53
Lack of Worry Control       0.71      0.72      0.72        47
      Excessive Worry       0.60      0.46      0.52        46
  Difficulty Relaxing       0.89      0.76      0.82        51
         Restlessness       0.56      0.69      0.62        58
       Impending Doom       0.69      0.71      0.70        52

             accuracy                           0.67       307
            macro avg       0.68      0.67      0.67       307
         weighted avg       0.68      0.67      0.67       307

Validation Loss: 2.2565
Validation Macro-F1: 0.6734, Weighted-F1: 0.6742

Epoch 21/30


Training: 100%|██████████| 135/135 [02:01<00:00,  1.11it/s]


Train Loss: 0.0099
Train Macro-F1: 0.9972, Weighted-F1: 0.9972


Evaluating: 100%|██████████| 20/20 [00:05<00:00,  3.69it/s]


                       precision    recall  f1-score   support

          Nervousness       0.66      0.70      0.68        53
Lack of Worry Control       0.70      0.70      0.70        47
      Excessive Worry       0.68      0.46      0.55        46
  Difficulty Relaxing       0.87      0.78      0.82        51
         Restlessness       0.56      0.69      0.62        58
       Impending Doom       0.69      0.73      0.71        52

             accuracy                           0.68       307
            macro avg       0.69      0.68      0.68       307
         weighted avg       0.69      0.68      0.68       307

Validation Loss: 2.3001
Validation Macro-F1: 0.6795, Weighted-F1: 0.6800

Epoch 22/30


Training: 100%|██████████| 135/135 [02:01<00:00,  1.11it/s]


Train Loss: 0.0102
Train Macro-F1: 0.9967, Weighted-F1: 0.9967


Evaluating: 100%|██████████| 20/20 [00:05<00:00,  3.67it/s]


                       precision    recall  f1-score   support

          Nervousness       0.63      0.72      0.67        53
Lack of Worry Control       0.70      0.68      0.69        47
      Excessive Worry       0.59      0.48      0.53        46
  Difficulty Relaxing       0.87      0.80      0.84        51
         Restlessness       0.63      0.67      0.65        58
       Impending Doom       0.69      0.73      0.71        52

             accuracy                           0.68       307
            macro avg       0.69      0.68      0.68       307
         weighted avg       0.69      0.68      0.68       307

Validation Loss: 2.2568
Validation Macro-F1: 0.6813, Weighted-F1: 0.6830

Epoch 23/30


Training: 100%|██████████| 135/135 [02:01<00:00,  1.11it/s]


Train Loss: 0.0090
Train Macro-F1: 0.9972, Weighted-F1: 0.9972


Evaluating: 100%|██████████| 20/20 [00:05<00:00,  3.70it/s]


                       precision    recall  f1-score   support

          Nervousness       0.65      0.70      0.67        53
Lack of Worry Control       0.68      0.72      0.70        47
      Excessive Worry       0.66      0.50      0.57        46
  Difficulty Relaxing       0.89      0.76      0.82        51
         Restlessness       0.62      0.69      0.66        58
       Impending Doom       0.67      0.73      0.70        52

             accuracy                           0.69       307
            macro avg       0.69      0.68      0.69       307
         weighted avg       0.69      0.69      0.69       307

Validation Loss: 2.3268
Validation Macro-F1: 0.6859, Weighted-F1: 0.6869

Epoch 24/30


Training: 100%|██████████| 135/135 [02:01<00:00,  1.11it/s]


Train Loss: 0.0114
Train Macro-F1: 0.9962, Weighted-F1: 0.9963


Evaluating: 100%|██████████| 20/20 [00:05<00:00,  3.70it/s]


                       precision    recall  f1-score   support

          Nervousness       0.67      0.68      0.67        53
Lack of Worry Control       0.65      0.72      0.69        47
      Excessive Worry       0.62      0.52      0.56        46
  Difficulty Relaxing       0.89      0.80      0.85        51
         Restlessness       0.64      0.64      0.64        58
       Impending Doom       0.67      0.75      0.71        52

             accuracy                           0.69       307
            macro avg       0.69      0.69      0.69       307
         weighted avg       0.69      0.69      0.69       307

Validation Loss: 2.2796
Validation Macro-F1: 0.6861, Weighted-F1: 0.6870

Epoch 25/30


Training: 100%|██████████| 135/135 [02:01<00:00,  1.11it/s]


Train Loss: 0.0082
Train Macro-F1: 0.9962, Weighted-F1: 0.9963


Evaluating: 100%|██████████| 20/20 [00:05<00:00,  3.69it/s]


                       precision    recall  f1-score   support

          Nervousness       0.64      0.70      0.67        53
Lack of Worry Control       0.67      0.72      0.69        47
      Excessive Worry       0.63      0.48      0.54        46
  Difficulty Relaxing       0.90      0.84      0.87        51
         Restlessness       0.65      0.64      0.64        58
       Impending Doom       0.67      0.75      0.71        52

             accuracy                           0.69       307
            macro avg       0.69      0.69      0.69       307
         weighted avg       0.69      0.69      0.69       307

Validation Loss: 2.2764
Validation Macro-F1: 0.6875, Weighted-F1: 0.6887

Epoch 26/30


Training: 100%|██████████| 135/135 [02:01<00:00,  1.11it/s]


Train Loss: 0.0066
Train Macro-F1: 0.9972, Weighted-F1: 0.9972


Evaluating: 100%|██████████| 20/20 [00:05<00:00,  3.69it/s]


                       precision    recall  f1-score   support

          Nervousness       0.64      0.70      0.67        53
Lack of Worry Control       0.67      0.70      0.69        47
      Excessive Worry       0.61      0.48      0.54        46
  Difficulty Relaxing       0.90      0.84      0.87        51
         Restlessness       0.66      0.66      0.66        58
       Impending Doom       0.67      0.75      0.71        52

             accuracy                           0.69       307
            macro avg       0.69      0.69      0.69       307
         weighted avg       0.69      0.69      0.69       307

Validation Loss: 2.2978
Validation Macro-F1: 0.6873, Weighted-F1: 0.6889

Epoch 27/30


Training: 100%|██████████| 135/135 [02:01<00:00,  1.11it/s]


Train Loss: 0.0080
Train Macro-F1: 0.9971, Weighted-F1: 0.9972


Evaluating: 100%|██████████| 20/20 [00:05<00:00,  3.70it/s]


                       precision    recall  f1-score   support

          Nervousness       0.64      0.70      0.67        53
Lack of Worry Control       0.69      0.70      0.69        47
      Excessive Worry       0.61      0.48      0.54        46
  Difficulty Relaxing       0.90      0.84      0.87        51
         Restlessness       0.65      0.67      0.66        58
       Impending Doom       0.67      0.73      0.70        52

             accuracy                           0.69       307
            macro avg       0.69      0.69      0.69       307
         weighted avg       0.69      0.69      0.69       307

Validation Loss: 2.2993
Validation Macro-F1: 0.6875, Weighted-F1: 0.6891

Epoch 28/30


Training: 100%|██████████| 135/135 [02:01<00:00,  1.11it/s]


Train Loss: 0.0064
Train Macro-F1: 0.9972, Weighted-F1: 0.9972


Evaluating: 100%|██████████| 20/20 [00:05<00:00,  3.72it/s]


                       precision    recall  f1-score   support

          Nervousness       0.65      0.70      0.67        53
Lack of Worry Control       0.69      0.70      0.69        47
      Excessive Worry       0.62      0.50      0.55        46
  Difficulty Relaxing       0.90      0.86      0.88        51
         Restlessness       0.66      0.67      0.67        58
       Impending Doom       0.68      0.75      0.72        52

             accuracy                           0.70       307
            macro avg       0.70      0.70      0.70       307
         weighted avg       0.70      0.70      0.70       307

Validation Loss: 2.2882
Validation Macro-F1: 0.6973, Weighted-F1: 0.6989

Epoch 29/30


Training: 100%|██████████| 135/135 [02:01<00:00,  1.11it/s]


Train Loss: 0.0063
Train Macro-F1: 0.9972, Weighted-F1: 0.9972


Evaluating: 100%|██████████| 20/20 [00:05<00:00,  3.70it/s]


                       precision    recall  f1-score   support

          Nervousness       0.65      0.70      0.67        53
Lack of Worry Control       0.69      0.70      0.69        47
      Excessive Worry       0.62      0.50      0.55        46
  Difficulty Relaxing       0.90      0.86      0.88        51
         Restlessness       0.65      0.67      0.66        58
       Impending Doom       0.68      0.73      0.70        52

             accuracy                           0.70       307
            macro avg       0.70      0.69      0.69       307
         weighted avg       0.70      0.70      0.70       307

Validation Loss: 2.3026
Validation Macro-F1: 0.6944, Weighted-F1: 0.6958

Epoch 30/30


Training: 100%|██████████| 135/135 [02:01<00:00,  1.11it/s]


Train Loss: 0.0059
Train Macro-F1: 0.9976, Weighted-F1: 0.9977


Evaluating: 100%|██████████| 20/20 [00:05<00:00,  3.71it/s]


                       precision    recall  f1-score   support

          Nervousness       0.65      0.70      0.67        53
Lack of Worry Control       0.69      0.70      0.69        47
      Excessive Worry       0.62      0.50      0.55        46
  Difficulty Relaxing       0.90      0.86      0.88        51
         Restlessness       0.65      0.67      0.66        58
       Impending Doom       0.68      0.73      0.70        52

             accuracy                           0.70       307
            macro avg       0.70      0.69      0.69       307
         weighted avg       0.70      0.70      0.70       307

Validation Loss: 2.2974
Validation Macro-F1: 0.6944, Weighted-F1: 0.6958


## Inference

In [32]:
!gdown 1nYT2Dw1vUTh5iK7FNSLwlJq1921_FjTO

Downloading...
From (original): https://drive.google.com/uc?id=1nYT2Dw1vUTh5iK7FNSLwlJq1921_FjTO
From (redirected): https://drive.google.com/uc?id=1nYT2Dw1vUTh5iK7FNSLwlJq1921_FjTO&confirm=t&uuid=3f0e6ee0-c984-43dc-aa27-4d8cb5bb78a3
To: /kaggle/working/only_text_anxiety.pth
100%|████████████████████████████████████████| 518M/518M [00:12<00:00, 42.1MB/s]


In [38]:
def inference(test_data, model_path):
    model = TextOnlyModel(
        text_model_name="mental/mental-roberta-base",
        num_classes=len(LABEL_MAP),
        hidden_dim=FUSION_DIM
    )

    model = model.to(DEVICE)

    weights = torch.load(model_path, map_location=DEVICE, weights_only=True)
    weights_single = {k.replace("module.", ""): v for k, v in weights.items()}

    model.load_state_dict(weights_single)
    tokenizer = AutoTokenizer.from_pretrained("mental/mental-roberta-base")

    test_dataset = TextOnlyAnxietyDataset(test_data, tokenizer, max_len=MAX_LEN)
    test_loader = DataLoader(
        test_dataset,
        batch_size=BATCH_SIZE,
        shuffle=False,
        collate_fn=custom_collate_fn
    )

    loss, macro_f1, weighted_f1 = evaluate_text_model(
        model, test_loader, nn.CrossEntropyLoss()
    )

    print(f"Test Loss: {loss:.4f}")
    print(f"Test Macro-F1: {macro_f1:.4f}, Weighted-F1: {weighted_f1:.4f}")

    return macro_f1, weighted_f1

In [39]:
inference(test_data, "only_text_anxiety.pth")

Some weights of RobertaModel were not initialized from the model checkpoint at mental/mental-roberta-base and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Evaluating: 100%|██████████| 39/39 [00:20<00:00,  1.93it/s]

                       precision    recall  f1-score   support

          Nervousness       0.70      0.60      0.65       106
Lack of Worry Control       0.65      0.67      0.66        94
      Excessive Worry       0.62      0.60      0.61        92
  Difficulty Relaxing       0.79      0.76      0.78       102
         Restlessness       0.50      0.59      0.54       116
       Impending Doom       0.74      0.72      0.73       105

             accuracy                           0.66       615
            macro avg       0.67      0.66      0.66       615
         weighted avg       0.66      0.66      0.66       615

Test Loss: 2.5080
Test Macro-F1: 0.6606, Weighted-F1: 0.6590





(0.6606387991478248, 0.658952957544853)