#### Âä®ÊÄÅÂêØÁî®‰∏≠Èó¥ÁõëÁù£

In [32]:
# train_latent_idrr_corrected.py
import os, json, random
os.environ["CUDA_VISIBLE_DEVICES"] = "6"
import numpy as np
from typing import List, Dict
from tqdm import tqdm
import torch.nn.functional as F
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

from transformers import AutoTokenizer, AutoModel, get_linear_schedule_with_warmup,T5EncoderModel
from sklearn.metrics import f1_score, classification_report, accuracy_score

# -----------------------
# Config
# -----------------------
SEED = 42
EPOCHS = 4
BATCH_SIZE = 8
LR =3e-5
MAX_LEN = 256
MODEL_NAME = "../models/flan-t5-base"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DATA_DIR = "../datasets/pdtb3_T5"
OUT_DIR = "outputs_latent_idrr_denoise_Intersupervised"
LOOP_K = 3
AUX_WEIGHT = 0.05
NOISE_STD = 0.02
LAMBDA_DENOISE = 0
DENOISE_LAST_ONLY = True
AUX_WARMUP_EPOCHS = 1  # üöÄ Âä®ÊÄÅÂêØÁî®‰∏≠Èó¥ÁõëÁù£
NUM_WORKERS = 0

In [33]:
# -----------------------
# Utilities
# -----------------------
def set_seed(seed=SEED):
    random.seed(seed); np.random.seed(seed); torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
set_seed()

# -----------------------
# Dataset
# -----------------------
class IDRRDataset(Dataset):
    def __init__(self, path, tokenizer, max_len=MAX_LEN, prompt_prefix="relation classification: Arg1: "):
        self.samples = []
        with open(path, "r", encoding="utf-8") as f:
            for line in f:
                if line.strip():
                    self.samples.append(json.loads(line))
        self.tok = tokenizer; self.max_len = max_len; self.prompt = prompt_prefix

    def __len__(self): return len(self.samples)
    def __getitem__(self, idx):
        s = self.samples[idx]
        text = f"{self.prompt}{s['Arg1']}  Arg2: {s['Arg2']}"
        inputs = self.tok(text, truncation=True, padding='max_length', max_length=self.max_len, return_tensors="pt")
        return {
            "input_ids": inputs["input_ids"].squeeze(0),
            "attention_mask": inputs["attention_mask"].squeeze(0),
            "label": s["Label"]
        }

def collate_fn(batch, tokenizer):
    return {
        "input_ids": torch.stack([b['input_ids'] for b in batch]),
        "attention_mask": torch.stack([b['attention_mask'] for b in batch]),
        "labels": [b['label'] for b in batch]
    }

def build_label_maps(paths: List[str]):
    labels = set()
    for p in paths:
        with open(p, 'r', encoding='utf-8') as f:
            for line in f:
                if line.strip():
                    labels.add(json.loads(line)['Label'])
    labels = sorted(list(labels))
    id2label = {i:l for i,l in enumerate(labels)}
    label2id = {l:i for i,l in id2label.items()}
    return label2id, id2label

# -----------------------
# Model components
# -----------------------
class RefinerBlock(nn.Module):
    def __init__(self, d_model, nhead=8, dim_ff=2048, dropout=0.1):
        super().__init__()
        self.attn = nn.MultiheadAttention(embed_dim=d_model, num_heads=nhead, batch_first=True)
        self.ff = nn.Sequential(nn.Linear(d_model, dim_ff), nn.GELU(), nn.Linear(dim_ff, d_model))
        self.ln1 = nn.LayerNorm(d_model)
        self.ln2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
    def forward(self, r, H, H_mask=None):
        key_padding_mask = (H_mask == 0) if H_mask is not None else None
        att_out, _ = self.attn(query=r, key=H, value=H, key_padding_mask=key_padding_mask)
        r = self.ln1(r + self.dropout(att_out))
        r = self.ln2(r + self.dropout(self.ff(r)))
        return r

