In [None]:
import polars as pl
import pandas as pd
import numpy as np
import torch.nn as nn

NUM_DAYS = 7

# Step 1: Get max timestamp from the file (only this column is scanned)
latest_ts = (
    pl.scan_parquet("/kaggle/input/drw-crypto-market-prediction/train.parquet")
    .select(pl.col("timestamp").max())
    .collect()
    .item()
)

# Step 2: Calculate cutoff timestamp
cutoff_ts = latest_ts - pl.duration(days=NUM_DAYS)

# Step 3: Lazily scan, filter, and sort
lf = pl.scan_parquet("/kaggle/input/drw-crypto-market-prediction/train.parquet")
lf_filtered = lf.filter(pl.col("timestamp") >= cutoff_ts)


In [None]:
train = lf_filtered.collect().sort("timestamp")

In [None]:
# Drop timestamp — it's not used directly by the model
features = train.drop(["timestamp", "label"])
target = train["label"]
X = features.to_numpy()
y = target.to_numpy()


In [None]:
# Step 1: Replace inf/-inf with NaN
X[np.isinf(X)] = np.nan

# Step 2: Drop columns with all NaNs
valid_cols = ~np.isnan(X).all(axis=0)
X = X[:, valid_cols]

# Step 3: Impute remaining NaNs with column means
col_means = np.nanmean(X, axis=0)
inds = np.where(np.isnan(X))
X[inds] = np.take(col_means, inds[1])

# Now scale
from sklearn.preprocessing import StandardScaler
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X)

In [None]:
import numpy as np

LOOKBACK = 45    # 1 hour of minute data
HORIZON = 10    # predict 15 minutes into the future
BATCH_SIZE = 32

def make_windows(X, y, lookback=LOOKBACK, horizon=HORIZON):
    Xs, Ys = [], []
    for i in range(len(X) - lookback - horizon):
        Xs.append(X[i:i+lookback])
        Ys.append(y[i+lookback + horizon - 1])
    return np.stack(Xs), np.array(Ys)

X_seq, y_seq = make_windows(X_scaled, y)


In [None]:
import torch
from torch.utils.data import Dataset
from torch.utils.data import Dataset

class CryptoDataset(Dataset):
    def __init__(self, X, y):
        self.X = torch.tensor(X, dtype=torch.float32)
        self.y = torch.tensor(y, dtype=torch.float32).unsqueeze(-1)
    def __len__(self):
        return len(self.X)
    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]

dataset = CryptoDataset(X_seq, y_seq)



In [None]:
from torch.utils.data import random_split, DataLoader

# Total number of samples
total_size = len(dataset)

# Proportions
train_pct = 0.7
val_pct   = 0.3


# Integer sizes
train_size = int(train_pct * total_size)
val_size   = int(total_size - train_size)


# Split the dataset
train_ds, val_ds = random_split(dataset, [train_size, val_size])

# Create DataLoaders
train_loader = DataLoader(train_ds, batch_size=32, shuffle=True)
val_loader   = DataLoader(val_ds, batch_size=32)



In [None]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=8192):
        super().__init__()
        pos = torch.arange(0, max_len).unsqueeze(1)
        div = torch.exp(torch.arange(0, d_model, 2) * -(np.log(10000.0) / d_model))
        pe = torch.zeros(1, max_len, d_model)
        pe[0, :, 0::2] = torch.sin(pos * div)
        pe[0, :, 1::2] = torch.cos(pos * div)
        self.register_buffer("pe", pe)

    def forward(self, x):
        # Dynamically adjust positional encoding to match input sequence length
        seq_len = x.size(1)
        if seq_len > self.pe.size(1):
            raise ValueError(f"Input sequence length ({seq_len}) exceeds maximum positional encoding length ({self.pe.size(1)}).")
        return x + self.pe[:, :seq_len, :]

In [None]:
class GatedSSMBlock(nn.Module):
    def __init__(self, d_model, dropout=0.2):
        super().__init__()
        self.U = nn.Linear(d_model, d_model)
        self.F = nn.Linear(d_model, d_model)
        self.O = nn.Linear(d_model, d_model)
        self.norm = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        forget_gate = torch.sigmoid(self.F(x))
        update = torch.tanh(self.U(x))
        output_gate = torch.sigmoid(self.O(x))
        h = forget_gate * x + (1 - forget_gate) * update
        return self.dropout(self.norm(output_gate * h))


