## Dataset Download

In [1]:
!gdown 1doU9FE1tJ-0IL4tKZUOOp76JkXKuTkFU

Downloading...
From (original): https://drive.google.com/uc?id=1doU9FE1tJ-0IL4tKZUOOp76JkXKuTkFU
From (redirected): https://drive.google.com/uc?id=1doU9FE1tJ-0IL4tKZUOOp76JkXKuTkFU&confirm=t&uuid=9749e577-5267-4601-8b6e-08c110657b91
To: /kaggle/working/anxiety_dataset_complete.zip
100%|██████████████████████████████████████| 1.41G/1.41G [00:31<00:00, 45.4MB/s]


In [2]:
!unzip anxiety_dataset_complete.zip > /dev/null

In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from transformers import AutoTokenizer, AutoModel, CLIPImageProcessor, CLIPVisionModel
from sklearn.metrics import f1_score, classification_report
from sklearn.model_selection import train_test_split
from tqdm import tqdm
from PIL import Image
import json
import math

TRAIN_RATIO = 0.7  # 70% train
VAL_RATIO = 0.1  # 10% validation
TEST_RATIO = 0.2 # 20% test
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
BATCH_SIZE = 16
LR = 2e-5
MAX_LEN = 512
FUSION_DIM = 768
LABELS = ["Nervousness", "Lack of Worry Control", "Excessive Worry",
          "Difficulty Relaxing", "Restlessness", "Impending Doom"]
LABEL_MAP = {i: idx for idx, i in enumerate(LABELS)} 

## Custom Dataset

In [6]:
class MultimodalAnxietyDataset(Dataset):
    def __init__(self, data, image_path, tokenizer, image_processor, max_len=512):
        self.data = data
        self.tokenizer = tokenizer
        self.image_processor = image_processor
        self.max_len = max_len
        self.img_path = image_path
        self.transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        sample = self.data[idx]

        ocr_text = sample["ocr_text"]
        figurative_reasoning = sample["figurative_reasoning"]
        combined_text = ocr_text + " [SEP] " + figurative_reasoning

        encoding = self.tokenizer(
            combined_text,
            padding="max_length",
            truncation=True,
            max_length=self.max_len,
            return_tensors="pt",
        )

        image_path = os.path.join(self.img_path, sample["sample_id"] + ".jpg")
        image = Image.open(image_path).convert("RGB")
        image_tensor = self.image_processor(image, return_tensors="pt")

        label = LABEL_MAP[sample["meme_anxiety_category"]]

        return {
            "input_ids": encoding["input_ids"].squeeze(0),
            "attention_mask": encoding["attention_mask"].squeeze(0),
            "image": image_tensor,
            "label": torch.tensor(label, dtype=torch.long)
        }

def custom_collate_fn(batch):
    input_ids = torch.stack([item['input_ids'] for item in batch])
    attention_mask = torch.stack([item['attention_mask'] for item in batch])
    labels = torch.stack([item['label'] for item in batch])

    images = {}
    for key in batch[0]['image'].keys():
        if isinstance(batch[0]['image'][key], torch.Tensor):
            images[key] = torch.stack([item['image'][key].squeeze(0) for item in batch])

    return {
        'input_ids': input_ids,
        'attention_mask': attention_mask,
        'image': images,
        'label': labels,
    }

## Model Definition

In [7]:
class CrossAttentionLayer(nn.Module):
    def __init__(self, hidden_size, num_attention_heads=8, dropout=0.1):
        super().__init__()
        self.num_attention_heads = num_attention_heads
        self.attention_head_size = hidden_size // num_attention_heads
        self.all_head_size = self.num_attention_heads * self.attention_head_size

        self.query = nn.Linear(hidden_size, self.all_head_size)
        self.key = nn.Linear(hidden_size, self.all_head_size)
        self.value = nn.Linear(hidden_size, self.all_head_size)

        self.output = nn.Linear(hidden_size, hidden_size)

        self.dropout = nn.Dropout(dropout)
        self.layer_norm = nn.LayerNorm(hidden_size)

        self.ffn = nn.Sequential(
            nn.Linear(hidden_size, hidden_size * 4),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_size * 4, hidden_size),
            nn.Dropout(dropout)
        )
        self.ffn_layer_norm = nn.LayerNorm(hidden_size)

    def transpose_for_scores(self, x):
        new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
        x = x.view(*new_x_shape)
        return x.permute(0, 2, 1, 3)

    def forward(self, x, context):
        mixed_query_layer = self.query(x)
        mixed_key_layer = self.key(context)
        mixed_value_layer = self.value(context)

        query_layer = self.transpose_for_scores(mixed_query_layer)
        key_layer = self.transpose_for_scores(mixed_key_layer)
        value_layer = self.transpose_for_scores(mixed_value_layer)

        attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
        attention_scores = attention_scores / math.sqrt(self.attention_head_size)

        attention_probs = nn.Softmax(dim=-1)(attention_scores)
        attention_probs = self.dropout(attention_probs)

        context_layer = torch.matmul(attention_probs, value_layer)
        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
        context_layer = context_layer.view(*new_context_layer_shape)

        attention_output = self.output(context_layer)
        attention_output = self.dropout(attention_output)
        attention_output = self.layer_norm(attention_output + x)

        ffn_output = self.ffn(attention_output)

        output = self.ffn_layer_norm(ffn_output + attention_output)

        return output

