In [None]:
import os
import json
import random
import time
import psutil
import gc
from datetime import datetime
from collections import Counter

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
from transformers import AutoTokenizer, AutoModel, get_linear_schedule_with_warmup
from sklearn.metrics import (
    classification_report,
    confusion_matrix,
    accuracy_score,
    f1_score,
    precision_recall_fscore_support,
)

# ==================== CONFIGURATION ====================
INLEGALBERT_MODEL_NAME = "law-ai/InLegalBERT"
TRAIN_PATH = "build_jsonl/build_train.jsonl"
DEV_PATH = "build_jsonl/build_dev.jsonl"
TEST_PATH = "build_jsonl/build_test.jsonl"

OUT_DIR = "prompt_tuning_enhanced_macro_f1"
os.makedirs(OUT_DIR, exist_ok=True)

SEED = 42
MAX_SEQ_LENGTH = 128
MAX_SENTS_PER_DOC = 64
BATCH_DOCS = 4

# ==================== PROMPT CONFIGURATION ====================
# Hard Prompt (prepended as text - provides legal context)
HARD_LEGAL_PROMPT = """Indian Legal Judgment Rhetorical Role Classification System:
Task: Classify each sentence into one of 13 legal rhetorical roles.
Classes: PREAMBLE (document introduction), FAC (facts of the case), 
RLC (rulings by lower court), ISSUE (key legal issues), 
ARG_PETITIONER (petitioner arguments), ARG_RESPONDENT (respondent arguments),
ANALYSIS (court's reasoning), STA (statutes cited), 
PRE_RELIED (precedents followed), PRE_NOT_RELIED (precedents distinguished),
RATIO (legal principles established), RPC (rulings by present court), NONE (other).
Context: Indian Supreme Court judgments follow specific rhetorical structure.
Sentence to classify:"""

# Soft Prompt Configuration
NUM_SOFT_PROMPT_TOKENS = 35  # Learnable continuous embeddings
CLASS_PROMPT_TOKENS = 20     # Per-class specific prompts
CONTEXT_PROMPT_TOKENS = 10   # Context-adaptive prompts

# ==================== TRAINING HYPERPARAMETERS ====================
NUM_EPOCHS = 20
PROMPT_LR = 8e-4      # High LR for prompt parameters
CLASSIFIER_LR = 3e-5  # Low LR for classifier layers
WEIGHT_DECAY = 0.01
GRAD_CLIP = 1.0
WARMUP_RATIO = 0.15

# Loss Configuration
FOCAL_GAMMA = 2.5
FOCAL_ALPHA = 0.25
LABEL_SMOOTHING = 0.08
PROTO_WEIGHT = 0.18
RPL_WEIGHT = 0.08
RTM_LAMBDA = 0.03
PROTO_AUX_TEMPERATURE = 5.0

# Minority Class Strategy
MINORITY_BOOST = 6.0  # Aggressive boosting for macro F1
USE_WEIGHTED_SAMPLER = True
USE_POSITIONAL_EMB = True
POS_EMB_DIM = 32
USE_KNN_PRIOR = True
KNN_K = 3
KNN_PRIOR_DIM = 64

# Architecture
LSTM_HIDDEN = 512
DROPOUT = 0.35

LABELS = [
    "PREAMBLE", "FAC", "RLC", "ISSUE", "ARG_PETITIONER",
    "ARG_RESPONDENT", "ANALYSIS", "STA", "PRE_RELIED",
    "PRE_NOT_RELIED", "RATIO", "RPC", "NONE",
]

label2id = {label: idx for idx, label in enumerate(LABELS)}
id2label = {idx: label for label, idx in label2id.items()}
NUM_LABELS = len(LABELS)

MINORITY_CLASSES = ["RLC", "ISSUE", "STA", "RATIO", "PRE_RELIED", "PRE_NOT_RELIED", "RPC"]
minority_ids = [label2id[label] for label in MINORITY_CLASSES if label in label2id]

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print("="*80)
print("ðŸš€ ADVANCED PROMPT TUNING FOR LEGAL CLASSIFICATION")
print("="*80)
print(f"Device: {DEVICE}")
print(f"Hard Prompt: {len(HARD_LEGAL_PROMPT.split())} words")
print(f"Soft Prompt Tokens: {NUM_SOFT_PROMPT_TOKENS}")
print(f"Class-Specific Prompts: {CLASS_PROMPT_TOKENS} tokens Ã— {NUM_LABELS} classes")
print(f"Minority Boost Factor: {MINORITY_BOOST}x")
print(f"Target: Macro F1 > 0.59")
print("="*80)

# ==================== UTILITIES ====================
def set_seed(seed=SEED):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

set_seed()

def load_jsonl(path):
    if not os.path.exists(path):
        raise FileNotFoundError(f"File not found: {path}")
    data = []
    with open(path, "r", encoding="utf-8") as f:
        for line in f:
            s = line.strip()
            if s:
                data.append(json.loads(s))
    return data

def extract_data(docs, max_sents=MAX_SENTS_PER_DOC):
    all_sents, all_labels, doc_ids = [], [], []
    for doc in docs:
        doc_id = doc.get("id", "")
        sents, labs = [], []
        
        if "sentences" in doc and "labels" in doc:
            sents = doc["sentences"]
            labs = [label2id.get(l, label2id["NONE"]) for l in doc["labels"]]
        elif "sentences" in doc and "annotation" in doc:
            sents = doc["sentences"]
            labs = [label2id.get(l, label2id["NONE"]) for l in doc["annotation"]]
        elif "annotations" in doc:
            for a in doc.get("annotations", []):
                for item in a.get("result", []):
                    val = item.get("value", {})
                    text = val.get("text", "").strip()
                    labs_list = val.get("labels", ["NONE"])
                    if text:
                        sents.append(text)
                        labs.append(label2id.get(labs_list[0], label2id["NONE"]))
        
        if len(sents) > max_sents:
            sents = sents[:max_sents]
            labs = labs[:max_sents]
        
        if sents and labs and len(sents) == len(labs):
            all_sents.append(sents)
            all_labels.append(labs)
            doc_ids.append(doc_id)
    
    return all_sents, all_labels, doc_ids

def get_memory_usage(device=DEVICE):
    if torch.cuda.is_available() and device.type == 'cuda':
        return {
            'allocated': torch.cuda.memory_allocated(device) / 1024**2,
            'reserved': torch.cuda.memory_reserved(device) / 1024**2,
            'max_allocated': torch.cuda.max_memory_allocated(device) / 1024**2
        }
    else:
        process = psutil.Process()
        return {'cpu_ram_mb': process.memory_info().rss / 1024**2}

