In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [4]:
import os
import time
import pickle
import random
import numpy as np
from math import isfinite

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
from transformers import AutoTokenizer, GPT2Config, GPT2LMHeadModel

# ---------------- CONFIG ----------------
DATA_DIR = "/content/drive/MyDrive/labtop2/v1"  # <-- adjust
MAX_LEN = 1024            # must be <= preprocessor max_len
BATCH_SIZE = 32
LR = 5e-4                 # a bit higher for small model
NUM_EPOCHS = 10
PATIENCE = 3
GRAD_CLIP = 1.0

# checkpoint 路径
CKPT_LAST = os.path.join(DATA_DIR, "labtop_checkpoint_last.pt")  # 断点续训用
CKPT_BEST = os.path.join(DATA_DIR, "labtop_best.pth")      # 最优模型，用于 eval


# ---------------- UTILS ----------------
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)


# ---------------- DATASET ----------------
class LabTOPDataset(Dataset):
    """
    Uses only input_ids/type_ids from your preprocessed pkl.
    Training now ignores type_ids for loss masking (LabTOP-style).
    """
    def __init__(self, data_path):
        with open(data_path, "rb") as f:
            self.data = pickle.load(f)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        return {
            "input_ids": torch.tensor(item["input_ids"], dtype=torch.long),
            "type_ids": torch.tensor(item["type_ids"], dtype=torch.long),  # kept for potential analysis
        }


# ---------------- COLLATE FN ----------------
def collate_fn(batch, pad_token_id):
    """
    - Truncates to MAX_LEN (from the *right* by default)
    - Pads to the longest length in batch
    - Labels = input_ids shifted internally by GPT2; here we just
      pass input_ids as labels with pad positions set to -100.
    """
    input_ids = [b["input_ids"] for b in batch]

    # truncate sequences (keep last MAX_LEN tokens)
    truncated = []
    for seq in input_ids:
        if len(seq) > MAX_LEN:
            truncated.append(seq[-MAX_LEN:])
        else:
            truncated.append(seq)

    input_ids_pad = torch.nn.utils.rnn.pad_sequence(
        truncated, batch_first=True, padding_value=pad_token_id
    )
    attention_mask = (input_ids_pad != pad_token_id).long()

    labels = input_ids_pad.clone()
    labels[input_ids_pad == pad_token_id] = -100   # ignore pads only

    return input_ids_pad, attention_mask, labels


# ---------------- MODEL (small GPT2-style LM) ----------------
class LabTOPGPT2Small(nn.Module):
    """
    Small GPT-2 style LM:
    - fewer layers / smaller hidden size for your budget
    - still uses GPT2LMHeadModel for correct autoregressive behavior
    """
    def __init__(self, tokenizer, d_model=256, n_heads=4, num_layers=4, max_len=MAX_LEN, dropout=0.1):
        super().__init__()
        vocab_size = len(tokenizer)
        config = GPT2Config(
            vocab_size=vocab_size,
            n_embd=d_model,
            n_head=n_heads,
            n_layer=num_layers,
            n_positions=max_len,
            n_ctx=max_len,
            resid_pdrop=dropout,
            embd_pdrop=dropout,
            attn_pdrop=dropout,
            bos_token_id=tokenizer.bos_token_id,
            eos_token_id=tokenizer.eos_token_id,
        )
        self.model = GPT2LMHeadModel(config)

    def forward(self, input_ids, attention_mask=None, labels=None):
        """
        GPT2LMHeadModel will:
        - apply causal masking internally
        - compute next-token cross-entropy loss if labels is provided
        """
        outputs = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=labels,
        )
        return outputs  # has .logits and .loss

    def generate_next_tokens(self, input_ids, attention_mask=None, max_new_tokens=6, bad_ids=None, eos_id=None):
        """
        Simple greedy generation of a few tokens.
        """
        self.eval()
        device = input_ids.device
        generated = []

        with torch.no_grad():
            for _ in range(max_new_tokens):
                if input_ids.size(1) > MAX_LEN:
                    input_ids = input_ids[:, -MAX_LEN:]
                    if attention_mask is not None:
                        attention_mask = attention_mask[:, -MAX_LEN:]

                outputs = self.model(input_ids, attention_mask=attention_mask)
                logits = outputs.logits[:, -1, :]  # (B, vocab)
                if bad_ids:
                    logits[:, bad_ids] = -1e9

                next_token = torch.argmax(logits, dim=-1)  # (B,)
                next_id = next_token.item()
                generated.append(next_id)

                input_ids = torch.cat([input_ids, next_token.unsqueeze(0)], dim=1)
                if attention_mask is not None:
                    next_mask_token = torch.ones_like(next_token).unsqueeze(0)
                    attention_mask = torch.cat([attention_mask, next_mask_token], dim=1)

                if eos_id is not None and next_id == eos_id:
                    break

        return generated


# ---------------- DECODE NUMERIC VALUE ----------------
def decode_value(token_ids, tokenizer):
    """
    Char-level decode: keep digits, '.', '-' and parse as float.
    Returns None if parsing fails.
    """
    text = tokenizer.decode(token_ids)
    text = text.replace(" ", "")
    filtered = "".join(ch for ch in text if ch.isdigit() or ch in ".-")
    if filtered == "":
        return None
    try:
        return float(filtered)
    except Exception:
        return None


# ---------------- TRAIN ----------------
def train_model(resume=False):
    """
    完全断点续训版本：
      - 如果 resume=True 且 CKPT_LAST 存在，则从上次保存的 epoch 继续
      - 否则从头开始训练
    """
    set_seed(42)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    tokenizer = AutoTokenizer.from_pretrained(DATA_DIR)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    pad_id = tokenizer.pad_token_id

    train_dataset = LabTOPDataset(os.path.join(DATA_DIR, "train.pkl"))
    val_dataset   = LabTOPDataset(os.path.join(DATA_DIR, "val.pkl"))
    print("Train dataset size:", len(train_dataset))
    print("Seq length example:", len(train_dataset[0]["input_ids"]))

    def collate_func(batch):
        return collate_fn(batch, pad_token_id=pad_id)

    train_loader = DataLoader(
        train_dataset,
        batch_size=BATCH_SIZE,
        shuffle=True,
        collate_fn=collate_func,
        num_workers=0,
        pin_memory=(device.type == "cuda"),
    )
    val_loader = DataLoader(
        val_dataset,
        batch_size=BATCH_SIZE,
        shuffle=False,
        collate_fn=collate_func,
        num_workers=0,
        pin_memory=(device.type == "cuda"),
    )

    model = LabTOPGPT2Small(tokenizer).to(device)
    optimizer = optim.AdamW(model.parameters(), lr=LR)
    use_amp = (device.type == "cuda")
    scaler = torch.cuda.amp.GradScaler(enabled=use_amp)

    # ---- 断点恢复逻辑 ----
    start_epoch = 0
    best_val_loss = float("inf")
    patience_counter = 0

    if resume and os.path.exists(CKPT_LAST):
        print(f"Resuming from checkpoint: {CKPT_LAST}")
        checkpoint = torch.load(CKPT_LAST, map_location=device)

        model.load_state_dict(checkpoint["model_state_dict"])
        optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
        scaler.load_state_dict(checkpoint["scaler_state_dict"])

        start_epoch      = checkpoint.get("epoch", 0) + 1   # 下一个 epoch
        best_val_loss    = checkpoint.get("best_val_loss", float("inf"))
        patience_counter = checkpoint.get("patience_counter", 0)

        print(
            f"  -> start_epoch = {start_epoch}, "
            f"best_val_loss = {best_val_loss:.4f}, "
            f"patience_counter = {patience_counter}"
        )
    else:
        if resume:
            print(f"resume=True but checkpoint not found at {CKPT_LAST}, training from scratch.")
        else:
            print("Training from scratch.")

    # ---- 训练循环 ----
    for epoch in range(start_epoch, NUM_EPOCHS):
        # ---------- TRAIN ----------
        model.train()
        total_loss = 0.0
        count = 0

        for batch in tqdm(train_loader, desc=f"Epoch {epoch+1} [train]"):
            input_ids, attention_mask, labels = [b.to(device) for b in batch]

            optimizer.zero_grad()
            with torch.cuda.amp.autocast(enabled=use_amp):
                outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
                loss = outputs.loss

            scaler.scale(loss).backward()
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP)
            scaler.step(optimizer)
            scaler.update()

            total_loss += loss.item()
            count += 1

        avg_train_loss = total_loss / max(count, 1)
        print(f"Epoch {epoch+1} Train Loss: {avg_train_loss:.4f}")

        # ---------- VALIDATION ----------
        model.eval()
        val_loss = 0.0
        val_count = 0
        with torch.no_grad():
            for batch in tqdm(val_loader, desc=f"Epoch {epoch+1} [val]"):
                input_ids, attention_mask, labels = [b.to(device) for b in batch]
                with torch.cuda.amp.autocast(enabled=use_amp):
                    outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
                    loss = outputs.loss
                val_loss += loss.item()
                val_count += 1

        avg_val_loss = val_loss / max(val_count, 1)
        print(f"Epoch {epoch+1} Val Loss: {avg_val_loss:.4f}")

        # ---------- 保存最优模型（仅参数，用于 eval） ----------
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            patience_counter = 0
            torch.save(model.state_dict(), CKPT_BEST)
            print(f"Saved BEST model (val loss={best_val_loss:.4f}) to {CKPT_BEST}")
        else:
            patience_counter += 1
            print(f"No improvement. Patience: {patience_counter}/{PATIENCE}")

        # ---------- 保存完整断点（每个 epoch 都保存） ----------
        last_state = {
            "epoch": epoch,
            "model_state_dict": model.state_dict(),
            "optimizer_state_dict": optimizer.state_dict(),
            "scaler_state_dict": scaler.state_dict(),
            "best_val_loss": best_val_loss,
            "patience_counter": patience_counter,
        }
        torch.save(last_state, CKPT_LAST)
        print(f"Saved training checkpoint to {CKPT_LAST}")

        # ---------- EARLY STOP ----------
        if patience_counter >= PATIENCE:
            print("Early stopping triggered.")
            break

    print("Training complete.")


