In [None]:
def composite_schedule_sampling(iter: int, decoder_step: int, max_iters: int, max_decoder_step: int = 128) -> float:
    """
    Calculate the sampling probability for the golden token using the Composite schedule.
    
    Args:
        iter (int): Current training iteration
        decoder_step (int): Current decoding step
        max_iters (int): Maximum number of training iterations
        max_decoder_step (int): Maximum number of decoding steps
    
    Returns:
        float: Probability of sampling the golden token
    """
    # Normalize the training iteration to [0, 1]
    normalized_iter = iter / max_iters #real_max_iters
    
    # Calculate f(i) - linear decay from 1 to 0 over training
    f_i = max(0.0, 1.0 - normalized_iter)
    
    # Normalize decoder step to [0, 1]
    normalized_step = decoder_step / max_decoder_step
    
    # Calculate h(i,t) = g(t * (1 - f(i)))
    # Using sigmoid as g(x) to create smooth transition
    x = normalized_step * (1 - f_i)
    probability = 1 - (1 / (1 + np.exp(-10 * (x - 0.5))))
    
    # Clip probability to [0, 1] range
    probability = np.clip(probability, 0.0, 1.0)
    
    return 1 - float(probability)

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim

from itertools import chain

# Parts of the code are modifications of Pytorch's AdamW optimizer
# Parts of the code are modifications of code from https://github.com/jiaweizzhao/GaLore/blob/master/galore_torch/galore_projector.py


class SOAP(optim.Optimizer):
    """
    Implements SOAP algorithm (https://arxiv.org/abs/2409.11321).

    Parameters:
        params (`Iterable[nn.parameter.Parameter]`):
            Iterable of parameters to optimize or dictionaries defining parameter groups.
        lr (`float`, *optional*, defaults to 0.003):
            The learning rate to use.
        betas (`Tuple[float,float]`, *optional*, defaults to `(0.95, 0.95)`):
            Adam's betas parameters (b1, b2).
        shampoo_beta (`float`, *optional*, defaults to -1):
            If >= 0, use this beta for the preconditioner (L and R in paper, state['GG'] below) moving average instead of betas[1].
        eps (`float`, *optional*, defaults to 1e-08):
            Adam's epsilon for numerical stability.
        weight_decay (`float`, *optional*, defaults to 0.01): weight decay coefficient.
        precondition_frequency (`int`, *optional*, defaults to 10):
            How often to update the preconditioner.
        max_precond_dim (`int`, *optional*, defaults to 10000):
            Maximum dimension of the preconditioner.
            Set to 10000, so that we exclude most common vocab sizes while including layers.
        merge_dims (`bool`, *optional*, defaults to `False`):
            Whether or not to merge dimensions of the preconditioner.
        precondition_1d (`bool`, *optional*, defaults to `False`):
            Whether or not to precondition 1D gradients.
        normalize_grads (`bool`, *optional*, defaults to `False`):
            Whether or not to normalize gradients per layer. 
            Helps at large precondition_frequency (~100 in our experiments), 
            but hurts performance at small precondition_frequency (~10 in our experiments).
        data_format (`str`, *optional*, defaults to `channels_first`):
            Data format of the input for convolutional layers.
            Should be "channels_last" for data_format of NHWC and "channels_first" for NCHW.
        correct_bias (`bool`, *optional*, defaults to `True`):
            Whether or not to use bias correction in Adam.
    """

    def __init__(
        self,
        params,
        lr: float = 3e-3,
        betas=(0.95, 0.95),
        shampoo_beta: float= -1,
        eps: float = 1e-8,
        weight_decay: float = 0.01,
        precondition_frequency: int=10,
        max_precond_dim: int=10000, # 
        merge_dims: bool = False, # Merge dimensions till the product of the dimensions is less than or equal to max_precond_dim.
        precondition_1d: bool = False,
        normalize_grads: bool = False,
        data_format: str = "channels_first",
        correct_bias: bool = True,
    ):
        defaults = {
            "lr": lr,
            "betas": betas,
            "shampoo_beta": shampoo_beta,
            "eps": eps,
            "weight_decay": weight_decay,
            "precondition_frequency": precondition_frequency,
            "max_precond_dim": max_precond_dim,
            "merge_dims": merge_dims,
            "precondition_1d": precondition_1d,
            "normalize_grads": normalize_grads,
            "correct_bias": correct_bias,
        }
        super().__init__(params, defaults)
        self._data_format = data_format
        
    def merge_dims(self, grad, max_precond_dim):
        """
        Merges dimensions of the gradient tensor till the product of the dimensions is less than or equal to max_precond_dim.
        """
        assert self._data_format in ["channels_first", "channels_last"]
        if self._data_format == "channels_last" and grad.dim() == 4:
            grad = grad.permute(0, 3, 1, 2)
        shape = grad.shape
        new_shape = []
        
        curr_shape = 1
        for sh in shape:
            temp_shape = curr_shape * sh
            if temp_shape > max_precond_dim:
                if curr_shape > 1:
                    new_shape.append(curr_shape)
                    curr_shape = sh
                else:
                    new_shape.append(sh)
                    curr_shape = 1
            else:
                curr_shape = temp_shape
        
        if curr_shape > 1 or len(new_shape)==0:
            new_shape.append(curr_shape)
        
        new_grad = grad.reshape(new_shape)
        return new_grad               

    @torch.no_grad()
    def step(self, closure = None):
        """
        Performs a single optimization step.

        Arguments:
            closure (`Callable`, *optional*): A closure that reevaluates the model and returns the loss.
        """
        if closure is None:
            loss = None
        else:
            loss = closure()
        
        for group in self.param_groups:
            for p in group["params"]:
                if p.grad is None:
                    continue
                grad = p.grad

                state = self.state[p]
                
                if "step" not in state:
                    state["step"] = 0 
                    
                # State initialization
                if "exp_avg" not in state:
                    # Exponential moving average of gradient values
                    state["exp_avg"] = torch.zeros_like(grad)
                    # Exponential moving average of squared gradient values
                    state["exp_avg_sq"] = torch.zeros_like(grad)
                
                if 'Q' not in state:
                    self.init_preconditioner(
                        grad,
                        state,
                        precondition_frequency=group['precondition_frequency'],
                        precondition_1d=group['precondition_1d'],
                        shampoo_beta=(group['shampoo_beta'] if group['shampoo_beta'] >= 0 else group["betas"][1]),
                        max_precond_dim=group['max_precond_dim'],
                        merge_dims=group["merge_dims"],
                    )
                    self.update_preconditioner(grad, state,
                                               max_precond_dim=group['max_precond_dim'],
                                               merge_dims=group["merge_dims"],
                                               precondition_1d=group["precondition_1d"])
                    continue # first step is skipped so that we never use the current gradients in the projection.
                
                # Projecting gradients to the eigenbases of Shampoo's preconditioner 
                # i.e. projecting to the eigenbases of matrices in state['GG']
                grad_projected = self.project(grad, state, merge_dims=group["merge_dims"], 
                                              max_precond_dim=group['max_precond_dim'])

                exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
                beta1, beta2 = group["betas"]

                state["step"] += 1

                # Decay the first and second moment running average coefficient
                # In-place operations to update the averages at the same time
                exp_avg.mul_(beta1).add_(grad_projected, alpha=(1.0 - beta1))
                exp_avg_sq.mul_(beta2).add_(grad_projected.square(), alpha=(1.0 - beta2))

                denom = exp_avg_sq.sqrt().add_(group["eps"])
                
                # Projecting the exponential moving average of gradients to the eigenbases of Shampoo's preconditioner 
                # i.e. projecting to the eigenbases of matrices in state['GG']
                # exp_avg_projected = self.project(exp_avg, state, merge_dims=group["merge_dims"],
                #                                  max_precond_dim=group['max_precond_dim'])
                exp_avg_projected = exp_avg
                
                step_size = group["lr"]
                if group["correct_bias"]:
                    bias_correction1 = 1.0 - beta1 ** (state["step"])
                    bias_correction2 = 1.0 - beta2 ** (state["step"])
                    step_size = step_size * (bias_correction2 ** .5) / bias_correction1

                # Projecting back the preconditioned (by Adam) exponential moving average of gradients
                # to the original space
                norm_grad = self.project_back(exp_avg_projected / denom, state, merge_dims=group["merge_dims"],
                                                 max_precond_dim=group['max_precond_dim'])

                if group["normalize_grads"]:
                    norm_grad = norm_grad / (1e-30+torch.mean(norm_grad**2)**0.5)
                
                p.add_(norm_grad, alpha=-step_size)
                

                # From AdamW code: Just adding the square of the weights to the loss function is *not*
                # the correct way of using L2 regularization/weight decay with Adam,
                # since that will interact with the m and v parameters in strange ways.
                #
                # Instead we want to decay the weights in a manner that doesn't interact
                # with the m/v parameters. This is equivalent to adding the square
                # of the weights to the loss with plain (non-momentum) SGD.
                # Add weight decay at the end (fixed version)
                if group["weight_decay"] > 0.0:
                    p.add_(p, alpha=(-group["lr"] * group["weight_decay"]))
                    
                # Update is done after the gradient step to avoid using current gradients in the projection.
                self.update_preconditioner(grad, state, 
                                               max_precond_dim=group['max_precond_dim'],
                                               merge_dims=group["merge_dims"],
                                               precondition_1d=group["precondition_1d"])
        
        return loss
    
    def init_preconditioner(self, grad, state, precondition_frequency=10, 
                            shampoo_beta=0.95, max_precond_dim=10000, precondition_1d=False,
                            merge_dims=False):
        """
        Initializes the preconditioner matrices (L and R in the paper).
        """
        state['GG'] = [] # Will hold all the preconditioner matrices (L and R in the paper).
        if grad.dim() == 1:
            if not precondition_1d or grad.shape[0] > max_precond_dim:
                state['GG'].append([])
            else:
                state['GG'].append(torch.zeros(grad.shape[0], grad.shape[0], device=grad.device))
        else:
            if merge_dims:
                grad = self.merge_dims(grad, max_precond_dim)

            for sh in grad.shape:
                if sh > max_precond_dim:
                    state['GG'].append([])
                else:
                    state['GG'].append(torch.zeros(sh, sh, device=grad.device))
                    
        state['Q'] = None # Will hold all the eigenbases of the preconditioner.
        state['precondition_frequency'] = precondition_frequency
        state['shampoo_beta'] = shampoo_beta          
        
    def project(self, grad, state, merge_dims=False, max_precond_dim=10000):
        """
        Projects the gradient to the eigenbases of the preconditioner.
        """
        original_shape = grad.shape
        if merge_dims:
            if grad.dim() == 4 and self._data_format == 'channels_last':
                permuted_shape = grad.permute(0, 3, 1, 2).shape
            grad = self.merge_dims(grad, max_precond_dim)

        for mat in state['Q']:
            if len(mat) > 0:
                grad = torch.tensordot(
                        grad,
                        mat,
                        dims=[[0], [0]],
                    )
            else:
                permute_order = list(range(1, len(grad.shape))) + [0]
                grad = grad.permute(permute_order)
        
        if merge_dims:
            if self._data_format == 'channels_last' and len(original_shape) == 4:
                grad = grad.reshape(permuted_shape).permute(0, 2, 3, 1)
            else:
                grad = grad.reshape(original_shape)
        return grad
        
    def update_preconditioner(self, grad, state, 
                              max_precond_dim=10000, merge_dims=False, precondition_1d=False):
        """
        Updates the preconditioner matrices and the eigenbases (L, R, Q_L, Q_R in the paper).
        """
        if state["Q"] is not None:
            state["exp_avg"] = self.project_back(state["exp_avg"], state, merge_dims=merge_dims, max_precond_dim=max_precond_dim)
        if grad.dim() == 1:
            if precondition_1d and grad.shape[0] <= max_precond_dim:
                state['GG'][0].lerp_(grad.unsqueeze(1) @ grad.unsqueeze(0), 1-state['shampoo_beta'])
        else:
            if merge_dims:
                new_grad = self.merge_dims(grad, max_precond_dim)
                for idx, sh in enumerate(new_grad.shape):
                    if sh <= max_precond_dim:
                        outer_product = torch.tensordot(
                                new_grad,
                                new_grad,
                                dims=[[*chain(range(idx), range(idx + 1, len(new_grad.shape)))]] * 2,
                            )
                        state['GG'][idx].lerp_(outer_product, 1-state['shampoo_beta'])
            else:
                for idx, sh in enumerate(grad.shape):
                    if sh <= max_precond_dim:
                        outer_product = torch.tensordot(
                                grad,
                                grad,
                                # Contracts across all dimensions except for k.
                                dims=[[*chain(range(idx), range(idx + 1, len(grad.shape)))]] * 2,
                            )
                        state['GG'][idx].lerp_(outer_product, 1-state['shampoo_beta'])
                     
        if state['Q'] is None:
            state['Q'] = self.get_orthogonal_matrix(state['GG'])
        if state['step'] > 0 and state['step'] % state['precondition_frequency'] == 0:
            state['Q'] = self.get_orthogonal_matrix_QR(state, max_precond_dim, merge_dims)
            # state['Q'] = self.get_fast_QR(state, max_precond_dim, merge_dims)             

        if state["step"] > 0:
            state["exp_avg"] = self.project(state["exp_avg"], state, merge_dims=merge_dims, max_precond_dim=max_precond_dim) 

    def project_back(self, grad, state, merge_dims=False, max_precond_dim=10000):
        """
        Projects the gradient back to the original space.
        """
        original_shape = grad.shape
        if merge_dims:
            if self._data_format == 'channels_last' and grad.dim() == 4:
                permuted_shape = grad.permute(0, 3, 1, 2).shape
            grad = self.merge_dims(grad, max_precond_dim)
        for mat in state['Q']:
            if len(mat) > 0:
                grad = torch.tensordot(
                        grad,
                        mat,
                        dims=[[0], [1]],
                    )
            else:
                permute_order = list(range(1, len(grad.shape))) + [0]
                grad = grad.permute(permute_order)
                
        if merge_dims:
            if self._data_format == 'channels_last' and len(original_shape) == 4:
                grad = grad.reshape(permuted_shape).permute(0, 2, 3, 1)
            else:
                grad = grad.reshape(original_shape)
        return grad
        

    def get_orthogonal_matrix(self, mat):
        """
        Computes the eigenbases of the preconditioner using torch.linalg.eigh decomposition.
        """
        matrix = []
        for m in mat:
            if len(m) == 0:
                matrix.append([])
                continue
            if m.data.dtype != torch.float:
                float_data = False
                original_type = m.data.dtype
                original_device = m.data.device
                matrix.append(m.data.float())
            else:
                float_data = True
                matrix.append(m.data)
        
        final = []
        for m in matrix:
            if len(m) == 0:
                final.append([])
                continue
            try:
                _, Q = torch.linalg.eigh(m+1e-30*torch.eye(m.shape[0], device=m.device))
            except:
                _, Q = torch.linalg.eigh(m.to(torch.float64)+1e-30*torch.eye(m.shape[0], device=m.device))
                Q = Q.to(m.dtype)
            Q = torch.flip(Q, [1])

            if not float_data:
                Q = Q.to(original_device).type(original_type)
            final.append(Q)
        return final
        

    def get_orthogonal_matrix_QR(self, state, max_precond_dim=10000, merge_dims=False):
        """
        Computes the eigenbases of the preconditioner using one round of power iteration 
        followed by torch.linalg.qr decomposition.
        """
        precond_list = state['GG']
        orth_list = state['Q']

        matrix = []
        orth_matrix = []
        for m,o in zip(precond_list, orth_list):
            if len(m) == 0:
                matrix.append([])
                orth_matrix.append([])
                continue
            if m.data.dtype != torch.float:
                float_data = False
                original_type = m.data.dtype
                original_device = m.data.device
                matrix.append(m.data.float())
                orth_matrix.append(o.data.float())
            else:
                float_data = True
                matrix.append(m.data.float())
                orth_matrix.append(o.data.float())
        
        orig_shape = state['exp_avg_sq'].shape
        if self._data_format == 'channels_last' and len(orig_shape) == 4:
            permuted_shape = state['exp_avg_sq'].permute(0, 3, 1, 2).shape
        if merge_dims:
            exp_avg_sq = self.merge_dims(state['exp_avg_sq'], max_precond_dim)
        else:
            exp_avg_sq = state['exp_avg_sq']
            
        final = []
        for ind, (m,o) in enumerate(zip(matrix, orth_matrix)):
            if len(m)==0:
                final.append([])
                continue
            est_eig = torch.diag(o.T @ m @ o)
            sort_idx = torch.argsort(est_eig, descending=True)
            exp_avg_sq = exp_avg_sq.index_select(ind, sort_idx)
            o = o[:,sort_idx]
            power_iter = m @ o
            Q, _ = torch.linalg.qr(power_iter)

            if not float_data:
                Q = Q.to(original_device).type(original_type)
            final.append(Q)
        
        if merge_dims:
            if self._data_format == 'channels_last' and len(orig_shape) == 4:
                exp_avg_sq = exp_avg_sq.reshape(permuted_shape).permute(0, 2, 3, 1)
            else:
                exp_avg_sq = exp_avg_sq.reshape(orig_shape)
                
        state['exp_avg_sq'] = exp_avg_sq
        return final