def compute_detailed_metrics(y_true, y_pred):
    y_true = np.asarray(y_true)
    y_pred = np.asarray(y_pred)
    
    precision, recall, f1, support = precision_recall_fscore_support(
        y_true, y_pred, average=None, zero_division=0, labels=range(NUM_LABELS)
    )
    
    macro_f1 = f1_score(y_true, y_pred, average='macro', zero_division=0)
    weighted_f1 = f1_score(y_true, y_pred, average='weighted', zero_division=0)
    
    # Minority class metrics
    minority_mask = np.isin(y_true, minority_ids)
    if minority_mask.sum() > 0:
        minority_true = y_true[minority_mask]
        minority_pred = y_pred[minority_mask]
        minority_f1 = f1_score(minority_true, minority_pred, average='macro', zero_division=0)
    else:
        minority_f1 = 0.0
    
    minority_f1_per_class = {}
    for cls_id in minority_ids:
        mask = (y_true == cls_id)
        if mask.sum() > 1:
            try:
                minority_f1_per_class[id2label[cls_id]] = f1_score(
                    y_true[mask], y_pred[mask], average='binary', zero_division=0
                )
            except:
                minority_f1_per_class[id2label[cls_id]] = 0.0
        else:
            minority_f1_per_class[id2label[cls_id]] = 0.0
    
    return {
        'accuracy': float(accuracy_score(y_true, y_pred)),
        'macro_f1': float(macro_f1),
        'weighted_f1': float(weighted_f1),
        'macro_precision': float(np.mean(precision)),
        'macro_recall': float(np.mean(recall)),
        'minority_macro_f1': float(minority_f1),
        'minority_f1_per_class': minority_f1_per_class,
        'per_class_precision': {id2label[i]: float(precision[i]) for i in range(NUM_LABELS)},
        'per_class_recall': {id2label[i]: float(recall[i]) for i in range(NUM_LABELS)},
        'per_class_f1': {id2label[i]: float(f1[i]) for i in range(NUM_LABELS)},
        'per_class_support': {id2label[i]: int(support[i]) for i in range(NUM_LABELS)},
    }

# ==================== PROMPT TUNING MODULE ====================
class AdvancedPromptTuning(nn.Module):
    """
    Multi-Level Prompt Tuning:
    1. Soft Prompts: Learnable continuous embeddings
    2. Class-Specific Prompts: Per-class trainable prompts
    3. Context-Adaptive Selection: Dynamic prompt weighting
    """
    def __init__(self, hidden_size, num_soft_tokens=NUM_SOFT_PROMPT_TOKENS,
                 class_tokens=CLASS_PROMPT_TOKENS, context_tokens=CONTEXT_PROMPT_TOKENS):
        super().__init__()
        self.hidden_size = hidden_size
        self.num_soft_tokens = num_soft_tokens
        self.class_tokens = class_tokens
        
        # 1. Global Soft Prompts (shared across all inputs)
        # Initialize with small random values for stability
        self.global_soft_prompt = nn.Parameter(
            torch.randn(1, num_soft_tokens, hidden_size) * 0.02
        )
        
        # 2. Class-Specific Soft Prompts (one per class)
        self.class_specific_prompts = nn.Parameter(
            torch.randn(NUM_LABELS, class_tokens, hidden_size) * 0.02
        )
        
        # 3. Context-Adaptive Prompt Selector
        # Learns to weight different prompts based on input
        self.prompt_selector = nn.Sequential(
            nn.Linear(hidden_size, 256),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(128, NUM_LABELS),
            nn.Softmax(dim=-1)
        )
        
        # 4. Prompt Fusion Layer
        self.prompt_fusion = nn.Sequential(
            nn.Linear(hidden_size * 2, hidden_size),
            nn.LayerNorm(hidden_size),
            nn.ReLU(),
            nn.Dropout(0.15)
        )
        
        # 5. Gating mechanism for prompt influence
        self.prompt_gate = nn.Sequential(
            nn.Linear(hidden_size, hidden_size),
            nn.Sigmoid()
        )
        
        print(f"âœ… Prompt Tuning Module Initialized:")
        print(f"   - Global Soft Prompts: {num_soft_tokens} tokens")
        print(f"   - Class-Specific Prompts: {class_tokens} Ã— {NUM_LABELS} classes")
        print(f"   - Total Prompt Parameters: {sum(p.numel() for p in self.parameters()):,}")
        
    def forward(self, sent_emb):
        """
        Args:
            sent_emb: [B, S, H] sentence embeddings
        Returns:
            enhanced_emb: [B, S, H] prompt-enhanced embeddings
        """
        B, S, H = sent_emb.shape
        
        # 1. Apply Global Soft Prompts
        # Expand to batch size and concatenate with sentence embeddings
        global_prompt = self.global_soft_prompt.expand(B, -1, -1)  # [B, num_soft_tokens, H]
        
        # 2. Compute Class-Specific Prompt Weights
        # Use mean pooled sentence representation for context
        sent_context = sent_emb.mean(dim=1)  # [B, H]
        class_weights = self.prompt_selector(sent_context)  # [B, NUM_LABELS]
        
        # 3. Weighted Combination of Class-Specific Prompts
        # Blend class prompts based on predicted relevance
        class_prompts = torch.einsum('bc,cth->bth', 
                                     class_weights, 
                                     self.class_specific_prompts)  # [B, class_tokens, H]
        
        # 4. Combine Global and Class Prompts
        combined_prompts = torch.cat([global_prompt, class_prompts], dim=1)  # [B, num_soft_tokens+class_tokens, H]
        
        # 5. Attend from Sentences to Prompts
        # Compute attention scores
        attn_scores = torch.bmm(sent_emb, combined_prompts.transpose(1, 2))  # [B, S, num_soft_tokens+class_tokens]
        attn_weights = F.softmax(attn_scores / np.sqrt(H), dim=-1)
        
        # 6. Get Prompt Context
        prompt_context = torch.bmm(attn_weights, combined_prompts)  # [B, S, H]
        
        # 7. Fuse Original and Prompt Features
        fused = torch.cat([sent_emb, prompt_context], dim=-1)  # [B, S, 2H]
        fused = self.prompt_fusion(fused)  # [B, S, H]
        
        # 8. Gated Addition (learnable skip connection)
        gate = self.prompt_gate(sent_emb)
        enhanced_emb = gate * fused + (1 - gate) * sent_emb
        
        return enhanced_emb

