In [None]:
import os
import json
import random
import time
import psutil
import gc
from datetime import datetime
from collections import Counter

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
from transformers import AutoTokenizer, AutoModel, get_linear_schedule_with_warmup
from sklearn.metrics import (
    classification_report,
    confusion_matrix,
    accuracy_score,
    f1_score,
    precision_recall_fscore_support,
)

from peft import AdaLoraConfig, get_peft_model, TaskType

# ---------------- CONFIG ----------------
INLEGALBERT_MODEL_NAME = "law-ai/InLegalBERT"
TRAIN_PATH = "build_jsonl/build_train.jsonl"
DEV_PATH = "build_jsonl/build_dev.jsonl"
TEST_PATH = "build_jsonl/build_test.jsonl"

OUT_DIR = "inlegalbert_AdaLoRA_RPL_RTM_fixed_imbalance"
os.makedirs(OUT_DIR, exist_ok=True)

SEED = 42
MAX_SEQ_LENGTH = 128
MAX_SENTS_PER_DOC = 64
BATCH_DOCS = 2

# ‚úÖ HYPERPARAMETERS
NUM_EPOCHS = 20
LR = 1e-4
WEIGHT_DECAY = 0.01
GRAD_CLIP = 1.0
PROTO_WEIGHT = 0.15
RPL_WEIGHT = 0.05
RTM_LAMBDA = 0.02
PROTO_AUX_TEMPERATURE = 5.0
DROPOUT = 0.3
LSTM_HIDDEN = 384
LABEL_SMOOTHING = 0.02

# Focal Loss params
FOCAL_GAMMA = 2.0
FOCAL_ALPHA = 0.25

# AdaLoRA specific params
ADALORA_INIT_R = 12
ADALORA_TARGET_R = 8
ADALORA_TINIT = 200
ADALORA_TFINAL = 1500
ADALORA_DELTA_T = 100

USE_WEIGHTED_SAMPLER = True
USE_POSITIONAL_EMB = True
POS_EMB_DIM = 32
USE_KNN_PRIOR = True
KNN_K = 3
KNN_PRIOR_DIM = 64
MINORITY_BOOST = 3.0

LABELS = [
    "PREAMBLE", "FAC", "RLC", "ISSUE", "ARG_PETITIONER",
    "ARG_RESPONDENT", "ANALYSIS", "STA", "PRE_RELIED",
    "PRE_NOT_RELIED", "RATIO", "RPC", "NONE",
]
label2id = {label: idx for idx, label in enumerate(LABELS)}
id2label = {idx: label for label, idx in label2id.items()}
NUM_LABELS = len(LABELS)
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

MINORITY_CLASSES = ["RLC", "ISSUE", "STA", "RATIO", "PRE_RELIED", "PRE_NOT_RELIED", "RPC"]
minority_ids = [label2id[label] for label in MINORITY_CLASSES if label in label2id]

# ---------------- UTILITIES ----------------
def get_memory_usage(device=DEVICE):
    if torch.cuda.is_available() and device.type == 'cuda':
        return {
            'allocated': torch.cuda.memory_allocated(device) / 1024**2,
            'reserved': torch.cuda.memory_reserved(device) / 1024**2,
            'max_allocated': torch.cuda.max_memory_allocated(device) / 1024**2
        }
    else:
        process = psutil.Process()
        return {'cpu_ram_mb': process.memory_info().rss / 1024**2}

def compute_detailed_metrics(y_true, y_pred):
    y_true = np.asarray(y_true)
    y_pred = np.asarray(y_pred)
    
    try:
        precision, recall, f1, support = precision_recall_fscore_support(
            y_true, y_pred, average=None, zero_division=0, labels=range(NUM_LABELS)
        )
    except ValueError:
        precision = np.zeros(NUM_LABELS)
        recall = np.zeros(NUM_LABELS)
        f1 = np.zeros(NUM_LABELS)
        support = np.bincount(y_true, minlength=NUM_LABELS)
    
    macro_f1 = f1_score(y_true, y_pred, average='macro', zero_division=0)
    weighted_f1 = f1_score(y_true, y_pred, average='weighted', zero_division=0)
    
    minority_mask = np.isin(y_true, minority_ids)
    if minority_mask.sum() > 0:
        minority_true = y_true[minority_mask]
        minority_pred = y_pred[minority_mask]
        if len(np.unique(minority_true)) > 0 and len(minority_true) > 1:
            try:
                minority_f1 = f1_score(minority_true, minority_pred, average='macro', zero_division=0)
            except:
                minority_f1 = 0.0
        else:
            minority_f1 = 0.0
    else:
        minority_f1 = 0.0
    
    minority_f1_per_class = {}
    for cls_id in minority_ids:
        mask = (y_true == cls_id)
        if mask.sum() > 1:
            try:
                minority_f1_per_class[id2label[cls_id]] = f1_score(
                    y_true[mask], y_pred[mask], average='binary', zero_division=0
                )
            except:
                minority_f1_per_class[id2label[cls_id]] = 0.0
        else:
            minority_f1_per_class[id2label[cls_id]] = 0.0
    
    return {
        'accuracy': float(accuracy_score(y_true, y_pred)),
        'macro_f1': float(macro_f1),
        'weighted_f1': float(weighted_f1),
        'macro_precision': float(np.mean(precision)),
        'macro_recall': float(np.mean(recall)),
        'minority_macro_f1': float(minority_f1),
        'minority_f1_per_class': minority_f1_per_class,
        'per_class_precision': {id2label[i]: float(precision[i]) for i in range(NUM_LABELS)},
        'per_class_recall': {id2label[i]: float(recall[i]) for i in range(NUM_LABELS)},
        'per_class_f1': {id2label[i]: float(f1[i]) for i in range(NUM_LABELS)},
        'per_class_support': {id2label[i]: int(support[i]) for i in range(NUM_LABELS)},
    }

