## Dataset Download

In [None]:
!gdown 1doU9FE1tJ-0IL4tKZUOOp76JkXKuTkFU

Downloading...
From (original): https://drive.google.com/uc?id=1doU9FE1tJ-0IL4tKZUOOp76JkXKuTkFU
From (redirected): https://drive.google.com/uc?id=1doU9FE1tJ-0IL4tKZUOOp76JkXKuTkFU&confirm=t&uuid=5e3bef8b-d79b-4ea5-810e-5ce7a67ce72a
To: /kaggle/working/anxiety_dataset_complete.zip
100%|██████████████████████████████████████| 1.41G/1.41G [00:17<00:00, 80.8MB/s]


In [None]:
!unzip anxiety_dataset_complete.zip > /dev/null

In [40]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from transformers import CLIPImageProcessor, CLIPVisionModel
from sklearn.metrics import f1_score, classification_report
from sklearn.model_selection import train_test_split
from tqdm import tqdm
from PIL import Image
import json
import math

TRAIN_RATIO = 0.7  # 70% train
VAL_RATIO = 0.1  # 10% validation
TEST_RATIO = 0.2  # 20% test
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
BATCH_SIZE = 16
LR = 2e-5
FUSION_DIM = 768
LABELS = ["Nervousness", "Lack of Worry Control", "Excessive Worry",
          "Difficulty Relaxing", "Restlessness", "Impending Doom"]
LABEL_MAP = {i: idx for idx, i in enumerate(LABELS)}

## Custom Dataset

In [41]:
class ImageOnlyAnxietyDataset(Dataset):
    def __init__(self, data, image_path, image_processor):
        self.data = data
        self.image_processor = image_processor
        self.img_path = image_path
        self.transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        sample = self.data[idx]

        image_path = os.path.join(self.img_path, sample["sample_id"] + ".jpg")
        image = Image.open(image_path).convert("RGB")
        image_tensor = self.image_processor(image, return_tensors="pt")

        label = LABEL_MAP[sample["meme_anxiety_category"]]

        return {
            "image": image_tensor,
            "label": torch.tensor(label, dtype=torch.long)
        }

def custom_collate_fn(batch):
    labels = torch.stack([item['label'] for item in batch])

    images = {}
    for key in batch[0]['image'].keys():
        if isinstance(batch[0]['image'][key], torch.Tensor):
            images[key] = torch.stack([item['image'][key].squeeze(0) for item in batch])

    return {
        'image': images,
        'label': labels,
    }

## Model Definition

In [42]:
class ImageOnlyModel(nn.Module):
    def __init__(self, num_classes=6):
        super(ImageOnlyModel, self).__init__()

        self.vision_encoder = CLIPVisionModel.from_pretrained("openai/clip-vit-base-patch32")
        self.vision_dim = self.vision_encoder.config.hidden_size

        self.vision_projections = nn.ModuleList([
            nn.Linear(self.vision_dim, 1024) for _ in range(3)
        ])

        self.expert_nets = nn.ModuleList([
            nn.Sequential(
                nn.Linear(1024, 1024),
                nn.LayerNorm(1024),
                nn.GELU(),
                nn.Dropout(0.2)
            ) for _ in range(4)
        ])

        self.moe_gate = nn.Linear(1024, 4)

        self.classifier = nn.Sequential(
            nn.Linear(1024, 1024),
            nn.LayerNorm(1024),
            nn.GELU(),
            nn.Dropout(0.1),
            nn.Linear(1024, num_classes)
        )


    def forward(self, image_features):
        vision_outputs = self.vision_encoder(**image_features, output_hidden_states=True)
        vision_features = [
            vision_outputs.hidden_states[-i][:, 0] for i in range(1, 4)
        ]
        vision_features = [
            proj(feat) for proj, feat in zip(self.vision_projections, vision_features)
        ]

        combined_feature = sum(vision_features) / len(vision_features)

        expert_outputs = [expert(combined_feature) for expert in self.expert_nets]
        expert_gates = F.softmax(self.moe_gate(combined_feature), dim=1)

        moe_output = torch.zeros_like(expert_outputs[0])
        for i, expert_out in enumerate(expert_outputs):
            moe_output += expert_out * expert_gates[:, i].unsqueeze(1)

        logits = self.classifier(moe_output)

        return logits

## Training Functions