# ==================== PROTOTYPE MANAGER ====================
class ClassPrototypeManager:
    """Manages class prototypes for prototypical learning"""
    def __init__(self):
        self.prototypes = None
        self.fitted = False

    def fit(self, embeddings, labels):
        embeddings = np.asarray(embeddings)
        labels = np.asarray(labels)
        D = embeddings.shape[1]
        protos = np.zeros((NUM_LABELS, D), dtype=np.float32)
        
        for k in range(NUM_LABELS):
            mask = labels == k
            if mask.sum() > 0:
                protos[k] = embeddings[mask].mean(axis=0)
            else:
                protos[k] = np.random.randn(D).astype(np.float32) * 1e-3
        
        self.prototypes = protos
        self.fitted = True
        print(f"[Prototypes] Fitted {NUM_LABELS} class prototypes, shape: {protos.shape}")

    def get_all_tensor(self, device=None):
        if not self.fitted:
            raise RuntimeError("Prototypes not fitted")
        t = torch.tensor(self.prototypes, dtype=torch.float32)
        if device is not None:
            t = t.to(device)
        return t

    def get_nearest_index(self, embeddings):
        if not self.fitted:
            raise RuntimeError("Prototypes not fitted")
        emb = np.asarray(embeddings)
        emb_n = emb / (np.linalg.norm(emb, axis=1, keepdims=True) + 1e-8)
        proto_n = self.prototypes / (np.linalg.norm(self.prototypes, axis=1, keepdims=True) + 1e-8)
        sims = emb_n @ proto_n.T
        return np.argmax(sims, axis=1)

    def knn_prior(self, embeddings, topk=KNN_K):
        if not self.fitted:
            raise RuntimeError("Prototypes not fitted")
        emb = np.asarray(embeddings)
        emb_n = emb / (np.linalg.norm(emb, axis=1, keepdims=True) + 1e-12)
        proto_n = self.prototypes / (np.linalg.norm(self.prototypes, axis=1, keepdims=True) + 1e-12)
        sims = emb_n @ proto_n.T
        topk_idx = np.argsort(-sims, axis=1)[:, :topk]
        topk_sims = np.take_along_axis(sims, topk_idx, axis=1)
        return topk_sims, topk_idx

# ==================== DATASET ====================
class DocumentTextDataset(Dataset):
    def __init__(self, docs_sents, docs_labels):
        self.docs_sents = docs_sents
        self.docs_labels = docs_labels

    def __len__(self):
        return len(self.docs_sents)

    def __getitem__(self, idx):
        return {
            "sentences": self.docs_sents[idx],
            "labels": torch.tensor(self.docs_labels[idx], dtype=torch.long),
        }

def collate_docs(batch, tokenizer):
    max_sents = max(len(b["sentences"]) for b in batch)
    B = len(batch)

    flat_sents = []
    doc_sent_offsets = []
    for b in batch:
        doc_sent_offsets.append(len(flat_sents))
        flat_sents.extend(b["sentences"])

    enc = tokenizer(
        flat_sents,
        padding=True,
        truncation=True,
        max_length=MAX_SEQ_LENGTH,
        return_tensors="pt",
    )
    input_ids_all = enc["input_ids"]
    attn_mask_all = enc["attention_mask"]

    max_tokens = input_ids_all.shape[1]
    input_ids_padded = torch.zeros((B, max_sents, max_tokens), dtype=torch.long)
    attn_mask_padded = torch.zeros((B, max_sents, max_tokens), dtype=torch.long)
    labels_padded = torch.full((B, max_sents), -100, dtype=torch.long)
    lengths = torch.zeros(B, dtype=torch.long)

    for i, b in enumerate(batch):
        num_s = len(b["sentences"])
        lengths[i] = num_s
        start = doc_sent_offsets[i]
        end = start + num_s
        input_ids_padded[i, :num_s] = input_ids_all[start:end]
        attn_mask_padded[i, :num_s] = attn_mask_all[start:end]
        labels_padded[i, :num_s] = b["labels"]

    return input_ids_padded, attn_mask_padded, labels_padded, lengths

# ==================== LOSS FUNCTIONS ====================
def focal_loss(logits_masked, labels_masked, gamma=FOCAL_GAMMA, alpha=FOCAL_ALPHA, label_smoothing=LABEL_SMOOTHING):
    """Focal Loss for handling class imbalance"""
    ce = F.cross_entropy(logits_masked, labels_masked, reduction='none', label_smoothing=label_smoothing)
    pt = torch.exp(-ce)
    focal = alpha * (1-pt)**gamma * ce
    return focal.mean()

def prototypical_cosine_loss(reprs, prototypes_tensor, labels, temperature=PROTO_AUX_TEMPERATURE):
    """Prototypical loss using cosine similarity"""
    x_norm = F.normalize(reprs, p=2, dim=1)
    p_norm = F.normalize(prototypes_tensor, p=2, dim=1)
    sims = x_norm @ p_norm.t()
    sims = sims * temperature
    loss = F.cross_entropy(sims, labels)
    return loss, sims

def compute_losses(logits, labels, sent_emb_flat, doc_out, prototypes_tensor, rpl_proto, temperature=PROTO_AUX_TEMPERATURE):
    """Compute combined losses"""
    logits_flat = logits.view(-1, NUM_LABELS)
    labels_flat = labels.view(-1)
    mask = labels_flat != -100
    
    if mask.sum() == 0:
        return torch.tensor(0.0, device=logits.device), torch.tensor(0.0, device=logits.device), torch.tensor(0.0, device=logits.device)
    
    logits_masked = logits_flat[mask]
    labels_masked = labels_flat[mask]
    
    # 1. Focal Cross-Entropy Loss
    ce_loss = focal_loss(logits_masked, labels_masked)
    
    # 2. Prototypical Loss
    valid_sent_emb = sent_emb_flat[mask]
    proto_loss, _ = prototypical_cosine_loss(valid_sent_emb, prototypes_tensor, labels_masked)
    
    # 3. Role Prototypical Layer Loss
    valid_doc_out = doc_out.view(-1, doc_out.size(-1))[mask]
    rpl_sim = F.normalize(valid_doc_out, dim=1) @ F.normalize(rpl_proto, dim=1).T
    rpl_loss = F.cross_entropy(rpl_sim * temperature, labels_masked)
    
    return ce_loss, proto_loss, rpl_loss