# ---------------- EVALUATION ----------------
def evaluate_model(resume=False, save_every=10000):
    test_eval_path = os.path.join(DATA_DIR, "test_eval.pkl")

    if not os.path.exists(test_eval_path):
        print("test_eval.pkl not found; nothing to evaluate.")
        return
    if not os.path.exists(CKPT_BEST):
        print(f"Best model checkpoint {CKPT_BEST} not found; train first.")
        return

    # 进度文件，用来断点续 eval
    PROGRESS_PATH = os.path.join(DATA_DIR, "eval_progress.pkl")

    if torch.cuda.is_available():
        device = torch.device("cuda")
    elif torch.backends.mps.is_available():
        device = torch.device("mps")
    else:
        device = torch.device("cpu")
    print(f"Using device: {device}")

    tokenizer = AutoTokenizer.from_pretrained(DATA_DIR)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    with open(test_eval_path, "rb") as f:
        test_data = pickle.load(f)
    print(f"Total examples in test_eval: {len(test_data)}")

    model = LabTOPGPT2Small(tokenizer).to(device)
    model.load_state_dict(torch.load(CKPT_BEST, map_location=device))
    model.eval()

    use_amp = (device.type == "cuda")

    # tokens we never want as value predictions
    bad_tokens = [
        "labevent", "inputevent", "outputevent",
        "gender", "age", "race",
        "procedureevent", "emarevent", "microevent"
    ]
    vocab = tokenizer.get_vocab()
    bad_ids = [tokenizer.convert_tokens_to_ids(t) for t in bad_tokens if t in vocab]

    eoe_id = tokenizer.convert_tokens_to_ids("[EOE]")
    max_new_tokens = 6

    # --------- 恢复 / 初始化进度 ---------
    if resume and os.path.exists(PROGRESS_PATH):
        print(f"Resuming evaluation from {PROGRESS_PATH}")
        prog = pickle.load(open(PROGRESS_PATH, "rb"))
        predictions     = prog["predictions"]
        ground_truths   = prog["ground_truths"]
        itemids         = prog["itemids"]
        event_types     = prog["event_types"]
        example_indices = prog["example_indices"]
        start_idx       = prog["next_idx"]
    else:
        predictions     = []
        ground_truths   = []
        itemids         = []
        event_types     = []
        example_indices = []
        start_idx       = 0

    print(f"Start evaluating from index {start_idx} ...")

    start_time = time.time()

    for idx in tqdm(range(start_idx, len(test_data)), desc="Evaluating"):
        item = test_data[idx]
        prompt_ids = item["prompt_ids"]
        true_val   = item["valuenum"]
        itemid     = item.get("itemid", None)
        e_type     = item.get("event_type", "unknown")

        # truncate prompt from the left to MAX_LEN
        if len(prompt_ids) > MAX_LEN:
            prompt_ids = prompt_ids[-MAX_LEN:]

        input_tensor = torch.tensor([prompt_ids], dtype=torch.long).to(device)
        attn_mask    = torch.ones_like(input_tensor, dtype=torch.long).to(device)

        # 只前几个样本打印 debug
        if idx < start_idx + 3:
            decoded_prompt_tail = tokenizer.decode(prompt_ids[-80:])
            print(f"\n=== Debug sample {idx} ===")
            print(f"Prompt tail: {decoded_prompt_tail}")
            print(f"True value: {true_val}, itemid: {itemid}, event_type: {e_type}")

        with torch.no_grad():
            with torch.cuda.amp.autocast(enabled=use_amp):
                generated_ids = model.generate_next_tokens(
                    input_ids=input_tensor,
                    attention_mask=attn_mask,
                    max_new_tokens=max_new_tokens,
                    bad_ids=bad_ids,
                    eos_id=eoe_id,
                )

        pred_val = decode_value(generated_ids, tokenizer)
        if pred_val is None or not isfinite(pred_val) or abs(pred_val) > 1e4:
            # 无效预测，跳过但仍然推进 idx
            continue

        if idx < start_idx + 3:
            decoded_generated = tokenizer.decode(generated_ids)
            print(f"Generated tokens: {decoded_generated}")
            print(f"Decoded pred_val: {pred_val}")

        predictions.append(pred_val)
        ground_truths.append(true_val)
        itemids.append(itemid)
        event_types.append(e_type)
        example_indices.append(idx)

        # --------- 定期保存 eval 进度 ---------
        if (idx + 1) % save_every == 0:
            prog = {
                "predictions": predictions,
                "ground_truths": ground_truths,
                "itemids": itemids,
                "event_types": event_types,
                "example_indices": example_indices,
                "next_idx": idx + 1,  # 下次从这里继续
            }
            with open(PROGRESS_PATH, "wb") as f:
                pickle.dump(prog, f)
            print(f"\n[Checkpoint] Saved eval progress at idx={idx+1} -> {PROGRESS_PATH}")

    # 全部跑完，删除进度文件
    if os.path.exists(PROGRESS_PATH):
        os.remove(PROGRESS_PATH)
        print(f"Removed eval progress file {PROGRESS_PATH}")

    if len(predictions) == 0:
        print("No valid predictions generated.")
        return

    preds  = np.array(predictions)
    truths = np.array(ground_truths)

    mae  = np.mean(np.abs(preds - truths))
    rmse = np.sqrt(np.mean((preds - truths) ** 2))

    p1   = np.percentile(truths, 1)
    p99  = np.percentile(truths, 99)
    range_val = p99 - p1 if p99 > p1 else 1e-6
    nmae = mae / range_val

    denom = (np.abs(truths) + np.abs(preds))
    mask = denom > 0
    if np.sum(mask) > 0:
        smape = np.mean(2 * np.abs(preds[mask] - truths[mask]) / denom[mask])
    else:
        smape = 0.0

    elapsed = time.time() - start_time
    print(f"\nEvaluated {len(preds)} samples in {elapsed:.2f} seconds.")
    print(f"Test MAE:   {mae:.4f}")
    print(f"Test RMSE:  {rmse:.4f}")
    print(f"Test NMAE:  {nmae:.4f}")
    print(f"Test SMAPE: {smape:.4f}")

    # 保存 sample predictions
    csv_path = os.path.join(DATA_DIR, "test_predictions.csv")
    print(f"Saving predictions to: {csv_path}")
    import csv
    with open(csv_path, "w", newline="") as f:
        writer = csv.writer(f)
        writer.writerow(["index", "itemid", "event_type", "true_value", "pred_value"])
        for i, iid, et, truth, pred in zip(example_indices, itemids, event_types, truths, preds):
            writer.writerow([i, iid, et, truth, pred])


if __name__ == "__main__":
    # 训练
    RESUME_TRAIN = False
    train_model(resume=RESUME_TRAIN)

    # eval：第一次跑用 resume=False，之后断了想接着跑就改成 True
    # RESUME_EVAL = False
    # evaluate_model(resume=RESUME_EVAL)

Using device: cuda
Train dataset size: 367830
Seq length example: 276


  scaler = torch.cuda.amp.GradScaler(enabled=use_amp)


Training from scratch.


  with torch.cuda.amp.autocast(enabled=use_amp):
`loss_type=None` was set in the config but it is unrecognized. Using the default loss: `ForCausalLMLoss`.
Epoch 1 [train]:  37%|███▋      | 4278/11495 [08:06<13:40,  8.79it/s]


KeyboardInterrupt: 

In [None]:
import os
import time
import pickle
import random
import numpy as np
from math import isfinite

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
from transformers import AutoTokenizer, GPT2Config, GPT2LMHeadModel

# ---------------- CONFIG ----------------
DATA_DIR = "/content/drive/MyDrive/labtop2/v1"  # <-- adjust
MAX_LEN = 1024            # must be <= preprocessor max_len
BATCH_SIZE = 32
LR = 5e-4                 # a bit higher for small model
NUM_EPOCHS = 10
PATIENCE = 3
GRAD_CLIP = 1.0

# checkpoint 路径
CKPT_LAST = os.path.join(DATA_DIR, "labtop_checkpoint_last.pt")  # 断点续训用
CKPT_BEST = os.path.join(DATA_DIR, "labtop_v1.pth")      # 最优模型，用于 eval


# ---------------- UTILS ----------------
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)


# ---------------- DATASET ----------------
class LabTOPDataset(Dataset):
    """
    Uses only input_ids/type_ids from your preprocessed pkl.
    Training now ignores type_ids for loss masking (LabTOP-style).
    """
    def __init__(self, data_path):
        with open(data_path, "rb") as f:
            self.data = pickle.load(f)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        return {
            "input_ids": torch.tensor(item["input_ids"], dtype=torch.long),
            "type_ids": torch.tensor(item["type_ids"], dtype=torch.long),  # kept for potential analysis
        }


# ---------------- COLLATE FN ----------------
def collate_fn(batch, pad_token_id):
    """
    - Truncates to MAX_LEN (from the *right* by default)
    - Pads to the longest length in batch
    - Labels = input_ids shifted internally by GPT2; here we just
      pass input_ids as labels with pad positions set to -100.
    """
    input_ids = [b["input_ids"] for b in batch]

    # truncate sequences (keep last MAX_LEN tokens)
    truncated = []
    for seq in input_ids:
        if len(seq) > MAX_LEN:
            truncated.append(seq[-MAX_LEN:])
        else:
            truncated.append(seq)

    input_ids_pad = torch.nn.utils.rnn.pad_sequence(
        truncated, batch_first=True, padding_value=pad_token_id
    )
    attention_mask = (input_ids_pad != pad_token_id).long()

    labels = input_ids_pad.clone()
    labels[input_ids_pad == pad_token_id] = -100   # ignore pads only

    return input_ids_pad, attention_mask, labels


# ---------------- MODEL (small GPT2-style LM) ----------------
class LabTOPGPT2Small(nn.Module):
    """
    Small GPT-2 style LM:
    - fewer layers / smaller hidden size for your budget
    - still uses GPT2LMHeadModel for correct autoregressive behavior
    """
    def __init__(self, tokenizer, d_model=256, n_heads=4, num_layers=4, max_len=MAX_LEN, dropout=0.1):
        super().__init__()
        vocab_size = len(tokenizer)
        config = GPT2Config(
            vocab_size=vocab_size,
            n_embd=d_model,
            n_head=n_heads,
            n_layer=num_layers,
            n_positions=max_len,
            n_ctx=max_len,
            resid_pdrop=dropout,
            embd_pdrop=dropout,
            attn_pdrop=dropout,
            bos_token_id=tokenizer.bos_token_id,
            eos_token_id=tokenizer.eos_token_id,
        )
        self.model = GPT2LMHeadModel(config)

    def forward(self, input_ids, attention_mask=None, labels=None):
        """
        GPT2LMHeadModel will:
        - apply causal masking internally
        - compute next-token cross-entropy loss if labels is provided
        """
        outputs = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=labels,
        )
        return outputs  # has .logits and .loss

    def generate_next_tokens(self, input_ids, attention_mask=None, max_new_tokens=6, bad_ids=None, eos_id=None):
        """
        Simple greedy generation of a few tokens.
        """
        self.eval()
        device = input_ids.device
        generated = []

        with torch.no_grad():
            for _ in range(max_new_tokens):
                if input_ids.size(1) > MAX_LEN:
                    input_ids = input_ids[:, -MAX_LEN:]
                    if attention_mask is not None:
                        attention_mask = attention_mask[:, -MAX_LEN:]

                outputs = self.model(input_ids, attention_mask=attention_mask)
                logits = outputs.logits[:, -1, :]  # (B, vocab)
                if bad_ids:
                    logits[:, bad_ids] = -1e9

                next_token = torch.argmax(logits, dim=-1)  # (B,)
                next_id = next_token.item()
                generated.append(next_id)

                input_ids = torch.cat([input_ids, next_token.unsqueeze(0)], dim=1)
                if attention_mask is not None:
                    next_mask_token = torch.ones_like(next_token).unsqueeze(0)
                    attention_mask = torch.cat([attention_mask, next_mask_token], dim=1)

                if eos_id is not None and next_id == eos_id:
                    break

        return generated


# ---------------- DECODE NUMERIC VALUE ----------------
def decode_value(token_ids, tokenizer):
    """
    Char-level decode: keep digits, '.', '-' and parse as float.
    Returns None if parsing fails.
    """
    text = tokenizer.decode(token_ids)
    text = text.replace(" ", "")
    filtered = "".join(ch for ch in text if ch.isdigit() or ch in ".-")
    if filtered == "":
        return None
    try:
        return float(filtered)
    except Exception:
        return None