In [43]:
def train_image_only_model(model, train_data, val_data, img_path, epochs, model_save_name):
    image_processor = CLIPImageProcessor.from_pretrained("openai/clip-vit-base-patch32")

    train_dataset = ImageOnlyAnxietyDataset(train_data, img_path, image_processor)
    val_dataset = ImageOnlyAnxietyDataset(val_data, img_path, image_processor)

    print("Train Set Size:", len(train_dataset))
    print("Validation Set Size:", len(val_dataset))

    train_loader = DataLoader(
        train_dataset,
        batch_size=BATCH_SIZE,
        shuffle=True,
        collate_fn=custom_collate_fn
    )

    val_loader = DataLoader(
        val_dataset,
        batch_size=BATCH_SIZE,
        shuffle=False,
        collate_fn=custom_collate_fn
    )

    optimizer = optim.AdamW(model.parameters(), lr=LR)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='max', factor=0.5, patience=2, verbose=True
    )
    criterion = nn.CrossEntropyLoss()
    
    model = model.to(DEVICE)
    model = nn.DataParallel(model)

    best_f1 = 0
    for epoch in range(epochs):
        print(f"\nEpoch {epoch+1}/{epochs}")

        model.train()
        train_loss = 0
        train_preds, train_labels = [], []

        for batch in tqdm(train_loader, desc="Training"):
            labels = batch["label"].to(DEVICE)
            image_features = {k: v.to(DEVICE) for k, v in batch["image"].items()}

            optimizer.zero_grad()
            logits = model(image_features)
            loss = criterion(logits, labels)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()

            train_loss += loss.item()
            predictions = torch.argmax(logits, dim=1).cpu().numpy()
            train_preds.extend(predictions)
            train_labels.extend(labels.cpu().numpy())

        train_loss = train_loss / len(train_loader)
        train_macro_f1 = f1_score(train_labels, train_preds, average="macro")
        train_weighted_f1 = f1_score(train_labels, train_preds, average="weighted")

        print(f"Train Loss: {train_loss:.4f}")
        print(f"Train Macro-F1: {train_macro_f1:.4f}, Weighted-F1: {train_weighted_f1:.4f}")

        val_loss, val_macro_f1, val_weighted_f1 = evaluate_image_model(
            model, val_loader, criterion
        )

        print(f"Validation Loss: {val_loss:.4f}")
        print(f"Validation Macro-F1: {val_macro_f1:.4f}, Weighted-F1: {val_weighted_f1:.4f}")

        scheduler.step(val_macro_f1)

        f1_hm = 2 * val_macro_f1 * val_weighted_f1 / (val_macro_f1 + val_weighted_f1)
        if f1_hm > best_f1:
            best_f1 = f1_hm
            torch.save(model.state_dict(), f"{model_save_name}_anxiety.pth")
            print("Best model saved!")

    return model

def evaluate_image_model(model, loader, criterion):
    model.eval()
    val_loss = 0
    val_preds, val_labels = [], []

    with torch.no_grad():
        for batch in tqdm(loader, desc="Evaluating"):
            labels = batch["label"].to(DEVICE)
            image_features = {k: v.to(DEVICE) for k, v in batch["image"].items()}

            logits = model(image_features)
            loss = criterion(logits, labels)

            val_loss += loss.item()
            predictions = torch.argmax(logits, dim=1).cpu().numpy()
            val_preds.extend(predictions)
            val_labels.extend(labels.cpu().numpy())

    val_loss = val_loss / len(loader)
    val_macro_f1 = f1_score(val_labels, val_preds, average="macro")
    val_weighted_f1 = f1_score(val_labels, val_preds, average="weighted")

    target_names = [k for k, v in sorted(LABEL_MAP.items(), key=lambda x: x[1])]
    report = classification_report(val_labels, val_preds, target_names=target_names)
    print(report)

    return val_loss, val_macro_f1, val_weighted_f1

## Model Training

In [44]:
full_train_data = json.load(open("anxiety_train_llava_dataset.json", "r"))
test_data = json.load(open("anxiety_test_llava_dataset.json", "r"))

labels = [LABEL_MAP[item["meme_anxiety_category"]] for item in full_train_data]

train_size = math.ceil(len(full_train_data) * TRAIN_RATIO / (TRAIN_RATIO + VAL_RATIO))
train_data, val_data = train_test_split(
    full_train_data, train_size=train_size, stratify=labels, random_state=42
)

img_path = "anxiety_train_image"

image_processor = CLIPImageProcessor.from_pretrained("openai/clip-vit-base-patch32")

test_dataset = ImageOnlyAnxietyDataset(test_data, "anxiety_test_image", image_processor)
test_loader = DataLoader(
    test_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    collate_fn=custom_collate_fn
)

In [None]:
model = ImageOnlyModel(num_classes=len(LABEL_MAP))

trained_model, image_processor = train_image_only_model(
    model,
    train_data,
    val_data,
    img_path,
    epochs=30,
    model_save_name="only_image"
)

Train Set Size: 2153
Validation Set Size: 307

Epoch 1/30


Training: 100%|██████████| 135/135 [01:19<00:00,  1.69it/s]


Train Loss: 1.8536
Train Macro-F1: 0.2953, Weighted-F1: 0.2953


Evaluating: 100%|██████████| 20/20 [00:09<00:00,  2.21it/s]


                       precision    recall  f1-score   support

          Nervousness       0.67      0.04      0.07        53
Lack of Worry Control       0.71      0.32      0.44        47
      Excessive Worry       0.00      0.00      0.00        46
  Difficulty Relaxing       0.33      0.88      0.48        51
         Restlessness       0.24      0.26      0.25        58
       Impending Doom       0.32      0.50      0.39        52

             accuracy                           0.34       307
            macro avg       0.38      0.33      0.27       307
         weighted avg       0.38      0.34      0.27       307

