# Transformers for Time Series Forecasting

In this lab, you will implement and experiment with transformer-based models for time series
forecasting. You will learn about:
- RevIN (Reversible Instance Normalization) for handling distribution shifts
- PatchTST: a patch-based transformer architecture for time series
- Training with SAM (Sharpness-Aware Minimization) loss

In [None]:
import numpy
import torch
import torch.nn as nn
import matplotlib.pyplot as plt

## Part 1: Unified Data Loaders

In this section, you will create data loaders for ETTh1, reusing code from previous lab sessions.

**Question 1.** Create a unified data loader function. It should:
- Support the natively multivariate `ETTh1Dataset` from Lab 2 
  (multivariate inputs, univariate outputs)
- Do not scale input features (you will use RevIn layers in your models for that)
- Return training and validation data loaders

## Part 2: RevIN (Reversible Instance Normalization)

RevIN is a normalization technique that helps transformers handle distribution shifts
in time series data. It normalizes each instance (time series) independently and
can be reversed after processing.

**Question 2.** Implement the RevIN module:
- Normalize input by subtracting mean and dividing by standard deviation (per instance)
- Store the statistics to reverse the normalization after processing
- Define learnable scaling parameters $\gamma$ and $\beta$
- Support both forward (normalize) and reverse (denormalize) operations
- Allow specification of a target channel for the multivariate-to-univariate 
  forecasting use case.

**Question 3.** Visualize the impact of this RevIN normalization/denormalization on a small set of time series from your data loader. Focus your visualizations on the target feature.

**Question 4.** Now visualize the impact of this RevIN normalization/denormalization at the distribution scale. Once again, focus on the target feature and plot training and validation empirical distributions before/after normalization.

## Part 3: PatchTST

PatchTST (Patch-based Time Series Transformer) divides the input time series into
patches and processes them with a transformer. This approach is more efficient than
processing individual time steps.

**Question 5.** Implement a minimal PatchTST model for univariate forecasting  
*At this stage, keep the simple channel-mixing implementation (no channel independence yet).*
- Divide the input sequence into patches of a given length
- Project patches to a model dimension
- Add learnable positional encodings
- Apply a standard transformer encoder 
  (use `nn.TransformerEncoder` and `nn.TransformerEncoderLayer` classes)
- Mean-pool patch representations and predict the forecast horizon using a linear head

**Question 6.** Compare performance of the following models on the ETTh1 dataset:
- PatchTST without RevIN and using patches of size 1
- PatchTST with RevIN and using patches of size 1
- PatchTST without RevIN and using patches of size 16
- PatchTST with RevIN and using patches of size 16

**Question 6bis.** Channel independence in PatchTST  
One hallmark of PatchTST is channel-independent patch processing (depthwise patch embedding), which is not implemented above.  
- Implement a channel-independent variant (one embedding per channel, no cross-channel mixing before the transformer). The typical implementation trick for this is to reorganize tensor dimensions such that independent channels can be processed as if they were independent time series, such that the batch dimension becomes B*C.
- Compare its performance with the mixed-channel version from Question 6 on ETTh1.  
- Discuss when channel independence helps or hurts.

# Part 4: Training with SAM (Sharpness-Aware Minimization)

Sharpness-Aware Minimization (SAM) improves generalization by encouraging
solutions that lie in flat regions of the loss landscape.

In this part:
- You are GIVEN an implementation of the SAM optimizer
- You must implement a SAM training step
- You must adapt the training loop accordingly

Important notes:
- The loss function is NOT modified
- The validation loop remains unchanged
- Only the training step differs from standard optimization

In [None]:
class SAM(torch.optim.Optimizer):
    """Sharpness-Aware Minimization optimizer wrapper."""
    
    def __init__(self, params, base_optimizer, rho=0.05, **kwargs):
        assert rho >= 0.0, f"Invalid rho, should be non-negative: {rho}"
        
        defaults = dict(rho=rho, **kwargs)
        super(SAM, self).__init__(params, defaults)
        
        self.base_optimizer = base_optimizer(self.param_groups, **kwargs)
        self.param_groups = self.base_optimizer.param_groups
    
    @torch.no_grad()
    def first_step(self, zero_grad=False):
        grad_norm = self._grad_norm()
        for group in self.param_groups:
            scale = group["rho"] / (grad_norm + 1e-12)
            
            for p in group["params"]:
                if p.grad is None:
                    continue
                e_w = p.grad * scale.to(p)
                p.add_(e_w)  # climb to the local maximum "w + e(w)"
                self.state[p]["e_w"] = e_w
        
        if zero_grad:
            self.zero_grad()
    
    @torch.no_grad()
    def second_step(self, zero_grad=False):
        for group in self.param_groups:
            for p in group["params"]:
                if p.grad is None:
                    continue
                p.sub_(self.state[p]["e_w"])  # get back to "w" from "w + e(w)"
        
        self.base_optimizer.step()  # do the actual "sharpness-aware" update
        
        if zero_grad:
            self.zero_grad()
    
    def _grad_norm(self):
        shared_device = self.param_groups[0]["params"][0].device
        norm = torch.norm(
            torch.stack([
                p.grad.norm(p=2).to(shared_device)
                for group in self.param_groups for p in group["params"]
                if p.grad is not None
            ]),
            p=2
        )
        return norm
    
    def step(self, closure=None):
        raise NotImplementedError("SAM doesn't work like the other optimizers, you should first call `first_step` and then `second_step`")

**Question 7.** Implement a SAM training step
The SAM optimizer exposes two methods:
  - optimizer.first_step()
  - optimizer.second_step()

Implement a function `sam_step` that:
  1. Computes the loss and gradients at the current parameters
  2. Calls `first_step()` to move to a nearby point of higher loss
  3. Recomputes the loss and gradients at the perturbed parameters
  4. Calls `second_step()` to update the model

The function should return the final loss value.

**Question 8.** Implement a training epoch using SAM

Using the `sam_step` function, implement a training epoch.
The structure should be similar to the standard training loop.

**Question 9.** Full training loop with SAM

Complete the training-and-validation loop using the SAM-based training epoch.
The validation loop remains unchanged.

**Question 8.** Compare the performance of:
- PatchTST (Adam)
- PatchTST (SAM)

Visualize the training curves and some example forecasts.