In [None]:
import os, time, pickle, random
import numpy as np, pandas as pd, torch
from tqdm import tqdm

from data_loader import load_data
from models import Tramba

def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(42)

In [None]:
# ── Configurations ──────────────────────────────────────────
SEQ_LIST   = [36]               
PRED_LIST  = [1, 6, 12, 24, 36]
N_FEATURE  = 3
D_MODEL    = 16
BATCH_SIZE = 8
EPOCHS     = 10

DATA_PATH  = "data/seoul_traffic_speed.csv"
LINK_PATH  = "data/GN_links.csv"
BASE_DIR   = "results"
os.makedirs(BASE_DIR, exist_ok=True)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)


Device: cuda


In [None]:
def train_tramba(model, tr_loader, vl_loader, label, scaler, save_dir):
    model.to(device)
    opt  = torch.optim.Adam(model.parameters(), lr=1e-3)
    crit = torch.nn.MSELoss()
    hist = {"train": [], "val": []}

    for ep in range(1, EPOCHS + 1):
        model.train(); tr_l = []
        for x, y in tqdm(tr_loader, desc=f"{label} {ep}/{EPOCHS}", leave=False, ncols=80):
            x, y = x.to(device), y.to(device)
            opt.zero_grad(); loss = crit(model(x), y); loss.backward(); opt.step()
            tr_l.append(loss.item())
        model.eval(); vl_l = []
        with torch.no_grad():
            for x, y in vl_loader:
                x, y = x.to(device), y.to(device)
                vl_l.append(crit(model(x), y).item())
        hist["train"].append(np.mean(tr_l))
        hist["val"].append(np.mean(vl_l))

    with open(os.path.join(save_dir, f"{label}_hist.pkl"), "wb") as f:
        pickle.dump(hist, f)

    preds, trues = [], []
    model.eval()
    with torch.no_grad():
        for x, y in vl_loader:
            x, y = x.to(device), y.to(device)
            preds.append(model(x).cpu().numpy()); trues.append(y.cpu().numpy())
            
    preds, trues = np.concatenate(preds), np.concatenate(trues)

    if preds.ndim == 4:
        preds, trues = preds[:, -1, :, 0], trues[:, -1, :, 0]  # shape (B, L)
    else:
        preds, trues = preds[:, -1, :], trues[:, -1, :]        # shape (B, L)

    pr = scaler.inverse_transform(preds.reshape(-1, preds.shape[-1])).reshape(preds.shape)
    tr = scaler.inverse_transform(trues.reshape(-1, trues.shape[-1])).reshape(trues.shape)

    mask = np.abs(tr) > 1e-3

    return dict(
        MAPE = round(np.mean(np.abs((tr[mask] - pr[mask]) / tr[mask])) * 100, 2),
        MAE  = round(np.mean(np.abs(tr - pr)), 3),
        MSE  = round(np.mean((tr - pr) ** 2), 3),
    ), model


# ── Run for SEQ_LEN × PRED_LEN combinations ────────────────
summary = []
for SEQ_LEN in SEQ_LIST:
    SAVE_DIR = os.path.join(BASE_DIR, f"s{SEQ_LEN}")
    os.makedirs(SAVE_DIR, exist_ok=True)

    for P in PRED_LIST:
        print(f"\n=== Tramba | seq_len {SEQ_LEN} | pred_len {P} ===")

        tr_loader, vl_loader, scaler = load_data(
            DATA_PATH, LINK_PATH, seq_len=SEQ_LEN, pred_len=P, batch_size=BATCH_SIZE)

        label   = f"Tramba_p{P}_s{SEQ_LEN}"
        wt_path = os.path.join(SAVE_DIR, f"{label}_wt.pth")

        if os.path.exists(wt_path):
            print(f"{label} exists, skipping.")
            continue

        model = Tramba(d_model=D_MODEL, in_features=N_FEATURE, pred_len=P)

        t0 = time.time()
        try:
            metrics, trained = train_tramba(model, tr_loader, vl_loader, label, scaler, SAVE_DIR)
            runtime = round(time.time() - t0, 1)

            torch.save(trained.state_dict(), wt_path)
            with open(os.path.join(SAVE_DIR, f"{label}_runtime.txt"), "w") as fp:
                fp.write(str(runtime))

            metrics.update(Model="Tramba", pred_len=P, seq_len=SEQ_LEN, Time_s=runtime)
            summary.append(metrics)

            print(f"{label} done ({runtime}s)")
        except Exception as e:
            print(label, e)

# ── Save Summary Table ──────────────────────────────────────
csv_path = os.path.join(BASE_DIR, "results_tramba.csv")
if summary:
    df = pd.DataFrame(summary)
    if os.path.exists(csv_path):
        df_all = pd.concat([pd.read_csv(csv_path), df], ignore_index=True)
        df_all = df_all.drop_duplicates(subset=["Model", "pred_len", "seq_len"])
    else:
        df_all = df
    df_all.to_csv(csv_path, index=False)
    display(df_all)
    print("saved to", csv_path)
else:
    print("No new runs.")