class MultimodalAttentionModel(nn.Module):
    def __init__(self, text_model_name="bert-base-uncased", num_classes=3, fusion_dim=768):
        super(MultimodalAttentionModel, self).__init__()

        self.text_encoder = AutoModel.from_pretrained(text_model_name)
        self.text_dim = self.text_encoder.config.hidden_size

        self.vision_encoder = CLIPVisionModel.from_pretrained("openai/clip-vit-base-patch32")
        self.vision_dim = self.vision_encoder.config.hidden_size

        self.text_dim = self.text_encoder.config.hidden_size
        self.vision_dim = self.vision_encoder.config.hidden_size
        fusion_dim = 1024

        self.text_projections = nn.ModuleList([
            nn.Linear(self.text_dim, fusion_dim) for _ in range(3)
        ])

        self.vision_projections = nn.ModuleList([
            nn.Linear(self.vision_dim, fusion_dim) for _ in range(3)
        ])

        self.co_attention_layers = nn.ModuleList([
            CrossAttentionLayer(fusion_dim) for _ in range(3)
        ])

        self.expert_nets = nn.ModuleList([
            nn.Sequential(
                nn.Linear(fusion_dim, fusion_dim),
                nn.LayerNorm(fusion_dim),
                nn.GELU(),
                nn.Dropout(0.2)
            ) for _ in range(4)
        ])

        self.moe_gate = nn.Linear(fusion_dim, 4)

        self.classifier = nn.Sequential(
            nn.Linear(fusion_dim, fusion_dim),
            nn.LayerNorm(fusion_dim),
            nn.GELU(),
            nn.Dropout(0.1),
            nn.Linear(fusion_dim, num_classes)
        )

    def forward(self, input_ids, attention_mask, image_features):
        text_outputs = self.text_encoder(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True)
        text_features = [
            text_outputs.hidden_states[-i][:, 0] for i in range(1, 4)
        ]
        text_features = [
            proj(feat) for proj, feat in zip(self.text_projections, text_features)
        ]

        vision_outputs = self.vision_encoder(**image_features, output_hidden_states=True)
        vision_features = [
            vision_outputs.hidden_states[-i][:, 0] for i in range(1, 4)
        ]
        vision_features = [
            proj(feat) for proj, feat in zip(self.vision_projections, vision_features)
        ]

        fused_features = []
        for text_feat, vision_feat, co_attn in zip(text_features, vision_features, self.co_attention_layers):
            fused_feat = co_attn(text_feat.unsqueeze(1), vision_feat.unsqueeze(1))
            fused_features.append(fused_feat.squeeze(1))

        combined_feature = sum(fused_features) / len(fused_features)

        expert_outputs = [expert(combined_feature) for expert in self.expert_nets]
        expert_gates = F.softmax(self.moe_gate(combined_feature), dim=1)

        moe_output = torch.zeros_like(expert_outputs[0])
        for i, expert_out in enumerate(expert_outputs):
            moe_output += expert_out * expert_gates[:, i].unsqueeze(1)

        logits = self.classifier(moe_output)

        return logits

## Training Functions

In [8]:
def train_multimodal_model(model, train_data, val_data, img_path, epochs, model_save_name):
    tokenizer = AutoTokenizer.from_pretrained("mental/mental-roberta-base")
    image_processor = CLIPImageProcessor.from_pretrained("openai/clip-vit-base-patch32")

    train_dataset = MultimodalAnxietyDataset(train_data, img_path, tokenizer, image_processor, max_len=MAX_LEN)
    val_dataset = MultimodalAnxietyDataset(val_data, img_path, tokenizer, image_processor, max_len=MAX_LEN)

    print("Train Set Size:", len(train_dataset))
    print("Validation Set Size:", len(val_dataset))

    train_loader = DataLoader(
        train_dataset,
        batch_size=BATCH_SIZE,
        shuffle=True,
        collate_fn=custom_collate_fn
    )

    val_loader = DataLoader(
        val_dataset,
        batch_size=BATCH_SIZE,
        shuffle=False,
        collate_fn=custom_collate_fn
    )

    optimizer = optim.AdamW(model.parameters(), lr=LR)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='max', factor=0.5, patience=2, verbose=True
    )
    criterion = nn.CrossEntropyLoss()

    model = model.to(DEVICE)
    model = nn.DataParallel(model)

    best_f1 = 0
    for epoch in range(epochs):
        print(f"\nEpoch {epoch+1}/{epochs}")

        model.train()
        train_loss = 0
        train_preds, train_labels = [], []

        for batch in tqdm(train_loader, desc="Training"):
            input_ids = batch["input_ids"].to(DEVICE)
            attention_mask = batch["attention_mask"].to(DEVICE)
            labels = batch["label"].to(DEVICE)
            image_features = {k: v.to(DEVICE) for k, v in batch["image"].items()}

            optimizer.zero_grad()
            logits = model(input_ids, attention_mask, image_features)
            loss = criterion(logits, labels)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()

            train_loss += loss.item()
            predictions = torch.argmax(logits, dim=1).cpu().numpy()
            train_preds.extend(predictions)
            train_labels.extend(labels.cpu().numpy())

        train_loss = train_loss / len(train_loader)
        train_macro_f1 = f1_score(train_labels, train_preds, average="macro")
        train_weighted_f1 = f1_score(train_labels, train_preds, average="weighted")

        print(f"Train Loss: {train_loss:.4f}")
        print(f"Train Macro-F1: {train_macro_f1:.4f}, Weighted-F1: {train_weighted_f1:.4f}")

        val_loss, val_macro_f1, val_weighted_f1 = evaluate_multimodal_model(
            model, val_loader, criterion
        )

        print(f"Validation Loss: {val_loss:.4f}")
        print(f"Validation Macro-F1: {val_macro_f1:.4f}, Weighted-F1: {val_weighted_f1:.4f}")

        scheduler.step(val_macro_f1)

        f1_hm = 2 * val_macro_f1 * val_weighted_f1 / (val_macro_f1 + val_weighted_f1)
        if f1_hm > best_f1:
            best_f1 = f1_hm
            torch.save(model.state_dict(), f"{model_save_name}_anxiety.pth")
            print("Best model saved!")

    return model

