In [None]:
# -*- coding: utf-8 -*-
"""
Teacher PLM Training for BERT-ERC (RoBERTa) with Fine-Grained Classification Head
+ Freeze first 8 encoder layers (paper-style)
+ Robust teacher predictions export (train/dev with confidences)

- Tri-pooling (Fp, Fq, Ff) over token embeddings of:
    Fp = tokens before utterance span
    Fq = utterance span
    Ff = tokens after utterance span
- 2-layer MLP head (Tanh + Dropout) -> 7 emotions
- Exports: Dialogue_ID, Utterance_ID, Speaker, Utterance, pred_id, pred, conf
"""

# =========================
# 1) Setup (Colab-friendly)
# =========================
!pip -q install transformers datasets evaluate

import os
import torch
import pandas as pd
from torch.utils.data import TensorDataset
from transformers import (
    AutoTokenizer, RobertaConfig, Trainer, TrainingArguments,
    RobertaPreTrainedModel, RobertaModel
)
import torch.nn as nn
import evaluate

# (Optional) Colab Drive mount (safe: only mounts if not already mounted)
try:
    from google.colab import drive
    if not os.path.ismount('/content/drive'):
        drive.mount('/content/drive', force_remount=False)
    else:
        print("✅ Drive already mounted at /content/drive")
except Exception as e:
    print("ℹ️ Not running on Colab or Drive not available:", e)

# ======================
# 2) Configuration
# ======================
BASE_DIR           = "/content/drive/MyDrive/MELD"
RAW_TRAIN_CSV      = os.path.join(BASE_DIR, "train_with_context.csv")
RAW_DEV_CSV        = os.path.join(BASE_DIR, "dev_with_context.csv")
TEACHER_OUTPUT     = os.path.join(BASE_DIR, "teacher_roberta_fg_head_marked")
TEACHER_PRED_TRAIN = os.path.join(BASE_DIR, "teacher_predictions_train.csv")
TEACHER_PRED_DEV   = os.path.join(BASE_DIR, "teacher_predictions_dev.csv")

MODEL_CHECKPOINT   = "roberta-base"
EMOTIONS           = ["anger","disgust","fear","joy","neutral","sadness","surprise"]
NUM_LABELS         = len(EMOTIONS)

MAX_LEN            = 128
BATCH_SIZE         = 8
EPOCHS             = 4
LR                 = 9e-5
WEIGHT_DECAY       = 0.01
DEVICE             = torch.device("cuda" if torch.cuda.is_available() else "cpu")

os.makedirs(TEACHER_OUTPUT, exist_ok=True)

# ======================
# 3) Load data
# ======================
df_train = pd.read_csv(RAW_TRAIN_CSV)
df_dev   = pd.read_csv(RAW_DEV_CSV)

label2id = {e:i for i,e in enumerate(EMOTIONS)}
id2label = {i:e for e,i in label2id.items()}
df_train['label'] = df_train['Emotion'].map(label2id)
df_dev['label']   = df_dev['Emotion'].map(label2id)

# ======================
# 4) Tokenizer & Config
# ======================
tokenizer = AutoTokenizer.from_pretrained(MODEL_CHECKPOINT, use_fast=True)
config    = RobertaConfig.from_pretrained(MODEL_CHECKPOINT, num_labels=NUM_LABELS)

# ======================
# 5) Fine-grained head
# ======================
class FineGrainedHead(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_labels, dropout=0.3):
        super().__init__()
        self.fc1  = nn.Linear(input_dim, hidden_dim)
        self.act  = nn.Tanh()
        self.drop = nn.Dropout(dropout)
        self.fc2  = nn.Linear(hidden_dim, num_labels)
    def forward(self, x):
        x = self.act(self.fc1(x))
        x = self.drop(x)
        return self.fc2(x)