In [None]:
import numpy as np

def downsample_array(arr: np.ndarray, period: int) -> np.ndarray:
    """
    Downsample a numpy array by taking the mean over specified periods.
    
    Parameters:
    -----------
    arr : np.ndarray
        Input array to be downsampled
    period : int
        Number of elements to average over
        
    Returns:
    --------
    np.ndarray
        Downsampled array where each element is the mean of 'period' elements
        from the original array
        
    Examples:
    --------
    >>> arr = np.array([1, 2, 3, 4, 5, 6])
    >>> downsample_array(arr, 2)
    array([1.5, 3.5, 5.5])
    >>> downsample_array(arr, 3)
    array([2., 5.])
    """
    # Check if array length is divisible by period
    if len(arr) % period != 0:
        # Trim array to make it divisible by period
        trim_length = len(arr) - (len(arr) % period)
        arr = arr[:trim_length]
    
    # Reshape array into rows of length 'period'
    reshaped = arr.reshape(-1, period)
    
    # Calculate mean along each row
    return np.mean(reshaped, axis=1)

In [None]:
import polars as pl
from tqdm import tqdm
import pandas as pd
import numpy as np
import math
import time
from dataclasses import dataclass
import torch
from contextlib import nullcontext

DATA = pl.scan_parquet('train.parquet/')

SYMBOLS = sorted(DATA.select('symbol_id').unique().collect().to_numpy().flatten())
print(SYMBOLS)
FEATURE_COLS = [col for col in DATA.collect_schema().names() if 'feature_' in col] + [col + '_downsampled' for col in DATA.collect_schema().names() if 'responder' in col]# + ['weight']

RESPONDER_COLS = [col for col in DATA.collect_schema().names() if 'responder' in col]
DFS_USED = {}
DATES_USED = {}

In [None]:
for k in tqdm(SYMBOLS, total=len(SYMBOLS), desc='Indexing DFS...'):
    # We load each dataframe into memory, probably not the best approach... (We can get away with it since the amount of data is small)
    lazy = pl.scan_parquet(f'df_train.parquet/{k}.parquet')
    df = lazy.sort(['date_id', 'time_id']).collect()
    DFS_USED[k] = df
    DATES_USED[k] = sorted(lazy.select('date_id').unique().collect().to_numpy().flatten())

In [None]:
MIN_IDXS = {}
for k in tqdm(SYMBOLS, total=len(SYMBOLS)):
    min_idx = DFS_USED[k].filter(pl.col('date_id') == DATES_USED[k][1]).select(pl.col('index').max()).to_numpy().flatten()[0]
    MIN_IDXS[k] = min_idx

In [None]:
DATA_LEN = 0
for df in DFS_USED.values():
    DATA_LEN += df.shape[0]
DATA_LEN

In [None]:
import inspect
import torch
import torch.nn as nn
from torch.nn import functional as F
from typing import Callable
import numpy as np
import re
import polars as pl

In [None]:
class Rotary(nn.Module):
    """Module that implements RoPE (Rotary Positional Embeddings)"""
    def __init__(self, dim, base=10_000):
        super().__init__()
        self.inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
        self.seq_len_cached = None
        self.cos_cached = None
        self.sin_cached = None

    def forward(self, x):
        seq_len = x.shape[1]
        if seq_len != self.seq_len_cached:
            self.seq_len_cached = seq_len
            t = torch.arange(seq_len, device=x.device).type_as(self.inv_freq)
            freqs = torch.outer(t, self.inv_freq).to(x.device)
            self.cos_cached = freqs.cos()
            self.sin_cached = freqs.sin()
        return self.cos_cached[None, :, None, :], self.sin_cached[None, :, None, :]

def apply_rotary_emb(x, cos, sin):
    assert x.ndim == 4
    d = x.shape[3] // 2
    x1 = x[..., :d]
    x2 = x[..., d:]
    y1 = x1 * cos + x2 * sin
    y2 = x1 * (-sin) + x2 * cos
    return torch.cat([y1, y2], 3).type_as(x)