def evaluate_multimodal_model(model, loader, criterion):
    model.eval()
    val_loss = 0
    val_preds, val_labels = [], []

    with torch.no_grad():
        for batch in tqdm(loader, desc="Evaluating"):
            input_ids = batch["input_ids"].to(DEVICE)
            attention_mask = batch["attention_mask"].to(DEVICE)
            labels = batch["label"].to(DEVICE)
            image_features = {k: v.to(DEVICE) for k, v in batch["image"].items()}

            logits = model(input_ids, attention_mask, image_features)
            loss = criterion(logits, labels)

            val_loss += loss.item()
            predictions = torch.argmax(logits, dim=1).cpu().numpy()
            val_preds.extend(predictions)
            val_labels.extend(labels.cpu().numpy())

    val_loss = val_loss / len(loader)
    val_macro_f1 = f1_score(val_labels, val_preds, average="macro")
    val_weighted_f1 = f1_score(val_labels, val_preds, average="weighted")

    target_names = [k for k, v in sorted(LABEL_MAP.items(), key=lambda x: x[1])]
    report = classification_report(val_labels, val_preds, target_names=target_names)
    print(report)

    return val_loss, val_macro_f1, val_weighted_f1

## Model Training

In [None]:
full_train_data = json.load(open("anxiety_train_llava_dataset.json", "r"))
test_data = json.load(open("anxiety_test_llava_dataset.json", "r"))

labels = [LABEL_MAP[item["meme_anxiety_category"]] for item in full_train_data]

train_size = math.ceil(len(full_train_data) * TRAIN_RATIO / (TRAIN_RATIO + VAL_RATIO))
train_data, val_data = train_test_split(
    full_train_data, train_size=train_size, stratify=labels, random_state=42
)

img_path = "anxiety_train_image"

tokenizer = AutoTokenizer.from_pretrained("mental/mental-roberta-base")
image_processor = CLIPImageProcessor.from_pretrained("openai/clip-vit-base-patch32")

test_dataset = MultimodalAnxietyDataset(test_data, "anxiety_test_image", tokenizer, image_processor, max_len=MAX_LEN)
test_loader = DataLoader(
    test_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    collate_fn=custom_collate_fn
)

In [None]:
model = MultimodalAttentionModel(
    text_model_name="mental/mental-roberta-base",
    num_classes=len(LABEL_MAP),
    fusion_dim=FUSION_DIM
)

trained_model, tokenizer, image_processor = train_multimodal_model(
    model,
    train_data,
    val_data,
    img_path,
    epochs=30,
    model_save_name="no_contrastive"
)

Some weights of RobertaModel were not initialized from the model checkpoint at mental/mental-roberta-base and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Train Set Size: 2153
Validation Set Size: 307

Epoch 1/30


Training: 100%|██████████| 135/135 [03:33<00:00,  1.58s/it]


Train Loss: 1.7318
Train Macro-F1: 0.2420, Weighted-F1: 0.2442


Evaluating: 100%|██████████| 20/20 [00:14<00:00,  1.35it/s]


                       precision    recall  f1-score   support

          Nervousness       0.55      0.62      0.58        53
Lack of Worry Control       0.33      0.72      0.45        47
      Excessive Worry       0.00      0.00      0.00        46
  Difficulty Relaxing       0.92      0.71      0.80        51
         Restlessness       0.33      0.48      0.39        58
       Impending Doom       0.72      0.25      0.37        52

             accuracy                           0.47       307
            macro avg       0.48      0.46      0.43       307
         weighted avg       0.48      0.47      0.44       307

Validation Loss: 1.3845
Validation Macro-F1: 0.4334, Weighted-F1: 0.4401
Best model saved!

Epoch 2/30


Training: 100%|██████████| 135/135 [03:31<00:00,  1.57s/it]


Train Loss: 1.2146
Train Macro-F1: 0.5356, Weighted-F1: 0.5355


Evaluating: 100%|██████████| 20/20 [00:14<00:00,  1.35it/s]


                       precision    recall  f1-score   support

          Nervousness       0.70      0.57      0.62        53
Lack of Worry Control       0.68      0.57      0.62        47
      Excessive Worry       0.63      0.37      0.47        46
  Difficulty Relaxing       0.82      0.80      0.81        51
         Restlessness       0.42      0.69      0.52        58
       Impending Doom       0.61      0.60      0.60        52

             accuracy                           0.61       307
            macro avg       0.64      0.60      0.61       307
         weighted avg       0.64      0.61      0.61       307

Validation Loss: 1.0530
Validation Macro-F1: 0.6075, Weighted-F1: 0.6077
Best model saved!

Epoch 3/30


Training: 100%|██████████| 135/135 [03:30<00:00,  1.56s/it]


Train Loss: 0.7739
Train Macro-F1: 0.7318, Weighted-F1: 0.7312


Evaluating: 100%|██████████| 20/20 [00:15<00:00,  1.33it/s]


                       precision    recall  f1-score   support

          Nervousness       0.66      0.77      0.71        53
Lack of Worry Control       0.66      0.57      0.61        47
      Excessive Worry       0.64      0.39      0.49        46
  Difficulty Relaxing       0.93      0.80      0.86        51
         Restlessness       0.53      0.57      0.55        58
       Impending Doom       0.63      0.85      0.72        52

             accuracy                           0.66       307
            macro avg       0.68      0.66      0.66       307
         weighted avg       0.67      0.66      0.66       307