def set_seed(seed=SEED):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

set_seed()

def load_jsonl(path):
    if not os.path.exists(path):
        raise FileNotFoundError(f"File not found: {path}")
    data = []
    with open(path, "r", encoding="utf-8") as f:
        for line in f:
            s = line.strip()
            if s:
                data.append(json.loads(s))
    return data

def extract_data(docs, max_sents=MAX_SENTS_PER_DOC):
    all_sents, all_labels, doc_ids = [], [], []
    for doc in docs:
        doc_id = doc.get("id", "")
        sents, labs = [], []
        if "sentences" in doc and "labels" in doc:
            sents = doc["sentences"]
            labs = [label2id.get(l, label2id["NONE"]) for l in doc["labels"]]
        elif "sentences" in doc and "annotation" in doc:
            sents = doc["sentences"]
            labs = [label2id.get(l, label2id["NONE"]) for l in doc["annotation"]]
        elif "annotations" in doc:
            for a in doc.get("annotations", []):
                for item in a.get("result", []):
                    val = item.get("value", {})
                    text = val.get("text", "").strip()
                    labs_list = val.get("labels", ["NONE"])
                    if text:
                        sents.append(text)
                        labs.append(label2id.get(labs_list[0], label2id["NONE"]))
        if len(sents) > max_sents:
            sents = sents[:max_sents]
            labs = labs[:max_sents]
        if sents and labs and len(sents) == len(labs):
            all_sents.append(sents)
            all_labels.append(labs)
            doc_ids.append(doc_id)
    return all_sents, all_labels, doc_ids

# ---------------- PROTOTYPE MANAGER ----------------
class ClassPrototypeManager:
    def __init__(self):
        self.prototypes = None
        self.fitted = False

    def fit(self, embeddings, labels):
        embeddings = np.asarray(embeddings)
        labels = np.asarray(labels)
        D = embeddings.shape[1]
        protos = np.zeros((NUM_LABELS, D), dtype=np.float32)
        for k in range(NUM_LABELS):
            mask = labels == k
            if mask.sum() > 0:
                protos[k] = embeddings[mask].mean(axis=0)
            else:
                protos[k] = np.random.randn(D).astype(np.float32) * 1e-3
        self.prototypes = protos
        self.fitted = True
        print(f"[Prototypes] fitted, shape {protos.shape}")

    def get_all_tensor(self, device=None):
        if not self.fitted:
            raise RuntimeError("Prototypes not fitted")
        t = torch.tensor(self.prototypes, dtype=torch.float32)
        if device is not None:
            t = t.to(device)
        return t

    def get_nearest_index(self, embeddings):
        if not self.fitted:
            raise RuntimeError("Prototypes not fitted")
        emb = np.asarray(embeddings)
        emb_n = emb / (np.linalg.norm(emb, axis=1, keepdims=True) + 1e-8)
        proto_n = self.prototypes / (np.linalg.norm(self.prototypes, axis=1, keepdims=True) + 1e-8)
        sims = emb_n @ proto_n.T
        return np.argmax(sims, axis=1)

    def knn_prior(self, embeddings, topk=KNN_K):
        if not self.fitted:
            raise RuntimeError("Prototypes not fitted")
        emb = np.asarray(embeddings)
        emb_n = emb / (np.linalg.norm(emb, axis=1, keepdims=True) + 1e-12)
        proto_n = self.prototypes / (np.linalg.norm(self.prototypes, axis=1, keepdims=True) + 1e-12)
        sims = emb_n @ proto_n.T
        topk_idx = np.argsort(-sims, axis=1)[:, :topk]
        topk_sims = np.take_along_axis(sims, topk_idx, axis=1)
        return topk_sims, topk_idx

