In [1]:
%cd /home/ec2-user
import glob

matches = glob.glob("data/cleaned_partitioned_ais/*/*/*/*.parquet")
print(f"Found {len(matches)} parquet files:")
for p in matches[:10]:
    print(" ", p)

/home/ec2-user
Found 2 parquet files:
  data/cleaned_partitioned_ais/year=2025/month=2/day=2/part-0.parquet
  data/cleaned_partitioned_ais/year=2025/month=2/day=1/part-0.parquet


In [2]:
%cd /home/ec2-user
# ─── 1) Setup imports & module path ──────────────────────────────────────────
import sys
from pathlib import Path
import time

# add your Fourier head module to path
fourier_dir = Path.home() / "repos" / "fourier-head" / "notebooks"
sys.path.insert(0, str(fourier_dir))

import polars as pl
import torch
from torch import nn, optim
from torch.utils.data import IterableDataset, DataLoader

from four_head_2D_LR import FourierHead2DLR  # your FourierHead2D_FFT implementation

# ─── 2) Data loading helper ─────────────────────────────────────────────────
def load_cleaned_data(cleaned_root: str) -> pl.DataFrame:
    print("🟡 Entering load_cleaned_data()", flush=True)
    p = Path(cleaned_root).resolve()
    print(f"    Looking under {p}", flush=True)
    files = list(p.rglob("*.parquet"))
    print(f"    Found {len(files)} parquet files", flush=True)
    if not files:
        raise FileNotFoundError(f"No parquet files under {p}")
    print("    Reading into Polars...", flush=True)
    df = pl.read_parquet([str(f) for f in files])
    print(f"🟢 Loaded DataFrame: {df.height} rows, {len(df.columns)} cols", flush=True)
    return df

# ─── 3) Streaming dataset ───────────────────────────────────────────────────
class AISForecastIterableDataset(IterableDataset):
    """
    Streams windows of past->next positions, one vessel at a time.
    """
    def __init__(self, df: pl.DataFrame, seq_len: int = 10):
        print("🟡 Initializing streaming dataset", flush=True)
        start = time.time()
        # normalize coordinates
        df = df.with_columns([
            (pl.col("lat") / 90.0).alias("lat_n"),
            (pl.col("lon") / 180.0).alias("lon_n"),
        ])
        self.df = df.sort(["mmsi", "timestamp"])
        self.mmsis = self.df["mmsi"].unique().to_list()
        self.seq_len = seq_len
        print(f"🟢 Dataset init: {len(self.mmsis)} vessels in {time.time()-start:.1f}s", flush=True)

    def __iter__(self):
        for idx, m in enumerate(self.mmsis, 1):
            grp = self.df.filter(pl.col("mmsi") == m).select(["lat_n", "lon_n"])
            coords = torch.tensor(grp.to_numpy(), dtype=torch.float32)
            N = coords.size(0)
            if N <= self.seq_len:
                continue
            for i in range(self.seq_len, N):
                past   = coords[i-self.seq_len:i]  # (seq_len,2)
                target = coords[i]                 # (2,)
                yield past, target
            if idx % 500 == 0:
                print(f"    [Dataset] streamed {idx}/{len(self.mmsis)} vessels", flush=True)