Validation Loss: 1.0143
Validation Macro-F1: 0.6579, Weighted-F1: 0.6594
Best model saved!

Epoch 4/30


Training: 100%|██████████| 135/135 [03:31<00:00,  1.57s/it]


Train Loss: 0.4749
Train Macro-F1: 0.8421, Weighted-F1: 0.8424


Evaluating: 100%|██████████| 20/20 [00:14<00:00,  1.34it/s]


                       precision    recall  f1-score   support

          Nervousness       0.61      0.66      0.64        53
Lack of Worry Control       0.81      0.45      0.58        47
      Excessive Worry       0.61      0.41      0.49        46
  Difficulty Relaxing       0.91      0.78      0.84        51
         Restlessness       0.50      0.72      0.59        58
       Impending Doom       0.58      0.73      0.65        52

             accuracy                           0.64       307
            macro avg       0.67      0.63      0.63       307
         weighted avg       0.67      0.64      0.63       307

Validation Loss: 1.2196
Validation Macro-F1: 0.6314, Weighted-F1: 0.6336

Epoch 5/30


Training: 100%|██████████| 135/135 [03:31<00:00,  1.57s/it]


Train Loss: 0.2657
Train Macro-F1: 0.9147, Weighted-F1: 0.9150


Evaluating: 100%|██████████| 20/20 [00:14<00:00,  1.34it/s]


                       precision    recall  f1-score   support

          Nervousness       0.67      0.66      0.67        53
Lack of Worry Control       0.65      0.60      0.62        47
      Excessive Worry       0.61      0.41      0.49        46
  Difficulty Relaxing       0.88      0.86      0.87        51
         Restlessness       0.51      0.55      0.53        58
       Impending Doom       0.59      0.77      0.67        52

             accuracy                           0.64       307
            macro avg       0.65      0.64      0.64       307
         weighted avg       0.65      0.64      0.64       307

Validation Loss: 1.4764
Validation Macro-F1: 0.6415, Weighted-F1: 0.6419

Epoch 6/30


Training: 100%|██████████| 135/135 [03:31<00:00,  1.57s/it]


Train Loss: 0.2074
Train Macro-F1: 0.9345, Weighted-F1: 0.9350


Evaluating: 100%|██████████| 20/20 [00:14<00:00,  1.35it/s]


                       precision    recall  f1-score   support

          Nervousness       0.55      0.79      0.65        53
Lack of Worry Control       0.92      0.47      0.62        47
      Excessive Worry       0.55      0.39      0.46        46
  Difficulty Relaxing       0.89      0.80      0.85        51
         Restlessness       0.51      0.64      0.57        58
       Impending Doom       0.65      0.69      0.67        52

             accuracy                           0.64       307
            macro avg       0.68      0.63      0.63       307
         weighted avg       0.67      0.64      0.64       307

Validation Loss: 1.8626
Validation Macro-F1: 0.6348, Weighted-F1: 0.6367

Epoch 7/30


Training: 100%|██████████| 135/135 [03:31<00:00,  1.57s/it]


Train Loss: 0.1187
Train Macro-F1: 0.9657, Weighted-F1: 0.9661


Evaluating: 100%|██████████| 20/20 [00:14<00:00,  1.35it/s]


                       precision    recall  f1-score   support

          Nervousness       0.67      0.68      0.67        53
Lack of Worry Control       0.75      0.64      0.69        47
      Excessive Worry       0.55      0.46      0.50        46
  Difficulty Relaxing       0.85      0.88      0.87        51
         Restlessness       0.56      0.60      0.58        58
       Impending Doom       0.65      0.75      0.70        52

             accuracy                           0.67       307
            macro avg       0.67      0.67      0.67       307
         weighted avg       0.67      0.67      0.67       307

Validation Loss: 1.7143
Validation Macro-F1: 0.6679, Weighted-F1: 0.6686
Best model saved!

Epoch 8/30


Training: 100%|██████████| 135/135 [03:30<00:00,  1.56s/it]


Train Loss: 0.0700
Train Macro-F1: 0.9794, Weighted-F1: 0.9796


Evaluating: 100%|██████████| 20/20 [00:14<00:00,  1.36it/s]


                       precision    recall  f1-score   support

          Nervousness       0.66      0.62      0.64        53
Lack of Worry Control       0.72      0.55      0.63        47
      Excessive Worry       0.58      0.46      0.51        46
  Difficulty Relaxing       0.87      0.88      0.87        51
         Restlessness       0.55      0.57      0.56        58
       Impending Doom       0.56      0.79      0.66        52

             accuracy                           0.65       307
            macro avg       0.66      0.65      0.64       307
         weighted avg       0.65      0.65      0.65       307

Validation Loss: 1.9985
Validation Macro-F1: 0.6448, Weighted-F1: 0.6452

Epoch 9/30


Training: 100%|██████████| 135/135 [03:31<00:00,  1.56s/it]


Train Loss: 0.0473
Train Macro-F1: 0.9867, Weighted-F1: 0.9870


Evaluating: 100%|██████████| 20/20 [00:14<00:00,  1.36it/s]


                       precision    recall  f1-score   support

          Nervousness       0.64      0.70      0.67        53
Lack of Worry Control       0.62      0.60      0.61        47
      Excessive Worry       0.69      0.39      0.50        46
  Difficulty Relaxing       0.91      0.78      0.84        51
         Restlessness       0.53      0.53      0.53        58
       Impending Doom       0.58      0.85      0.69        52

             accuracy                           0.64       307
            macro avg       0.66      0.64      0.64       307
         weighted avg       0.66      0.64      0.64       307