# ---------------- TRAIN ----------------
def train_model(resume=False):
    """
    完全断点续训版本：
      - 如果 resume=True 且 CKPT_LAST 存在，则从上次保存的 epoch 继续
      - 否则从头开始训练
    """
    set_seed(42)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    tokenizer = AutoTokenizer.from_pretrained(DATA_DIR)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    pad_id = tokenizer.pad_token_id

    train_dataset = LabTOPDataset(os.path.join(DATA_DIR, "train.pkl"))
    val_dataset   = LabTOPDataset(os.path.join(DATA_DIR, "val.pkl"))
    print("Train dataset size:", len(train_dataset))
    print("Seq length example:", len(train_dataset[0]["input_ids"]))

    def collate_func(batch):
        return collate_fn(batch, pad_token_id=pad_id)

    train_loader = DataLoader(
        train_dataset,
        batch_size=BATCH_SIZE,
        shuffle=True,
        collate_fn=collate_func,
        num_workers=0,
        pin_memory=(device.type == "cuda"),
    )
    val_loader = DataLoader(
        val_dataset,
        batch_size=BATCH_SIZE,
        shuffle=False,
        collate_fn=collate_func,
        num_workers=0,
        pin_memory=(device.type == "cuda"),
    )

    model = LabTOPGPT2Small(tokenizer).to(device)
    optimizer = optim.AdamW(model.parameters(), lr=LR)
    use_amp = (device.type == "cuda")
    scaler = torch.cuda.amp.GradScaler(enabled=use_amp)

    # ---- 断点恢复逻辑 ----
    start_epoch = 0
    best_val_loss = float("inf")
    patience_counter = 0

    if resume and os.path.exists(CKPT_LAST):
        print(f"Resuming from checkpoint: {CKPT_LAST}")
        checkpoint = torch.load(CKPT_LAST, map_location=device)

        model.load_state_dict(checkpoint["model_state_dict"])
        optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
        scaler.load_state_dict(checkpoint["scaler_state_dict"])

        start_epoch      = checkpoint.get("epoch", 0) + 1   # 下一个 epoch
        best_val_loss    = checkpoint.get("best_val_loss", float("inf"))
        patience_counter = checkpoint.get("patience_counter", 0)

        print(
            f"  -> start_epoch = {start_epoch}, "
            f"best_val_loss = {best_val_loss:.4f}, "
            f"patience_counter = {patience_counter}"
        )
    else:
        if resume:
            print(f"resume=True but checkpoint not found at {CKPT_LAST}, training from scratch.")
        else:
            print("Training from scratch.")

    # ---- 训练循环 ----
    for epoch in range(start_epoch, NUM_EPOCHS):
        # ---------- TRAIN ----------
        model.train()
        total_loss = 0.0
        count = 0

        for batch in tqdm(train_loader, desc=f"Epoch {epoch+1} [train]"):
            input_ids, attention_mask, labels = [b.to(device) for b in batch]

            optimizer.zero_grad()
            with torch.cuda.amp.autocast(enabled=use_amp):
                outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
                loss = outputs.loss

            scaler.scale(loss).backward()
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP)
            scaler.step(optimizer)
            scaler.update()

            total_loss += loss.item()
            count += 1

        avg_train_loss = total_loss / max(count, 1)
        print(f"Epoch {epoch+1} Train Loss: {avg_train_loss:.4f}")

        # ---------- VALIDATION ----------
        model.eval()
        val_loss = 0.0
        val_count = 0
        with torch.no_grad():
            for batch in tqdm(val_loader, desc=f"Epoch {epoch+1} [val]"):
                input_ids, attention_mask, labels = [b.to(device) for b in batch]
                with torch.cuda.amp.autocast(enabled=use_amp):
                    outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
                    loss = outputs.loss
                val_loss += loss.item()
                val_count += 1

        avg_val_loss = val_loss / max(val_count, 1)
        print(f"Epoch {epoch+1} Val Loss: {avg_val_loss:.4f}")

        # ---------- 保存最优模型（仅参数，用于 eval） ----------
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            patience_counter = 0
            torch.save(model.state_dict(), CKPT_BEST)
            print(f"Saved BEST model (val loss={best_val_loss:.4f}) to {CKPT_BEST}")
        else:
            patience_counter += 1
            print(f"No improvement. Patience: {patience_counter}/{PATIENCE}")

        # ---------- 保存完整断点（每个 epoch 都保存） ----------
        last_state = {
            "epoch": epoch,
            "model_state_dict": model.state_dict(),
            "optimizer_state_dict": optimizer.state_dict(),
            "scaler_state_dict": scaler.state_dict(),
            "best_val_loss": best_val_loss,
            "patience_counter": patience_counter,
        }
        torch.save(last_state, CKPT_LAST)
        print(f"Saved training checkpoint to {CKPT_LAST}")

        # ---------- EARLY STOP ----------
        if patience_counter >= PATIENCE:
            print("Early stopping triggered.")
            break

    print("Training complete.")


# ---------------- EVALUATION ----------------
def evaluate_model(resume=False, save_every=10000):
    test_eval_path = os.path.join(DATA_DIR, "test_eval.pkl")

    if not os.path.exists(test_eval_path):
        print("test_eval.pkl not found; nothing to evaluate.")
        return
    if not os.path.exists(CKPT_BEST):
        print(f"Best model checkpoint {CKPT_BEST} not found; train first.")
        return

    # 进度文件，用来断点续 eval
    PROGRESS_PATH = os.path.join(DATA_DIR, "eval_progress.pkl")

    if torch.cuda.is_available():
        device = torch.device("cuda")
    elif torch.backends.mps.is_available():
        device = torch.device("mps")
    else:
        device = torch.device("cpu")
    print(f"Using device: {device}")

    tokenizer = AutoTokenizer.from_pretrained(DATA_DIR)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    with open(test_eval_path, "rb") as f:
        test_data = pickle.load(f)
    print(f"Total examples in test_eval: {len(test_data)}")

    model = LabTOPGPT2Small(tokenizer).to(device)
    model.load_state_dict(torch.load(CKPT_BEST, map_location=device))
    model.eval()

    use_amp = (device.type == "cuda")

    # tokens we never want as value predictions
    bad_tokens = [
        "labevent", "inputevent", "outputevent",
        "gender", "age", "race",
        "procedureevent", "emarevent", "microevent"
    ]
    vocab = tokenizer.get_vocab()
    bad_ids = [tokenizer.convert_tokens_to_ids(t) for t in bad_tokens if t in vocab]

    eoe_id = tokenizer.convert_tokens_to_ids("[EOE]")
    max_new_tokens = 6

    # --------- 恢复 / 初始化进度 ---------
    if resume and os.path.exists(PROGRESS_PATH):
        print(f"Resuming evaluation from {PROGRESS_PATH}")
        prog = pickle.load(open(PROGRESS_PATH, "rb"))
        predictions     = prog["predictions"]
        ground_truths   = prog["ground_truths"]
        itemids         = prog["itemids"]
        event_types     = prog["event_types"]
        example_indices = prog["example_indices"]
        start_idx       = prog["next_idx"]
    else:
        predictions     = []
        ground_truths   = []
        itemids         = []
        event_types     = []
        example_indices = []
        start_idx       = 0

    print(f"Start evaluating from index {start_idx} ...")

    start_time = time.time()

    for idx in tqdm(range(start_idx, len(test_data)), desc="Evaluating"):
        item = test_data[idx]
        prompt_ids = item["prompt_ids"]
        true_val   = item["valuenum"]
        itemid     = item.get("itemid", None)
        e_type     = item.get("event_type", "unknown")

        # truncate prompt from the left to MAX_LEN
        if len(prompt_ids) > MAX_LEN:
            prompt_ids = prompt_ids[-MAX_LEN:]

        input_tensor = torch.tensor([prompt_ids], dtype=torch.long).to(device)
        attn_mask    = torch.ones_like(input_tensor, dtype=torch.long).to(device)

        # 只前几个样本打印 debug
        if idx < start_idx + 3:
            decoded_prompt_tail = tokenizer.decode(prompt_ids[-80:])
            print(f"\n=== Debug sample {idx} ===")
            print(f"Prompt tail: {decoded_prompt_tail}")
            print(f"True value: {true_val}, itemid: {itemid}, event_type: {e_type}")

        with torch.no_grad():
            with torch.cuda.amp.autocast(enabled=use_amp):
                generated_ids = model.generate_next_tokens(
                    input_ids=input_tensor,
                    attention_mask=attn_mask,
                    max_new_tokens=max_new_tokens,
                    bad_ids=bad_ids,
                    eos_id=eoe_id,
                )

        pred_val = decode_value(generated_ids, tokenizer)
        if pred_val is None or not isfinite(pred_val) or abs(pred_val) > 1e4:
            # 无效预测，跳过但仍然推进 idx
            continue

        if idx < start_idx + 3:
            decoded_generated = tokenizer.decode(generated_ids)
            print(f"Generated tokens: {decoded_generated}")
            print(f"Decoded pred_val: {pred_val}")

        predictions.append(pred_val)
        ground_truths.append(true_val)
        itemids.append(itemid)
        event_types.append(e_type)
        example_indices.append(idx)

        # --------- 定期保存 eval 进度 ---------
        if (idx + 1) % save_every == 0:
            prog = {
                "predictions": predictions,
                "ground_truths": ground_truths,
                "itemids": itemids,
                "event_types": event_types,
                "example_indices": example_indices,
                "next_idx": idx + 1,  # 下次从这里继续
            }
            with open(PROGRESS_PATH, "wb") as f:
                pickle.dump(prog, f)
            print(f"\n[Checkpoint] Saved eval progress at idx={idx+1} -> {PROGRESS_PATH}")

    # 全部跑完，删除进度文件
    if os.path.exists(PROGRESS_PATH):
        os.remove(PROGRESS_PATH)
        print(f"Removed eval progress file {PROGRESS_PATH}")

    if len(predictions) == 0:
        print("No valid predictions generated.")
        return

    preds  = np.array(predictions)
    truths = np.array(ground_truths)

    mae  = np.mean(np.abs(preds - truths))
    rmse = np.sqrt(np.mean((preds - truths) ** 2))

    p1   = np.percentile(truths, 1)
    p99  = np.percentile(truths, 99)
    range_val = p99 - p1 if p99 > p1 else 1e-6
    nmae = mae / range_val

    denom = (np.abs(truths) + np.abs(preds))
    mask = denom > 0
    if np.sum(mask) > 0:
        smape = np.mean(2 * np.abs(preds[mask] - truths[mask]) / denom[mask])
    else:
        smape = 0.0

    elapsed = time.time() - start_time
    print(f"\nEvaluated {len(preds)} samples in {elapsed:.2f} seconds.")
    print(f"Test MAE:   {mae:.4f}")
    print(f"Test RMSE:  {rmse:.4f}")
    print(f"Test NMAE:  {nmae:.4f}")
    print(f"Test SMAPE: {smape:.4f}")

    # 保存 sample predictions
    csv_path = os.path.join(DATA_DIR, "test_predictions.csv")
    print(f"Saving predictions to: {csv_path}")
    import csv
    with open(csv_path, "w", newline="") as f:
        writer = csv.writer(f)
        writer.writerow(["index", "itemid", "event_type", "true_value", "pred_value"])
        for i, iid, et, truth, pred in zip(example_indices, itemids, event_types, truths, preds):
            writer.writerow([i, iid, et, truth, pred])