# ==================== MODEL COMPONENTS ====================
class RolePrototypicalLayer(nn.Module):
    """Role-specific prototypical classification head"""
    def __init__(self, hidden_dim):
        super().__init__()
        self.proto = nn.Parameter(torch.randn(NUM_LABELS, hidden_dim) * 0.02)
    
    def forward(self, h):
        h_norm = F.normalize(h, dim=-1)
        proto_norm = F.normalize(self.proto, dim=-1)
        return h_norm @ proto_norm.T

class RoleTransitionMatrix(nn.Module):
    """Models sequential dependencies between rhetorical roles"""
    def __init__(self, rtm_lambda=RTM_LAMBDA):
        super().__init__()
        self.A = nn.Parameter(torch.zeros(NUM_LABELS, NUM_LABELS))
        self.rtm_lambda = rtm_lambda
    
    def forward(self, logits):
        lp = logits.log_softmax(-1)
        B, S, C = lp.shape
        
        for t in range(1, S):
            tr = torch.logsumexp(
                lp[:, t-1].unsqueeze(2) + self.A.log_softmax(-1),
                dim=1
            )
            logits[:, t] += self.rtm_lambda * tr
        return logits

class SentenceEncoderFFN(nn.Module):
    """Feed-forward network for sentence encoding"""
    def __init__(self, sent_dim, hidden=512, dropout=0.2):
        super().__init__()
        self.fc1 = nn.Linear(sent_dim, hidden)
        self.act = nn.ReLU()
        self.fc2 = nn.Linear(hidden, sent_dim)
        self.dropout = nn.Dropout(dropout)
        self.ln = nn.LayerNorm(sent_dim)

    def forward(self, x):
        h = self.fc1(x)
        h = self.act(h)
        h = self.dropout(h)
        h = self.fc2(h)
        return self.ln(x + h)

class PrototypeAttention(nn.Module):
    """Attention mechanism over class prototypes"""
    def __init__(self, hidden_dim):
        super().__init__()
        self.W = nn.Linear(hidden_dim, hidden_dim, bias=False)

    def forward(self, sent_emb, prototypes):
        B, S, H = sent_emb.shape
        h_proj = self.W(sent_emb)
        scores = torch.matmul(h_proj, prototypes.t())
        attn_weights = torch.softmax(scores, dim=-1)
        proto_ctx = torch.matmul(attn_weights, prototypes)
        return proto_ctx, attn_weights

# ==================== MAIN MODEL ====================
class PromptTuningHSLNModel(nn.Module):
    """
    Prompt Tuning + Hierarchical Sentence-Level Network
    with Role Prototypical Layer and Transition Modeling
    """
    def __init__(self, bert_name=INLEGALBERT_MODEL_NAME, pos_dim=POS_EMB_DIM,
                 use_pos_emb=USE_POSITIONAL_EMB, use_knn_prior=USE_KNN_PRIOR,
                 knn_prior_dim=KNN_PRIOR_DIM, doc_hidden=LSTM_HIDDEN, dropout=DROPOUT):
        super().__init__()

        # âœ… FROZEN BERT (NO LoRA/PEFT)
        self.bert = AutoModel.from_pretrained(bert_name)
        for param in self.bert.parameters():
            param.requires_grad = False
            
        print("\nâœ… BERT Model: COMPLETELY FROZEN (no gradients)")
        print("âœ… Training Strategy: PURE PROMPT TUNING")

        self.hidden_size = self.bert.config.hidden_size
        
        # âœ… ADVANCED PROMPT TUNING MODULE
        self.prompt_tuning = AdvancedPromptTuning(self.hidden_size)
        
        # Positional Embeddings
        self.use_pos_emb = use_pos_emb
        if use_pos_emb and pos_dim > 0:
            self.pos_emb = nn.Embedding(1024, pos_dim)
            self.pos_proj = nn.Linear(pos_dim, self.hidden_size) if pos_dim != self.hidden_size else None
        else:
            self.pos_emb = None
            self.pos_proj = None

        # KNN Prior Features
        self.use_knn_prior = use_knn_prior
        self.knn_prior_dim = knn_prior_dim
        if use_knn_prior and knn_prior_dim > 0:
            self.knn_proj = nn.Sequential(
                nn.Linear(KNN_K, 64),
                nn.ReLU(),
                nn.Dropout(0.1),
                nn.Linear(64, knn_prior_dim),
            )
        else:
            self.knn_proj = None

        # Sentence Encoder
        self.sent_encoder = SentenceEncoderFFN(self.hidden_size, hidden=512, dropout=dropout)
        
        # Prototype Attention
        self.proto_attn = PrototypeAttention(self.hidden_size)
        
        # Document-Level LSTM
        final_in_dim = self.hidden_size + (self.knn_prior_dim if self.knn_proj else 0)
        self.doc_lstm = nn.LSTM(final_in_dim, doc_hidden, 2, bidirectional=True, 
                               batch_first=True, dropout=dropout)
        
        lstm_out_dim = doc_hidden * 2
        
        # Classification Heads
        self.ce_classifier = nn.Sequential(
            nn.Dropout(dropout),
            nn.Linear(lstm_out_dim, doc_hidden),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(doc_hidden, NUM_LABELS),
        )
        
        self.rpl = RolePrototypicalLayer(lstm_out_dim)
        self.rtm = RoleTransitionMatrix()
        self.head_alpha = nn.Parameter(torch.tensor(2.0))

    def encode_sentences(self, input_ids, attention_mask):
        """Encode sentences using frozen BERT"""
        B, S, T = input_ids.shape
        input_ids_flat = input_ids.view(B * S, T)
        attn_flat = attention_mask.view(B * S, T)
        
        with torch.no_grad():  # BERT is frozen
            outputs = self.bert(input_ids_flat, attention_mask=attn_flat)
        
        sent_emb_flat = outputs.last_hidden_state.mean(dim=1)
        sent_emb = sent_emb_flat.view(B, S, -1)
        return sent_emb, sent_emb_flat

    def forward(self, input_ids, attention_mask, lengths, prototypes_tensor, 
                proto_idx_batch=None, knn_sims=None):
        B, S, T = input_ids.shape
        device = input_ids.device

        # 1. Encode Sentences (frozen BERT)
        sent_emb, sent_emb_flat = self.encode_sentences(input_ids, attention_mask)
        
        # 2. âœ… APPLY PROMPT TUNING (trainable)
        sent_emb = self.prompt_tuning(sent_emb)
        sent_emb = self.sent_encoder(sent_emb)

        # 3. Add Positional Embeddings
        if self.pos_emb is not None:
            pos_idx = torch.arange(S, device=device).unsqueeze(0).expand(B, -1)
            pos_vec = self.pos_emb(pos_idx)
            if self.pos_proj is not None:
                pos_vec = self.pos_proj(pos_vec)
            sent_emb = sent_emb + pos_vec

        # 4. Prototype Attention
        protos = prototypes_tensor if isinstance(prototypes_tensor, torch.Tensor) else \
                torch.tensor(prototypes_tensor, device=device, dtype=torch.float32)
        proto_ctx, _ = self.proto_attn(sent_emb, protos)
        sent_emb = sent_emb + proto_ctx

        # 5. Add KNN Prior Features
        if self.knn_proj is not None and knn_sims is not None:
            knn_feat = self.knn_proj(knn_sims.view(-1, KNN_K)).view(B, S, -1)
            doc_in = torch.cat([sent_emb, knn_feat], dim=2)
        else:
            doc_in = sent_emb

        # 6. Document-Level LSTM
        packed = nn.utils.rnn.pack_padded_sequence(doc_in, lengths.cpu(), batch_first=True, enforce_sorted=False)
        packed_out, _ = self.doc_lstm(packed)
        doc_out, _ = nn.utils.rnn.pad_packed_sequence(packed_out, batch_first=True)

        # 7. Multi-Head Classification
        ce_logits = self.ce_classifier(doc_out)
        rpl_logits = self.rpl(doc_out)
        alpha = torch.sigmoid(self.head_alpha)
        blended_logits = alpha * ce_logits + (1 - alpha) * rpl_logits
        
        # 8. Role Transition Modeling
        final_logits = self.rtm(blended_logits)
        
        return final_logits, sent_emb_flat, doc_out