Validation Loss: 1.6980
Validation Macro-F1: 0.2708, Weighted-F1: 0.2716
Best model saved!

Epoch 2/30


Training: 100%|██████████| 135/135 [01:20<00:00,  1.68it/s]


Train Loss: 1.4580
Train Macro-F1: 0.5562, Weighted-F1: 0.5561


Evaluating: 100%|██████████| 20/20 [00:08<00:00,  2.23it/s]


                       precision    recall  f1-score   support

          Nervousness       0.38      0.51      0.43        53
Lack of Worry Control       0.59      0.49      0.53        47
      Excessive Worry       0.55      0.37      0.44        46
  Difficulty Relaxing       0.86      0.47      0.61        51
         Restlessness       0.25      0.22      0.24        58
       Impending Doom       0.43      0.71      0.54        52

             accuracy                           0.46       307
            macro avg       0.51      0.46      0.47       307
         weighted avg       0.50      0.46      0.46       307

Validation Loss: 1.4569
Validation Macro-F1: 0.4651, Weighted-F1: 0.4595
Best model saved!

Epoch 3/30


Training: 100%|██████████| 135/135 [01:20<00:00,  1.67it/s]


Train Loss: 0.9964
Train Macro-F1: 0.7621, Weighted-F1: 0.7629


Evaluating: 100%|██████████| 20/20 [00:09<00:00,  2.19it/s]


                       precision    recall  f1-score   support

          Nervousness       0.39      0.51      0.44        53
Lack of Worry Control       0.49      0.49      0.49        47
      Excessive Worry       0.42      0.24      0.31        46
  Difficulty Relaxing       0.88      0.43      0.58        51
         Restlessness       0.35      0.40      0.37        58
       Impending Doom       0.46      0.65      0.54        52

             accuracy                           0.46       307
            macro avg       0.50      0.45      0.45       307
         weighted avg       0.50      0.46      0.45       307

Validation Loss: 1.6143
Validation Macro-F1: 0.4545, Weighted-F1: 0.4548

Epoch 4/30


Training: 100%|██████████| 135/135 [01:20<00:00,  1.67it/s]


Train Loss: 0.7085
Train Macro-F1: 0.8694, Weighted-F1: 0.8698


Evaluating: 100%|██████████| 20/20 [00:09<00:00,  2.21it/s]


                       precision    recall  f1-score   support

          Nervousness       0.39      0.51      0.44        53
Lack of Worry Control       0.57      0.51      0.54        47
      Excessive Worry       0.52      0.30      0.38        46
  Difficulty Relaxing       0.74      0.57      0.64        51
         Restlessness       0.31      0.47      0.37        58
       Impending Doom       0.60      0.48      0.53        52

             accuracy                           0.48       307
            macro avg       0.52      0.47      0.49       307
         weighted avg       0.51      0.48      0.48       307

Validation Loss: 1.7553
Validation Macro-F1: 0.4851, Weighted-F1: 0.4833
Best model saved!

Epoch 5/30


Training: 100%|██████████| 135/135 [01:20<00:00,  1.67it/s]


Train Loss: 0.6106
Train Macro-F1: 0.9057, Weighted-F1: 0.9058


Evaluating: 100%|██████████| 20/20 [00:09<00:00,  2.18it/s]


                       precision    recall  f1-score   support

          Nervousness       0.39      0.51      0.44        53
Lack of Worry Control       0.44      0.40      0.42        47
      Excessive Worry       0.52      0.30      0.38        46
  Difficulty Relaxing       0.70      0.51      0.59        51
         Restlessness       0.26      0.19      0.22        58
       Impending Doom       0.41      0.69      0.51        52

             accuracy                           0.43       307
            macro avg       0.45      0.43      0.43       307
         weighted avg       0.45      0.43      0.42       307

Validation Loss: 2.1478
Validation Macro-F1: 0.4286, Weighted-F1: 0.4250

Epoch 6/30


Training: 100%|██████████| 135/135 [01:20<00:00,  1.67it/s]


Train Loss: 0.5254
Train Macro-F1: 0.9368, Weighted-F1: 0.9368


Evaluating: 100%|██████████| 20/20 [00:09<00:00,  2.21it/s]


                       precision    recall  f1-score   support

          Nervousness       0.41      0.51      0.45        53
Lack of Worry Control       0.57      0.43      0.49        47
      Excessive Worry       0.28      0.41      0.33        46
  Difficulty Relaxing       0.80      0.47      0.59        51
         Restlessness       0.29      0.22      0.25        58
       Impending Doom       0.49      0.60      0.54        52

             accuracy                           0.44       307
            macro avg       0.47      0.44      0.44       307
         weighted avg       0.47      0.44      0.44       307

Validation Loss: 2.5853
Validation Macro-F1: 0.4432, Weighted-F1: 0.4404