if __name__ == "__main__":
    # 训练
    # RESUME_TRAIN = False
    # train_model(resume=RESUME_TRAIN)

    # eval：第一次跑用 resume=False，之后断了想接着跑就改成 True
    RESUME_EVAL = False
    evaluate_model(resume=RESUME_EVAL)

Using device: cuda
Total examples in test_eval: 1229791
Start evaluating from index 0 ...


  with torch.cuda.amp.autocast(enabled=use_amp):
Evaluating:   0%|          | 4/1229791 [00:00<10:05:38, 33.84it/s]


=== Debug sample 0 ===
Prompt tail: gender f age 53 race other [DAY0] [SUN] [00h] [20m] labevent base excess
True value: 0.0, itemid: 50802, event_type: labevent
Generated tokens: - 1 meq / l
Decoded pred_val: -1.0

=== Debug sample 1 ===
Prompt tail: gender f age 53 race other [DAY0] [SUN] [00h] [20m] labevent base excess 0 meq / l [EOE] [DAY0] [SUN] [00h] [20m] labevent sodium, whole blood
True value: 137.0, itemid: 50824, event_type: labevent
Generated tokens: 1 3 7 meq /
Decoded pred_val: 137.0

=== Debug sample 2 ===
Prompt tail: gender f age 53 race other [DAY0] [SUN] [00h] [20m] labevent base excess 0 meq / l [EOE] [DAY0] [SUN] [00h] [20m] labevent sodium, whole blood 1 3 7 meq / l [EOE] [DAY0] [SUN] [00h] [20m] labevent potassium, whole blood
True value: 3.4, itemid: 50822, event_type: labevent
Generated tokens: 4. 1 meq /
Decoded pred_val: 4.1


Evaluating:   1%|          | 10003/1229791 [04:45<10:08:08, 33.43it/s]


[Checkpoint] Saved eval progress at idx=10000 -> /content/drive/MyDrive/labtop2/v1/eval_progress.pkl


Evaluating:   2%|▏         | 20004/1229791 [09:31<9:40:11, 34.75it/s]


[Checkpoint] Saved eval progress at idx=20000 -> /content/drive/MyDrive/labtop2/v1/eval_progress.pkl


Evaluating:   2%|▏         | 30004/1229791 [14:16<10:41:15, 31.18it/s]


[Checkpoint] Saved eval progress at idx=30000 -> /content/drive/MyDrive/labtop2/v1/eval_progress.pkl


Evaluating:   3%|▎         | 40005/1229791 [19:05<9:50:30, 33.58it/s]


[Checkpoint] Saved eval progress at idx=40000 -> /content/drive/MyDrive/labtop2/v1/eval_progress.pkl


Evaluating:   4%|▍         | 50005/1229791 [23:48<10:03:43, 32.57it/s]


[Checkpoint] Saved eval progress at idx=50000 -> /content/drive/MyDrive/labtop2/v1/eval_progress.pkl


Evaluating:   5%|▍         | 60004/1229791 [28:32<10:11:08, 31.90it/s]


[Checkpoint] Saved eval progress at idx=60000 -> /content/drive/MyDrive/labtop2/v1/eval_progress.pkl


Evaluating:   6%|▌         | 70005/1229791 [33:17<9:55:27, 32.46it/s] 


[Checkpoint] Saved eval progress at idx=70000 -> /content/drive/MyDrive/labtop2/v1/eval_progress.pkl


Evaluating:   7%|▋         | 80007/1229791 [38:01<9:44:36, 32.78it/s]


[Checkpoint] Saved eval progress at idx=80000 -> /content/drive/MyDrive/labtop2/v1/eval_progress.pkl


Evaluating:   7%|▋         | 90005/1229791 [42:45<10:00:51, 31.62it/s]


[Checkpoint] Saved eval progress at idx=90000 -> /content/drive/MyDrive/labtop2/v1/eval_progress.pkl


Evaluating:   8%|▊         | 100006/1229791 [47:29<9:40:53, 32.42it/s] 


[Checkpoint] Saved eval progress at idx=100000 -> /content/drive/MyDrive/labtop2/v1/eval_progress.pkl


Evaluating:   9%|▉         | 110007/1229791 [52:09<9:28:55, 32.80it/s]


[Checkpoint] Saved eval progress at idx=110000 -> /content/drive/MyDrive/labtop2/v1/eval_progress.pkl


Evaluating:  10%|▉         | 120006/1229791 [56:53<9:37:25, 32.03it/s] 


[Checkpoint] Saved eval progress at idx=120000 -> /content/drive/MyDrive/labtop2/v1/eval_progress.pkl


Evaluating:  11%|█         | 130003/1229791 [1:01:39<9:26:08, 32.38it/s]


[Checkpoint] Saved eval progress at idx=130000 -> /content/drive/MyDrive/labtop2/v1/eval_progress.pkl


Evaluating:  11%|█▏        | 140005/1229791 [1:06:26<10:11:42, 29.69it/s]


[Checkpoint] Saved eval progress at idx=140000 -> /content/drive/MyDrive/labtop2/v1/eval_progress.pkl


Evaluating:  12%|█▏        | 150005/1229791 [1:11:09<9:45:19, 30.75it/s] 


[Checkpoint] Saved eval progress at idx=150000 -> /content/drive/MyDrive/labtop2/v1/eval_progress.pkl


Evaluating:  13%|█▎        | 160004/1229791 [1:15:52<9:08:04, 32.53it/s]


[Checkpoint] Saved eval progress at idx=160000 -> /content/drive/MyDrive/labtop2/v1/eval_progress.pkl


Evaluating:  14%|█▍        | 170006/1229791 [1:20:35<9:07:51, 32.24it/s]


[Checkpoint] Saved eval progress at idx=170000 -> /content/drive/MyDrive/labtop2/v1/eval_progress.pkl


Evaluating:  15%|█▍        | 180007/1229791 [1:25:21<9:30:10, 30.69it/s] 


[Checkpoint] Saved eval progress at idx=180000 -> /content/drive/MyDrive/labtop2/v1/eval_progress.pkl


Evaluating:  15%|█▌        | 190006/1229791 [1:30:07<9:03:41, 31.87it/s]


[Checkpoint] Saved eval progress at idx=190000 -> /content/drive/MyDrive/labtop2/v1/eval_progress.pkl


Evaluating:  16%|█▋        | 200006/1229791 [1:34:53<9:20:30, 30.62it/s]


[Checkpoint] Saved eval progress at idx=200000 -> /content/drive/MyDrive/labtop2/v1/eval_progress.pkl


Evaluating:  17%|█▋        | 210005/1229791 [1:39:36<9:20:51, 30.30it/s] 


[Checkpoint] Saved eval progress at idx=210000 -> /content/drive/MyDrive/labtop2/v1/eval_progress.pkl


Evaluating:  18%|█▊        | 220004/1229791 [1:44:20<9:22:33, 29.92it/s]


[Checkpoint] Saved eval progress at idx=220000 -> /content/drive/MyDrive/labtop2/v1/eval_progress.pkl


Evaluating:  19%|█▊        | 230004/1229791 [1:49:05<9:35:51, 28.94it/s] 


[Checkpoint] Saved eval progress at idx=230000 -> /content/drive/MyDrive/labtop2/v1/eval_progress.pkl


Evaluating:  20%|█▉        | 240003/1229791 [1:53:52<9:48:52, 28.01it/s]


[Checkpoint] Saved eval progress at idx=240000 -> /content/drive/MyDrive/labtop2/v1/eval_progress.pkl


Evaluating:  20%|██        | 250004/1229791 [1:58:35<9:05:30, 29.93it/s]


[Checkpoint] Saved eval progress at idx=250000 -> /content/drive/MyDrive/labtop2/v1/eval_progress.pkl


Evaluating:  21%|██        | 260006/1229791 [2:03:19<9:11:58, 29.28it/s]


[Checkpoint] Saved eval progress at idx=260000 -> /content/drive/MyDrive/labtop2/v1/eval_progress.pkl


Evaluating:  22%|██▏       | 270007/1229791 [2:08:06<9:02:12, 29.50it/s]


[Checkpoint] Saved eval progress at idx=270000 -> /content/drive/MyDrive/labtop2/v1/eval_progress.pkl


Evaluating:  23%|██▎       | 280003/1229791 [2:12:50<10:22:29, 25.43it/s]


[Checkpoint] Saved eval progress at idx=280000 -> /content/drive/MyDrive/labtop2/v1/eval_progress.pkl


Evaluating:  24%|██▎       | 290003/1229791 [2:17:35<9:48:09, 26.63it/s]


[Checkpoint] Saved eval progress at idx=290000 -> /content/drive/MyDrive/labtop2/v1/eval_progress.pkl


Evaluating:  24%|██▍       | 300004/1229791 [2:22:19<8:23:02, 30.81it/s]


[Checkpoint] Saved eval progress at idx=300000 -> /content/drive/MyDrive/labtop2/v1/eval_progress.pkl


Evaluating:  25%|██▌       | 310004/1229791 [2:27:02<8:05:45, 31.56it/s]


[Checkpoint] Saved eval progress at idx=310000 -> /content/drive/MyDrive/labtop2/v1/eval_progress.pkl


Evaluating:  26%|██▌       | 320004/1229791 [2:31:44<9:13:15, 27.41it/s]


[Checkpoint] Saved eval progress at idx=320000 -> /content/drive/MyDrive/labtop2/v1/eval_progress.pkl


Evaluating:  27%|██▋       | 330008/1229791 [2:36:31<8:06:55, 30.80it/s]


[Checkpoint] Saved eval progress at idx=330000 -> /content/drive/MyDrive/labtop2/v1/eval_progress.pkl


Evaluating:  28%|██▊       | 340006/1229791 [2:41:12<8:18:00, 29.78it/s]


[Checkpoint] Saved eval progress at idx=340000 -> /content/drive/MyDrive/labtop2/v1/eval_progress.pkl


Evaluating:  28%|██▊       | 350005/1229791 [2:45:57<8:27:55, 28.87it/s]


[Checkpoint] Saved eval progress at idx=350000 -> /content/drive/MyDrive/labtop2/v1/eval_progress.pkl


Evaluating:  29%|██▉       | 360005/1229791 [2:50:43<8:54:58, 27.10it/s]


[Checkpoint] Saved eval progress at idx=360000 -> /content/drive/MyDrive/labtop2/v1/eval_progress.pkl


Evaluating:  30%|███       | 370005/1229791 [2:55:28<8:25:24, 28.35it/s]


[Checkpoint] Saved eval progress at idx=370000 -> /content/drive/MyDrive/labtop2/v1/eval_progress.pkl


Evaluating:  31%|███       | 380005/1229791 [3:00:15<8:39:28, 27.26it/s]


[Checkpoint] Saved eval progress at idx=380000 -> /content/drive/MyDrive/labtop2/v1/eval_progress.pkl


Evaluating:  32%|███▏      | 390003/1229791 [3:05:01<9:35:29, 24.32it/s]


[Checkpoint] Saved eval progress at idx=390000 -> /content/drive/MyDrive/labtop2/v1/eval_progress.pkl