# ---------------- DATASET ----------------
class DocumentTextDataset(Dataset):
    def __init__(self, docs_sents, docs_labels):
        self.docs_sents = docs_sents
        self.docs_labels = docs_labels

    def __len__(self):
        return len(self.docs_sents)

    def __getitem__(self, idx):
        return {
            "sentences": self.docs_sents[idx],
            "labels": torch.tensor(self.docs_labels[idx], dtype=torch.long),
        }

def collate_docs(batch, tokenizer):
    max_sents = max(len(b["sentences"]) for b in batch)
    B = len(batch)

    flat_sents = []
    doc_sent_offsets = []
    for b in batch:
        doc_sent_offsets.append(len(flat_sents))
        flat_sents.extend(b["sentences"])

    enc = tokenizer(
        flat_sents,
        padding=True,
        truncation=True,
        max_length=MAX_SEQ_LENGTH,
        return_tensors="pt",
    )
    input_ids_all = enc["input_ids"]
    attn_mask_all = enc["attention_mask"]

    max_tokens = input_ids_all.shape[1]
    input_ids_padded = torch.zeros((B, max_sents, max_tokens), dtype=torch.long)
    attn_mask_padded = torch.zeros((B, max_sents, max_tokens), dtype=torch.long)
    labels_padded = torch.full((B, max_sents), -100, dtype=torch.long)
    lengths = torch.zeros(B, dtype=torch.long)

    for i, b in enumerate(batch):
        num_s = len(b["sentences"])
        lengths[i] = num_s
        start = doc_sent_offsets[i]
        end = start + num_s
        input_ids_padded[i, :num_s] = input_ids_all[start:end]
        attn_mask_padded[i, :num_s] = attn_mask_all[start:end]
        labels_padded[i, :num_s] = b["labels"]

    return input_ids_padded, attn_mask_padded, labels_padded, lengths

# ---------------- FOCAL LOSS ----------------
def focal_loss(logits_masked, labels_masked, gamma=FOCAL_GAMMA, alpha=FOCAL_ALPHA, label_smoothing=LABEL_SMOOTHING):
    ce = F.cross_entropy(logits_masked, labels_masked, reduction='none', label_smoothing=label_smoothing)
    pt = torch.exp(-ce)
    focal = alpha * (1-pt)**gamma * ce
    return focal.mean()

# ---------------- MODEL COMPONENTS ----------------
class RolePrototypicalLayer(nn.Module):
    def __init__(self, hidden_dim):
        super().__init__()
        self.proto = nn.Parameter(torch.randn(NUM_LABELS, hidden_dim) * 0.02)
    
    def forward(self, h):
        h_norm = F.normalize(h, dim=-1)
        proto_norm = F.normalize(self.proto, dim=-1)
        return h_norm @ proto_norm.T

class RoleTransitionMatrix(nn.Module):
    def __init__(self, rtm_lambda=RTM_LAMBDA):
        super().__init__()
        self.A = nn.Parameter(torch.zeros(NUM_LABELS, NUM_LABELS))
        self.rtm_lambda = rtm_lambda
    
    def forward(self, logits):
        lp = logits.log_softmax(-1)
        B, S, C = lp.shape
        
        for t in range(1, S):
            tr = torch.logsumexp(
                lp[:, t-1].unsqueeze(2) + self.A.log_softmax(-1),
                dim=1
            )
            logits[:, t] += self.rtm_lambda * tr
        return logits

class SentenceEncoderFFN(nn.Module):
    def __init__(self, sent_dim, hidden=512, dropout=0.2):
        super().__init__()
        self.fc1 = nn.Linear(sent_dim, hidden)
        self.act = nn.ReLU()
        self.fc2 = nn.Linear(hidden, sent_dim)
        self.dropout = nn.Dropout(dropout)
        self.ln = nn.LayerNorm(sent_dim)

    def forward(self, x):
        h = self.fc1(x)
        h = self.act(h)
        h = self.dropout(h)
        h = self.fc2(h)
        return self.ln(x + h)

class PrototypeAttention(nn.Module):
    def __init__(self, hidden_dim):
        super().__init__()
        self.W = nn.Linear(hidden_dim, hidden_dim, bias=False)

    def forward(self, sent_emb, prototypes):
        B, S, H = sent_emb.shape
        h_proj = self.W(sent_emb)
        scores = torch.matmul(h_proj, prototypes.t())
        attn_weights = torch.softmax(scores, dim=-1)
        proto_ctx = torch.matmul(attn_weights, prototypes)
        return proto_ctx, attn_weights