# ===============================================
# 6) Teacher model (RoBERTa + tri-pool + FG head)
# ===============================================
class RobertaTeacherFG(RobertaPreTrainedModel):
    def __init__(self, config):
        super().__init__(config)
        self.roberta = RobertaModel(config)
        H = config.hidden_size
        self.head = FineGrainedHead(
            input_dim=3*H,
            hidden_dim=H,
            num_labels=config.num_labels,
            dropout=config.hidden_dropout_prob
        )
        self.post_init()

    def forward(self, input_ids=None, attention_mask=None, spans=None, labels=None, **kwargs):
        out = self.roberta(input_ids=input_ids, attention_mask=attention_mask, return_dict=True)
        hs = out.last_hidden_state  # [B, S, H]
        B, S, H = hs.size()

        # tri-pooling
        Fp = torch.zeros(B, H, device=hs.device)
        Fq = torch.zeros(B, H, device=hs.device)
        Ff = torch.zeros(B, H, device=hs.device)

        # spans: [B, 2] with (a, b); may be (-1, -1) if not found
        for i, (a, b) in enumerate(spans.tolist()):
            if a > 0:
                Fp[i] = hs[i, :a].mean(0)
            if (b >= a) and (a >= 0):
                Fq[i] = hs[i, a:b+1].mean(0)
            if (b + 1) < S and b >= -1:
                Ff[i] = hs[i, b+1:].mean(0)

        cat    = torch.cat([Fp, Fq, Ff], dim=1)
        logits = self.head(cat)
        loss   = None
        if labels is not None:
            loss = nn.CrossEntropyLoss()(logits, labels)
        return {'loss': loss, 'logits': logits}

# ======================
# 7) Freeze first 8 layers
# ======================
def freeze_first_n_layers_roberta(model, n=8, freeze_embeddings=False, verbose=True):
    """
    Freeze the first `n` transformer blocks on a RoBERTa-based model.
    Expects `model.roberta` to be the backbone.
    """
    if not hasattr(model, "roberta"):
        raise ValueError("Expected the backbone at model.roberta")
    backbone = model.roberta

    if freeze_embeddings and hasattr(backbone, "embeddings"):
        for p in backbone.embeddings.parameters():
            p.requires_grad = False

    layers = backbone.encoder.layer
    n = min(n, len(layers))
    for i in range(n):
        for p in layers[i].parameters():
            p.requires_grad = False

    if verbose:
        print(f"🔒 Froze encoder layers [0..{n-1}] / {len(layers)} (freeze_embeddings={freeze_embeddings})")

# ======================
# 8) Build tensors
# ======================
def build_tensors(df):
    ids, masks, spans, labels = [], [], [], []
    for _, row in df.iterrows():
        text = row['bert_input']
        enc  = tokenizer(
            text,
            padding='max_length',
            truncation=True,
            max_length=MAX_LEN,
            return_tensors='pt'
        )
        ids.append(enc['input_ids'].squeeze(0))
        masks.append(enc['attention_mask'].squeeze(0))

        # find utterance span (simple exact sub-token match)
        token_ids = enc['input_ids'].squeeze(0).tolist()
        utt_enc   = tokenizer(row['Utterance'], add_special_tokens=False)['input_ids']

        try:
            a = next(i for i in range(len(token_ids) - len(utt_enc) + 1)
                     if token_ids[i:i+len(utt_enc)] == utt_enc)
            b = a + len(utt_enc) - 1
        except StopIteration:
            a, b = -1, -1  # not found

        spans.append(torch.tensor([a, b], dtype=torch.long))
        labels.append(torch.tensor(row['label'], dtype=torch.long))

    return (
        torch.stack(ids),
        torch.stack(masks),
        torch.stack(spans),
        torch.stack(labels)
    )

train_ids, train_masks, train_spans, train_labels = build_tensors(df_train)
dev_ids,   dev_masks,   dev_spans,   dev_labels   = build_tensors(df_dev)

train_ds = TensorDataset(train_ids, train_masks, train_spans, train_labels)
dev_ds   = TensorDataset(dev_ids,   dev_masks,   dev_spans,   dev_labels)