Evaluating:  33%|███▎      | 400005/1229791 [3:09:45<8:26:28, 27.31it/s]


[Checkpoint] Saved eval progress at idx=400000 -> /content/drive/MyDrive/labtop2/v1/eval_progress.pkl


Evaluating:  33%|███▎      | 410006/1229791 [3:14:24<8:06:10, 28.10it/s]


[Checkpoint] Saved eval progress at idx=410000 -> /content/drive/MyDrive/labtop2/v1/eval_progress.pkl


Evaluating:  34%|███▍      | 420006/1229791 [3:19:06<8:28:29, 26.54it/s]


[Checkpoint] Saved eval progress at idx=420000 -> /content/drive/MyDrive/labtop2/v1/eval_progress.pkl


Evaluating:  35%|███▍      | 430004/1229791 [3:23:46<7:36:24, 29.21it/s]


[Checkpoint] Saved eval progress at idx=430000 -> /content/drive/MyDrive/labtop2/v1/eval_progress.pkl


Evaluating:  36%|███▌      | 440005/1229791 [3:28:24<8:13:46, 26.66it/s]


[Checkpoint] Saved eval progress at idx=440000 -> /content/drive/MyDrive/labtop2/v1/eval_progress.pkl


Evaluating:  37%|███▋      | 450003/1229791 [3:33:01<8:43:33, 24.82it/s]


[Checkpoint] Saved eval progress at idx=450000 -> /content/drive/MyDrive/labtop2/v1/eval_progress.pkl


Evaluating:  37%|███▋      | 460005/1229791 [3:37:41<7:50:25, 27.27it/s]


[Checkpoint] Saved eval progress at idx=460000 -> /content/drive/MyDrive/labtop2/v1/eval_progress.pkl


Evaluating:  38%|███▊      | 470006/1229791 [3:42:27<7:32:02, 28.01it/s]


[Checkpoint] Saved eval progress at idx=470000 -> /content/drive/MyDrive/labtop2/v1/eval_progress.pkl


Evaluating:  39%|███▉      | 480007/1229791 [3:47:14<7:38:50, 27.23it/s]


[Checkpoint] Saved eval progress at idx=480000 -> /content/drive/MyDrive/labtop2/v1/eval_progress.pkl


Evaluating:  40%|███▉      | 490004/1229791 [3:52:00<8:21:49, 24.57it/s]


[Checkpoint] Saved eval progress at idx=490000 -> /content/drive/MyDrive/labtop2/v1/eval_progress.pkl


Evaluating:  41%|████      | 500004/1229791 [3:56:49<7:54:39, 25.63it/s]


[Checkpoint] Saved eval progress at idx=500000 -> /content/drive/MyDrive/labtop2/v1/eval_progress.pkl


Evaluating:  41%|████▏     | 510004/1229791 [4:01:38<7:52:46, 25.37it/s]


[Checkpoint] Saved eval progress at idx=510000 -> /content/drive/MyDrive/labtop2/v1/eval_progress.pkl


Evaluating:  42%|████▏     | 520007/1229791 [4:06:23<7:17:43, 27.03it/s]


[Checkpoint] Saved eval progress at idx=520000 -> /content/drive/MyDrive/labtop2/v1/eval_progress.pkl


Evaluating:  43%|████▎     | 530005/1229791 [4:11:13<7:10:22, 27.10it/s]


[Checkpoint] Saved eval progress at idx=530000 -> /content/drive/MyDrive/labtop2/v1/eval_progress.pkl


Evaluating:  44%|████▍     | 540005/1229791 [4:16:01<7:36:26, 25.19it/s]


[Checkpoint] Saved eval progress at idx=540000 -> /content/drive/MyDrive/labtop2/v1/eval_progress.pkl


Evaluating:  45%|████▍     | 550004/1229791 [4:20:46<7:40:57, 24.58it/s]


[Checkpoint] Saved eval progress at idx=550000 -> /content/drive/MyDrive/labtop2/v1/eval_progress.pkl


Evaluating:  46%|████▌     | 560003/1229791 [4:25:33<8:22:51, 22.20it/s]


[Checkpoint] Saved eval progress at idx=560000 -> /content/drive/MyDrive/labtop2/v1/eval_progress.pkl


Evaluating:  46%|████▋     | 570003/1229791 [4:30:24<8:07:03, 22.58it/s]


[Checkpoint] Saved eval progress at idx=570000 -> /content/drive/MyDrive/labtop2/v1/eval_progress.pkl


Evaluating:  47%|████▋     | 580006/1229791 [4:35:11<7:06:16, 25.41it/s]


[Checkpoint] Saved eval progress at idx=580000 -> /content/drive/MyDrive/labtop2/v1/eval_progress.pkl


Evaluating:  48%|████▊     | 590003/1229791 [4:40:00<8:24:00, 21.16it/s]


[Checkpoint] Saved eval progress at idx=590000 -> /content/drive/MyDrive/labtop2/v1/eval_progress.pkl


Evaluating:  49%|████▉     | 600005/1229791 [4:44:47<7:10:37, 24.37it/s]


[Checkpoint] Saved eval progress at idx=600000 -> /content/drive/MyDrive/labtop2/v1/eval_progress.pkl


Evaluating:  50%|████▉     | 610007/1229791 [4:49:35<6:48:26, 25.29it/s]


[Checkpoint] Saved eval progress at idx=610000 -> /content/drive/MyDrive/labtop2/v1/eval_progress.pkl


Evaluating:  50%|█████     | 620003/1229791 [4:54:25<8:02:52, 21.05it/s]


[Checkpoint] Saved eval progress at idx=620000 -> /content/drive/MyDrive/labtop2/v1/eval_progress.pkl


Evaluating:  51%|█████     | 630004/1229791 [4:59:14<6:45:36, 24.65it/s]


[Checkpoint] Saved eval progress at idx=630000 -> /content/drive/MyDrive/labtop2/v1/eval_progress.pkl


Evaluating:  52%|█████▏    | 640004/1229791 [5:04:04<6:31:51, 25.09it/s]


[Checkpoint] Saved eval progress at idx=640000 -> /content/drive/MyDrive/labtop2/v1/eval_progress.pkl


Evaluating:  53%|█████▎    | 650004/1229791 [5:08:52<6:46:44, 23.76it/s]


[Checkpoint] Saved eval progress at idx=650000 -> /content/drive/MyDrive/labtop2/v1/eval_progress.pkl


Evaluating:  54%|█████▎    | 660003/1229791 [5:13:42<7:27:38, 21.21it/s]


[Checkpoint] Saved eval progress at idx=660000 -> /content/drive/MyDrive/labtop2/v1/eval_progress.pkl


Evaluating:  54%|█████▍    | 670003/1229791 [5:18:30<7:36:26, 20.44it/s]


[Checkpoint] Saved eval progress at idx=670000 -> /content/drive/MyDrive/labtop2/v1/eval_progress.pkl


Evaluating:  55%|█████▌    | 680006/1229791 [5:23:19<6:27:44, 23.63it/s]


[Checkpoint] Saved eval progress at idx=680000 -> /content/drive/MyDrive/labtop2/v1/eval_progress.pkl


Evaluating:  56%|█████▌    | 690004/1229791 [5:28:05<6:09:55, 24.32it/s]


[Checkpoint] Saved eval progress at idx=690000 -> /content/drive/MyDrive/labtop2/v1/eval_progress.pkl


Evaluating:  57%|█████▋    | 700003/1229791 [5:32:50<7:20:45, 20.03it/s]


[Checkpoint] Saved eval progress at idx=700000 -> /content/drive/MyDrive/labtop2/v1/eval_progress.pkl


Evaluating:  58%|█████▊    | 710006/1229791 [5:37:37<5:54:02, 24.47it/s]


[Checkpoint] Saved eval progress at idx=710000 -> /content/drive/MyDrive/labtop2/v1/eval_progress.pkl


Evaluating:  59%|█████▊    | 720004/1229791 [5:42:17<5:58:34, 23.70it/s]


[Checkpoint] Saved eval progress at idx=720000 -> /content/drive/MyDrive/labtop2/v1/eval_progress.pkl


Evaluating:  59%|█████▉    | 730004/1229791 [5:46:57<5:59:32, 23.17it/s]


[Checkpoint] Saved eval progress at idx=730000 -> /content/drive/MyDrive/labtop2/v1/eval_progress.pkl


Evaluating:  60%|██████    | 740004/1229791 [5:51:38<5:38:41, 24.10it/s]


[Checkpoint] Saved eval progress at idx=740000 -> /content/drive/MyDrive/labtop2/v1/eval_progress.pkl


Evaluating:  61%|██████    | 750007/1229791 [5:56:19<5:29:44, 24.25it/s]


[Checkpoint] Saved eval progress at idx=750000 -> /content/drive/MyDrive/labtop2/v1/eval_progress.pkl


Evaluating:  62%|██████▏   | 760005/1229791 [6:00:58<5:32:27, 23.55it/s]


[Checkpoint] Saved eval progress at idx=760000 -> /content/drive/MyDrive/labtop2/v1/eval_progress.pkl


Evaluating:  63%|██████▎   | 770003/1229791 [6:05:38<6:25:08, 19.90it/s]


[Checkpoint] Saved eval progress at idx=770000 -> /content/drive/MyDrive/labtop2/v1/eval_progress.pkl


Evaluating:  63%|██████▎   | 780004/1229791 [6:10:26<5:27:29, 22.89it/s]


[Checkpoint] Saved eval progress at idx=780000 -> /content/drive/MyDrive/labtop2/v1/eval_progress.pkl


Evaluating:  64%|██████▍   | 790004/1229791 [6:15:17<5:36:43, 21.77it/s]


[Checkpoint] Saved eval progress at idx=790000 -> /content/drive/MyDrive/labtop2/v1/eval_progress.pkl


Evaluating:  65%|██████▌   | 800006/1229791 [6:20:04<5:08:06, 23.25it/s]


[Checkpoint] Saved eval progress at idx=800000 -> /content/drive/MyDrive/labtop2/v1/eval_progress.pkl


Evaluating:  66%|██████▌   | 810006/1229791 [6:24:50<4:59:11, 23.38it/s]


[Checkpoint] Saved eval progress at idx=810000 -> /content/drive/MyDrive/labtop2/v1/eval_progress.pkl


Evaluating:  67%|██████▋   | 820005/1229791 [6:29:36<4:51:41, 23.41it/s]


[Checkpoint] Saved eval progress at idx=820000 -> /content/drive/MyDrive/labtop2/v1/eval_progress.pkl


Evaluating:  67%|██████▋   | 830003/1229791 [6:34:24<5:44:37, 19.33it/s]


[Checkpoint] Saved eval progress at idx=830000 -> /content/drive/MyDrive/labtop2/v1/eval_progress.pkl


Evaluating:  68%|██████▊   | 840003/1229791 [6:39:11<5:23:49, 20.06it/s]


[Checkpoint] Saved eval progress at idx=840000 -> /content/drive/MyDrive/labtop2/v1/eval_progress.pkl


Evaluating:  69%|██████▉   | 850003/1229791 [6:44:00<5:07:50, 20.56it/s]


[Checkpoint] Saved eval progress at idx=850000 -> /content/drive/MyDrive/labtop2/v1/eval_progress.pkl