# ==================== TRAINER ====================
class Trainer:
    def __init__(self, model, tokenizer, prototype_manager, device=DEVICE):
        self.model = model.to(device)
        self.tokenizer = tokenizer
        self.prototype_manager = prototype_manager
        self.device = device

    def _build_train_loader(self, train_dataset):
        """Build training data loader with weighted sampling for minority classes"""
        if not USE_WEIGHTED_SAMPLER:
            return DataLoader(train_dataset, batch_size=BATCH_DOCS, shuffle=True,
                            collate_fn=lambda b: collate_docs(b, self.tokenizer))
        
        # Compute document-level majority class
        major_labels = []
        for doc_labels in train_dataset.docs_labels:
            if doc_labels:
                doc_counter = Counter(doc_labels)
                major_label = doc_counter.most_common(1)[0][0]
                major_labels.append(major_label)
            else:
                major_labels.append(0)
        
        # Compute inverse frequency weights
        counts = np.bincount(major_labels, minlength=NUM_LABELS)
        inv_freq = 1.0 / (counts + 1e-6)
        
        # Apply minority boosting
        weights = inv_freq[np.array(major_labels)].copy()
        for doc_idx, major_label in enumerate(major_labels):
            if major_label in minority_ids:
                weights[doc_idx] *= MINORITY_BOOST
        
        sampler = WeightedRandomSampler(
            torch.tensor(weights, dtype=torch.double),
            num_samples=len(weights),
            replacement=True
        )
        
        return DataLoader(train_dataset, batch_size=BATCH_DOCS, sampler=sampler,
                         collate_fn=lambda b: collate_docs(b, self.tokenizer))

    def train(self, train_dataset, dev_dataset, num_epochs=NUM_EPOCHS, 
              proto_weight=PROTO_WEIGHT, rpl_weight=RPL_WEIGHT):
        """Train the model with prompt tuning"""
        start_time = time.time()
        
        # âœ… SEPARATE OPTIMIZERS: High LR for prompts, low for classifiers
        prompt_params = list(self.model.prompt_tuning.parameters())
        other_params = [p for n, p in self.model.named_parameters() 
                       if p.requires_grad and 'prompt_tuning' not in n]
        
        prompt_optimizer = torch.optim.AdamW(prompt_params, lr=PROMPT_LR, weight_decay=WEIGHT_DECAY)
        other_optimizer = torch.optim.AdamW(other_params, lr=CLASSIFIER_LR, weight_decay=WEIGHT_DECAY)
        
        train_loader = self._build_train_loader(train_dataset)
        total_steps = max(1, len(train_loader) * num_epochs)
        warmup_steps = max(1, int(WARMUP_RATIO * total_steps))
        
        prompt_scheduler = get_linear_schedule_with_warmup(prompt_optimizer, warmup_steps, total_steps)
        other_scheduler = get_linear_schedule_with_warmup(other_optimizer, warmup_steps, total_steps)

        prototypes_tensor = self.prototype_manager.get_all_tensor(device=self.device)
        best_macro_f1 = -1.0
        best_ckpt = None
        history = []

        print("\n" + "="*80)
        print("ðŸš€ STARTING PROMPT TUNING TRAINING")
        print("="*80)
        print(f"Training Strategy:")
        print(f"  - Prompt LR: {PROMPT_LR} (aggressive)")
        print(f"  - Classifier LR: {CLASSIFIER_LR} (conservative)")
        print(f"  - Minority Boost: {MINORITY_BOOST}x")
        print(f"  - Focal Loss Gamma: {FOCAL_GAMMA}")
        print(f"  - Total Epochs: {num_epochs}")
        print("="*80 + "\n")
        
        for epoch in range(1, num_epochs + 1):
            self.model.train()
            t0 = time.time()
            running_ce = running_proto = running_rpl = running_total = 0.0
            n_samples = 0

            for batch_idx, (input_ids, attn_mask, labels, lengths) in enumerate(train_loader):
                input_ids = input_ids.to(self.device)
                attn_mask = attn_mask.to(self.device)
                labels = labels.to(self.device)
                lengths = lengths.to(self.device)

                # Compute KNN features (BERT frozen during this)
                with torch.no_grad():
                    sent_emb, sent_emb_flat_for_knn = self.model.encode_sentences(input_ids, attn_mask)
                    sent_emb_flat_for_knn_np = sent_emb_flat_for_knn.cpu().numpy()
                    sims, _ = self.prototype_manager.knn_prior(sent_emb_flat_for_knn_np, topk=KNN_K)
                    knn_sims_tensor = torch.tensor(sims, dtype=torch.float32, device=self.device)
                    nearest_idx_flat = self.prototype_manager.get_nearest_index(sent_emb_flat_for_knn_np)
                    proto_idx_batch = torch.tensor(nearest_idx_flat, dtype=torch.long, 
                                                 device=self.device).view(input_ids.shape[:2])

                # Forward pass
                logits, sent_emb_flat, doc_out = self.model(
                    input_ids, attn_mask, lengths, prototypes_tensor,
                    proto_idx_batch, knn_sims_tensor
                )

                # Compute losses
                ce_loss, proto_loss, rpl_loss = compute_losses(
                    logits, labels, sent_emb_flat, doc_out,
                    prototypes_tensor, self.model.rpl.proto
                )
                total_loss = ce_loss + proto_weight * proto_loss + rpl_weight * rpl_loss

                # Backward pass with separate optimizers
                prompt_optimizer.zero_grad()
                other_optimizer.zero_grad()
                total_loss.backward()
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), GRAD_CLIP)
                prompt_optimizer.step()
                other_optimizer.step()
                prompt_scheduler.step()
                other_scheduler.step()

                # Track metrics
                mask = labels.view(-1) != -100
                n = mask.sum().item()
                running_ce += ce_loss.item() * n
                running_proto += proto_loss.item() * n
                running_rpl += rpl_loss.item() * n
                running_total += total_loss.item() * n
                n_samples += n

            # Validation
            val_results = self.evaluate(dev_dataset, measure_time=True)
            avg_ce = running_ce / max(1, n_samples)
            avg_proto = running_proto / max(1, n_samples)
            avg_rpl = running_rpl / max(1, n_samples)
            avg_total = running_total / max(1, n_samples)

            epoch_time = time.time() - t0
            mem_usage = get_memory_usage(self.device)

            # Log history
            history.append({
                "epoch": epoch,
                "train_ce": avg_ce,
                "train_proto": avg_proto,
                "train_rpl": avg_rpl,
                "train_total": avg_total,
                "val_acc": val_results["metrics"]["accuracy"],
                "val_macro_f1": val_results["metrics"]["macro_f1"],
                "val_minority_f1": val_results["metrics"]["minority_macro_f1"],
                "epoch_time_s": epoch_time,
                "mem_allocated_mb": mem_usage.get('allocated', 0),
                "total_time_s": time.time() - start_time,
            })

            # Print progress
            print(f"Epoch {epoch:02d}/{num_epochs} | "
                  f"Train Loss: {avg_total:.4f} (CE:{avg_ce:.4f} Proto:{avg_proto:.4f} RPL:{avg_rpl:.4f}) | "
                  f"Val Acc: {val_results['metrics']['accuracy']:.4f} | "
                  f"Val Macro-F1: {val_results['metrics']['macro_f1']:.4f} | "
                  f"Val Minority-F1: {val_results['metrics']['minority_macro_f1']:.4f} | "
                  f"Time: {epoch_time:.1f}s | "
                  f"Mem: {mem_usage.get('allocated', 0):.0f}MB")

            # Save best checkpoint
            if val_results["metrics"]["macro_f1"] > best_macro_f1 + 1e-5:
                best_macro_f1 = val_results["metrics"]["macro_f1"]
                ckpt_path = os.path.join(OUT_DIR, f"best_epoch{epoch}_macroF1_{best_macro_f1:.4f}.pt")
                torch.save({
                    "model_state_dict": self.model.state_dict(),
                    "epoch": epoch,
                    "macro_f1": best_macro_f1,
                    "metrics": val_results["metrics"],
                }, ckpt_path)
                if best_ckpt and os.path.exists(best_ckpt):
                    os.remove(best_ckpt)
                best_ckpt = ckpt_path
                print(f"  ðŸ’¾ Saved new best checkpoint: {ckpt_path}")

        total_train_time = time.time() - start_time
        
        # Save training history
        history_df = pd.DataFrame(history)
        history_df.to_csv(os.path.join(OUT_DIR, "training_history.csv"), index=False)
        
        print("\n" + "="*80)
        print(f"âœ… TRAINING COMPLETED")
        print(f"   Total Time: {total_train_time/60:.2f} minutes")
        print(f"   Best Macro F1: {best_macro_f1:.4f}")
        print(f"   Best Checkpoint: {best_ckpt}")
        print("="*80 + "\n")
        
        return best_ckpt, total_train_time

    @torch.no_grad()
    def evaluate(self, dataset, measure_time=False):
        """Evaluate the model"""
        start_time = time.time()
        self.model.eval()
        
        loader = DataLoader(dataset, batch_size=8, shuffle=False,
                           collate_fn=lambda b: collate_docs(b, self.tokenizer))
        prototypes_tensor = self.prototype_manager.get_all_tensor(device=self.device)
        
        all_preds, all_trues = [], []
        running_ce = 0.0
        n_samples = 0

        for input_ids, attn_mask, labels, lengths in loader:
            input_ids = input_ids.to(self.device)
            attn_mask = attn_mask.to(self.device)
            labels = labels.to(self.device)
            lengths = lengths.to(self.device)

            # KNN features
            sent_emb, sent_emb_flat_for_knn = self.model.encode_sentences(input_ids, attn_mask)
            sent_emb_flat_for_knn_np = sent_emb_flat_for_knn.cpu().numpy()
            sims, _ = self.prototype_manager.knn_prior(sent_emb_flat_for_knn_np, topk=KNN_K)
            knn_sims_tensor = torch.tensor(sims, dtype=torch.float32, device=self.device)
            nearest_idx_flat = self.prototype_manager.get_nearest_index(sent_emb_flat_for_knn_np)
            proto_idx_batch = torch.tensor(nearest_idx_flat, dtype=torch.long, 
                                         device=self.device).view(input_ids.shape[:2])

            # Forward pass
            logits, sent_emb_flat, doc_out = self.model(
                input_ids, attn_mask, lengths, prototypes_tensor,
                proto_idx_batch, knn_sims_tensor
            )
            
            # Get predictions
            logits_flat = logits.view(-1, NUM_LABELS)
            labels_flat = labels.view(-1)
            mask = labels_flat != -100
            
            if mask.sum() == 0:
                continue
                
            logits_masked = logits_flat[mask]
            labels_masked = labels_flat[mask]
            
            ce_loss = focal_loss(logits_masked, labels_masked)
            preds = torch.argmax(logits_masked, dim=1).cpu().numpy()
            
            all_preds.extend(preds.tolist())
            all_trues.extend(labels_masked.cpu().numpy().tolist())
            
            n = labels_masked.size(0)
            running_ce += ce_loss.item() * n
            n_samples += n

        if n_samples == 0:
            return {
                "metrics": {"accuracy": 0, "macro_f1": 0},
                "total_loss": 0,
                "inference_time": 0
            }

        # Compute metrics
        avg_ce = running_ce / n_samples
        detailed_metrics = compute_detailed_metrics(all_trues, all_preds)
        
        cls_report = classification_report(
            [id2label[x] for x in all_trues],
            [id2label[x] for x in all_preds],
            digits=4,
            zero_division=0
        )
        
        result = {
            "metrics": detailed_metrics,
            "total_loss": avg_ce,
            "classification_report": cls_report,
            "all_preds": all_preds,
            "all_trues": all_trues,
            "sample_count": n_samples
        }
        
        if measure_time:
            result["inference_time"] = time.time() - start_time
            result["mem_usage"] = get_memory_usage(self.device)
        
        return result