def prototypical_cosine_loss(reprs, prototypes_tensor, labels, temperature=PROTO_AUX_TEMPERATURE):
    x_norm = F.normalize(reprs, p=2, dim=1)
    p_norm = F.normalize(prototypes_tensor, p=2, dim=1)
    sims = x_norm @ p_norm.t()
    sims = sims * temperature
    loss = F.cross_entropy(sims, labels)
    return loss, sims

def compute_losses(logits, labels, sent_emb_flat, doc_out, prototypes_tensor, rpl_proto, temperature=PROTO_AUX_TEMPERATURE):
    logits_flat = logits.view(-1, NUM_LABELS)
    labels_flat = labels.view(-1)
    mask = labels_flat != -100
    
    if mask.sum() == 0:
        return torch.tensor(0.0, device=logits.device), 0.0, 0.0
    
    logits_masked = logits_flat[mask]
    labels_masked = labels_flat[mask]
    
    ce_loss = focal_loss(logits_masked, labels_masked)
    
    valid_sent_emb = sent_emb_flat[mask]
    proto_loss, _ = prototypical_cosine_loss(valid_sent_emb, prototypes_tensor, labels_masked)
    
    valid_doc_out = doc_out.view(-1, doc_out.size(-1))[mask]
    rpl_sim = F.normalize(valid_doc_out, dim=1) @ F.normalize(rpl_proto, dim=1).T
    rpl_loss = F.cross_entropy(rpl_sim * temperature, labels_masked)
    
    return ce_loss, proto_loss, rpl_loss

# ---------------- ‚úÖ FIXED MODEL (No enable_adapters) ----------------
class ProtoHSLNModel(nn.Module):
    def __init__(self, bert_name=INLEGALBERT_MODEL_NAME, pos_dim=POS_EMB_DIM,
                use_pos_emb=USE_POSITIONAL_EMB, use_knn_prior=USE_KNN_PRIOR,
                knn_prior_dim=KNN_PRIOR_DIM, doc_hidden=LSTM_HIDDEN, dropout=DROPOUT):
        super().__init__()

        base_model = AutoModel.from_pretrained(bert_name)
        
        # ‚úÖ FIXED: Standard AdaLoRA config - NO enable_adapters()
        adalora_config = AdaLoraConfig(
            r=ADALORA_INIT_R,
            target_r=ADALORA_TARGET_R,
            init_r=ADALORA_INIT_R,
            tinit=ADALORA_TINIT,
            tfinal=ADALORA_TFINAL,
            deltaT=ADALORA_DELTA_T,
            lora_alpha=32,
            lora_dropout=0.1,
            bias="none",
            task_type=TaskType.FEATURE_EXTRACTION,
            # ‚úÖ Conservative target modules (works reliably)
            target_modules=["query", "key", "value", "dense"],
        )
        self.bert = get_peft_model(base_model, adalora_config)

        # ‚úÖ FIXED: Print trainable params FIRST (before any enable calls)
        print("\nüìä BERT Trainable Parameters:")
        self.bert.print_trainable_parameters()
        
        # ‚úÖ Enable gradients for ALL LoRA parameters safely
        for name, param in self.bert.named_parameters():
            if any(x in name for x in ["lora_", "ranknum"]):
                param.requires_grad = True

        self.hidden_size = self.bert.config.hidden_size
        
        self.use_pos_emb = use_pos_emb
        if use_pos_emb and pos_dim > 0:
            self.pos_emb = nn.Embedding(1024, pos_dim)
            self.pos_proj = nn.Linear(pos_dim, self.hidden_size) if pos_dim != self.hidden_size else None
        else:
            self.pos_emb = None
            self.pos_proj = None

        self.use_knn_prior = use_knn_prior
        self.knn_prior_dim = knn_prior_dim
        if use_knn_prior and knn_prior_dim > 0:
            self.knn_proj = nn.Sequential(
                nn.Linear(KNN_K, 64), nn.ReLU(), nn.Dropout(0.1), nn.Linear(64, knn_prior_dim),
            )
        else:
            self.knn_proj = None

        self.sent_encoder = SentenceEncoderFFN(self.hidden_size, hidden=512, dropout=dropout)
        self.proto_attn = PrototypeAttention(self.hidden_size)
        
        final_in_dim = self.hidden_size + (self.knn_prior_dim if self.knn_proj else 0)
        self.doc_lstm = nn.LSTM(final_in_dim, doc_hidden, 2, bidirectional=True, 
                               batch_first=True, dropout=dropout)
        
        lstm_out_dim = doc_hidden * 2
        
        self.ce_classifier = nn.Sequential(
            nn.Dropout(dropout), nn.Linear(lstm_out_dim, doc_hidden),
            nn.ReLU(), nn.Dropout(dropout), nn.Linear(doc_hidden, NUM_LABELS),
        )
        self.rpl = RolePrototypicalLayer(lstm_out_dim)
        self.rtm = RoleTransitionMatrix()
        self.head_alpha = nn.Parameter(torch.tensor(2.0))

    def encode_sentences(self, input_ids, attention_mask):
        B, S, T = input_ids.shape
        input_ids_flat = input_ids.view(B * S, T)
        attn_flat = attention_mask.view(B * S, T)
        outputs = self.bert(input_ids_flat, attention_mask=attn_flat)
        sent_emb_flat = outputs.last_hidden_state.mean(dim=1)
        sent_emb = sent_emb_flat.view(B, S, -1)
        return sent_emb, sent_emb_flat

    def forward(self, input_ids, attention_mask, lengths, prototypes_tensor, proto_idx_batch=None, knn_sims=None):
        B, S, T = input_ids.shape
        device = input_ids.device

        sent_emb, sent_emb_flat = self.encode_sentences(input_ids, attention_mask)
        sent_emb = self.sent_encoder(sent_emb)

        if self.pos_emb is not None:
            pos_idx = torch.arange(S, device=device).unsqueeze(0).expand(B, -1)
            pos_vec = self.pos_emb(pos_idx)
            if self.pos_proj is not None:
                pos_vec = self.pos_proj(pos_vec)
            sent_emb = sent_emb + pos_vec

        protos = prototypes_tensor if isinstance(prototypes_tensor, torch.Tensor) else \
                torch.tensor(prototypes_tensor, device=device, dtype=torch.float32)
        proto_ctx, _ = self.proto_attn(sent_emb, protos)
        sent_emb = sent_emb + proto_ctx

        if self.knn_proj is not None and knn_sims is not None:
            knn_feat = self.knn_proj(knn_sims.view(-1, KNN_K)).view(B, S, -1)
            doc_in = torch.cat([sent_emb, knn_feat], dim=2)
        else:
            doc_in = sent_emb

        packed = nn.utils.rnn.pack_padded_sequence(doc_in, lengths.cpu(), batch_first=True, enforce_sorted=False)
        packed_out, _ = self.doc_lstm(packed)
        doc_out, _ = nn.utils.rnn.pad_packed_sequence(packed_out, batch_first=True)

        ce_logits = self.ce_classifier(doc_out)
        rpl_logits = self.rpl(doc_out)
        alpha = torch.sigmoid(self.head_alpha)
        blended_logits = alpha * ce_logits + (1 - alpha) * rpl_logits
        final_logits = self.rtm(blended_logits)
        
        return final_logits, sent_emb_flat, doc_out