Evaluating:  70%|██████▉   | 860004/1229791 [6:48:47<4:47:50, 21.41it/s]


[Checkpoint] Saved eval progress at idx=860000 -> /content/drive/MyDrive/labtop2/v1/eval_progress.pkl


Evaluating:  71%|███████   | 870005/1229791 [6:53:36<4:13:59, 23.61it/s]


[Checkpoint] Saved eval progress at idx=870000 -> /content/drive/MyDrive/labtop2/v1/eval_progress.pkl


Evaluating:  72%|███████▏  | 880007/1229791 [6:58:23<4:25:08, 21.99it/s]


[Checkpoint] Saved eval progress at idx=880000 -> /content/drive/MyDrive/labtop2/v1/eval_progress.pkl


Evaluating:  72%|███████▏  | 890006/1229791 [7:03:13<4:24:48, 21.39it/s]


[Checkpoint] Saved eval progress at idx=890000 -> /content/drive/MyDrive/labtop2/v1/eval_progress.pkl


Evaluating:  73%|███████▎  | 900003/1229791 [7:08:03<4:41:45, 19.51it/s]


[Checkpoint] Saved eval progress at idx=900000 -> /content/drive/MyDrive/labtop2/v1/eval_progress.pkl


Evaluating:  74%|███████▍  | 910004/1229791 [7:12:53<4:08:38, 21.44it/s]


[Checkpoint] Saved eval progress at idx=910000 -> /content/drive/MyDrive/labtop2/v1/eval_progress.pkl


Evaluating:  75%|███████▍  | 920005/1229791 [7:17:40<4:02:25, 21.30it/s]


[Checkpoint] Saved eval progress at idx=920000 -> /content/drive/MyDrive/labtop2/v1/eval_progress.pkl


Evaluating:  76%|███████▌  | 930006/1229791 [7:22:29<3:52:19, 21.51it/s]


[Checkpoint] Saved eval progress at idx=930000 -> /content/drive/MyDrive/labtop2/v1/eval_progress.pkl


Evaluating:  76%|███████▋  | 940003/1229791 [7:27:18<4:10:25, 19.29it/s]


[Checkpoint] Saved eval progress at idx=940000 -> /content/drive/MyDrive/labtop2/v1/eval_progress.pkl


Evaluating:  77%|███████▋  | 950007/1229791 [7:32:07<3:40:20, 21.16it/s]


[Checkpoint] Saved eval progress at idx=950000 -> /content/drive/MyDrive/labtop2/v1/eval_progress.pkl


Evaluating:  78%|███████▊  | 960005/1229791 [7:36:56<3:37:01, 20.72it/s]


[Checkpoint] Saved eval progress at idx=960000 -> /content/drive/MyDrive/labtop2/v1/eval_progress.pkl


Evaluating:  79%|███████▉  | 970003/1229791 [7:41:46<4:02:18, 17.87it/s]


[Checkpoint] Saved eval progress at idx=970000 -> /content/drive/MyDrive/labtop2/v1/eval_progress.pkl


Evaluating:  79%|███████▉  | 975457/1229791 [7:44:24<2:04:17, 34.10it/s]

In [3]:
import os
import time
import pickle
import random
import numpy as np
from math import isfinite

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
from transformers import AutoTokenizer, GPT2Config, GPT2LMHeadModel

# ---------------- CONFIG ----------------
DATA_DIR = "/content/drive/MyDrive/labtop2/v1"  # <-- adjust
MAX_LEN = 1024            # must be <= preprocessor max_len
BATCH_SIZE = 32
LR = 5e-4                 # a bit higher for small model
NUM_EPOCHS = 10
PATIENCE = 3
GRAD_CLIP = 1.0

# checkpoint 路径
CKPT_LAST = os.path.join(DATA_DIR, "labtop_checkpoint_last.pt")  # 断点续训用
CKPT_BEST = os.path.join(DATA_DIR, "labtop_v1.pth")      # 最优模型，用于 eval


# ---------------- UTILS ----------------
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)


# ---------------- DATASET ----------------
class LabTOPDataset(Dataset):
    """
    Uses only input_ids/type_ids from your preprocessed pkl.
    Training now ignores type_ids for loss masking (LabTOP-style).
    """
    def __init__(self, data_path):
        with open(data_path, "rb") as f:
            self.data = pickle.load(f)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        return {
            "input_ids": torch.tensor(item["input_ids"], dtype=torch.long),
            "type_ids": torch.tensor(item["type_ids"], dtype=torch.long),  # kept for potential analysis
        }


# ---------------- COLLATE FN ----------------
def collate_fn(batch, pad_token_id):
    """
    - Truncates to MAX_LEN (from the *right* by default)
    - Pads to the longest length in batch
    - Labels = input_ids shifted internally by GPT2; here we just
      pass input_ids as labels with pad positions set to -100.
    """
    input_ids = [b["input_ids"] for b in batch]

    # truncate sequences (keep last MAX_LEN tokens)
    truncated = []
    for seq in input_ids:
        if len(seq) > MAX_LEN:
            truncated.append(seq[-MAX_LEN:])
        else:
            truncated.append(seq)

    input_ids_pad = torch.nn.utils.rnn.pad_sequence(
        truncated, batch_first=True, padding_value=pad_token_id
    )
    attention_mask = (input_ids_pad != pad_token_id).long()

    labels = input_ids_pad.clone()
    labels[input_ids_pad == pad_token_id] = -100   # ignore pads only

    return input_ids_pad, attention_mask, labels


# ---------------- MODEL (small GPT2-style LM) ----------------
class LabTOPGPT2Small(nn.Module):
    """
    Small GPT-2 style LM:
    - fewer layers / smaller hidden size for your budget
    - still uses GPT2LMHeadModel for correct autoregressive behavior
    """
    def __init__(self, tokenizer, d_model=256, n_heads=4, num_layers=4, max_len=MAX_LEN, dropout=0.1):
        super().__init__()
        vocab_size = len(tokenizer)
        config = GPT2Config(
            vocab_size=vocab_size,
            n_embd=d_model,
            n_head=n_heads,
            n_layer=num_layers,
            n_positions=max_len,
            n_ctx=max_len,
            resid_pdrop=dropout,
            embd_pdrop=dropout,
            attn_pdrop=dropout,
            bos_token_id=tokenizer.bos_token_id,
            eos_token_id=tokenizer.eos_token_id,
        )
        self.model = GPT2LMHeadModel(config)

    def forward(self, input_ids, attention_mask=None, labels=None):
        """
        GPT2LMHeadModel will:
        - apply causal masking internally
        - compute next-token cross-entropy loss if labels is provided
        """
        outputs = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=labels,
        )
        return outputs  # has .logits and .loss

    def generate_next_tokens(self, input_ids, attention_mask=None, max_new_tokens=6, bad_ids=None, eos_id=None):
        """
        Simple greedy generation of a few tokens.
        """
        self.eval()
        device = input_ids.device
        generated = []

        with torch.no_grad():
            for _ in range(max_new_tokens):
                if input_ids.size(1) > MAX_LEN:
                    input_ids = input_ids[:, -MAX_LEN:]
                    if attention_mask is not None:
                        attention_mask = attention_mask[:, -MAX_LEN:]

                outputs = self.model(input_ids, attention_mask=attention_mask)
                logits = outputs.logits[:, -1, :]  # (B, vocab)
                if bad_ids:
                    logits[:, bad_ids] = -1e9

                next_token = torch.argmax(logits, dim=-1)  # (B,)
                next_id = next_token.item()
                generated.append(next_id)

                input_ids = torch.cat([input_ids, next_token.unsqueeze(0)], dim=1)
                if attention_mask is not None:
                    next_mask_token = torch.ones_like(next_token).unsqueeze(0)
                    attention_mask = torch.cat([attention_mask, next_mask_token], dim=1)

                if eos_id is not None and next_id == eos_id:
                    break

        return generated


# ---------------- DECODE NUMERIC VALUE ----------------
def decode_value(token_ids, tokenizer):
    """
    Char-level decode: keep digits, '.', '-' and parse as float.
    Returns None if parsing fails.
    """
    text = tokenizer.decode(token_ids)
    text = text.replace(" ", "")
    filtered = "".join(ch for ch in text if ch.isdigit() or ch in ".-")
    if filtered == "":
        return None
    try:
        return float(filtered)
    except Exception:
        return None


# ---------------- TRAIN ----------------
def train_model(resume=False):
    """
    完全断点续训版本：
      - 如果 resume=True 且 CKPT_LAST 存在，则从上次保存的 epoch 继续
      - 否则从头开始训练
    """
    set_seed(42)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    tokenizer = AutoTokenizer.from_pretrained(DATA_DIR)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    pad_id = tokenizer.pad_token_id

    train_dataset = LabTOPDataset(os.path.join(DATA_DIR, "train.pkl"))
    val_dataset   = LabTOPDataset(os.path.join(DATA_DIR, "val.pkl"))
    print("Train dataset size:", len(train_dataset))
    print("Seq length example:", len(train_dataset[0]["input_ids"]))

    def collate_func(batch):
        return collate_fn(batch, pad_token_id=pad_id)

    train_loader = DataLoader(
        train_dataset,
        batch_size=BATCH_SIZE,
        shuffle=True,
        collate_fn=collate_func,
        num_workers=0,
        pin_memory=(device.type == "cuda"),
    )
    val_loader = DataLoader(
        val_dataset,
        batch_size=BATCH_SIZE,
        shuffle=False,
        collate_fn=collate_func,
        num_workers=0,
        pin_memory=(device.type == "cuda"),
    )

    model = LabTOPGPT2Small(tokenizer).to(device)
    optimizer = optim.AdamW(model.parameters(), lr=LR)
    use_amp = (device.type == "cuda")
    scaler = torch.cuda.amp.GradScaler(enabled=use_amp)

    # ---- 断点恢复逻辑 ----
    start_epoch = 0
    best_val_loss = float("inf")
    patience_counter = 0

    if resume and os.path.exists(CKPT_LAST):
        print(f"Resuming from checkpoint: {CKPT_LAST}")
        checkpoint = torch.load(CKPT_LAST, map_location=device)

        model.load_state_dict(checkpoint["model_state_dict"])
        optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
        scaler.load_state_dict(checkpoint["scaler_state_dict"])

        start_epoch      = checkpoint.get("epoch", 0) + 1   # 下一个 epoch
        best_val_loss    = checkpoint.get("best_val_loss", float("inf"))
        patience_counter = checkpoint.get("patience_counter", 0)

        print(
            f"  -> start_epoch = {start_epoch}, "
            f"best_val_loss = {best_val_loss:.4f}, "
            f"patience_counter = {patience_counter}"
        )
    else:
        if resume:
            print(f"resume=True but checkpoint not found at {CKPT_LAST}, training from scratch.")
        else:
            print("Training from scratch.")

    # ---- 训练循环 ----
    for epoch in range(start_epoch, NUM_EPOCHS):
        # ---------- TRAIN ----------
        model.train()
        total_loss = 0.0
        count = 0

        for batch in tqdm(train_loader, desc=f"Epoch {epoch+1} [train]"):
            input_ids, attention_mask, labels = [b.to(device) for b in batch]

            optimizer.zero_grad()
            with torch.cuda.amp.autocast(enabled=use_amp):
                outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
                loss = outputs.loss

            scaler.scale(loss).backward()
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP)
            scaler.step(optimizer)
            scaler.update()

            total_loss += loss.item()
            count += 1

        avg_train_loss = total_loss / max(count, 1)
        print(f"Epoch {epoch+1} Train Loss: {avg_train_loss:.4f}")

        # ---------- VALIDATION ----------
        model.eval()
        val_loss = 0.0
        val_count = 0
        with torch.no_grad():
            for batch in tqdm(val_loader, desc=f"Epoch {epoch+1} [val]"):
                input_ids, attention_mask, labels = [b.to(device) for b in batch]
                with torch.cuda.amp.autocast(enabled=use_amp):
                    outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
                    loss = outputs.loss
                val_loss += loss.item()
                val_count += 1

        avg_val_loss = val_loss / max(val_count, 1)
        print(f"Epoch {epoch+1} Val Loss: {avg_val_loss:.4f}")

        # ---------- 保存最优模型（仅参数，用于 eval） ----------
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            patience_counter = 0
            torch.save(model.state_dict(), CKPT_BEST)
            print(f"Saved BEST model (val loss={best_val_loss:.4f}) to {CKPT_BEST}")
        else:
            patience_counter += 1
            print(f"No improvement. Patience: {patience_counter}/{PATIENCE}")

        # ---------- 保存完整断点（每个 epoch 都保存） ----------
        last_state = {
            "epoch": epoch,
            "model_state_dict": model.state_dict(),
            "optimizer_state_dict": optimizer.state_dict(),
            "scaler_state_dict": scaler.state_dict(),
            "best_val_loss": best_val_loss,
            "patience_counter": patience_counter,
        }
        torch.save(last_state, CKPT_LAST)
        print(f"Saved training checkpoint to {CKPT_LAST}")

        # ---------- EARLY STOP ----------
        if patience_counter >= PATIENCE:
            print("Early stopping triggered.")
            break

    print("Training complete.")