def collate_fn(batch):
    input_ids      = torch.stack([x[0] for x in batch])
    attention_mask = torch.stack([x[1] for x in batch])
    spans          = torch.stack([x[2] for x in batch])
    labels         = torch.stack([x[3] for x in batch])
    return {'input_ids': input_ids, 'attention_mask': attention_mask, 'spans': spans, 'labels': labels}

# ======================
# 9) Train
# ======================
model_teacher = RobertaTeacherFG.from_pretrained(MODEL_CHECKPOINT, config=config).to(DEVICE)
freeze_first_n_layers_roberta(model_teacher, n=8, freeze_embeddings=False, verbose=True)

acc_metric = evaluate.load('accuracy')
f1_metric  = evaluate.load('f1')

def compute_metrics(p):
    y_pred = p.predictions.argmax(-1)
    y_true = p.label_ids
    return {
        'acc': acc_metric.compute(predictions=y_pred, references=y_true)['accuracy'],
        'f1_weighted': f1_metric.compute(predictions=y_pred, references=y_true, average='weighted')['f1']
    }

args = TrainingArguments(
    output_dir=TEACHER_OUTPUT,
    eval_strategy='epoch',         # <- correct arg name
    save_strategy='epoch',
    learning_rate=LR,
    per_device_train_batch_size=BATCH_SIZE,
    per_device_eval_batch_size=BATCH_SIZE,
    num_train_epochs=EPOCHS,
    weight_decay=WEIGHT_DECAY,
    load_best_model_at_end=True,
    metric_for_best_model='f1_weighted',
    dataloader_pin_memory=False,
    report_to=["none"],
)

trainer = Trainer(
    model=model_teacher,
    args=args,
    train_dataset=train_ds,
    eval_dataset=dev_ds,
    data_collator=collate_fn,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics
)

trainer.train()
trainer.save_model(TEACHER_OUTPUT)
tokenizer.save_pretrained(TEACHER_OUTPUT)

print("✅ Training done. Model & tokenizer saved to:", TEACHER_OUTPUT)

# =======================================
# 10) Export teacher predictions (robust)
# =======================================
def _pick(df_src, names):
    low = {c.lower(): c for c in df_src.columns}
    for n in names:
        if n in df_src.columns: return n
        if n.lower() in low: return low[n.lower()]
    raise KeyError(f"Need one of {names}, have {list(df_src.columns)}")

def export_teacher_preds(df_src, ids, masks, spans, out_csv):
    # keep exact row order
    dummy_labels = torch.zeros(len(df_src), dtype=torch.long)
    ds = TensorDataset(ids, masks, spans, dummy_labels)

    pred_out = trainer.predict(ds)
    logits = torch.from_numpy(pred_out.predictions)
    probs  = torch.softmax(logits, dim=-1).numpy()
    yhat   = probs.argmax(axis=-1)
    conf   = probs.max(axis=-1)

    C_DID = _pick(df_src, ["Dialogue_ID","dialogue_id","Conversation_ID","conv_id"])
    C_UID = _pick(df_src, ["Utterance_ID","utterance_id","Utterance_ID_in_Dialogue","utt_id"])
    C_SPK = _pick(df_src, ["Speaker","speaker","speaker_id"])
    C_UTT = _pick(df_src, ["Utterance","utterance","text"])

    out = pd.DataFrame({
        "Dialogue_ID":  df_src[C_DID].values,
        "Utterance_ID": df_src[C_UID].values,
        "Speaker":      df_src[C_SPK].astype(str).values,
        "Utterance":    df_src[C_UTT].astype(str).values,
        "pred_id":      yhat,
        "pred":         [EMOTIONS[i] for i in yhat],
        "conf":         conf
    })
    out.to_csv(out_csv, index=False)
    print(f"✓ Saved {out_csv} (rows: {len(out)})")

export_teacher_preds(df_train, train_ids, train_masks, train_spans, TEACHER_PRED_TRAIN)
export_teacher_preds(df_dev,   dev_ids,   dev_masks,   dev_spans,   TEACHER_PRED_DEV)

print("✅ Done. Exports:")
print("  -", TEACHER_PRED_TRAIN)
print("  -", TEACHER_PRED_DEV)
