In [None]:
import random
import os
import re
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

from tensorflow.keras.preprocessing.text import Tokenizer
from tensorflow.keras.preprocessing.sequence import pad_sequences

from transformers import DistilBertTokenizer, DistilBertModel, logging as hf_logging

hf_logging.set_verbosity_error()  # silence some HF logs

# ------------------- Reproducibility -------------------
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

# ------------------- 1. Load & Clean Dataset -------------------
df = pd.read_csv("/content/Mental Health Disorder Detection Dataset.csv")
df.dropna(subset=['body', 'category'], inplace=True)

def clean_text(text):
    text = str(text).lower()
    text = re.sub(r'[^a-z\s]', '', text)
    text = re.sub(r'\s+', ' ', text).strip()
    return text

df['body'] = df['body'].apply(clean_text)

# Encode labels
label_enc = LabelEncoder()
df['category'] = label_enc.fit_transform(df['category'])
num_classes = len(label_enc.classes_)

# initial train-test split
train_texts, test_texts, train_labels, test_labels = train_test_split(
    df['body'], df['category'], test_size=0.2, stratify=df['category'], random_state=SEED
)

# create validation set from train
train_texts, val_texts, train_labels, val_labels = train_test_split(
    train_texts, train_labels, test_size=0.1, stratify=train_labels, random_state=SEED
)

# ------------------- 2. LSTM (tokenizer + sequences) -------------------
MAX_WORDS = 20000
MAX_LEN_LSTM = 100

tokenizer_lstm = Tokenizer(num_words=MAX_WORDS, oov_token="<OOV>")
tokenizer_lstm.fit_on_texts(train_texts)

def texts_to_padded_tensor(texts):
    seqs = tokenizer_lstm.texts_to_sequences(texts)
    padded = pad_sequences(seqs, maxlen=MAX_LEN_LSTM, padding='post')
    return torch.tensor(padded, dtype=torch.long)

X_train_lstm = texts_to_padded_tensor(list(train_texts))
X_val_lstm = texts_to_padded_tensor(list(val_texts))
X_test_lstm = texts_to_padded_tensor(list(test_texts))

y_train_tensor = torch.tensor(train_labels.values, dtype=torch.long)
y_val_tensor = torch.tensor(val_labels.values, dtype=torch.long)
y_test_tensor = torch.tensor(test_labels.values, dtype=torch.long)

# ------------------- 3. GloVe embeddings -------------------
# Download "glove.6B.100d.txt" and provide path
GLOVE_PATH = "glove.6B.100d.txt"
EMBED_DIM = 100

embedding_index = {}
with open(GLOVE_PATH, encoding="utf8") as f:
    for line in f:
        values = line.split()
        word = values[0]
        coefs = np.asarray(values[1:], dtype="float32")
        embedding_index[word] = coefs

word_index = tokenizer_lstm.word_index
vocab_size = min(MAX_WORDS, len(word_index) + 1)
embedding_matrix = np.zeros((vocab_size, EMBED_DIM), dtype=np.float32)

for word, i in word_index.items():
    if i < vocab_size:
        vec = embedding_index.get(word)
        if vec is not None:
            embedding_matrix[i] = vec

embedding_matrix = torch.tensor(embedding_matrix)  # float32

# ------------------- 4. DistilBERT preprocessing -------------------
MAX_LEN_BERT = 128
tokenizer_bert = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")

def encode_texts(texts):
    # return PyTorch tensors
    return tokenizer_bert(list(texts), truncation=True, padding='max_length',
                          max_length=MAX_LEN_BERT, return_tensors="pt")

train_encodings = encode_texts(train_texts)
val_encodings = encode_texts(val_texts)
test_encodings = encode_texts(test_texts)

# ------------------- 5. Dataset classes -------------------
class LSTMDataset(Dataset):
    def __init__(self, X, y):
        self.X = X
        self.y = y
    def __len__(self):
        return len(self.X)
    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]