class LatentIDRRModel(nn.Module):
    def __init__(self, model_name=MODEL_NAME, label2id=None, loop_k=LOOP_K, aux_weight=AUX_WEIGHT, noise_std=NOISE_STD, use_separate_refiners=False):
        super().__init__()
        assert label2id is not None
        base = AutoModel.from_pretrained(model_name)
        # if seq2seq (T5), get encoder
        self.encoder = base.get_encoder() if hasattr(base, "get_encoder") else base
        self.loop_k = loop_k
        self.noise_std = noise_std
        self.aux_weight = aux_weight
        self.d_model = self.encoder.config.hidden_size
        self.use_separate_refiners = use_separate_refiners
        if use_separate_refiners:
            self.refiners = nn.ModuleList([RefinerBlock(d_model=self.d_model) for _ in range(self.loop_k)])
        else:
            self.refiner = RefinerBlock(d_model=self.d_model)
        self.pool_proj = nn.Linear(self.d_model, self.d_model)
        # classifier
        hid = max(self.d_model//2, 32)
        self.classifier = nn.Sequential(nn.Linear(self.d_model, hid), nn.GELU(), nn.LayerNorm(hid), nn.Linear(hid, len(label2id)))
        self.label2id = label2id

    def forward(self, input_ids, attention_mask):
        # ---- 1Ô∏è‚É£ ÁºñÁ†ÅËæìÂÖ• ----
        enc = self.encoder(input_ids=input_ids, attention_mask=attention_mask, return_dict=True)
        H = enc.last_hidden_state   # (B,L,D) ÂéüÂßã‰∏ä‰∏ãÊñáÈöêÁ©∫Èó¥Ë°®ÂæÅ (contextual embeddings)
        mask = attention_mask.unsqueeze(-1)  # (B,L,1)

        # ---- 2Ô∏è‚É£ Âπ≥ÂùáÊ±†ÂåñÂæóÂà∞ÂàùÂßãÂÖ≥Á≥ªË°®Á§∫ r ----
        # r Áõ∏ÂΩì‰∫é ‚ÄúÂàùÂßãÊΩúÂú®ÊÄùÁª¥Áä∂ÊÄÅÔºàinitial latent reasoning stateÔºâ‚Äù
        sum_h = (H * mask).sum(dim=1)
        lengths = mask.sum(dim=1).clamp(min=1)
        r = (sum_h / lengths).unsqueeze(1)   # (B,1,D)
        r = self.pool_proj(r)

        aux_logits = []

        # =====================================================
        # üöÄ 3Ô∏è‚É£ ËøõÂÖ•‚ÄúÂæ™ÁéØÈöêÂºèÊÄùËÄÉ (Iterative Latent Reasoning Loop)‚Äù
        # =====================================================
        for t in range(self.loop_k):

            # ---- (a) Âô™Â£∞Ê≥®ÂÖ• (Latent Noise Injection) ----
            # Âú®ÊØè‰∏ÄËΩÆÂæ™ÁéØÂâçÔºåÂØπ‰∏ä‰∏ãÊñáË°®ÂæÅ H Âä†ÈöèÊú∫Âô™Â£∞Ôºå
            # ÁõÆÁöÑÊòØÊ®°Êãü‚Äú‰ø°ÊÅØ‰∏çÁ°ÆÂÆöÊÄß‚ÄùÂπ∂Ëø´‰ΩøÊ®°ÂûãÂ≠¶‰ºö‰ªéÂô™Â£∞‰∏≠ÊÅ¢Â§çËØ≠‰πâ„ÄÇ
            # Ëøô‰∏ÄÊ≠•ÊòØ‚ÄúÂéªÂô™Êé®ÁêÜ‚ÄùÁöÑËæìÂÖ•ÂáÜÂ§á„ÄÇ
            if self.noise_std is not None and self.noise_std > 0:
                noisy_H = H + torch.randn_like(H) * (self.noise_std)
            else:
                noisy_H = H

            # ---- (b) Á≤æÂåñÊõ¥Êñ∞ (Refinement Step) ----
            # RefinerBlock ‰ΩøÁî®Ê≥®ÊÑèÂäõÊú∫Âà∂ÔºåËÆ©ÂΩìÂâçÁöÑÂÖ≥Á≥ªÁä∂ÊÄÅ r
            # Âú® noisy_H ÁöÑ‰∏ä‰∏ãÊñá‰∏ãÂæóÂà∞Êõ¥Êñ∞„ÄÇ
            # ÂèØÁêÜËß£‰∏∫ ‚Äúr Âú®Âô™Â£∞Âπ≤Êâ∞ÁöÑÊΩúÁ©∫Èó¥‰∏≠ÈáçÊñ∞ÊÄùËÄÉÂπ∂Ê†°Ê≠£Ëá™Ë∫´ËØ≠‰πâ‚Äù„ÄÇ
            if self.use_separate_refiners:
                r = r + self.refiners[t](r, noisy_H, H_mask=attention_mask)
            else:
                r = r + self.refiner(r, noisy_H, H_mask=attention_mask)

            # ---- (c) ‰∏≠Èó¥ÁõëÁù£ (Auxiliary Supervision) ----
            # Â¶ÇÊûúÂêØÁî®ËæÖÂä©ÁõëÁù£ÔºåÊØè‰∏ÄËΩÆÂæ™ÁéØÈÉΩ‰ºö‰∫ßÁîü‰∏Ä‰∏™È¢ÑÊµãÔºå
            # Áî®‰∫éÈºìÂä±Ê®°ÂûãÂú®ÊØèÊ¨°‚ÄúÊÄùËÄÉ‚ÄùÂêéÈÉΩÊúâÊõ¥Ê∏ÖÊô∞ÁöÑÂÖ≥Á≥ªË°®Á§∫„ÄÇ
            if self.aux_weight is not None and self.aux_weight > 0:
                aux_logits.append(self.classifier(r.squeeze(1)))

        # =====================================================
        # üß© 4Ô∏è‚É£ ÊúÄÁªàÂàÜÁ±ªÔºöËæìÂá∫ÊúÄÂêé‰∏ÄÊ≠•Á≤æÂåñÁªìÊûú
        # =====================================================
        logits = self.classifier(r.squeeze(1))
        return logits, aux_logits



# -----------------------
# Train / Eval
# -----------------------
def encode_labels(labels, label2id):
    return torch.tensor([label2id[l] for l in labels], dtype=torch.long)

def evaluate(model, dataloader, label2id, device):
    model.eval()
    id2label = {v:k for k,v in label2id.items()}
    preds, gts = [], []
    with torch.no_grad():
        for b in tqdm(dataloader, desc="Eval", leave=False):
            input_ids = b['input_ids'].to(device)
            attention_mask = b['attention_mask'].to(device)
            labels_raw = b['labels']
            logits, _ = model(input_ids, attention_mask)
            pred_ids = logits.argmax(dim=-1).cpu().numpy().tolist()
            preds.extend([id2label[i] for i in pred_ids])
            gts.extend(labels_raw)
    macro = f1_score(gts, preds, average='macro', labels=list(label2id.keys()))
    acc = accuracy_score(gts, preds)
    print(classification_report(gts, preds, digits=4))
    return {"macro_f1": macro, "accuracy": acc}

def train(model, train_dl, dev_dl, optimizer, scheduler, label2id, device):
    model.to(device)
    best_f1 = 0.0
    loss_fn = nn.CrossEntropyLoss()

    for epoch in range(EPOCHS):
        model.train()
        losses = []
        # linear warm-up of aux: increases from 0 -> model.aux_weight over AUX_WARMUP_EPOCHS
        if AUX_WARMUP_EPOCHS <= 0:
            current_aux_w = float(model.aux_weight or 0.0)
        else:
            current_aux_w = float(model.aux_weight or 0.0) * min(1.0, (epoch + 1) / float(AUX_WARMUP_EPOCHS))
        use_aux = current_aux_w > 0.0

        pbar = tqdm(train_dl, desc=f"Train Epoch {epoch+1} (aux_w={current_aux_w:.4f})")
        for b in pbar:
            input_ids = b['input_ids'].to(device)
            attention_mask = b['attention_mask'].to(device)
            labels_raw = b['labels']
            labels = encode_labels(labels_raw, label2id).to(device)

            logits, aux_logits = model(input_ids, attention_mask)
            loss = loss_fn(logits, labels)
            if use_aux and len(aux_logits) > 0:
                for al in aux_logits:
                    loss = loss + current_aux_w * loss_fn(al, labels)

            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            scheduler.step()

            losses.append(loss.item())
            pbar.set_postfix({"loss": f"{np.mean(losses[-50:]):.4f}"})

        dev_metrics = evaluate(model, dev_dl, label2id, device)
        print(f"Epoch {epoch+1} dev macro-F1: {dev_metrics['macro_f1']:.4f}, acc: {dev_metrics['accuracy']:.4f}")
        if dev_metrics['macro_f1'] > best_f1:
            best_f1 = dev_metrics['macro_f1']
            os.makedirs(OUT_DIR, exist_ok=True)
            torch.save(model.state_dict(), os.path.join(OUT_DIR, "best_model.pt"))
            print("‚úÖ Saved best model.")
    print("Training complete.")

# -----------------------
# Main
# -----------------------
def main():
    os.makedirs(OUT_DIR, exist_ok=True)
    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
    label2id, id2label = build_label_maps([
        os.path.join(DATA_DIR, "train.jsonl"),
        os.path.join(DATA_DIR, "dev.jsonl"),
        os.path.join(DATA_DIR, "test.jsonl")
    ])
    print("Labels:", label2id)

    train_ds = IDRRDataset(os.path.join(DATA_DIR, "train.jsonl"), tokenizer)
    dev_ds = IDRRDataset(os.path.join(DATA_DIR, "dev.jsonl"), tokenizer)
    test_ds = IDRRDataset(os.path.join(DATA_DIR, "test.jsonl"), tokenizer)

    train_dl = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True,
                          collate_fn=lambda x: collate_fn(x, tokenizer), num_workers=NUM_WORKERS)
    dev_dl = DataLoader(dev_ds, batch_size=BATCH_SIZE, shuffle=False,
                        collate_fn=lambda x: collate_fn(x, tokenizer), num_workers=NUM_WORKERS)
    test_dl = DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=False,
                         collate_fn=lambda x: collate_fn(x, tokenizer), num_workers=NUM_WORKERS)

    model = LatentIDRRModel(
        model_name=MODEL_NAME,
        label2id=label2id,
        loop_k=LOOP_K,
        aux_weight=AUX_WEIGHT,
        noise_std=NOISE_STD,
        use_separate_refiners=False
    )

    optimizer = torch.optim.AdamW(model.parameters(), lr=LR)
    total_steps = len(train_dl) * EPOCHS
    scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=max(1,int(0.06*total_steps)), num_training_steps=max(1,total_steps))

    print("Device:", DEVICE)
    train(model, train_dl, dev_dl, optimizer, scheduler, label2id, DEVICE)

    # final test
    best_path = os.path.join(OUT_DIR, "best_model.pt")
    if os.path.exists(best_path):
        model.load_state_dict(torch.load(best_path, map_location=DEVICE))
        print("\nüß™ Final Test Evaluation")
        test_metrics = evaluate(model, test_dl, label2id, DEVICE)
        print(f"Final Test Macro-F1: {test_metrics['macro_f1']:.4f}, Acc: {test_metrics['accuracy']:.4f}")
    else:
        print("No best model saved, skipping test eval.")


