<a href="https://colab.research.google.com/github/satyabratkumarsingh/option-portfolio-encoder-decoder/blob/main/SetTransformer_Optimized_Standarization.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
!pip install torch
!pip install comet_ml
!pip install tqdm
!pip install matplotlib



#Mount to Google drive

In [3]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


#Deele scalars

In [4]:
import os
def delete_file_from_drive(full_file_path):
  if os.path.exists(full_file_path):
      try:
          os.remove(full_file_path)
          print(f"File '{full_file_path}' successfully deleted from Google Drive.")
      except Exception as e:
          print(f"Error deleting file '{full_file_path}': {e}")
  else:
      print(f"File '{full_file_path}' not found at '{full_file_path}'.")


# Portfolio Generation functions

In [5]:
import random
import numpy as np
import torch
import itertools
from itertools import product
from torch.utils.data import Dataset, DataLoader
import gc # For garbage collection

# Parameters
MU = 0.05
SIGMA = 0.2
T = 1.0 # Time to maturity
NOISE_STD = 0.01
MIN_PRICE_RANGE = 100
MAX_PRICE_RANGE = 1000

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def generate_option_prices_for_idx(idx, n, weights=None):
    # Use numpy's default_rng for better random number generation and seeding
    rng = np.random.default_rng(idx)
    # Use torch.manual_seed for PyTorch operations
    torch.manual_seed(idx)
    if DEVICE.type == 'cuda':
        torch.cuda.manual_seed_all(idx)

    # Generate S_0
    random_number = rng.integers(MIN_PRICE_RANGE, MAX_PRICE_RANGE + 1) # +1 because randint is inclusive
    min_price = random_number
    max_price = random_number + 5
    S_0 = rng.uniform(min_price, max_price)

    # Generate option types
    option_types = rng.choice(["call", "put"], size=n)
    option_types_numeric = np.where(option_types == "call", 1, 0).astype(np.float32) # Ensure float32

    # Generate X_prices (strike prices)
    K_prices = np.zeros(n, dtype=np.float32) # Ensure float32
    for i in range(n):
        K_prices[i] = S_0 * rng.uniform(0.90, 1.20)

    # Generate or use weights
    if weights is None:
        # If weights are not provided, generate them using generate_combinatorial_weights_manageable
        weight_sets = generate_combinatorial_weights_manageable(n)
        weights_array = weight_sets[0]  # Use the first (and only) set of weights
    else:
        weights_array = np.array(weights, dtype=np.float32)  # Ensure float32

    return K_prices, option_types_numeric, S_0, weights_array


def generate_combinatorial_weights_manageable(n, base_weights=[-0.75, -0.5, -0.25, 0, 0.25, 0.5, 0.75]):
    weight_sets = []

    # Handle the case where n < 2
    if n < 2:
        weights = np.zeros(n, dtype=np.float32)
        if n == 1:
            # If only one position, assign a long position (1.0)
            weights[0] = 1.0
        weight_sets.append(weights)
        return weight_sets

    # Generate a single portfolio: either one long or one short, and the rest from combinatorics
    weights = np.zeros(n, dtype=np.float32)

    # Randomly choose if we want a long or short portfolio
    is_long = random.choice([True, False])

    if is_long:
        # Choose one position to be long (1.0)
        long_idx = random.randint(0, n - 1)
        weights[long_idx] = 1.0
    else:
        # Choose one position to be short (-1.0)
        short_idx = random.randint(0, n - 1)
        weights[short_idx] = -1.0

    # Fill remaining positions with combinatorial weights from base_weights
    remaining_positions = [i for i in range(n) if weights[i] == 0]  # Find positions not yet filled
    combinatorics = np.random.choice(base_weights, size=len(remaining_positions), replace=True)

    # Assign combinatorial weights to the remaining positions without normalization
    weights[remaining_positions] = combinatorics

    weight_sets.append(weights)

    return weight_sets



import torch

def black_scholes_delta(S, K, T, r, sigma, option_type):
    """
    Computes the Black-Scholes delta for a call or put option.

    Args:
        S (Tensor): Spot price [any shape]
        K (Tensor): Strike price [same shape as S or broadcastable]
        T (float or Tensor): Time to maturity (scalar or broadcastable)
        r (float or Tensor): Risk-free rate (scalar or broadcastable)
        sigma (float or Tensor): Volatility (scalar or broadcastable)
        option_type (Tensor): 1 for call, 0 for put [same shape as S]

    Returns:
        Tensor: Delta of the option [same shape as S]
    """
    eps = 1e-8  # Numerical stability for sqrt
    device = S.device

    T = torch.as_tensor(T, device=device, dtype=S.dtype)
    r = torch.as_tensor(r, device=device, dtype=S.dtype)
    sigma = torch.as_tensor(sigma, device=device, dtype=S.dtype)

    d1 = (torch.log(S / K + eps) + (r + 0.5 * sigma ** 2) * T) / (sigma * torch.sqrt(T + eps))

    # More stable, recommended: torch.special.ndtr(d1) (if available)
    N_d1 = torch.distributions.Normal(0.0, 1.0).cdf(d1)

    delta = torch.where(option_type == 1, N_d1, N_d1 - 1.0)
    return delta


def compute_cashflow(portfolio, S_T, weights):
    strikes = portfolio[..., 0]
    types = portfolio[..., 1]
    weights = weights.to(DEVICE)

    # Compute option payoffs
    payoffs = torch.where(
        types == 1,
        torch.relu(S_T - strikes),
        torch.relu(strikes - S_T)
    )
    weighted_payoffs = payoffs * weights
    cashflow = weighted_payoffs.sum(dim=-1, keepdim=True)

    # --- Continuous Derivative (Black-Scholes Delta) ---
    delta = black_scholes_delta(
        S_T, strikes, T=T, r=MU, sigma=SIGMA, option_type=types
    )
    weighted_delta = weights * delta
    derivative = weighted_delta.sum(dim=-1, keepdim=True)

    return cashflow.to(torch.float32), derivative.to(torch.float32)


In [6]:

DRIVE_PATH = "/content/drive/MyDrive/Ucl/"
K_SCALAR_FILE = os.path.join(DRIVE_PATH, 'K_Scalar.pkl')
ST_SCALAR_FILE = os.path.join(DRIVE_PATH, 'S_T_scalar.pkl')

delete_file_from_drive(K_SCALAR_FILE)
delete_file_from_drive(ST_SCALAR_FILE)

File '/content/drive/MyDrive/Ucl/K_Scalar.pkl' successfully deleted from Google Drive.
File '/content/drive/MyDrive/Ucl/S_T_scalar.pkl' successfully deleted from Google Drive.


#Dataset

In [8]:

import os
import joblib
import numpy as np
from sklearn.preprocessing import StandardScaler, MinMaxScaler

DRIVE_PATH = "/content/drive/MyDrive/Ucl/"
K_SCALAR_FILE = os.path.join(DRIVE_PATH, 'K_Scalar.pkl')
ST_SCALAR_FILE = os.path.join(DRIVE_PATH, 'S_T_scalar.pkl')
SCALAR_SAMPLE_NO = 80000000



class OperatorDatasetStandarization(Dataset):

    def __init__(self, num_samples, portfolio_size, num_samples_S_T,
                 K_Scalar=None, S_T_scalar=None, cashflow_scaler=None,
                 is_fitting_mode=False): # NEW: Add is_fitting_mode

        self.num_samples = num_samples
        self.portfolio_size = portfolio_size
        self.num_samples_S_T = num_samples_S_T
        self.is_fitting_mode = is_fitting_mode # Store the mode

        # Load or generate K_Scalar and S_T_scalar
        if K_Scalar is None or S_T_scalar is None:
            if os.path.exists(K_SCALAR_FILE) and os.path.exists(ST_SCALAR_FILE):
                print("Loading K and S_T scalers from Google Drive...")
                self.K_Scalar = joblib.load(K_SCALAR_FILE)
                self.S_T_scalar = joblib.load(ST_SCALAR_FILE)
            else:
                print("Generating and saving new K and S_T scalers to Google Drive...")
                self.K_Scalar, self.S_T_scalar = self._generate_and_save_K_ST_scalers(SCALAR_SAMPLE_NO)
                print("***** Generated K_Scalar and S_T_Scalar and saved into Google drive")
        else:
            self.K_Scalar = K_Scalar
            self.S_T_scalar = S_T_scalar


        if not self.is_fitting_mode:
            if cashflow_scaler is None:
                 if os.path.exists(CASHFLOW_SCALAR_FILE):
                    print("Loading Cashflow scaler from Google Drive...")
                    self.cashflow_scaler = joblib.load(CASHFLOW_SCALAR_FILE)
                 else:
                    print("WARNING: Cashflow scaler not provided and not found on drive. "
                          "Cashflows will not be normalized. Consider running fitting process.")
                    self.cashflow_scaler = None
            else:
                self.cashflow_scaler = cashflow_scaler
        else:
            self.cashflow_scaler = None # No scaler in fitting mode


    def _generate_and_save_K_ST_scalers(self, num_samples):
        # Your existing code for generating and saving K_Scalar and S_T_scalar
        S_0_values = np.random.uniform(MIN_PRICE_RANGE, MAX_PRICE_RANGE, num_samples)
        K_prices = S_0_values * np.random.uniform(0.90, 1.20, num_samples)

        Z = np.random.randn(num_samples) # One random shock per S_T value
        S_T_values = S_0_values * np.exp((MU - 0.5 * SIGMA**2) * T + SIGMA * np.sqrt(T) * Z)

        K_scalar = MinMaxScaler()
        K_scalar.fit(K_prices.reshape(-1, 1))

        S_T_scalar = MinMaxScaler()
        S_T_scalar.fit(S_T_values.reshape(-1, 1))

        # Save the scalers
        joblib.dump(K_scalar, K_SCALAR_FILE)
        joblib.dump(S_T_scalar, ST_SCALAR_FILE)

        return K_scalar, S_T_scalar


    def __len__(self):
        """Returns the total number of samples in the dataset."""
        return self.num_samples

    def __getitem__(self, idx):
        """
        Generates and returns a single data sample (portfolio, S_T, cashflow, derivative).
        This method is called by the DataLoader.
        """

        K, option_types, S_0, weights = generate_option_prices_for_idx(
            idx, self.portfolio_size
        )

        portfolio_features_tensor = torch.stack([
            torch.tensor(K, dtype=torch.float32, device=DEVICE),
            torch.tensor(option_types, dtype=torch.float32, device=DEVICE),
            torch.tensor(weights, dtype=torch.float32, device=DEVICE)
        ], dim=-1)

        weights_i = torch.tensor(weights, dtype=torch.float32, device=DEVICE)
        K_i = torch.tensor(K, dtype=torch.float32, device=DEVICE)
        S_0_i = torch.tensor(S_0, dtype=torch.float32, device=DEVICE)
        Z = torch.clamp(torch.randn(self.num_samples_S_T, device=DEVICE), -3, 3)

        S_T_i = S_0_i * torch.exp((MU - 0.5 * SIGMA**2) * T + SIGMA * torch.sqrt(torch.tensor(T, device=DEVICE)) * Z)
        S_T_i += torch.randn_like(S_T_i, device=DEVICE) * (NOISE_STD * S_T_i)

        # Normalization for K and S_T (always apply if scalers are present and not in fitting mode)
        # Note: In fitting mode, K and S_T are still generated, but their *normalized* versions
        # are not what we're collecting for the cashflow scaler.
        if not self.is_fitting_mode:
            K_i_cpu = K_i.reshape(-1, 1).cpu()
            K_i_normalized = torch.tensor(self.K_Scalar.transform(K_i_cpu), dtype=torch.float32, device=DEVICE)
            S_T_i_normalized = torch.from_numpy(self.S_T_scalar.transform(S_T_i.cpu().numpy().reshape(-1, 1))).to(DEVICE).squeeze()
        else: # In fitting mode, just use raw K_i and S_T_i for generating raw cashflow
            K_i_normalized = K_i # This won't be used for input features directly, but kept for clarity
            S_T_i_normalized = S_T_i


        # Compute cashflow and derivative (ALWAYS raw when generated in __getitem__)
        cashflow_i_raw, derivative_i_raw = compute_cashflow(
            portfolio_features_tensor.expand(self.num_samples_S_T, -1, -1),
            S_T_i.unsqueeze(-1),
            weights_i.expand(self.num_samples_S_T, -1)
        )

        # ===== CASHFLOW NORMALIZATION (Apply only if a scaler is provided and not in fitting mode) =====
        if self.cashflow_scaler is not None and not self.is_fitting_mode:
            # Convert to numpy on CPU for scaler, then back to tensor
            cashflow_i_normalized_np = self.cashflow_scaler.transform(cashflow_i_raw.cpu().numpy().reshape(-1, 1))
            cashflow_i_to_return = torch.from_numpy(cashflow_i_normalized_np).to(DEVICE).squeeze()
        else:
            # If no scaler or in fitting mode, return the raw cashflow
            cashflow_i_to_return = cashflow_i_raw.squeeze()

        # Update portfolio_i_normalized with normalized K_i if not in fitting mode
        if not self.is_fitting_mode:
            portfolio_i_normalized = portfolio_features_tensor.clone()
            portfolio_i_normalized[:, 0] = K_i_normalized.squeeze()
        else: # In fitting mode, just return original K in portfolio_features_tensor for now
            portfolio_i_normalized = portfolio_features_tensor.clone()


        return portfolio_i_normalized, S_T_i_normalized, cashflow_i_to_return, derivative_i_raw.squeeze()


In [9]:
dataset = OperatorDatasetStandarization(num_samples=3, portfolio_size=3, num_samples_S_T=2, is_fitting_mode=True)


# Access the first sample in the dataset (index 0)
portfolio_features_tensor, S_T_i, cashflow_i, derivative_i = dataset[0]

# Print the cashflow
print("Portfolio Features:")
print(portfolio_features_tensor)

print("ST")
print(S_T_i)
print("Cashflow for first sample:")
print(cashflow_i)
print("DErivative first sample:")
print(derivative_i)


Generating and saving new K and S_T scalers to Google Drive...
***** Generated K_Scalar and S_T_Scalar and saved into Google drive
Portfolio Features:
tensor([[ 7.8491e+02,  0.0000e+00,  2.5000e-01],
        [ 9.9223e+02,  1.0000e+00,  1.0000e+00],
        [ 1.0181e+03,  1.0000e+00, -2.5000e-01]], device='cuda:0')
ST
tensor([744.2042, 816.3430], device='cuda:0')
Cashflow for first sample:
tensor([10.1776,  0.0000], device='cuda:0')
DErivative first sample:
tensor([-0.0064,  0.1364], device='cuda:0')


In [10]:

import torch.multiprocessing as mp
CASHFLOW_SCALAR_FILE = os.path.join(DRIVE_PATH, 'Cashflow_scalar.pkl')