Epoch 7/30


Training: 100%|██████████| 135/135 [01:21<00:00,  1.67it/s]


Train Loss: 0.5776
Train Macro-F1: 0.9267, Weighted-F1: 0.9266


Evaluating: 100%|██████████| 20/20 [00:08<00:00,  2.28it/s]


                       precision    recall  f1-score   support

          Nervousness       0.38      0.57      0.45        53
Lack of Worry Control       0.61      0.53      0.57        47
      Excessive Worry       0.40      0.46      0.42        46
  Difficulty Relaxing       0.76      0.55      0.64        51
         Restlessness       0.38      0.34      0.36        58
       Impending Doom       0.55      0.46      0.50        52

             accuracy                           0.48       307
            macro avg       0.51      0.48      0.49       307
         weighted avg       0.51      0.48      0.49       307

Validation Loss: 2.4035
Validation Macro-F1: 0.4906, Weighted-F1: 0.4875
Best model saved!

Epoch 8/30


Training: 100%|██████████| 135/135 [01:21<00:00,  1.66it/s]


Train Loss: 0.4973
Train Macro-F1: 0.9485, Weighted-F1: 0.9485


Evaluating: 100%|██████████| 20/20 [00:09<00:00,  2.19it/s]


                       precision    recall  f1-score   support

          Nervousness       0.35      0.58      0.44        53
Lack of Worry Control       0.46      0.51      0.48        47
      Excessive Worry       0.47      0.39      0.43        46
  Difficulty Relaxing       0.66      0.57      0.61        51
         Restlessness       0.35      0.19      0.25        58
       Impending Doom       0.51      0.52      0.51        52

             accuracy                           0.46       307
            macro avg       0.47      0.46      0.45       307
         weighted avg       0.46      0.46      0.45       307

Validation Loss: 2.8378
Validation Macro-F1: 0.4537, Weighted-F1: 0.4491

Epoch 9/30


Training: 100%|██████████| 135/135 [01:21<00:00,  1.67it/s]


Train Loss: 0.5114
Train Macro-F1: 0.9496, Weighted-F1: 0.9494


Evaluating: 100%|██████████| 20/20 [00:09<00:00,  2.21it/s]


                       precision    recall  f1-score   support

          Nervousness       0.34      0.68      0.45        53
Lack of Worry Control       0.78      0.30      0.43        47
      Excessive Worry       0.48      0.22      0.30        46
  Difficulty Relaxing       0.56      0.49      0.52        51
         Restlessness       0.30      0.31      0.30        58
       Impending Doom       0.53      0.56      0.54        52

             accuracy                           0.43       307
            macro avg       0.49      0.43      0.42       307
         weighted avg       0.49      0.43      0.42       307

Validation Loss: 3.0701
Validation Macro-F1: 0.4241, Weighted-F1: 0.4239

Epoch 10/30


Training: 100%|██████████| 135/135 [01:20<00:00,  1.67it/s]


Train Loss: 0.4919
Train Macro-F1: 0.9487, Weighted-F1: 0.9490


Evaluating: 100%|██████████| 20/20 [00:09<00:00,  2.21it/s]


                       precision    recall  f1-score   support

          Nervousness       0.38      0.43      0.41        53
Lack of Worry Control       0.63      0.40      0.49        47
      Excessive Worry       0.54      0.28      0.37        46
  Difficulty Relaxing       0.80      0.55      0.65        51
         Restlessness       0.28      0.52      0.36        58
       Impending Doom       0.46      0.44      0.45        52

             accuracy                           0.44       307
            macro avg       0.52      0.44      0.46       307
         weighted avg       0.51      0.44      0.45       307

Validation Loss: 3.0115
Validation Macro-F1: 0.4559, Weighted-F1: 0.4543

Epoch 11/30


Training: 100%|██████████| 135/135 [01:21<00:00,  1.66it/s]


Train Loss: 0.4204
Train Macro-F1: 0.9783, Weighted-F1: 0.9787


Evaluating: 100%|██████████| 20/20 [00:09<00:00,  2.19it/s]


                       precision    recall  f1-score   support

          Nervousness       0.35      0.57      0.43        53
Lack of Worry Control       0.55      0.57      0.56        47
      Excessive Worry       0.52      0.37      0.43        46
  Difficulty Relaxing       0.78      0.57      0.66        51
         Restlessness       0.27      0.21      0.23        58
       Impending Doom       0.50      0.56      0.53        52

             accuracy                           0.47       307
            macro avg       0.49      0.47      0.47       307
         weighted avg       0.49      0.47      0.47       307

Validation Loss: 3.0952
Validation Macro-F1: 0.4745, Weighted-F1: 0.4685

Epoch 12/30


Training: 100%|██████████| 135/135 [01:21<00:00,  1.66it/s]


Train Loss: 0.3869
Train Macro-F1: 0.9882, Weighted-F1: 0.9884


Evaluating: 100%|██████████| 20/20 [00:09<00:00,  2.18it/s]


                       precision    recall  f1-score   support

          Nervousness       0.58      0.34      0.43        53