Validation Loss: 2.1113
Validation Macro-F1: 0.6399, Weighted-F1: 0.6405

Epoch 10/30


Training: 100%|██████████| 135/135 [03:30<00:00,  1.56s/it]


Train Loss: 0.0393
Train Macro-F1: 0.9891, Weighted-F1: 0.9893


Evaluating: 100%|██████████| 20/20 [00:15<00:00,  1.33it/s]


                       precision    recall  f1-score   support

          Nervousness       0.60      0.74      0.66        53
Lack of Worry Control       0.86      0.51      0.64        47
      Excessive Worry       0.56      0.54      0.55        46
  Difficulty Relaxing       0.88      0.84      0.86        51
         Restlessness       0.58      0.62      0.60        58
       Impending Doom       0.67      0.75      0.71        52

             accuracy                           0.67       307
            macro avg       0.69      0.67      0.67       307
         weighted avg       0.69      0.67      0.67       307

Validation Loss: 2.2772
Validation Macro-F1: 0.6699, Weighted-F1: 0.6708
Best model saved!

Epoch 11/30


Training: 100%|██████████| 135/135 [03:31<00:00,  1.57s/it]


Train Loss: 0.0354
Train Macro-F1: 0.9925, Weighted-F1: 0.9926


Evaluating: 100%|██████████| 20/20 [00:14<00:00,  1.35it/s]


                       precision    recall  f1-score   support

          Nervousness       0.66      0.74      0.70        53
Lack of Worry Control       0.82      0.60      0.69        47
      Excessive Worry       0.57      0.59      0.58        46
  Difficulty Relaxing       0.88      0.86      0.87        51
         Restlessness       0.56      0.62      0.59        58
       Impending Doom       0.68      0.69      0.69        52

             accuracy                           0.68       307
            macro avg       0.70      0.68      0.69       307
         weighted avg       0.69      0.68      0.69       307

Validation Loss: 2.1422
Validation Macro-F1: 0.6859, Weighted-F1: 0.6855
Best model saved!

Epoch 12/30


Training: 100%|██████████| 135/135 [03:31<00:00,  1.57s/it]


Train Loss: 0.0293
Train Macro-F1: 0.9925, Weighted-F1: 0.9926


Evaluating: 100%|██████████| 20/20 [00:14<00:00,  1.34it/s]


                       precision    recall  f1-score   support

          Nervousness       0.61      0.74      0.67        53
Lack of Worry Control       0.84      0.55      0.67        47
      Excessive Worry       0.62      0.46      0.52        46
  Difficulty Relaxing       0.84      0.82      0.83        51
         Restlessness       0.57      0.64      0.60        58
       Impending Doom       0.60      0.73      0.66        52

             accuracy                           0.66       307
            macro avg       0.68      0.66      0.66       307
         weighted avg       0.68      0.66      0.66       307

Validation Loss: 2.3044
Validation Macro-F1: 0.6588, Weighted-F1: 0.6596

Epoch 13/30


Training: 100%|██████████| 135/135 [03:30<00:00,  1.56s/it]


Train Loss: 0.0274
Train Macro-F1: 0.9948, Weighted-F1: 0.9949


Evaluating: 100%|██████████| 20/20 [00:14<00:00,  1.35it/s]


                       precision    recall  f1-score   support

          Nervousness       0.59      0.72      0.65        53
Lack of Worry Control       0.86      0.51      0.64        47
      Excessive Worry       0.55      0.50      0.52        46
  Difficulty Relaxing       0.91      0.78      0.84        51
         Restlessness       0.52      0.53      0.53        58
       Impending Doom       0.58      0.77      0.66        52

             accuracy                           0.64       307
            macro avg       0.67      0.64      0.64       307
         weighted avg       0.66      0.64      0.64       307

Validation Loss: 2.3376
Validation Macro-F1: 0.6402, Weighted-F1: 0.6396

Epoch 14/30


Training: 100%|██████████| 135/135 [03:31<00:00,  1.56s/it]


Train Loss: 0.0253
Train Macro-F1: 0.9952, Weighted-F1: 0.9954


Evaluating: 100%|██████████| 20/20 [00:15<00:00,  1.31it/s]


                       precision    recall  f1-score   support

          Nervousness       0.65      0.75      0.70        53
Lack of Worry Control       0.84      0.55      0.67        47
      Excessive Worry       0.54      0.46      0.49        46
  Difficulty Relaxing       0.88      0.84      0.86        51
         Restlessness       0.51      0.64      0.56        58
       Impending Doom       0.68      0.69      0.69        52

             accuracy                           0.66       307
            macro avg       0.68      0.66      0.66       307
         weighted avg       0.68      0.66      0.66       307

Validation Loss: 2.3073
Validation Macro-F1: 0.6612, Weighted-F1: 0.6619

Epoch 15/30


Training: 100%|██████████| 135/135 [03:31<00:00,  1.57s/it]


Train Loss: 0.0156
Train Macro-F1: 0.9962, Weighted-F1: 0.9963


Evaluating: 100%|██████████| 20/20 [00:14<00:00,  1.34it/s]


                       precision    recall  f1-score   support

          Nervousness       0.63      0.74      0.68        53
Lack of Worry Control       0.86      0.53      0.66        47
      Excessive Worry       0.66      0.46      0.54        46
  Difficulty Relaxing       0.88      0.84      0.86        51
         Restlessness       0.58      0.57      0.57        58
       Impending Doom       0.55      0.83      0.66        52

             accuracy                           0.66       307
            macro avg       0.69      0.66      0.66       307
         weighted avg       0.69      0.66      0.66       307