def get_raw_cashflow_sample(self, idx):
        """
        Generates and returns only the raw cashflow for a given index.
        Useful for fitting the cashflow scaler.
        """
        # Re-using the logic from __getitem__ to generate the necessary components
        K, option_types, S_0, weights = generate_option_prices_for_idx(
            idx, self.portfolio_size
        )
        portfolio_features_tensor = torch.stack([
            torch.tensor(K, dtype=torch.float32, device=DEVICE),
            torch.tensor(option_types, dtype=torch.float32, device=DEVICE),
            torch.tensor(weights, dtype=torch.float32, device=DEVICE)
        ], dim=-1)
        weights_i = torch.tensor(weights, dtype=torch.float32, device=DEVICE)
        S_0_i = torch.tensor(S_0, dtype=torch.float32, device=DEVICE)
        Z = torch.clamp(torch.randn(self.num_samples_S_T, device=DEVICE), -3, 3)
        S_T_i = S_0_i * torch.exp((MU - 0.5 * SIGMA**2) * T + SIGMA * torch.sqrt(torch.tensor(T, device=DEVICE)) * Z)
        S_T_i += torch.randn_like(S_T_i, device=DEVICE) * (NOISE_STD * S_T_i)

        cashflow_i_raw, _ = compute_cashflow(
            portfolio_features_tensor.expand(self.num_samples_S_T, -1, -1),
            S_T_i.unsqueeze(-1),
            weights_i.expand(self.num_samples_S_T, -1)
        )
        return cashflow_i_raw.squeeze().cpu().numpy()


CASHFLOW_SCALER_FIT_SAMPLES = 100000

print(f"Generating {CASHFLOW_SCALER_FIT_SAMPLES} samples to fit Cashflow Scaler...")
fitting_dataset = OperatorDatasetStandarization(
    num_samples=CASHFLOW_SCALER_FIT_SAMPLES,
    portfolio_size=200, # Use your desired portfolio_size
    num_samples_S_T=1, # Use your desired num_samples_S_T
    is_fitting_mode=True # This tells the dataset to return raw cashflows
)


fitting_loader = DataLoader(fitting_dataset, batch_size=128, shuffle=False)

all_raw_cashflows_for_fitting = []
for i, (_, _, raw_cashflow_batch, _) in enumerate(fitting_loader):
    all_raw_cashflows_for_fitting.append(raw_cashflow_batch.view(-1).cpu().numpy()) # Flatten, move to CPU, and convert to numpy


# Concatenate all collected cashflows
all_raw_cashflows_for_fitting = np.concatenate(all_raw_cashflows_for_fitting).reshape(-1, 1)

print(f"Fitting Cashflow Scaler on {len(all_raw_cashflows_for_fitting)} samples...")
cashflow_scaler = StandardScaler() # StandardScaler is usually good for financial values
cashflow_scaler.fit(all_raw_cashflows_for_fitting)

# Save the cashflow scaler
joblib.dump(cashflow_scaler, CASHFLOW_SCALAR_FILE)
print(f"Cashflow Scaler fitted and saved to {CASHFLOW_SCALAR_FILE}")
print(f"Cashflow Mean: {cashflow_scaler.mean_[0]:.4f}, Std Dev: {cashflow_scaler.scale_[0]:.4f}")


Generating 100000 samples to fit Cashflow Scaler...
Loading K and S_T scalers from Google Drive...
Fitting Cashflow Scaler on 100000 samples...
Cashflow Scaler fitted and saved to /content/drive/MyDrive/Ucl/Cashflow_scalar.pkl
Cashflow Mean: 0.1338, Std Dev: 704.7646


#Models

In [12]:

import torch
import torch.nn as nn

class AttentionPooling(nn.Module):
    def __init__(self, hidden_dim):
        super(AttentionPooling, self).__init__()
        # Define a learnable weight for attention scores (query)
        self.attention_weights = nn.Parameter(torch.randn(hidden_dim))

    def forward(self, x):
        """
        Args:
            x: [batch_size, seq_len, hidden_dim] - sequence of input features
        Returns:
            pooled: [batch_size, hidden_dim] - attention-based pooled representation
        """
        # Calculate attention scores (dot product between sequence and attention weights)
        attention_scores = torch.matmul(x, self.attention_weights)  # [batch_size, seq_len]

        # Apply softmax to get the attention weights (normalized)
        attention_weights = torch.softmax(attention_scores, dim=1)  # [batch_size, seq_len]

        # Apply attention weights to the sequence
        weighted_sum = torch.sum(x * attention_weights.unsqueeze(-1), dim=1)  # [batch_size, hidden_dim]

        return weighted_sum

class MinMaxPooling(nn.Module):
    def __init__(self, dim):
        super().__init__()
        pass

    def forward(self, x):
        """
        Args:
            x: [batch_size, seq_len, dim]
        Returns:
            pooled: [batch_size, dim]
        """
        min_pooled, _ = torch.min(x, dim=1)  # [batch_size, dim]
        max_pooled, _ = torch.max(x, dim=1)  # [batch_size, dim]
        pooled = 0.5 * (min_pooled + max_pooled)
        return pooled



class InducedSetTransformerEncoder(nn.Module):
    """Set Transformer with Induced Self-Attention using PyTorch components"""
    def __init__(self, portfolio_feature_dim=3, latent_dim=128, hidden_dim=64,
                 num_layers=2, num_heads=4, num_inducing=16, dropout_prob=0.1):
        super().__init__()

        if hidden_dim % num_heads != 0:
            hidden_dim = ((hidden_dim // num_heads) + 1) * num_heads

        self.hidden_dim = hidden_dim
        self.num_inducing = num_inducing

        # Input projection
        self.input_proj = nn.Sequential(
            nn.Linear(portfolio_feature_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout_prob * 0.5)
        )

        # Learnable inducing points (using PyTorch default initialization)
        self.inducing_points = nn.Parameter(torch.randn(1, num_inducing, hidden_dim))

        # ISAB layers using MultiheadAttention
        self.isab_layers = nn.ModuleList()
        for _ in range(num_layers):
            self.isab_layers.append(ISABLayer(hidden_dim, num_heads, dropout_prob))

        # Final pooling and projection
        self.final_attention = nn.MultiheadAttention(
            hidden_dim, num_heads, dropout=dropout_prob, batch_first=True
        )
        self.norm = nn.LayerNorm(hidden_dim)

        self.output_proj = nn.Sequential(
            nn.LayerNorm(hidden_dim),
            nn.Linear(hidden_dim, latent_dim),
            nn.Tanh()
        )

        # Custom weight initialization removed - using PyTorch defaults

    def forward(self, portfolio):
        """
        Args:
            portfolio: [batch_size, portfolio_size, portfolio_feature_dim]
        Returns:
            [batch_size, latent_dim]
        """
        batch_size = portfolio.size(0)
        x = self.input_proj(portfolio)  # [B, P, H]

        # Expand inducing points for batch
        inducing = self.inducing_points.expand(batch_size, -1, -1)  # [B, I, H]

        # Apply ISAB layers
        for isab in self.isab_layers:
            x = isab(x, inducing)

        # Final attention pooling using inducing points as queries
        pooled, _ = self.final_attention(inducing, x, x)  # [B, I, H]

        # Global average pooling over inducing points
        pooled = pooled.mean(dim=1)  # [B, H]

        out = self.output_proj(pooled)  # [B, latent_dim]
        return out


class ISABLayer(nn.Module):
    """Induced Self-Attention Block using PyTorch's MultiheadAttention"""
    def __init__(self, hidden_dim, num_heads, dropout_prob=0.1):
        super().__init__()

        self.attn1 = nn.MultiheadAttention(
            hidden_dim, num_heads, dropout=dropout_prob, batch_first=True
        )
        self.attn2 = nn.MultiheadAttention(
            hidden_dim, num_heads, dropout=dropout_prob, batch_first=True
        )

        self.norm1 = nn.LayerNorm(hidden_dim)
        self.norm2 = nn.LayerNorm(hidden_dim)
        self.norm3 = nn.LayerNorm(hidden_dim)

        self.ff = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim * FEED_FWD_DEPTH),
            nn.ReLU(),
            nn.Dropout(dropout_prob),
            nn.Linear(hidden_dim * FEED_FWD_DEPTH, hidden_dim)
        )

        self.dropout = nn.Dropout(dropout_prob)

    def forward(self, x, inducing_points):
        """
        Args:
            x: [batch_size, seq_len, hidden_dim] - input sequence
            inducing_points: [batch_size, num_inducing, hidden_dim]
        Returns:
            [batch_size, seq_len, hidden_dim]
        """
        # Step 1: Inducing points attend to input
        h, _ = self.attn1(
            self.norm1(inducing_points),  # query
            self.norm2(x),                # key
            self.norm2(x)                 # value
        )  # [B, I, H]

        # Step 2: Input attends to processed inducing points
        x_out, _ = self.attn2(
            self.norm3(x),  # query
            h,              # key
            h               # value
        )  # [B, P, H]

        # Residual connection
        x = x + self.dropout(x_out)

        # Feedforward with residual
        x = x + self.dropout(self.ff(self.norm3(x)))

        return x

class OptimizedTrunkNet(nn.Module):
    """Trunk Network with increased depth/layers"""
    def __init__(self, input_dim=1, latent_dim=128, hidden_dim=64, dropout_prob=0.1, num_hidden_layers=2):
        # ^^^^^ ADDED num_hidden_layers parameter
        super().__init__()

        self.input_proj = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout_prob)
        )

        # Dynamically create more hidden layers
        layers = []
        for _ in range(num_hidden_layers): # This now represents the additional blocks
            layers.append(nn.Linear(hidden_dim, hidden_dim))
            layers.append(nn.GELU())
            layers.append(nn.Dropout(dropout_prob))

        # Remove the last dropout if it's not desired before the final layer
        if layers and isinstance(layers[-1], nn.Dropout):
            layers = layers[:-1]

        self.ff_block = nn.Sequential(*layers)


        self.output_proj = nn.Linear(hidden_dim, latent_dim)

    def forward(self, S_T):
        batch_size = S_T.shape[0]

        if S_T.dim() == 2:
            num_samples = S_T.shape[1]
            S_T_flat = S_T.reshape(-1, 1)
        else:
            raise ValueError(f"Expected shape [B, N], got {S_T.shape}")

        x = self.input_proj(S_T_flat)            # [B * N, H]
        x = self.ff_block(x)                     # [B * N, H]
        x = self.output_proj(x)                  # [B * N, latent_dim]
        x = x.view(batch_size, num_samples, -1)  # [B, N, latent_dim]

        return x



In [13]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from torch.utils.data import DataLoader
from torch.amp import autocast, GradScaler
from tqdm import tqdm
import math

FEED_FWD_DEPTH = 2


class TrunkNet(nn.Module):
    def __init__(self, input_dim=1,  # S_T is scalar
                 latent_dim=64,
                 hidden_dim=32,
                 num_layers=4,
                 dropout_prob=0.3):
          super(TrunkNet, self).__init__()
          layers = []

          layers.append(nn.Linear(input_dim, hidden_dim))
          layers.append(nn.LayerNorm(hidden_dim))
          layers.append(nn.ReLU())
          layers.append(nn.Dropout(dropout_prob)) # <--- UNCOMMENT THIS
          for _ in range(num_layers - 1):
              layers.append(nn.Linear(hidden_dim, hidden_dim))
              layers.append(nn.LayerNorm(hidden_dim))
              layers.append(nn.ReLU())
              layers.append(nn.Dropout(dropout_prob)) # <--- AND THIS
          layers.append(nn.Linear(hidden_dim, latent_dim))
          self.net = nn.Sequential(*layers)

    def forward(self, S_T):
        if S_T.dim() == 1:
          S_T = S_T.unsqueeze(-1)  # (batch_size,) → (batch_size, 1)
        elif S_T.dim() == 2:
          S_T = S_T.unsqueeze(-1)
        result = self.net(S_T)
        return result      # [B, N, latent_dim]


