## Mental Health NLP Project

## Imports & Dependencies

In [1]:
import torch
import pandas as pd
import matplotlib.pyplot as plt

from src.utils import get_device
from src.inference import predict_text_rnn
from src.training.training import (
    train_rnn,
    train_baselines,
    train_transformer_experiment,
)

from src.data.preprocess import clean_text


## Configurations & Hyperparameters

In [2]:
CSV_PATH = "data/mental_health.csv"

## Train Models

### Train Baselines

In [None]:
baseline_results = train_baselines(CSV_PATH)

print("Majority baseline metrics:")
print(baseline_results["majority"])

print("\nTF-IDF + Logistic Regression metrics:")
print(baseline_results["tfidf_logreg"])

print("\nTF-IDF + Linear SVM metrics:")
print(baseline_results["tfidf_svm"])

  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])


Majority baseline metrics:
{'accuracy': 0.3102619258509427, 'f1_macro': 0.06765537697454646, 'classification_report': {'Anxiety': {'precision': 0.0, 'recall': 0.0, 'f1-score': 0.0, 'support': 576.0}, 'Bipolar': {'precision': 0.0, 'recall': 0.0, 'f1-score': 0.0, 'support': 417.0}, 'Depression': {'precision': 0.0, 'recall': 0.0, 'f1-score': 0.0, 'support': 2311.0}, 'Normal': {'precision': 0.3102619258509427, 'recall': 1.0, 'f1-score': 0.4735876388218252, 'support': 2452.0}, 'Personality disorder': {'precision': 0.0, 'recall': 0.0, 'f1-score': 0.0, 'support': 161.0}, 'Stress': {'precision': 0.0, 'recall': 0.0, 'f1-score': 0.0, 'support': 388.0}, 'Suicidal': {'precision': 0.0, 'recall': 0.0, 'f1-score': 0.0, 'support': 1598.0}, 'accuracy': 0.3102619258509427, 'macro avg': {'precision': 0.04432313226442038, 'recall': 0.14285714285714285, 'f1-score': 0.06765537697454646, 'support': 7903.0}, 'weighted avg': {'precision': 0.09626246263273586, 'recall': 0.3102619258509427, 'f1-score': 0.1469362

### Train LSTM Model

In [4]:
rnn_result = train_rnn(CSV_PATH, model_type="lstm")
print("LSTM test metrics:", rnn_result["test_metrics"])

Epoch 1/32: 100%|██████████| 1153/1153 [00:12<00:00, 92.03it/s]
Evaluating RNN: 100%|██████████| 247/247 [00:00<00:00, 276.45it/s]


Epoch 1: train_loss=1.0326, val_acc=0.6741, val_f1=0.5321


Epoch 2/32: 100%|██████████| 1153/1153 [00:12<00:00, 94.87it/s]
Evaluating RNN: 100%|██████████| 247/247 [00:00<00:00, 273.42it/s]


Epoch 2: train_loss=0.7149, val_acc=0.7240, val_f1=0.6042


Epoch 3/32: 100%|██████████| 1153/1153 [00:12<00:00, 94.82it/s]
Evaluating RNN: 100%|██████████| 247/247 [00:00<00:00, 273.19it/s]


Epoch 3: train_loss=0.5862, val_acc=0.7274, val_f1=0.6491


Epoch 4/32: 100%|██████████| 1153/1153 [00:12<00:00, 94.60it/s]
Evaluating RNN: 100%|██████████| 247/247 [00:00<00:00, 274.37it/s]


Epoch 4: train_loss=0.4836, val_acc=0.7447, val_f1=0.6735


Epoch 5/32: 100%|██████████| 1153/1153 [00:12<00:00, 94.77it/s]
Evaluating RNN: 100%|██████████| 247/247 [00:00<00:00, 275.64it/s]


Epoch 5: train_loss=0.4051, val_acc=0.7460, val_f1=0.6888


Epoch 6/32: 100%|██████████| 1153/1153 [00:12<00:00, 94.97it/s]
Evaluating RNN: 100%|██████████| 247/247 [00:00<00:00, 267.19it/s]


Epoch 6: train_loss=0.3319, val_acc=0.7468, val_f1=0.6875


Epoch 7/32: 100%|██████████| 1153/1153 [00:12<00:00, 93.53it/s]
Evaluating RNN: 100%|██████████| 247/247 [00:00<00:00, 272.37it/s]


Epoch 7: train_loss=0.2702, val_acc=0.7545, val_f1=0.6962


Epoch 8/32: 100%|██████████| 1153/1153 [00:12<00:00, 92.78it/s]
Evaluating RNN: 100%|██████████| 247/247 [00:00<00:00, 276.17it/s]


Epoch 8: train_loss=0.2227, val_acc=0.7508, val_f1=0.6941


Epoch 9/32: 100%|██████████| 1153/1153 [00:12<00:00, 94.93it/s]
Evaluating RNN: 100%|██████████| 247/247 [00:00<00:00, 266.79it/s]


Epoch 9: train_loss=0.1773, val_acc=0.7488, val_f1=0.7027


Epoch 10/32: 100%|██████████| 1153/1153 [00:12<00:00, 94.42it/s]
Evaluating RNN: 100%|██████████| 247/247 [00:00<00:00, 273.15it/s]


Epoch 10: train_loss=0.1452, val_acc=0.7506, val_f1=0.6892


Epoch 11/32: 100%|██████████| 1153/1153 [00:12<00:00, 94.70it/s]
Evaluating RNN: 100%|██████████| 247/247 [00:00<00:00, 276.02it/s]


Epoch 11: train_loss=0.1163, val_acc=0.7515, val_f1=0.7047


Epoch 12/32: 100%|██████████| 1153/1153 [00:12<00:00, 94.79it/s]
Evaluating RNN: 100%|██████████| 247/247 [00:00<00:00, 272.59it/s]


Epoch 12: train_loss=0.0901, val_acc=0.7402, val_f1=0.6927


Epoch 13/32: 100%|██████████| 1153/1153 [00:12<00:00, 94.77it/s]
Evaluating RNN: 100%|██████████| 247/247 [00:00<00:00, 272.20it/s]


Epoch 13: train_loss=0.0819, val_acc=0.7431, val_f1=0.6933


Epoch 14/32: 100%|██████████| 1153/1153 [00:12<00:00, 93.01it/s]
Evaluating RNN: 100%|██████████| 247/247 [00:00<00:00, 275.14it/s]


Epoch 14: train_loss=0.0704, val_acc=0.7407, val_f1=0.6869


Epoch 15/32: 100%|██████████| 1153/1153 [00:12<00:00, 95.04it/s]
Evaluating RNN: 100%|██████████| 247/247 [00:00<00:00, 290.43it/s]


Epoch 15: train_loss=0.0615, val_acc=0.7423, val_f1=0.6925


Epoch 16/32: 100%|██████████| 1153/1153 [00:12<00:00, 94.35it/s]
Evaluating RNN: 100%|██████████| 247/247 [00:00<00:00, 270.88it/s]


Epoch 16: train_loss=0.0568, val_acc=0.7408, val_f1=0.6890


Epoch 17/32: 100%|██████████| 1153/1153 [00:12<00:00, 94.83it/s]
Evaluating RNN: 100%|██████████| 247/247 [00:00<00:00, 272.80it/s]


Epoch 17: train_loss=0.0486, val_acc=0.7480, val_f1=0.6973


Epoch 18/32: 100%|██████████| 1153/1153 [00:12<00:00, 94.73it/s]
Evaluating RNN: 100%|██████████| 247/247 [00:00<00:00, 272.36it/s]


Epoch 18: train_loss=0.0418, val_acc=0.7382, val_f1=0.6824


Epoch 19/32: 100%|██████████| 1153/1153 [00:12<00:00, 94.33it/s]
Evaluating RNN: 100%|██████████| 247/247 [00:00<00:00, 269.37it/s]


Epoch 19: train_loss=0.0415, val_acc=0.7442, val_f1=0.6977


Epoch 20/32: 100%|██████████| 1153/1153 [00:12<00:00, 94.83it/s]
Evaluating RNN: 100%|██████████| 247/247 [00:00<00:00, 271.03it/s]


Epoch 20: train_loss=0.0440, val_acc=0.7439, val_f1=0.6963


Epoch 21/32: 100%|██████████| 1153/1153 [00:12<00:00, 93.60it/s]
Evaluating RNN: 100%|██████████| 247/247 [00:00<00:00, 272.45it/s]


Epoch 21: train_loss=0.0368, val_acc=0.7421, val_f1=0.6990


Epoch 22/32: 100%|██████████| 1153/1153 [00:12<00:00, 90.43it/s]
Evaluating RNN: 100%|██████████| 247/247 [00:01<00:00, 237.45it/s]


Epoch 22: train_loss=0.0367, val_acc=0.7408, val_f1=0.6903


Epoch 23/32: 100%|██████████| 1153/1153 [00:12<00:00, 94.81it/s]
Evaluating RNN: 100%|██████████| 247/247 [00:00<00:00, 271.55it/s]


Epoch 23: train_loss=0.0372, val_acc=0.7431, val_f1=0.7003


Epoch 24/32: 100%|██████████| 1153/1153 [00:12<00:00, 94.83it/s]
Evaluating RNN: 100%|██████████| 247/247 [00:00<00:00, 273.72it/s]


Epoch 24: train_loss=0.0307, val_acc=0.7425, val_f1=0.6929


Epoch 25/32: 100%|██████████| 1153/1153 [00:12<00:00, 94.92it/s]
Evaluating RNN: 100%|██████████| 247/247 [00:00<00:00, 273.41it/s]


Epoch 25: train_loss=0.0341, val_acc=0.7437, val_f1=0.6918


Epoch 26/32: 100%|██████████| 1153/1153 [00:12<00:00, 94.67it/s]
Evaluating RNN: 100%|██████████| 247/247 [00:00<00:00, 272.43it/s]


Epoch 26: train_loss=0.0277, val_acc=0.7356, val_f1=0.6886


Epoch 27/32: 100%|██████████| 1153/1153 [00:12<00:00, 94.60it/s]
Evaluating RNN: 100%|██████████| 247/247 [00:00<00:00, 270.13it/s]


Epoch 27: train_loss=0.0331, val_acc=0.7472, val_f1=0.7006


Epoch 28/32: 100%|██████████| 1153/1153 [00:12<00:00, 95.08it/s]
Evaluating RNN: 100%|██████████| 247/247 [00:00<00:00, 271.60it/s]


Epoch 28: train_loss=0.0315, val_acc=0.7454, val_f1=0.6997


Epoch 29/32: 100%|██████████| 1153/1153 [00:12<00:00, 94.93it/s]
Evaluating RNN: 100%|██████████| 247/247 [00:00<00:00, 273.89it/s]


Epoch 29: train_loss=0.0293, val_acc=0.7445, val_f1=0.6941


Epoch 30/32: 100%|██████████| 1153/1153 [00:12<00:00, 94.93it/s]
Evaluating RNN: 100%|██████████| 247/247 [00:00<00:00, 266.02it/s]


Epoch 30: train_loss=0.0250, val_acc=0.7412, val_f1=0.6882


Epoch 31/32: 100%|██████████| 1153/1153 [00:12<00:00, 94.89it/s]
Evaluating RNN: 100%|██████████| 247/247 [00:00<00:00, 269.44it/s]


Epoch 31: train_loss=0.0256, val_acc=0.7418, val_f1=0.6890


Epoch 32/32: 100%|██████████| 1153/1153 [00:12<00:00, 95.12it/s]
Evaluating RNN: 100%|██████████| 247/247 [00:00<00:00, 272.67it/s]


Epoch 32: train_loss=0.0244, val_acc=0.7439, val_f1=0.6952


Evaluating RNN: 100%|██████████| 247/247 [00:01<00:00, 230.81it/s]


Test metrics: 0.7382006832848286 0.6872085557596751
LSTM test metrics: {'accuracy': 0.7382006832848286, 'f1_macro': 0.6872085557596751, 'classification_report': {'Anxiety': {'precision': 0.7896678966789668, 'recall': 0.7430555555555556, 'f1-score': 0.7656529516994633, 'support': 576.0}, 'Bipolar': {'precision': 0.781491002570694, 'recall': 0.7290167865707434, 'f1-score': 0.7543424317617866, 'support': 417.0}, 'Depression': {'precision': 0.6714285714285714, 'recall': 0.6914755517092168, 'f1-score': 0.6813046258793434, 'support': 2311.0}, 'Normal': {'precision': 0.8884086444007858, 'recall': 0.9221044045676998, 'f1-score': 0.9049429657794676, 'support': 2452.0}, 'Personality disorder': {'precision': 0.5517241379310345, 'recall': 0.4968944099378882, 'f1-score': 0.5228758169934641, 'support': 161.0}, 'Stress': {'precision': 0.5788113695090439, 'recall': 0.5773195876288659, 'f1-score': 0.5780645161290323, 'support': 388.0}, 'Suicidal': {'precision': 0.6198019801980198, 'recall': 0.587609511

### Train GRU Model

In [5]:
gru_result = train_rnn(CSV_PATH, model_type="gru")
print("GRU test metrics:", gru_result["test_metrics"])

Epoch 1/32: 100%|██████████| 1153/1153 [00:34<00:00, 33.44it/s]
Evaluating RNN: 100%|██████████| 247/247 [00:01<00:00, 127.37it/s]
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])


Epoch 1: train_loss=1.0412, val_acc=0.6878, val_f1=0.4747


Epoch 2/32: 100%|██████████| 1153/1153 [00:34<00:00, 33.53it/s]
Evaluating RNN: 100%|██████████| 247/247 [00:01<00:00, 127.96it/s]


Epoch 2: train_loss=0.6970, val_acc=0.7426, val_f1=0.6360


Epoch 3/32: 100%|██████████| 1153/1153 [00:34<00:00, 33.83it/s]
Evaluating RNN: 100%|██████████| 247/247 [00:01<00:00, 133.03it/s]


Epoch 3: train_loss=0.5542, val_acc=0.7493, val_f1=0.6792


Epoch 4/32: 100%|██████████| 1153/1153 [00:34<00:00, 33.02it/s]
Evaluating RNN: 100%|██████████| 247/247 [00:01<00:00, 128.19it/s]


Epoch 4: train_loss=0.4535, val_acc=0.7482, val_f1=0.6843


Epoch 5/32: 100%|██████████| 1153/1153 [00:34<00:00, 32.99it/s]
Evaluating RNN: 100%|██████████| 247/247 [00:02<00:00, 123.42it/s]


Epoch 5: train_loss=0.3670, val_acc=0.7558, val_f1=0.7006


Epoch 6/32: 100%|██████████| 1153/1153 [00:36<00:00, 31.93it/s]
Evaluating RNN: 100%|██████████| 247/247 [00:02<00:00, 114.74it/s]


Epoch 6: train_loss=0.2878, val_acc=0.7456, val_f1=0.6992


Epoch 7/32: 100%|██████████| 1153/1153 [00:36<00:00, 31.85it/s]
Evaluating RNN: 100%|██████████| 247/247 [00:01<00:00, 126.56it/s]


Epoch 7: train_loss=0.2261, val_acc=0.7499, val_f1=0.7002


Epoch 8/32: 100%|██████████| 1153/1153 [00:36<00:00, 31.91it/s]
Evaluating RNN: 100%|██████████| 247/247 [00:01<00:00, 128.97it/s]


Epoch 8: train_loss=0.1719, val_acc=0.7459, val_f1=0.6976


Epoch 9/32: 100%|██████████| 1153/1153 [00:35<00:00, 32.09it/s]
Evaluating RNN: 100%|██████████| 247/247 [00:01<00:00, 125.52it/s]


Epoch 9: train_loss=0.1362, val_acc=0.7413, val_f1=0.6901


Epoch 10/32: 100%|██████████| 1153/1153 [00:35<00:00, 32.05it/s]
Evaluating RNN: 100%|██████████| 247/247 [00:01<00:00, 123.86it/s]


Epoch 10: train_loss=0.1057, val_acc=0.7363, val_f1=0.6891


Epoch 11/32: 100%|██████████| 1153/1153 [00:35<00:00, 32.25it/s]
Evaluating RNN: 100%|██████████| 247/247 [00:01<00:00, 125.44it/s]


Epoch 11: train_loss=0.0898, val_acc=0.7377, val_f1=0.6869


Epoch 12/32: 100%|██████████| 1153/1153 [00:35<00:00, 32.13it/s]
Evaluating RNN: 100%|██████████| 247/247 [00:01<00:00, 124.81it/s]


Epoch 12: train_loss=0.0726, val_acc=0.7316, val_f1=0.6762


Epoch 13/32: 100%|██████████| 1153/1153 [00:35<00:00, 32.31it/s]
Evaluating RNN: 100%|██████████| 247/247 [00:01<00:00, 124.89it/s]


Epoch 13: train_loss=0.0649, val_acc=0.7377, val_f1=0.6913


Epoch 14/32: 100%|██████████| 1153/1153 [00:35<00:00, 32.19it/s]
Evaluating RNN: 100%|██████████| 247/247 [00:01<00:00, 124.81it/s]


Epoch 14: train_loss=0.0607, val_acc=0.7389, val_f1=0.6911


Epoch 15/32: 100%|██████████| 1153/1153 [00:35<00:00, 32.52it/s]
Evaluating RNN: 100%|██████████| 247/247 [00:01<00:00, 124.87it/s]


Epoch 15: train_loss=0.0502, val_acc=0.7379, val_f1=0.6900


Epoch 16/32: 100%|██████████| 1153/1153 [00:35<00:00, 32.22it/s]
Evaluating RNN: 100%|██████████| 247/247 [00:01<00:00, 124.33it/s]


Epoch 16: train_loss=0.0478, val_acc=0.7331, val_f1=0.6859


Epoch 17/32: 100%|██████████| 1153/1153 [00:35<00:00, 32.50it/s]
Evaluating RNN: 100%|██████████| 247/247 [00:01<00:00, 124.92it/s]


Epoch 17: train_loss=0.0468, val_acc=0.7356, val_f1=0.6888


Epoch 18/32: 100%|██████████| 1153/1153 [00:35<00:00, 32.44it/s]
Evaluating RNN: 100%|██████████| 247/247 [00:01<00:00, 125.57it/s]


Epoch 18: train_loss=0.0456, val_acc=0.7355, val_f1=0.6889


Epoch 19/32: 100%|██████████| 1153/1153 [00:35<00:00, 32.21it/s]
Evaluating RNN: 100%|██████████| 247/247 [00:01<00:00, 138.58it/s]


Epoch 19: train_loss=0.0378, val_acc=0.7346, val_f1=0.6946


Epoch 20/32: 100%|██████████| 1153/1153 [00:35<00:00, 32.22it/s]
Evaluating RNN: 100%|██████████| 247/247 [00:01<00:00, 142.03it/s]


Epoch 20: train_loss=0.0361, val_acc=0.7308, val_f1=0.6783


Epoch 21/32: 100%|██████████| 1153/1153 [00:33<00:00, 34.93it/s]
Evaluating RNN: 100%|██████████| 247/247 [00:01<00:00, 137.69it/s]


Epoch 21: train_loss=0.0440, val_acc=0.7222, val_f1=0.6731


Epoch 22/32: 100%|██████████| 1153/1153 [00:33<00:00, 34.61it/s]
Evaluating RNN: 100%|██████████| 247/247 [00:01<00:00, 141.42it/s]


Epoch 22: train_loss=0.0412, val_acc=0.7336, val_f1=0.6845


Epoch 23/32: 100%|██████████| 1153/1153 [00:33<00:00, 34.66it/s]
Evaluating RNN: 100%|██████████| 247/247 [00:01<00:00, 137.83it/s]


Epoch 23: train_loss=0.0329, val_acc=0.7317, val_f1=0.6834


Epoch 24/32: 100%|██████████| 1153/1153 [00:33<00:00, 34.91it/s]
Evaluating RNN: 100%|██████████| 247/247 [00:01<00:00, 139.82it/s]


Epoch 24: train_loss=0.0347, val_acc=0.7334, val_f1=0.6918


Epoch 25/32: 100%|██████████| 1153/1153 [00:32<00:00, 35.23it/s]
Evaluating RNN: 100%|██████████| 247/247 [00:01<00:00, 139.45it/s]


Epoch 25: train_loss=0.0312, val_acc=0.7330, val_f1=0.6853


Epoch 26/32: 100%|██████████| 1153/1153 [00:33<00:00, 34.88it/s]
Evaluating RNN: 100%|██████████| 247/247 [00:01<00:00, 142.43it/s]


Epoch 26: train_loss=0.0338, val_acc=0.7326, val_f1=0.6868


Epoch 27/32: 100%|██████████| 1153/1153 [00:32<00:00, 35.34it/s]
Evaluating RNN: 100%|██████████| 247/247 [00:01<00:00, 137.70it/s]


Epoch 27: train_loss=0.0267, val_acc=0.7278, val_f1=0.6818


Epoch 28/32: 100%|██████████| 1153/1153 [00:33<00:00, 34.82it/s]
Evaluating RNN: 100%|██████████| 247/247 [00:01<00:00, 138.41it/s]


Epoch 28: train_loss=0.0346, val_acc=0.7311, val_f1=0.6846


Epoch 29/32: 100%|██████████| 1153/1153 [00:32<00:00, 35.04it/s]
Evaluating RNN: 100%|██████████| 247/247 [00:01<00:00, 141.22it/s]


Epoch 29: train_loss=0.0301, val_acc=0.7265, val_f1=0.6804


Epoch 30/32: 100%|██████████| 1153/1153 [00:32<00:00, 35.12it/s]
Evaluating RNN: 100%|██████████| 247/247 [00:01<00:00, 139.68it/s]


Epoch 30: train_loss=0.0273, val_acc=0.7325, val_f1=0.6955


Epoch 31/32: 100%|██████████| 1153/1153 [00:32<00:00, 35.02it/s]
Evaluating RNN: 100%|██████████| 247/247 [00:01<00:00, 142.27it/s]


Epoch 31: train_loss=0.0306, val_acc=0.7287, val_f1=0.6785


Epoch 32/32: 100%|██████████| 1153/1153 [00:32<00:00, 35.28it/s]
Evaluating RNN: 100%|██████████| 247/247 [00:01<00:00, 141.60it/s]


Epoch 32: train_loss=0.0266, val_acc=0.7304, val_f1=0.6834


Evaluating RNN: 100%|██████████| 247/247 [00:01<00:00, 138.87it/s]

Test metrics: 0.725673794761483 0.6735423396087429
GRU test metrics: {'accuracy': 0.725673794761483, 'f1_macro': 0.6735423396087429, 'classification_report': {'Anxiety': {'precision': 0.7641509433962265, 'recall': 0.703125, 'f1-score': 0.7323688969258589, 'support': 576.0}, 'Bipolar': {'precision': 0.7482678983833718, 'recall': 0.7769784172661871, 'f1-score': 0.7623529411764706, 'support': 417.0}, 'Depression': {'precision': 0.6552567237163814, 'recall': 0.6958026828212894, 'f1-score': 0.6749213011542498, 'support': 2311.0}, 'Normal': {'precision': 0.8872, 'recall': 0.9045676998368679, 'f1-score': 0.8957996768982229, 'support': 2452.0}, 'Personality disorder': {'precision': 0.5189873417721519, 'recall': 0.5093167701863354, 'f1-score': 0.5141065830721003, 'support': 161.0}, 'Stress': {'precision': 0.5590551181102362, 'recall': 0.5489690721649485, 'f1-score': 0.5539661898569571, 'support': 388.0}, 'Suicidal': {'precision': 0.6116102280580511, 'recall': 0.5538172715894869, 'f1-score': 0.5




### Train Tiny Transformer Model

In [3]:
transformer_result = train_transformer_experiment(CSV_PATH)
print("Transformer test metrics:", transformer_result["test_metrics"])

Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Step,Training Loss,Validation Loss,Accuracy,F1 Macro
460,0.5778,0.525323,0.792837,0.748099


Error: command buffer exited with error status.
	The Metal Performance Shaders operations encoded on it may not have completed.
	Error: 
	(null)
	Insufficient Memory (00000008:kIOGPUCommandBufferCallbackErrorOutOfMemory)
	<AGXG15XFamilyCommandBuffer: 0x3e20cdcf0>
    label = <none> 
    device = <AGXG15SDevice: 0x3a0d6b200>
        name = Apple M3 Pro 
    commandQueue = <AGXG15XFamilyCommandQueue: 0x3a58cbe00>
        label = <none> 
        device = <AGXG15SDevice: 0x3a0d6b200>
            name = Apple M3 Pro 
    retainedReferences = 1


RuntimeError: MPS backend out of memory (MPS allocated: 36.86 GiB, other allocations: 6.01 GiB, max allowed: 47.74 GiB). Tried to allocate 5.79 GiB on private pool. Use PYTORCH_MPS_HIGH_WATERMARK_RATIO=0.0 to disable upper limit for memory allocations (may cause system failure).

## Inference & Predictions

### RNN Model

## Model Evaluation & Visualization

In [13]:
# Aggregate metrics for all models into a DataFrame
model_metrics = []

# Baselines
for name in ["majority", "tfidf_logreg", "tfidf_svm"]:
    m = baseline_results[name]
    model_metrics.append({
        "model": name,
        "accuracy": m["accuracy"],
        "f1_macro": m["f1_macro"],
    })

# RNNs
for name, res in [("rnn_lstm", rnn_result), ("rnn_gru", gru_result)]:
    m = res["test_metrics"]
    model_metrics.append({
        "model": name,
        "accuracy": m["accuracy"],
        "f1_macro": m["f1_macro"],
    })

# Tiny transformer
m = transformer_result["test_metrics"]
model_metrics.append({
    "model": "tiny_transformer",
    "accuracy": m["accuracy"],
    "f1_macro": m["f1_macro"],
})

metrics_df = pd.DataFrame(model_metrics)

# Ranking by F1-macro
metrics_df = metrics_df.sort_values("f1_macro", ascending=False).reset_index(drop=True)
metrics_df["rank"] = metrics_df.index + 1

print("Model ranking (by macro F1):")

display(metrics_df)

# Grouped bar chart for accuracy and F1-macro
plt.figure(figsize=(8, 5))
x = range(len(metrics_df))
bar_width = 0.35

plt.bar([i - bar_width/2 for i in x], metrics_df["accuracy"], width=bar_width, label="Accuracy")
plt.bar([i + bar_width/2 for i in x], metrics_df["f1_macro"], width=bar_width, label="F1-macro")

plt.xticks(list(x), metrics_df["model"], rotation=45)
plt.ylabel("Score")
plt.title("Model Performance Comparison")
plt.legend()
plt.tight_layout()
plt.show()

NameError: name 'transformer_result' is not defined

### Confusion Matrices per Model

In [None]:
from matplotlib import cm as cm_color

label_names = baseline_results["label_encoder"].classes_.tolist()

def plot_confusion_matrix(cm, labels, title):
    fig, ax = plt.subplots()
    im = ax.imshow(cm, interpolation="nearest", cmap=cm_color.Blues)
    ax.figure.colorbar(im, ax=ax)
    ax.set(
        xticks=range(len(labels)),
        yticks=range(len(labels)),
        xticklabels=labels,
        yticklabels=labels,
        ylabel="True label",
        xlabel="Predicted label",
        title=title,
    )
    plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor")
    # annotate
    for i in range(len(labels)):
        for j in range(len(labels)):
            ax.text(j, i, cm[i][j], ha="center", va="center", color="black")
    plt.tight_layout()
    return ax

# Collect confusion matrices
cms = [
    ("majority", baseline_results["majority"]["confusion_matrix"]),
    ("tfidf_logreg", baseline_results["tfidf_logreg"]["confusion_matrix"]),
    ("tfidf_svm", baseline_results["tfidf_svm"]["confusion_matrix"]),
    ("rnn_lstm", rnn_result["test_metrics"]["confusion_matrix"]),
    ("rnn_gru", gru_result["test_metrics"]["confusion_matrix"]),
    ("tiny_transformer", transformer_result["test_metrics"]["confusion_matrix"]),
]

# Plot in a grid
plt.figure(figsize=(14, 10))
for idx, (name, cm_vals) in enumerate(cms, start=1):
    plt.subplot(2, 3, idx)
    plt.imshow(cm_vals, interpolation="nearest", cmap=cm_color.Blues)
    plt.title(name)
    plt.xticks(range(len(label_names)), label_names, rotation=45, ha="right")
    plt.yticks(range(len(label_names)), label_names)
    for i in range(len(label_names)):
        for j in range(len(label_names)):
            plt.text(j, i, cm_vals[i][j], ha="center", va="center", color="black", fontsize=8)
    plt.xlabel("Predicted")
    plt.ylabel("True")
plt.tight_layout()
plt.show()

## Sample Predictions Across Models

In [None]:
# Prepare sample texts: one per distinct category (up to 5)
label_encoder = rnn_result["label_encoder"]
labels = label_encoder.classes_.tolist()

def make_sample_for_label(label: str) -> str:
    lower = label.lower()
    if "depress" in lower:
        return "I feel sad, empty, and unmotivated most days, like nothing is worth it."
    if "anx" in lower:
        return "My heart races all the time and I cannot stop worrying about everything."
    if "stress" in lower:
        return "Work and life are overwhelming right now, I feel under constant pressure."
    if "suic" in lower or "self" in lower:
        return "Lately I have been thinking that people would be better off without me."
    if "normal" in lower or "none" in lower or "control" in lower:
        return "I have been feeling okay lately, managing my responsibilities without much trouble."
    return f"This is an example statement that might be categorized as '{label}' in a mental health context."

sample_texts = []
for label in labels[:5]:
    sample_texts.append((label, make_sample_for_label(label)))

vec_svm = baseline_results["vec_svm"]
clf_svm = baseline_results["clf_svm"]
tokenizer = transformer_result["tokenizer"]
transformer_model = transformer_result["model"]
device = get_device()
transformer_model.to(device)
rnn_model = rnn_result["model"]
rnn_model.to(device)
word2idx = rnn_result["word2idx"]

print("Sample predictions (label, text, predicted by each model):\n")

for true_label, text in sample_texts:
    cleaned = clean_text(text)
    # Baseline SVM
    X_vec = vec_svm.transform([cleaned])
    svm_pred_idx = clf_svm.predict(X_vec)[0]
    svm_pred = label_encoder.inverse_transform([svm_pred_idx])[0]

    # RNN LSTM
    rnn_pred, rnn_conf = predict_text_rnn(
        text,
        rnn_model,
        word2idx,
        label_encoder,
        device,
    )

    # Tiny transformer
    enc = tokenizer(
        [text],
        truncation=True,
        padding=True,
        max_length=128,
        return_tensors="pt",
    )
    enc = {k: v.to(device) for k, v in enc.items()}
    with torch.no_grad():
        outputs = transformer_model(**enc)
        probs = torch.softmax(outputs.logits, dim=-1)[0].cpu().numpy()
        pred_idx = probs.argmax()
    transformer_pred = label_encoder.inverse_transform([pred_idx])[0]

    print(f"True label: {true_label}")
    print(f"Text      : {text}")
    print(f"  TF-IDF + SVM       → {svm_pred}")
    print(f"  RNN (LSTM)         → {rnn_pred} (conf={rnn_conf:.3f})")
    print(f"  Tiny Transformer   → {transformer_pred}")
    print("-" * 80)

## Qualitative Error Analysis (TF-IDF + SVM)

In [None]:
# Inspect a few misclassified examples for the TF-IDF + SVM model
from src.data.preprocess import load_data, encode_labels

df_all = load_data(CSV_PATH)
# Use the same label encoder as in training
le = baseline_results["label_encoder"]
df_all["label"] = le.transform(df_all["status"])

texts = df_all["clean_text"].tolist()
true_labels = df_all["label"].tolist()

vec_svm = baseline_results["vec_svm"]
clf_svm = baseline_results["clf_svm"]

X_all = vec_svm.transform(texts)
pred_labels = clf_svm.predict(X_all)

mis_indices = [i for i in range(len(texts)) if pred_labels[i] != true_labels[i]]

print(f"Total samples: {len(texts)}, misclassified by SVM: {len(mis_indices)}")

for idx in mis_indices[:5]:
    text = texts[idx]
    true_lab = le.inverse_transform([true_labels[idx]])[0]
    pred_lab = le.inverse_transform([pred_labels[idx]])[0]
    print("\n--- Misclassified example ---")
    print("Text         :", text)
    print("True label   :", true_lab)
    print("Predicted    :", pred_lab)