class BERTDataset(Dataset):
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels
    def __len__(self):
        return len(self.labels)
    def __getitem__(self, idx):
        item = {k: v[idx] for k, v in self.encodings.items()}
        return item, self.labels[idx]

BATCH_LSTM = 32
BATCH_BERT = 16

train_loader_lstm = DataLoader(LSTMDataset(X_train_lstm, y_train_tensor), batch_size=BATCH_LSTM, shuffle=True)
val_loader_lstm = DataLoader(LSTMDataset(X_val_lstm, y_val_tensor), batch_size=BATCH_LSTM)
test_loader_lstm = DataLoader(LSTMDataset(X_test_lstm, y_test_tensor), batch_size=BATCH_LSTM)

train_loader_bert = DataLoader(BERTDataset(train_encodings, y_train_tensor), batch_size=BATCH_BERT, shuffle=True)
val_loader_bert = DataLoader(BERTDataset(val_encodings, y_val_tensor), batch_size=BATCH_BERT)
test_loader_bert = DataLoader(BERTDataset(test_encodings, y_test_tensor), batch_size=BATCH_BERT)

# ------------------- 6. Models -------------------
class LSTMClassifier(nn.Module):
    def __init__(self, embedding_matrix, hidden_dim, num_classes, num_layers=2, bidirectional=True, freeze_embeddings=False):
        super().__init__()
        vocab_size, embed_dim = embedding_matrix.shape
        # use Embedding.from_pretrained so we can choose freeze
        self.embedding = nn.Embedding.from_pretrained(embedding_matrix, freeze=freeze_embeddings, padding_idx=0)
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers
        self.bidirectional = bidirectional
        self.num_directions = 2 if bidirectional else 1
        self.lstm = nn.LSTM(embed_dim, hidden_dim, num_layers=num_layers, batch_first=True, bidirectional=bidirectional)
        self.fc = nn.Linear(hidden_dim * self.num_directions, num_classes)

    def forward(self, x):
        # x: (batch, seq_len)
        emb = self.embedding(x)  # (batch, seq_len, embed_dim)
        _, (h_n, _) = self.lstm(emb)  # h_n: (num_layers * num_directions, batch, hidden_dim)
        # reshape to (num_layers, num_directions, batch, hidden_dim)
        h_n = h_n.view(self.num_layers, self.num_directions, h_n.size(1), self.hidden_dim)
        last_layer = h_n[-1]  # (num_directions, batch, hidden_dim)
        if self.bidirectional:
            # concat forward and backward
            last = torch.cat([last_layer[0], last_layer[1]], dim=1)  # (batch, hidden_dim*2)
        else:
            last = last_layer[0]  # (batch, hidden_dim)
        return self.fc(last)  # (batch, num_classes)

class DistilBERTClassifier(nn.Module):
    def __init__(self, num_classes, freeze_bert=False):
        super().__init__()
        self.bert = DistilBertModel.from_pretrained("distilbert-base-uncased")
        if freeze_bert:
            for p in self.bert.parameters():
                p.requires_grad = False
        self.fc = nn.Linear(self.bert.config.hidden_size, num_classes)

    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        # DistilBERT does not have pooled_output; we can use first token's hidden state (CLS-like)
        hidden_state = outputs.last_hidden_state[:, 0]  # (batch, hidden_size)
        return self.fc(hidden_state)

# ------------------- 7. Initialize models, loss, optimizers, schedulers -------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# LSTM: allow embeddings to be fine-tuned (freeze_embeddings=False)
model_lstm = LSTMClassifier(embedding_matrix, hidden_dim=128, num_classes=num_classes,
                            num_layers=2, bidirectional=True, freeze_embeddings=False).to(device)
model_bert = DistilBERTClassifier(num_classes=num_classes, freeze_bert=False).to(device)

criterion = nn.CrossEntropyLoss()

optimizer_lstm = optim.Adam(model_lstm.parameters(), lr=1e-3)
optimizer_bert = optim.Adam(model_bert.parameters(), lr=2e-5)

