## Dataset Download

In [None]:
!gdown 1doU9FE1tJ-0IL4tKZUOOp76JkXKuTkFU

Downloading...
From (original): https://drive.google.com/uc?id=1doU9FE1tJ-0IL4tKZUOOp76JkXKuTkFU
From (redirected): https://drive.google.com/uc?id=1doU9FE1tJ-0IL4tKZUOOp76JkXKuTkFU&confirm=t&uuid=952822ae-d209-4cee-a667-1111276a6e59
To: /kaggle/working/anxiety_dataset_complete.zip
100%|███████████████████████████████████████| 1.41G/1.41G [00:06<00:00, 215MB/s]


In [None]:
!unzip anxiety_dataset_complete.zip > /dev/null

In [18]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from transformers import AutoTokenizer, AutoModel, CLIPImageProcessor, CLIPVisionModel
from sklearn.metrics import f1_score, classification_report
from sklearn.model_selection import train_test_split
from tqdm import tqdm
from PIL import Image
import json
import math

TRAIN_RATIO = 0.7  # 70% train
VAL_RATIO = 0.1  # 10% validation
TEST_RATIO = 0.2 # 20% test
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
BATCH_SIZE = 16
LR = 2e-5
MAX_LEN = 512
FUSION_DIM = 768
LABELS = ["Nervousness", "Lack of Worry Control", "Excessive Worry",
          "Difficulty Relaxing", "Restlessness", "Impending Doom"]
LABEL_MAP = {i: idx for idx, i in enumerate(LABELS)}  

## Custom Dataset

In [19]:
class MultimodalAnxietyDataset(Dataset):
    def __init__(self, data, image_path, tokenizer, image_processor, max_len=512):
        self.data = data
        self.tokenizer = tokenizer
        self.image_processor = image_processor
        self.max_len = max_len
        self.img_path = image_path
        self.transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        sample = self.data[idx]

        ocr_text = sample["ocr_text"]
        figurative_reasoning = sample["figurative_reasoning"]
        combined_text = ocr_text + " [SEP] " + figurative_reasoning

        encoding = self.tokenizer(
            combined_text,
            padding="max_length",
            truncation=True,
            max_length=self.max_len,
            return_tensors="pt",
        )

        image_path = os.path.join(self.img_path, sample["sample_id"] + ".jpg")
        image = Image.open(image_path).convert("RGB")
        image_tensor = self.image_processor(image, return_tensors="pt")

        label = LABEL_MAP[sample["meme_anxiety_category"]]

        return {
            "input_ids": encoding["input_ids"].squeeze(0),
            "attention_mask": encoding["attention_mask"].squeeze(0),
            "image": image_tensor,
            "label": torch.tensor(label, dtype=torch.long)
        }

def custom_collate_fn(batch):
    input_ids = torch.stack([item['input_ids'] for item in batch])
    attention_mask = torch.stack([item['attention_mask'] for item in batch])
    labels = torch.stack([item['label'] for item in batch])

    images = {}
    for key in batch[0]['image'].keys():
        if isinstance(batch[0]['image'][key], torch.Tensor):
            images[key] = torch.stack([item['image'][key].squeeze(0) for item in batch])

    return {
        'input_ids': input_ids,
        'attention_mask': attention_mask,
        'image': images,
        'label': labels,
    }

## Model Definition