# ---------------- EVALUATION ----------------
def evaluate_model(resume=False, save_every=10000):
    test_eval_path = os.path.join(DATA_DIR, "test_eval.pkl")

    if not os.path.exists(test_eval_path):
        print("test_eval.pkl not found; nothing to evaluate.")
        return
    if not os.path.exists(CKPT_BEST):
        print(f"Best model checkpoint {CKPT_BEST} not found; train first.")
        return

    # 进度文件，用来断点续 eval
    PROGRESS_PATH = os.path.join(DATA_DIR, "eval_progress.pkl")

    if torch.cuda.is_available():
        device = torch.device("cuda")
    elif torch.backends.mps.is_available():
        device = torch.device("mps")
    else:
        device = torch.device("cpu")
    print(f"Using device: {device}")

    tokenizer = AutoTokenizer.from_pretrained(DATA_DIR)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    with open(test_eval_path, "rb") as f:
        test_data = pickle.load(f)
    print(f"Total examples in test_eval: {len(test_data)}")

    model = LabTOPGPT2Small(tokenizer).to(device)
    model.load_state_dict(torch.load(CKPT_BEST, map_location=device))
    model.eval()

    use_amp = (device.type == "cuda")

    # tokens we never want as value predictions
    bad_tokens = [
        "labevent", "inputevent", "outputevent",
        "gender", "age", "race",
        "procedureevent", "emarevent", "microevent"
    ]
    vocab = tokenizer.get_vocab()
    bad_ids = [tokenizer.convert_tokens_to_ids(t) for t in bad_tokens if t in vocab]

    eoe_id = tokenizer.convert_tokens_to_ids("[EOE]")
    max_new_tokens = 6

    # --------- 恢复 / 初始化进度 ---------
    if resume and os.path.exists(PROGRESS_PATH):
        print(f"Resuming evaluation from {PROGRESS_PATH}")
        prog = pickle.load(open(PROGRESS_PATH, "rb"))
        predictions     = prog["predictions"]
        ground_truths   = prog["ground_truths"]
        itemids         = prog["itemids"]
        event_types     = prog["event_types"]
        example_indices = prog["example_indices"]
        start_idx       = prog["next_idx"]
    else:
        predictions     = []
        ground_truths   = []
        itemids         = []
        event_types     = []
        example_indices = []
        start_idx       = 0

    print(f"Start evaluating from index {start_idx} ...")

    start_time = time.time()

    for idx in tqdm(range(start_idx, len(test_data)), desc="Evaluating"):
        item = test_data[idx]
        prompt_ids = item["prompt_ids"]
        true_val   = item["valuenum"]
        itemid     = item.get("itemid", None)
        e_type     = item.get("event_type", "unknown")

        # truncate prompt from the left to MAX_LEN
        if len(prompt_ids) > MAX_LEN:
            prompt_ids = prompt_ids[-MAX_LEN:]

        input_tensor = torch.tensor([prompt_ids], dtype=torch.long).to(device)
        attn_mask    = torch.ones_like(input_tensor, dtype=torch.long).to(device)

        # 只前几个样本打印 debug
        if idx < start_idx + 3:
            decoded_prompt_tail = tokenizer.decode(prompt_ids[-80:])
            print(f"\n=== Debug sample {idx} ===")
            print(f"Prompt tail: {decoded_prompt_tail}")
            print(f"True value: {true_val}, itemid: {itemid}, event_type: {e_type}")

        with torch.no_grad():
            with torch.cuda.amp.autocast(enabled=use_amp):
                generated_ids = model.generate_next_tokens(
                    input_ids=input_tensor,
                    attention_mask=attn_mask,
                    max_new_tokens=max_new_tokens,
                    bad_ids=bad_ids,
                    eos_id=eoe_id,
                )

        pred_val = decode_value(generated_ids, tokenizer)
        if pred_val is None or not isfinite(pred_val) or abs(pred_val) > 1e4:
            # 无效预测，跳过但仍然推进 idx
            continue

        if idx < start_idx + 3:
            decoded_generated = tokenizer.decode(generated_ids)
            print(f"Generated tokens: {decoded_generated}")
            print(f"Decoded pred_val: {pred_val}")

        predictions.append(pred_val)
        ground_truths.append(true_val)
        itemids.append(itemid)
        event_types.append(e_type)
        example_indices.append(idx)

        # --------- 定期保存 eval 进度 ---------
        if (idx + 1) % save_every == 0:
            prog = {
                "predictions": predictions,
                "ground_truths": ground_truths,
                "itemids": itemids,
                "event_types": event_types,
                "example_indices": example_indices,
                "next_idx": idx + 1,  # 下次从这里继续
            }
            with open(PROGRESS_PATH, "wb") as f:
                pickle.dump(prog, f)
            print(f"\n[Checkpoint] Saved eval progress at idx={idx+1} -> {PROGRESS_PATH}")

    # 全部跑完，删除进度文件
    if os.path.exists(PROGRESS_PATH):
        os.remove(PROGRESS_PATH)
        print(f"Removed eval progress file {PROGRESS_PATH}")

    if len(predictions) == 0:
        print("No valid predictions generated.")
        return

    preds  = np.array(predictions)
    truths = np.array(ground_truths)

    mae  = np.mean(np.abs(preds - truths))
    rmse = np.sqrt(np.mean((preds - truths) ** 2))

    p1   = np.percentile(truths, 1)
    p99  = np.percentile(truths, 99)
    range_val = p99 - p1 if p99 > p1 else 1e-6
    nmae = mae / range_val

    denom = (np.abs(truths) + np.abs(preds))
    mask = denom > 0
    if np.sum(mask) > 0:
        smape = np.mean(2 * np.abs(preds[mask] - truths[mask]) / denom[mask])
    else:
        smape = 0.0

    elapsed = time.time() - start_time
    print(f"\nEvaluated {len(preds)} samples in {elapsed:.2f} seconds.")
    print(f"Test MAE:   {mae:.4f}")
    print(f"Test RMSE:  {rmse:.4f}")
    print(f"Test NMAE:  {nmae:.4f}")
    print(f"Test SMAPE: {smape:.4f}")

    # 保存 sample predictions
    csv_path = os.path.join(DATA_DIR, "test_predictions.csv")
    print(f"Saving predictions to: {csv_path}")
    import csv
    with open(csv_path, "w", newline="") as f:
        writer = csv.writer(f)
        writer.writerow(["index", "itemid", "event_type", "true_value", "pred_value"])
        for i, iid, et, truth, pred in zip(example_indices, itemids, event_types, truths, preds):
            writer.writerow([i, iid, et, truth, pred])


if __name__ == "__main__":
    # 训练
    # RESUME_TRAIN = False
    # train_model(resume=RESUME_TRAIN)

    # eval：第一次跑用 resume=False，之后断了想接着跑就改成 True
    RESUME_EVAL = True
    evaluate_model(resume=RESUME_EVAL)

Using device: cuda
Total examples in test_eval: 1229791
Resuming evaluation from /content/drive/MyDrive/labtop2/v1/eval_progress.pkl
Start evaluating from index 970000 ...


Evaluating:   0%|          | 0/259791 [00:00<?, ?it/s]


=== Debug sample 970000 ===
Prompt tail: [10m] labevent bicarbonate 2 1 meq / l [EOE] [DAY0] [SUN] [09h] [10m] labevent anion gap 1 7 meq / l [EOE] [DAY0] [SUN] [09h] [10m] labevent potassium 4 meq / l [EOE] [DAY0] [SUN] [09h] [10m] labevent troponin t 0. 0 5 ng / ml [EOE] [DAY0] [SUN] [09h] [10m] labevent urea nitrogen 5 4 mg / dl [EOE] [DAY0] [SUN] [09h] [10m] labevent sodium
True value: 146.0, itemid: 50983, event_type: labevent


  with torch.cuda.amp.autocast(enabled=use_amp):
Evaluating:   0%|          | 4/259791 [00:00<11:45:50,  6.13it/s]

Generated tokens: 1 4 1 meq /
Decoded pred_val: 141.0

=== Debug sample 970001 ===
Prompt tail: ##roponin t 0. 0 5 ng / ml [EOE] [DAY0] [SUN] [09h] [10m] labevent urea nitrogen 5 4 mg / dl [EOE] [DAY0] [SUN] [09h] [10m] labevent sodium 1 4 6 meq / l [EOE] [DAY0] [SUN] [09h] [20m] inputevent solution 1 1. 1 7 ml [EOE] [DAY0] [SUN] [09h] [20m] inputevent propofol 1 1 1. 7 2 mg [EOE] [DAY0] [SUN] [09h] [20m] labevent oxygen saturation
True value: 97.0, itemid: 50817, event_type: labevent
Generated tokens: 9 8 % [EOE]
Decoded pred_val: 98.0

=== Debug sample 970002 ===
Prompt tail: ##l [EOE] [DAY0] [SUN] [09h] [10m] labevent urea nitrogen 5 4 mg / dl [EOE] [DAY0] [SUN] [09h] [10m] labevent sodium 1 4 6 meq / l [EOE] [DAY0] [SUN] [09h] [20m] inputevent solution 1 1. 1 7 ml [EOE] [DAY0] [SUN] [09h] [20m] inputevent propofol 1 1 1. 7 2 mg [EOE] [DAY0] [SUN] [09h] [20m] labevent oxygen saturation 9 7 % [EOE] [DAY0] [SUN] [09h] [20m] labevent free calcium
True value: 1.13, itemid: 50808, event_