In [2]:
# ----------------------- 
# Utilities
# -----------------------
def set_seed(seed=SEED):
    random.seed(seed); np.random.seed(seed); torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
set_seed()

# -----------------------
# Dataset
# -----------------------
class IDRRDataset(Dataset):
    def __init__(self, path, tokenizer, max_len=MAX_LEN, prompt_prefix="relation classification: Arg1: "):
        self.samples = []
        with open(path, "r", encoding="utf-8") as f:
            for line in f:
                if line.strip():
                    self.samples.append(json.loads(line))
        self.tok = tokenizer; self.max_len = max_len; self.prompt = prompt_prefix

    def __len__(self): return len(self.samples)
    def __getitem__(self, idx):
        s = self.samples[idx]
        text = f"{self.prompt}{s['Arg1']}  Arg2: {s['Arg2']}"
        inputs = self.tok(text, truncation=True, padding='max_length', max_length=self.max_len, return_tensors="pt")
        return {
            "input_ids": inputs["input_ids"].squeeze(0),
            "attention_mask": inputs["attention_mask"].squeeze(0),
            "label": s["Label"]
        }

def collate_fn(batch, tokenizer):
    return {
        "input_ids": torch.stack([b['input_ids'] for b in batch]),
        "attention_mask": torch.stack([b['attention_mask'] for b in batch]),
        "labels": [b['label'] for b in batch]
    }