Lack of Worry Control       0.53      0.55      0.54        47
      Excessive Worry       0.46      0.46      0.46        46
  Difficulty Relaxing       0.59      0.69      0.64        51
         Restlessness       0.33      0.45      0.38        58
       Impending Doom       0.48      0.40      0.44        52

             accuracy                           0.48       307
            macro avg       0.50      0.48      0.48       307
         weighted avg       0.49      0.48      0.48       307

Validation Loss: 3.1221
Validation Macro-F1: 0.4805, Weighted-F1: 0.4774

Epoch 13/30


Training: 100%|██████████| 135/135 [01:21<00:00,  1.66it/s]


Train Loss: 0.3913
Train Macro-F1: 0.9874, Weighted-F1: 0.9875


Evaluating: 100%|██████████| 20/20 [00:09<00:00,  2.18it/s]


                       precision    recall  f1-score   support

          Nervousness       0.45      0.45      0.45        53
Lack of Worry Control       0.63      0.40      0.49        47
      Excessive Worry       0.50      0.50      0.50        46
  Difficulty Relaxing       0.74      0.57      0.64        51
         Restlessness       0.33      0.55      0.41        58
       Impending Doom       0.66      0.52      0.58        52

             accuracy                           0.50       307
            macro avg       0.55      0.50      0.51       307
         weighted avg       0.55      0.50      0.51       307

Validation Loss: 3.1540
Validation Macro-F1: 0.5136, Weighted-F1: 0.5116
Best model saved!

Epoch 14/30


Training: 100%|██████████| 135/135 [01:21<00:00,  1.67it/s]


Train Loss: 0.3830
Train Macro-F1: 0.9905, Weighted-F1: 0.9907


Evaluating: 100%|██████████| 20/20 [00:09<00:00,  2.19it/s]


                       precision    recall  f1-score   support

          Nervousness       0.54      0.53      0.53        53
Lack of Worry Control       0.49      0.49      0.49        47
      Excessive Worry       0.45      0.43      0.44        46
  Difficulty Relaxing       0.72      0.57      0.64        51
         Restlessness       0.33      0.41      0.37        58
       Impending Doom       0.54      0.54      0.54        52

             accuracy                           0.50       307
            macro avg       0.51      0.50      0.50       307
         weighted avg       0.51      0.50      0.50       307

Validation Loss: 3.2205
Validation Macro-F1: 0.5020, Weighted-F1: 0.5004

Epoch 15/30


Training: 100%|██████████| 135/135 [01:21<00:00,  1.66it/s]


Train Loss: 0.3665
Train Macro-F1: 0.9924, Weighted-F1: 0.9926


Evaluating: 100%|██████████| 20/20 [00:09<00:00,  2.17it/s]


                       precision    recall  f1-score   support

          Nervousness       0.41      0.47      0.44        53
Lack of Worry Control       0.51      0.53      0.52        47
      Excessive Worry       0.56      0.30      0.39        46
  Difficulty Relaxing       0.78      0.57      0.66        51
         Restlessness       0.34      0.50      0.40        58
       Impending Doom       0.57      0.54      0.55        52

             accuracy                           0.49       307
            macro avg       0.53      0.49      0.50       307
         weighted avg       0.52      0.49      0.49       307

Validation Loss: 3.1817
Validation Macro-F1: 0.4950, Weighted-F1: 0.4940

Epoch 16/30


Training: 100%|██████████| 135/135 [01:19<00:00,  1.70it/s]


Train Loss: 0.3680
Train Macro-F1: 0.9909, Weighted-F1: 0.9912


Evaluating: 100%|██████████| 20/20 [00:08<00:00,  2.36it/s]


                       precision    recall  f1-score   support

          Nervousness       0.43      0.55      0.48        53
Lack of Worry Control       0.59      0.51      0.55        47
      Excessive Worry       0.42      0.41      0.42        46
  Difficulty Relaxing       0.66      0.57      0.61        51
         Restlessness       0.35      0.31      0.33        58
       Impending Doom       0.49      0.54      0.51        52

             accuracy                           0.48       307
            macro avg       0.49      0.48      0.48       307
         weighted avg       0.48      0.48      0.48       307

Validation Loss: 3.1794
Validation Macro-F1: 0.4823, Weighted-F1: 0.4791

Epoch 17/30


Training: 100%|██████████| 135/135 [01:17<00:00,  1.75it/s]


Train Loss: 0.3714
Train Macro-F1: 0.9915, Weighted-F1: 0.9916


Evaluating: 100%|██████████| 20/20 [00:08<00:00,  2.38it/s]


                       precision    recall  f1-score   support

          Nervousness       0.50      0.51      0.50        53
Lack of Worry Control       0.63      0.51      0.56        47
      Excessive Worry       0.52      0.33      0.40        46
  Difficulty Relaxing       0.67      0.63      0.65        51
         Restlessness       0.34      0.41      0.37        58
       Impending Doom       0.48      0.62      0.54        52

             accuracy                           0.50       307
            macro avg       0.52      0.50      0.50       307
         weighted avg       0.52      0.50      0.50       307