# ---------------- ‚úÖ FIXED TRAINER ----------------
class Trainer:
    def __init__(self, model, tokenizer, prototype_manager, device=DEVICE):
        self.model = model.to(device)
        self.tokenizer = tokenizer
        self.prototype_manager = prototype_manager
        self.device = device
        self.global_step = 0

    def _build_train_loader(self, train_dataset):
        if not USE_WEIGHTED_SAMPLER:
            return DataLoader(train_dataset, batch_size=BATCH_DOCS, shuffle=True,
                            collate_fn=lambda b: collate_docs(b, self.tokenizer))
        
        major_labels = []
        for doc_labels in train_dataset.docs_labels:
            if doc_labels:
                doc_counter = Counter(doc_labels)
                major_label = doc_counter.most_common(1)[0][0]
                major_labels.append(major_label)
            else:
                major_labels.append(0)
        
        counts = np.bincount(major_labels, minlength=NUM_LABELS)
        inv_freq = 1.0 / (counts + 1e-6)
        
        weights = inv_freq[np.array(major_labels)].copy()
        for doc_idx, major_label in enumerate(major_labels):
            if major_label in minority_ids:
                weights[doc_idx] *= MINORITY_BOOST
        
        sampler = WeightedRandomSampler(torch.tensor(weights, dtype=torch.double),
                                       num_samples=len(weights), replacement=True)
        return DataLoader(train_dataset, batch_size=BATCH_DOCS, sampler=sampler,
                         collate_fn=lambda b: collate_docs(b, self.tokenizer))

    def safe_adalora_update(self, global_step):
        """‚úÖ Ultra-safe AdaLoRA update - fails silently"""
        try:
            if (global_step >= ADALORA_TINIT and 
                global_step <= ADALORA_TFINAL and 
                global_step % ADALORA_DELTA_T == 0):
                
                if hasattr(self.model.bert, 'update_and_allocate'):
                    self.model.bert.update_and_allocate(global_step)
                    return True
        except Exception:
            pass  # Silent fail - training continues
        return False

    def train(self, train_dataset, dev_dataset, num_epochs=NUM_EPOCHS, lr=LR, 
              proto_weight=PROTO_WEIGHT, rpl_weight=RPL_WEIGHT):
        start_time = time.time()
        optimizer = torch.optim.AdamW(self.model.parameters(), lr=lr, weight_decay=WEIGHT_DECAY)
        train_loader = self._build_train_loader(train_dataset)
        total_steps = max(1, len(train_loader) * num_epochs)
        warmup_steps = max(1, int(0.05 * total_steps))
        scheduler = get_linear_schedule_with_warmup(optimizer, warmup_steps, total_steps)
        
        prototypes_tensor = self.prototype_manager.get_all_tensor(device=self.device)
        best_macro_f1 = -1.0
        best_ckpt = None
        history = []
        adalora_updates = 0

        print("üöÄ Starting training with FIXED AdaLoRA + imbalance handling...")
        print(f"‚úÖ AdaLoRA: init_r={ADALORA_INIT_R}, target_r={ADALORA_TARGET_R}")
        print(f"‚úÖ Minority boost: {MINORITY_BOOST}x | Focal gamma: {FOCAL_GAMMA}")
        
        for epoch in range(1, num_epochs + 1):
            self.model.train()
            t0 = time.time()
            running_ce = running_proto = running_rpl = running_total = 0.0
            n_samples = 0

            for batch_idx, (input_ids, attn_mask, labels, lengths) in enumerate(train_loader):
                input_ids = input_ids.to(self.device)
                attn_mask = attn_mask.to(self.device)
                labels = labels.to(self.device)
                lengths = lengths.to(self.device)

                # ‚úÖ KNN prototype features
                with torch.no_grad():
                    sent_emb, sent_emb_flat_for_knn = self.model.encode_sentences(input_ids, attn_mask)
                    sent_emb_flat_for_knn_np = sent_emb_flat_for_knn.view(-1, sent_emb_flat_for_knn.size(-1)).cpu().numpy()
                    sims, _ = self.prototype_manager.knn_prior(sent_emb_flat_for_knn_np, topk=KNN_K)
                    knn_sims_tensor = torch.tensor(sims, dtype=torch.float32, device=self.device)
                    nearest_idx_flat = self.prototype_manager.get_nearest_index(sent_emb_flat_for_knn_np)
                    proto_idx_batch = torch.tensor(nearest_idx_flat, dtype=torch.long, 
                                                 device=self.device).view(input_ids.shape[:2])

                logits, sent_emb_flat, doc_out = self.model(
                    input_ids, attn_mask, lengths, prototypes_tensor,
                    proto_idx_batch, knn_sims_tensor
                )

                ce_loss, proto_loss, rpl_loss = compute_losses(
                    logits, labels, sent_emb_flat, doc_out,
                    prototypes_tensor, self.model.rpl.proto
                )
                total_loss = ce_loss + proto_weight * proto_loss + rpl_weight * rpl_loss

                optimizer.zero_grad()
                total_loss.backward()
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), GRAD_CLIP)
                optimizer.step()
                scheduler.step()
                
                # ‚úÖ Safe AdaLoRA update
                self.global_step += 1
                if self.safe_adalora_update(self.global_step):
                    adalora_updates += 1

                mask = labels.view(-1) != -100
                n = mask.sum().item()
                running_ce += ce_loss.item() * n
                running_proto += proto_loss.item() * n
                running_rpl += rpl_loss.item() * n
                running_total += total_loss.item() * n
                n_samples += n

            val_results = self.evaluate(dev_dataset, measure_time=True)
            avg_ce = running_ce / max(1, n_samples)
            avg_proto = running_proto / max(1, n_samples)
            avg_rpl = running_rpl / max(1, n_samples)
            avg_total = running_total / max(1, n_samples)

            epoch_time = time.time() - t0
            mem_usage = get_memory_usage(self.device)

            history.append({
                "epoch": epoch, 
                "train_ce": avg_ce, "train_proto": avg_proto, 
                "train_rpl": avg_rpl, "train_total": avg_total,
                "val_acc": val_results["metrics"]["accuracy"], 
                "val_macro_f1": val_results["metrics"]["macro_f1"],
                "epoch_time_s": epoch_time,
                "mem_usage": mem_usage,
                "adalora_updates": adalora_updates,
                "time": time.time() - start_time,
            })

            print(f"Epoch {epoch}/{num_epochs} | train:{avg_total:.4f} | "
                  f"val:{val_results['total_loss']:.4f} acc:{val_results['metrics']['accuracy']:.4f} "
                  f"macroF1:{val_results['metrics']['macro_f1']:.4f} | "
                  f"time:{epoch_time/60:.1f}m | mem:{mem_usage.get('allocated', 0):.1f}MB | "
                  f"AdaLoRA updates: {adalora_updates}")

            if val_results["metrics"]["macro_f1"] > best_macro_f1 + 1e-4:
                best_macro_f1 = val_results["metrics"]["macro_f1"]
                best_ckpt_path = os.path.join(OUT_DIR, f"best_epoch{epoch}_f1{best_macro_f1:.4f}.pt")
                torch.save({
                    "model_state_dict": self.model.state_dict(),
                    "optimizer_state_dict": optimizer.state_dict(),
                    "scheduler_state_dict": scheduler.state_dict(),
                    "epoch": epoch,
                    "macro_f1": best_macro_f1,
                }, best_ckpt_path)
                best_ckpt = best_ckpt_path

        total_train_time = time.time() - start_time
        pd.DataFrame(history).to_csv(os.path.join(OUT_DIR, "training_history.csv"), index=False)
        
        print(f"üèÜ Total Training Time: {total_train_time/60:.1f} minutes")
        print(f"üèÜ AdaLoRA updates performed: {adalora_updates}")
        print(f"üèÜ Best checkpoint: {best_ckpt}")
        return best_ckpt, total_train_time

    def evaluate(self, dataset, measure_time=False):
        start_time = time.time()
        self.model.eval()
        loader = DataLoader(dataset, batch_size=2, shuffle=False,
                           collate_fn=lambda b: collate_docs(b, self.tokenizer))
        prototypes_tensor = self.prototype_manager.get_all_tensor(device=self.device)
        all_preds, all_trues = [], []
        running_ce = 0.0
        n_samples = 0

        with torch.no_grad():
            for input_ids, attn_mask, labels, lengths in loader:
                input_ids = input_ids.to(self.device)
                attn_mask = attn_mask.to(self.device)
                labels = labels.to(self.device)
                lengths = lengths.to(self.device)

                sent_emb, sent_emb_flat_for_knn = self.model.encode_sentences(input_ids, attn_mask)
                sent_emb_flat_for_knn_np = sent_emb_flat_for_knn.view(-1, sent_emb_flat_for_knn.size(-1)).cpu().numpy()
                sims, _ = self.prototype_manager.knn_prior(sent_emb_flat_for_knn_np, topk=KNN_K)
                knn_sims_tensor = torch.tensor(sims, dtype=torch.float32, device=self.device)
                nearest_idx_flat = self.prototype_manager.get_nearest_index(sent_emb_flat_for_knn_np)
                proto_idx_batch = torch.tensor(nearest_idx_flat, dtype=torch.long, 
                                             device=self.device).view(input_ids.shape[:2])

                logits, sent_emb_flat, doc_out = self.model(
                    input_ids, attn_mask, lengths, prototypes_tensor,
                    proto_idx_batch, knn_sims_tensor
                )
                
                logits_flat = logits.view(-1, NUM_LABELS)
                labels_flat = labels.view(-1)
                mask = labels_flat != -100
                if mask.sum() == 0:
                    continue
                    
                logits_masked = logits_flat[mask]
                labels_masked = labels_flat[mask]
                
                ce_loss = focal_loss(logits_masked, labels_masked)
                preds = torch.argmax(logits_masked, dim=1).cpu().numpy()
                
                all_preds.extend(preds.tolist())
                all_trues.extend(labels_masked.cpu().numpy().tolist())
                
                n = labels_masked.size(0)
                running_ce += ce_loss.item() * n
                n_samples += n

        if n_samples == 0:
            return {"metrics": {"accuracy": 0, "macro_f1": 0}, "total_loss": 0, "inference_time": 0}

        avg_ce = running_ce / n_samples
        detailed_metrics = compute_detailed_metrics(all_trues, all_preds)
        
        cls_report = classification_report(
            [id2label[x] for x in all_trues], [id2label[x] for x in all_preds],
            digits=4, zero_division=0
        )
        
        result = {
            "metrics": detailed_metrics,
            "total_loss": avg_ce,
            "classification_report": cls_report,
            "all_preds": all_preds, 
            "all_trues": all_trues,
            "sample_count": n_samples
        }
        
        if measure_time:
            result["inference_time"] = time.time() - start_time
            result["mem_usage"] = get_memory_usage(self.device)
        
        return result