In [None]:
def zero_weighted_rsquared(y_pred, y_true, weights):
    # Calculate numerator: Σ w_i(y_i - ŷ_i)^2
    numerator = torch.sum(weights * (y_true - y_pred)**2)

    # Calculate denominator: Σ w_i * y_i^2
    denominator = torch.sum(weights * y_true**2)

    # Calculate R-squared
    r_squared = 1 - (numerator / denominator)

    # Since we want to minimize the loss, we return 1 - R^2
    return 1 - r_squared

def zero_weighted_rsquared_np(y_pred, y_true, weights):
    # Calculate numerator: Σ w_i(y_i - ŷ_i)^2
    numerator = np.sum(weights * (y_true - y_pred)**2)

    # Calculate denominator: Σ w_i * y_i^2
    denominator = np.sum(weights * y_true**2)

    # Calculate R-squared
    r_squared = 1 - (numerator / denominator)

    # Since we want to minimize the loss, we return 1 - R^2
    return 1 - r_squared

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Callable

class GatedNormalization(nn.Module):
    """Gated Normalization module that balances between BatchNorm and InstanceNorm using a gate in order to handle severe non-stationary features."""
    def __init__(self, num_features, epsilon=1e-8):
        super().__init__()
        self.feature_wise_norm = nn.InstanceNorm1d(num_features)
        self.batch_norm = nn.BatchNorm1d(num_features)
        self.gate = nn.Parameter(torch.zeros(num_features))  # Initialized to 0 (equal weighting)

    def forward(self, x):
        # x: (batch, features, time)
        rms_normed = self.feature_wise_norm(x)
        layer_normed = self.batch_norm(x)

        gate = torch.sigmoid(self.gate).view(1, -1, 1)  # (1, features, 1)
        return gate * rms_normed + (1 - gate) * layer_normed

class BlockTimeReducer(nn.Module):
    """
    Module that reduces the time dimension of a tensor of size (batch, features, time, embd).
    The idea here is to leverage the high auto correlation of the features across time in order to reduce computation later in the attention.
    """
    def __init__(self, reduction_steps, n_features=79, n_embd=64, activation='gelu'):
        super().__init__()
        
        self.n_embd = n_embd
        self.n_features = n_features
        
        # Learnable feature-wise scaling factors
        self.feature_scales = nn.Parameter(torch.ones(n_features, 1, 1))
        
        # Set activation function
        if activation == 'gelu':
            self.activation = nn.GELU()
        elif activation == 'gelu_new':
            self.activation = GELUNew()
        elif activation == 'silu':
            self.activation = SiLU()
        elif activation == 'mish':
            self.activation = nn.Mish()
        else:
            raise ValueError(f"Unsupported activation: {activation}")

        self.conv_layers = nn.ModuleList([
            nn.Conv2d(
                in_channels=n_features,
                out_channels=n_features,
                kernel_size=(step, 1),
                stride=(step, 1),
                groups=n_features,
            ) for step in reduction_steps
        ])
    
    def forward(self, x):
        # Input shape: (batch, n_features, block_size, n_embd)
        batch_size, n_features, time, n_embd = x.shape
        x = x * self.feature_scales
        
        # Apply reduction strategy
        for c in self.conv_layers:
            x = c(x)
            x = self.activation(x)

        return x
    
class RMSNorm(nn.Module):
    def __init__(self, dim, norm_dim=0):
        super().__init__()
        self.norm_dim = norm_dim
        self.scale = dim**0.5
        if norm_dim == 0:
            self.g = nn.Parameter(torch.ones(dim))
        else:
            self.g = nn.Parameter(torch.ones(dim, 1))

    def forward(self, x):
        if self.norm_dim == 0:
            return F.normalize(x, dim=-1) * self.scale * self.g
        else:
            return F.normalize(x, dim=1) * self.scale * self.g

class LayerNorm(nn.Module):
    """ LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False """

    def __init__(self, ndim, bias):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(ndim))
        self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None

    def forward(self, input):
        return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5)

class GLU(nn.Module):
    def __init__(self, dim_in, dim_out, activation: Callable, mult_bias=False):
        super().__init__()
        self.act = activation
        self.proj = nn.Linear(dim_in, dim_out * 2)
        self.mult_bias = nn.Parameter(torch.ones(dim_out)) if mult_bias else 1.0

    def forward(self, x):
        x, gate = self.proj(x).chunk(2, dim=-1)
        return x * self.act(gate) * self.mult_bias
    