In [None]:
class CrossHeadRouter(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        self.attn_weights = nn.Parameter(torch.randn(num_heads, d_model))

    def forward(self, head_outputs):
        # head_outputs: list of (B, T, D)
        stacked = torch.stack(head_outputs, dim=2)  # (B, T, H, D)
        weights = torch.einsum('bthd,hd->bth', stacked, self.attn_weights)  # (B, T, H)
        soft_weights = torch.softmax(weights, dim=-1).unsqueeze(-1)         # (B, T, H, 1)
        fused = (stacked * soft_weights).sum(dim=2)                         # (B, T, D)
        return fused


In [None]:
class TokenWiseFeedForward(nn.Module):
    def __init__(self, d_model):
        super().__init__()
        self.ffn = nn.Sequential(
            nn.Linear(d_model, d_model * 2),
            nn.GELU(),
            nn.Linear(d_model * 2, d_model)
        )

    def forward(self, x):
        return self.ffn(x)


In [None]:
class ParallelSSMHeads(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        self.heads = nn.ModuleList([GatedSSMBlock(d_model) for _ in range(num_heads)])
        self.router = CrossHeadRouter(d_model, num_heads)
        self.tokenwise_ffn = TokenWiseFeedForward(d_model)
        self.norm = nn.LayerNorm(d_model)

    def forward(self, x):
        # x: (B, T, D)
        head_outputs = [head(x) for head in self.heads]  # Each (B, T, D)
        fused = self.router(head_outputs)                # (B, T, D) after soft alignment
        fused = fused + self.tokenwise_ffn(fused)        # Add local dynamics
        return self.norm(fused)


In [None]:
class MultiHeadSSM(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        assert d_model % num_heads == 0
        self.head_dim = d_model // num_heads
        self.heads = nn.ModuleList([GatedSSMBlock(self.head_dim) for _ in range(num_heads)])
        self.output_proj = nn.Linear(d_model, d_model)

    def forward(self, x):
        B, T, D = x.shape
        x = x.view(B, T, len(self.heads), self.head_dim)
        x = torch.stack([head(x[:, :, i, :]) for i, head in enumerate(self.heads)], dim=2)
        x = x.view(B, T, D)
        return self.output_proj(x)


In [None]:
class HybridBlock(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        self.parallel_ssm = ParallelSSMHeads(d_model, num_heads)
        self.norm = nn.LayerNorm(d_model)

    def forward(self, x):
        x_ssm = self.parallel_ssm(x)
        return self.norm(x + x_ssm)


In [None]:
import torch.nn as nn

class LazyAttentionSSM(nn.Module):
    def __init__(self, input_dim, d_model=512, depth=4, num_heads=4):
        super().__init__()
        self.input_proj = nn.Linear(input_dim, d_model)
        self.pos_enc = PositionalEncoding(d_model)
        self.blocks = nn.Sequential(*[HybridBlock(d_model, num_heads) for _ in range(depth)])
        self.output_norm = nn.LayerNorm(d_model)           # ⬅ use correct variable
        self.head_dropout = nn.Dropout(0.2)
        self.head = nn.Linear(d_model, 1)

    def forward(self, x):
        x = self.input_proj(x)
        x = self.pos_enc(x)
        x = self.blocks(x)
        x = self.output_norm(x)                            # ⬅ normalize features before pooling
        x = x.mean(dim=1)                                  # global average pooling
        x = self.head_dropout(x)
        return self.head(x)


In [None]:
import torch.nn as nn

class SimpleTransformerBlock(nn.Module):
    def __init__(self, d_model, num_heads, dropout=0.2):
        super().__init__()
        self.attn = nn.MultiheadAttention(d_model, num_heads, dropout=dropout, batch_first=True)
        self.norm1 = nn.LayerNorm(d_model)
        self.ffn = nn.Sequential(
            nn.Linear(d_model, d_model * 2),
            nn.GELU(),
            nn.Linear(d_model * 2, d_model),
            nn.Dropout(dropout)
        )
        self.norm2 = nn.LayerNorm(d_model)

    def forward(self, x):
        attn_out, _ = self.attn(x, x, x)
        x = self.norm1(x + attn_out)
        ffn_out = self.ffn(x)
        x = self.norm2(x + ffn_out)
        return x

class Transformer(nn.Module):
    def __init__(self, input_dim, d_model=1024, depth=8, num_heads=8):
        super().__init__()
        self.input_proj = nn.Linear(input_dim, d_model)
        self.pos_enc = PositionalEncoding(d_model)
        self.blocks = nn.Sequential(*[SimpleTransformerBlock(d_model, num_heads) for _ in range(depth)])
        self.output_norm = nn.LayerNorm(d_model)
        self.head_dropout = nn.Dropout(0.2)
        self.head = nn.Linear(d_model, 1)

    def forward(self, x):
        x = self.input_proj(x)
        x = self.pos_enc(x)
        x = self.blocks(x)
        x = self.output_norm(x)
        x = x.mean(dim=1)
        x = self.head_dropout(x)
        return self.head(x)

In [None]:
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = LazyAttentionSSM(input_dim=X.shape[-1]).to(device)


loss_fn = nn.MSELoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)


In [None]:
from tqdm import tqdm

NUM_EPOCHS = 25

for epoch in range(NUM_EPOCHS):
    model.train()
    total_loss = 0
    for xb, yb in tqdm(train_loader, desc=f"Epoch {epoch+1}/{NUM_EPOCHS}"):
        xb, yb = xb.to(device), yb.to(device)
        if yb.ndim == 1:
            yb = yb.unsqueeze(1)  # Ensure shape is [B, 1]
        optimizer.zero_grad()
        preds = model(xb)
        loss = loss_fn(preds, yb)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=.5)

        optimizer.step()
        total_loss += loss.item()
    
    model.eval()
    val_loss = 0
    with torch.no_grad():
        for xb, yb in val_loader:
            xb, yb = xb.to(device), yb.to(device)
            if yb.ndim == 1:
                yb = yb.unsqueeze(1)  # Ensure shape is [B, 1]
            preds = model(xb)
            val_loss += loss_fn(preds, yb).item()

    print(f"Train Loss: {total_loss/len(train_loader):.6f} | Val Loss: {val_loss/len(val_loader):.6f}")


In [None]:
import pandas as pd
import numpy as np
import polars as pl
import torch
from torch.utils.data import DataLoader, TensorDataset

# === CONFIG ===
chunk_size = 25_000  # Adjust this depending on available memory
test_path = "/kaggle/input/drw-crypto-market-prediction/test.parquet"

# === STORAGE ===
submission_chunks = []
running_id = 0

# === CHUNKED PROCESSING LOOP ===
for i, df_chunk in enumerate(pl.read_parquet(test_path).iter_slices(n_rows=chunk_size)):
    print(f" Processing chunk {i+1}")

    # Drop label (all zeros in test set)
    if "label" in df_chunk.columns:
        df_chunk = df_chunk.drop("label")

    # Convert to NumPy and clean
    X_chunk = df_chunk.to_numpy()
    X_chunk[np.isinf(X_chunk)] = np.nan

    # Apply valid column filter
    X_chunk = X_chunk[:, valid_cols]

    # Impute missing
    inds = np.where(np.isnan(X_chunk))
    X_chunk[inds] = np.take(col_means, inds[1])

    # Scale
    X_chunk_scaled = scaler.transform(X_chunk)

    # Create rolling windows
    X_chunk_seq = [
        X_chunk_scaled[j:j + LOOKBACK]
        for j in range(len(X_chunk_scaled) - LOOKBACK - HORIZON)
    ]

    if not X_chunk_seq:
        continue

    X_chunk_seq = np.stack(X_chunk_seq)
    X_chunk_tensor = torch.tensor(X_chunk_seq, dtype=torch.float32)
    test_loader = DataLoader(TensorDataset(X_chunk_tensor), batch_size=32)

    # Predict
    model.eval()
    chunk_preds = []
    with torch.no_grad():
        for (xb,) in test_loader:
            xb = xb.to(device)
            out = model(xb).cpu().numpy().squeeze()
            chunk_preds.extend(out)

    # Store predictions with global ID
    chunk_df = pd.DataFrame({
        "ID": np.arange(running_id, running_id + len(chunk_preds)),
        "prediction": chunk_preds
    })
    submission_chunks.append(chunk_df)
    running_id += len(chunk_preds)

# === FINAL SUBMISSION FILE ===
submission_df = pd.concat(submission_chunks, ignore_index=True)
submission_df.to_csv("3chunkedcryptopredsubmission.csv", index=False)
print(" Submission saved to 'cunkedcryptopredsubmission.csv' with shape:", submission_df.shape)


In [None]:
model.eval()
val_preds = []
val_targets = []

with torch.no_grad():
    for xb, yb in val_loader:
        xb, yb = xb.to(device), yb.to(device)
        preds = model(xb).squeeze().cpu().numpy()
        targets = yb.squeeze().cpu().numpy()
        val_preds.extend(preds)
        val_targets.extend(targets)

val_preds = np.array(val_preds)
val_targets = np.array(val_targets)


In [None]:
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score

mse = mean_squared_error(val_targets, val_preds)
mae = mean_absolute_error(val_targets, val_preds)
r2  = r2_score(val_targets, val_preds)

print(f"Validation MSE:  {mse:.6f}")
print(f"Validation MAE:  {mae:.6f}")
print(f"Validation R²:   {r2:.4f}")


In [None]:
direction_acc = np.mean(np.sign(val_preds) == np.sign(val_targets))
print(f"Directional Accuracy: {direction_acc:.2%}")


In [None]:
import matplotlib.pyplot as plt

plt.figure(figsize=(12, 4))
plt.plot(val_targets[:300], label="True", alpha=0.7)
plt.plot(val_preds[:300], label="Predicted", alpha=0.7)
plt.title("Validation Predictions vs Ground Truth")
plt.xlabel("Time step")
plt.ylabel("Label")
plt.legend()
plt.grid(True)
plt.show()


In [None]:
from scipy.stats import pearsonr

model.eval()
preds, labels = [], []

with torch.no_grad():
    for xb, yb in val_loader:
        xb, yb = xb.to(device), yb.to(device)
        pr = model(xb).cpu().numpy()
        lb = yb.cpu().numpy()
        preds.extend(pr.squeeze())
        labels.extend(lb.squeeze())

pearson_val = pearsonr(preds, labels)[0]
print(f" Pearson on validation set: {pearson_val:.6f}")


In [None]:
import numpy as np
import torch
from torch.utils.data import DataLoader
from sklearn.preprocessing import StandardScaler

# Create a synthetic sinusoidal dataset for demo

# Parameters for the synthetic data
N = 8192
t = np.arange(N)
freq = 0.05
amplitude = 1.0
noise_std = 0.25

# Sinusoidal signal with noise
signal = amplitude * np.sin(2 * np.pi * freq * t) + np.random.normal(0, noise_std, N)

# Prepare features and targets using the same windowing as before
LOOKBACK = 45
HORIZON = 10

def make_windows(X, y, lookback=LOOKBACK, horizon=HORIZON):
    Xs, Ys = [], []
    for i in range(len(X) - lookback - horizon):
        signal_window = X[i:i+lookback]
        pos_window    = np.linspace(0, 1, lookback)  # normalized position
        stacked       = np.stack([signal_window, pos_window], axis=-1)  # shape: (L, 2)
        Xs.append(stacked)

        future_val = y[i + lookback + horizon - 1]
        last_val   = y[i + lookback - 1]
        Ys.append(future_val - last_val)
    return np.stack(Xs), np.array(Ys)




X_demo, y_demo = make_windows(signal, signal)

# Standardize
scaler_demo = StandardScaler()
X_demo_flat = X_demo.reshape(-1, X_demo.shape[-1])
X_demo_scaled = scaler_demo.fit_transform(X_demo_flat).reshape(X_demo.shape)
y_demo_scaled = scaler_demo.fit_transform(y_demo.reshape(-1, 1)).flatten()

# Torch dataset and loader
class DemoDataset(torch.utils.data.Dataset):
    def __init__(self, X, y):
        self.X = torch.tensor(X, dtype=torch.float32)
        self.y = torch.tensor(y, dtype=torch.float32).unsqueeze(-1)
    def __len__(self):
        return len(self.X)
    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]

demo_dataset = DemoDataset(X_demo_scaled, y_demo_scaled)
demo_loader = DataLoader(demo_dataset, batch_size=32, shuffle=False)

In [None]:
# -----------------  Multi-Step Rollout Training Loop -----------------
N_ROLL         = 5  # predict 10 steps into the future
demo_model     = Transformer(input_dim=2).to(device)
demo_loss_fn   = nn.MSELoss()
demo_optimizer = torch.optim.AdamW(demo_model.parameters(), lr=1e-5)
max_grad_norm  = 0.5
demo_epochs    = 25

for epoch in range(demo_epochs):
    demo_model.train()
    total_loss = 0.0

    for xb, yb in demo_loader:              # xb: (B, L, 2), yb: (B, 1)
        xb, yb = xb.to(device), yb.to(device)

        # Initialize current window
        current_window = xb.clone()         # shape: (B, L, 2)
        future_deltas = []
        last_val = xb[:, -1, 0]             # last signal value in window

        preds_all = []

        for step in range(N_ROLL):
            pred_delta = demo_model(current_window).squeeze(1)  # (B,)

            # Store prediction
            preds_all.append(pred_delta)

            # Compute next value (unroll)
            next_val = last_val + pred_delta                   # (B,)

            # Build new row: [next_val, pos=1.0]
            next_row = torch.stack([next_val, torch.ones_like(next_val)], dim=-1)  # (B, 2)
            next_row = next_row.unsqueeze(1)  # (B, 1, 2)

            # Roll window forward
            current_window = torch.cat([current_window[:, 1:], next_row], dim=1)
            last_val = next_val.detach()

        # Stack predicted deltas: shape (B, N_ROLL)
        preds_all = torch.stack(preds_all, dim=1)  # (B, N_ROLL)

        # Build ground truth deltas: advance yb by N_ROLL targets
        true_deltas = []
        start_idx = epoch * len(demo_loader.dataset) + 0  # just reuses the original delta from y_demo
        for b in range(xb.size(0)):
            future_deltas = []
            ref_idx = b + xb.size(1)  # starting at last point
            for step in range(N_ROLL):
                idx = ref_idx + step * HORIZON
                if idx + HORIZON >= len(signal): break
                future_val = signal[idx + HORIZON]
                last_val   = signal[idx]
                delta = future_val - last_val
                future_deltas.append(delta)
            padded = future_deltas + [0.0] * (N_ROLL - len(future_deltas))
            true_deltas.append(padded)

        true_deltas = torch.tensor(true_deltas, dtype=torch.float32).to(device)  # (B, N_ROLL)

        # Compute loss
        loss = demo_loss_fn(preds_all, true_deltas)
        demo_optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(demo_model.parameters(), max_grad_norm)
        demo_optimizer.step()

        total_loss += loss.item()

    avg_loss = total_loss / len(demo_loader)
    print(f"Epoch {epoch+1:2d}/{demo_epochs} | Rollout Loss: {avg_loss:.6f}")
# ---------------------------------------------------------------------


In [None]:
import matplotlib.pyplot as plt
import numpy as np
import torch

# ------------------------------------------------------------
# helper functions that respect the StandardScaler you trained

def build_raw_window(start_idx):
    """
    returns ndarray (LOOKBACK, 2)
    [:,0] = raw signal value
    [:,1] = pos in [0,1]
    """
    sig  = signal[start_idx : start_idx + LOOKBACK]
    pos  = np.linspace(0, 1, LOOKBACK)
    return np.stack([sig, pos], axis=-1)

# ------------------------------------------------------------
def scale_batch(raw_np_window):
    """
    raw_np_window : (1, L, 2)  [signal, pos]
    scales column 0, leaves column 1 intact
    """
    sig = raw_np_window[..., 0:1].reshape(-1, 1)
    sig_scaled = scaler_demo.transform(sig).reshape(1, LOOKBACK, 1)

    scaled_window = np.concatenate([sig_scaled, raw_np_window[...,1:2]], axis=-1)
    return torch.tensor(scaled_window, dtype=torch.float32, device=device)

def unscale_val(z):
    """
    z : scalar or 1-D NumPy array in scaled space  ->  raw scalar
    """
    return scaler_demo.inverse_transform(np.array(z, ndmin=2))[0, 0]

# ----------------------------------------------------------------
# autoregressive rollout

seed_idx  = 100
num_steps = 100

raw_window = build_raw_window(seed_idx)                # (L,2)
preds = []

demo_model.eval()

for _ in range(num_steps):
    xb = scale_batch(raw_window[np.newaxis, ...])      # (1,L,2)

    with torch.no_grad():
        delta_scaled = demo_model(xb).cpu().item()     # scalar in scaled space

    delta_raw = unscale_val(delta_scaled)
    next_sig  = raw_window[-1, 0] + delta_raw          # absolute prediction

    preds.append(next_sig)

    # slide: new row = [next_sig , 1.0]  (pos==1), then re-normalise positions
    new_row  = np.array([next_sig, 1.0])
    raw_window = np.vstack([raw_window[1:], new_row])

    # re-normalise pos column to 0-1
    raw_window[:, 1] = np.linspace(0, 1, LOOKBACK)


# ----------------------------------------------------------------
# plot
# ----------------------------------------------------------------
plt.figure(figsize=(14, 5))

# true future (raw signal)
plt.plot(
    range(seed_idx + LOOKBACK, seed_idx + LOOKBACK + num_steps),
    signal[seed_idx + LOOKBACK : seed_idx + LOOKBACK + num_steps],
    label="True Future", color="tab:blue"
)

# model predictions
plt.plot(
    range(seed_idx + LOOKBACK, seed_idx + LOOKBACK + num_steps),
    preds,
    label="Model Prediction", color="tab:orange"
)

# seed window
plt.plot(
    range(seed_idx, seed_idx + LOOKBACK),
    signal[seed_idx : seed_idx + LOOKBACK],
    label="Seed Window", color="tab:green", linestyle="dashed"
)

plt.title("Autoregressive Inference: Predicting Future Timesteps w/ Transformer")
plt.xlabel("Timestep")
plt.ylabel("Label")
plt.legend()
plt.grid(True)
plt.show()


In [None]:
# -----------------  Multi-Step Rollout Training Loop -----------------
N_ROLL         = 5  # predict 10 steps into the future
demo_model     = LazyAttentionSSM(input_dim=2).to(device)
demo_loss_fn   = nn.MSELoss()
demo_optimizer = torch.optim.AdamW(demo_model.parameters(), lr=1e-5)
max_grad_norm  = 0.5
demo_epochs    = 25

for epoch in range(demo_epochs):
    demo_model.train()
    total_loss = 0.0

    for xb, yb in demo_loader:              # xb: (B, L, 2), yb: (B, 1)
        xb, yb = xb.to(device), yb.to(device)

        # Initialize current window
        current_window = xb.clone()         # shape: (B, L, 2)
        future_deltas = []
        last_val = xb[:, -1, 0]             # last signal value in window

        preds_all = []

        for step in range(N_ROLL):
            pred_delta = demo_model(current_window).squeeze(1)  # (B,)

            # Store prediction
            preds_all.append(pred_delta)

            # Compute next value (unroll)
            next_val = last_val + pred_delta                   # (B,)

            # Build new row: [next_val, pos=1.0]
            next_row = torch.stack([next_val, torch.ones_like(next_val)], dim=-1)  # (B, 2)
            next_row = next_row.unsqueeze(1)  # (B, 1, 2)

            # Roll window forward
            current_window = torch.cat([current_window[:, 1:], next_row], dim=1)
            last_val = next_val.detach()

        # Stack predicted deltas: shape (B, N_ROLL)
        preds_all = torch.stack(preds_all, dim=1)  # (B, N_ROLL)

        # Build ground truth deltas: advance yb by N_ROLL targets
        true_deltas = []
        start_idx = epoch * len(demo_loader.dataset) + 0  # just reuses the original delta from y_demo
        for b in range(xb.size(0)):
            future_deltas = []
            ref_idx = b + xb.size(1)  # starting at last point
            for step in range(N_ROLL):
                idx = ref_idx + step * HORIZON
                if idx + HORIZON >= len(signal): break
                future_val = signal[idx + HORIZON]
                last_val   = signal[idx]
                delta = future_val - last_val
                future_deltas.append(delta)
            padded = future_deltas + [0.0] * (N_ROLL - len(future_deltas))
            true_deltas.append(padded)

        true_deltas = torch.tensor(true_deltas, dtype=torch.float32).to(device)  # (B, N_ROLL)

        # Compute loss
        loss = demo_loss_fn(preds_all, true_deltas)
        demo_optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(demo_model.parameters(), max_grad_norm)
        demo_optimizer.step()

        total_loss += loss.item()

    avg_loss = total_loss / len(demo_loader)
    print(f"Epoch {epoch+1:2d}/{demo_epochs} | Rollout Loss: {avg_loss:.6f}")
# ---------------------------------------------------------------------


In [None]:
import matplotlib.pyplot as plt
import numpy as np
import torch

# ------------------------------------------------------------
# helper functions that respect the StandardScaler you trained

def build_raw_window(start_idx):
    """
    returns ndarray (LOOKBACK, 2)
    [:,0] = raw signal value
    [:,1] = pos in [0,1]
    """
    sig  = signal[start_idx : start_idx + LOOKBACK]
    pos  = np.linspace(0, 1, LOOKBACK)
    return np.stack([sig, pos], axis=-1)

# ------------------------------------------------------------
def scale_batch(raw_np_window):
    """
    raw_np_window : (1, L, 2)  [signal, pos]
    scales column 0, leaves column 1 intact
    """
    sig = raw_np_window[..., 0:1].reshape(-1, 1)
    sig_scaled = scaler_demo.transform(sig).reshape(1, LOOKBACK, 1)

    scaled_window = np.concatenate([sig_scaled, raw_np_window[...,1:2]], axis=-1)
    return torch.tensor(scaled_window, dtype=torch.float32, device=device)

def unscale_val(z):
    """
    z : scalar or 1-D NumPy array in scaled space  ->  raw scalar
    """
    return scaler_demo.inverse_transform(np.array(z, ndmin=2))[0, 0]

# ----------------------------------------------------------------
# autoregressive rollout

seed_idx  = 100
num_steps = 100

raw_window = build_raw_window(seed_idx)                # (L,2)
preds = []

demo_model.eval()

for _ in range(num_steps):
    xb = scale_batch(raw_window[np.newaxis, ...])      # (1,L,2)

    with torch.no_grad():
        delta_scaled = demo_model(xb).cpu().item()     # scalar in scaled space

    delta_raw = unscale_val(delta_scaled)
    next_sig  = raw_window[-1, 0] + delta_raw          # absolute prediction

    preds.append(next_sig)

    # slide: new row = [next_sig , 1.0]  (pos==1), then re-normalise positions
    new_row  = np.array([next_sig, 1.0])
    raw_window = np.vstack([raw_window[1:], new_row])

    # re-normalise pos column to 0-1
    raw_window[:, 1] = np.linspace(0, 1, LOOKBACK)


# ----------------------------------------------------------------
# plot
# ----------------------------------------------------------------
plt.figure(figsize=(14, 5))

# true future (raw signal)
plt.plot(
    range(seed_idx + LOOKBACK, seed_idx + LOOKBACK + num_steps),
    signal[seed_idx + LOOKBACK : seed_idx + LOOKBACK + num_steps],
    label="True Future", color="tab:blue"
)

# model predictions
plt.plot(
    range(seed_idx + LOOKBACK, seed_idx + LOOKBACK + num_steps),
    preds,
    label="Model Prediction", color="tab:orange"
)

# seed window
plt.plot(
    range(seed_idx, seed_idx + LOOKBACK),
    signal[seed_idx : seed_idx + LOOKBACK],
    label="Seed Window", color="tab:green", linestyle="dashed"
)

plt.title("Autoregressive Inference: Predicting Future Timesteps w/ LazyAttentionSSM")
plt.xlabel("Timestep")
plt.ylabel("Label")
plt.legend()
plt.grid(True)
plt.show()