def build_label_maps(paths: List[str]):
    labels = set()
    for p in paths:
        with open(p, 'r', encoding='utf-8') as f:
            for line in f:
                if line.strip():
                    labels.add(json.loads(line)['Label'])
    labels = sorted(list(labels))
    id2label = {i:l for i,l in enumerate(labels)}
    label2id = {l:i for i,l in id2label.items()}
    return label2id, id2label

# -----------------------
# Model components
# -----------------------
class RefinerBlock(nn.Module):
    def __init__(self, d_model, nhead=8, dim_ff=2048, dropout=0.1):
        super().__init__()
        self.attn = nn.MultiheadAttention(embed_dim=d_model, num_heads=nhead, batch_first=True)
        self.ff = nn.Sequential(nn.Linear(d_model, dim_ff), nn.GELU(), nn.Linear(dim_ff, d_model))
        self.ln1 = nn.LayerNorm(d_model)
        self.ln2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
    def forward(self, r, H, H_mask=None):
        key_padding_mask = (H_mask == 0) if H_mask is not None else None
        att_out, _ = self.attn(query=r, key=H, value=H, key_padding_mask=key_padding_mask)
        r = self.ln1(r + self.dropout(att_out))
        r = self.ln2(r + self.dropout(self.ff(r)))
        return r