Validation Loss: 3.3746
Validation Macro-F1: 0.5043, Weighted-F1: 0.5023

Epoch 18/30


Training: 100%|██████████| 135/135 [01:19<00:00,  1.70it/s]


Train Loss: 0.3622
Train Macro-F1: 0.9939, Weighted-F1: 0.9940


Evaluating: 100%|██████████| 20/20 [00:08<00:00,  2.28it/s]


                       precision    recall  f1-score   support

          Nervousness       0.42      0.58      0.49        53
Lack of Worry Control       0.61      0.49      0.54        47
      Excessive Worry       0.56      0.30      0.39        46
  Difficulty Relaxing       0.82      0.53      0.64        51
         Restlessness       0.31      0.41      0.35        58
       Impending Doom       0.49      0.56      0.52        52

             accuracy                           0.48       307
            macro avg       0.53      0.48      0.49       307
         weighted avg       0.53      0.48      0.49       307

Validation Loss: 3.4535
Validation Macro-F1: 0.4903, Weighted-F1: 0.4882

Epoch 19/30


Training: 100%|██████████| 135/135 [01:18<00:00,  1.72it/s]


Train Loss: 0.3523
Train Macro-F1: 0.9953, Weighted-F1: 0.9954


Evaluating: 100%|██████████| 20/20 [00:08<00:00,  2.38it/s]


                       precision    recall  f1-score   support

          Nervousness       0.49      0.49      0.49        53
Lack of Worry Control       0.48      0.55      0.51        47
      Excessive Worry       0.55      0.37      0.44        46
  Difficulty Relaxing       0.72      0.57      0.64        51
         Restlessness       0.34      0.43      0.38        58
       Impending Doom       0.50      0.54      0.52        52

             accuracy                           0.49       307
            macro avg       0.51      0.49      0.50       307
         weighted avg       0.51      0.49      0.50       307

Validation Loss: 3.4072
Validation Macro-F1: 0.4974, Weighted-F1: 0.4955

Epoch 20/30


Training: 100%|██████████| 135/135 [01:17<00:00,  1.75it/s]


Train Loss: 0.3508
Train Macro-F1: 0.9953, Weighted-F1: 0.9954


Evaluating: 100%|██████████| 20/20 [00:08<00:00,  2.34it/s]


                       precision    recall  f1-score   support

          Nervousness       0.48      0.43      0.46        53
Lack of Worry Control       0.50      0.53      0.52        47
      Excessive Worry       0.55      0.39      0.46        46
  Difficulty Relaxing       0.73      0.59      0.65        51
         Restlessness       0.33      0.48      0.39        58
       Impending Doom       0.51      0.50      0.50        52

             accuracy                           0.49       307
            macro avg       0.52      0.49      0.50       307
         weighted avg       0.51      0.49      0.49       307

Validation Loss: 3.4333
Validation Macro-F1: 0.4963, Weighted-F1: 0.4942

Epoch 21/30


Training: 100%|██████████| 135/135 [01:17<00:00,  1.74it/s]


Train Loss: 0.3500
Train Macro-F1: 0.9953, Weighted-F1: 0.9954


Evaluating: 100%|██████████| 20/20 [00:08<00:00,  2.33it/s]


                       precision    recall  f1-score   support

          Nervousness       0.43      0.43      0.43        53
Lack of Worry Control       0.56      0.53      0.54        47
      Excessive Worry       0.56      0.39      0.46        46
  Difficulty Relaxing       0.78      0.57      0.66        51
         Restlessness       0.31      0.50      0.38        58
       Impending Doom       0.52      0.46      0.49        52

             accuracy                           0.48       307
            macro avg       0.53      0.48      0.49       307
         weighted avg       0.52      0.48      0.49       307

Validation Loss: 3.4373
Validation Macro-F1: 0.4947, Weighted-F1: 0.4916

Epoch 22/30


Training: 100%|██████████| 135/135 [01:16<00:00,  1.76it/s]


Train Loss: 0.3498
Train Macro-F1: 0.9943, Weighted-F1: 0.9944


Evaluating: 100%|██████████| 20/20 [00:08<00:00,  2.32it/s]


                       precision    recall  f1-score   support

          Nervousness       0.45      0.43      0.44        53
Lack of Worry Control       0.56      0.53      0.54        47
      Excessive Worry       0.59      0.37      0.45        46
  Difficulty Relaxing       0.78      0.57      0.66        51
         Restlessness       0.30      0.50      0.37        58
       Impending Doom       0.49      0.44      0.46        52

             accuracy                           0.48       307
            macro avg       0.53      0.47      0.49       307
         weighted avg       0.52      0.48      0.49       307

Validation Loss: 3.4416
Validation Macro-F1: 0.4891, Weighted-F1: 0.4859

Epoch 23/30