Evaluating:   4%|▍         | 10007/259791 [04:30<2:40:07, 26.00it/s]


[Checkpoint] Saved eval progress at idx=980000 -> /content/drive/MyDrive/labtop2/v1/eval_progress.pkl


Evaluating:   8%|▊         | 20006/259791 [08:59<2:43:22, 24.46it/s]


[Checkpoint] Saved eval progress at idx=990000 -> /content/drive/MyDrive/labtop2/v1/eval_progress.pkl


Evaluating:  12%|█▏        | 30007/259791 [13:26<2:26:11, 26.20it/s]


[Checkpoint] Saved eval progress at idx=1000000 -> /content/drive/MyDrive/labtop2/v1/eval_progress.pkl


Evaluating:  15%|█▌        | 40007/259791 [17:55<2:19:32, 26.25it/s]


[Checkpoint] Saved eval progress at idx=1010000 -> /content/drive/MyDrive/labtop2/v1/eval_progress.pkl


Evaluating:  19%|█▉        | 50005/259791 [22:23<2:04:29, 28.09it/s]


[Checkpoint] Saved eval progress at idx=1020000 -> /content/drive/MyDrive/labtop2/v1/eval_progress.pkl


Evaluating:  23%|██▎       | 60003/259791 [26:52<2:32:37, 21.82it/s]


[Checkpoint] Saved eval progress at idx=1030000 -> /content/drive/MyDrive/labtop2/v1/eval_progress.pkl


Evaluating:  27%|██▋       | 70003/259791 [31:17<2:25:20, 21.76it/s]


[Checkpoint] Saved eval progress at idx=1040000 -> /content/drive/MyDrive/labtop2/v1/eval_progress.pkl


Evaluating:  31%|███       | 80006/259791 [35:43<1:58:55, 25.19it/s]


[Checkpoint] Saved eval progress at idx=1050000 -> /content/drive/MyDrive/labtop2/v1/eval_progress.pkl


Evaluating:  35%|███▍      | 90005/259791 [40:10<1:59:29, 23.68it/s]


[Checkpoint] Saved eval progress at idx=1060000 -> /content/drive/MyDrive/labtop2/v1/eval_progress.pkl


Evaluating:  38%|███▊      | 100005/259791 [44:37<1:43:41, 25.68it/s]


[Checkpoint] Saved eval progress at idx=1070000 -> /content/drive/MyDrive/labtop2/v1/eval_progress.pkl


Evaluating:  42%|████▏     | 110008/259791 [49:05<1:31:07, 27.40it/s]


[Checkpoint] Saved eval progress at idx=1080000 -> /content/drive/MyDrive/labtop2/v1/eval_progress.pkl


Evaluating:  46%|████▌     | 120007/259791 [53:34<1:33:59, 24.79it/s]


[Checkpoint] Saved eval progress at idx=1090000 -> /content/drive/MyDrive/labtop2/v1/eval_progress.pkl


Evaluating:  50%|█████     | 130003/259791 [58:06<1:39:25, 21.76it/s]


[Checkpoint] Saved eval progress at idx=1100000 -> /content/drive/MyDrive/labtop2/v1/eval_progress.pkl


Evaluating:  54%|█████▍    | 140008/259791 [1:02:33<1:20:24, 24.83it/s]


[Checkpoint] Saved eval progress at idx=1110000 -> /content/drive/MyDrive/labtop2/v1/eval_progress.pkl


Evaluating:  58%|█████▊    | 150006/259791 [1:07:01<1:02:25, 29.31it/s]


[Checkpoint] Saved eval progress at idx=1120000 -> /content/drive/MyDrive/labtop2/v1/eval_progress.pkl


Evaluating:  62%|██████▏   | 160004/259791 [1:11:30<1:10:20, 23.64it/s]


[Checkpoint] Saved eval progress at idx=1130000 -> /content/drive/MyDrive/labtop2/v1/eval_progress.pkl


Evaluating:  65%|██████▌   | 170007/259791 [1:15:56<58:41, 25.50it/s]  


[Checkpoint] Saved eval progress at idx=1140000 -> /content/drive/MyDrive/labtop2/v1/eval_progress.pkl


Evaluating:  69%|██████▉   | 180005/259791 [1:20:23<52:26, 25.36it/s]


[Checkpoint] Saved eval progress at idx=1150000 -> /content/drive/MyDrive/labtop2/v1/eval_progress.pkl


Evaluating:  73%|███████▎  | 190005/259791 [1:24:49<45:12, 25.73it/s]


[Checkpoint] Saved eval progress at idx=1160000 -> /content/drive/MyDrive/labtop2/v1/eval_progress.pkl


Evaluating:  77%|███████▋  | 200005/259791 [1:29:15<38:27, 25.91it/s]


[Checkpoint] Saved eval progress at idx=1170000 -> /content/drive/MyDrive/labtop2/v1/eval_progress.pkl


Evaluating:  81%|████████  | 210007/259791 [1:33:39<31:17, 26.52it/s]


[Checkpoint] Saved eval progress at idx=1180000 -> /content/drive/MyDrive/labtop2/v1/eval_progress.pkl


Evaluating:  85%|████████▍ | 220006/259791 [1:38:04<27:33, 24.06it/s]


[Checkpoint] Saved eval progress at idx=1190000 -> /content/drive/MyDrive/labtop2/v1/eval_progress.pkl


Evaluating:  89%|████████▊ | 230006/259791 [1:42:33<20:38, 24.04it/s]


[Checkpoint] Saved eval progress at idx=1200000 -> /content/drive/MyDrive/labtop2/v1/eval_progress.pkl


Evaluating:  92%|█████████▏| 240006/259791 [1:47:00<13:52, 23.76it/s]


[Checkpoint] Saved eval progress at idx=1210000 -> /content/drive/MyDrive/labtop2/v1/eval_progress.pkl


Evaluating:  96%|█████████▌| 250004/259791 [1:51:27<07:03, 23.10it/s]


[Checkpoint] Saved eval progress at idx=1220000 -> /content/drive/MyDrive/labtop2/v1/eval_progress.pkl


Evaluating: 100%|██████████| 259791/259791 [1:55:49<00:00, 37.38it/s]


Removed eval progress file /content/drive/MyDrive/labtop2/v1/eval_progress.pkl

Evaluated 1229124 samples in 6950.21 seconds.
Test MAE:   35.0470
Test RMSE:  1321.8479
Test NMAE:  0.0767
Test SMAPE: 0.2935
Saving predictions to: /content/drive/MyDrive/labtop2/v1/test_predictions.csv


In [3]:
import torch, os

path = "/content/drive/MyDrive/labtop2/v1/labtop_v1.pth"  # 改成你想查的
state = torch.load(path, map_location="cpu")
print(state.keys())

ckpt = torch.load(path, map_location="cpu")
wte = ckpt["model.transformer.wte.weight"]
wpe = ckpt["model.transformer.wpe.weight"]

print("wte shape:", wte.shape)  # (vocab_size, d_model)
print("wpe shape:", wpe.shape)  # (max_positions, d_model)

odict_keys(['model.transformer.wte.weight', 'model.transformer.wpe.weight', 'model.transformer.h.0.ln_1.weight', 'model.transformer.h.0.ln_1.bias', 'model.transformer.h.0.attn.c_attn.weight', 'model.transformer.h.0.attn.c_attn.bias', 'model.transformer.h.0.attn.c_proj.weight', 'model.transformer.h.0.attn.c_proj.bias', 'model.transformer.h.0.ln_2.weight', 'model.transformer.h.0.ln_2.bias', 'model.transformer.h.0.mlp.c_fc.weight', 'model.transformer.h.0.mlp.c_fc.bias', 'model.transformer.h.0.mlp.c_proj.weight', 'model.transformer.h.0.mlp.c_proj.bias', 'model.transformer.h.1.ln_1.weight', 'model.transformer.h.1.ln_1.bias', 'model.transformer.h.1.attn.c_attn.weight', 'model.transformer.h.1.attn.c_attn.bias', 'model.transformer.h.1.attn.c_proj.weight', 'model.transformer.h.1.attn.c_proj.bias', 'model.transformer.h.1.ln_2.weight', 'model.transformer.h.1.ln_2.bias', 'model.transformer.h.1.mlp.c_fc.weight', 'model.transformer.h.1.mlp.c_fc.bias', 'model.transformer.h.1.mlp.c_proj.weight', 'mode

In [2]:
import os
import pickle
from collections import Counter
from transformers import AutoTokenizer

OUT_DIR = "/content/drive/MyDrive/labtop2/v1"

def load_pkl(name):
    with open(os.path.join(OUT_DIR, name), "rb") as f:
        return pickle.load(f)

# 1. 看一下有哪些 pkl，大小
# for fn in os.listdir(OUT_DIR):
#     if fn.endswith(".pkl"):
#         path = os.path.join(OUT_DIR, fn)
#         print(fn, os.path.getsize(path) / 1024 / 1024, "MB")

# 2. 读一小段 train / val
train = load_pkl("train.pkl")[:200]
val   = load_pkl("val.pkl")[:200]

print("sample train size:", len(train))

# 3. 检查每条的 keys、长度、type_ids 分布
print("keys in one example:", train[0].keys())
print("len(input_ids):", len(train[0]["input_ids"]))
print("len(type_ids):", len(train[0]["type_ids"]))

from collections import Counter
c = Counter()
for x in train:
    c.update(x["type_ids"])
print("type_ids counts:", dict(c))

# 4. 解码几条看看文本大概长什么样
tok = AutoTokenizer.from_pretrained(OUT_DIR)
for i in range(3):
    ids = train[i]["input_ids"][:300]
    print(f"\n=== train sample {i} ===")
    print(tok.decode(ids))

sample train size: 200
keys in one example: dict_keys(['stay_id', 'input_ids', 'type_ids'])
len(input_ids): 276
len(type_ids): 276
type_ids counts: {0: 147166, 1: 36570}

=== train sample 0 ===
gender f age 52 race white [DAY0] [SUN] [00h] [20m] procedureevent 18 gauge 5 6 6 min [EOE] [DAY0] [SUN] [00h] [20m] procedureevent 20 gauge 5 6 6 min [EOE] [DAY0] [SUN] [01h] [00m] outputevent void 1 7 5 ml [EOE] [DAY0] [SUN] [03h] [00m] inputevent albumin 25 % 5 0 ml [EOE] [DAY0] [SUN] [03h] [00m] inputevent po intake 2 0 0 ml [EOE] [DAY0] [SUN] [03h] [30m] inputevent albumin 25 % 5 0 ml [EOE] [DAY0] [SUN] [04h] [50m] inputevent po intake 1 0 0 ml [EOE] [DAY0] [SUN] [07h] [10m] inputevent po intake 1 0 0 ml [EOE] [DAY0] [SUN] [07h] [40m] labevent urea nitrogen 3 3 mg / dl [EOE] [DAY0] [SUN] [07h] [40m] labevent anion gap 1 4 meq / l [EOE] [DAY0] [SUN] [07h] [40m] labevent phosphate 2. 4 mg / dl [EOE] [DAY0] [SUN] [07h] [40m] labevent magnesium 2. 3 mg / dl [EOE] [DAY0] [SUN] [07h] [40m] labeve