In [20]:
class MultimodalConcatenationModel(nn.Module):
    def __init__(self, text_model_name="bert-base-uncased", num_classes=3, fusion_dim=768):
        super(MultimodalConcatenationModel, self).__init__()

        self.text_encoder = AutoModel.from_pretrained(text_model_name)
        self.text_dim = self.text_encoder.config.hidden_size

        self.vision_encoder = CLIPVisionModel.from_pretrained("openai/clip-vit-base-patch32")
        self.vision_dim = self.vision_encoder.config.hidden_size

        fusion_dim = 1024
        contrastive_dim = 256

        self.text_projection = nn.Linear(self.text_dim, fusion_dim)
        self.vision_projection = nn.Linear(self.vision_dim, fusion_dim)

        self.fusion_layer = nn.Sequential(
            nn.Linear(fusion_dim * 2, fusion_dim),
            nn.LayerNorm(fusion_dim),
            nn.GELU(),
            nn.Dropout(0.2)
        )

        self.classifier = nn.Sequential(
            nn.Linear(fusion_dim, fusion_dim),
            nn.LayerNorm(fusion_dim),
            nn.GELU(),
            nn.Dropout(0.1),
            nn.Linear(fusion_dim, num_classes)
        )

        self.contrastive_projection = nn.Sequential(
            nn.Linear(fusion_dim, fusion_dim // 2),
            nn.LayerNorm(fusion_dim // 2),
            nn.GELU(),
            nn.Linear(fusion_dim // 2, contrastive_dim)
        )

        self.text_contrastive_proj = nn.Sequential(
            nn.Linear(self.text_dim, contrastive_dim),
            nn.LayerNorm(contrastive_dim),
            nn.GELU()
        )

        self.image_contrastive_proj = nn.Sequential(
            nn.Linear(self.vision_dim, contrastive_dim),
            nn.LayerNorm(contrastive_dim),
            nn.GELU()
        )

    def forward(self, input_ids, attention_mask, image_features, get_embeddings=False):
        text_outputs = self.text_encoder(input_ids=input_ids, attention_mask=attention_mask)
        text_cls = text_outputs.last_hidden_state[:, 0]

        vision_outputs = self.vision_encoder(**image_features)
        image_cls = vision_outputs.pooler_output

        text_features = self.text_projection(text_cls)
        vision_features = self.vision_projection(image_cls)

        concatenated_features = torch.cat([text_features, vision_features], dim=1)

        fused_features = self.fusion_layer(concatenated_features)

        logits = self.classifier(fused_features)

        if get_embeddings:
            multimodal_contrastive = self.contrastive_projection(fused_features)
            text_contrastive = self.text_contrastive_proj(text_cls)
            image_contrastive = self.image_contrastive_proj(image_cls)

            multimodal_contrastive = F.normalize(multimodal_contrastive, p=2, dim=1)
            text_contrastive = F.normalize(text_contrastive, p=2, dim=1)
            image_contrastive = F.normalize(image_contrastive, p=2, dim=1)

            return {
                "logits": logits,
                "multimodal_embedding": multimodal_contrastive,
                "text_embedding": text_contrastive,
                "image_embedding": image_contrastive
            }

        return logits

class ContrastiveLoss(nn.Module):
    def __init__(self, temperature=0.07):
        super().__init__()
        self.temperature = temperature
        self.criterion = nn.CrossEntropyLoss(reduction="sum")

    def forward(self, modal1, modal2):
        batch_size = modal1.shape[0]

        features = torch.cat([modal1, modal2], dim=0)
        similarity_matrix = torch.matmul(features, features.T)

        mask = (~torch.eye(2 * batch_size, dtype=bool, device=features.device)).float()

        similarity_matrix = similarity_matrix * mask
        similarity_matrix = similarity_matrix / self.temperature

        labels = torch.arange(batch_size, device=features.device, dtype=torch.long)
        labels = torch.cat([labels + batch_size, labels], dim=0)

        loss = self.criterion(similarity_matrix, labels)
        loss = loss / (2 * batch_size)

        return loss

## Training Functions

In [21]:
def train_multimodal_model(model, train_data, val_data, img_path, epochs, model_save_name):
    tokenizer = AutoTokenizer.from_pretrained("mental/mental-roberta-base")
    image_processor = CLIPImageProcessor.from_pretrained("openai/clip-vit-base-patch32")

    train_dataset = MultimodalAnxietyDataset(train_data, img_path, tokenizer, image_processor, max_len=MAX_LEN)
    val_dataset = MultimodalAnxietyDataset(val_data, img_path, tokenizer, image_processor, max_len=MAX_LEN)

    print("Train Set Size:", len(train_dataset))
    print("Validation Set Size:", len(val_dataset))

    train_loader = DataLoader(
        train_dataset,
        batch_size=BATCH_SIZE,
        shuffle=True,
        collate_fn=custom_collate_fn
    )

    val_loader = DataLoader(
        val_dataset,
        batch_size=BATCH_SIZE,
        shuffle=False,
        collate_fn=custom_collate_fn
    )

    optimizer = optim.AdamW(model.parameters(), lr=LR)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='max', factor=0.5, patience=2, verbose=True
    )
    criterion = nn.CrossEntropyLoss()
    contrastive_criterion = ContrastiveLoss(temperature=0.07)
    contrastive_weight = 0.3

    model = model.to(DEVICE)
    model = nn.DataParallel(model)

    best_f1 = 0
    for epoch in range(epochs):
        print(f"\nEpoch {epoch+1}/{epochs}")

        model.train()
        train_loss = 0
        train_preds, train_labels = [], []

        for batch in tqdm(train_loader, desc="Training"):
            input_ids = batch["input_ids"].to(DEVICE)
            attention_mask = batch["attention_mask"].to(DEVICE)
            labels = batch["label"].to(DEVICE)
            image_features = {k: v.to(DEVICE) for k, v in batch["image"].items()}

            optimizer.zero_grad()
            outputs = model(input_ids, attention_mask, image_features, get_embeddings=True)
            logits = outputs["logits"]

            classification_loss = criterion(logits, labels)

            multimodal_text_loss = contrastive_criterion(
                outputs["multimodal_embedding"],
                outputs["text_embedding"]
            )

            multimodal_image_loss = contrastive_criterion(
                outputs["multimodal_embedding"],
                outputs["image_embedding"]
            )

            text_image_loss = contrastive_criterion(
                outputs["text_embedding"],
                outputs["image_embedding"]
            )

            contrastive_loss = (multimodal_text_loss + multimodal_image_loss + text_image_loss) / 3
            loss = (1 - contrastive_weight) * classification_loss + contrastive_weight * contrastive_loss

            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()

            train_loss += loss.item()
            predictions = torch.argmax(logits, dim=1).cpu().numpy()
            train_preds.extend(predictions)
            train_labels.extend(labels.cpu().numpy())

        train_loss = train_loss / len(train_loader)
        train_macro_f1 = f1_score(train_labels, train_preds, average="macro")
        train_weighted_f1 = f1_score(train_labels, train_preds, average="weighted")

        print(f"Train Loss: {train_loss:.4f}")
        print(f"Train Macro-F1: {train_macro_f1:.4f}, Weighted-F1: {train_weighted_f1:.4f}")

        val_loss, val_macro_f1, val_weighted_f1 = evaluate_multimodal_model(
            model, val_loader, criterion
        )

        print(f"Validation Loss: {val_loss:.4f}")
        print(f"Validation Macro-F1: {val_macro_f1:.4f}, Weighted-F1: {val_weighted_f1:.4f}")

        scheduler.step(val_macro_f1)

        f1_hm = 2 * val_macro_f1 * val_weighted_f1 / (val_macro_f1 + val_weighted_f1)
        if f1_hm > best_f1:
            best_f1 = f1_hm
            torch.save(model.state_dict(), f"{model_save_name}_anxiety.pth")
            print("Best model saved!")

    return model

def evaluate_multimodal_model(model, loader, criterion):
    model.eval()
    val_loss = 0
    val_preds, val_labels = [], []

    with torch.no_grad():
        for batch in tqdm(loader, desc="Evaluating"):
            input_ids = batch["input_ids"].to(DEVICE)
            attention_mask = batch["attention_mask"].to(DEVICE)
            labels = batch["label"].to(DEVICE)
            image_features = {k: v.to(DEVICE) for k, v in batch["image"].items()}

            logits = model(input_ids, attention_mask, image_features)
            loss = criterion(logits, labels)

            val_loss += loss.item()
            predictions = torch.argmax(logits, dim=1).cpu().numpy()
            val_preds.extend(predictions)
            val_labels.extend(labels.cpu().numpy())

    val_loss = val_loss / len(loader)
    val_macro_f1 = f1_score(val_labels, val_preds, average="macro")
    val_weighted_f1 = f1_score(val_labels, val_preds, average="weighted")

    target_names = [k for k, v in sorted(LABEL_MAP.items(), key=lambda x: x[1])]
    report = classification_report(val_labels, val_preds, target_names=target_names)
    print(report)

    return val_loss, val_macro_f1, val_weighted_f1

## Model Training

In [22]:
full_train_data = json.load(open("anxiety_train_llava_dataset.json", "r"))
test_data = json.load(open("anxiety_test_llava_dataset.json", "r"))

labels = [LABEL_MAP[item["meme_anxiety_category"]] for item in full_train_data]

train_size = math.ceil(len(full_train_data) * TRAIN_RATIO / (TRAIN_RATIO + VAL_RATIO))
train_data, val_data = train_test_split(
    full_train_data, train_size=train_size, stratify=labels, random_state=42
)

img_path = "anxiety_train_image"

tokenizer = AutoTokenizer.from_pretrained("mental/mental-roberta-base")
image_processor = CLIPImageProcessor.from_pretrained("openai/clip-vit-base-patch32")

test_dataset = MultimodalAnxietyDataset(test_data, "anxiety_test_image", tokenizer, image_processor, max_len=MAX_LEN)
test_loader = DataLoader(
    test_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    collate_fn=custom_collate_fn
)

In [None]:
model = MultimodalConcatenationModel(
    text_model_name="mental/mental-roberta-base",
    num_classes=len(LABEL_MAP),
    fusion_dim=FUSION_DIM
)

trained_model, tokenizer, image_processor = train_multimodal_model(
    model,
    train_data,
    val_data,
    img_path,
    epochs=30,
    model_save_name="no_fusion"
)

Some weights of RobertaModel were not initialized from the model checkpoint at mental/mental-roberta-base and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Train Set Size: 2153
Validation Set Size: 307

Epoch 1/30


Training: 100%|██████████| 135/135 [03:21<00:00,  1.49s/it]


Train Loss: 2.2484
Train Macro-F1: 0.3441, Weighted-F1: 0.3460


Evaluating: 100%|██████████| 20/20 [00:14<00:00,  1.42it/s]


                       precision    recall  f1-score   support

          Nervousness       0.49      0.70      0.57        53
Lack of Worry Control       0.63      0.57      0.60        47
      Excessive Worry       0.56      0.39      0.46        46
  Difficulty Relaxing       0.77      0.90      0.83        51
         Restlessness       0.43      0.28      0.34        58
       Impending Doom       0.56      0.63      0.59        52

             accuracy                           0.58       307
            macro avg       0.57      0.58      0.57       307
         weighted avg       0.57      0.58      0.56       307

Validation Loss: 1.1235
Validation Macro-F1: 0.5659, Weighted-F1: 0.5621
Best model saved!

Epoch 2/30


Training: 100%|██████████| 135/135 [03:24<00:00,  1.51s/it]


Train Loss: 1.4261
Train Macro-F1: 0.6749, Weighted-F1: 0.6760


Evaluating: 100%|██████████| 20/20 [00:14<00:00,  1.42it/s]


                       precision    recall  f1-score   support

          Nervousness       0.69      0.64      0.67        53
Lack of Worry Control       0.73      0.51      0.60        47
      Excessive Worry       0.59      0.41      0.49        46
  Difficulty Relaxing       0.73      0.96      0.83        51
         Restlessness       0.45      0.50      0.47        58
       Impending Doom       0.66      0.77      0.71        52

             accuracy                           0.64       307
            macro avg       0.64      0.63      0.63       307
         weighted avg       0.64      0.64      0.63       307

Validation Loss: 1.0687
Validation Macro-F1: 0.6273, Weighted-F1: 0.6269
Best model saved!

Epoch 3/30


Training: 100%|██████████| 135/135 [03:23<00:00,  1.51s/it]


Train Loss: 0.8245
Train Macro-F1: 0.8625, Weighted-F1: 0.8631


Evaluating: 100%|██████████| 20/20 [00:14<00:00,  1.38it/s]


                       precision    recall  f1-score   support

          Nervousness       0.69      0.72      0.70        53
Lack of Worry Control       0.64      0.64      0.64        47
      Excessive Worry       0.75      0.33      0.45        46
  Difficulty Relaxing       0.84      0.80      0.82        51
         Restlessness       0.44      0.69      0.54        58
       Impending Doom       0.65      0.58      0.61        52

             accuracy                           0.63       307
            macro avg       0.67      0.63      0.63       307
         weighted avg       0.66      0.63      0.63       307

Validation Loss: 1.1290
Validation Macro-F1: 0.6282, Weighted-F1: 0.6294
Best model saved!

Epoch 4/30


Training: 100%|██████████| 135/135 [03:24<00:00,  1.51s/it]


Train Loss: 0.4862
Train Macro-F1: 0.9557, Weighted-F1: 0.9563


Evaluating: 100%|██████████| 20/20 [00:13<00:00,  1.44it/s]


                       precision    recall  f1-score   support

          Nervousness       0.55      0.77      0.64        53
Lack of Worry Control       0.69      0.57      0.63        47
      Excessive Worry       0.59      0.41      0.49        46
  Difficulty Relaxing       0.91      0.78      0.84        51
         Restlessness       0.48      0.55      0.52        58
       Impending Doom       0.69      0.67      0.68        52

             accuracy                           0.63       307
            macro avg       0.65      0.63      0.63       307
         weighted avg       0.65      0.63      0.63       307

Validation Loss: 1.3264
Validation Macro-F1: 0.6323, Weighted-F1: 0.6322
Best model saved!

Epoch 5/30


Training: 100%|██████████| 135/135 [03:22<00:00,  1.50s/it]


Train Loss: 0.3243
Train Macro-F1: 0.9630, Weighted-F1: 0.9633


Evaluating: 100%|██████████| 20/20 [00:13<00:00,  1.45it/s]


                       precision    recall  f1-score   support

          Nervousness       0.58      0.72      0.64        53
Lack of Worry Control       0.82      0.57      0.68        47
      Excessive Worry       0.56      0.52      0.54        46
  Difficulty Relaxing       0.87      0.78      0.82        51
         Restlessness       0.58      0.43      0.50        58
       Impending Doom       0.57      0.85      0.68        52

             accuracy                           0.64       307
            macro avg       0.66      0.65      0.64       307
         weighted avg       0.66      0.64      0.64       307

Validation Loss: 1.3781
Validation Macro-F1: 0.6434, Weighted-F1: 0.6414
Best model saved!

Epoch 6/30


Training: 100%|██████████| 135/135 [03:22<00:00,  1.50s/it]


Train Loss: 0.2554
Train Macro-F1: 0.9681, Weighted-F1: 0.9684


Evaluating: 100%|██████████| 20/20 [00:13<00:00,  1.44it/s]


                       precision    recall  f1-score   support

          Nervousness       0.72      0.68      0.70        53
Lack of Worry Control       0.56      0.51      0.53        47
      Excessive Worry       0.61      0.41      0.49        46
  Difficulty Relaxing       0.91      0.76      0.83        51
         Restlessness       0.50      0.47      0.48        58
       Impending Doom       0.51      0.85      0.64        52

             accuracy                           0.62       307
            macro avg       0.63      0.61      0.61       307
         weighted avg       0.63      0.62      0.61       307

Validation Loss: 1.4467
Validation Macro-F1: 0.6126, Weighted-F1: 0.6132

Epoch 7/30


Training: 100%|██████████| 135/135 [03:22<00:00,  1.50s/it]


Train Loss: 0.1866
Train Macro-F1: 0.9788, Weighted-F1: 0.9791


Evaluating: 100%|██████████| 20/20 [00:13<00:00,  1.45it/s]


                       precision    recall  f1-score   support

          Nervousness       0.62      0.68      0.65        53
Lack of Worry Control       0.64      0.62      0.63        47
      Excessive Worry       0.52      0.37      0.43        46
  Difficulty Relaxing       0.88      0.71      0.78        51
         Restlessness       0.43      0.52      0.47        58
       Impending Doom       0.62      0.73      0.67        52

             accuracy                           0.61       307
            macro avg       0.62      0.60      0.61       307
         weighted avg       0.62      0.61      0.61       307

Validation Loss: 1.5166
Validation Macro-F1: 0.6062, Weighted-F1: 0.6062

Epoch 8/30


Training: 100%|██████████| 135/135 [03:22<00:00,  1.50s/it]


Train Loss: 0.1559
Train Macro-F1: 0.9807, Weighted-F1: 0.9810


Evaluating: 100%|██████████| 20/20 [00:13<00:00,  1.44it/s]


                       precision    recall  f1-score   support

          Nervousness       0.67      0.58      0.63        53
Lack of Worry Control       0.62      0.64      0.63        47
      Excessive Worry       0.71      0.33      0.45        46
  Difficulty Relaxing       0.83      0.84      0.83        51
         Restlessness       0.46      0.66      0.54        58
       Impending Doom       0.67      0.75      0.71        52

             accuracy                           0.64       307
            macro avg       0.66      0.63      0.63       307
         weighted avg       0.66      0.64      0.63       307

Validation Loss: 1.5378
Validation Macro-F1: 0.6321, Weighted-F1: 0.6333

Epoch 9/30


Training: 100%|██████████| 135/135 [03:23<00:00,  1.51s/it]


Train Loss: 0.1173
Train Macro-F1: 0.9886, Weighted-F1: 0.9888


Evaluating: 100%|██████████| 20/20 [00:14<00:00,  1.40it/s]


                       precision    recall  f1-score   support

          Nervousness       0.62      0.77      0.69        53
Lack of Worry Control       0.75      0.64      0.69        47
      Excessive Worry       0.58      0.48      0.52        46
  Difficulty Relaxing       0.84      0.82      0.83        51
         Restlessness       0.60      0.52      0.56        58
       Impending Doom       0.65      0.79      0.71        52

             accuracy                           0.67       307
            macro avg       0.67      0.67      0.67       307
         weighted avg       0.67      0.67      0.67       307

Validation Loss: 1.4638
Validation Macro-F1: 0.6671, Weighted-F1: 0.6669
Best model saved!

Epoch 10/30


Training: 100%|██████████| 135/135 [03:24<00:00,  1.52s/it]


Train Loss: 0.0771
Train Macro-F1: 0.9948, Weighted-F1: 0.9949


Evaluating: 100%|██████████| 20/20 [00:14<00:00,  1.42it/s]


                       precision    recall  f1-score   support

          Nervousness       0.63      0.75      0.69        53
Lack of Worry Control       0.73      0.64      0.68        47
      Excessive Worry       0.66      0.41      0.51        46
  Difficulty Relaxing       0.81      0.86      0.84        51
         Restlessness       0.48      0.53      0.51        58
       Impending Doom       0.62      0.67      0.65        52

             accuracy                           0.65       307
            macro avg       0.66      0.65      0.65       307
         weighted avg       0.65      0.65      0.64       307

Validation Loss: 1.5531
Validation Macro-F1: 0.6454, Weighted-F1: 0.6444

Epoch 11/30


Training: 100%|██████████| 135/135 [03:24<00:00,  1.51s/it]


Train Loss: 0.0672
Train Macro-F1: 0.9947, Weighted-F1: 0.9949


Evaluating: 100%|██████████| 20/20 [00:14<00:00,  1.41it/s]


                       precision    recall  f1-score   support

          Nervousness       0.61      0.74      0.67        53
Lack of Worry Control       0.71      0.68      0.70        47
      Excessive Worry       0.65      0.43      0.52        46
  Difficulty Relaxing       0.83      0.88      0.86        51
         Restlessness       0.49      0.50      0.50        58
       Impending Doom       0.65      0.67      0.66        52

             accuracy                           0.65       307
            macro avg       0.66      0.65      0.65       307
         weighted avg       0.65      0.65      0.65       307

Validation Loss: 1.5853
Validation Macro-F1: 0.6492, Weighted-F1: 0.6473

Epoch 12/30


Training: 100%|██████████| 135/135 [03:23<00:00,  1.51s/it]


Train Loss: 0.0612
Train Macro-F1: 0.9953, Weighted-F1: 0.9954


Evaluating: 100%|██████████| 20/20 [00:13<00:00,  1.43it/s]


                       precision    recall  f1-score   support

          Nervousness       0.55      0.79      0.65        53
Lack of Worry Control       0.67      0.60      0.63        47
      Excessive Worry       0.75      0.33      0.45        46
  Difficulty Relaxing       0.83      0.88      0.86        51
         Restlessness       0.48      0.52      0.50        58
       Impending Doom       0.69      0.67      0.68        52

             accuracy                           0.64       307
            macro avg       0.66      0.63      0.63       307
         weighted avg       0.65      0.64      0.63       307

Validation Loss: 1.7188
Validation Macro-F1: 0.6271, Weighted-F1: 0.6272

Epoch 13/30


Training: 100%|██████████| 135/135 [03:22<00:00,  1.50s/it]


Train Loss: 0.0497
Train Macro-F1: 0.9962, Weighted-F1: 0.9963


Evaluating: 100%|██████████| 20/20 [00:13<00:00,  1.44it/s]


                       precision    recall  f1-score   support

          Nervousness       0.63      0.75      0.69        53
Lack of Worry Control       0.67      0.64      0.65        47
      Excessive Worry       0.61      0.43      0.51        46
  Difficulty Relaxing       0.86      0.84      0.85        51
         Restlessness       0.49      0.45      0.47        58
       Impending Doom       0.63      0.77      0.70        52

             accuracy                           0.65       307
            macro avg       0.65      0.65      0.64       307
         weighted avg       0.65      0.65      0.64       307

Validation Loss: 1.6304
Validation Macro-F1: 0.6440, Weighted-F1: 0.6426

Epoch 14/30


Training: 100%|██████████| 135/135 [03:23<00:00,  1.51s/it]


Train Loss: 0.0368
Train Macro-F1: 0.9962, Weighted-F1: 0.9963


Evaluating: 100%|██████████| 20/20 [00:13<00:00,  1.44it/s]


                       precision    recall  f1-score   support

          Nervousness       0.61      0.79      0.69        53
Lack of Worry Control       0.69      0.62      0.65        47
      Excessive Worry       0.62      0.39      0.48        46
  Difficulty Relaxing       0.84      0.90      0.87        51
         Restlessness       0.52      0.50      0.51        58
       Impending Doom       0.68      0.73      0.70        52

             accuracy                           0.66       307
            macro avg       0.66      0.66      0.65       307
         weighted avg       0.66      0.66      0.65       307

Validation Loss: 1.6224
Validation Macro-F1: 0.6501, Weighted-F1: 0.6501

Epoch 15/30


Training: 100%|██████████| 135/135 [03:22<00:00,  1.50s/it]


Train Loss: 0.0365
Train Macro-F1: 0.9972, Weighted-F1: 0.9972


Evaluating: 100%|██████████| 20/20 [00:14<00:00,  1.42it/s]


                       precision    recall  f1-score   support

          Nervousness       0.68      0.74      0.71        53
Lack of Worry Control       0.66      0.62      0.64        47
      Excessive Worry       0.57      0.46      0.51        46
  Difficulty Relaxing       0.82      0.92      0.87        51
         Restlessness       0.53      0.48      0.50        58
       Impending Doom       0.66      0.75      0.70        52

             accuracy                           0.66       307
            macro avg       0.65      0.66      0.66       307
         weighted avg       0.65      0.66      0.65       307

Validation Loss: 1.5689
Validation Macro-F1: 0.6550, Weighted-F1: 0.6547

Epoch 16/30


Training: 100%|██████████| 135/135 [03:22<00:00,  1.50s/it]


Train Loss: 0.0306
Train Macro-F1: 0.9967, Weighted-F1: 0.9967


Evaluating: 100%|██████████| 20/20 [00:13<00:00,  1.44it/s]


                       precision    recall  f1-score   support

          Nervousness       0.62      0.75      0.68        53
Lack of Worry Control       0.71      0.62      0.66        47
      Excessive Worry       0.71      0.37      0.49        46
  Difficulty Relaxing       0.87      0.88      0.87        51
         Restlessness       0.48      0.55      0.52        58
       Impending Doom       0.65      0.75      0.70        52

             accuracy                           0.66       307
            macro avg       0.67      0.65      0.65       307
         weighted avg       0.67      0.66      0.65       307

Validation Loss: 1.6409
Validation Macro-F1: 0.6525, Weighted-F1: 0.6524

Epoch 17/30


Training: 100%|██████████| 135/135 [03:22<00:00,  1.50s/it]


Train Loss: 0.0317
Train Macro-F1: 0.9967, Weighted-F1: 0.9967


Evaluating: 100%|██████████| 20/20 [00:13<00:00,  1.44it/s]


                       precision    recall  f1-score   support

          Nervousness       0.65      0.77      0.71        53
Lack of Worry Control       0.72      0.62      0.67        47
      Excessive Worry       0.64      0.46      0.53        46
  Difficulty Relaxing       0.87      0.90      0.88        51
         Restlessness       0.49      0.48      0.49        58
       Impending Doom       0.66      0.77      0.71        52

             accuracy                           0.67       307
            macro avg       0.67      0.67      0.66       307
         weighted avg       0.67      0.67      0.66       307

Validation Loss: 1.6354
Validation Macro-F1: 0.6641, Weighted-F1: 0.6626

Epoch 18/30


Training: 100%|██████████| 135/135 [03:22<00:00,  1.50s/it]


Train Loss: 0.0302
Train Macro-F1: 0.9962, Weighted-F1: 0.9963


Evaluating: 100%|██████████| 20/20 [00:13<00:00,  1.44it/s]


                       precision    recall  f1-score   support

          Nervousness       0.66      0.74      0.70        53
Lack of Worry Control       0.74      0.60      0.66        47
      Excessive Worry       0.62      0.43      0.51        46
  Difficulty Relaxing       0.90      0.84      0.87        51
         Restlessness       0.49      0.57      0.52        58
       Impending Doom       0.66      0.79      0.72        52

             accuracy                           0.66       307
            macro avg       0.68      0.66      0.66       307
         weighted avg       0.67      0.66      0.66       307

Validation Loss: 1.6552
Validation Macro-F1: 0.6633, Weighted-F1: 0.6630

Epoch 19/30


Training: 100%|██████████| 135/135 [03:22<00:00,  1.50s/it]


Train Loss: 0.0269
Train Macro-F1: 0.9976, Weighted-F1: 0.9977


Evaluating: 100%|██████████| 20/20 [00:13<00:00,  1.43it/s]


                       precision    recall  f1-score   support

          Nervousness       0.67      0.77      0.72        53
Lack of Worry Control       0.66      0.62      0.64        47
      Excessive Worry       0.62      0.43      0.51        46
  Difficulty Relaxing       0.85      0.88      0.87        51
         Restlessness       0.46      0.45      0.46        58
       Impending Doom       0.66      0.77      0.71        52

             accuracy                           0.65       307
            macro avg       0.65      0.65      0.65       307
         weighted avg       0.65      0.65      0.65       307

Validation Loss: 1.6529
Validation Macro-F1: 0.6498, Weighted-F1: 0.6484

Epoch 20/30


Training: 100%|██████████| 135/135 [03:22<00:00,  1.50s/it]


Train Loss: 0.0280
Train Macro-F1: 0.9962, Weighted-F1: 0.9963


Evaluating: 100%|██████████| 20/20 [00:13<00:00,  1.44it/s]


                       precision    recall  f1-score   support

          Nervousness       0.68      0.75      0.71        53
Lack of Worry Control       0.64      0.60      0.62        47
      Excessive Worry       0.61      0.43      0.51        46
  Difficulty Relaxing       0.86      0.86      0.86        51
         Restlessness       0.46      0.45      0.46        58
       Impending Doom       0.62      0.77      0.69        52

             accuracy                           0.64       307
            macro avg       0.65      0.64      0.64       307
         weighted avg       0.64      0.64      0.64       307

Validation Loss: 1.6607
Validation Macro-F1: 0.6408, Weighted-F1: 0.6397

Epoch 21/30


Training: 100%|██████████| 135/135 [03:22<00:00,  1.50s/it]


Train Loss: 0.0269
Train Macro-F1: 0.9976, Weighted-F1: 0.9977


Evaluating: 100%|██████████| 20/20 [00:13<00:00,  1.43it/s]


                       precision    recall  f1-score   support

          Nervousness       0.64      0.77      0.70        53
Lack of Worry Control       0.72      0.60      0.65        47
      Excessive Worry       0.62      0.43      0.51        46
  Difficulty Relaxing       0.85      0.86      0.85        51
         Restlessness       0.49      0.48      0.49        58
       Impending Doom       0.63      0.77      0.70        52

             accuracy                           0.65       307
            macro avg       0.66      0.65      0.65       307
         weighted avg       0.66      0.65      0.65       307

Validation Loss: 1.6711
Validation Macro-F1: 0.6503, Weighted-F1: 0.6493

Epoch 22/30


Training: 100%|██████████| 135/135 [03:23<00:00,  1.50s/it]


Train Loss: 0.0263
Train Macro-F1: 0.9976, Weighted-F1: 0.9977


Evaluating: 100%|██████████| 20/20 [00:14<00:00,  1.41it/s]


                       precision    recall  f1-score   support

          Nervousness       0.67      0.77      0.72        53
Lack of Worry Control       0.65      0.60      0.62        47
      Excessive Worry       0.62      0.43      0.51        46
  Difficulty Relaxing       0.83      0.88      0.86        51
         Restlessness       0.50      0.48      0.49        58
       Impending Doom       0.64      0.75      0.69        52

             accuracy                           0.65       307
            macro avg       0.65      0.65      0.65       307
         weighted avg       0.65      0.65      0.65       307

Validation Loss: 1.6577
Validation Macro-F1: 0.6488, Weighted-F1: 0.6484

Epoch 23/30


Training: 100%|██████████| 135/135 [03:22<00:00,  1.50s/it]


Train Loss: 0.0282
Train Macro-F1: 0.9962, Weighted-F1: 0.9963


Evaluating: 100%|██████████| 20/20 [00:13<00:00,  1.44it/s]


                       precision    recall  f1-score   support

          Nervousness       0.67      0.75      0.71        53
Lack of Worry Control       0.67      0.60      0.63        47
      Excessive Worry       0.62      0.43      0.51        46
  Difficulty Relaxing       0.85      0.86      0.85        51
         Restlessness       0.49      0.50      0.50        58
       Impending Doom       0.63      0.75      0.68        52

             accuracy                           0.65       307
            macro avg       0.65      0.65      0.65       307
         weighted avg       0.65      0.65      0.65       307

Validation Loss: 1.6605
Validation Macro-F1: 0.6474, Weighted-F1: 0.6469

Epoch 24/30


Training: 100%|██████████| 135/135 [03:22<00:00,  1.50s/it]


Train Loss: 0.0241
Train Macro-F1: 0.9971, Weighted-F1: 0.9972


Evaluating: 100%|██████████| 20/20 [00:13<00:00,  1.44it/s]


                       precision    recall  f1-score   support

          Nervousness       0.66      0.77      0.71        53
Lack of Worry Control       0.67      0.62      0.64        47
      Excessive Worry       0.65      0.43      0.52        46
  Difficulty Relaxing       0.85      0.86      0.85        51
         Restlessness       0.49      0.48      0.49        58
       Impending Doom       0.63      0.75      0.68        52

             accuracy                           0.65       307
            macro avg       0.66      0.65      0.65       307
         weighted avg       0.65      0.65      0.65       307

Validation Loss: 1.6631
Validation Macro-F1: 0.6504, Weighted-F1: 0.6494

Epoch 25/30


Training: 100%|██████████| 135/135 [03:22<00:00,  1.50s/it]


Train Loss: 0.0233
Train Macro-F1: 0.9967, Weighted-F1: 0.9967


Evaluating: 100%|██████████| 20/20 [00:14<00:00,  1.41it/s]


                       precision    recall  f1-score   support

          Nervousness       0.67      0.77      0.72        53
Lack of Worry Control       0.66      0.62      0.64        47
      Excessive Worry       0.67      0.43      0.53        46
  Difficulty Relaxing       0.85      0.88      0.87        51
         Restlessness       0.50      0.48      0.49        58
       Impending Doom       0.62      0.75      0.68        52

             accuracy                           0.66       307
            macro avg       0.66      0.66      0.65       307
         weighted avg       0.66      0.66      0.65       307

Validation Loss: 1.6643
Validation Macro-F1: 0.6530, Weighted-F1: 0.6521

Epoch 26/30


Training: 100%|██████████| 135/135 [03:22<00:00,  1.50s/it]


Train Loss: 0.0233
Train Macro-F1: 0.9976, Weighted-F1: 0.9977


Evaluating: 100%|██████████| 20/20 [00:13<00:00,  1.45it/s]


                       precision    recall  f1-score   support

          Nervousness       0.66      0.77      0.71        53
Lack of Worry Control       0.69      0.62      0.65        47
      Excessive Worry       0.67      0.43      0.53        46
  Difficulty Relaxing       0.85      0.88      0.87        51
         Restlessness       0.51      0.50      0.50        58
       Impending Doom       0.62      0.75      0.68        52

             accuracy                           0.66       307
            macro avg       0.67      0.66      0.66       307
         weighted avg       0.66      0.66      0.66       307

Validation Loss: 1.6637
Validation Macro-F1: 0.6565, Weighted-F1: 0.6557

Epoch 27/30


Training: 100%|██████████| 135/135 [03:22<00:00,  1.50s/it]


Train Loss: 0.0246
Train Macro-F1: 0.9976, Weighted-F1: 0.9977


Evaluating: 100%|██████████| 20/20 [00:13<00:00,  1.45it/s]


                       precision    recall  f1-score   support

          Nervousness       0.67      0.75      0.71        53
Lack of Worry Control       0.67      0.62      0.64        47
      Excessive Worry       0.67      0.43      0.53        46
  Difficulty Relaxing       0.85      0.88      0.87        51
         Restlessness       0.50      0.50      0.50        58
       Impending Doom       0.62      0.75      0.68        52

             accuracy                           0.66       307
            macro avg       0.66      0.66      0.65       307
         weighted avg       0.66      0.66      0.65       307

Validation Loss: 1.6565
Validation Macro-F1: 0.6537, Weighted-F1: 0.6529

Epoch 28/30


Training: 100%|██████████| 135/135 [03:22<00:00,  1.50s/it]


Train Loss: 0.0230
Train Macro-F1: 0.9972, Weighted-F1: 0.9972


Evaluating: 100%|██████████| 20/20 [00:14<00:00,  1.41it/s]


                       precision    recall  f1-score   support

          Nervousness       0.67      0.75      0.71        53
Lack of Worry Control       0.67      0.62      0.64        47
      Excessive Worry       0.67      0.43      0.53        46
  Difficulty Relaxing       0.85      0.86      0.85        51
         Restlessness       0.49      0.50      0.50        58
       Impending Doom       0.62      0.75      0.68        52

             accuracy                           0.65       307
            macro avg       0.66      0.65      0.65       307
         weighted avg       0.66      0.65      0.65       307

Validation Loss: 1.6538
Validation Macro-F1: 0.6512, Weighted-F1: 0.6502

Epoch 29/30


Training: 100%|██████████| 135/135 [03:22<00:00,  1.50s/it]


Train Loss: 0.0256
Train Macro-F1: 0.9976, Weighted-F1: 0.9977


Evaluating: 100%|██████████| 20/20 [00:13<00:00,  1.43it/s]


                       precision    recall  f1-score   support

          Nervousness       0.67      0.75      0.71        53
Lack of Worry Control       0.69      0.62      0.65        47
      Excessive Worry       0.67      0.43      0.53        46
  Difficulty Relaxing       0.85      0.88      0.87        51
         Restlessness       0.50      0.50      0.50        58
       Impending Doom       0.62      0.77      0.69        52

             accuracy                           0.66       307
            macro avg       0.67      0.66      0.66       307
         weighted avg       0.66      0.66      0.66       307

Validation Loss: 1.6558
Validation Macro-F1: 0.6568, Weighted-F1: 0.6559

Epoch 30/30


Training: 100%|██████████| 135/135 [03:22<00:00,  1.50s/it]


Train Loss: 0.0256
Train Macro-F1: 0.9971, Weighted-F1: 0.9972


Evaluating: 100%|██████████| 20/20 [00:13<00:00,  1.45it/s]


                       precision    recall  f1-score   support

          Nervousness       0.67      0.75      0.71        53
Lack of Worry Control       0.67      0.62      0.64        47
      Excessive Worry       0.67      0.43      0.53        46
  Difficulty Relaxing       0.85      0.88      0.87        51
         Restlessness       0.49      0.48      0.49        58
       Impending Doom       0.62      0.77      0.69        52

             accuracy                           0.66       307
            macro avg       0.66      0.66      0.65       307
         weighted avg       0.66      0.66      0.65       307

Validation Loss: 1.6587
Validation Macro-F1: 0.6535, Weighted-F1: 0.6523


## Inference

In [24]:
!gdown 1Tx3n9i3fjK1JKg9KipX_JlurFEvprBQs

Downloading...
From (original): https://drive.google.com/uc?id=1Tx3n9i3fjK1JKg9KipX_JlurFEvprBQs
From (redirected): https://drive.google.com/uc?id=1Tx3n9i3fjK1JKg9KipX_JlurFEvprBQs&confirm=t&uuid=8de1d34c-5411-473d-ad4a-386a9fb4ac9e
To: /kaggle/working/no_fusion_anxiety.pth
100%|████████████████████████████████████████| 872M/872M [00:29<00:00, 29.7MB/s]


In [25]:
def inference(test_data, model_path):
    model = MultimodalConcatenationModel(
        text_model_name="mental/mental-roberta-base",
        num_classes=len(LABEL_MAP),
        fusion_dim=FUSION_DIM
    )

    model = model.to(DEVICE)

    weights = torch.load(model_path, map_location=DEVICE, weights_only=True)
    weights_single = {k.replace("module.", ""): v for k, v in weights.items()}

    model.load_state_dict(weights_single)
    tokenizer = AutoTokenizer.from_pretrained("mental/mental-roberta-base")
    image_processor = CLIPImageProcessor.from_pretrained("openai/clip-vit-base-patch32")

    test_dataset = MultimodalAnxietyDataset(test_data, "anxiety_test_image", tokenizer, image_processor, max_len=MAX_LEN)
    test_loader = DataLoader(
        test_dataset,
        batch_size=BATCH_SIZE,
        shuffle=False,
        collate_fn=custom_collate_fn
    )

    loss, macro_f1, weighted_f1 = evaluate_multimodal_model(
        model, test_loader, nn.CrossEntropyLoss()
    )

    print(f"Test Loss: {loss:.4f}")
    print(f"Test Macro-F1: {macro_f1:.4f}, Weighted-F1: {weighted_f1:.4f}")

    return macro_f1, weighted_f1

In [26]:
inference(test_data, "no_fusion_anxiety.pth")

Some weights of RobertaModel were not initialized from the model checkpoint at mental/mental-roberta-base and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Evaluating: 100%|██████████| 39/39 [00:36<00:00,  1.08it/s]

                       precision    recall  f1-score   support

          Nervousness       0.71      0.67      0.69       106
Lack of Worry Control       0.75      0.60      0.66        94
      Excessive Worry       0.58      0.62      0.60        92
  Difficulty Relaxing       0.76      0.77      0.77       102
         Restlessness       0.51      0.56      0.53       116
       Impending Doom       0.64      0.67      0.65       105

             accuracy                           0.65       615
            macro avg       0.66      0.65      0.65       615
         weighted avg       0.65      0.65      0.65       615

Test Loss: 1.5172
Test Macro-F1: 0.6503, Weighted-F1: 0.6487





(0.6503389044020362, 0.6486789102737075)