class InformedDownsampler(nn.Module):
    """
    Module that reduces along the feature dimension of a tensor of size (batch, feature, time, embd).
    The features.csv and responders.csv suggest that some features belong in the same categories, so it makes sense to reduce them to fit those categories.
    The idea is the same as in the Time reducer, reduce the size of the tensors to reduce the attention computation.
    Note: The first layer weights' are initialized using the correlation matrix taken from the .csv files.
    """
    def __init__(self, input_dim, output_dim=8, feature_relationships=None):
        super().__init__()
        
        # Process feature relationships if provided
        self.correlation_matrix = None
        if feature_relationships is not None:
            # Convert boolean strings to actual booleans and then to float tensor
            tag_matrix = feature_relationships.iloc[:, 1:].astype(bool).astype(float).values
            # Calculate feature correlations based on tag similarities
            self.correlation_matrix = torch.from_numpy(
                np.corrcoef(tag_matrix)
            ).float()
        
        hidden_dim = input_dim // 2
        
        self.layer1 = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.LayerNorm(hidden_dim, elementwise_affine=True),
            nn.GELU()
        )
        
        self.layer2 = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.LayerNorm(hidden_dim // 2, elementwise_affine=True),
            nn.GELU()
        )
        
        self.layer3 = nn.Linear(hidden_dim // 2, output_dim)
        
        # Initialize weights using feature relationships if available
        if self.correlation_matrix is not None:
            self._initialize_weights()
            
        # Skip connection
        self.skip = nn.Linear(input_dim, output_dim)
            
    def _initialize_weights(self):
        # Fix initialization to match dimensions correctly
        with torch.no_grad():
            # Get the first layer's weight shape
            out_features, in_features = self.layer1[0].weight.shape
            
            # Ensure correlation matrix matches input dimension
            if self.correlation_matrix.shape[0] != in_features:
                # If dimensions don't match, we need to either truncate or pad
                if self.correlation_matrix.shape[0] < in_features:
                    # Pad with zeros if correlation matrix is too small
                    padded = torch.zeros(in_features, in_features)
                    padded[:self.correlation_matrix.shape[0], :self.correlation_matrix.shape[1]] = self.correlation_matrix
                    self.correlation_matrix = padded
                else:
                    # Truncate if correlation matrix is too large
                    self.correlation_matrix = self.correlation_matrix[:in_features, :in_features]
            
            # Scale and reshape correlation matrix to match first layer's weights
            scaled_weights = (self.correlation_matrix * 0.1).T
            self.layer1[0].weight.data = scaled_weights[:out_features, :]
            
    def forward(self, x):
        # Handle 4D input (batch, n_features, time, embd)
        batch_size, n_features, time_steps, embedding_dim = x.shape
        
        # Reshape to process features while maintaining causality
        # First, move feature dim to end: (batch, time, embd, n_features)
        x_reshaped = x.permute(0, 2, 3, 1)

        # Process through layers
        h = self.layer1(x_reshaped)
        h = self.layer2(h)
        h = self.layer3(h)
        skip = self.skip(x_reshaped)
        h = h + skip

        # Permute back to original dimension order: (batch, output_dim, time, embd)
        return h.permute(0, 3, 1, 2)

def create_downsampler(csv_path, input_dim, output_dim=1):
    df = pd.read_csv(csv_path)
    
    if 'feature' in csv_path:
        # Extract the 10th, 11th, and 12th rows
        rows_to_move = df.iloc[9:12].copy()

        # Drop these rows from the original DataFrame
        df = df.drop(df.index[9:12])

        # Append the extracted rows to the bottom of the DataFrame
        df = pd.concat([df, rows_to_move], ignore_index=True)
        
    return InformedDownsampler(
        input_dim=input_dim,
        output_dim=output_dim,
        feature_relationships=df
    )

class MLP(nn.Module):

    def __init__(self, config, n_embd=None):
        super().__init__()
        if n_embd is None:
            n_embd = config.n_embd
        self.glu     = GLU(n_embd, 4 * n_embd, nn.SiLU())
        self.norm    = LayerNorm(4 * n_embd, bias=config.bias)
        self.c_proj  = nn.Linear(4 * n_embd, n_embd, bias=config.bias)
        self.dropout = nn.Dropout(config.dropout)

    def forward(self, x):
        x = self.glu(x)
        x = self.c_proj(self.norm(x))
        x = self.dropout(x)
        return x
    
class SelfAttention(nn.Module):

    def __init__(self, config, is_causal):
        super().__init__()
        assert config.n_embd % config.n_head == 0
        self.to_q = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
        self.to_k = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
        self.to_v = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
        self.rotary = Rotary(config.n_embd // config.n_head)
        # output projection
        self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
        
        self.lamb1 = nn.Parameter(torch.tensor(0.5))
        self.lamb2 = nn.Parameter(torch.tensor(0.5))
        
        # regularization
        self.resid_dropout = nn.Dropout(config.dropout)
        self.n_embd = config.n_embd
        self.n_head = config.n_head
        self.dropout = config.dropout
        self.is_causal = is_causal
        # flash attention is way faster than MultiHeadAttention, requires Pytorch >= 2.0
        assert hasattr(torch.nn.functional, 'scaled_dot_product_attention'), "Flash Attention requires PyTorch >= 2.0"


    def forward(self, x, v1=None, cutoff=None):
        if len(x.shape) == 3:
            B, T, C = x.size()
            FS = 1
            reshape_at_end = False
        else:
            B, FS, T, C = x.size()  # batch, features, time, embedding
        
            # Reshape to combine features and time dimensions
            x = x.permute(0, 2, 1, 3).reshape(B, T * FS, C)
            reshape_at_end = True
        
        # Apply linear projections to reshaped input
        q = self.to_q(x)
        k = self.to_k(x)
        v = self.to_v(x)

        # Reshape to (batch, time*features, n_head, head_dim)
        k = k.view(B, T * FS, self.n_head, C // self.n_head)
        q = q.view(B, T * FS, self.n_head, C // self.n_head)
        v = v.view(B, T * FS, self.n_head, C // self.n_head)
        
        if v1 is None:
            v1 = v
        else:
            v1 = v1.view(B, T * FS, self.n_head, C // self.n_head)
            
        v = self.lamb1 * v + self.lamb2 * v1.view_as(v)

        # Convert to float32 for rotary embeddings
        orig_dtype = q.dtype
        q_f32 = q.float()
        k_f32 = k.float()
        
        # Get and apply rotary embeddings in float32
        cos, sin = self.rotary(q_f32)
        q = apply_rotary_emb(q_f32, cos, sin).to(orig_dtype)
        k = apply_rotary_emb(k_f32, cos, sin).to(orig_dtype)
        
        q, k = F.rms_norm(q, (q.size(-1),)), F.rms_norm(k, (k.size(-1),))

        # Transpose for attention
        q = q.transpose(1, 2)
        k = k.transpose(1, 2)
        v = v.transpose(1, 2)
        
        mask = None
        
        # Compute attention
        y = torch.nn.functional.scaled_dot_product_attention(
            q, k, v,
            attn_mask=mask,
            dropout_p=self.dropout if self.training else 0,
            is_causal=False
        )
        
        # Reshape back
        y = y.transpose(1, 2).contiguous().view(B, T * FS, C)
        
        # Output projection
        y = self.resid_dropout(self.c_proj(y))
        
        if reshape_at_end:
            # Reshape back to original dimensions (batch, features, time, embedding)
            y = y.view(B, T, FS, C).permute(0, 2, 1, 3)

        return y, v1
    
class CrossAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        assert config.n_embd % config.n_head == 0
        self.to_q = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
        self.to_k = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
        self.to_v = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
        # output projection
        self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
        self.rotary_q = Rotary(config.n_embd // config.n_head)
        self.rotary_k = Rotary(config.n_embd // config.n_head)
        
        # regularization
        self.resid_dropout = nn.Dropout(config.dropout)
        self.n_head = config.n_head
        self.n_embd = config.n_embd
        self.dropout = config.dropout
        assert hasattr(torch.nn.functional, 'scaled_dot_product_attention'), "Flash Attention requires PyTorch >= 2.0"

    def forward(self, x, encoded_x):
        B, FS, eT, eC = encoded_x.size()  # batch, features, time, embedding for encoded_x
        B, T, C = x.size()  # batch, time, embedding for x
        
        # Reshape encoded_x to combine features and time
        encoded_x_reshaped = encoded_x.permute(0, 2, 1, 3).reshape(B, eT * FS, eC)
        
        # Compute Q, K, V
        q = self.to_q(x)  # (B, T, C)
        k = self.to_k(encoded_x_reshaped)  # (B, eT*F, eC)
        v = self.to_v(encoded_x_reshaped)  # (B, eT*F, eC)

        # Reshape to add head dimension
        k = k.view(B, eT * FS, self.n_head, eC // self.n_head)  # (B, eT*F, nh, hs)
        q = q.view(B, T, self.n_head, C // self.n_head)        # (B, T, nh, hs)
        v = v.view(B, eT * FS, self.n_head, eC // self.n_head)  # (B, eT*F, nh, hs)
        
        # Apply separate rotary embeddings
        q_f32 = q.float()
        k_f32 = k.float()
        
        # Get rotary embeddings for different sequence lengths
        cos_q, sin_q = self.rotary_q(q_f32)
        cos_k, sin_k = self.rotary_k(k_f32)
        
        # Apply rotary embeddings
        q = apply_rotary_emb(q_f32, cos_q, sin_q).to(q.dtype)
        k = apply_rotary_emb(k_f32, cos_k, sin_k).to(k.dtype)
        
        # Apply RMS normalization
        q, k = F.rms_norm(q, (q.size(-1),)), F.rms_norm(k, (k.size(-1),))

        # Transpose for attention
        q = q.transpose(1, 2)  # (B, nh, T, hs)
        k = k.transpose(1, 2)  # (B, nh, eT*F, hs)
        v = v.transpose(1, 2)  # (B, nh, eT*F, hs)

        # Compute attention
        y = torch.nn.functional.scaled_dot_product_attention(
            q, k, v,
            attn_mask=None,
            dropout_p=self.dropout if self.training else 0,
            is_causal=False
        )

        # Reshape back
        y = y.transpose(1, 2).contiguous().view(B, T, C)
        
        # Output projection and dropout
        y = self.resid_dropout(self.c_proj(y))
        
        return y

class Block(nn.Module):

    def __init__(self, config, n_encoded_arrays=1):
        super().__init__()
        self.rms_1 = RMSNorm(config.n_embd)
        self.attn = SelfAttention(config, False)
        self.rms_2 = RMSNorm(config.n_embd)
        self.cross_attn = nn.ModuleList([CrossAttention(config) for _ in range(n_encoded_arrays)])
        self.rms_3 = RMSNorm(config.n_embd)
        self.rms_4 = RMSNorm(config.n_embd)
        self.mlp = MLP(config)
        if n_encoded_arrays > 1:
            self.reducer = nn.Linear(config.downsampled_block_size * 2, config.downsampled_block_size)

    def forward(self, x, encoded_x, self_v1=None, cutoff=None):
        # We add x to each layer to skip connections.
        if isinstance(encoded_x, list):
            attn_out, sv1 = self.attn(self.rms_1(x), self_v1, cutoff)
            x = x + attn_out
            results = []
            for idx, encoded_arr in enumerate(encoded_x):
                cross_attn_out = self.cross_attn[idx](self.rms_2(x), self.rms_3(encoded_arr))#, cross_v1)
                temp_x = x + cross_attn_out
                temp_x = temp_x + self.mlp(self.rms_4(temp_x))
                results.append(temp_x)
            x = torch.cat(results, dim=1)
            x = self.reducer(x.transpose(1, 2)).transpose(1, 2)
        else:
            attn_out, sv1 = self.attn(self.rms_1(x), self_v1)
            x = x + attn_out
            cross_attn_out = self.cross_attn[0](self.rms_2(x), self.rms_3(encoded_x))#, cross_v1)
            x = x + cross_attn_out
            x = x + self.mlp(self.rms_4(x))
        return x, sv1

class EncoderBlock(nn.Module):

    def __init__(self, config, causal_attn=False):
        super().__init__()
        self.rms_1 = RMSNorm(config.n_embd)
        self.attn = SelfAttention(config, causal_attn)
        self.rms_2 = RMSNorm(config.n_embd)
        self.mlp = MLP(config)

    def forward(self, x, v1=None):
        attn_out, v1 = self.attn(self.rms_1(x), v1)
        x = x + attn_out
        x = x + self.mlp(self.rms_2(x))
        return x, v1

class Generator(nn.Module):
    def __init__(self, config):
        super(Generator, self).__init__()
        self.n_cat_features = config.n_cat_features
        self.n_cont_features = config.n_cont_features
        self.vocab_sizes = config.vocab_sizes
        self.n_embd = config.n_embd
        self.block_size = config.block_size
        self.downsampled_block_size = config.downsampled_block_size
        self.dropout = config.dropout
        self.n_responder_features = config.n_responder_features
        self.n_layer = config.n_layer
        assert self.n_cat_features == len(self.vocab_sizes)
        assert config.n_layer == config.n_layer
        assert config.block_size % config.downsampled_block_size == 0
        self.max_iters = config.n_iters

        self.cont_norm = RMSNorm(config.n_embd)
        self.cat_norm = RMSNorm(config.n_embd)
        self.full_norm = RMSNorm(config.n_embd)
        
        # Categorical features embedding
        self.cat_embeddings = nn.ModuleList([
            nn.Embedding(vocab_size, config.n_embd)
            for vocab_size in config.vocab_sizes
        ])
        self.nan_embedding = nn.Embedding(2, config.n_embd)
        self.synthetic_embedding = nn.Embedding(2, config.n_embd)
        
        self.encoder = nn.ModuleDict(dict(
            fc_cont = nn.Linear(1, config.n_embd),
            wpe = nn.Embedding(968, config.n_embd),
            down_time_cont = BlockTimeReducer(
                [2], 
                config.n_cont_features + config.n_cat_features,
                n_embd=config.n_embd,
                activation='gelu'
            ),
            down_time_resp = BlockTimeReducer(
                [2], #[4, 4],
                n_features=config.n_responder_features,
                n_embd=config.n_embd,
                activation='gelu'
            ),
            down_f_cont = create_downsampler('features.csv', config.n_cont_features + config.n_cat_features, 16),
            down_f_resp = create_downsampler('responders.csv', config.n_responder_features, 4),
            drop = nn.Dropout(config.dropout),
            h = nn.ModuleList([EncoderBlock(config, causal_attn=False) for _ in range(config.n_layer)]),
        ))

        self.decoder = nn.ModuleDict(dict(
            wte = nn.Linear(1, config.n_embd),
            wpe = nn.Embedding(config.block_size, config.n_embd),
            drop = nn.Dropout(config.dropout),
            h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
            ln_f = RMSNorm(config.n_embd),
        ))
        self.lm_head = nn.Linear(config.n_embd, config.vocab_size)
        self.regressor = nn.Linear(config.vocab_size, 1)

    def forward(self, s, cat_features, cont_features, nan_cont_features, responders, t_ids, it=None, synthetic_flag=None):
        assert cat_features.shape[1] == self.n_cat_features
        batch_size, _, block_size = cont_features.shape
        
        if synthetic_flag is None and it is not None:
            # Not used in scheduled sampling.
            seq, synthetic_flag, cutoff = self.generate_synthetic_data(s, it)
        else:
            seq = s.clone()
            cutoff = None
            if synthetic_flag is None:
                synthetic_flag = torch.zeros((batch_size, self.downsampled_block_size), dtype=torch.int64, device=cont_features.device)
                
        synthetic_flag = self.synthetic_embedding(synthetic_flag)
        
        cont_features = self.encoder.fc_cont(cont_features.unsqueeze(-1))
        responders = self.decoder.wte(responders.unsqueeze(-1))
        
        nan_cont_features_embd = self.nan_embedding(nan_cont_features)
        pos_features = self.encoder.wpe(t_ids)
        
        # Process categorical features
        batch_size, num_features, _ = cat_features.shape
        embedded_features = [
            self.cat_embeddings[i](cat_features[:, i])  # (batch_size, block_size, n_embd)
            for i in range(num_features)
        ]
        
        # Concatenate embeddings along the last dimension
        cat_features = torch.stack(embedded_features, dim=1)
        cat_features = self.cat_norm(cat_features + pos_features.unsqueeze(1))
        
        cont_features = self.cont_norm(cont_features + nan_cont_features_embd + pos_features.unsqueeze(1))
        cont_features = cont_features * nan_cont_features.unsqueeze(-1)
        cont_features = self.full_norm(torch.cat([cont_features, cat_features], dim=1))
        
        cont_features = self.encoder.down_time_cont(cont_features)
        cont_features = self.encoder.down_f_cont(cont_features).squeeze(1)
        
        responders = self.encoder.down_time_resp(responders)
        responders = self.encoder.down_f_resp(responders).squeeze(1)
        
        # Positional embeddings
        pos = torch.arange(0, self.downsampled_block_size, dtype=torch.long, device=cont_features.device)
        pos_emb = self.decoder.wpe(pos)

        encoded_data = self.encoder.drop(torch.cat([cont_features, responders], dim=1))
        target = None
        
        tok_emb = self.decoder.wte(seq.unsqueeze(-1))
        x = self.decoder.drop(tok_emb + pos_emb + synthetic_flag)

        ve1 = None
        vd1 = None
        ve2 = None
        for idx in range(self.n_layer):   
            encoded_data, ve1 = self.encoder.h[idx](encoded_data, ve1)
            x, vd1 = self.decoder.h[idx](x, encoded_data, vd1, cutoff)

        x = self.decoder.ln_f(x)
            
        logits = self.lm_head(x)
        target = torch.clamp(self.regressor(logits), -5, 5)

        return target[:, -1:]

    
    def generate_synthetic_data(self, seq_tensor, it):
        """
        Cuts a sequence on a generated idx in order to simulate inference. Not used anymore.
        """
        batch, block_size = seq_tensor.size()

        # Calculate cutting probability using cosine schedule
        # At it=0: prob=0 (no cutting)
        # At it=max_iters: prob=1 (all sequences cut)
        progress = it / self.max_iters
        cut_prob = 0.5 * (1 + torch.cos(torch.tensor(math.pi * (1 - progress))))

        # Generate random values to determine which sequences to cut
        should_cut = torch.rand(batch) < cut_prob

        # Generate random indices in the range [0, block_size - 1] for sequences that will be cut
        random_indices = torch.randint(0, block_size, (batch,))

        # Create a new tensor with the same values as the original
        repeated_tokens = seq_tensor.clone()

        # Generate synthetic_flag tensor (0 for original sequence, 1 for extended sequence)
        synthetic_flag = torch.zeros_like(seq_tensor)
        
        cutoff = None

        return repeated_tokens, synthetic_flag.to(dtype=torch.int64), torch.tensor(cutoff, dtype=torch.int64, device=seq_tensor.device, requires_grad=False)

    def configure_optimizers(self, weight_decay, learning_rate, betas, device_type):
        # start with all of the candidate parameters
        param_dict = {pn: p for pn, p in self.named_parameters()}
        # filter out those that do not require grad
        param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad}
        
        attention_params = []
        decay_params = []
        nodecay_params = []

        # Loop through all parameters in the model
        for name, param in param_dict.items():
            if "attention" in name or "attn" in name:
                # Add attention layer parameters to attention_params
                attention_params.append(param)
            else:
                if param.dim() >= 2:
                    decay_params.append(param)
                else:
                    nodecay_params.append(param)
        
        # It has been shown that weight decay induces low rank attention layers (https://arxiv.org/html/2410.23819v1) 
        optim_groups = [
            {'params': attention_params, 'weight_decay': weight_decay*0.01},
            {'params': decay_params, 'weight_decay': weight_decay},
            {'params': nodecay_params, 'weight_decay': 0.0}
        ]
        num_decay_params = sum(p.numel() for p in decay_params)
        num_nodecay_params = sum(p.numel() for p in nodecay_params)
        num_attention_params = sum(p.numel() for p in attention_params)
        print(f"num attention parameter tensors: {len(attention_params)}, with {num_attention_params:,} parameters")
        print(f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters")
        print(f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters")
        optimizer = SOAP(optim_groups, lr=learning_rate, betas=betas, weight_decay=weight_decay, precondition_frequency=10)

        return optimizer

In [None]:
asset_probs = [round(len(DFS_USED[s])/DATA_LEN, 2) for s in SYMBOLS]
cat_ids = [9,10,11]
cat_vars = [f'feature_{i:02}' for i in cat_ids]
cont_vars = [f'feature_{i:02}' for i in range(79) if i not in cat_ids]
responder_vars = [f'responder_{n}' for n in range(9)]
print(responder_vars)

mappings = {f'feature_{i:02}': {} for i in cat_ids}
vocab_sizes = []
for id in cat_ids:
    f_name = f'feature_{id:02}'
    unique_values = DATA.select(f_name).unique().collect().to_numpy().flatten()
    mappings[f_name] = {val: idx for idx, val in enumerate(unique_values)}  # Map to range [0, len(unique_values)-1]
    vocab_sizes.append(len(unique_values))  # Update vocab_size for this feature

In [None]:
# The idea generating the batches is that we can separate the training set into days that have 848 steps and days that have 967 steps.
# It's good to have the days that have the same length grouped together because the r2 weighted loss function will be closer to what it is evaluated.
idxs_train_1 = []
idxs_train_2 = []
idxs_valid_1 = []
idxs_valid_2 = []
np.random.seed(42)
count = 0
for k, df in DFS_USED.items():
    days_900 = df.filter(pl.col('time_id')>960).select(pl.col('date_id')).unique().to_numpy().flatten()
    days_800 = df.filter(~pl.col('date_id').is_in(days_900)).select(pl.col('date_id')).unique().to_numpy().flatten()
    days_900 = df.filter(pl.col('time_id')==0, pl.col('date_id').is_in(days_900)).select(pl.col('index')).to_numpy().flatten().tolist()
    days_800 = df.filter(pl.col('time_id')==0, pl.col('date_id').is_in(days_800)).select(pl.col('index')).to_numpy().flatten().tolist()
    days_800 = list(zip(days_800, [k] * len(days_800)))
    idxs_train_1.extend(days_800[2:len(days_800)-5])
    idxs_valid_1.extend(days_800[len(days_800)-5:len(days_800)])
    
    days_900 = list(zip(days_900, [k] * len(days_900)))
    idxs_train_2.extend(days_900[2:len(days_900)-5])
    idxs_valid_2.extend(days_900[len(days_900)-5:len(days_900)])

print(len(idxs_train_1), len(idxs_valid_1), len(idxs_train_2), len(idxs_valid_2))
np.random.shuffle(idxs_train_1)
np.random.shuffle(idxs_valid_1)
np.random.shuffle(idxs_train_2)
np.random.shuffle(idxs_valid_2)

# We will use the validation set for training on the final steps.
idxs_train_1 = idxs_train_1 + idxs_valid_1
idxs_train_2 = idxs_train_2 + idxs_valid_2

In [None]:
TRAIN_COUNT_1 = 0
TRAIN_COUNT_2 = 0
EPOCH_COUNT = 0
CHOSEN_TIME = 0

In [None]:
VALID_COUNT_1 = 0
VALID_COUNT_2 = 0
CONT_COUNT = 0

In [None]:
def prepare_model():
    global DATA
    downsampled_block_size = 128 # We will mean pool block_size into 128 tokens. This allows for more information and not much is lost along the way

    total_tokens = ((len(idxs_train_1) + len(idxs_train_2)) * downsampled_block_size)
    print("The total amount of tokens is", total_tokens)

    eval_iters = downsampled_block_size
    log_interval = 20

    always_save_checkpoint = True

    gradient_accumulation_steps = 32 # Below 32 scheduled sampling does not work
    batch_size = 16 # Highest as memory allows
    block_size = 896 # We will mean pool 896 tokens into 128 for the decoded sequence
    
    # Since both sequences are concatenated is important that they have the same block_size. If the attention strategy is changed this can be changed as well.
    features_block_size = 32 # We will take the last 32 time steps of features
    responder_block_size = 32 # We will take the responders from the previous day taking the current time_id as a pivot

    tokens_per_iter = gradient_accumulation_steps * batch_size
    print(f"tokens per iteration will be: {tokens_per_iter:,}")

    learning_rate = 3e-4
    n_epochs = 1
    max_iters = int((total_tokens/tokens_per_iter)*n_epochs)
    warmup_iters = int(max_iters*0.01)
    lr_decay_iters = max_iters - warmup_iters

    print(f"We will train this dataset over {max_iters} steps")
    n_validations = 20 # The number of validations we will perform along the training
    eval_interval = (max_iters//n_validations)-1

    min_lr = 3e-5
    # Since this objective function has a lot of local minima we need to tune the betas a bit.
    beta1 = 0.9
    beta2 = 0.999
    weight_decay = 1e-1

    # bfloat16 strongly preferred over float16. This allows us to double the batch size over float32.
    dtype = 'bfloat16' if torch.cuda.is_bf16_supported() else 'float16'
    bias = False
    decay_lr = True
    eval_only = False
    
    # This can be set to 0.0 to disable gradient clipping, 
    #    but since sequences can be "hard" or "easy" it can make the loss jumpy, which can lead to big updates, so I strongly suggest it is used.
    grad_clip = 1.0 

    device = 'cuda' # Always cuda if available

    torch.manual_seed(1337)
    torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul
    torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn
    device_type = 'cuda' if 'cuda' in device else 'cpu' # for later use in torch.autocast
    # note: float16 data type will automatically use a GradScaler
    # This is needed for reproducibility in case we run this code with different GPUS

    n_layer = 2
    # 1 head per every 64 embd as rule of thumb
    n_head = 2
    n_embd = 128
    bias = False
    dropout = 0.0

    # model init
    model_args = dict(n_embd=n_embd, block_size=block_size, bias=bias, dropout=dropout, vocab_sizes=vocab_sizes, n_cat_features=len(cat_vars), n_cont_features=len(cont_vars), 
                      n_head=n_head, n_responder_features=len(responder_vars), n_layer=n_layer, vocab_size=64, n_iters=max_iters, 
                     downsampled_block_size=downsampled_block_size, features_block_size=features_block_size, responder_block_size=responder_block_size)

    # This sets the matrix calculations precision to tensorfloat 32, which speeds up computation by a lot, with negligible cost for precision
    # Check https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html for more info
    torch.set_float32_matmul_precision('high')
    conf = ModelConfig(**model_args)
    model = Generator(conf)
    print(model_args)

    model.to(device)

    train_config = {
        'vocab_size': 64, # Usually block_size // 2
        'downsampled_block_size': downsampled_block_size,
        'batch_size': batch_size,
        'block_size': block_size,
        'responder_block_size': responder_block_size,
        'features_block_size': features_block_size,
        'model_args': model_args,
        'eval_interval': eval_interval,
        'eval_iters': eval_iters,
        'log_interval': log_interval,
        'learning_rate': learning_rate,
        'always_save_checkpoint': always_save_checkpoint,
        'gradient_accumulation_steps': gradient_accumulation_steps,
        'n_epochs': n_epochs,
        'max_iters': max_iters,
        'warmup_iters': warmup_iters,
        'lr_decay_iters': lr_decay_iters,
        'min_lr': min_lr,
        'beta1': beta1,
        'beta2': beta2,
        'weight_decay': weight_decay,
        'dtype': dtype,
        'decay_lr': decay_lr,
        'eval_only': eval_only,
        'grad_clip': grad_clip,
        'device': device,
        'vocab_sizes': vocab_sizes
    }

    return model, train_config

@torch.no_grad()
def estimate_loss(model, ctx, train_config, it):
    """We will take the last 4 batches of training and the full validation set and see how they perform on a fully auto regressive sequence"""
    global VALID_COUNT_1
    global VALID_COUNT_2 
    global TRAIN_COUNT_1
    global TRAIN_COUNT_2
    out = {}
    model.eval()
    batch_train_count_1, batch_train_count_2 = TRAIN_COUNT_1, TRAIN_COUNT_2
    batch_size = train_config['batch_size'] * 4
    TRAIN_COUNT_1 -= batch_size
    TRAIN_COUNT_2 -= batch_size
    for split in ['train', 'val']:
        losses = torch.zeros(train_config['eval_iters'])
        can_cont = False
        prev_idxs = None
        all_logits = []
        all_targets = []
        all_weights = []
        # Create numpy arrays directly instead of accumulating lists
        xpoints = np.zeros(train_config['eval_iters'])
        ypoints = np.zeros(train_config['eval_iters'])

        for k in range(train_config['eval_iters']):
            if not can_cont:
                X, Y, cat, cont, nan_mask, w, t_ids, rs, prev_idxs, can_cont = get_batch_new_ss(
                    split, batch_size, train_config['block_size'], 
                    train_config['features_block_size'], 
                    train_config['responder_block_size'], 
                    train_config['device'], 
                    prev_idxs=None, 
                    ss=can_cont
                )
                synthetic_flag = torch.zeros(
                    (batch_size, train_config['downsampled_block_size']), 
                    dtype=torch.int64, 
                    device=train_config['device']
                )
            else:
                _, Y, cat, cont, nan_mask, w, t_ids, rs, prev_idxs, can_cont = update_batch(
                    X, cat, cont, nan_mask, t_ids, rs
                )
                X = torch.cat([sampled[:, 1:], logits_detached], dim=1)

            with ctx:
                logits = model(
                    X, cat, cont, nan_mask, rs, t_ids, 
                    targets=Y, weights=w, sampling_prob=1.0, 
                    synthetic_flag=synthetic_flag
                )

                # Store tensors directly without converting to CPU
                all_logits.append(logits)
                all_targets.append(Y)
                all_weights.append(w)

                # Store plotting data directly in numpy arrays
                xpoints[k] = logits[0].item()  # Assumes single value needed
                ypoints[k] = Y[0, 0].item()    # Assumes single value needed

                sampled = X.clone()
                logits_detached = logits.squeeze(-1).detach()
                synthetic_flag = torch.cat([
                    synthetic_flag[:, 1:], 
                    torch.ones((batch_size, 1), dtype=torch.int64, device=train_config['device'])
                ], dim=1)

        # Plot using the numpy arrays directly
        plt.plot(xpoints, label='valid pred')
        plt.plot(ypoints, label='valid real')
        plt.legend()
        plt.show()

        # Concatenate all tensors while keeping them on GPU
        combined_logits = torch.cat(all_logits, dim=0)
        combined_targets = torch.cat(all_targets, dim=0)
        combined_weights = torch.cat(all_weights, dim=0)

        # Calculate loss
        loss = zero_weighted_rsquared(
            combined_logits.view(-1), 
            combined_targets.view(-1), 
            combined_weights.view(-1)
        )
        out[split] = loss.item()
    model.train()
    VALID_COUNT_1 = 0
    VALID_COUNT_2 = 0
    TRAIN_COUNT_1 = batch_train_count_1
    TRAIN_COUNT_2 = batch_train_count_2
    return out


import os
from typing import List

@dataclass
class ModelConfig:
    vocab_sizes: list
    block_size: int = 896
    n_embd: int = 128
    dropout: float = 0.0
    bias: bool = False # True: bias in Linears and LayerNorms. False: a bit better and faster
    n_cat_features: int = 0
    n_cont_features: int = 0
    n_responder_features: int = 9
    n_head: int = 2
    n_layer: int = 2
    vocab_size: int = 64
    n_iters: int = 0
    downsampled_block_size: int = 128
    features_block_size: int = 32
    responder_block_size: int = 32

In [None]:
# We group the days depending on their length. 
# We randomly choose between using one or the other sets for each batch, and then we keep updating the batch until we finish the day.

# We will use for context the last features_block_size. 
# For the responders we will take the current time, look it up on the day before and take the responder_block_size//2 before and after.
# For example, if we are at time 200 and responder_block_size = 64 we will take all the responders from time 168 to 232.

# For the decoded sequence we want to contain as much information from that day into the sequence, so we downsample it using mean pooling.
# I noticed that using mean pooling on the sequence the r2 hovers around 0.85, so it does not loose too much information,
#   and this means we can contain a sequence of length of around 850-950 in 128 tokens.

BATCH_INFO = []
def get_batch_new_ss(split, batch_size, block_size, features_block_size, responder_block_size, device='cuda', prev_idxs=None, ss=False, random=False):
    # Pre-calculate constants
    global CHOSEN_TIME
    global BATCH_INFO
    BATCH_INFO.clear()
    downsample_period = block_size // 128
    seq_block_size = 128
    cutoff = 848 if CHOSEN_TIME == 0 else 967
    
    # Initialize tensors on CPU with proper memory layout. Change bfloat16 for float16 if not supported.
    batch_tensors = {
        'X': torch.zeros((batch_size, seq_block_size), dtype=torch.bfloat16),
        'Y': torch.zeros((batch_size, 1), dtype=torch.bfloat16),
        'cat': torch.zeros((batch_size, len(cat_vars), features_block_size), dtype=torch.int64),
        'cont': torch.zeros((batch_size, len(cont_vars), features_block_size), dtype=torch.bfloat16),
        'nan_cont': torch.zeros((batch_size, len(cont_vars), features_block_size), dtype=torch.int64),
        't_ids': torch.zeros((batch_size, features_block_size), dtype=torch.int64),
        'responders': torch.zeros((batch_size, len(responder_vars), responder_block_size), dtype=torch.bfloat16),
        'weights': torch.zeros((batch_size, 1), dtype=torch.bfloat16)
    }
    
    # Batch processing variables
    idxs = []
    can_cont = True
    
    # Choose time period if not in sequential mode
    if not ss:
        total_samples = len(idxs_train_1) + len(idxs_train_2)
        p1 = len(idxs_train_1) / total_samples
        CHOSEN_TIME = np.random.choice([0, 1], p=[p1, 1-p1])
    
    # Process each item in batch
    for b in range(batch_size):
        b_info = {}
        # Get index and asset
        if ss:
            ix, asset = prev_idxs[b][0] + downsample_period, prev_idxs[b][1]
        else:
            ix, asset = _get_index_and_asset(split, random, cutoff)
        # print(ix, asset)
        b_info['idx'] = ix
        b_info['asset'] = asset
        b_info['offset'] = downsample_period
        idxs.append((ix, asset))
        
        # Get data frames and process sequence
        df = DFS_USED[asset]
        dates = DATES_USED[asset]
        sequence_data = df.slice(ix-block_size, block_size+downsample_period)
        
        # Process time and date information
        date = sequence_data['date_id'].to_numpy()[-1]
        time = sequence_data['time_id'].to_numpy()[-downsample_period]
        b_info['time'] = time
        
        if time + downsample_period*2 > cutoff:
            can_cont = False
            
        # Process responder data
        prev_date, second_prev_date = _get_previous_date(dates, date)
        responders_data = _process_responders(df, ix, prev_date, second_prev_date, time, responder_block_size, b_info)
        batch_tensors['responders'][b] = torch.tensor(
            responders_data[responder_vars].to_numpy().T[-responder_block_size:]
        )
        
        # Process main sequence data
        sequence = sequence_data['responder_6'].to_numpy().flatten()
        downsampled = downsample_array(sequence, downsample_period)
        batch_tensors['X'][b] = torch.tensor(downsampled[-(seq_block_size+1):-1].astype(np.float32))
        batch_tensors['Y'][b] = torch.tensor(np.clip(downsampled[-1], -5, 5).astype(np.float32))
        
        # Process time IDs
        time_slice = slice(-(features_block_size+downsample_period-1), -(downsample_period-1))
        batch_tensors['t_ids'][b] = torch.tensor(sequence_data['time_id'][time_slice].to_numpy())
        
        # Process categorical variables
        for f_id, f in enumerate(cat_vars):
            mapped_cat = np.vectorize(mappings[f].get)(sequence_data[f].to_numpy()[time_slice])
            batch_tensors['cat'][b, f_id] = torch.tensor(mapped_cat)
        
        # Process continuous variables
        cont_data = sequence_data[time_slice]
        for f_id, f in enumerate(cont_vars):
            seq = cont_data[f].to_numpy()
            batch_tensors['cont'][b, f_id] = torch.tensor(np.nan_to_num(seq).astype(np.float32))
            batch_tensors['nan_cont'][b, f_id] = torch.tensor((~np.isnan(seq)).astype(np.int64))
        
        # Process weights
        batch_tensors['weights'][b] = torch.tensor(
            np.mean(sequence_data['weight'][-downsample_period:].to_numpy().flatten())
        )
        BATCH_INFO.append(b_info)
    
    # Move tensors to device efficiently
    if device == 'cuda':
        device_tensors = {
            k: v.pin_memory().to(device, non_blocking=True) 
            for k, v in batch_tensors.items()
        }
    else:
        device_tensors = {
            k: v.to(device) for k, v in batch_tensors.items()
        }
    
    return (
        device_tensors['X'], device_tensors['Y'], device_tensors['cat'],
        device_tensors['cont'], device_tensors['nan_cont'], device_tensors['weights'],
        device_tensors['t_ids'], device_tensors['responders'], idxs, can_cont
    )

def _get_index_and_asset(split, random, cutoff):
    """Helper function to get index and asset based on split type"""
    global TRAIN_COUNT_1, TRAIN_COUNT_2, VALID_COUNT_1, VALID_COUNT_2, EPOCH_COUNT
    
    # Reset counters if needed
    if TRAIN_COUNT_1 >= len(idxs_train_1):
        TRAIN_COUNT_1 = 0
        EPOCH_COUNT += 1
    if TRAIN_COUNT_2 >= len(idxs_train_2):
        TRAIN_COUNT_2 = 0
        EPOCH_COUNT += 1
    if VALID_COUNT_1 >= len(idxs_valid_1):
        VALID_COUNT_1 = 0
    if VALID_COUNT_2 >= len(idxs_valid_2):
        VALID_COUNT_2 = 0
    
    # Get appropriate index and asset
    if split == 'train':
        if CHOSEN_TIME == 0:
            if random:
                ix, asset = idxs_train_1[np.random.choice(len(idxs_train_1))]
                ix = ix + np.random.randint(cutoff-1)
            else:
                ix, asset = idxs_train_1[TRAIN_COUNT_1]
                TRAIN_COUNT_1 += 1
        else:
            if random:
                ix, asset = idxs_train_2[np.random.choice(len(idxs_train_2))]
                ix = ix + np.random.randint(cutoff-1)
            else:
                ix, asset = idxs_train_2[TRAIN_COUNT_2]
                TRAIN_COUNT_2 += 1
    else:
        if CHOSEN_TIME == 0:
            ix, asset = idxs_valid_1[VALID_COUNT_1]
            VALID_COUNT_1 += 1
        else:
            ix, asset = idxs_valid_2[VALID_COUNT_2]
            VALID_COUNT_2 += 1
            
    return ix, asset

def _get_previous_date(dates, date):
    """Helper function to get the previous date"""
    prev_date = None
    second_prev_date = None
    for d in dates:
        if d == date:
            break
        second_prev_date = prev_date
        prev_date = d
    return prev_date, second_prev_date

def _process_responders(df, ix, prev_date, second_prev_date, t, responder_block_size, b_info):
    """Helper function to process responder data"""
    prev_day = df.filter(pl.col("date_id") <= prev_date, pl.col("date_id") >= second_prev_date).drop(['index']).with_row_index()
    b_info['resp'] = prev_day
    to_t = min(t + responder_block_size//2, len(prev_day))
    idx_to_cut = prev_day.filter(pl.col("time_id") < to_t)['index'][-1]
    b_info['resp_cut'] = idx_to_cut
    b_info['resp_max_index'] = prev_day.shape[0] - 1
    to_return = prev_day.slice(idx_to_cut - responder_block_size + 1, responder_block_size)
    return to_return

def update_batch(x, cat, cont, nan_cont, t_ids, responders):
    """We use the batch information generated to increase the idx by 1 and continue the sequence in a compute efficient way"""
    global BATCH_INFO
    batch_size, block_size = x.shape[0], 896
    features_block_size = cont.shape[2]
    responder_block_size = responders.shape[2]
    device = x.device
    downsample_period = block_size // 128
    seq_block_size = 128
    cutoff = 848 if CHOSEN_TIME == 0 else 967
    
    new_x = torch.zeros((batch_size, 1), dtype=torch.bfloat16, device=device)
    new_y = torch.zeros((batch_size, 1), dtype=torch.bfloat16, device=device)
    new_weights = torch.zeros((batch_size, 1), dtype=torch.bfloat16, device=device)
    # new_t_ids = torch.zeros((batch_size, downsample_period), dtype=torch.int64)
    new_cat = torch.zeros((batch_size, len(cat_vars), downsample_period), dtype=torch.int64, device=device)
    new_cont = torch.zeros((batch_size, len(cont_vars), downsample_period), dtype=torch.bfloat16, device=device)
    new_nan_cont = torch.zeros((batch_size, len(cont_vars), downsample_period), dtype=torch.int64, device=device)
    new_rs = torch.zeros((batch_size, len(responder_vars), downsample_period), dtype=torch.bfloat16, device=device)
    
    # Batch processing variables
    idxs = []
    can_cont = True
    
    # Process each item in batch
    for b in range(batch_size):
        # Get index and asset
        ix, asset = BATCH_INFO[b]['idx'] + BATCH_INFO[b]['offset'], BATCH_INFO[b]['asset']
        # print(ix, asset)
        idxs.append((ix, asset))
        
        # Get data frames and process sequence
        df = DFS_USED[asset]
        dates = DATES_USED[asset]
        sequence_data = df.slice(ix-downsample_period, downsample_period*2)
        t = BATCH_INFO[b]['time'] + BATCH_INFO[b]['offset']
        if t + downsample_period*2 > cutoff:
            can_cont = False
        
        new_x[b] = torch.tensor(np.mean(sequence_data['responder_6'].to_numpy()[:downsample_period]).astype(np.float32))
        new_y[b] = torch.tensor([np.mean(sequence_data['responder_6'].to_numpy()[-downsample_period:])])
        new_weights[b] = torch.tensor([np.mean(sequence_data['weight'].to_numpy()[-downsample_period:])])
        for f_id, f in enumerate(cat_vars):
            mapped_cat = np.vectorize(mappings[f].get)(sequence_data[f].to_numpy()[1:downsample_period+1])
            new_cat[b, f_id] = torch.tensor(mapped_cat)

        cont_data = sequence_data[1:downsample_period+1]
        for f_id, f in enumerate(cont_vars):
            seq = cont_data[f].to_numpy()
            new_cont[b, f_id] = torch.tensor(np.nan_to_num(seq).astype(np.float32))
            new_nan_cont[b, f_id] = torch.tensor((~np.isnan(seq)).astype(np.int64))
        
        prev_day = BATCH_INFO[b]['resp']
        idx_to_cut = BATCH_INFO[b]['resp_cut'] + BATCH_INFO[b]['offset']
        max_index = BATCH_INFO[b]['resp_max_index']
        
        if idx_to_cut > max_index:
            period = max(0, downsample_period - (idx_to_cut - max_index))
        else:
            period = downsample_period
        idx_to_cut = min(idx_to_cut, max_index)
        responders_data = prev_day.slice(idx_to_cut - responder_block_size + 1, responder_block_size)
        #print(responders_data[responder_vars].to_numpy().T[:, -downsample_period:])
        new_rs[b] = torch.tensor(
            responders_data[responder_vars].to_numpy().T[:, -downsample_period:], dtype=torch.bfloat16, device=device
        )
        BATCH_INFO[b]['offset'] += downsample_period
    x = torch.cat([x[:, 1:], new_x], dim=1)
    cat = torch.cat([cat[:, :, downsample_period:], new_cat], dim=2)
    t_ids = torch.cat([
        t_ids[:, downsample_period:], 
        torch.tensor(
            sequence_data['time_id'][1:downsample_period+1].to_numpy(), dtype=torch.int64, device=device
        ).repeat(batch_size, 1)
    ], dim=1)
    cont = torch.cat([cont[:, :, downsample_period:], new_cont], dim=2)
    nan_cont = torch.cat([nan_cont[:, :, downsample_period:], new_nan_cont], dim=2)
    if period > 0:
        responders = torch.cat([responders[:, :, period:], new_rs[:, :, -period:]], dim=2)
    return (
        x, new_y, cat,
        cont, nan_cont, new_weights,
        t_ids, responders, idxs, can_cont
    )

In [None]:
# model, train_config = prepare_model()

In [None]:
def get_ss_prob(it, train_config):
    # no-op
    return train_config['downsampled_block_size']

def get_lr(it, train_config):
    # 1) linear warmup for warmup_iters steps
    if it < train_config['warmup_iters']:
        return train_config['learning_rate'] * it / train_config['warmup_iters']
    # 2) if it > lr_decay_iters, return min learning rate
    if it > train_config['lr_decay_iters']:
        return train_config['min_lr']
    # 3) in between, use cosine decay down to min learning rate
    decay_ratio = (it - train_config['warmup_iters']) / (train_config['lr_decay_iters'] - train_config['warmup_iters'])
    assert 0 <= decay_ratio <= 1
    coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # coeff ranges 0..1
    return train_config['min_lr'] + coeff * (train_config['learning_rate'] - train_config['min_lr'])

def main(init_mode='scratch', save_at_end=True, savename='ckpt', out_dir='models', load_only=False, checkpoint_name='ckpt.pt'):
    global EPOCH_COUNT
    global TRAIN_COUNT_1
    global VALID_COUNT_1
    global TRAIN_COUNT_2
    global VALID_COUNT_2
    global idxs_train_1
    global idxs_train_2
    
    if save_at_end:
        assert isinstance(savename, str), "The name of the model hast to be a str"
    
    if init_mode == 'resume':
        print(f"Resuming training from {out_dir}")
        # resume training from a checkpoint.
        ckpt_path = os.path.join(out_dir, checkpoint_name)
        checkpoint = torch.load(ckpt_path, map_location='cuda')
        #_, train_config = prepare_model()
        train_config = checkpoint['train_config']
        
        # In case the checkpoint was not correctly saved
        if 'vocab_sizes' not in train_config:
            train_config['vocab_sizes'] = vocab_sizes
        if 'model_args' not in train_config:
            model_args = dict(n_embd=64, block_size=896, bias=False, dropout=0.0, vocab_sizes=vocab_sizes, n_cat_features=len(cat_vars), n_cont_features=len(cont_vars))
            train_config['model_args'] = model_args
            
        torch.set_float32_matmul_precision('high')
                
        model_args = dict(n_embd=train_config['model_args']['n_embd'], block_size=train_config['model_args']['block_size'], bias=train_config['model_args']['bias'],
                          features_block_size=train_config['features_block_size'], responder_block_size=train_config['responder_block_size'],
                          dropout=train_config['model_args']['dropout'], vocab_sizes=train_config['vocab_sizes'], n_cat_features=len(cat_vars), n_cont_features=len(cont_vars),
                          n_head=train_config['model_args']['n_head'], n_layer=train_config['model_args']['n_layer'], n_iters=train_config['max_iters'], 
                          downsampled_block_size=train_config['model_args']['downsampled_block_size'])
        
        conf = ModelConfig(**model_args)
        model = Generator(conf)
        state_dict = checkpoint['model']

        # Not sure why the model is saved with this prefix...
        unwanted_prefix = '_orig_mod.'
        for k,v in list(state_dict.items()):
            if k.startswith(unwanted_prefix):
                state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
        state_dict = {k: v for k, v in state_dict.items() if k in model.state_dict()}
        model.load_state_dict(state_dict, strict=True)
        
        if load_only:
            model.to(train_config['device'])
            return model
        
        iter_num = checkpoint['iter_num'] + 1
        best_val_loss = checkpoint['best_val_loss']
        EPOCH_COUNT = checkpoint['epoch_count']
        TRAIN_COUNT_1 = checkpoint['train_count_1']
        VALID_COUNT_1 = checkpoint['valid_count_1']
        TRAIN_COUNT_2 = checkpoint['train_count_2']
        VALID_COUNT_2 = checkpoint['valid_count_2']
        print(f"Epoch: {EPOCH_COUNT}, Train sample: {TRAIN_COUNT_1}, Valid sample: {VALID_COUNT_1}. Starting at {iter_num}")
    else:
        model, train_config = prepare_model()
        iter_num = 0
        best_val_loss = 1e9
        
    print("Testing batch...")
    X, Y, cat, cont, nan_mask, w, t_ids, rs, prev_idxs, can_cont = get_batch_new_ss('train', train_config['batch_size'], train_config['block_size'], train_config['features_block_size'], train_config['responder_block_size'], train_config['device'])
    print(X.shape, Y.shape, cat.shape, cont.shape, nan_mask.shape, t_ids.shape, rs.shape, can_cont)

    assert Y.size() == (train_config['batch_size'], 1)
    assert X.size() == (train_config['batch_size'],train_config['downsampled_block_size'])
    assert cat.size() == (train_config['batch_size'], len(cat_vars),train_config['features_block_size'])
    assert cont.size() == (train_config['batch_size'], len(cont_vars), train_config['features_block_size'])
    assert nan_mask.size() == (train_config['batch_size'], len(cont_vars), train_config['features_block_size'])
    assert w.size() == (train_config['batch_size'],1)
    assert t_ids.size() == (train_config['batch_size'],train_config['features_block_size'])
    assert rs.size() == (train_config['batch_size'],len(responder_vars), train_config['responder_block_size'])

    print("Test succesful!")

    # We pass the model to the device
    model.to(train_config['device'])
    # initialize a GradScaler. If enabled=False scaler is a no-op
    scaler = torch.amp.GradScaler('cuda', enabled=(train_config['dtype'] == 'float16'))

    # optimizer
    optimizer = model.configure_optimizers(train_config['weight_decay'], train_config['learning_rate'], (train_config['beta1'], train_config['beta2']), train_config['device'])
    if init_mode == 'resume':
        print("Loading optimizer...")
        optimizer.load_state_dict(checkpoint['optimizer'])

    # NOTE: Compile the model if we are running in a gpu!
    # If using RoPE the compiling raises some warnings, just ignore them...
    if True: 
        print("compiling the model... (takes a ~minute)")
        model = torch.compile(model) # requires PyTorch 2.0

    t0 = time.time()
    raw_model = model
    outs = []
    ss_counter = 0
    # The synthetic flag is a tensor that represents if a token in the decoded sequence is model generated or is a golden token.
    synthetic_flag = torch.zeros((train_config['batch_size'], train_config['downsampled_block_size']), dtype=torch.int64, device=train_config['device'])
    update_sample = False 
    xpoints = []
    ypoints = []
    wpoints = []
    scheduled_sampling_enabled = True
    ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[train_config['dtype']]
    ctx = nullcontext() if train_config['device'] == 'cpu' else torch.amp.autocast(device_type=train_config['device'], dtype=ptdtype)
    # In case we want to cut schedule sampling at some point to avoid fully model-generated sequences.
    # Since the scheduler is exponential I found that cutting at 0.5 is a good compromise.
    cutoff_pct = 1 # 0.5
    cutoff_iter = int(train_config['max_iters']*cutoff_pct)
    print(f"Cutting scheduled sampling at {cutoff_iter}")
    print("#### Starting Training ####")

    while True:

        # determine and set the learning rate for this iteration
        lr = get_lr(iter_num, train_config) if train_config['decay_lr'] else train_config['learning_rate']
        sampled_tokens = int(get_ss_prob(iter_num, train_config))
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr
            
        # evaluate the loss on train/val sets and write checkpoints
        if iter_num % train_config['eval_interval'] == 0 and iter_num > 0:
            losses =  estimate_loss(model, train_config, iter_num)
            val_line = f"step {iter_num}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}"
            outs.append(val_line)
            print(val_line)
            can_cont = False
            if losses['val'] < best_val_loss or train_config['always_save_checkpoint']:
                best_val_loss = losses['val']
                with open('out.txt', 'w') as file:
                    file.write('\n'.join(outs) + '\n')
                # Uncomment this for checkpoint saving
                if iter_num > 0:
                   checkpoint = {
                       'model': raw_model.state_dict(),
                       'optimizer': optimizer.state_dict(),
                       'iter_num': iter_num,
                       'best_val_loss': best_val_loss,
                       'epoch_count': EPOCH_COUNT,
                       'train_count_1': TRAIN_COUNT_1,
                       'valid_count_1': VALID_COUNT_1,
                       'train_count_2': TRAIN_COUNT_2,
                       'valid_count_2': VALID_COUNT_2,
                       'train_config': train_config,
                       'mappings': mappings,
                       'cat_ids': cat_ids
                   }
                   print(f"saving checkpoint to {out_dir}")
                   torch.save(checkpoint, os.path.join(out_dir, f'ckpt_{iter_num}_2.pt'))
        if iter_num == 0 and train_config['eval_only']:
            break

        # This does {gradient_accumulation_steps} of forward steps, and performs a backwards pass after that.
        # Depending on the iter_num and decoded_step the model will have a higher probability to use its own predictions as context. 
        # This is called scheduled sampling, and is used in order to reduce exposure bias in auto regressive inference scenarios. 
        # (https://arxiv.org/abs/1506.03099 for the paper introducting scheduled sampling, https://aclanthology.org/2021.emnlp-main.264.pdf for the introduction of decoder-step-based probabilities)
        # Since the loss function is a weighted r2, it is imperative that the gradient_accumulation_steps are high (probably at least 32).
        # This is because in smaller values the loss can be quite jumpy (for example an "easy" sequence can have a loss of 0.5, while a "difficult" one can go up to 1), which makes the learning difficult.
        all_logits = []
        all_targets = []
        all_weights = []
        
        for micro_step in range(train_config['gradient_accumulation_steps']):
            # The max decoder step can count as an hyperparameter to tune.
            # In a case where the information of the sequence is lost on the first tokens is sensible to set it low so the model sampling probability is high.
            sampling_prob = composite_schedule_sampling(min(cutoff_iter, iter_num), max(ss_counter, 10), train_config['max_iters'], 7)

            with ctx:
                logits = model(X, cat, cont, nan_mask, rs, t_ids,
                               it=iter_num, synthetic_flag=synthetic_flag)
                if not scheduled_sampling_enabled:
                    # If we don't use scheduled sampling we can calculate the loss and do the backwards pass in the loop.
                    loss = zero_weighted_rsquared(logits.view(-1), Y.view(-1), w.view(-1))
                    loss = loss / train_config['gradient_accumulation_steps'] # scale the loss to account for gradient accumulation
                else:
                    # Store predictions and targets
                    all_logits.append(logits)
                    all_targets.append(Y)
                    all_weights.append(w)
                
            if not can_cont:
                ss_counter = 0
                # Uncomment this to see the generated predictions.
                #plt.plot(xpoints, label='pred')
                #plt.plot(ypoints, label='real')
                #plt.legend()
                #plt.show()
                #print("Loss: ", zero_weighted_rsquared_np(np.array(xpoints).flatten(), np.array(ypoints).flatten(), np.array(wpoints).flatten()))
                #xpoints = []
                #ypoints = []
                #wpoints = []
                synthetic_flag = torch.zeros((train_config['batch_size'], 
                                            train_config['downsampled_block_size']), 
                                           dtype=torch.int64, 
                                           device=train_config['device'])
                update_sample = False
                
            if sampled_tokens > 0 and can_cont:
                sampled = X.clone()
                update_sample = True
                logits_detached = logits.squeeze(-1).detach()
                
            if sampled_tokens == 0:
                can_cont = False
                
            # Uncomment this to see the generated predictions
            #if update_sample:
            #    xpoints.append(logits_detached[0].cpu().numpy())
            #    ypoints.append(Y.to(dtype=torch.float32).detach().cpu().numpy()[0][0])
            #    wpoints.append(w.to(dtype=torch.float32).detach().cpu().numpy()[0][0])

            if can_cont:
                X, Y, cat, cont, nan_mask, w, t_ids, rs, prev_idxs, can_cont = update_batch(
                    X.detach(),
                    cat.detach(),
                    cont.detach(),
                    nan_mask.detach(),
                    t_ids.detach(),
                    rs.detach()
                )
            else:
                X, Y, cat, cont, nan_mask, w, t_ids, rs, prev_idxs, can_cont = get_batch_new_ss(
                    'train', 
                    train_config['batch_size'], 
                    train_config['block_size'], 
                    train_config['features_block_size'], 
                    train_config['responder_block_size'], 
                    train_config['device'], 
                    prev_idxs, 
                    ss=can_cont,
                    random=False
                )
            
            if update_sample:
                ss_counter += 1
                # Random values for each sequence
                use_prediction = torch.rand((train_config['batch_size'], 1), device=train_config['device'])
                # Whether we want to use the golden tokens or the model prediction
                to_cat = torch.where(use_prediction < sampling_prob, logits_detached, X[:, [-1]])
                X = torch.cat([sampled[:, 1:], to_cat], dim=1)
                
                synthetic_flag = torch.cat([
                    synthetic_flag[:, 1:],
                    torch.where(
                        use_prediction < sampling_prob,
                        1,  # model generated context
                        0   # golden token
                    )
                ], dim=1)

            if not scheduled_sampling_enabled:    
                scaler.scale(loss).backward()
                
        # Combine all stored tensors
        if scheduled_sampling_enabled:
            combined_logits = torch.cat(all_logits, dim=0)
            combined_targets = torch.cat(all_targets, dim=0)
            combined_weights = torch.cat(all_weights, dim=0)

            loss = zero_weighted_rsquared(combined_logits.view(-1), combined_targets.view(-1), combined_weights.view(-1))
        
            scaler.scale(loss).backward()

        if train_config['grad_clip'] != 0.0:
            scaler.unscale_(optimizer)
            norm = torch.nn.utils.clip_grad_norm_(model.parameters(), train_config['grad_clip'])
        else:
            norm = 0

        scaler.step(optimizer)
        scaler.update()
        optimizer.zero_grad(set_to_none=True)

        # timing and logging
        t1 = time.time()
        dt = t1 - t0
        t0 = t1
        if iter_num % train_config['log_interval'] == 0:
            mul = train_config['gradient_accumulation_steps'] if not scheduled_sampling_enabled else 1
            lossf = loss.item() * mul
            max_iters = train_config["max_iters"]
            log_line = f"iter {iter_num}: loss {lossf:.4f}, time {dt*1000:.2f}ms, norm: {norm:.4f}, sampling prob: {sampling_prob}"
            outs.append(log_line)
            print(log_line)
        iter_num += 1

        # termination conditions
        if iter_num > train_config['max_iters']:
            break
            
    if save_at_end:
        os.makedirs(out_dir, exist_ok=True)
        checkpoint = {
            'model': raw_model.state_dict(),
            'optimizer': optimizer.state_dict(),
            'iter_num': iter_num,
            'best_val_loss': best_val_loss,
            'epoch_count': EPOCH_COUNT,
            'train_count_1': TRAIN_COUNT_1,
            'valid_count_1': VALID_COUNT_1,
            'train_count_2': TRAIN_COUNT_2,
            'valid_count_2': VALID_COUNT_2,
            'train_config': train_config,
            'mappings': mappings,
            'cat_ids': cat_ids
        }
        print(f"saving checkpoint to {out_dir}")
        torch.save(checkpoint, os.path.join(out_dir, f'{savename}.pt'))
        
    with open('out.txt', 'w') as file:
        file.write('\n'.join(outs) + '\n')
        
    return model

In [None]:
model = main(init_mode='scratch', save_at_end=True, savename='ckpt_final', load_only=False, checkpoint_name='full_1/ckpt_10574_2.pt')

In [None]:
model.eval()