# LR schedulers (ReduceLROnPlateau based on validation loss)
scheduler_lstm = optim.lr_scheduler.ReduceLROnPlateau(optimizer_lstm, mode='min', factor=0.5, patience=1, verbose=True)
scheduler_bert = optim.lr_scheduler.ReduceLROnPlateau(optimizer_bert, mode='min', factor=0.5, patience=1, verbose=True)

# ------------------- 8. Training / evaluation helpers -------------------
def evaluate_lstm(model, data_loader, device):
    model.eval()
    total_loss = 0.0
    correct = 0
    total = 0
    with torch.no_grad():
        for X_batch, y_batch in data_loader:
            X_batch = X_batch.to(device)
            y_batch = y_batch.to(device)
            outputs = model(X_batch)
            loss = criterion(outputs, y_batch)
            total_loss += loss.item() * X_batch.size(0)
            preds = outputs.argmax(dim=1)
            correct += (preds == y_batch).sum().item()
            total += X_batch.size(0)
    return total_loss / total, correct / total

def evaluate_bert(model, data_loader, device):
    model.eval()
    total_loss = 0.0
    correct = 0
    total = 0
    with torch.no_grad():
        for batch in data_loader:
            inputs, labels = batch
            inputs = {k: v.to(device) for k, v in inputs.items()}
            labels = labels.to(device)
            outputs = model(inputs['input_ids'], inputs['attention_mask'])
            loss = criterion(outputs, labels)
            total_loss += loss.item() * labels.size(0)
            preds = outputs.argmax(dim=1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)
    return total_loss / total, correct / total

# Utility to get probabilities for ensemble (softmax)
softmax = nn.Softmax(dim=1)

def predict_proba_lstm(model, data_loader, device):
    model.eval()
    probs_list = []
    labels_list = []
    with torch.no_grad():
        for X_batch, y_batch in data_loader:
            X_batch = X_batch.to(device)
            outputs = model(X_batch)
            probs = softmax(outputs).cpu().numpy()
            probs_list.append(probs)
            labels_list.append(y_batch.numpy())
    return np.vstack(probs_list), np.concatenate(labels_list)

def predict_proba_bert(model, data_loader, device):
    model.eval()
    probs_list = []
    labels_list = []
    with torch.no_grad():
        for batch in data_loader:
            inputs, labels = batch
            inputs = {k: v.to(device) for k, v in inputs.items()}
            outputs = model(inputs['input_ids'], inputs['attention_mask'])
            probs = softmax(outputs).cpu().numpy()
            probs_list.append(probs)
            labels_list.append(labels.numpy())
    return np.vstack(probs_list), np.concatenate(labels_list)

# ------------------- 9. Train loops with early stopping -------------------
EPOCHS_LSTM = 10
EPOCHS_BERT = 4
PATIENCE = 3  # early stopping patience based on val loss
CLIP_NORM = 1.0

best_val_loss_lstm = float('inf')
best_val_loss_bert = float('inf')
patience_counter_lstm = 0
patience_counter_bert = 0

# Checkpoint paths
os.makedirs("checkpoints", exist_ok=True)
path_lstm = "checkpoints/best_lstm.pt"
path_bert = "checkpoints/best_bert.pt"

# Train LSTM
for epoch in range(EPOCHS_LSTM):
    model_lstm.train()
    total_loss = 0.0
    for X_batch, y_batch in train_loader_lstm:
        X_batch = X_batch.to(device)
        y_batch = y_batch.to(device)
        optimizer_lstm.zero_grad()
        outputs = model_lstm(X_batch)
        loss = criterion(outputs, y_batch)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model_lstm.parameters(), CLIP_NORM)
        optimizer_lstm.step()
        total_loss += loss.item() * X_batch.size(0)
    train_loss = total_loss / len(train_loader_lstm.dataset)
    val_loss, val_acc = evaluate_lstm(model_lstm, val_loader_lstm, device)
    print(f"[LSTM] Epoch {epoch+1} Train Loss: {train_loss:.4f} Val Loss: {val_loss:.4f} Val Acc: {val_acc:.4f}")

    scheduler_lstm.step(val_loss)
    # early stopping & checkpoint
    if val_loss < best_val_loss_lstm - 1e-6:
        best_val_loss_lstm = val_loss
        torch.save(model_lstm.state_dict(), path_lstm)
        patience_counter_lstm = 0
        print("  -> Saved best LSTM model")
    else:
        patience_counter_lstm += 1
        if patience_counter_lstm >= PATIENCE:
            print("  -> Early stopping LSTM")
            break