Validation Loss: 2.4506
Validation Macro-F1: 0.6617, Weighted-F1: 0.6618

Epoch 16/30


Training: 100%|██████████| 135/135 [03:31<00:00,  1.56s/it]


Train Loss: 0.0117
Train Macro-F1: 0.9962, Weighted-F1: 0.9963


Evaluating: 100%|██████████| 20/20 [00:14<00:00,  1.35it/s]


                       precision    recall  f1-score   support

          Nervousness       0.66      0.72      0.68        53
Lack of Worry Control       0.81      0.53      0.64        47
      Excessive Worry       0.49      0.48      0.48        46
  Difficulty Relaxing       0.90      0.88      0.89        51
         Restlessness       0.51      0.64      0.57        58
       Impending Doom       0.71      0.69      0.70        52

             accuracy                           0.66       307
            macro avg       0.68      0.66      0.66       307
         weighted avg       0.68      0.66      0.66       307

Validation Loss: 2.3573
Validation Macro-F1: 0.6614, Weighted-F1: 0.6628

Epoch 17/30


Training: 100%|██████████| 135/135 [03:31<00:00,  1.56s/it]


Train Loss: 0.0121
Train Macro-F1: 0.9962, Weighted-F1: 0.9963


Evaluating: 100%|██████████| 20/20 [00:14<00:00,  1.34it/s]


                       precision    recall  f1-score   support

          Nervousness       0.62      0.74      0.67        53
Lack of Worry Control       0.83      0.53      0.65        47
      Excessive Worry       0.68      0.46      0.55        46
  Difficulty Relaxing       0.88      0.88      0.88        51
         Restlessness       0.54      0.59      0.56        58
       Impending Doom       0.59      0.79      0.68        52

             accuracy                           0.67       307
            macro avg       0.69      0.66      0.66       307
         weighted avg       0.69      0.67      0.66       307

Validation Loss: 2.5054
Validation Macro-F1: 0.6649, Weighted-F1: 0.6648

Epoch 18/30


Training: 100%|██████████| 135/135 [03:30<00:00,  1.56s/it]


Train Loss: 0.0097
Train Macro-F1: 0.9972, Weighted-F1: 0.9972


Evaluating: 100%|██████████| 20/20 [00:15<00:00,  1.32it/s]


                       precision    recall  f1-score   support

          Nervousness       0.67      0.74      0.70        53
Lack of Worry Control       0.83      0.53      0.65        47
      Excessive Worry       0.62      0.43      0.51        46
  Difficulty Relaxing       0.92      0.86      0.89        51
         Restlessness       0.51      0.60      0.55        58
       Impending Doom       0.59      0.79      0.67        52

             accuracy                           0.66       307
            macro avg       0.69      0.66      0.66       307
         weighted avg       0.68      0.66      0.66       307

Validation Loss: 2.4603
Validation Macro-F1: 0.6628, Weighted-F1: 0.6632

Epoch 19/30


Training: 100%|██████████| 135/135 [03:30<00:00,  1.56s/it]


Train Loss: 0.0083
Train Macro-F1: 0.9972, Weighted-F1: 0.9972


Evaluating: 100%|██████████| 20/20 [00:14<00:00,  1.35it/s]


                       precision    recall  f1-score   support

          Nervousness       0.67      0.72      0.69        53
Lack of Worry Control       0.82      0.57      0.68        47
      Excessive Worry       0.59      0.43      0.50        46
  Difficulty Relaxing       0.92      0.86      0.89        51
         Restlessness       0.51      0.62      0.56        58
       Impending Doom       0.63      0.79      0.70        52

             accuracy                           0.67       307
            macro avg       0.69      0.67      0.67       307
         weighted avg       0.68      0.67      0.67       307

Validation Loss: 2.4237
Validation Macro-F1: 0.6697, Weighted-F1: 0.6702

Epoch 20/30


Training: 100%|██████████| 135/135 [03:31<00:00,  1.57s/it]


Train Loss: 0.0091
Train Macro-F1: 0.9972, Weighted-F1: 0.9972


Evaluating: 100%|██████████| 20/20 [00:14<00:00,  1.35it/s]


                       precision    recall  f1-score   support

          Nervousness       0.64      0.74      0.68        53
Lack of Worry Control       0.78      0.53      0.63        47
      Excessive Worry       0.65      0.43      0.52        46
  Difficulty Relaxing       0.91      0.84      0.88        51
         Restlessness       0.51      0.59      0.54        58
       Impending Doom       0.59      0.79      0.68        52

             accuracy                           0.66       307
            macro avg       0.68      0.65      0.66       307
         weighted avg       0.68      0.66      0.66       307

Validation Loss: 2.4825
Validation Macro-F1: 0.6560, Weighted-F1: 0.6562

Epoch 21/30


Training: 100%|██████████| 135/135 [03:31<00:00,  1.56s/it]


Train Loss: 0.0080
Train Macro-F1: 0.9962, Weighted-F1: 0.9963


Evaluating: 100%|██████████| 20/20 [00:14<00:00,  1.34it/s]


                       precision    recall  f1-score   support

          Nervousness       0.65      0.75      0.70        53
Lack of Worry Control       0.82      0.60      0.69        47
      Excessive Worry       0.61      0.43      0.51        46
  Difficulty Relaxing       0.90      0.86      0.88        51
         Restlessness       0.52      0.62      0.57        58
       Impending Doom       0.67      0.77      0.71        52

             accuracy                           0.68       307
            macro avg       0.69      0.67      0.68       307
         weighted avg       0.69      0.68      0.68       307

Validation Loss: 2.4638
Validation Macro-F1: 0.6758, Weighted-F1: 0.6761