class LatentIDRRModel(nn.Module):
    def __init__(self, model_name="google/flan-t5-base", num_labels=11,
                 loop_K=2, noise_std=0.02,
                 aux_weight=0.05, lambda_denoise=0.01,
                 denoise_last_only=True, aux_warmup_epochs=1,
                 use_separate_refiners=False):
        super().__init__()

        self.encoder = T5EncoderModel.from_pretrained(model_name)
        self.loop_K = loop_K
        self.noise_std = noise_std
        self.aux_weight = aux_weight
        self.lambda_denoise = lambda_denoise
        self.denoise_last_only = denoise_last_only
        self.aux_warmup_epochs = aux_warmup_epochs
        self.use_separate_refiners = use_separate_refiners

        hidden_size = self.encoder.config.d_model
        self.pool_proj = nn.Linear(hidden_size, hidden_size)

        def make_refiner():
            return nn.Sequential(
                nn.Linear(hidden_size * 2, hidden_size),
                nn.ReLU(),
                nn.Linear(hidden_size, hidden_size)
            )

        if self.use_separate_refiners:
            self.refiners = nn.ModuleList([make_refiner() for _ in range(loop_K)])
        else:
            self.refiner = make_refiner()

        self.classifier = nn.Linear(hidden_size, num_labels)
        self.aux_head = nn.Linear(hidden_size, num_labels)

    def refine_once(self, r, H, H_mask, refiner):
        B, L, D = H.size()
        r_expand = r.expand(B, L, D)
        concat = torch.cat([r_expand, H], dim=-1)
        delta = refiner(concat)
        return (delta * H_mask.unsqueeze(-1)).mean(dim=1, keepdim=True)

    def forward(self, input_ids, attention_mask, labels=None, epoch=0):
        enc = self.encoder(input_ids=input_ids, attention_mask=attention_mask, return_dict=True)
        H = enc.last_hidden_state  # (B, L, D)

        mask = attention_mask.unsqueeze(-1)
        sum_h = (H * mask).sum(dim=1)
        lengths = mask.sum(dim=1).clamp(min=1)
        r_clean = (sum_h / lengths).unsqueeze(1)
        r_clean_proj = self.pool_proj(r_clean)

        if self.noise_std > 0:
            H_noisy = H + torch.randn_like(H) * self.noise_std
        else:
            H_noisy = H

        r = self.pool_proj(r_clean)
        aux_logits, denoise_preds = [], []

        for t in range(self.loop_K):
            if self.use_separate_refiners:
                r = r + self.refine_once(r, H_noisy, attention_mask, self.refiners[t])
            else:
                r = r + self.refine_once(r, H_noisy, attention_mask, self.refiner)
            denoise_preds.append(r.squeeze(1))

            if (self.aux_weight > 0) and (epoch >= self.aux_warmup_epochs):
                aux_logits.append(self.aux_head(r.squeeze(1)))

        logits = self.classifier(r.squeeze(1))

        if labels is not None:
            loss_main = F.cross_entropy(logits, labels)

            # ‰∏≠Èó¥ÁõëÁù£
            if aux_logits:
                loss_aux = sum(F.cross_entropy(aux, labels) for aux in aux_logits) / len(aux_logits)
            else:
                loss_aux = torch.tensor(0.0, device=logits.device)

            # ÂéªÂô™Á∫¶ÊùüÊçüÂ§±
            if self.lambda_denoise > 0 and len(denoise_preds) > 0:
                if self.denoise_last_only:
                    loss_denoise = F.mse_loss(denoise_preds[-1], r_clean_proj.squeeze(1))
                else:
                    loss_denoise = sum(F.mse_loss(r_t, r_clean_proj.squeeze(1)) for r_t in denoise_preds) / len(denoise_preds)
            else:
                loss_denoise = torch.tensor(0.0, device=logits.device)

            loss = loss_main + self.aux_weight * loss_aux + self.lambda_denoise * loss_denoise

            return {
                "loss": loss,
                "loss_main": loss_main.detach(),
                "loss_aux": loss_aux.detach(),
                "loss_denoise": loss_denoise.detach(),
                "logits": logits
            }

        return {"logits": logits}