Training: 100%|██████████| 135/135 [01:17<00:00,  1.75it/s]


Train Loss: 0.3498
Train Macro-F1: 0.9953, Weighted-F1: 0.9954


Evaluating: 100%|██████████| 20/20 [00:08<00:00,  2.35it/s]


                       precision    recall  f1-score   support

          Nervousness       0.47      0.43      0.45        53
Lack of Worry Control       0.57      0.53      0.55        47
      Excessive Worry       0.57      0.37      0.45        46
  Difficulty Relaxing       0.80      0.55      0.65        51
         Restlessness       0.30      0.50      0.37        58
       Impending Doom       0.47      0.46      0.47        52

             accuracy                           0.48       307
            macro avg       0.53      0.47      0.49       307
         weighted avg       0.52      0.48      0.49       307

Validation Loss: 3.4790
Validation Macro-F1: 0.4895, Weighted-F1: 0.4864

Epoch 24/30


Training: 100%|██████████| 135/135 [01:17<00:00,  1.75it/s]


Train Loss: 0.3495
Train Macro-F1: 0.9962, Weighted-F1: 0.9963


Evaluating: 100%|██████████| 20/20 [00:08<00:00,  2.35it/s]


                       precision    recall  f1-score   support

          Nervousness       0.45      0.43      0.44        53
Lack of Worry Control       0.58      0.53      0.56        47
      Excessive Worry       0.55      0.37      0.44        46
  Difficulty Relaxing       0.79      0.53      0.64        51
         Restlessness       0.29      0.48      0.37        58
       Impending Doom       0.49      0.50      0.50        52

             accuracy                           0.48       307
            macro avg       0.53      0.47      0.49       307
         weighted avg       0.52      0.48      0.49       307

Validation Loss: 3.4558
Validation Macro-F1: 0.4893, Weighted-F1: 0.4861

Epoch 25/30


Training: 100%|██████████| 135/135 [01:16<00:00,  1.76it/s]


Train Loss: 0.3492
Train Macro-F1: 0.9953, Weighted-F1: 0.9954


Evaluating: 100%|██████████| 20/20 [00:08<00:00,  2.33it/s]


                       precision    recall  f1-score   support

          Nervousness       0.44      0.45      0.45        53
Lack of Worry Control       0.58      0.53      0.56        47
      Excessive Worry       0.55      0.37      0.44        46
  Difficulty Relaxing       0.82      0.53      0.64        51
         Restlessness       0.31      0.47      0.37        58
       Impending Doom       0.50      0.56      0.53        52

             accuracy                           0.49       307
            macro avg       0.53      0.48      0.50       307
         weighted avg       0.53      0.49      0.49       307

Validation Loss: 3.4568
Validation Macro-F1: 0.4976, Weighted-F1: 0.4946

Epoch 26/30


Training: 100%|██████████| 135/135 [01:17<00:00,  1.75it/s]


Train Loss: 0.3483
Train Macro-F1: 0.9953, Weighted-F1: 0.9954


Evaluating: 100%|██████████| 20/20 [00:08<00:00,  2.28it/s]


                       precision    recall  f1-score   support

          Nervousness       0.44      0.45      0.45        53
Lack of Worry Control       0.58      0.53      0.56        47
      Excessive Worry       0.55      0.37      0.44        46
  Difficulty Relaxing       0.82      0.55      0.66        51
         Restlessness       0.31      0.47      0.37        58
       Impending Doom       0.49      0.54      0.51        52

             accuracy                           0.49       307
            macro avg       0.53      0.48      0.50       307
         weighted avg       0.53      0.49      0.50       307

Validation Loss: 3.4423
Validation Macro-F1: 0.4980, Weighted-F1: 0.4950

Epoch 27/30


Training: 100%|██████████| 135/135 [01:16<00:00,  1.76it/s]


Train Loss: 0.3482
Train Macro-F1: 0.9957, Weighted-F1: 0.9958


Evaluating: 100%|██████████| 20/20 [00:08<00:00,  2.37it/s]


                       precision    recall  f1-score   support

          Nervousness       0.46      0.43      0.45        53
Lack of Worry Control       0.58      0.53      0.56        47
      Excessive Worry       0.55      0.37      0.44        46
  Difficulty Relaxing       0.82      0.53      0.64        51
         Restlessness       0.29      0.47      0.36        58
       Impending Doom       0.49      0.54      0.51        52

             accuracy                           0.48       307
            macro avg       0.53      0.48      0.49       307
         weighted avg       0.52      0.48      0.49       307

Validation Loss: 3.4700
Validation Macro-F1: 0.4930, Weighted-F1: 0.4897

Epoch 28/30


Training: 100%|██████████| 135/135 [01:17<00:00,  1.75it/s]


Train Loss: 0.3482
Train Macro-F1: 0.9962, Weighted-F1: 0.9963


Evaluating: 100%|██████████| 20/20 [00:08<00:00,  2.35it/s]


                       precision    recall  f1-score   support

          Nervousness       0.46      0.43      0.45        53