class OptimizedSetTransformerEncoder(nn.Module):
    """Set Transformer using PyTorch's built-in TransformerEncoder"""
    def __init__(self, portfolio_feature_dim=3, latent_dim=128, hidden_dim=64,
                 num_layers=1, num_heads=2, dropout_prob=0.1):
        super().__init__()

        # Ensure hidden_dim is divisible by num_heads
        if hidden_dim % num_heads != 0:
            hidden_dim = ((hidden_dim // num_heads) + 1) * num_heads

        self.hidden_dim = hidden_dim

        # Input projection
        self.input_proj = nn.Sequential(
            nn.Linear(portfolio_feature_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.GELU(),  #
            nn.Dropout(dropout_prob * 0.5)
        )

        # PyTorch's built-in TransformerEncoder
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=hidden_dim,
            nhead=num_heads,
            dim_feedforward=hidden_dim * FEED_FWD_DEPTH,
            dropout=dropout_prob,
            activation='gelu',
            batch_first=True,  # Important: input shape is [batch, seq, feature]
            norm_first=True    # Pre-norm (more stable)
        )

        self.transformer_encoder = nn.TransformerEncoder(
            encoder_layer,
            num_layers=num_layers,
            enable_nested_tensor=False  # For compatibility
        )

        # Pooling and output projection
        self.min_max_pool = MinMaxPooling(hidden_dim)

        self.output_proj = nn.Sequential(
            nn.LayerNorm(hidden_dim),
            nn.Dropout(dropout_prob * 0.5),  # Additional dropout
            nn.Linear(hidden_dim, latent_dim)
        )

        # Custom weight initialization removed - using PyTorch defaults

    def forward(self, portfolio):
        """
        Args:
            portfolio: [batch_size, portfolio_size, portfolio_feature_dim]
        Returns:
            [batch_size, latent_dim]
        """
        x = self.input_proj(portfolio)  # [B, P, H]

        # PyTorch transformer expects [batch, seq, feature] with batch_first=True
        x = self.transformer_encoder(x)  # [B, P, H]

        # Pool across sequence dimension
        x = torch.mean(x, dim=1)   # [B, H]

        #x = self.min_max_pool(x)  # [B, H]

        # Final projection
        out = self.output_proj(x)  # [B, latent_dim]

        return out

class OptimizedDeepONet(nn.Module):
    """DeepONet with choice of Set Transformer architectures"""
    def __init__(self, portfolio_feature_dim=3, hidden_dim=64, latent_dim=128,
                 dropout_prob=0.1, use_induced=False, num_inducing=16, num_heads=2):
        super().__init__()

        if hidden_dim % num_heads != 0:
          recommended = ((hidden_dim // num_heads) + 1) * num_heads
          raise ValueError(
              f"hidden_dim ({hidden_dim}) must be divisible by num_heads ({num_heads}). "
              f"Try hidden_dim={recommended}"
          )

        self.latent_dim = latent_dim

        # Choose branch network architecture
        if use_induced:
            self.branch_net = InducedSetTransformerEncoder(
                portfolio_feature_dim=portfolio_feature_dim,
                latent_dim=latent_dim,
                hidden_dim=hidden_dim,
                num_inducing=num_inducing,
                dropout_prob=dropout_prob
            )
        else:
            self.branch_net = OptimizedSetTransformerEncoder(
                portfolio_feature_dim=portfolio_feature_dim,
                latent_dim=latent_dim,
                hidden_dim=hidden_dim,
                dropout_prob=dropout_prob
            )

        # Trunk network
        self.trunk_net = TrunkNet(
            input_dim=1,
            latent_dim=latent_dim,
            hidden_dim=hidden_dim,
            dropout_prob=dropout_prob
        )

        # DeepONet parameters
        self.bias = nn.Parameter(torch.zeros(1))
        self.branch_scale = nn.Parameter(torch.ones(1) * 0.8)
        self.trunk_scale = nn.Parameter(torch.ones(1) * 0.8)

    def forward(self, portfolio, S_T):
        """
        Args:
            portfolio: [batch_size, portfolio_size, 3] - portfolio features
            S_T: [batch_size, num_S_T_samples] - multiple S_T values per portfolio

        Returns:
            cashflows: [batch_size, num_S_T_samples] - predicted cashflows for each S_T
        """
        # Branch network: encode portfolio
        branch_out = self.branch_net(portfolio)  # [batch_size, latent_dim]
        branch_out = branch_out * self.branch_scale

        # Trunk network: process S_T values
        trunk_out = self.trunk_net(S_T)  # [batch_size, num_S_T_samples, latent_dim]
        trunk_out = trunk_out * self.trunk_scale

        # Compute dot product: branch ⊗ trunk
        branch_expanded = branch_out.unsqueeze(1)  # [batch_size, 1, latent_dim]


        interaction = (branch_expanded * trunk_out).sum(dim=-1)  # [batch_size, num_S_T_samples]

        # Add bias
        cashflows = interaction + self.bias
        return cashflows

# Trainers

In [14]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from torch.amp import autocast, GradScaler

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from torch.amp import autocast, GradScaler

# REMOVE these global constants, they will be passed during Trainer initialization
# LEARNING_RATE = 5e-6
# LAMBDA_DERIV = 0.1

# ===== DERIVATIVE RESCALING =====
def rescale_derivative_autograd(pred_derivative_from_normalized_input, S_T_scalar_normalizer):
    data_range = S_T_scalar_normalizer.data_max_[0] - S_T_scalar_normalizer.data_min_[0]

    if data_range < 1e-6:
        print(f"WARNING: S_T data_range is extremely small ({data_range}). Derivative might be unstable.")
        return torch.zeros_like(pred_derivative_from_normalized_input)

    # ✅ Multiply to rescale back to original S_T scale
    rescaled_derivative = pred_derivative_from_normalized_input * data_range
    return rescaled_derivative


class ExtendedEarlyStopping:
    # ... (no changes needed here) ...
    def __init__(self, patience=30, min_delta=0.0005, restore_best_weights=True):
        self.patience = patience
        self.min_delta = min_delta
        self.restore_best_weights = restore_best_weights
        self.wait = 0
        self.stopped_epoch = 0
        self.best = float('inf')
        self.best_weights = None

    def __call__(self, val_loss, model=None):
        if val_loss < self.best - self.min_delta:
            self.best = val_loss
            self.wait = 0
            if model is not None and self.restore_best_weights:
                self.best_weights = model.state_dict().copy()
        else:
            self.wait += 1

        if self.wait >= self.patience:
            self.stopped_epoch = True
            if model is not None and self.restore_best_weights and self.best_weights is not None:
                model.load_state_dict(self.best_weights)

        return self.stopped_epoch

class OptimizedTrainer:
    def __init__(self, model, device='cuda', monitor_gradients=True,
                 scale_warmup_epochs=5, initial_scale=0.05, final_scale=1.0,
                 learning_rate=5e-6, lambda_deriv_weight=0.01, weight_decay=1e-4): # ADDED: learning_rate, lambda_deriv_weight, weight_decay for consistency
        self.model = model.to(device)
        self.device = device
        self.monitor_gradients = monitor_gradients
        self.lambda_deriv_weight = lambda_deriv_weight # Store this for compute_loss

        self.optimizer = optim.AdamW(
            model.parameters(),
            lr=learning_rate, # Use the passed learning_rate
            weight_decay=weight_decay, # Use the passed weight_decay (or make it a constant here if not configurable)
            betas=(0.9, 0.999),
            eps=1e-8
        )

        self.scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(
            self.optimizer, T_0=100, T_mult=2, eta_min=1e-6 # T_0 might also be a hyperparameter
        )

        self.scaler = GradScaler()
        self.S_T_scalar = joblib.load(ST_SCALAR_FILE)

        self.mse_loss = nn.MSELoss()
        self.huber_loss = nn.SmoothL1Loss(beta=1.0)

        # Scale warmup parameters
        self.scale_warmup_epochs = scale_warmup_epochs
        self.initial_scale = initial_scale
        self.final_scale = final_scale

        # Initialize model scales
        if hasattr(self.model, 'branch_scale') and hasattr(self.model, 'trunk_scale'):
            with torch.no_grad():
                self.model.branch_scale.fill_(initial_scale)
                self.model.trunk_scale.fill_(initial_scale)

    def check_model_health(self, epoch, batch_idx):
        """Check model parameters for NaN/Inf values"""
        for name, param in self.model.named_parameters():
            if torch.isnan(param).any():
                print(f"🚨 NaN parameter found: {name} at Epoch {epoch}, Batch {batch_idx}")
                return False
            if torch.isinf(param).any():
                print(f"🚨 Inf parameter found: {name} at Epoch {epoch}, Batch {batch_idx}")
                return False
        return True

    def compute_loss(self, pred_cashflow, true_cashflow, pred_deriv, true_deriv,
                     lambda_reg=1e-4): # REMOVED lambda_deriv parameter, use self.lambda_deriv_weight
        cashflow_loss = self.huber_loss(pred_cashflow, true_cashflow)

        if true_deriv is not None:
            pred_deriv_rescaled = rescale_derivative_autograd(pred_deriv, self.S_T_scalar)
            deriv_loss = self.huber_loss(pred_deriv_rescaled, true_deriv)
            total_loss  =  cashflow_loss + self.lambda_deriv_weight * deriv_loss # Use the stored weight
        else:
             deriv_loss = torch.tensor(0.0, device=pred_cashflow.device)

        # l2_reg = sum(p.pow(2).sum() for p in self.model.parameters())
        # total_loss += lambda_reg * l2_reg

        return total_loss, cashflow_loss, deriv_loss



    def train_step(self, portfolio, S_T, cashflow, true_derivative,
                  epoch=0, batch_idx=0, experiment=None):

        # Check model health before training step
        if not self.check_model_health(epoch, batch_idx):
            return float('inf'), float('inf'), float('inf')

        self.optimizer.zero_grad()

        # Enable gradients for S_T
        S_T = S_T.clone().detach().requires_grad_(True)

        try:
            #with autocast(device_type=self.device.type):
            pred_cashflow = self.model(portfolio, S_T)

            # Check prediction health
            if torch.isnan(pred_cashflow).any() or torch.isinf(pred_cashflow).any():
                print(f"🚨 Invalid predictions at Epoch {epoch}, Batch {batch_idx}")
                return float('inf'), float('inf'), float('inf')

            # Compute derivatives with error handling
            pred_deriv_from_autograd = None
            if true_derivative is not None:
                try:
                    pred_deriv_from_autograd = torch.autograd.grad(
                        outputs=pred_cashflow.sum(),  # Sum to get scalar output
                        inputs=S_T,
                        retain_graph=True,
                        create_graph=True,
                        allow_unused=True
                    )[0]

                    # Check derivative health
                    if pred_deriv_from_autograd is not None:
                        if torch.isnan(pred_deriv_from_autograd).any() or torch.isinf(pred_deriv_from_autograd).any():
                            print(f"🚨 Invalid derivatives at Epoch {epoch}, Batch {batch_idx}")
                            pred_deriv_from_autograd = None

                except RuntimeError as e:
                    print(f"⚠️ Derivative computation failed: {e}")
                    pred_deriv_from_autograd = None

            total_loss, cashflow_loss, deriv_loss = self.compute_loss(
                pred_cashflow, cashflow, pred_deriv_from_autograd, true_derivative
            )

        except RuntimeError as e:
            print(f"⚠️ Forward pass failed at Epoch {epoch}, Batch {batch_idx}: {e}")
            return float('inf'), float('inf'), float('inf')

        # Check loss health
        if torch.isnan(total_loss) or torch.isinf(total_loss):
            print(f"⚠️ Invalid total loss at Epoch {epoch}, Batch {batch_idx}")
            return float('inf'), float('inf'), float('inf')

        # Scale and backward pass
        scaled_loss = self.scaler.scale(total_loss)
        scaled_loss.backward()

        # Unscale gradients for inspection
        self.scaler.unscale_(self.optimizer)

        # Check gradients for NaN/Inf
        gradient_health = True
        max_grad_norm = 0.0
        for name, param in self.model.named_parameters():
            if param.grad is not None:
                grad_norm = param.grad.norm().item()
                max_grad_norm = max(max_grad_norm, grad_norm)

                if torch.isnan(param.grad).any():
                    print(f"🚨 NaN gradient found: {name} at Epoch {epoch}, Batch {batch_idx}")
                    gradient_health = False
                if torch.isinf(param.grad).any():
                    print(f"🚨 Inf gradient found: {name} at Epoch {epoch}, Batch {batch_idx}")
                    gradient_health = False

        if not gradient_health or max_grad_norm > 50.0:
            print(f"🛑 Unhealthy gradients detected. Max norm: {max_grad_norm:.4f}")
            self.scaler.update()  # Update scaler even on failure
            return float('inf'), float('inf'), float('inf')

        # Gradient clipping with more conservative threshold
        torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1)

        # Gradient monitoring
        if self.monitor_gradients and (batch_idx % 500 == 0 or max_grad_norm > 1.0):
            total_norm, gradient_stats, param_count = compute_gradient_stats(self.model)
            print_gradient_summary(gradient_stats, total_norm, epoch, batch_idx)

        # Optimizer step
        self.scaler.step(self.optimizer)
        self.scaler.update()
        self.scheduler.step()

        return total_loss.item(), cashflow_loss.item(), deriv_loss.item()


    def val_step(self, portfolio, S_T, cashflow, derivative): # derivative is the true derivative
        self.model.eval()

        # Ensure S_T requires gradients for derivative calculation in validation
        S_T.requires_grad_(True) # Temporarily enable for this specific calculation

        with torch.no_grad(): # Outer no_grad block for validation
            # Use torch.enable_grad() only around the autograd.grad call if necessary
            # For validation, we don't need to retain graph or create graph.
            # The model forward is within no_grad, so its parameters won't get gradients.
            # Only S_T will have its gradient computed.
            with torch.enable_grad(): # Re-enable graph building for derivative computation
                pred_cashflow = self.model(portfolio, S_T)
                pred_deriv_from_autograd = torch.autograd.grad(
                    outputs=pred_cashflow,
                    inputs=S_T,
                    grad_outputs=torch.ones_like(pred_cashflow),
                    retain_graph=False, # No need to retain graph in val
                    create_graph=False  # No need for higher-order in val
                )[0]

            val_total_loss, val_cashflow_loss, val_deriv_loss = self.compute_loss(
                pred_cashflow, cashflow, pred_deriv_from_autograd, derivative
            )

        # Reset requires_grad for S_T after use
        if S_T.requires_grad:
            S_T.requires_grad_(False)

        return val_total_loss.item(), val_cashflow_loss.item(), val_deriv_loss.item()


    def update_scale(self, current_epoch):
        # ... (no changes needed here) ...
        if hasattr(self.model, 'branch_scale') and hasattr(self.model, 'trunk_scale'):
            # Linear warmup from initial_scale to final_scale over scale_warmup_epochs
            if current_epoch < self.scale_warmup_epochs:
                factor = (current_epoch + 1) / self.scale_warmup_epochs
                new_scale = self.initial_scale + (self.final_scale - self.initial_scale) * factor
            else:
                new_scale = self.final_scale

            with torch.no_grad():
                self.model.branch_scale.fill_(new_scale)
                self.model.trunk_scale.fill_(new_scale)

            print(f"🔧 [Epoch {current_epoch}] Updated branch/trunk scale → {new_scale:.4f}")

# ... (compute_gradient_stats, print_gradient_summary, get_stable_hyperparameters remain the same) ...
# Ensure get_stable_hyperparameters defines learning_rate and lambda_deriv as it will be used directly.
def get_stable_hyperparameters():
    """Return more stable hyperparameters"""
    return {
        "learning_rate": 1e-5,
        "weight_decay": 5e-4,
        "lambda_deriv": 0.1,
        "lambda_reg": 1e-4,
        "gradient_clip_norm": 1,
        "batch_size": 32,
        "scheduler_T0": 200,
        "early_stopping_patience": 50,
    }


def compute_gradient_stats(model):
    """Compute gradient statistics - fixed for single-element tensors"""
    total_norm = 0
    param_count = 0
    gradient_stats = {}

    for name, param in model.named_parameters():
        if param.grad is not None:
            param_norm = param.grad.data.norm(2)
            total_norm += param_norm.item() ** 2
            param_count += 1

            grad_data = param.grad.data
            if grad_data.numel() > 1:
                grad_std = grad_data.std(unbiased=False).item()
            else:
                grad_std = 0.0 # Handle single-element tensors

            gradient_stats[name] = {
                'norm': param_norm.item(),
                'mean': grad_data.mean().item(),
                'std': grad_std,
                'max': grad_data.max().item(),
                'min': grad_data.min().item(),
                'shape': list(param.grad.shape),
                'numel': grad_data.numel()
            }

    total_norm = total_norm ** (1. / 2)
    return total_norm, gradient_stats, param_count

def print_gradient_summary(gradient_stats, total_norm, epoch, batch_idx=None):
    """Enhanced gradient summary with more details"""
    prefix = f"Epoch {epoch}" + (f", Batch {batch_idx}" if batch_idx is not None else "")
    print(f"\n🔍 === Gradient Analysis - {prefix} ===")
    print(f"Total Gradient Norm: {total_norm:.6f}")

    if total_norm > 20.0:
        print("🚨 CRITICAL: Severe gradient explosion! Consider stopping training.")
    elif total_norm > 10.0:
        print("⚠️  SEVERE: Major gradient explosion detected!")
    elif total_norm > 5.0:
        print("⚠️  WARNING: Moderate gradient explosion detected!")
    elif total_norm < 1e-6:
        print("⚠️  WARNING: Vanishing gradients detected!")
    else:
        print("✅ Gradient norm is healthy")

    sorted_layers = sorted(gradient_stats.items(), key=lambda x: x[1]['norm'], reverse=True)
    print(f"\nTop 5 layers by gradient norm (out of {len(gradient_stats)} total):")
    for i, (layer_name, stats) in enumerate(sorted_layers[:5]):
        status = "🔥" if stats['norm'] > 3.0 else "⚠️" if stats['norm'] > 1.0 else "✅"
        print(f"  {status} {i+1}. {layer_name}: {stats['norm']:.4f}")
        print(f"      Shape: {stats['shape']}, Elements: {stats['numel']}")
        print(f"      Mean: {stats['mean']:.6f}, Std: {stats['std']:.6f}")

    print("=" * 60)

# Start Training

In [15]:

experiment.end()

NameError: name 'experiment' is not defined

In [None]:
import torch
import torch.nn as nn
import numpy as np
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
from comet_ml import start
from comet_ml.integration.pytorch import log_model
from tqdm import tqdm

# Assuming OptimizedDeepONet, OptimizedTrainer, ExtendedEarlyStopping, OperatorDatasetStandarization
# and other utility functions (compute_gradient_stats, print_gradient_summary, get_stable_hyperparameters)
# are available in the scope or imported.

# Hyperparameters (these will now be primarily driven by get_stable_hyperparameters)
hidden_dim = 32
latent_dim = 32
batch_size = 32
epochs = 1000
portfolio_feature_dim = 3

PORT_LEN = 200
PORT_SAMPLE_SIZE = 51200
FEED_ST_LEN_EACH_PORT = 50
SAMPLE_SIZE_SCALAR = 20
# REMOVED: LAMBDA_DERIV = 0.1 (it's now part of hyperparameters)


# Comet Experiment Setup
experiment = start(api_key="iatWnXT4JyBtDQhn7OfgISQoF", project_name="option-portfolio-encoder-decoder", workspace="satyabratkumarsingh")
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def main():
    # Retrieve hyperparameters
    hparams = get_stable_hyperparameters()
    experiment.log_parameters(hparams) # Log all hparams to Comet

    # Initialize model
    deeponet_model = OptimizedDeepONet(portfolio_feature_dim=portfolio_feature_dim, hidden_dim=hidden_dim, latent_dim=latent_dim, use_induced=False).to(DEVICE)

    # Initialize optimized trainer with explicit hyperparameters
    trainer = OptimizedTrainer(
        deeponet_model,
        device=DEVICE,
        learning_rate=hparams["learning_rate"],
        lambda_deriv_weight=hparams["lambda_deriv"],
        weight_decay=hparams["weight_decay"], # Pass weight_decay as well for consistency
        # You can add other configurable hparams from get_stable_hyperparameters here
    )

    # Assuming dataset has been created
    dataset = OperatorDatasetStandarization(num_samples=PORT_SAMPLE_SIZE, portfolio_size=PORT_LEN, num_samples_S_T=FEED_ST_LEN_EACH_PORT, is_fitting_mode=False)

    # Split dataset
    train_size = int(0.8 * len(dataset))
    val_size = len(dataset) - train_size
    val_size = (val_size // batch_size) * batch_size  # Round to nearest multiple of batch_size
    train_size = len(dataset) - val_size
    train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

    # Create DataLoader for both train and validation datasets
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size)

    # Early stopping and learning rate scheduler setup
    best_val_total_loss = float('inf')  # Track best validation loss for early stopping
    early_stopping = ExtendedEarlyStopping(patience=hparams["early_stopping_patience"], min_delta=0.0005) # Use hparams patience
    # REMOVED: lambda_deriv = LAMBDA_DERIV # No longer needed, trainer uses its internal stored value
    plateau_scheduler = optim.lr_scheduler.ReduceLROnPlateau(trainer.optimizer, mode='min', factor=0.7, patience=10, verbose=True)


    # Training loop
    for epoch in range(epochs):
        train_epoch_total_losses = []
        train_epoch_cashflow_losses = []
        train_epoch_deriv_losses = []

        trainer.update_scale(epoch)
        trainer.model.train()

        for batch_idx, (portfolio_real, s_t_real, cashflows_real, derivs_real) in enumerate(tqdm(train_loader, desc=f"Training Epoch {epoch}", leave=False)):
            portfolio_real = portfolio_real.to(torch.float32).to(DEVICE)
            cashflows_real = cashflows_real.to(torch.float32).to(DEVICE)
            s_t_real = s_t_real.to(torch.float32).to(DEVICE)
            derivs_real = derivs_real.to(torch.float32).to(DEVICE)


            # Get all loss components from train_step
            total_loss, cashflow_loss, deriv_loss = trainer.train_step(
                portfolio=portfolio_real,
                S_T=s_t_real,
                cashflow=cashflows_real,
                true_derivative=derivs_real, # This is the true derivative for loss comparison
                batch_idx=batch_idx,
                experiment=experiment
            )
            train_epoch_total_losses.append(total_loss)
            train_epoch_cashflow_losses.append(cashflow_loss)
            train_epoch_deriv_losses.append(deriv_loss)

            # Log individual batch training losses
            experiment.log_metric("training_total_loss_batch", total_loss, step=epoch * len(train_loader) + batch_idx)
            experiment.log_metric("training_cashflow_loss_batch", cashflow_loss, step=epoch * len(train_loader) + batch_idx)
            experiment.log_metric("training_deriv_loss_batch", deriv_loss, step=epoch * len(train_loader) + batch_idx)

        # Average epoch losses for training
        avg_train_total_loss = np.mean(train_epoch_total_losses)
        avg_train_cashflow_loss = np.mean(train_epoch_cashflow_losses)
        avg_train_deriv_loss = np.mean(train_epoch_deriv_losses)

        # Log average training losses for the epoch
        experiment.log_metric("training_total_loss_epoch", avg_train_total_loss, epoch=epoch)
        experiment.log_metric("training_cashflow_loss_epoch", avg_train_cashflow_loss, epoch=epoch)
        experiment.log_metric("training_deriv_loss_epoch", avg_train_deriv_loss, epoch=epoch)

        # Validation loop
        trainer.model.eval()
        val_epoch_total_losses = []
        val_epoch_cashflow_losses = []
        val_epoch_deriv_losses = []

        with torch.no_grad(): # Outer no_grad is for overall validation efficiency
            for batch_idx, (portfolio_real, s_t_real, cashflows_real, derivs_real) in enumerate(tqdm(val_loader, desc=f"Validation Epoch {epoch}", leave=False)):

                portfolio_real = portfolio_real.to(torch.float32).to(DEVICE)
                cashflows_real = cashflows_real.to(torch.float32).to(DEVICE)
                s_t_real = s_t_real.to(torch.float32).to(DEVICE) # Ensure S_T is float32
                derivs_real = derivs_real.to(torch.float32).to(DEVICE)

                # val_step now implicitly calculates the derivative
                val_total_loss, val_cashflow_loss, val_deriv_loss = trainer.val_step(
                    portfolio=portfolio_real,
                    S_T=s_t_real,
                    cashflow=cashflows_real,
                    derivative=derivs_real, # This is the true derivative for loss comparison
                )


                val_epoch_total_losses.append(val_total_loss)
                val_epoch_cashflow_losses.append(val_cashflow_loss)
                val_epoch_deriv_losses.append(val_deriv_loss)

        # Average validation losses for the epoch
        avg_val_total_loss = np.mean(val_epoch_total_losses)
        avg_val_cashflow_loss = np.mean(val_epoch_cashflow_losses)
        avg_val_deriv_loss = np.mean(val_epoch_deriv_losses)

        # Log validation losses for the epoch
        experiment.log_metric("validation_total_loss_epoch", avg_val_total_loss, epoch=epoch)
        experiment.log_metric("validation_cashflow_loss_epoch", avg_val_cashflow_loss, epoch=epoch)
        experiment.log_metric("validation_deriv_loss_epoch", avg_val_deriv_loss, epoch=epoch)

        # Step the plateau scheduler based on validation total loss
        plateau_scheduler.step(avg_val_total_loss)

        # Progress reporting
        if epoch % 10 == 0:
            current_lr = trainer.optimizer.param_groups[0]['lr']
            print(f'Epoch [{epoch}/{epochs}], Train Total Loss: {avg_train_total_loss:.6f}, Val Total Loss: {avg_val_total_loss:.6f}, '
                  f'Train CF Loss: {avg_train_cashflow_loss:.6f}, Val CF Loss: {avg_val_cashflow_loss:.6f}, '
                  f'Train Deriv Loss: {avg_train_deriv_loss:.6f}, Val Deriv Loss: {avg_val_deriv_loss:.6f}')
            print(f'Current Learning Rate: {current_lr:.8f}')

        # Save checkpoint every 50 epochs
        if epoch % 50 == 0 and epoch > 0:
            checkpoint_path = f'/content/drive/MyDrive/Ucl/checkpoint_epoch_{epoch}_V2.pt'
            torch.save({
                'epoch': epoch,
                'model_state_dict': trainer.model.state_dict(),
                'optimizer_state_dict': trainer.optimizer.state_dict(),
                'scheduler_state_dict': trainer.scheduler.state_dict(),
                'train_total_loss': avg_train_total_loss,
                'val_total_loss': avg_val_total_loss,
                'best_val_loss': best_val_total_loss
            }, checkpoint_path)
            print(f"Checkpoint saved at epoch {epoch}")

        # Early stopping check (using total loss)
        if early_stopping(avg_val_total_loss, trainer.model):
            print(f'Early stopping triggered at epoch {epoch}')
            save_path = '/content/drive/MyDrive/Ucl/best_deeponet_model_V2.pt'
            print(f"Saving model to drive: {save_path}")
            torch.save(trainer.model.state_dict(), save_path)
            print(f"Model saved to: {save_path}")
            break

        if avg_val_total_loss < best_val_total_loss:
            best_val_total_loss = avg_val_total_loss

    # Save final model even if early stopping doesn't trigger
    final_save_path = '/content/drive/MyDrive/Ucl/final_deeponet_model_V2.pt'
    torch.save(trainer.model.state_dict(), final_save_path)
    print(f"Final model saved to: {final_save_path}")

    # End the Comet experiment
    experiment.end()

if __name__ == "__main__":
    main()

Training Epoch 29:  47%|████▋     | 599/1280 [01:27<01:39,  6.87it/s]


🔍 === Gradient Analysis - Epoch 0, Batch 597 ===
Total Gradient Norm: 1.000000
✅ Gradient norm is healthy

Top 5 layers by gradient norm (out of 41 total):
  ✅ 1. branch_net.transformer_encoder.layers.0.self_attn.in_proj_weight: 0.5943
      Shape: [96, 32], Elements: 3072
      Mean: -0.000003, Std: 0.010723
  ✅ 2. branch_net.transformer_encoder.layers.0.self_attn.out_proj.weight: 0.4820
      Shape: [32, 32], Elements: 1024
      Mean: 0.000006, Std: 0.015062
  ✅ 3. branch_net.output_proj.2.weight: 0.3200
      Shape: [32, 32], Elements: 1024
      Mean: 0.000005, Std: 0.010001
  ✅ 4. branch_net.transformer_encoder.layers.0.linear2.weight: 0.2539
      Shape: [32, 64], Elements: 2048
      Mean: -0.000008, Std: 0.005610
  ✅ 5. trunk_net.net.16.weight: 0.2350
      Shape: [32, 32], Elements: 1024
      Mean: 0.000413, Std: 0.007332

🔍 === Gradient Analysis - Epoch 0, Batch 598 ===
Total Gradient Norm: 1.000000
✅ Gradient norm is healthy

Top 5 layers by gradient norm (out of 41 total

Training Epoch 29:  47%|████▋     | 601/1280 [01:27<01:37,  7.00it/s]


🔍 === Gradient Analysis - Epoch 0, Batch 599 ===
Total Gradient Norm: 1.000000
✅ Gradient norm is healthy

Top 5 layers by gradient norm (out of 41 total):
  ✅ 1. branch_net.transformer_encoder.layers.0.self_attn.in_proj_weight: 0.5923
      Shape: [96, 32], Elements: 3072
      Mean: 0.000003, Std: 0.010686
  ✅ 2. branch_net.transformer_encoder.layers.0.self_attn.out_proj.weight: 0.4527
      Shape: [32, 32], Elements: 1024
      Mean: -0.000002, Std: 0.014148
  ✅ 3. branch_net.output_proj.2.weight: 0.3440
      Shape: [32, 32], Elements: 1024
      Mean: 0.000025, Std: 0.010749
  ✅ 4. trunk_net.net.16.weight: 0.2760
      Shape: [32, 32], Elements: 1024
      Mean: 0.000194, Std: 0.008622
  ✅ 5. branch_net.transformer_encoder.layers.0.linear2.weight: 0.2259
      Shape: [32, 64], Elements: 2048
      Mean: 0.000013, Std: 0.004991


Training Epoch 29:  47%|████▋     | 603/1280 [01:27<01:37,  6.91it/s]


🔍 === Gradient Analysis - Epoch 0, Batch 601 ===
Total Gradient Norm: 1.000000
✅ Gradient norm is healthy

Top 5 layers by gradient norm (out of 41 total):
  ✅ 1. branch_net.transformer_encoder.layers.0.self_attn.in_proj_weight: 0.5166
      Shape: [96, 32], Elements: 3072
      Mean: -0.000002, Std: 0.009321
  ✅ 2. branch_net.transformer_encoder.layers.0.self_attn.out_proj.weight: 0.4475
      Shape: [32, 32], Elements: 1024
      Mean: -0.000060, Std: 0.013984
  ✅ 3. branch_net.output_proj.2.weight: 0.3756
      Shape: [32, 32], Elements: 1024
      Mean: 0.000068, Std: 0.011738
  ✅ 4. trunk_net.net.16.weight: 0.3024
      Shape: [32, 32], Elements: 1024
      Mean: -0.001340, Std: 0.009356
  ✅ 5. branch_net.transformer_encoder.layers.0.linear2.weight: 0.2629
      Shape: [32, 64], Elements: 2048
      Mean: 0.000049, Std: 0.005809

🔍 === Gradient Analysis - Epoch 0, Batch 602 ===
Total Gradient Norm: 1.000000
✅ Gradient norm is healthy

Top 5 layers by gradient norm (out of 41 tota

Training Epoch 29:  47%|████▋     | 605/1280 [01:28<01:37,  6.90it/s]


🔍 === Gradient Analysis - Epoch 0, Batch 603 ===
Total Gradient Norm: 1.000000
✅ Gradient norm is healthy

Top 5 layers by gradient norm (out of 41 total):
  ✅ 1. branch_net.transformer_encoder.layers.0.self_attn.in_proj_weight: 0.5699
      Shape: [96, 32], Elements: 3072
      Mean: -0.000002, Std: 0.010283
  ✅ 2. branch_net.transformer_encoder.layers.0.self_attn.out_proj.weight: 0.4451
      Shape: [32, 32], Elements: 1024
      Mean: 0.000013, Std: 0.013911
  ✅ 3. branch_net.output_proj.2.weight: 0.3221
      Shape: [32, 32], Elements: 1024
      Mean: -0.000030, Std: 0.010064
  ✅ 4. trunk_net.net.16.weight: 0.2589
      Shape: [32, 32], Elements: 1024
      Mean: -0.001020, Std: 0.008027
  ✅ 5. branch_net.transformer_encoder.layers.0.linear2.weight: 0.2375
      Shape: [32, 64], Elements: 2048
      Mean: -0.000018, Std: 0.005249

🔍 === Gradient Analysis - Epoch 0, Batch 604 ===
Total Gradient Norm: 1.000000
✅ Gradient norm is healthy

Top 5 layers by gradient norm (out of 41 tot

Training Epoch 29:  47%|████▋     | 607/1280 [01:28<01:38,  6.84it/s]


🔍 === Gradient Analysis - Epoch 0, Batch 605 ===
Total Gradient Norm: 1.000000
✅ Gradient norm is healthy

Top 5 layers by gradient norm (out of 41 total):
  ✅ 1. branch_net.transformer_encoder.layers.0.self_attn.in_proj_weight: 0.5932
      Shape: [96, 32], Elements: 3072
      Mean: -0.000003, Std: 0.010703
  ✅ 2. branch_net.transformer_encoder.layers.0.self_attn.out_proj.weight: 0.4521
      Shape: [32, 32], Elements: 1024
      Mean: -0.000002, Std: 0.014129
  ✅ 3. branch_net.output_proj.2.weight: 0.3430
      Shape: [32, 32], Elements: 1024
      Mean: -0.000032, Std: 0.010720
  ✅ 4. trunk_net.net.16.weight: 0.2622
      Shape: [32, 32], Elements: 1024
      Mean: -0.000194, Std: 0.008192
  ✅ 5. branch_net.transformer_encoder.layers.0.linear2.weight: 0.2426
      Shape: [32, 64], Elements: 2048
      Mean: 0.000011, Std: 0.005361

🔍 === Gradient Analysis - Epoch 0, Batch 606 ===
Total Gradient Norm: 1.000000
✅ Gradient norm is healthy

Top 5 layers by gradient norm (out of 41 tot

Training Epoch 29:  48%|████▊     | 610/1280 [01:29<01:40,  6.67it/s]


🔍 === Gradient Analysis - Epoch 0, Batch 608 ===
Total Gradient Norm: 1.000000
✅ Gradient norm is healthy

Top 5 layers by gradient norm (out of 41 total):
  ✅ 1. branch_net.transformer_encoder.layers.0.self_attn.in_proj_weight: 0.5723
      Shape: [96, 32], Elements: 3072
      Mean: 0.000003, Std: 0.010325
  ✅ 2. branch_net.transformer_encoder.layers.0.self_attn.out_proj.weight: 0.4573
      Shape: [32, 32], Elements: 1024
      Mean: -0.000025, Std: 0.014291
  ✅ 3. branch_net.output_proj.2.weight: 0.3768
      Shape: [32, 32], Elements: 1024
      Mean: -0.000016, Std: 0.011775
  ✅ 4. trunk_net.net.16.weight: 0.2872
      Shape: [32, 32], Elements: 1024
      Mean: -0.000222, Std: 0.008973
  ✅ 5. branch_net.transformer_encoder.layers.0.linear2.weight: 0.2396
      Shape: [32, 64], Elements: 2048
      Mean: 0.000000, Std: 0.005295

🔍 === Gradient Analysis - Epoch 0, Batch 609 ===
Total Gradient Norm: 1.000000
✅ Gradient norm is healthy

Top 5 layers by gradient norm (out of 41 tota

Training Epoch 29:  48%|████▊     | 613/1280 [01:29<01:35,  6.98it/s]


🔍 === Gradient Analysis - Epoch 0, Batch 611 ===
Total Gradient Norm: 1.000000
✅ Gradient norm is healthy

Top 5 layers by gradient norm (out of 41 total):
  ✅ 1. branch_net.transformer_encoder.layers.0.self_attn.in_proj_weight: 0.5718
      Shape: [96, 32], Elements: 3072
      Mean: -0.000003, Std: 0.010317
  ✅ 2. branch_net.transformer_encoder.layers.0.self_attn.out_proj.weight: 0.4661
      Shape: [32, 32], Elements: 1024
      Mean: 0.000013, Std: 0.014567
  ✅ 3. branch_net.output_proj.2.weight: 0.3651
      Shape: [32, 32], Elements: 1024
      Mean: 0.000010, Std: 0.011409
  ✅ 4. trunk_net.net.16.weight: 0.2752
      Shape: [32, 32], Elements: 1024
      Mean: -0.000064, Std: 0.008599
  ✅ 5. branch_net.transformer_encoder.layers.0.linear2.weight: 0.2549
      Shape: [32, 64], Elements: 2048
      Mean: 0.000015, Std: 0.005632

🔍 === Gradient Analysis - Epoch 0, Batch 612 ===
Total Gradient Norm: 1.000000
✅ Gradient norm is healthy

Top 5 layers by gradient norm (out of 41 total

Training Epoch 29:  48%|████▊     | 615/1280 [01:29<01:36,  6.91it/s]


🔍 === Gradient Analysis - Epoch 0, Batch 613 ===
Total Gradient Norm: 1.000000
✅ Gradient norm is healthy

Top 5 layers by gradient norm (out of 41 total):
  ✅ 1. branch_net.transformer_encoder.layers.0.self_attn.out_proj.weight: 0.4972
      Shape: [32, 32], Elements: 1024
      Mean: 0.000025, Std: 0.015536
  ✅ 2. branch_net.transformer_encoder.layers.0.self_attn.in_proj_weight: 0.4926
      Shape: [96, 32], Elements: 3072
      Mean: 0.000001, Std: 0.008887
  ✅ 3. branch_net.output_proj.2.weight: 0.3779
      Shape: [32, 32], Elements: 1024
      Mean: -0.000084, Std: 0.011809
  ✅ 4. branch_net.transformer_encoder.layers.0.linear2.weight: 0.2804
      Shape: [32, 64], Elements: 2048
      Mean: 0.000025, Std: 0.006196
  ✅ 5. trunk_net.net.16.weight: 0.2713
      Shape: [32, 32], Elements: 1024
      Mean: -0.000966, Std: 0.008422

🔍 === Gradient Analysis - Epoch 0, Batch 614 ===
Total Gradient Norm: 1.000000
✅ Gradient norm is healthy

Top 5 layers by gradient norm (out of 41 total

Training Epoch 29:  48%|████▊     | 617/1280 [01:30<01:35,  6.93it/s]


🔍 === Gradient Analysis - Epoch 0, Batch 615 ===
Total Gradient Norm: 1.000000
✅ Gradient norm is healthy

Top 5 layers by gradient norm (out of 41 total):
  ✅ 1. branch_net.transformer_encoder.layers.0.self_attn.in_proj_weight: 0.5104
      Shape: [96, 32], Elements: 3072
      Mean: 0.000001, Std: 0.009209
  ✅ 2. branch_net.transformer_encoder.layers.0.self_attn.out_proj.weight: 0.4416
      Shape: [32, 32], Elements: 1024
      Mean: -0.000005, Std: 0.013801
  ✅ 3. branch_net.output_proj.2.weight: 0.3982
      Shape: [32, 32], Elements: 1024
      Mean: 0.000012, Std: 0.012445
  ✅ 4. trunk_net.net.16.weight: 0.2891
      Shape: [32, 32], Elements: 1024
      Mean: 0.001477, Std: 0.008913
  ✅ 5. branch_net.transformer_encoder.layers.0.linear2.weight: 0.2550
      Shape: [32, 64], Elements: 2048
      Mean: 0.000011, Std: 0.005636

🔍 === Gradient Analysis - Epoch 0, Batch 616 ===
Total Gradient Norm: 1.000000
✅ Gradient norm is healthy

Top 5 layers by gradient norm (out of 41 total)

Training Epoch 29:  48%|████▊     | 620/1280 [01:30<01:34,  6.96it/s]


🔍 === Gradient Analysis - Epoch 0, Batch 618 ===
Total Gradient Norm: 1.000000
✅ Gradient norm is healthy

Top 5 layers by gradient norm (out of 41 total):
  ✅ 1. branch_net.transformer_encoder.layers.0.self_attn.out_proj.weight: 0.5095
      Shape: [32, 32], Elements: 1024
      Mean: -0.000006, Std: 0.015921
  ✅ 2. branch_net.transformer_encoder.layers.0.self_attn.in_proj_weight: 0.5020
      Shape: [96, 32], Elements: 3072
      Mean: -0.000000, Std: 0.009057
  ✅ 3. branch_net.output_proj.2.weight: 0.2743
      Shape: [32, 32], Elements: 1024
      Mean: 0.000031, Std: 0.008573
  ✅ 4. branch_net.input_proj.0.bias: 0.2722
      Shape: [32], Elements: 32
      Mean: 0.000000, Std: 0.048119
  ✅ 5. branch_net.transformer_encoder.layers.0.linear2.weight: 0.2674
      Shape: [32, 64], Elements: 2048
      Mean: 0.000004, Std: 0.005909

🔍 === Gradient Analysis - Epoch 0, Batch 619 ===
Total Gradient Norm: 1.000000
✅ Gradient norm is healthy

Top 5 layers by gradient norm (out of 41 total)

Training Epoch 29:  49%|████▊     | 622/1280 [01:30<01:37,  6.72it/s]


🔍 === Gradient Analysis - Epoch 0, Batch 620 ===
Total Gradient Norm: 1.000000
✅ Gradient norm is healthy

Top 5 layers by gradient norm (out of 41 total):
  ✅ 1. branch_net.transformer_encoder.layers.0.self_attn.out_proj.weight: 0.4976
      Shape: [32, 32], Elements: 1024
      Mean: 0.000009, Std: 0.015549
  ✅ 2. branch_net.transformer_encoder.layers.0.self_attn.in_proj_weight: 0.4965
      Shape: [96, 32], Elements: 3072
      Mean: 0.000002, Std: 0.008958
  ✅ 3. branch_net.output_proj.2.weight: 0.3736
      Shape: [32, 32], Elements: 1024
      Mean: -0.000081, Std: 0.011674
  ✅ 4. branch_net.transformer_encoder.layers.0.linear2.weight: 0.2762
      Shape: [32, 64], Elements: 2048
      Mean: 0.000019, Std: 0.006102
  ✅ 5. trunk_net.net.16.weight: 0.2629
      Shape: [32, 32], Elements: 1024
      Mean: -0.000797, Std: 0.008178

🔍 === Gradient Analysis - Epoch 0, Batch 621 ===
Total Gradient Norm: 1.000000
✅ Gradient norm is healthy

Top 5 layers by gradient norm (out of 41 total

Training Epoch 29:  49%|████▉     | 624/1280 [01:31<01:36,  6.80it/s]


🔍 === Gradient Analysis - Epoch 0, Batch 622 ===
Total Gradient Norm: 1.000000
✅ Gradient norm is healthy

Top 5 layers by gradient norm (out of 41 total):
  ✅ 1. branch_net.transformer_encoder.layers.0.self_attn.out_proj.weight: 0.4751
      Shape: [32, 32], Elements: 1024
      Mean: -0.000011, Std: 0.014847
  ✅ 2. branch_net.transformer_encoder.layers.0.self_attn.in_proj_weight: 0.4685
      Shape: [96, 32], Elements: 3072
      Mean: -0.000000, Std: 0.008452
  ✅ 3. branch_net.output_proj.2.weight: 0.3747
      Shape: [32, 32], Elements: 1024
      Mean: -0.000007, Std: 0.011709
  ✅ 4. trunk_net.net.16.weight: 0.2694
      Shape: [32, 32], Elements: 1024
      Mean: 0.001347, Std: 0.008309
  ✅ 5. branch_net.input_proj.0.bias: 0.2606
      Shape: [32], Elements: 32
      Mean: 0.000000, Std: 0.046077

🔍 === Gradient Analysis - Epoch 0, Batch 623 ===
Total Gradient Norm: 1.000000
✅ Gradient norm is healthy

Top 5 layers by gradient norm (out of 41 total):
  ✅ 1. branch_net.transforme

Training Epoch 29:  49%|████▉     | 626/1280 [01:31<01:36,  6.78it/s]


🔍 === Gradient Analysis - Epoch 0, Batch 624 ===
Total Gradient Norm: 1.000000
✅ Gradient norm is healthy

Top 5 layers by gradient norm (out of 41 total):
  ✅ 1. branch_net.transformer_encoder.layers.0.self_attn.in_proj_weight: 0.5391
      Shape: [96, 32], Elements: 3072
      Mean: -0.000002, Std: 0.009727
  ✅ 2. branch_net.output_proj.2.weight: 0.4236
      Shape: [32, 32], Elements: 1024
      Mean: 0.000020, Std: 0.013236
  ✅ 3. branch_net.transformer_encoder.layers.0.self_attn.out_proj.weight: 0.4133
      Shape: [32, 32], Elements: 1024
      Mean: 0.000004, Std: 0.012915
  ✅ 4. trunk_net.net.16.weight: 0.3317
      Shape: [32, 32], Elements: 1024
      Mean: -0.000905, Std: 0.010327
  ✅ 5. branch_net.transformer_encoder.layers.0.linear2.weight: 0.2431
      Shape: [32, 64], Elements: 2048
      Mean: -0.000007, Std: 0.005371


Training Epoch 29:  49%|████▉     | 628/1280 [01:31<01:34,  6.88it/s]


🔍 === Gradient Analysis - Epoch 0, Batch 626 ===
Total Gradient Norm: 1.000000
✅ Gradient norm is healthy

Top 5 layers by gradient norm (out of 41 total):
  ✅ 1. branch_net.transformer_encoder.layers.0.self_attn.in_proj_weight: 0.5029
      Shape: [96, 32], Elements: 3072
      Mean: -0.000001, Std: 0.009074
  ✅ 2. branch_net.transformer_encoder.layers.0.self_attn.out_proj.weight: 0.5015
      Shape: [32, 32], Elements: 1024
      Mean: 0.000027, Std: 0.015673
  ✅ 3. branch_net.output_proj.2.weight: 0.3643
      Shape: [32, 32], Elements: 1024
      Mean: 0.000082, Std: 0.011383
  ✅ 4. trunk_net.net.16.weight: 0.2743
      Shape: [32, 32], Elements: 1024
      Mean: 0.000548, Std: 0.008555
  ✅ 5. branch_net.transformer_encoder.layers.0.linear2.weight: 0.2624
      Shape: [32, 64], Elements: 2048
      Mean: 0.000002, Std: 0.005798

🔍 === Gradient Analysis - Epoch 0, Batch 627 ===
Total Gradient Norm: 1.000000
✅ Gradient norm is healthy

Top 5 layers by gradient norm (out of 41 total)

Training Epoch 29:  49%|████▉     | 630/1280 [01:31<01:33,  6.98it/s]


🔍 === Gradient Analysis - Epoch 0, Batch 628 ===
Total Gradient Norm: 1.000000
✅ Gradient norm is healthy

Top 5 layers by gradient norm (out of 41 total):
  ✅ 1. branch_net.transformer_encoder.layers.0.self_attn.in_proj_weight: 0.5586
      Shape: [96, 32], Elements: 3072
      Mean: -0.000002, Std: 0.010078
  ✅ 2. branch_net.transformer_encoder.layers.0.self_attn.out_proj.weight: 0.4663
      Shape: [32, 32], Elements: 1024
      Mean: -0.000023, Std: 0.014571
  ✅ 3. branch_net.output_proj.2.weight: 0.3631
      Shape: [32, 32], Elements: 1024
      Mean: 0.000060, Std: 0.011347
  ✅ 4. trunk_net.net.16.weight: 0.2835
      Shape: [32, 32], Elements: 1024
      Mean: 0.000852, Std: 0.008820
  ✅ 5. branch_net.transformer_encoder.layers.0.linear2.weight: 0.2610
      Shape: [32, 64], Elements: 2048
      Mean: 0.000029, Std: 0.005767


Training Epoch 29:  49%|████▉     | 632/1280 [01:32<01:32,  7.00it/s]


🔍 === Gradient Analysis - Epoch 0, Batch 630 ===
Total Gradient Norm: 1.000000
✅ Gradient norm is healthy

Top 5 layers by gradient norm (out of 41 total):
  ✅ 1. branch_net.transformer_encoder.layers.0.self_attn.in_proj_weight: 0.5504
      Shape: [96, 32], Elements: 3072
      Mean: 0.000003, Std: 0.009931
  ✅ 2. branch_net.transformer_encoder.layers.0.self_attn.out_proj.weight: 0.4634
      Shape: [32, 32], Elements: 1024
      Mean: -0.000004, Std: 0.014482
  ✅ 3. branch_net.output_proj.2.weight: 0.3949
      Shape: [32, 32], Elements: 1024
      Mean: -0.000002, Std: 0.012339
  ✅ 4. trunk_net.net.16.weight: 0.3023
      Shape: [32, 32], Elements: 1024
      Mean: -0.000336, Std: 0.009441
  ✅ 5. branch_net.transformer_encoder.layers.0.linear2.weight: 0.2367
      Shape: [32, 64], Elements: 2048
      Mean: 0.000014, Std: 0.005230

🔍 === Gradient Analysis - Epoch 0, Batch 631 ===
Total Gradient Norm: 1.000000
✅ Gradient norm is healthy

Top 5 layers by gradient norm (out of 41 tota

Training Epoch 29:  50%|████▉     | 635/1280 [01:32<01:34,  6.83it/s]


🔍 === Gradient Analysis - Epoch 0, Batch 633 ===
Total Gradient Norm: 1.000000
✅ Gradient norm is healthy

Top 5 layers by gradient norm (out of 41 total):
  ✅ 1. branch_net.transformer_encoder.layers.0.self_attn.in_proj_weight: 0.5177
      Shape: [96, 32], Elements: 3072
      Mean: 0.000002, Std: 0.009341
  ✅ 2. branch_net.transformer_encoder.layers.0.self_attn.out_proj.weight: 0.4747
      Shape: [32, 32], Elements: 1024
      Mean: -0.000026, Std: 0.014834
  ✅ 3. branch_net.output_proj.2.weight: 0.3799
      Shape: [32, 32], Elements: 1024
      Mean: 0.000015, Std: 0.011871
  ✅ 4. trunk_net.net.16.weight: 0.3017
      Shape: [32, 32], Elements: 1024
      Mean: 0.000080, Std: 0.009429
  ✅ 5. branch_net.transformer_encoder.layers.0.linear2.weight: 0.2685
      Shape: [32, 64], Elements: 2048
      Mean: 0.000025, Std: 0.005933

🔍 === Gradient Analysis - Epoch 0, Batch 634 ===
Total Gradient Norm: 1.000000
✅ Gradient norm is healthy

Top 5 layers by gradient norm (out of 41 total)

Training Epoch 29:  50%|████▉     | 637/1280 [01:32<01:33,  6.89it/s]


🔍 === Gradient Analysis - Epoch 0, Batch 635 ===
Total Gradient Norm: 1.000000
✅ Gradient norm is healthy

Top 5 layers by gradient norm (out of 41 total):
  ✅ 1. branch_net.transformer_encoder.layers.0.self_attn.in_proj_weight: 0.5161
      Shape: [96, 32], Elements: 3072
      Mean: -0.000002, Std: 0.009312
  ✅ 2. branch_net.output_proj.2.weight: 0.4877
      Shape: [32, 32], Elements: 1024
      Mean: -0.000039, Std: 0.015240
  ✅ 3. branch_net.transformer_encoder.layers.0.self_attn.out_proj.weight: 0.3946
      Shape: [32, 32], Elements: 1024
      Mean: -0.000011, Std: 0.012331
  ✅ 4. trunk_net.net.16.weight: 0.3520
      Shape: [32, 32], Elements: 1024
      Mean: 0.000014, Std: 0.011001
  ✅ 5. branch_net.transformer_encoder.layers.0.linear2.weight: 0.2639
      Shape: [32, 64], Elements: 2048
      Mean: 0.000001, Std: 0.005832

🔍 === Gradient Analysis - Epoch 0, Batch 636 ===
Total Gradient Norm: 1.000000
✅ Gradient norm is healthy

Top 5 layers by gradient norm (out of 41 tota

Training Epoch 29:  50%|████▉     | 639/1280 [01:33<01:33,  6.89it/s]


🔍 === Gradient Analysis - Epoch 0, Batch 637 ===
Total Gradient Norm: 1.000000
✅ Gradient norm is healthy

Top 5 layers by gradient norm (out of 41 total):
  ✅ 1. branch_net.transformer_encoder.layers.0.self_attn.in_proj_weight: 0.5898
      Shape: [96, 32], Elements: 3072
      Mean: -0.000003, Std: 0.010641
  ✅ 2. branch_net.transformer_encoder.layers.0.self_attn.out_proj.weight: 0.4657
      Shape: [32, 32], Elements: 1024
      Mean: -0.000017, Std: 0.014552
  ✅ 3. branch_net.output_proj.2.weight: 0.3385
      Shape: [32, 32], Elements: 1024
      Mean: 0.000030, Std: 0.010578
  ✅ 4. trunk_net.net.16.weight: 0.2765
      Shape: [32, 32], Elements: 1024
      Mean: 0.000060, Std: 0.008639
  ✅ 5. branch_net.transformer_encoder.layers.0.linear2.weight: 0.2365
      Shape: [32, 64], Elements: 2048
      Mean: -0.000002, Std: 0.005225

🔍 === Gradient Analysis - Epoch 0, Batch 638 ===
Total Gradient Norm: 1.000000
✅ Gradient norm is healthy

Top 5 layers by gradient norm (out of 41 tota

Training Epoch 29:  50%|█████     | 641/1280 [01:33<01:30,  7.03it/s]


🔍 === Gradient Analysis - Epoch 0, Batch 639 ===
Total Gradient Norm: 1.000000
✅ Gradient norm is healthy

Top 5 layers by gradient norm (out of 41 total):
  ✅ 1. branch_net.transformer_encoder.layers.0.self_attn.in_proj_weight: 0.6071
      Shape: [96, 32], Elements: 3072
      Mean: -0.000003, Std: 0.010953
  ✅ 2. branch_net.transformer_encoder.layers.0.self_attn.out_proj.weight: 0.4744
      Shape: [32, 32], Elements: 1024
      Mean: -0.000008, Std: 0.014825
  ✅ 3. branch_net.output_proj.2.weight: 0.2905
      Shape: [32, 32], Elements: 1024
      Mean: 0.000023, Std: 0.009077
  ✅ 4. trunk_net.net.16.weight: 0.2371
      Shape: [32, 32], Elements: 1024
      Mean: -0.000120, Std: 0.007408
  ✅ 5. branch_net.transformer_encoder.layers.0.linear2.weight: 0.2359
      Shape: [32, 64], Elements: 2048
      Mean: -0.000011, Std: 0.005214

🔍 === Gradient Analysis - Epoch 0, Batch 640 ===
Total Gradient Norm: 1.000000
✅ Gradient norm is healthy

Top 5 layers by gradient norm (out of 41 tot

Training Epoch 29:  50%|█████     | 643/1280 [01:33<01:31,  6.95it/s]


🔍 === Gradient Analysis - Epoch 0, Batch 641 ===
Total Gradient Norm: 1.000000
✅ Gradient norm is healthy

Top 5 layers by gradient norm (out of 41 total):
  ✅ 1. branch_net.transformer_encoder.layers.0.self_attn.in_proj_weight: 0.5484
      Shape: [96, 32], Elements: 3072
      Mean: -0.000003, Std: 0.009895
  ✅ 2. branch_net.transformer_encoder.layers.0.self_attn.out_proj.weight: 0.4639
      Shape: [32, 32], Elements: 1024
      Mean: -0.000038, Std: 0.014497
  ✅ 3. branch_net.output_proj.2.weight: 0.3780
      Shape: [32, 32], Elements: 1024
      Mean: 0.000000, Std: 0.011814
  ✅ 4. trunk_net.net.16.weight: 0.2981
      Shape: [32, 32], Elements: 1024
      Mean: 0.000166, Std: 0.009313
  ✅ 5. branch_net.transformer_encoder.layers.0.linear2.weight: 0.2485
      Shape: [32, 64], Elements: 2048
      Mean: 0.000009, Std: 0.005491

🔍 === Gradient Analysis - Epoch 0, Batch 642 ===
Total Gradient Norm: 1.000000
✅ Gradient norm is healthy

Top 5 layers by gradient norm (out of 41 total

Training Epoch 29:  50%|█████     | 646/1280 [01:34<01:32,  6.84it/s]


🔍 === Gradient Analysis - Epoch 0, Batch 644 ===
Total Gradient Norm: 1.000000
✅ Gradient norm is healthy

Top 5 layers by gradient norm (out of 41 total):
  ✅ 1. branch_net.transformer_encoder.layers.0.self_attn.in_proj_weight: 0.5725
      Shape: [96, 32], Elements: 3072
      Mean: 0.000002, Std: 0.010329
  ✅ 2. branch_net.transformer_encoder.layers.0.self_attn.out_proj.weight: 0.4563
      Shape: [32, 32], Elements: 1024
      Mean: 0.000004, Std: 0.014258
  ✅ 3. branch_net.output_proj.2.weight: 0.3331
      Shape: [32, 32], Elements: 1024
      Mean: -0.000003, Std: 0.010409
  ✅ 4. trunk_net.net.16.weight: 0.2662
      Shape: [32, 32], Elements: 1024
      Mean: 0.000611, Std: 0.008297
  ✅ 5. branch_net.input_proj.0.bias: 0.2218
      Shape: [32], Elements: 32
      Mean: -0.000000, Std: 0.039209

🔍 === Gradient Analysis - Epoch 0, Batch 645 ===
Total Gradient Norm: 1.000000
✅ Gradient norm is healthy

Top 5 layers by gradient norm (out of 41 total):
  ✅ 1. branch_net.transformer

Training Epoch 29:  51%|█████     | 648/1280 [01:34<01:30,  6.95it/s]


🔍 === Gradient Analysis - Epoch 0, Batch 646 ===
Total Gradient Norm: 1.000000
✅ Gradient norm is healthy

Top 5 layers by gradient norm (out of 41 total):
  ✅ 1. branch_net.transformer_encoder.layers.0.self_attn.in_proj_weight: 0.5890
      Shape: [96, 32], Elements: 3072
      Mean: -0.000003, Std: 0.010627
  ✅ 2. branch_net.transformer_encoder.layers.0.self_attn.out_proj.weight: 0.4721
      Shape: [32, 32], Elements: 1024
      Mean: -0.000011, Std: 0.014754
  ✅ 3. branch_net.output_proj.2.weight: 0.3292
      Shape: [32, 32], Elements: 1024
      Mean: 0.000015, Std: 0.010289
  ✅ 4. branch_net.transformer_encoder.layers.0.linear2.weight: 0.2516
      Shape: [32, 64], Elements: 2048
      Mean: -0.000010, Std: 0.005560
  ✅ 5. trunk_net.net.16.weight: 0.2444
      Shape: [32, 32], Elements: 1024
      Mean: 0.000239, Std: 0.007634


Training Epoch 29:  51%|█████     | 651/1280 [01:34<01:29,  7.02it/s]


🔍 === Gradient Analysis - Epoch 0, Batch 649 ===
Total Gradient Norm: 1.000000
✅ Gradient norm is healthy

Top 5 layers by gradient norm (out of 41 total):
  ✅ 1. branch_net.transformer_encoder.layers.0.self_attn.out_proj.weight: 0.4804
      Shape: [32, 32], Elements: 1024
      Mean: -0.000042, Std: 0.015014
  ✅ 2. branch_net.transformer_encoder.layers.0.self_attn.in_proj_weight: 0.4644
      Shape: [96, 32], Elements: 3072
      Mean: -0.000001, Std: 0.008379
  ✅ 3. branch_net.output_proj.2.weight: 0.3873
      Shape: [32, 32], Elements: 1024
      Mean: 0.000091, Std: 0.012103
  ✅ 4. trunk_net.net.16.weight: 0.2743
      Shape: [32, 32], Elements: 1024
      Mean: 0.001241, Std: 0.008481
  ✅ 5. branch_net.transformer_encoder.layers.0.linear2.weight: 0.2607
      Shape: [32, 64], Elements: 2048
      Mean: -0.000013, Std: 0.005762

🔍 === Gradient Analysis - Epoch 0, Batch 650 ===
Total Gradient Norm: 1.000000
✅ Gradient norm is healthy

Top 5 layers by gradient norm (out of 41 tota

Training Epoch 29:  51%|█████     | 655/1280 [01:35<01:27,  7.17it/s]


🔍 === Gradient Analysis - Epoch 0, Batch 653 ===
Total Gradient Norm: 1.000000
✅ Gradient norm is healthy

Top 5 layers by gradient norm (out of 41 total):
  ✅ 1. branch_net.transformer_encoder.layers.0.self_attn.in_proj_weight: 0.5623
      Shape: [96, 32], Elements: 3072
      Mean: -0.000001, Std: 0.010144
  ✅ 2. branch_net.transformer_encoder.layers.0.self_attn.out_proj.weight: 0.4693
      Shape: [32, 32], Elements: 1024
      Mean: -0.000003, Std: 0.014666
  ✅ 3. branch_net.output_proj.2.weight: 0.3044
      Shape: [32, 32], Elements: 1024
      Mean: 0.000013, Std: 0.009512
  ✅ 4. trunk_net.net.16.weight: 0.2468
      Shape: [32, 32], Elements: 1024
      Mean: -0.000891, Std: 0.007660
  ✅ 5. branch_net.input_proj.0.bias: 0.2417
      Shape: [32], Elements: 32
      Mean: -0.000000, Std: 0.042736

🔍 === Gradient Analysis - Epoch 0, Batch 654 ===
Total Gradient Norm: 1.000000
✅ Gradient norm is healthy

Top 5 layers by gradient norm (out of 41 total):
  ✅ 1. branch_net.transform

Training Epoch 29:  51%|█████▏    | 657/1280 [01:35<01:28,  7.01it/s]


🔍 === Gradient Analysis - Epoch 0, Batch 655 ===
Total Gradient Norm: 1.000000
✅ Gradient norm is healthy

Top 5 layers by gradient norm (out of 41 total):
  ✅ 1. branch_net.transformer_encoder.layers.0.self_attn.in_proj_weight: 0.4938
      Shape: [96, 32], Elements: 3072
      Mean: 0.000002, Std: 0.008910
  ✅ 2. branch_net.transformer_encoder.layers.0.self_attn.out_proj.weight: 0.4855
      Shape: [32, 32], Elements: 1024
      Mean: 0.000001, Std: 0.015173
  ✅ 3. branch_net.output_proj.2.weight: 0.3986
      Shape: [32, 32], Elements: 1024
      Mean: -0.000046, Std: 0.012456
  ✅ 4. trunk_net.net.16.weight: 0.2781
      Shape: [32, 32], Elements: 1024
      Mean: -0.000825, Std: 0.008651
  ✅ 5. branch_net.transformer_encoder.layers.0.linear2.weight: 0.2576
      Shape: [32, 64], Elements: 2048
      Mean: -0.000001, Std: 0.005693

🔍 === Gradient Analysis - Epoch 0, Batch 656 ===
Total Gradient Norm: 1.000000
✅ Gradient norm is healthy

Top 5 layers by gradient norm (out of 41 tota

Training Epoch 29:  51%|█████▏    | 659/1280 [01:36<01:28,  7.03it/s]


🔍 === Gradient Analysis - Epoch 0, Batch 657 ===
Total Gradient Norm: 1.000000
✅ Gradient norm is healthy

Top 5 layers by gradient norm (out of 41 total):
  ✅ 1. branch_net.transformer_encoder.layers.0.self_attn.in_proj_weight: 0.5341
      Shape: [96, 32], Elements: 3072
      Mean: -0.000002, Std: 0.009636
  ✅ 2. branch_net.transformer_encoder.layers.0.self_attn.out_proj.weight: 0.4358
      Shape: [32, 32], Elements: 1024
      Mean: 0.000005, Std: 0.013618
  ✅ 3. branch_net.output_proj.2.weight: 0.4259
      Shape: [32, 32], Elements: 1024
      Mean: -0.000031, Std: 0.013308
  ✅ 4. trunk_net.net.16.weight: 0.3300
      Shape: [32, 32], Elements: 1024
      Mean: -0.000313, Std: 0.010307
  ✅ 5. branch_net.transformer_encoder.layers.0.linear2.weight: 0.2643
      Shape: [32, 64], Elements: 2048
      Mean: 0.000040, Std: 0.005840

🔍 === Gradient Analysis - Epoch 0, Batch 658 ===
Total Gradient Norm: 1.000000
✅ Gradient norm is healthy

Top 5 layers by gradient norm (out of 41 tota

Training Epoch 29:  52%|█████▏    | 661/1280 [01:36<01:31,  6.78it/s]


🔍 === Gradient Analysis - Epoch 0, Batch 659 ===
Total Gradient Norm: 1.000000
✅ Gradient norm is healthy

Top 5 layers by gradient norm (out of 41 total):
  ✅ 1. branch_net.transformer_encoder.layers.0.self_attn.in_proj_weight: 0.5160
      Shape: [96, 32], Elements: 3072
      Mean: -0.000002, Std: 0.009310
  ✅ 2. branch_net.transformer_encoder.layers.0.self_attn.out_proj.weight: 0.4861
      Shape: [32, 32], Elements: 1024
      Mean: -0.000051, Std: 0.015191
  ✅ 3. branch_net.output_proj.2.weight: 0.3549
      Shape: [32, 32], Elements: 1024
      Mean: -0.000010, Std: 0.011092
  ✅ 4. branch_net.transformer_encoder.layers.0.linear2.weight: 0.2733
      Shape: [32, 64], Elements: 2048
      Mean: -0.000018, Std: 0.006039
  ✅ 5. trunk_net.net.16.weight: 0.2582
      Shape: [32, 32], Elements: 1024
      Mean: 0.000561, Std: 0.008048

🔍 === Gradient Analysis - Epoch 0, Batch 660 ===
Total Gradient Norm: 1.000000
✅ Gradient norm is healthy

Top 5 layers by gradient norm (out of 41 tot

Training Epoch 29:  52%|█████▏    | 663/1280 [01:36<01:29,  6.90it/s]


🔍 === Gradient Analysis - Epoch 0, Batch 661 ===
Total Gradient Norm: 1.000000
✅ Gradient norm is healthy

Top 5 layers by gradient norm (out of 41 total):
  ✅ 1. branch_net.transformer_encoder.layers.0.self_attn.in_proj_weight: 0.5912
      Shape: [96, 32], Elements: 3072
      Mean: 0.000002, Std: 0.010667
  ✅ 2. branch_net.transformer_encoder.layers.0.self_attn.out_proj.weight: 0.4593
      Shape: [32, 32], Elements: 1024
      Mean: -0.000010, Std: 0.014352
  ✅ 3. branch_net.output_proj.2.weight: 0.3284
      Shape: [32, 32], Elements: 1024
      Mean: 0.000001, Std: 0.010261
  ✅ 4. trunk_net.net.16.weight: 0.2749
      Shape: [32, 32], Elements: 1024
      Mean: 0.000228, Std: 0.008587
  ✅ 5. branch_net.transformer_encoder.layers.0.linear2.weight: 0.2194
      Shape: [32, 64], Elements: 2048
      Mean: -0.000000, Std: 0.004849

🔍 === Gradient Analysis - Epoch 0, Batch 662 ===
Total Gradient Norm: 0.999999
✅ Gradient norm is healthy

Top 5 layers by gradient norm (out of 41 total

Training Epoch 29:  52%|█████▏    | 665/1280 [01:36<01:27,  7.02it/s]


🔍 === Gradient Analysis - Epoch 0, Batch 663 ===
Total Gradient Norm: 1.000000
✅ Gradient norm is healthy

Top 5 layers by gradient norm (out of 41 total):
  ✅ 1. branch_net.transformer_encoder.layers.0.self_attn.in_proj_weight: 0.5491
      Shape: [96, 32], Elements: 3072
      Mean: -0.000002, Std: 0.009907
  ✅ 2. branch_net.transformer_encoder.layers.0.self_attn.out_proj.weight: 0.4825
      Shape: [32, 32], Elements: 1024
      Mean: -0.000005, Std: 0.015077
  ✅ 3. branch_net.output_proj.2.weight: 0.3806
      Shape: [32, 32], Elements: 1024
      Mean: 0.000046, Std: 0.011893
  ✅ 4. branch_net.transformer_encoder.layers.0.linear2.weight: 0.2805
      Shape: [32, 64], Elements: 2048
      Mean: -0.000005, Std: 0.006199
  ✅ 5. trunk_net.net.16.weight: 0.2656
      Shape: [32, 32], Elements: 1024
      Mean: 0.000364, Std: 0.008292


Training Epoch 29:  52%|█████▏    | 667/1280 [01:37<01:31,  6.67it/s]


🔍 === Gradient Analysis - Epoch 0, Batch 665 ===
Total Gradient Norm: 1.000000
✅ Gradient norm is healthy

Top 5 layers by gradient norm (out of 41 total):
  ✅ 1. branch_net.transformer_encoder.layers.0.self_attn.out_proj.weight: 0.4963
      Shape: [32, 32], Elements: 1024
      Mean: 0.000029, Std: 0.015508
  ✅ 2. branch_net.transformer_encoder.layers.0.self_attn.in_proj_weight: 0.4943
      Shape: [96, 32], Elements: 3072
      Mean: 0.000002, Std: 0.008919
  ✅ 3. branch_net.output_proj.2.weight: 0.3834
      Shape: [32, 32], Elements: 1024
      Mean: 0.000012, Std: 0.011982
  ✅ 4. trunk_net.net.16.weight: 0.2745
      Shape: [32, 32], Elements: 1024
      Mean: -0.000619, Std: 0.008556
  ✅ 5. branch_net.transformer_encoder.layers.0.linear2.weight: 0.2646
      Shape: [32, 64], Elements: 2048
      Mean: 0.000030, Std: 0.005848

🔍 === Gradient Analysis - Epoch 0, Batch 666 ===
Total Gradient Norm: 1.000000
✅ Gradient norm is healthy

Top 5 layers by gradient norm (out of 41 total)

Training Epoch 29:  52%|█████▏    | 669/1280 [01:37<01:29,  6.79it/s]


🔍 === Gradient Analysis - Epoch 0, Batch 667 ===
Total Gradient Norm: 1.000000
✅ Gradient norm is healthy

Top 5 layers by gradient norm (out of 41 total):
  ✅ 1. branch_net.transformer_encoder.layers.0.self_attn.in_proj_weight: 0.5326
      Shape: [96, 32], Elements: 3072
      Mean: -0.000002, Std: 0.009609
  ✅ 2. branch_net.transformer_encoder.layers.0.self_attn.out_proj.weight: 0.4865
      Shape: [32, 32], Elements: 1024
      Mean: -0.000029, Std: 0.015205
  ✅ 3. branch_net.output_proj.2.weight: 0.3623
      Shape: [32, 32], Elements: 1024
      Mean: -0.000011, Std: 0.011322
  ✅ 4. trunk_net.net.16.weight: 0.2561
      Shape: [32, 32], Elements: 1024
      Mean: 0.000845, Std: 0.007959
  ✅ 5. branch_net.transformer_encoder.layers.0.linear2.weight: 0.2503
      Shape: [32, 64], Elements: 2048
      Mean: -0.000015, Std: 0.005532

🔍 === Gradient Analysis - Epoch 0, Batch 668 ===
Total Gradient Norm: 1.000000
✅ Gradient norm is healthy

Top 5 layers by gradient norm (out of 41 tot

Training Epoch 29:  52%|█████▎    | 672/1280 [01:37<01:29,  6.81it/s]


🔍 === Gradient Analysis - Epoch 0, Batch 670 ===
Total Gradient Norm: 1.000000
✅ Gradient norm is healthy

Top 5 layers by gradient norm (out of 41 total):
  ✅ 1. branch_net.transformer_encoder.layers.0.self_attn.in_proj_weight: 0.6188
      Shape: [96, 32], Elements: 3072
      Mean: -0.000003, Std: 0.011164
  ✅ 2. branch_net.transformer_encoder.layers.0.self_attn.out_proj.weight: 0.4849
      Shape: [32, 32], Elements: 1024
      Mean: -0.000040, Std: 0.015154
  ✅ 3. branch_net.output_proj.2.weight: 0.2473
      Shape: [32, 32], Elements: 1024
      Mean: 0.000022, Std: 0.007727
  ✅ 4. branch_net.transformer_encoder.layers.0.linear2.weight: 0.2355
      Shape: [32, 64], Elements: 2048
      Mean: -0.000008, Std: 0.005204
  ✅ 5. trunk_net.net.16.weight: 0.2248
      Shape: [32, 32], Elements: 1024
      Mean: -0.000554, Std: 0.007004

🔍 === Gradient Analysis - Epoch 0, Batch 671 ===
Total Gradient Norm: 1.000000
✅ Gradient norm is healthy

Top 5 layers by gradient norm (out of 41 tot

Training Epoch 29:  53%|█████▎    | 674/1280 [01:38<01:28,  6.87it/s]


🔍 === Gradient Analysis - Epoch 0, Batch 672 ===
Total Gradient Norm: 1.000000
✅ Gradient norm is healthy

Top 5 layers by gradient norm (out of 41 total):
  ✅ 1. branch_net.transformer_encoder.layers.0.self_attn.in_proj_weight: 0.5661
      Shape: [96, 32], Elements: 3072
      Mean: 0.000003, Std: 0.010214
  ✅ 2. branch_net.transformer_encoder.layers.0.self_attn.out_proj.weight: 0.4610
      Shape: [32, 32], Elements: 1024
      Mean: 0.000014, Std: 0.014406
  ✅ 3. branch_net.output_proj.2.weight: 0.3799
      Shape: [32, 32], Elements: 1024
      Mean: -0.000021, Std: 0.011873
  ✅ 4. trunk_net.net.16.weight: 0.2818
      Shape: [32, 32], Elements: 1024
      Mean: -0.000448, Std: 0.008793
  ✅ 5. branch_net.transformer_encoder.layers.0.linear2.weight: 0.2486
      Shape: [32, 64], Elements: 2048
      Mean: 0.000003, Std: 0.005494

🔍 === Gradient Analysis - Epoch 0, Batch 673 ===
Total Gradient Norm: 1.000000
✅ Gradient norm is healthy

Top 5 layers by gradient norm (out of 41 total

Training Epoch 29:  53%|█████▎    | 676/1280 [01:38<01:27,  6.87it/s]


🔍 === Gradient Analysis - Epoch 0, Batch 674 ===
Total Gradient Norm: 1.000000
✅ Gradient norm is healthy

Top 5 layers by gradient norm (out of 41 total):
  ✅ 1. branch_net.transformer_encoder.layers.0.self_attn.in_proj_weight: 0.5981
      Shape: [96, 32], Elements: 3072
      Mean: -0.000002, Std: 0.010791
  ✅ 2. branch_net.transformer_encoder.layers.0.self_attn.out_proj.weight: 0.4653
      Shape: [32, 32], Elements: 1024
      Mean: -0.000021, Std: 0.014540
  ✅ 3. branch_net.output_proj.2.weight: 0.3091
      Shape: [32, 32], Elements: 1024
      Mean: 0.000049, Std: 0.009659
  ✅ 4. trunk_net.net.16.weight: 0.2698
      Shape: [32, 32], Elements: 1024
      Mean: -0.000267, Std: 0.008428
  ✅ 5. branch_net.transformer_encoder.layers.0.linear2.weight: 0.2179
      Shape: [32, 64], Elements: 2048
      Mean: 0.000015, Std: 0.004816

🔍 === Gradient Analysis - Epoch 0, Batch 675 ===
Total Gradient Norm: 1.000000
✅ Gradient norm is healthy

Top 5 layers by gradient norm (out of 41 tota

Training Epoch 29:  53%|█████▎    | 678/1280 [01:38<01:28,  6.81it/s]


🔍 === Gradient Analysis - Epoch 0, Batch 676 ===
Total Gradient Norm: 1.000000
✅ Gradient norm is healthy

Top 5 layers by gradient norm (out of 41 total):
  ✅ 1. branch_net.transformer_encoder.layers.0.self_attn.in_proj_weight: 0.5612
      Shape: [96, 32], Elements: 3072
      Mean: -0.000002, Std: 0.010126
  ✅ 2. branch_net.transformer_encoder.layers.0.self_attn.out_proj.weight: 0.4586
      Shape: [32, 32], Elements: 1024
      Mean: -0.000031, Std: 0.014331
  ✅ 3. branch_net.output_proj.2.weight: 0.3356
      Shape: [32, 32], Elements: 1024
      Mean: -0.000047, Std: 0.010487
  ✅ 4. trunk_net.net.16.weight: 0.2709
      Shape: [32, 32], Elements: 1024
      Mean: -0.000480, Std: 0.008452
  ✅ 5. branch_net.input_proj.0.weight: 0.2277
      Shape: [32, 3], Elements: 96
      Mean: -0.000000, Std: 0.023243

🔍 === Gradient Analysis - Epoch 0, Batch 677 ===
Total Gradient Norm: 1.000000
✅ Gradient norm is healthy

Top 5 layers by gradient norm (out of 41 total):
  ✅ 1. branch_net.tra

Training Epoch 29:  53%|█████▎    | 683/1280 [01:39<01:24,  7.06it/s]


🔍 === Gradient Analysis - Epoch 0, Batch 681 ===
Total Gradient Norm: 1.000000
✅ Gradient norm is healthy

Top 5 layers by gradient norm (out of 41 total):
  ✅ 1. branch_net.transformer_encoder.layers.0.self_attn.in_proj_weight: 0.5854
      Shape: [96, 32], Elements: 3072
      Mean: 0.000002, Std: 0.010562
  ✅ 2. branch_net.transformer_encoder.layers.0.self_attn.out_proj.weight: 0.4919
      Shape: [32, 32], Elements: 1024
      Mean: 0.000033, Std: 0.015373
  ✅ 3. branch_net.output_proj.2.weight: 0.3221
      Shape: [32, 32], Elements: 1024
      Mean: -0.000003, Std: 0.010065
  ✅ 4. branch_net.transformer_encoder.layers.0.linear2.weight: 0.2497
      Shape: [32, 64], Elements: 2048
      Mean: -0.000010, Std: 0.005518
  ✅ 5. trunk_net.net.16.weight: 0.2283
      Shape: [32, 32], Elements: 1024
      Mean: -0.000397, Std: 0.007124


Training Epoch 29:  54%|█████▎    | 685/1280 [01:39<01:25,  6.97it/s]


🔍 === Gradient Analysis - Epoch 0, Batch 683 ===
Total Gradient Norm: 1.000000
✅ Gradient norm is healthy

Top 5 layers by gradient norm (out of 41 total):
  ✅ 1. branch_net.transformer_encoder.layers.0.self_attn.in_proj_weight: 0.5947
      Shape: [96, 32], Elements: 3072
      Mean: 0.000003, Std: 0.010730
  ✅ 2. branch_net.transformer_encoder.layers.0.self_attn.out_proj.weight: 0.4678
      Shape: [32, 32], Elements: 1024
      Mean: 0.000026, Std: 0.014620
  ✅ 3. branch_net.output_proj.2.weight: 0.2926
      Shape: [32, 32], Elements: 1024
      Mean: -0.000009, Std: 0.009145
  ✅ 4. branch_net.transformer_encoder.layers.0.linear2.weight: 0.2344
      Shape: [32, 64], Elements: 2048
      Mean: -0.000009, Std: 0.005179
  ✅ 5. trunk_net.net.16.weight: 0.2340
      Shape: [32, 32], Elements: 1024
      Mean: 0.000422, Std: 0.007302

🔍 === Gradient Analysis - Epoch 0, Batch 684 ===
Total Gradient Norm: 1.000000
✅ Gradient norm is healthy

Top 5 layers by gradient norm (out of 41 total

Training Epoch 29:  54%|█████▎    | 687/1280 [01:40<01:25,  6.95it/s]


🔍 === Gradient Analysis - Epoch 0, Batch 685 ===
Total Gradient Norm: 1.000000
✅ Gradient norm is healthy

Top 5 layers by gradient norm (out of 41 total):
  ✅ 1. branch_net.transformer_encoder.layers.0.self_attn.in_proj_weight: 0.6011
      Shape: [96, 32], Elements: 3072
      Mean: 0.000002, Std: 0.010845
  ✅ 2. branch_net.transformer_encoder.layers.0.self_attn.out_proj.weight: 0.4619
      Shape: [32, 32], Elements: 1024
      Mean: -0.000005, Std: 0.014433
  ✅ 3. branch_net.output_proj.2.weight: 0.3153
      Shape: [32, 32], Elements: 1024
      Mean: -0.000001, Std: 0.009853
  ✅ 4. trunk_net.net.16.weight: 0.2405
      Shape: [32, 32], Elements: 1024
      Mean: 0.000494, Std: 0.007500
  ✅ 5. branch_net.transformer_encoder.layers.0.linear2.weight: 0.2368
      Shape: [32, 64], Elements: 2048
      Mean: 0.000002, Std: 0.005233

🔍 === Gradient Analysis - Epoch 0, Batch 686 ===
Total Gradient Norm: 1.000000
✅ Gradient norm is healthy

Top 5 layers by gradient norm (out of 41 total

Training Epoch 29:  54%|█████▍    | 690/1280 [01:40<01:22,  7.11it/s]


🔍 === Gradient Analysis - Epoch 0, Batch 688 ===
Total Gradient Norm: 1.000000
✅ Gradient norm is healthy

Top 5 layers by gradient norm (out of 41 total):
  ✅ 1. branch_net.transformer_encoder.layers.0.self_attn.in_proj_weight: 0.5505
      Shape: [96, 32], Elements: 3072
      Mean: -0.000001, Std: 0.009932
  ✅ 2. branch_net.transformer_encoder.layers.0.self_attn.out_proj.weight: 0.4257
      Shape: [32, 32], Elements: 1024
      Mean: -0.000005, Std: 0.013303
  ✅ 3. branch_net.output_proj.2.weight: 0.4153
      Shape: [32, 32], Elements: 1024
      Mean: -0.000031, Std: 0.012979
  ✅ 4. trunk_net.net.16.weight: 0.3219
      Shape: [32, 32], Elements: 1024
      Mean: -0.000643, Std: 0.010038
  ✅ 5. branch_net.transformer_encoder.layers.0.linear2.weight: 0.2124
      Shape: [32, 64], Elements: 2048
      Mean: 0.000003, Std: 0.004694

🔍 === Gradient Analysis - Epoch 0, Batch 689 ===
Total Gradient Norm: 1.000000
✅ Gradient norm is healthy

Top 5 layers by gradient norm (out of 41 tot

Training Epoch 29:  54%|█████▍    | 692/1280 [01:40<01:26,  6.76it/s]


🔍 === Gradient Analysis - Epoch 0, Batch 690 ===
Total Gradient Norm: 1.000000
✅ Gradient norm is healthy

Top 5 layers by gradient norm (out of 41 total):
  ✅ 1. branch_net.transformer_encoder.layers.0.self_attn.in_proj_weight: 0.4925
      Shape: [96, 32], Elements: 3072
      Mean: 0.000002, Std: 0.008886
  ✅ 2. branch_net.transformer_encoder.layers.0.self_attn.out_proj.weight: 0.4873
      Shape: [32, 32], Elements: 1024
      Mean: 0.000012, Std: 0.015227
  ✅ 3. branch_net.output_proj.2.weight: 0.3813
      Shape: [32, 32], Elements: 1024
      Mean: -0.000008, Std: 0.011914
  ✅ 4. trunk_net.net.16.weight: 0.2677
      Shape: [32, 32], Elements: 1024
      Mean: -0.001356, Std: 0.008255
  ✅ 5. branch_net.transformer_encoder.layers.0.linear2.weight: 0.2472
      Shape: [32, 64], Elements: 2048
      Mean: 0.000020, Std: 0.005463

🔍 === Gradient Analysis - Epoch 0, Batch 691 ===
Total Gradient Norm: 1.000000
✅ Gradient norm is healthy

Top 5 layers by gradient norm (out of 41 total

Training Epoch 29:  54%|█████▍    | 694/1280 [01:41<01:27,  6.67it/s]

#Evaluation

In [None]:
# Load the trained model

hidden_dim = 64
latent_dim = 128
batch_size = 128
epochs = 1000
portfolio_feature_dim = 3

PORT_LEN = 200
PORT_SAMPLE_SIZE = 25600
FEED_ST_LEN_EACH_PORT = 20
SAMPLE_SIZE_SCALAR = 20
LAMBDA_DERIV = 0.1



final_save_path = '/content/drive/MyDrive/Ucl/'
model_path = final_save_path + 'best_deeponet_model_V2.pt' # Construct the full path to the model file

deeponet_model = OptimizedDeepONet(portfolio_feature_dim=portfolio_feature_dim, hidden_dim=hidden_dim, latent_dim=latent_dim).to(DEVICE)

# Load the state_dict, mapping to the correct device
deeponet_model.load_state_dict(torch.load(model_path, map_location=DEVICE))

deeponet_model.eval()




dataset = OperatorDatasetStandarization(num_samples=1, portfolio_size=100, num_samples_S_T=20)
portfolio, S_T_i, true_cashflows, derivative_i = dataset[0]


# Predict cashflows with the model
with torch.no_grad():
    predicted_cashflows = deeponet_model(portfolio, S_T_i)

s_t_values_sorted, indices = torch.sort(S_T_i, dim=0)

# Debugging prints
print('=====Portfolio=====')
print(portfolio)
print('=====S_T Sorted=====')
print(s_t_values_sorted)
print('=====True Cashflows=====')
print(true_cashflows)
print('=====Predicted Cashflows=====')
print(predicted_cashflows)

# Plotting
plt.figure(figsize=(10, 6))
plt.plot(s_t_values_sorted.cpu().numpy(), true_cashflows, label='True Cashflow', color='blue')
plt.plot(s_t_values_sorted.cpu().numpy(), predicted_cashflows, label='Predicted Cashflow', color='orange', linestyle='--')
plt.xlabel('S_T (Absolute Value)')
plt.ylabel('Cashflow')
plt.title('True vs. Predicted Cashflow for a Single Portfolio (Sorted S_T)')
plt.legend()
plt.grid(True)
plt.show()


RuntimeError: Error(s) in loading state_dict for OptimizedDeepONet:
	size mismatch for branch_net.input_proj.0.weight: copying a param with shape torch.Size([32, 3]) from checkpoint, the shape in current model is torch.Size([64, 3]).
	size mismatch for branch_net.input_proj.0.bias: copying a param with shape torch.Size([32]) from checkpoint, the shape in current model is torch.Size([64]).
	size mismatch for branch_net.input_proj.1.weight: copying a param with shape torch.Size([32]) from checkpoint, the shape in current model is torch.Size([64]).
	size mismatch for branch_net.input_proj.1.bias: copying a param with shape torch.Size([32]) from checkpoint, the shape in current model is torch.Size([64]).
	size mismatch for branch_net.transformer_encoder.layers.0.self_attn.in_proj_weight: copying a param with shape torch.Size([96, 32]) from checkpoint, the shape in current model is torch.Size([192, 64]).
	size mismatch for branch_net.transformer_encoder.layers.0.self_attn.in_proj_bias: copying a param with shape torch.Size([96]) from checkpoint, the shape in current model is torch.Size([192]).
	size mismatch for branch_net.transformer_encoder.layers.0.self_attn.out_proj.weight: copying a param with shape torch.Size([32, 32]) from checkpoint, the shape in current model is torch.Size([64, 64]).
	size mismatch for branch_net.transformer_encoder.layers.0.self_attn.out_proj.bias: copying a param with shape torch.Size([32]) from checkpoint, the shape in current model is torch.Size([64]).
	size mismatch for branch_net.transformer_encoder.layers.0.linear1.weight: copying a param with shape torch.Size([64, 32]) from checkpoint, the shape in current model is torch.Size([128, 64]).
	size mismatch for branch_net.transformer_encoder.layers.0.linear1.bias: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for branch_net.transformer_encoder.layers.0.linear2.weight: copying a param with shape torch.Size([32, 64]) from checkpoint, the shape in current model is torch.Size([64, 128]).
	size mismatch for branch_net.transformer_encoder.layers.0.linear2.bias: copying a param with shape torch.Size([32]) from checkpoint, the shape in current model is torch.Size([64]).
	size mismatch for branch_net.transformer_encoder.layers.0.norm1.weight: copying a param with shape torch.Size([32]) from checkpoint, the shape in current model is torch.Size([64]).
	size mismatch for branch_net.transformer_encoder.layers.0.norm1.bias: copying a param with shape torch.Size([32]) from checkpoint, the shape in current model is torch.Size([64]).
	size mismatch for branch_net.transformer_encoder.layers.0.norm2.weight: copying a param with shape torch.Size([32]) from checkpoint, the shape in current model is torch.Size([64]).
	size mismatch for branch_net.transformer_encoder.layers.0.norm2.bias: copying a param with shape torch.Size([32]) from checkpoint, the shape in current model is torch.Size([64]).
	size mismatch for branch_net.output_proj.0.weight: copying a param with shape torch.Size([32]) from checkpoint, the shape in current model is torch.Size([64]).
	size mismatch for branch_net.output_proj.0.bias: copying a param with shape torch.Size([32]) from checkpoint, the shape in current model is torch.Size([64]).
	size mismatch for branch_net.output_proj.2.weight: copying a param with shape torch.Size([64, 32]) from checkpoint, the shape in current model is torch.Size([128, 64]).
	size mismatch for branch_net.output_proj.2.bias: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for trunk_net.net.0.weight: copying a param with shape torch.Size([32, 1]) from checkpoint, the shape in current model is torch.Size([64, 1]).
	size mismatch for trunk_net.net.0.bias: copying a param with shape torch.Size([32]) from checkpoint, the shape in current model is torch.Size([64]).
	size mismatch for trunk_net.net.1.weight: copying a param with shape torch.Size([32]) from checkpoint, the shape in current model is torch.Size([64]).
	size mismatch for trunk_net.net.1.bias: copying a param with shape torch.Size([32]) from checkpoint, the shape in current model is torch.Size([64]).
	size mismatch for trunk_net.net.4.weight: copying a param with shape torch.Size([32, 32]) from checkpoint, the shape in current model is torch.Size([64, 64]).
	size mismatch for trunk_net.net.4.bias: copying a param with shape torch.Size([32]) from checkpoint, the shape in current model is torch.Size([64]).
	size mismatch for trunk_net.net.5.weight: copying a param with shape torch.Size([32]) from checkpoint, the shape in current model is torch.Size([64]).
	size mismatch for trunk_net.net.5.bias: copying a param with shape torch.Size([32]) from checkpoint, the shape in current model is torch.Size([64]).
	size mismatch for trunk_net.net.8.weight: copying a param with shape torch.Size([32, 32]) from checkpoint, the shape in current model is torch.Size([64, 64]).
	size mismatch for trunk_net.net.8.bias: copying a param with shape torch.Size([32]) from checkpoint, the shape in current model is torch.Size([64]).
	size mismatch for trunk_net.net.9.weight: copying a param with shape torch.Size([32]) from checkpoint, the shape in current model is torch.Size([64]).
	size mismatch for trunk_net.net.9.bias: copying a param with shape torch.Size([32]) from checkpoint, the shape in current model is torch.Size([64]).
	size mismatch for trunk_net.net.12.weight: copying a param with shape torch.Size([32, 32]) from checkpoint, the shape in current model is torch.Size([64, 64]).
	size mismatch for trunk_net.net.12.bias: copying a param with shape torch.Size([32]) from checkpoint, the shape in current model is torch.Size([64]).
	size mismatch for trunk_net.net.13.weight: copying a param with shape torch.Size([32]) from checkpoint, the shape in current model is torch.Size([64]).
	size mismatch for trunk_net.net.13.bias: copying a param with shape torch.Size([32]) from checkpoint, the shape in current model is torch.Size([64]).
	size mismatch for trunk_net.net.16.weight: copying a param with shape torch.Size([64, 32]) from checkpoint, the shape in current model is torch.Size([128, 64]).
	size mismatch for trunk_net.net.16.bias: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([128]).