# -----------------------
# Train / Eval
# -----------------------
def encode_labels(labels, label2id):
    return torch.tensor([label2id[l] for l in labels], dtype=torch.long)

def evaluate(model, dataloader, label2id, device):
    model.eval()
    id2label = {v:k for k,v in label2id.items()}
    preds, gts = [], []
    with torch.no_grad():
        for b in tqdm(dataloader, desc="Eval", leave=False):
            input_ids = b['input_ids'].to(device)
            attention_mask = b['attention_mask'].to(device)
            labels_raw = b['labels']
            out = model(input_ids, attention_mask)
            logits = out["logits"]
            pred_ids = logits.argmax(dim=-1).cpu().numpy().tolist()
            preds.extend([id2label[i] for i in pred_ids])
            gts.extend(labels_raw)
    macro = f1_score(gts, preds, average='macro', labels=list(label2id.keys()))
    acc = accuracy_score(gts, preds)
    print(classification_report(gts, preds, digits=4))
    return {"macro_f1": macro, "accuracy": acc}

def train(model, train_dl, dev_dl, optimizer, scheduler, label2id, device):
    model.to(device)
    best_f1 = 0.0

    for epoch in range(EPOCHS):
        model.train()
        losses = []
        pbar = tqdm(train_dl, desc=f"Train Epoch {epoch+1}")

        for b in pbar:
            input_ids = b['input_ids'].to(device)
            attention_mask = b['attention_mask'].to(device)
            labels = encode_labels(b['labels'], label2id).to(device)

            out = model(input_ids, attention_mask, labels=labels, epoch=epoch)
            loss = out["loss"]

            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            scheduler.step()

            losses.append(loss.item())
            pbar.set_postfix({"loss": f"{np.mean(losses[-50:]):.4f}"})

        dev_metrics = evaluate(model, dev_dl, label2id, device)
        print(f"Epoch {epoch+1} dev macro-F1: {dev_metrics['macro_f1']:.4f}, acc: {dev_metrics['accuracy']:.4f}")
        if dev_metrics['macro_f1'] > best_f1:
            best_f1 = dev_metrics['macro_f1']
            os.makedirs(OUT_DIR, exist_ok=True)
            torch.save(model.state_dict(), os.path.join(OUT_DIR, "best_model.pt"))
            print("‚úÖ Saved best model.")
    print("Training complete.")

# -----------------------
# Main
# -----------------------
def main():
    os.makedirs(OUT_DIR, exist_ok=True)
    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
    label2id, id2label = build_label_maps([
        os.path.join(DATA_DIR, "train.jsonl"),
        os.path.join(DATA_DIR, "dev.jsonl"),
        os.path.join(DATA_DIR, "test.jsonl")
    ])
    print("Labels:", label2id)

    train_ds = IDRRDataset(os.path.join(DATA_DIR, "train.jsonl"), tokenizer)
    dev_ds = IDRRDataset(os.path.join(DATA_DIR, "dev.jsonl"), tokenizer)
    test_ds = IDRRDataset(os.path.join(DATA_DIR, "test.jsonl"), tokenizer)

    train_dl = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True,
                          collate_fn=lambda x: collate_fn(x, tokenizer), num_workers=NUM_WORKERS)
    dev_dl = DataLoader(dev_ds, batch_size=BATCH_SIZE, shuffle=False,
                        collate_fn=lambda x: collate_fn(x, tokenizer), num_workers=NUM_WORKERS)
    test_dl = DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=False,
                         collate_fn=lambda x: collate_fn(x, tokenizer), num_workers=NUM_WORKERS)

    model = LatentIDRRModel(
        model_name=MODEL_NAME,
        num_labels=len(label2id),
        loop_K=LOOP_K,
        aux_weight=AUX_WEIGHT,
        noise_std=NOISE_STD,
        lambda_denoise=LAMBDA_DENOISE,
        denoise_last_only=DENOISE_LAST_ONLY,
        aux_warmup_epochs=AUX_WARMUP_EPOCHS,
        use_separate_refiners=False
    )

    optimizer = torch.optim.AdamW(model.parameters(), lr=LR)
    total_steps = len(train_dl) * EPOCHS
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=max(1,int(0.06*total_steps)),
        num_training_steps=max(1,total_steps)
    )

    print("Device:", DEVICE)
    train(model, train_dl, dev_dl, optimizer, scheduler, label2id, DEVICE)

    best_path = os.path.join(OUT_DIR, "best_model.pt")
    if os.path.exists(best_path):
        model.load_state_dict(torch.load(best_path, map_location=DEVICE))
        print("\nüß™ Final Test Evaluation")
        test_metrics = evaluate(model, test_dl, label2id, DEVICE)
        print(f"Final Test Macro-F1: {test_metrics['macro_f1']:.4f}, Acc: {test_metrics['accuracy']:.4f}")
    else:
        print("No best model saved, skipping test eval.")