Epoch 22/30


Training: 100%|██████████| 135/135 [03:31<00:00,  1.57s/it]


Train Loss: 0.0062
Train Macro-F1: 0.9966, Weighted-F1: 0.9967


Evaluating: 100%|██████████| 20/20 [00:15<00:00,  1.33it/s]


                       precision    recall  f1-score   support

          Nervousness       0.64      0.74      0.68        53
Lack of Worry Control       0.82      0.57      0.68        47
      Excessive Worry       0.60      0.46      0.52        46
  Difficulty Relaxing       0.92      0.86      0.89        51
         Restlessness       0.51      0.57      0.54        58
       Impending Doom       0.63      0.79      0.70        52

             accuracy                           0.67       307
            macro avg       0.69      0.66      0.67       307
         weighted avg       0.68      0.67      0.67       307

Validation Loss: 2.4596
Validation Macro-F1: 0.6673, Weighted-F1: 0.6669

Epoch 23/30


Training: 100%|██████████| 135/135 [03:31<00:00,  1.57s/it]


Train Loss: 0.0084
Train Macro-F1: 0.9967, Weighted-F1: 0.9967


Evaluating: 100%|██████████| 20/20 [00:14<00:00,  1.35it/s]


                       precision    recall  f1-score   support

          Nervousness       0.65      0.75      0.70        53
Lack of Worry Control       0.81      0.55      0.66        47
      Excessive Worry       0.58      0.46      0.51        46
  Difficulty Relaxing       0.90      0.86      0.88        51
         Restlessness       0.51      0.57      0.54        58
       Impending Doom       0.63      0.77      0.70        52

             accuracy                           0.66       307
            macro avg       0.68      0.66      0.66       307
         weighted avg       0.68      0.66      0.66       307

Validation Loss: 2.4886
Validation Macro-F1: 0.6631, Weighted-F1: 0.6630

Epoch 24/30


Training: 100%|██████████| 135/135 [03:31<00:00,  1.57s/it]


Train Loss: 0.0060
Train Macro-F1: 0.9972, Weighted-F1: 0.9972


Evaluating: 100%|██████████| 20/20 [00:14<00:00,  1.34it/s]


                       precision    recall  f1-score   support

          Nervousness       0.63      0.75      0.69        53
Lack of Worry Control       0.81      0.55      0.66        47
      Excessive Worry       0.62      0.46      0.52        46
  Difficulty Relaxing       0.90      0.86      0.88        51
         Restlessness       0.51      0.57      0.54        58
       Impending Doom       0.62      0.77      0.69        52

             accuracy                           0.66       307
            macro avg       0.68      0.66      0.66       307
         weighted avg       0.68      0.66      0.66       307

Validation Loss: 2.5116
Validation Macro-F1: 0.6632, Weighted-F1: 0.6629

Epoch 25/30


Training: 100%|██████████| 135/135 [03:31<00:00,  1.57s/it]


Train Loss: 0.0052
Train Macro-F1: 0.9976, Weighted-F1: 0.9977


Evaluating: 100%|██████████| 20/20 [00:14<00:00,  1.34it/s]


                       precision    recall  f1-score   support

          Nervousness       0.65      0.75      0.70        53
Lack of Worry Control       0.81      0.55      0.66        47
      Excessive Worry       0.60      0.46      0.52        46
  Difficulty Relaxing       0.90      0.86      0.88        51
         Restlessness       0.51      0.57      0.54        58
       Impending Doom       0.62      0.77      0.69        52

             accuracy                           0.66       307
            macro avg       0.68      0.66      0.66       307
         weighted avg       0.68      0.66      0.66       307

Validation Loss: 2.5004
Validation Macro-F1: 0.6631, Weighted-F1: 0.6629

Epoch 26/30


Training: 100%|██████████| 135/135 [03:31<00:00,  1.57s/it]


Train Loss: 0.0071
Train Macro-F1: 0.9967, Weighted-F1: 0.9968


Evaluating: 100%|██████████| 20/20 [00:14<00:00,  1.35it/s]


                       precision    recall  f1-score   support

          Nervousness       0.65      0.75      0.70        53
Lack of Worry Control       0.81      0.55      0.66        47
      Excessive Worry       0.61      0.43      0.51        46
  Difficulty Relaxing       0.92      0.86      0.89        51
         Restlessness       0.51      0.59      0.54        58
       Impending Doom       0.63      0.79      0.70        52

             accuracy                           0.67       307
            macro avg       0.69      0.66      0.67       307
         weighted avg       0.68      0.67      0.67       307

Validation Loss: 2.5031
Validation Macro-F1: 0.6657, Weighted-F1: 0.6659

Epoch 27/30


Training: 100%|██████████| 135/135 [03:30<00:00,  1.56s/it]


Train Loss: 0.0046
Train Macro-F1: 0.9976, Weighted-F1: 0.9977


Evaluating: 100%|██████████| 20/20 [00:14<00:00,  1.35it/s]


                       precision    recall  f1-score   support

          Nervousness       0.65      0.75      0.70        53
Lack of Worry Control       0.81      0.55      0.66        47
      Excessive Worry       0.61      0.43      0.51        46
  Difficulty Relaxing       0.90      0.86      0.88        51
         Restlessness       0.51      0.59      0.54        58
       Impending Doom       0.62      0.77      0.69        52

             accuracy                           0.66       307
            macro avg       0.68      0.66      0.66       307
         weighted avg       0.68      0.66      0.66       307

Validation Loss: 2.5050
Validation Macro-F1: 0.6623, Weighted-F1: 0.6625

Epoch 28/30


