In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from sympy import symbols, lambdify
from sympy import symbols, hermite

# ─────────────────────────────────────────────────────────────────────────────
# 1) DATA‐GEN HELPERS
# ─────────────────────────────────────────────────────────────────────────────

x_sym = symbols('x')

def sample_hermite_function(p, q, coeff_variance=1.0):
    """
    Returns a NumPy‐vectorized function f(x) = sum_{i=p}^q c_i H_i(x),
    with c_i ~ N(0, coeff_variance).
    """
    expr = 0
    for i in range(p, q+1):
        c_i = np.random.normal(scale=np.sqrt(coeff_variance))
        expr += c_i * hermite(i, x_sym)
    return lambdify(x_sym, expr, 'numpy')

def generate_dataset(num_tasks, p, q, seq_len,
                     coeff_variance=1.0,
                     x_variance=1.0):
    """
    Generates `num_tasks` tasks.  Each task is a (seq_len × 2) NumPy array:
      - col0:  x ~ N(0, x_variance)
      - col1:  y = f(x) for a fresh random Hermite f of orders [p..q]
    """
    tasks = []
    for _ in range(num_tasks):
        xs = np.random.normal(scale=np.sqrt(x_variance), size=seq_len)
        f  = sample_hermite_function(p, q, coeff_variance)
        ys = f(xs)
        tasks.append(np.column_stack((xs, ys)))
    return tasks

# ─────────────────────────────────────────────────────────────────────────────
# 2) THE MODEL
# ─────────────────────────────────────────────────────────────────────────────

class TinyTransformer(nn.Module):
    def __init__(self, d_model=16, nhead=4):
        super().__init__()
        # (a) embed each 2-vector [x_t,y_t] → d_model
        self.embedding_mlp = nn.Sequential(
            nn.Linear(2, d_model),
            nn.ReLU(),
            nn.Linear(d_model, d_model),
        )
        # (b) one layer of causal self-attention
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=nhead,
            batch_first=True,   # we have [B, S, d_model]
        )
        # create a mask to prevent attending to future positions:
        # mask shape [S, S], True means masked
        self.register_buffer("mask", None)  # will build on the fly
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=1)
        # (c) read-out final token → scalar
        self.readout = nn.Linear(d_model, 1)

    def forward(self, x_seq):
        """
        x_seq: [B, S, 2] tensor of (x_t,y_t) pairs
        returns y_pred: [B] of final predictions
        """
        B, S, F = x_seq.shape       # F==2

        # 1) embed every time‐step
        flat     = x_seq.view(B*S, F)              # → [B*S, 2]
        emb_flat = self.embedding_mlp(flat)        # → [B*S, d_model]
        emb      = emb_flat.view(B, S, -1)         # → [B, S, d_model]

        # 2) build causal mask once if needed
        if self.mask is None or self.mask.size(0) != S:
            # square mask: True in upper triangle above diag
            self.mask = torch.triu(torch.ones(S,S,dtype=torch.bool), diagonal=1).to(emb.device)

        # 3) apply masked Transformer
        attn_out = self.transformer(emb, mask=self.mask)  # → [B, S, d_model]

        # 4) grab last time‐step
        last = attn_out[:, -1, :]               # → [B, d_model]

        # 5) project to a scalar
        y_pred = self.readout(last).squeeze(-1) # → [B]
        return y_pred

# ─────────────────────────────────────────────────────────────────────────────
# 3) DRIVER / TRAINING + VALIDATION
# ─────────────────────────────────────────────────────────────────────────────

if __name__ == "__main__":
    torch.manual_seed(0)
    np.random.seed(0)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # 3.1) HYPERPARAMETERS
    p, q   = 0, 3    # Hermite orders in [p..q]
    N1, T1 = 500,  8  # first group: 50 tasks, each length 8
    N2, T2 = 800, 12  # second group: 30 tasks, each length 12

    # 3.2) GENERATE + CORRUPT DATA
    training_data1 = generate_dataset(T1, p, q, N1, coeff_variance=2.0, x_variance=1.0)
    training_data2 = generate_dataset(T2, p, q, N2, coeff_variance=2.0, x_variance=1.0)

    # extract true last‐y and zero it in place
    y_last1 = []
    for task in training_data1:
        y_last1.append(float(task[-1,1]))
        task[-1,1] = 0.0

    y_last2 = []
    for task in training_data2:
        y_last2.append(float(task[-1,1]))
        task[-1,1] = 0.0

    # combine both groups into one training list
    train_tasks = training_data1 + training_data2
    train_y     = y_last1 + y_last2
    total_tasks = len(train_tasks)  # = N1 + N2

    # 3.3) INSTANTIATE MODEL, LOSS, OPTIMIZER
    model     = TinyTransformer(d_model=5, nhead=1).to(device)
    criterion = nn.MSELoss()
    optimizer = optim.Adam(model.parameters(), lr=1e-4)

    # 3.4) TRAINING LOOP (no freezing!)
    print("\n--- Joint training of MLP+Transformer on all tasks ---")
    for epoch in range(1, 10000):
        model.train()
        total_loss = 0.0

        # shuffle tasks each epoch
        idxs = np.random.permutation(total_tasks)
        for i in idxs:
            task = train_tasks[i]         # NumPy array (T,2)
            y_true = train_y[i]           # float scalar

            # build input [1, T, 2]
            x_seq = torch.from_numpy(task[:,:2])\
                          .unsqueeze(0)\
                          .float()\
                          .to(device)      # → [1, T, 2]
            y_true = torch.tensor([y_true], dtype=torch.float32, device=device)  # [1]

            # forward + backward
            pred = model(x_seq)           # → [1]
            loss = criterion(pred, y_true)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        if epoch % 1000 == 0:
            avg = total_loss / total_tasks
            print(f"Epoch {epoch:4d}  avg loss = {avg:.4f}")

    # 3.5) VALIDATION
    model.eval()
    T_val  = 100
    N_star = 15
    val_tasks = generate_dataset(T_val, p, q, N_star, coeff_variance=2.0, x_variance=1.0)

    val_y_true = []
    for task in val_tasks:
        val_y_true.append(float(task[-1,1]))
        task[-1,1] = 0.0

    val_preds = []
    with torch.no_grad():
        for task in val_tasks:
            x_seq = torch.from_numpy(task[:,:2])\
                          .unsqueeze(0)\
                          .float()\
                          .to(device)    # [1, N_star, 2]
            pred = model(x_seq).item()
            val_preds.append(pred)

    val_preds = np.array(val_preds)
    val_trues = np.array(val_y_true)
    val_mse   = np.mean((val_preds - val_trues)**2)

    print(f"\nValidation on {T_val} tasks of length {N_star}")
    print(f"  Avg MSE = {val_mse:.4f}")
    for i,(t,p) in enumerate(zip(val_trues, val_preds),1):
        print(f" Task {i:2d}: true={t:.4f}  pred={p:.4f}")





--- Joint training of MLP+Transformer on all tasks ---
Epoch 1000  avg loss = 855.5754
Epoch 2000  avg loss = 719.9027
Epoch 3000  avg loss = 573.1282
Epoch 4000  avg loss = 399.9585
