Train grid based and pointwise vol surface vae

In [1]:
import sys
import os
from dotenv import load_dotenv
load_dotenv()
sys.path.insert(0, os.getenv('SRC_PATH'))

import numpy as np
import pandas as pd
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.pyplot as plt
import seaborn as sns
import sqlite3
from src.volsurface import GridInterpVolSurface, KernelVolSurface

import torch
from torch.utils.data import DataLoader, TensorDataset
from torchvision import transforms
from src.train import Trainer

import json

In [2]:
CSV_PATH = os.getenv('CSV_PATH')
SRC_PATH = os.getenv('SRC_PATH')
os.chdir(SRC_PATH)

In [12]:
from src.utils.logger import setup_logger

# Set up logger
logger = setup_logger(__name__, level="INFO")

## grid


In [14]:
model_name = "vae_v2"
train_model = True
load_model = False
save_model = True
data_dir = CSV_PATH + "/predicted_vol_surfaces.json"  # Path to the volatility surfaces dataset
batch_size = 32
epochs = 10

# Initialize the trainer
trainer = Trainer(model_name)

if train_model:
    trainer.create_model()

    if load_model:
        trainer.load_model(f"params/{trainer.model_name}.pth")
    else:
        # Create dataset
        transform = transforms.Compose([transforms.ToTensor()])
        # Load the JSON file
        with open(data_dir, "r") as f:
            data = json.load(f)

        vol_surfaces = []
        for key in data:
            surface = torch.tensor(data[key], dtype=torch.float32)
            vol_surfaces.append(surface.flatten())  # Flatten 2D to 1D

        data_tensor = torch.stack(vol_surfaces)
        dummy_labels = torch.zeros(len(data_tensor))

        dataset = TensorDataset(data_tensor, dummy_labels)
        train_loader = DataLoader(
            dataset, 
            batch_size=trainer.batch_size,
            shuffle=True
        )

        # Train the model
        for epoch in range(epochs):
            logger.info(f"Epoch {epoch + 1}/{epochs}")
            trainer.train(train_loader)

        if save_model:
            torch.save(
                trainer.model.state_dict(), f"params/{trainer.model_name}.pth"
            )

    # Evaluate the model
    trainer.evaluate("output")

[2025-04-14 23:10:15] [INFO] src.train: Using device: mps
[2025-04-14 23:10:15] [INFO] __main__: Epoch 1/10
[2025-04-14 23:10:15] [INFO] src.train: Loss: 7.8364
[2025-04-14 23:10:15] [INFO] __main__: Epoch 2/10
[2025-04-14 23:10:15] [INFO] src.train: Loss: 7.2382
[2025-04-14 23:10:15] [INFO] __main__: Epoch 3/10
[2025-04-14 23:10:15] [INFO] src.train: Loss: 6.7007
[2025-04-14 23:10:15] [INFO] __main__: Epoch 4/10
[2025-04-14 23:10:15] [INFO] src.train: Loss: 6.1634
[2025-04-14 23:10:15] [INFO] __main__: Epoch 5/10
[2025-04-14 23:10:15] [INFO] src.train: Loss: 5.7021
[2025-04-14 23:10:15] [INFO] __main__: Epoch 6/10
[2025-04-14 23:10:15] [INFO] src.train: Loss: 5.2446
[2025-04-14 23:10:15] [INFO] __main__: Epoch 7/10
[2025-04-14 23:10:15] [INFO] src.train: Loss: 4.7676
[2025-04-14 23:10:15] [INFO] __main__: Epoch 8/10
[2025-04-14 23:10:15] [INFO] src.train: Loss: 4.3850
[2025-04-14 23:10:15] [INFO] __main__: Epoch 9/10
[2025-04-14 23:10:15] [INFO] src.train: Loss: 3.9390
[2025-04-14 23:

## pointwise