In [None]:
import torch
import torch.nn as nn

import numpy as np

from factorVAE.factor_VAE import FactorVAE

In [None]:
features_dataset = torch.load("./dataset/basic_feat/features.pt")
returns_dataset = torch.load("./dataset/basic_feat/returns.pt")

print(f"Total step: {features_dataset.shape[0]}")
print(f"Time span: {features_dataset.shape[1]}")
print(f"Stock size: {features_dataset.shape[2]}")
print(f"Feature size: {features_dataset.shape[3]}")

In [None]:
batch_size = 128
characteristic_size = 5
stock_size = 100
latent_size = 16
factor_size = 8
time_span = 60
gru_input_size = 8
hidden_size = [16, 32, 16]
lr = 1e-4
epochs = 80

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
from torch.utils.data import TensorDataset, DataLoader

def get_dataloader(data, label, device=device, batch_size=batch_size):
    data = torch.Tensor(data).to(device)
    label = torch.Tensor(label).to(device).long()
    ds = TensorDataset(data, label)
    dl = DataLoader(ds, batch_size=batch_size, shuffle=True)
    return dl

In [None]:
train_dl = get_dataloader(features_dataset, returns_dataset)

In [None]:
factor_VAE = FactorVAE(
    characteristic_size=characteristic_size,
    stock_size=stock_size,
    latent_size=latent_size,
    factor_size=factor_size,
    time_span=time_span,
    gru_input_size=gru_input_size,
    hidden_size=hidden_size
).to(device)

In [None]:
optimizer = torch.optim.Adam(factor_VAE.parameters(), lr=lr)

In [None]:
def train_loop(dataloader, model, optimizer):
    for batch, (feat, ret) in enumerate(train_dl):
        loss = model.run_model(feat, ret)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        if batch % 10 == 0:
            print(f"batch: {batch}, loss: {loss.item()}")

In [None]:
for i in range(epochs):
    print(f"=== Epoch: {i} ===")
    train_loop(train_dl, factor_VAE, optimizer)

In [None]:
returns_dataset[100]

In [None]:
result = factor_VAE.prediction(features_dataset[100].unsqueeze(0).to(device))

In [None]:
result[1]


In [None]:
result[2]

In [None]:
from torch.distributions import Normal

n = Normal(result[1], result[2])
n.sample()