# ─── 4) Model definition ────────────────────────────────────────────────────
class TransformerForecaster(nn.Module):
    def __init__(self, seq_len: int, d_model: int, nhead: int,
                 num_layers: int, ff_hidden: int, fourier_m: int, rank: int):
        super().__init__()
        print("🟡 Building model", flush=True)
        self.input_proj = nn.Linear(2, d_model)
        self.pos_emb = nn.Parameter(torch.randn(seq_len, d_model))
        encoder_layer = nn.TransformerEncoderLayer(d_model, nhead, ff_hidden)
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers)
        self.fh = FourierHead2DLR(dim_input=d_model, num_frequencies=fourier_m,rank=rank)
        print("🟢 Model built", flush=True)

    def forward(self, x: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
        # x: (B, seq_len, 2), targets: (B, 2)
        h = self.input_proj(x) + self.pos_emb.unsqueeze(0)   # (B,S,d_model)
        h = self.transformer(h.transpose(0,1))               # (S,B,d_model)
        last = h[-1]                                         # (B,d_model)
        return self.fh(last, targets)                        # (B,)

# ─── 5) Training loop ───────────────────────────────────────────────────────
def train(
    df: pl.DataFrame,
    seq_len: int = 10,
    d_model: int = 64,
    nhead: int = 4,
    num_layers: int = 2,
    ff_hidden: int = 128,
    fourier_m: int = 8,
    rank: int = 4,
    batch_size: int = 64,
    lr: float = 1e-6,
    epochs: int = 5,
    device: str = None
):
    device = device or ("cuda" if torch.cuda.is_available() else "cpu")
    print(f"🟡 Training on device: {device}", flush=True)

    ds = AISForecastIterableDataset(df, seq_len=seq_len)
    loader = DataLoader(ds, batch_size=batch_size, shuffle=False,
                        drop_last=True, num_workers=0, pin_memory=True)
    print(f"🟢 DataLoader ready with batch_size={batch_size}", flush=True)

    model = TransformerForecaster(seq_len, d_model, nhead, num_layers, ff_hidden, fourier_m, rank)
    model.to(device)
    opt = optim.Adam(model.parameters(), lr=lr)

    for ep in range(1, epochs+1):
        print(f"\n🟡 Epoch {ep}/{epochs} start", flush=True)
        total_nll = 0.0
        count = 0
        for batch_i, (xb, yb) in enumerate(loader, 1):
            xb, yb = xb.to(device), yb.to(device)
            pdf = model(xb, yb)
            loss = -(pdf + 1e-12).log().mean()
            opt.zero_grad(); loss.backward(); opt.step()

            total_nll += loss.item() * xb.size(0)
            count += xb.size(0)
            if batch_i % 50 == 0:
                print(f"    [Epoch {ep}] Batch {batch_i} NLL {loss.item():.4f}", flush=True)
        avg_nll = total_nll / count
        print(f"🟢 Epoch {ep} done — Avg NLL: {avg_nll:.4f}", flush=True)

    return model

# ─── 6) Run everything ───────────────────────────────────────────────────────
print("🟢 Cell start: loading data", flush=True)
df = load_cleaned_data("data/cleaned_partitioned_ais")

print("🟢 Starting training", flush=True)
model = train(
    df,
    seq_len=20,
    d_model=128,
    nhead=4,
    num_layers=4,
    ff_hidden=512,
    fourier_m=512,
    batch_size=64,
    lr=1e-3,
    epochs=100,
    rank = 4
)
print("🟢 Training complete", flush=True)


/home/ec2-user
🟢 Cell start: loading data
🟡 Entering load_cleaned_data()
    Looking under /home/ec2-user/data/cleaned_partitioned_ais
    Found 2 parquet files
    Reading into Polars...
🟢 Loaded DataFrame: 19082312 rows, 9 cols
🟢 Starting training
🟡 Training on device: cuda
🟡 Initializing streaming dataset
🟢 Dataset init: 3866 vessels in 2.8s
🟢 DataLoader ready with batch_size=64
🟡 Building model




🟢 Model built

🟡 Epoch 1/100 start
    [Epoch 1] Batch 50 NLL -6.5805
    [Epoch 1] Batch 100 NLL -8.7749
    [Epoch 1] Batch 150 NLL -9.7512
    [Epoch 1] Batch 200 NLL -10.2883
    [Epoch 1] Batch 250 NLL -10.6036
    [Epoch 1] Batch 300 NLL -5.0510
    [Epoch 1] Batch 350 NLL -9.2293
    [Epoch 1] Batch 400 NLL -10.1182
    [Epoch 1] Batch 450 NLL -10.5253
    [Epoch 1] Batch 500 NLL -10.7483
    [Epoch 1] Batch 550 NLL -10.8768
    [Epoch 1] Batch 600 NLL -8.7911
    [Epoch 1] Batch 650 NLL -10.1051
    [Epoch 1] Batch 700 NLL -10.5682
    [Epoch 1] Batch 750 NLL -10.7944
    [Epoch 1] Batch 800 NLL -10.9186
    [Epoch 1] Batch 850 NLL -7.2878
    [Epoch 1] Batch 900 NLL -9.8364
    [Epoch 1] Batch 950 NLL -10.4597
    [Epoch 1] Batch 1000 NLL -10.7429
    [Epoch 1] Batch 1050 NLL -10.8915
    [Epoch 1] Batch 1100 NLL -10.9725
    [Epoch 1] Batch 1150 NLL -11.0187


KeyboardInterrupt: 

In [1]:
%%bash
# ─── cell: end-to-end local-window training ────────────────────────────────
# Run this once; it replaces the previous “global PDF” cell.
set -euo pipefail
python - <<'PY'
# ------------------------------------------------------------
# 0) Imports  (all standard – nothing beyond PyTorch & Polars)
# ------------------------------------------------------------
import sys, time, math
from pathlib import Path

import polars as pl
import torch
from torch import nn, optim
from torch.utils.data import IterableDataset, DataLoader

# Make sure the Fourier-head repo is on path
fourier_dir = Path.home() / "repos" / "fourier-head" / "notebooks"
sys.path.insert(0, str(fourier_dir))
from four_head_2D_LR import FourierHead2DLR   # low-rank Fourier head

# ============================================================
# 1) Utility: lat/lon  →  local (u,v)  in [-1,1]²
# ============================================================
def latlon_to_local_uv(lat, lon, lat0, lon0, half_side_mi=50.0):
    """
    lat, lon, lat0, lon0  : tensors in **degrees**  (any shape …)
    Returns:
        uv    : (..., 2)  where each component is in [-1,1]  (clipped)
        logJ  : (...)    log-Jacobian Δ(area_u,v) → Δ(area_lat,lon)
    """
    lat_rad   = torch.deg2rad(lat)
    lat0_rad  = torch.deg2rad(lat0)

    # miles per degree (≈ 69 mi) and longitude scaling with cos(lat)
    R_mi      = 69.0
    dx_mi     = R_mi * torch.cos(lat0_rad) * (lon - lon0)
    dy_mi     = R_mi * (lat - lat0)

    u         = dx_mi / half_side_mi
    v         = dy_mi / half_side_mi

    # ---------- Jacobian |∂(x,y)/∂(u,v)| = (half_side_mi)²
    # plus lat scaling factors (R cosφ, R)
    logJ      = (
        -2 * math.log(half_side_mi)
        - torch.log(R_mi * torch.cos(lat0_rad))
        - math.log(R_mi)
    )

    # ----- Option A: clip targets that wander outside the window
    u_clipped = torch.clamp(u, -1.0, 1.0)
    v_clipped = torch.clamp(v, -1.0, 1.0)
    uv        = torch.stack([u_clipped, v_clipped], dim=-1)
    return uv, logJ

# ============================================================
# 2) Load all cleaned-AIS parquet shards
# ============================================================
def load_cleaned_data(cleaned_root: str) -> pl.DataFrame:
    p = Path(cleaned_root).resolve()
    files = list(p.rglob("*.parquet"))
    if not files:
        raise FileNotFoundError(f"No parquet files under {p}")
    print(f"🟢 Reading {len(files)} parquet file(s)…", flush=True)
    return pl.read_parquet([str(f) for f in files])

# ============================================================
# 3) Streaming dataset  (local-window version)
# ============================================================
class AISForecastIterableDataset(IterableDataset):
    def __init__(self, df: pl.DataFrame, seq_len: int = 10, half_side_mi=50.0):
        start = time.time()
        self.df = df.sort(["mmsi", "timestamp"])
        self.mmsis = self.df["mmsi"].unique().to_list()
        self.seq_len = seq_len
        self.half_side_mi = half_side_mi
        print(f"🟢 Dataset: {len(self.mmsis)} vessels loaded in {time.time()-start:.1f}s", flush=True)

    def __iter__(self):
        for m in self.mmsis:
            grp = (
                self.df
                .filter(pl.col("mmsi") == m)
                .select(["lat", "lon"])
            )
            coords = torch.tensor(grp.to_numpy(), dtype=torch.float32)  # (N,2)
            N = coords.size(0)
            if N <= self.seq_len + 1:
                continue

            # rolling windows
            for i in range(self.seq_len, N - 1):
                past_abs = coords[i - self.seq_len : i]   # (S,2)
                target_abs = coords[i]                    # (2,)

                lat0, lon0 = past_abs[-1]                 # window centre
                past_uv, _     = latlon_to_local_uv(
                    past_abs[:, 0], past_abs[:, 1],
                    lat0, lon0, self.half_side_mi
                )
                target_uv, logJ = latlon_to_local_uv(
                    target_abs[0], target_abs[1],
                    lat0, lon0, self.half_side_mi
                )

                centre_ll_norm = torch.tensor([lat0 / 90.0, lon0 / 180.0])

                yield (
                    past_uv,              # (S,2)  local motion
                    centre_ll_norm,       # (2,)    global context
                    target_uv,            # (2,)    local target
                    logJ                  # ()      log-Jacobian
                )

# ============================================================
# 4) Model  (local motion + global context)
# ============================================================
class TransformerForecaster(nn.Module):
    def __init__(
        self,
        seq_len: int,
        d_model: int,
        nhead: int,
        num_layers: int,
        ff_hidden: int,
        fourier_m: int,
        rank: int,
    ):
        super().__init__()
        self.seq_len = seq_len

        # each timestep now has 4 features: (u,v, lat_norm, lon_norm)
        self.input_proj = nn.Linear(4, d_model)
        self.pos_emb = nn.Parameter(torch.randn(seq_len, d_model))
        enc_layer = nn.TransformerEncoderLayer(d_model, nhead, ff_hidden)
        self.transformer = nn.TransformerEncoder(enc_layer, num_layers)
        self.fh = FourierHead2DLR(
            dim_input=d_model,
            num_frequencies=fourier_m,
            rank=rank
        )

    def forward(self, past_uv, centre_ll, target_uv):
        """
        past_uv   : (B, S, 2)
        centre_ll : (B, 2)
        target_uv : (B, 2)
        """
        B, S, _ = past_uv.shape
        # repeat centre_lat/lon to every timestep
        centre_rep = centre_ll.unsqueeze(1).expand(-1, S, -1)  # (B,S,2)
        x = torch.cat([past_uv, centre_rep], dim=-1)           # (B,S,4)
        h = self.input_proj(x) + self.pos_emb.unsqueeze(0)     # (B,S,d)
        h = self.transformer(h.transpose(0, 1))                # (S,B,d)
        last = h[-1]                                           # (B,d)
        return self.fh(last, target_uv)                        # (B,)

# ============================================================
# 5) Training loop
# ============================================================
def train(
    df: pl.DataFrame,
    seq_len: int = 10,
    d_model: int = 128,
    nhead: int = 4,
    num_layers: int = 4,
    ff_hidden: int = 512,
    fourier_m: int = 512,
    rank: int = 4,
    batch_size: int = 64,
    lr: float = 1e-3,
    epochs: int = 5,
    device: str | None = None,
):
    device = device or ("cuda" if torch.cuda.is_available() else "cpu")
    print(f"🟡 Training on {device}", flush=True)

    ds = AISForecastIterableDataset(df, seq_len=seq_len)
    loader = DataLoader(
        ds, batch_size=batch_size, shuffle=False,
        drop_last=True, num_workers=0, pin_memory=True
    )
    model = TransformerForecaster(
        seq_len, d_model, nhead, num_layers,
        ff_hidden, fourier_m, rank
    ).to(device)
    opt = optim.Adam(model.parameters(), lr=lr)

    for ep in range(1, epochs + 1):
        total_nll, count = 0.0, 0
        print(f"\n🟡 Epoch {ep}/{epochs}", flush=True)

        for b_idx, (past_uv, centre_ll, target_uv, logJ) in enumerate(loader, 1):
            past_uv   = past_uv.to(device)
            centre_ll = centre_ll.to(device)
            target_uv = target_uv.to(device)
            logJ      = logJ.to(device)

            pdf_uv = model(past_uv, centre_ll, target_uv)      # (B,)
            # log-pdf in lat/lon space = log(pdf_uv) + logJ
            loss   = -(torch.log(pdf_uv + 1e-12) + logJ).mean()

            opt.zero_grad()
            loss.backward()
            opt.step()

            total_nll += loss.item() * past_uv.size(0)
            count     += past_uv.size(0)
            if b_idx % 50 == 0:
                print(f"    Batch {b_idx}  NLL {loss.item():.4f}", flush=True)

        print(f"🟢 Epoch {ep}  Avg NLL: {total_nll / count:.4f}", flush=True)

    return model

# ============================================================
# 6) Execute
# ============================================================
df = load_cleaned_data("data/cleaned_partitioned_ais")
model = train(
    df,
    seq_len      = 20,
    d_model      = 128,
    nhead        = 4,
    num_layers   = 4,
    ff_hidden    = 512,
    fourier_m    = 64,     # local window → far fewer freqs needed
    rank         = 4,
    batch_size   = 64,
    lr           = 1e-3,
    epochs       = 100,
)
print("✅ Training complete", flush=True)
PY


Traceback (most recent call last):
  File "<stdin>", line 169, in <module>
TypeError: unsupported operand type(s) for |: 'type' and 'NoneType'


CalledProcessError: Command 'b'# \xe2\x94\x80\xe2\x94\x80\xe2\x94\x80 cell: end-to-end local-window training \xe2\x94\x80\xe2\x94\x80\xe2\x94\x80\xe2\x94\x80\xe2\x94\x80\xe2\x94\x80\xe2\x94\x80\xe2\x94\x80\xe2\x94\x80\xe2\x94\x80\xe2\x94\x80\xe2\x94\x80\xe2\x94\x80\xe2\x94\x80\xe2\x94\x80\xe2\x94\x80\xe2\x94\x80\xe2\x94\x80\xe2\x94\x80\xe2\x94\x80\xe2\x94\x80\xe2\x94\x80\xe2\x94\x80\xe2\x94\x80\xe2\x94\x80\xe2\x94\x80\xe2\x94\x80\xe2\x94\x80\xe2\x94\x80\xe2\x94\x80\xe2\x94\x80\xe2\x94\x80\n# Run this once; it replaces the previous \xe2\x80\x9cglobal PDF\xe2\x80\x9d cell.\nset -euo pipefail\npython - <<\'PY\'\n# ------------------------------------------------------------\n# 0) Imports  (all standard \xe2\x80\x93 nothing beyond PyTorch & Polars)\n# ------------------------------------------------------------\nimport sys, time, math\nfrom pathlib import Path\n\nimport polars as pl\nimport torch\nfrom torch import nn, optim\nfrom torch.utils.data import IterableDataset, DataLoader\n\n# Make sure the Fourier-head repo is on path\nfourier_dir = Path.home() / "repos" / "fourier-head" / "notebooks"\nsys.path.insert(0, str(fourier_dir))\nfrom four_head_2D_LR import FourierHead2DLR   # low-rank Fourier head\n\n# ============================================================\n# 1) Utility: lat/lon  \xe2\x86\x92  local (u,v)  in [-1,1]\xc2\xb2\n# ============================================================\ndef latlon_to_local_uv(lat, lon, lat0, lon0, half_side_mi=50.0):\n    """\n    lat, lon, lat0, lon0  : tensors in **degrees**  (any shape \xe2\x80\xa6)\n    Returns:\n        uv    : (..., 2)  where each component is in [-1,1]  (clipped)\n        logJ  : (...)    log-Jacobian \xce\x94(area_u,v) \xe2\x86\x92 \xce\x94(area_lat,lon)\n    """\n    lat_rad   = torch.deg2rad(lat)\n    lat0_rad  = torch.deg2rad(lat0)\n\n    # miles per degree (\xe2\x89\x88 69 mi) and longitude scaling with cos(lat)\n    R_mi      = 69.0\n    dx_mi     = R_mi * torch.cos(lat0_rad) * (lon - lon0)\n    dy_mi     = R_mi * (lat - lat0)\n\n    u         = dx_mi / half_side_mi\n    v         = dy_mi / half_side_mi\n\n    # ---------- Jacobian |\xe2\x88\x82(x,y)/\xe2\x88\x82(u,v)| = (half_side_mi)\xc2\xb2\n    # plus lat scaling factors (R cos\xcf\x86, R)\n    logJ      = (\n        -2 * math.log(half_side_mi)\n        - torch.log(R_mi * torch.cos(lat0_rad))\n        - math.log(R_mi)\n    )\n\n    # ----- Option A: clip targets that wander outside the window\n    u_clipped = torch.clamp(u, -1.0, 1.0)\n    v_clipped = torch.clamp(v, -1.0, 1.0)\n    uv        = torch.stack([u_clipped, v_clipped], dim=-1)\n    return uv, logJ\n\n# ============================================================\n# 2) Load all cleaned-AIS parquet shards\n# ============================================================\ndef load_cleaned_data(cleaned_root: str) -> pl.DataFrame:\n    p = Path(cleaned_root).resolve()\n    files = list(p.rglob("*.parquet"))\n    if not files:\n        raise FileNotFoundError(f"No parquet files under {p}")\n    print(f"\xf0\x9f\x9f\xa2 Reading {len(files)} parquet file(s)\xe2\x80\xa6", flush=True)\n    return pl.read_parquet([str(f) for f in files])\n\n# ============================================================\n# 3) Streaming dataset  (local-window version)\n# ============================================================\nclass AISForecastIterableDataset(IterableDataset):\n    def __init__(self, df: pl.DataFrame, seq_len: int = 10, half_side_mi=50.0):\n        start = time.time()\n        self.df = df.sort(["mmsi", "timestamp"])\n        self.mmsis = self.df["mmsi"].unique().to_list()\n        self.seq_len = seq_len\n        self.half_side_mi = half_side_mi\n        print(f"\xf0\x9f\x9f\xa2 Dataset: {len(self.mmsis)} vessels loaded in {time.time()-start:.1f}s", flush=True)\n\n    def __iter__(self):\n        for m in self.mmsis:\n            grp = (\n                self.df\n                .filter(pl.col("mmsi") == m)\n                .select(["lat", "lon"])\n            )\n            coords = torch.tensor(grp.to_numpy(), dtype=torch.float32)  # (N,2)\n            N = coords.size(0)\n            if N <= self.seq_len + 1:\n                continue\n\n            # rolling windows\n            for i in range(self.seq_len, N - 1):\n                past_abs = coords[i - self.seq_len : i]   # (S,2)\n                target_abs = coords[i]                    # (2,)\n\n                lat0, lon0 = past_abs[-1]                 # window centre\n                past_uv, _     = latlon_to_local_uv(\n                    past_abs[:, 0], past_abs[:, 1],\n                    lat0, lon0, self.half_side_mi\n                )\n                target_uv, logJ = latlon_to_local_uv(\n                    target_abs[0], target_abs[1],\n                    lat0, lon0, self.half_side_mi\n                )\n\n                centre_ll_norm = torch.tensor([lat0 / 90.0, lon0 / 180.0])\n\n                yield (\n                    past_uv,              # (S,2)  local motion\n                    centre_ll_norm,       # (2,)    global context\n                    target_uv,            # (2,)    local target\n                    logJ                  # ()      log-Jacobian\n                )\n\n# ============================================================\n# 4) Model  (local motion + global context)\n# ============================================================\nclass TransformerForecaster(nn.Module):\n    def __init__(\n        self,\n        seq_len: int,\n        d_model: int,\n        nhead: int,\n        num_layers: int,\n        ff_hidden: int,\n        fourier_m: int,\n        rank: int,\n    ):\n        super().__init__()\n        self.seq_len = seq_len\n\n        # each timestep now has 4 features: (u,v, lat_norm, lon_norm)\n        self.input_proj = nn.Linear(4, d_model)\n        self.pos_emb = nn.Parameter(torch.randn(seq_len, d_model))\n        enc_layer = nn.TransformerEncoderLayer(d_model, nhead, ff_hidden)\n        self.transformer = nn.TransformerEncoder(enc_layer, num_layers)\n        self.fh = FourierHead2DLR(\n            dim_input=d_model,\n            num_frequencies=fourier_m,\n            rank=rank\n        )\n\n    def forward(self, past_uv, centre_ll, target_uv):\n        """\n        past_uv   : (B, S, 2)\n        centre_ll : (B, 2)\n        target_uv : (B, 2)\n        """\n        B, S, _ = past_uv.shape\n        # repeat centre_lat/lon to every timestep\n        centre_rep = centre_ll.unsqueeze(1).expand(-1, S, -1)  # (B,S,2)\n        x = torch.cat([past_uv, centre_rep], dim=-1)           # (B,S,4)\n        h = self.input_proj(x) + self.pos_emb.unsqueeze(0)     # (B,S,d)\n        h = self.transformer(h.transpose(0, 1))                # (S,B,d)\n        last = h[-1]                                           # (B,d)\n        return self.fh(last, target_uv)                        # (B,)\n\n# ============================================================\n# 5) Training loop\n# ============================================================\ndef train(\n    df: pl.DataFrame,\n    seq_len: int = 10,\n    d_model: int = 128,\n    nhead: int = 4,\n    num_layers: int = 4,\n    ff_hidden: int = 512,\n    fourier_m: int = 512,\n    rank: int = 4,\n    batch_size: int = 64,\n    lr: float = 1e-3,\n    epochs: int = 5,\n    device: str | None = None,\n):\n    device = device or ("cuda" if torch.cuda.is_available() else "cpu")\n    print(f"\xf0\x9f\x9f\xa1 Training on {device}", flush=True)\n\n    ds = AISForecastIterableDataset(df, seq_len=seq_len)\n    loader = DataLoader(\n        ds, batch_size=batch_size, shuffle=False,\n        drop_last=True, num_workers=0, pin_memory=True\n    )\n    model = TransformerForecaster(\n        seq_len, d_model, nhead, num_layers,\n        ff_hidden, fourier_m, rank\n    ).to(device)\n    opt = optim.Adam(model.parameters(), lr=lr)\n\n    for ep in range(1, epochs + 1):\n        total_nll, count = 0.0, 0\n        print(f"\\n\xf0\x9f\x9f\xa1 Epoch {ep}/{epochs}", flush=True)\n\n        for b_idx, (past_uv, centre_ll, target_uv, logJ) in enumerate(loader, 1):\n            past_uv   = past_uv.to(device)\n            centre_ll = centre_ll.to(device)\n            target_uv = target_uv.to(device)\n            logJ      = logJ.to(device)\n\n            pdf_uv = model(past_uv, centre_ll, target_uv)      # (B,)\n            # log-pdf in lat/lon space = log(pdf_uv) + logJ\n            loss   = -(torch.log(pdf_uv + 1e-12) + logJ).mean()\n\n            opt.zero_grad()\n            loss.backward()\n            opt.step()\n\n            total_nll += loss.item() * past_uv.size(0)\n            count     += past_uv.size(0)\n            if b_idx % 50 == 0:\n                print(f"    Batch {b_idx}  NLL {loss.item():.4f}", flush=True)\n\n        print(f"\xf0\x9f\x9f\xa2 Epoch {ep}  Avg NLL: {total_nll / count:.4f}", flush=True)\n\n    return model\n\n# ============================================================\n# 6) Execute\n# ============================================================\ndf = load_cleaned_data("data/cleaned_partitioned_ais")\nmodel = train(\n    df,\n    seq_len      = 20,\n    d_model      = 128,\n    nhead        = 4,\n    num_layers   = 4,\n    ff_hidden    = 512,\n    fourier_m    = 64,     # local window \xe2\x86\x92 far fewer freqs needed\n    rank         = 4,\n    batch_size   = 64,\n    lr           = 1e-3,\n    epochs       = 100,\n)\nprint("\xe2\x9c\x85 Training complete", flush=True)\nPY\n'' returned non-zero exit status 1.