Lack of Worry Control       0.58      0.53      0.56        47
      Excessive Worry       0.53      0.37      0.44        46
  Difficulty Relaxing       0.82      0.53      0.64        51
         Restlessness       0.29      0.47      0.36        58
       Impending Doom       0.49      0.54      0.51        52

             accuracy                           0.48       307
            macro avg       0.53      0.48      0.49       307
         weighted avg       0.52      0.48      0.49       307

Validation Loss: 3.4812
Validation Macro-F1: 0.4924, Weighted-F1: 0.4893

Epoch 29/30


Training: 100%|██████████| 135/135 [01:17<00:00,  1.74it/s]


Train Loss: 0.3476
Train Macro-F1: 0.9971, Weighted-F1: 0.9972


Evaluating: 100%|██████████| 20/20 [00:08<00:00,  2.28it/s]


                       precision    recall  f1-score   support

          Nervousness       0.47      0.43      0.45        53
Lack of Worry Control       0.58      0.53      0.56        47
      Excessive Worry       0.53      0.37      0.44        46
  Difficulty Relaxing       0.82      0.53      0.64        51
         Restlessness       0.29      0.47      0.36        58
       Impending Doom       0.49      0.54      0.51        52

             accuracy                           0.48       307
            macro avg       0.53      0.48      0.49       307
         weighted avg       0.52      0.48      0.49       307

Validation Loss: 3.4880
Validation Macro-F1: 0.4928, Weighted-F1: 0.4896

Epoch 30/30


Training: 100%|██████████| 135/135 [01:18<00:00,  1.73it/s]


Train Loss: 0.3474
Train Macro-F1: 0.9966, Weighted-F1: 0.9968


Evaluating: 100%|██████████| 20/20 [00:08<00:00,  2.26it/s]


                       precision    recall  f1-score   support

          Nervousness       0.47      0.43      0.45        53
Lack of Worry Control       0.57      0.53      0.55        47
      Excessive Worry       0.53      0.37      0.44        46
  Difficulty Relaxing       0.82      0.53      0.64        51
         Restlessness       0.30      0.48      0.37        58
       Impending Doom       0.49      0.52      0.50        52

             accuracy                           0.48       307
            macro avg       0.53      0.48      0.49       307
         weighted avg       0.52      0.48      0.49       307

Validation Loss: 3.4840
Validation Macro-F1: 0.4920, Weighted-F1: 0.4892


## Inference

In [71]:
!gdown 1rdLwO2mZTYoD7NtMYjpZW_owsPzyHjvi

Downloading...
From (original): https://drive.google.com/uc?id=1rdLwO2mZTYoD7NtMYjpZW_owsPzyHjvi
From (redirected): https://drive.google.com/uc?id=1rdLwO2mZTYoD7NtMYjpZW_owsPzyHjvi&confirm=t&uuid=e3037a52-acd3-480e-9130-43d9e05bc16b
To: /kaggle/working/only_image_anxiety.pth
100%|█████████████████████████████████████████| 380M/380M [00:02<00:00, 151MB/s]


In [72]:
def inference(test_data, model_path):
    model = ImageOnlyModel(num_classes=len(LABEL_MAP))

    model = model.to(DEVICE)

    weights = torch.load(model_path, map_location=DEVICE, weights_only=True)
    weights_single = {k.replace("module.", ""): v for k, v in weights.items() if "contrastive_projection" not in k}

    model.load_state_dict(weights_single)
    image_processor = CLIPImageProcessor.from_pretrained("openai/clip-vit-base-patch32")

    test_dataset = ImageOnlyAnxietyDataset(test_data, "anxiety_test_image", image_processor)
    test_loader = DataLoader(
        test_dataset,
        batch_size=BATCH_SIZE,
        shuffle=False,
        collate_fn=custom_collate_fn
    )

    loss, macro_f1, weighted_f1 = evaluate_image_model(
        model, test_loader, nn.CrossEntropyLoss()
    )

    print(f"Test Loss: {loss:.4f}")
    print(f"Test Macro-F1: {macro_f1:.4f}, Weighted-F1: {weighted_f1:.4f}")

    return macro_f1, weighted_f1

In [73]:
inference(test_data, "only_image_anxiety.pth")

Evaluating: 100%|██████████| 39/39 [00:17<00:00,  2.23it/s]


                       precision    recall  f1-score   support

          Nervousness       0.44      0.36      0.40       106
Lack of Worry Control       0.67      0.39      0.50        94
      Excessive Worry       0.41      0.39      0.40        92
  Difficulty Relaxing       0.76      0.58      0.66       102
         Restlessness       0.37      0.74      0.49       116
       Impending Doom       0.55      0.40      0.46       105

             accuracy                           0.48       615
            macro avg       0.53      0.48      0.48       615
         weighted avg       0.53      0.48      0.49       615

Test Loss: 3.2406
Test Macro-F1: 0.4845, Weighted-F1: 0.4852


(0.48453214917837184, 0.48522518139802373)