Training: 100%|██████████| 135/135 [03:31<00:00,  1.57s/it]


Train Loss: 0.0057
Train Macro-F1: 0.9976, Weighted-F1: 0.9977


Evaluating: 100%|██████████| 20/20 [00:14<00:00,  1.35it/s]


                       precision    recall  f1-score   support

          Nervousness       0.66      0.75      0.70        53
Lack of Worry Control       0.81      0.55      0.66        47
      Excessive Worry       0.61      0.43      0.51        46
  Difficulty Relaxing       0.92      0.86      0.89        51
         Restlessness       0.51      0.60      0.56        58
       Impending Doom       0.63      0.79      0.70        52

             accuracy                           0.67       307
            macro avg       0.69      0.67      0.67       307
         weighted avg       0.68      0.67      0.67       307

Validation Loss: 2.5110
Validation Macro-F1: 0.6686, Weighted-F1: 0.6691

Epoch 29/30


Training: 100%|██████████| 135/135 [03:31<00:00,  1.56s/it]


Train Loss: 0.0051
Train Macro-F1: 0.9976, Weighted-F1: 0.9977


Evaluating: 100%|██████████| 20/20 [00:15<00:00,  1.33it/s]


                       precision    recall  f1-score   support

          Nervousness       0.65      0.74      0.69        53
Lack of Worry Control       0.81      0.55      0.66        47
      Excessive Worry       0.61      0.43      0.51        46
  Difficulty Relaxing       0.92      0.86      0.89        51
         Restlessness       0.51      0.60      0.56        58
       Impending Doom       0.62      0.79      0.69        52

             accuracy                           0.67       307
            macro avg       0.69      0.66      0.67       307
         weighted avg       0.68      0.67      0.67       307

Validation Loss: 2.5064
Validation Macro-F1: 0.6657, Weighted-F1: 0.6661

Epoch 30/30


Training: 100%|██████████| 135/135 [03:31<00:00,  1.57s/it]


Train Loss: 0.0053
Train Macro-F1: 0.9976, Weighted-F1: 0.9977


Evaluating: 100%|██████████| 20/20 [00:14<00:00,  1.35it/s]


                       precision    recall  f1-score   support

          Nervousness       0.66      0.74      0.70        53
Lack of Worry Control       0.81      0.55      0.66        47
      Excessive Worry       0.61      0.43      0.51        46
  Difficulty Relaxing       0.92      0.86      0.89        51
         Restlessness       0.51      0.62      0.56        58
       Impending Doom       0.63      0.79      0.70        52

             accuracy                           0.67       307
            macro avg       0.69      0.67      0.67       307
         weighted avg       0.69      0.67      0.67       307

Validation Loss: 2.5001
Validation Macro-F1: 0.6689, Weighted-F1: 0.6695


## Inference

In [11]:
!gdown 1qF-7lyFOOLxn6XqBqRYp-UMFa6dF9IYv

Downloading...
From (original): https://drive.google.com/uc?id=1qF-7lyFOOLxn6XqBqRYp-UMFa6dF9IYv
From (redirected): https://drive.google.com/uc?id=1qF-7lyFOOLxn6XqBqRYp-UMFa6dF9IYv&confirm=t&uuid=1b026422-fcf7-4c65-9128-3fc9c7d551f8
To: /kaggle/working/no_contrastive_anxiety.pth
100%|██████████████████████████████████████| 1.04G/1.04G [00:18<00:00, 54.7MB/s]


In [12]:
def inference(test_data, model_path):
    model = MultimodalAttentionModel(
        text_model_name="mental/mental-roberta-base",
        num_classes=len(LABEL_MAP),
        fusion_dim=FUSION_DIM
    )

    model = model.to(DEVICE)

    weights = torch.load(model_path, map_location=DEVICE, weights_only=True)
    weights_single = {k.replace("module.", ""): v for k, v in weights.items()}

    model.load_state_dict(weights_single)
    tokenizer = AutoTokenizer.from_pretrained("mental/mental-roberta-base")
    image_processor = CLIPImageProcessor.from_pretrained("openai/clip-vit-base-patch32")

    test_dataset = MultimodalAnxietyDataset(test_data, "anxiety_test_image", tokenizer, image_processor, max_len=MAX_LEN)
    test_loader = DataLoader(
        test_dataset,
        batch_size=BATCH_SIZE,
        shuffle=False,
        collate_fn=custom_collate_fn
    )

    loss, macro_f1, weighted_f1 = evaluate_multimodal_model(
        model, test_loader, nn.CrossEntropyLoss()
    )

    print(f"Test Loss: {loss:.4f}")
    print(f"Test Macro-F1: {macro_f1:.4f}, Weighted-F1: {weighted_f1:.4f}")

    return macro_f1, weighted_f1

In [16]:
inference(test_data, "no_contrastive_anxiety.pth")

Some weights of RobertaModel were not initialized from the model checkpoint at mental/mental-roberta-base and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Evaluating: 100%|██████████| 39/39 [00:32<00:00,  1.19it/s]

                       precision    recall  f1-score   support

          Nervousness       0.66      0.65      0.65       106
Lack of Worry Control       0.66      0.52      0.58        94
      Excessive Worry       0.52      0.65      0.58        92
  Difficulty Relaxing       0.83      0.77      0.80       102
         Restlessness       0.52      0.59      0.55       116
       Impending Doom       0.76      0.67      0.71       105

             accuracy                           0.64       615
            macro avg       0.66      0.64      0.65       615
         weighted avg       0.66      0.64      0.65       615

Test Loss: 2.4443
Test Macro-F1: 0.6469, Weighted-F1: 0.6471





(0.6468653448458661, 0.6470777612908264)