# load best LSTM
model_lstm.load_state_dict(torch.load(path_lstm))

# Train BERT
for epoch in range(EPOCHS_BERT):
    model_bert.train()
    total_loss = 0.0
    for batch in train_loader_bert:
        inputs, labels = batch
        inputs = {k: v.to(device) for k, v in inputs.items()}
        labels = labels.to(device)
        optimizer_bert.zero_grad()
        outputs = model_bert(inputs['input_ids'], inputs['attention_mask'])
        loss = criterion(outputs, labels)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model_bert.parameters(), CLIP_NORM)
        optimizer_bert.step()
        total_loss += loss.item() * labels.size(0)
    train_loss = total_loss / len(train_loader_bert.dataset)
    val_loss, val_acc = evaluate_bert(model_bert, val_loader_bert, device)
    print(f"[BERT] Epoch {epoch+1} Train Loss: {train_loss:.4f} Val Loss: {val_loss:.4f} Val Acc: {val_acc:.4f}")

    scheduler_bert.step(val_loss)
    # early stopping & checkpoint
    if val_loss < best_val_loss_bert - 1e-6:
        best_val_loss_bert = val_loss
        torch.save(model_bert.state_dict(), path_bert)
        patience_counter_bert = 0
        print("  -> Saved best BERT model")
    else:
        patience_counter_bert += 1
        if patience_counter_bert >= PATIENCE:
            print("  -> Early stopping BERT")
            break

# load best BERT
model_bert.load_state_dict(torch.load(path_bert))

# ------------------- 10. Find best ensemble weight on validation set -------------------
# get probabilities for val set
lstm_val_probs, val_labels_arr = predict_proba_lstm(model_lstm, val_loader_lstm, device)
bert_val_probs, _ = predict_proba_bert(model_bert, val_loader_bert, device)

# grid search for best weight w in [0,1] -> final = w*lstm + (1-w)*bert
best_w = None
best_acc = -1.0
for w in np.linspace(0.0, 1.0, 21):
    final_probs = w * lstm_val_probs + (1.0 - w) * bert_val_probs
    preds = final_probs.argmax(axis=1)
    acc = (preds == val_labels_arr).mean()
    if acc > best_acc:
        best_acc = acc
        best_w = w
print(f"Best ensemble weight on validation: w_lstm = {best_w:.2f}, val_acc = {best_acc:.4f}")

# ------------------- 11. Evaluate ensemble on test set -------------------
lstm_test_probs, test_labels_arr = predict_proba_lstm(model_lstm, test_loader_lstm, device)
bert_test_probs, _ = predict_proba_bert(model_bert, test_loader_bert, device)

final_test_probs = best_w * lstm_test_probs + (1.0 - best_w) * bert_test_probs
test_preds = final_test_probs.argmax(axis=1)
test_acc = (test_preds == test_labels_arr).mean()
print(f"Ensemble Test Accuracy (w_lstm={best_w:.2f}): {100.0 * test_acc:.2f}%")

# Optional: show per-model test accuracies
_, lstm_test_acc = evaluate_lstm(model_lstm, test_loader_lstm, device)
_, bert_test_acc = evaluate_bert(model_bert, test_loader_bert, device)
print(f"LSTM Test Acc: {100.0 * lstm_test_acc:.2f}%")
print(f"BERT Test Acc: {100.0 * bert_test_acc:.2f}%")
