## 9 · PositionalEncoding + Transformer model  
This is the neural architecture that you train. It consists of a standard PyTorch Transformer with encoder-decoder structure and sinusoidal positional encodings. The model accepts a sequence of past observations (and optionally decoder inputs during training) and returns predictions for the future window.

In [None]:
# 09. PositionalEncoding + Transformer model (univariate)

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.2, max_len=1024):
        super().__init__()
        self.dropout = nn.Dropout(dropout)
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float32).unsqueeze(1)
        div_term = torch.exp(
            torch.arange(0, d_model, 2, dtype=torch.float32) * -(math.log(10000.0) / d_model)
        )
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.register_buffer("pe", pe.unsqueeze(0))

    def forward(self, x):
        return self.dropout(x + self.pe[:, : x.size(1)])

class TimeSeriesTransformer(nn.Module):
    def __init__(self, input_window, horizon, d_model=64, nhead=8, num_layers=2):
        super().__init__()
        self.horizon  = horizon
        self.d_model  = d_model

        self.in_proj  = nn.Linear(1, d_model)
        self.pos_enc  = PositionalEncoding(d_model)
        self.tr_model = nn.Transformer(
            d_model=d_model,
            nhead=nhead,
            num_encoder_layers=num_layers,
            num_decoder_layers=num_layers,
            batch_first=True,
        )
        self.out_proj = nn.Linear(d_model, 1)

    def forward(self, past, decoder_input=None):
        """
        Args:
            past           : (B, T, 1)    — encoder input
            decoder_input  : (B, F, 1)    — optional decoder input (teacher forcing)
        Returns:
            preds          : (B, F)       — predicted future values
        """
        B = past.size(0)

        # Encoder input
        src = self.in_proj(past) * math.sqrt(self.d_model)
        src = self.pos_enc(src)

        # Decoder input
        if decoder_input is None:
            decoder_input = past[:, -1:, :].repeat(1, self.horizon, 1)

        tgt = self.in_proj(decoder_input) * math.sqrt(self.d_model)
        tgt = self.pos_enc(tgt)

        # Transformer forward
        output = self.tr_model(src, tgt)  # shape: (B, F, d_model)
        return self.out_proj(output).squeeze(-1)  # shape: (B, F)

### 10 · Ray Train training loop (with teacher forcing)  
This is the heart of Ray Train. Each worker executes this loop independently, but Ray orchestrates everything from checkpointing to failure recovery. Include teacher forcing, feeding the shifted ground-truth to the decoder, which allows the model to learn more quickly than starting from zero. Also log training and validation loss per epoch and save checkpoints to the shared filesystem.

In [None]:
# 10. Ray Train train_loop_per_worker with checkpointing, teacher forcing, and clean structure

def train_loop_per_worker(config):
    torch.manual_seed(0)

    # ─────────────────────────────────────────────────────────────
    # 1. Instantiate and prepare the model
    # ─────────────────────────────────────────────────────────────
    model = TimeSeriesTransformer(
        input_window=INPUT_WINDOW,
        horizon=HORIZON,
        d_model=config["d_model"],
        nhead=config["nhead"],
        num_layers=config["num_layers"],
    )
    model = train.torch.prepare_model(model)  # wrap in DDP if needed

    # ─────────────────────────────────────────────────────────────
    # 2. Define optimizer and loss
    # ─────────────────────────────────────────────────────────────
    optimizer = optim.Adam(model.parameters(), lr=config["lr"])
    loss_fn  = nn.SmoothL1Loss()

    # ─────────────────────────────────────────────────────────────
    # 3. Restore checkpoint if available
    # ─────────────────────────────────────────────────────────────
    start_epoch = 0
    checkpoint = get_checkpoint()
    if checkpoint:
        with checkpoint.as_directory() as ckpt_dir:
            model.load_state_dict(torch.load(os.path.join(ckpt_dir, "model.pt")))
            optimizer.load_state_dict(torch.load(os.path.join(ckpt_dir, "optim.pt")))
            start_epoch = torch.load(os.path.join(ckpt_dir, "extra.pt"))["epoch"] + 1
        print(f"[Rank {get_context().get_world_rank()}] Resumed @ epoch {start_epoch}")

    # ─────────────────────────────────────────────────────────────
    # 4. Load data for this worker
    # ─────────────────────────────────────────────────────────────
    train_loader = build_dataloader(
        os.path.join(PARQUET_DIR, "train.parquet"),
        batch_size=config["bs"],
        shuffle=True,
    )
    val_loader = build_dataloader(
        os.path.join(PARQUET_DIR, "val.parquet"),
        batch_size=config["bs"],
        shuffle=False,
    )

    # ─────────────────────────────────────────────────────────────
    # 5. Epoch loop
    # ─────────────────────────────────────────────────────────────
    for epoch in range(start_epoch, config["epochs"]):
        model.train()
        train_loss_sum = 0.0

        # ───── Training step ─────
        for past, future in train_loader:
            optimizer.zero_grad()

            # Teacher forcing: shift future targets to use as decoder input
            future = future.unsqueeze(-1)                          # (B, F, 1)
            start_token = torch.zeros_like(future[:, :1])         # (B, 1, 1)
            decoder_input = torch.cat([start_token, future[:, :-1]], dim=1)  # (B, F, 1)

            # Forward + loss
            pred = model(past, decoder_input)                     # (B, F)
            loss = loss_fn(pred, future.squeeze(-1))             # (B, F) vs (B, F)

            loss.backward()
            optimizer.step()
            train_loss_sum += loss.item()

        avg_train_loss = train_loss_sum / len(train_loader)

        # ───── Validation step ─────
        model.eval()
        val_loss_sum = 0.0
        with torch.no_grad():
            for past, future in val_loader:
                pred = model(past)                               # model inference (zeros as decoder input)
                loss = loss_fn(pred, future)
                val_loss_sum += loss.item()
        avg_val_loss = val_loss_sum / len(val_loader)

        # ─────────────────────────────────────────────────────────────
        # 6. Report metrics + optionally save checkpoint (rank 0 only)
        # ─────────────────────────────────────────────────────────────
        metrics = {
            "epoch": epoch,
            "train_loss": avg_train_loss,
            "val_loss": avg_val_loss,
        }

        if get_context().get_world_rank() == 0:
            print(metrics)

            # Save checkpoint
            ckpt_dir = f"{DATA_DIR}/tmp_ckpts/epoch_{epoch}_{uuid.uuid4().hex}"
            os.makedirs(ckpt_dir, exist_ok=True)
            torch.save(model.state_dict(), os.path.join(ckpt_dir, "model.pt"))
            torch.save(optimizer.state_dict(), os.path.join(ckpt_dir, "optim.pt"))
            torch.save({"epoch": epoch}, os.path.join(ckpt_dir, "extra.pt"))
            checkpoint_out = Checkpoint.from_directory(ckpt_dir)

            # Save loss history
            hist_path = os.path.join(DATA_DIR, "results", "history.csv")
            with open(hist_path, "a") as f:
                f.write(f"{epoch},{avg_train_loss:.6f},{avg_val_loss:.6f}\n")
        else:
            checkpoint_out = None

        train.report(metrics, checkpoint=checkpoint_out)