In [None]:

if __name__ == "__main__":
    main()


Labels: {'Comparison': 0, 'Contingency': 1, 'Expansion': 2, 'Temporal': 3}
Device: cuda


Train Epoch 1 (aux_w=0.0500): 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 2230/2230 [04:45<00:00,  7.80it/s, loss=1.0429]
                                                     

              precision    recall  f1-score   support

  Comparison     0.5104    0.5326    0.5213        92
 Contingency     0.7783    0.6798    0.7257       253
   Expansion     0.6783    0.8241    0.7441       307
    Temporal     0.5152    0.2394    0.3269        71

    accuracy                         0.6791       723
   macro avg     0.6205    0.5690    0.5795       723
weighted avg     0.6759    0.6791    0.6684       723

Epoch 1 dev macro-F1: 0.5795, acc: 0.6791
‚úÖ Saved best model.


Train Epoch 2 (aux_w=0.0500): 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 2230/2230 [04:35<00:00,  8.11it/s, loss=0.8100]
                                                     

              precision    recall  f1-score   support

  Comparison     0.5934    0.5870    0.5902        92
 Contingency     0.7559    0.7589    0.7574       253
   Expansion     0.7251    0.7818    0.7524       307
    Temporal     0.6383    0.4225    0.5085        71

    accuracy                         0.7137       723
   macro avg     0.6782    0.6375    0.6521       723
weighted avg     0.7106    0.7137    0.7095       723

Epoch 2 dev macro-F1: 0.6521, acc: 0.7137
‚úÖ Saved best model.


Train Epoch 3 (aux_w=0.0500): 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 2230/2230 [04:35<00:00,  8.08it/s, loss=0.7820]
                                                     

              precision    recall  f1-score   support

  Comparison     0.6173    0.5435    0.5780        92
 Contingency     0.7860    0.7115    0.7469       253
   Expansion     0.7060    0.8371    0.7660       307
    Temporal     0.6122    0.4225    0.5000        71

    accuracy                         0.7151       723
   macro avg     0.6804    0.6287    0.6477       723
weighted avg     0.7135    0.7151    0.7093       723

Epoch 3 dev macro-F1: 0.6477, acc: 0.7151


Train Epoch 4 (aux_w=0.0500): 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 2230/2230 [04:35<00:00,  8.09it/s, loss=0.6152]
                                                     

              precision    recall  f1-score   support

  Comparison     0.6044    0.5978    0.6011        92
 Contingency     0.7705    0.7431    0.7565       253
   Expansion     0.7188    0.8078    0.7607       307
    Temporal     0.6512    0.3944    0.4912        71

    accuracy                         0.7178       723
   macro avg     0.6862    0.6358    0.6524       723
weighted avg     0.7157    0.7178    0.7125       723

Epoch 4 dev macro-F1: 0.6524, acc: 0.7178
‚úÖ Saved best model.
Training complete.

üß™ Final Test Evaluation


                                                       

              precision    recall  f1-score   support

  Comparison     0.5735    0.5652    0.5693       138
 Contingency     0.7476    0.6882    0.7167       340
   Expansion     0.7290    0.7958    0.7610       480
    Temporal     0.5405    0.3846    0.4494        52

    accuracy                         0.7069      1010
   macro avg     0.6477    0.6085    0.6241      1010
