In [1]:
import pandas as pd
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import numpy as np
import os

# ==== Paths ====
model_path = "../../MentalBert/mentalbert_fine_tuned_6class_learningrate_5e-5_weightdecay_0.01/best_model"
data_path = "../../Data for analysis/data_for_further_analysis.csv"
output_csv = "../Output/mentalbert_7labels_predictions.csv"

os.makedirs(os.path.dirname(output_csv), exist_ok=True)

# ==== Device ====
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

Device: cpu


In [2]:
tokenizer = AutoTokenizer.from_pretrained(model_path, local_files_only=True)
model = AutoModelForSequenceClassification.from_pretrained(model_path, local_files_only=True)
model.to(device)
model.eval()

print("✅ MentalBERT 7-label model loaded successfully.")

✅ MentalBERT 7-label model loaded successfully.


In [3]:
df = pd.read_csv(data_path)
print("Dataset shape:", df.shape)
df.head()

Dataset shape: (95250, 8)


Unnamed: 0,subreddit,author,date,clean_text,title,tokens,score,url
0,ADHD,crabgal,2024-12-12,psychiatrist decided switch concerta strattera...,Does Strattera work for anyone?,"['psychiatrist', 'decided', 'switch', 'concert...",1,https://www.reddit.com/r/ADHD/comments/1hcl16z...
1,ADHD,Thirdworldnerd02,2024-12-12,’ trouble uni spoke favorite teacher another t...,How can I become a better student and stop fee...,"['’', 'trouble', 'uni', 'spoke', 'favorite', '...",2,https://www.reddit.com/r/ADHD/comments/1hc3g0n...
2,ADHD,brohno,2024-12-12,socially shy growing usually quiet situation n...,able to read social cues fine but not know how...,"['socially', 'shy', 'growing', 'usually', 'qui...",7,https://www.reddit.com/r/ADHD/comments/1hc30d4...
3,ADHD,Munchkin_Hound,2024-12-12,please help 19 undiagnosed waiting list tested...,Focussing and Concentrating on dull tasks is a...,"['please', 'help', '19', 'undiagnosed', 'waiti...",1,https://www.reddit.com/r/ADHD/comments/1hc346v...
4,ADHD,pinklemon36,2024-12-12,random request label maker adhd struggle organ...,random - label maker recommendations,"['random', 'request', 'label', 'maker', 'adhd'...",1,https://www.reddit.com/r/ADHD/comments/1hc3axk...


In [4]:
# Combine title + clean_text for richer context (same as Roberta & 28-label setup)
df["combined_text"] = (
    df["title"].fillna("") + ". " + df["clean_text"].fillna("")
).str.strip()

texts = df["combined_text"].astype(str).tolist()
print("Total posts for prediction:", len(texts))

Total posts for prediction: 95250


In [5]:
def predict_classes(texts, batch_size=16, max_len=128):
    all_probs = []
    for i in range(0, len(texts), batch_size):
        batch_texts = texts[i:i + batch_size]
        enc = tokenizer(
            batch_texts,
            padding=True,
            truncation=True,
            max_length=max_len,
            return_tensors="pt"
        ).to(device)

        with torch.no_grad():
            logits = model(**enc).logits
            probs = torch.nn.functional.softmax(logits, dim=-1).cpu().numpy()  # single-label → softmax

        all_probs.append(probs)

    return np.vstack(all_probs)

In [6]:
print("Running predictions using 7-label MentalBERT...")
probs = predict_classes(texts, batch_size=32)   # can safely increase batch size on GPU
print("Predictions shape:", probs.shape)

Running predictions using 7-label MentalBERT...
Predictions shape: (95250, 7)


In [7]:
# ==== Label set (6 + neutral) ====
# ==== Label set (6 + neutral, Ekman + Neutral, correct order) ====
emotion_labels = [
    "anger",    # id 0
    "disgust",  # id 1
    "fear",     # id 2
    "joy",      # id 3
    "neutral",  # id 4
    "sadness",  # id 5
    "surprise"  # id 6
]

# Create DataFrame and ensure numeric dtype
pred_df = pd.DataFrame(probs, columns=emotion_labels)
pred_df = pred_df.apply(pd.to_numeric, errors="coerce")

# Predicted & top3 emotions
pred_df["predicted_emotion"] = pred_df.idxmax(axis=1)
pred_df["top3_emotions"] = pred_df[emotion_labels].apply(
    lambda row: ", ".join(row.astype(float).nlargest(3).index.tolist()), axis=1
)

# Merge with base Reddit info
final_df = pd.concat(
    [df[["subreddit", "title", "clean_text", "url"]].reset_index(drop=True), pred_df],
    axis=1
)

print("✅ Final DataFrame shape:", final_df.shape)
final_df.head()

✅ Final DataFrame shape: (95250, 13)


Unnamed: 0,subreddit,title,clean_text,url,anger,disgust,fear,joy,neutral,sadness,surprise,predicted_emotion,top3_emotions
0,ADHD,Does Strattera work for anyone?,psychiatrist decided switch concerta strattera...,https://www.reddit.com/r/ADHD/comments/1hcl16z...,0.00099,0.00028,0.000411,0.001127,0.985838,0.000148,0.011205,neutral,"neutral, surprise, joy"
1,ADHD,How can I become a better student and stop fee...,’ trouble uni spoke favorite teacher another t...,https://www.reddit.com/r/ADHD/comments/1hc3g0n...,0.273112,0.005573,0.001122,0.002145,0.373796,0.311687,0.032565,neutral,"neutral, sadness, anger"
2,ADHD,able to read social cues fine but not know how...,socially shy growing usually quiet situation n...,https://www.reddit.com/r/ADHD/comments/1hc30d4...,0.007374,0.001338,0.004227,0.00966,0.906869,0.011406,0.059126,neutral,"neutral, surprise, sadness"
3,ADHD,Focussing and Concentrating on dull tasks is a...,please help 19 undiagnosed waiting list tested...,https://www.reddit.com/r/ADHD/comments/1hc346v...,0.011763,0.001967,0.550103,0.004271,0.412945,0.014081,0.004871,fear,"fear, neutral, sadness"
4,ADHD,random - label maker recommendations,random request label maker adhd struggle organ...,https://www.reddit.com/r/ADHD/comments/1hc3axk...,0.000936,0.000246,0.000197,0.00253,0.995199,0.000197,0.000695,neutral,"neutral, joy, anger"


In [8]:
final_df.to_csv(output_csv, index=False)
print(f"✅ Predictions saved to: {output_csv}")

✅ Predictions saved to: ../Output/mentalbert_7labels_predictions.csv