# ==================== MAIN EXECUTION ====================
def main():
    total_start_time = time.time()
    
    print("\n" + "="*80)
    print("LOADING DATASETS")
    print("="*80)
    train_docs = load_jsonl(TRAIN_PATH)
    dev_docs = load_jsonl(DEV_PATH)
    test_docs = load_jsonl(TEST_PATH)
    print(f"âœ… Train: {len(train_docs)} documents")
    print(f"âœ… Dev: {len(dev_docs)} documents")
    print(f"âœ… Test: {len(test_docs)} documents")

    train_sents, train_labels, _ = extract_data(train_docs)
    dev_sents, dev_labels, _ = extract_data(dev_docs)
    test_sents, test_labels, _ = extract_data(test_docs)
    
    print(f"âœ… Train: {len(train_sents)} docs, {sum(len(s) for s in train_sents)} sentences")
    print(f"âœ… Dev: {len(dev_sents)} docs, {sum(len(s) for s in dev_sents)} sentences")
    print(f"âœ… Test: {len(test_sents)} docs, {sum(len(s) for s in test_sents)} sentences")

    # Class distribution
    flat_train_labels = [l for doc in train_labels for l in doc]
    label_dist = Counter(flat_train_labels)
    print("\nðŸ“Š Training Label Distribution:")
    for label_id, count in sorted(label_dist.items(), key=lambda x: x[1], reverse=True):
        print(f"   {id2label[label_id]:20s}: {count:6d} ({100*count/len(flat_train_labels):5.2f}%)")

    # Initialize tokenizer
    tokenizer = AutoTokenizer.from_pretrained(INLEGALBERT_MODEL_NAME)
    
    # Compute prototypes using frozen BERT
    print("\n" + "="*80)
    print("COMPUTING CLASS PROTOTYPES")
    print("="*80)
    temp_bert = AutoModel.from_pretrained(INLEGALBERT_MODEL_NAME).to(DEVICE)
    temp_bert.eval()
    
    with torch.no_grad():
        flat_train_sents = [s for doc in train_sents for s in doc]
        flat_train_labels_np = np.array([l for doc in train_labels for l in doc], dtype=np.int64)
        
        train_embs = []
        batch_size = 64
        for i in range(0, len(flat_train_sents), batch_size):
            batch = flat_train_sents[i:i+batch_size]
            enc = tokenizer(batch, padding=True, truncation=True,
                          max_length=MAX_SEQ_LENGTH, return_tensors="pt").to(DEVICE)
            out = temp_bert(**enc).last_hidden_state.mean(dim=1)
            train_embs.append(out.cpu().numpy())
            
            if (i // batch_size + 1) % 50 == 0:
                print(f"  Processed {i+len(batch)}/{len(flat_train_sents)} sentences...")
        
        train_embs = np.vstack(train_embs)
    
    del temp_bert
    torch.cuda.empty_cache()
    print("âœ… Embeddings computed")

    # Fit prototypes
    proto_mgr = ClassPrototypeManager()
    proto_mgr.fit(train_embs, flat_train_labels_np)

    # Create datasets
    train_dataset = DocumentTextDataset(train_sents, train_labels)
    dev_dataset = DocumentTextDataset(dev_sents, dev_labels)
    test_dataset = DocumentTextDataset(test_sents, test_labels)

    # Initialize model
    print("\n" + "="*80)
    print("INITIALIZING MODEL")
    print("="*80)
    model = PromptTuningHSLNModel()
    
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    prompt_params = sum(p.numel() for p in model.prompt_tuning.parameters())
    frozen_params = total_params - trainable_params
    
    print(f"\nðŸ“Š Model Parameters:")
    print(f"   Total:     {total_params:,}")
    print(f"   Trainable: {trainable_params:,} ({100*trainable_params/total_params:.2f}%)")
    print(f"   Frozen:    {frozen_params:,} ({100*frozen_params/total_params:.2f}%)")
    print(f"   Prompts:   {prompt_params:,} ({100*prompt_params/trainable_params:.2f}% of trainable)")
    print(f"\nâœ… Prompt Tuning Only - BERT Frozen")

    # Train
    trainer = Trainer(model, tokenizer, proto_mgr, device=DEVICE)
    best_ckpt, train_time = trainer.train(
        train_dataset, dev_dataset,
        proto_weight=PROTO_WEIGHT,
        rpl_weight=RPL_WEIGHT
    )
    
    # Load best checkpoint
    if best_ckpt and os.path.exists(best_ckpt):
        print(f"\nâœ… Loading best checkpoint: {best_ckpt}")
        ckpt = torch.load(best_ckpt, map_location=DEVICE)
        model.load_state_dict(ckpt["model_state_dict"])
        print(f"   Macro F1: {ckpt['macro_f1']:.4f}")

    # Final test evaluation
    print("\n" + "="*80)
    print("FINAL TEST EVALUATION")
    print("="*80)
    test_results = trainer.evaluate(test_dataset, measure_time=True)
    
    print(f"\nðŸ“Š Test Results:")
    print(f"   Accuracy:          {test_results['metrics']['accuracy']:.4f}")
    print(f"   Macro F1:          {test_results['metrics']['macro_f1']:.4f}")
    print(f"   Weighted F1:       {test_results['metrics']['weighted_f1']:.4f}")
    print(f"   Macro Precision:   {test_results['metrics']['macro_precision']:.4f}")
    print(f"   Macro Recall:      {test_results['metrics']['macro_recall']:.4f}")
    print(f"   Minority Macro F1: {test_results['metrics']['minority_macro_f1']:.4f}")
    print(f"   Inference Time:    {test_results['inference_time']:.2f}s")
    
    print(f"\nðŸ“‹ Classification Report:")
    print(test_results["classification_report"])
    
    print(f"\nðŸŽ¯ Minority Class F1 Scores:")
    for cls, f1 in sorted(test_results['metrics']['minority_f1_per_class'].items(),
                         key=lambda x: x[1], reverse=True):
        print(f"   {cls:20s}: {f1:.4f}")
    
    print(f"\nðŸ“Š Per-Class F1 Scores:")
    for cls, f1 in sorted(test_results['metrics']['per_class_f1'].items(),
                         key=lambda x: x[1], reverse=True):
        support = test_results['metrics']['per_class_support'][cls]
        print(f"   {cls:20s}: {f1:.4f} (n={support})")
    
    # Plot confusion matrix
    plt.figure(figsize=(16, 14))
    cm = confusion_matrix(test_results["all_trues"], test_results["all_preds"],
                         labels=range(NUM_LABELS))
    
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                xticklabels=[l.replace('_', '\n') for l in LABELS],
                yticklabels=[l.replace('_', '\n') for l in LABELS],
                cbar_kws={'label': 'Count'})
    
    plt.title(f'Prompt Tuning + HSLN + RPL + RTM\n'
              f'Test Macro F1: {test_results["metrics"]["macro_f1"]:.4f} | '
              f'Minority F1: {test_results["metrics"]["minority_macro_f1"]:.4f}\n'
              f'Accuracy: {test_results["metrics"]["accuracy"]:.4f}',
              fontsize=14, fontweight='bold')
    plt.xlabel('Predicted Label', fontsize=12)
    plt.ylabel('True Label', fontsize=12)
    plt.xticks(rotation=45, ha='right')
    plt.yticks(rotation=0)
    plt.tight_layout()
    plt.savefig(os.path.join(OUT_DIR, "confusion_matrix.png"), dpi=300, bbox_inches='tight')
    print(f"\nðŸ’¾ Saved confusion matrix to {OUT_DIR}/confusion_matrix.png")
    
    # Save comprehensive results
    results_summary = {
        'method': 'Advanced_Prompt_Tuning_HSLN',
        'hard_prompt_words': len(HARD_LEGAL_PROMPT.split()),
        'soft_prompt_tokens': NUM_SOFT_PROMPT_TOKENS,
        'class_prompt_tokens': CLASS_PROMPT_TOKENS,
        'total_prompt_params': prompt_params,
        'trainable_params': trainable_params,
        'minority_boost': MINORITY_BOOST,
        'focal_gamma': FOCAL_GAMMA,
        'total_train_time_minutes': train_time / 60,
        'test_inference_time_seconds': test_results['inference_time'],
        **test_results['metrics']
    }
    
    summary_df = pd.DataFrame([results_summary])
    summary_df.to_csv(os.path.join(OUT_DIR, "final_results_summary.csv"), index=False)
    
    # Save predictions
    pred_df = pd.DataFrame({
        "true_label": [id2label[x] for x in test_results["all_trues"]],
        "pred_label": [id2label[x] for x in test_results["all_preds"]],
        "true_id": test_results["all_trues"],
        "pred_id": test_results["all_preds"],
        "correct": [t == p for t, p in zip(test_results["all_trues"], test_results["all_preds"])]
    })
    pred_df.to_csv(os.path.join(OUT_DIR, "final_test_predictions.csv"), index=False)
    
    # Save detailed per-class metrics
    per_class_df = pd.DataFrame({
        'class': LABELS,
        'precision': [test_results['metrics']['per_class_precision'][l] for l in LABELS],
        'recall': [test_results['metrics']['per_class_recall'][l] for l in LABELS],
        'f1': [test_results['metrics']['per_class_f1'][l] for l in LABELS],
        'support': [test_results['metrics']['per_class_support'][l] for l in LABELS],
    })
    per_class_df.to_csv(os.path.join(OUT_DIR, "per_class_metrics.csv"), index=False)
    
    # Save prompts information
    prompt_info = {
        'hard_prompt': HARD_LEGAL_PROMPT,
        'num_soft_tokens': NUM_SOFT_PROMPT_TOKENS,
        'class_prompt_tokens': CLASS_PROMPT_TOKENS,
        'context_prompt_tokens': CONTEXT_PROMPT_TOKENS,
        'prompt_lr': PROMPT_LR,
        'classifier_lr': CLASSIFIER_LR,
    }
    with open(os.path.join(OUT_DIR, "prompt_configuration.json"), 'w') as f:
        json.dump(prompt_info, f, indent=2)
    
    total_runtime = time.time() - total_start_time
    
    print("\n" + "="*80)
    print("RESULTS SAVED")
    print("="*80)
    print(f"Output Directory: {OUT_DIR}/")
    print(f"  âœ… final_results_summary.csv")
    print(f"  âœ… final_test_predictions.csv")
    print(f"  âœ… per_class_metrics.csv")
    print(f"  âœ… training_history.csv")
    print(f"  âœ… confusion_matrix.png")
    print(f"  âœ… prompt_configuration.json")
    print(f"  âœ… Best model checkpoint")
    
    print("\n" + "="*80)
    print("EXECUTION SUMMARY")
    print("="*80)
    print(f"Total Runtime:      {total_runtime/60:.2f} minutes")
    print(f"Training Time:      {train_time/60:.2f} minutes")
    print(f"Test Macro F1:      {test_results['metrics']['macro_f1']:.4f}")
    print(f"Test Accuracy:      {test_results['metrics']['accuracy']:.4f}")
    print(f"Minority Macro F1:  {test_results['metrics']['minority_macro_f1']:.4f}")
    print("="*80)
    
    if test_results['metrics']['macro_f1'] >= 0.59:
        print("ðŸŽ‰ TARGET ACHIEVED: Macro F1 â‰¥ 0.59!")
    else:
        print(f"ðŸ“Š Current Macro F1: {test_results['metrics']['macro_f1']:.4f} (Target: 0.59)")
    
    print("\nâœ… PROMPT TUNING EXPERIMENT COMPLETED SUCCESSFULLY!\n")

if __name__ == "__main__":
    main()