weighted avg     0.7043    0.7069    0.7038      1010

Final Test Macro-F1: 0.6241, Acc: 0.7069




: 

ÂèÇÊï∞Ôºö
SEED = 42
EPOCHS = 4
BATCH_SIZE = 8
LR = 3e-5
MAX_LEN = 256
MODEL_NAME = "../models/flan-t5-base"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DATA_DIR = "../datasets/pdtb3_T5"
OUT_DIR = "outputs_latent_idrr_denoise_Intersupervised"
LOOP_K = 3
AUX_WEIGHT = 0.05
NOISE_STD = 0.02
LAMBDA_DENOISE = 0
DENOISE_LAST_ONLY = True
AUX_WARMUP_EPOCHS = 1  # üöÄ Âä®ÊÄÅÂêØÁî®‰∏≠Èó¥ÁõëÁù£
NUM_WORKERS = 0


Labels: {'Comparison': 0, 'Contingency': 1, 'Expansion': 2, 'Temporal': 3}
Device: cuda
Train Epoch 1 (aux_w=0.0500): 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 2230/2230 [04:33<00:00,  8.14it/s, loss=1.0429]
                                                     
              precision    recall  f1-score   support

  Comparison     0.5104    0.5326    0.5213        92
 Contingency     0.7783    0.6798    0.7257       253
   Expansion     0.6783    0.8241    0.7441       307
    Temporal     0.5152    0.2394    0.3269        71

    accuracy                         0.6791       723
   macro avg     0.6205    0.5690    0.5795       723
weighted avg     0.6759    0.6791    0.6684       723

Epoch 1 dev macro-F1: 0.5795, acc: 0.6791
‚úÖ Saved best model.
Train Epoch 2 (aux_w=0.0500): 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 2230/2230 [04:34<00:00,  8.14it/s, loss=0.8100]
                                                     
              precision    recall  f1-score   support

  Comparison     0.5934    0.5870    0.5902        92
 Contingency     0.7559    0.7589    0.7574       253
   Expansion     0.7251    0.7818    0.7524       307
    Temporal     0.6383    0.4225    0.5085        71

    accuracy                         0.7137       723
   macro avg     0.6782    0.6375    0.6521       723
weighted avg     0.7106    0.7137    0.7095       723

Epoch 2 dev macro-F1: 0.6521, acc: 0.7137
‚úÖ Saved best model.
Train Epoch 3 (aux_w=0.0500): 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 2230/2230 [04:34<00:00,  8.11it/s, loss=0.7820]
                                                     
              precision    recall  f1-score   support

  Comparison     0.6173    0.5435    0.5780        92
 Contingency     0.7860    0.7115    0.7469       253
   Expansion     0.7060    0.8371    0.7660       307
    Temporal     0.6122    0.4225    0.5000        71

    accuracy                         0.7151       723
   macro avg     0.6804    0.6287    0.6477       723
weighted avg     0.7135    0.7151    0.7093       723

Epoch 3 dev macro-F1: 0.6477, acc: 0.7151
Train Epoch 4 (aux_w=0.0500): 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 2230/2230 [04:34<00:00,  8.12it/s, loss=0.6152]
                                                     
              precision    recall  f1-score   support

  Comparison     0.6044    0.5978    0.6011        92
 Contingency     0.7705    0.7431    0.7565       253
   Expansion     0.7188    0.8078    0.7607       307
    Temporal     0.6512    0.3944    0.4912        71

    accuracy                         0.7178       723
   macro avg     0.6862    0.6358    0.6524       723
weighted avg     0.7157    0.7178    0.7125       723

Epoch 4 dev macro-F1: 0.6524, acc: 0.7178
‚úÖ Saved best model.
Training complete.

üß™ Final Test Evaluation
                                                       
              precision    recall  f1-score   support

  Comparison     0.5735    0.5652    0.5693       138
 Contingency     0.7476    0.6882    0.7167       340
   Expansion     0.7290    0.7958    0.7610       480
    Temporal     0.5405    0.3846    0.4494        52

    accuracy                         0.7069      1010
   macro avg     0.6477    0.6085    0.6241      1010
weighted avg     0.7043    0.7069    0.7038      1010

Final Test Macro-F1: 0.6241, Acc: 0.7069