# ---------------- MAIN ----------------
def main():
    total_start_time = time.time()
    print("üöÄ Loading data...")
    train_docs = load_jsonl(TRAIN_PATH)
    dev_docs = load_jsonl(DEV_PATH)
    test_docs = load_jsonl(TEST_PATH)
    print(f"Dataset sizes - Train: {len(train_docs)}, Dev: {len(dev_docs)}, Test: {len(test_docs)}")

    train_sents, train_labels, _ = extract_data(train_docs)
    dev_sents, dev_labels, _ = extract_data(dev_docs)
    test_sents, test_labels, _ = extract_data(test_docs)

    tokenizer = AutoTokenizer.from_pretrained(INLEGALBERT_MODEL_NAME)
    temp_bert = AutoModel.from_pretrained(INLEGALBERT_MODEL_NAME).to(DEVICE)
    temp_bert.eval()
    
    print("Computing train prototypes...")
    with torch.no_grad():
        flat_train_sents = [s for doc in train_sents for s in doc]
        flat_train_labels = np.array([l for doc in train_labels for l in doc], dtype=np.int64)
        train_embs = []
        for i in range(0, len(flat_train_sents), 32):
            batch = flat_train_sents[i:i+32]
            enc = tokenizer(batch, padding=True, truncation=True, 
                          max_length=MAX_SEQ_LENGTH, return_tensors="pt").to(DEVICE)
            out = temp_bert(**enc).last_hidden_state.mean(dim=1)
            train_embs.append(out.cpu().numpy())
        train_embs = np.vstack(train_embs)

    proto_mgr = ClassPrototypeManager()
    proto_mgr.fit(train_embs, flat_train_labels)

    train_dataset = DocumentTextDataset(train_sents, train_labels)
    dev_dataset = DocumentTextDataset(dev_sents, dev_labels)
    test_dataset = DocumentTextDataset(test_sents, test_labels)

    print("‚úÖ Initializing ProtoHSLNModel with FIXED AdaLoRA...")
    model = ProtoHSLNModel()
    
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"üìä Total params: {total_params:,} | Trainable: {trainable_params:,} "
          f"({100.0 * trainable_params / total_params:.2f}%)")

    trainer = Trainer(model, tokenizer, proto_mgr, device=DEVICE)
    
    best_ckpt, train_time = trainer.train(train_dataset, dev_dataset, 
                                        proto_weight=PROTO_WEIGHT, rpl_weight=RPL_WEIGHT)
    
    if best_ckpt:
        ckpt = torch.load(best_ckpt, map_location=DEVICE)
        model.load_state_dict(ckpt["model_state_dict"])
        print("‚úÖ Loaded best model")

    print("\nüîç FINAL TEST EVALUATION")
    test_start_time = time.time()
    test_results = trainer.evaluate(test_dataset, measure_time=True)
    test_inference_time = test_results.get("inference_time", 0)
    
    print(f"‚è±Ô∏è  Test Inference Time: {test_inference_time:.2f}s")
    print(f"üìä Test Memory Usage: {test_results.get('mem_usage', {})}")
    print(f"üìà Test Accuracy:     {test_results['metrics']['accuracy']:.4f}")
    print(f"üéØ Test Macro-F1:     {test_results['metrics']['macro_f1']:.4f}")
    print(f"üîç Minority Macro-F1: {test_results['metrics']['minority_macro_f1']:.4f}")
    print("\nüìã Test Classification Report:")
    print(test_results["classification_report"])
    
    print("\nüéØ Minority Class F1 Scores:")
    for cls, f1 in sorted(test_results['metrics']['minority_f1_per_class'].items()):
        print(f"  {cls}: {f1:.4f}")
    
    # Save comprehensive results
    results_summary = {
        'total_train_time_minutes': train_time / 60,
        'test_inference_time_seconds': test_inference_time,
        **test_results['metrics']
    }
    pd.DataFrame([results_summary]).to_csv(os.path.join(OUT_DIR, "final_results_summary.csv"), index=False)
    
    pred_df = pd.DataFrame({
        "true": [id2label[x] for x in test_results["all_trues"]],
        "pred": [id2label[x] for x in test_results["all_preds"]],
    })
    pred_df.to_csv(os.path.join(OUT_DIR, "final_test_preds.csv"), index=False)
    
    total_runtime = time.time() - total_start_time
    print(f"\nüíæ Saved predictions to {OUT_DIR}/final_test_preds.csv")
    print(f"‚è±Ô∏è  Total Runtime: {total_runtime/60:.1f} minutes")
    print(f"üèÜ FIXED AdaLoRA + Imbalance Handling Training COMPLETED SUCCESSFULLY! ‚úÖ")

if __name__ == "__main__":
    main()
