In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset, TensorDataset
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import os
from tqdm.auto import tqdm
import argparse
from factorvae import FactorVAE, FeatureExtractor, FactorDecoder, FactorEncoder, FactorPredictor, AlphaLayer, BetaLayer
from dataset import StockDataset

#### **Set Parameters**

In [3]:
args = {
    'batch_size': 300,
    'seq_len': 20,
    'num_latent': 158,
    'hidden_size': 20,
    'num_factor': 8,
    'lr': 0.0005,
    'num_epochs': 25
}

#### **Load Datasets**

In [5]:
df_train = pd.read_pickle('data/train.pkl')
df_valid = pd.read_pickle('data/valid.pkl')
df_test = pd.read_pickle('data/test.pkl')

df_train.columns = df_train.columns.droplevel(level=0)
df_valid.columns = df_valid.columns.droplevel(level=0)
df_test.columns = df_test.columns.droplevel(level=0)

In [6]:
ds_train = StockDataset(df_train, args['batch_size'], args['seq_len'])
ds_valid = StockDataset(df_valid, args['batch_size'], args['seq_len'])
ds_test = StockDataset(df_test, args['batch_size'], args['seq_len'])

In [7]:
train_dataloader = DataLoader(ds_train, batch_size=300, shuffle=False)
valid_dataloader = DataLoader(ds_valid, batch_size=300, shuffle=False)
test_dataloader = DataLoader(ds_test, batch_size=300, shuffle=False)

In [8]:
check_dataloader = DataLoader(ds_valid, batch_size=1, shuffle=False)

In [10]:
for hist, futr in check_dataloader:
    print(hist)
    print(hist.shape)
    print(futr)
    print(futr.shape)
    break

tensor([[[-1.1951,  1.8856, -0.5353,  ...,  0.5397,  1.7871,  1.5821],
         [ 0.0000, -1.7979,  0.0000,  ...,  3.0000,  3.0000,  3.0000],
         [-3.0000,  3.0000, -1.2462,  ...,  3.0000,  3.0000,  3.0000],
         ...,
         [ 1.7401,  1.6219,  0.8395,  ...,  0.2356, -0.5692, -0.9325],
         [ 0.0000,  1.0907,  0.0000,  ...,  2.3406,  2.8888,  3.0000],
         [ 1.0017,  0.3257,  0.7783,  ..., -0.1469, -1.2197,  0.6544]]],
       dtype=torch.float64)
torch.Size([1, 20, 158])
tensor([[0.0714, 0.1000, 0.1001, 0.1000, 0.0605, 0.1002, 0.0545, 0.0400, 0.0838,
         0.1000, 0.0725, 0.0778, 0.0900, 0.0627, 0.1002, 0.1004, 0.1009, 0.0985,
         0.0999, 0.0882]], dtype=torch.float64)
torch.Size([1, 20])


#### **Build FactorVAE Model**

In [18]:
feature_extractor = FeatureExtractor(num_latent = args['num_latent'], hidden_size = args['hidden_size'])

factor_encoder = FactorEncoder(num_factors = args['num_factor'], num_portfolio = args['num_latent'], hidden_size = args['hidden_size'])

alpha_layer = AlphaLayer(args['hidden_size'])
beta_layer = BetaLayer(args['hidden_size'], args['num_factor'])
factor_decoder = FactorDecoder(alpha_layer, beta_layer)

factor_predictor = FactorPredictor(args['batch_size'], args['hidden_size'], args['num_factor'])

factorVAE = FactorVAE(feature_extractor, factor_encoder, factor_decoder, factor_predictor)

#### **Train the Model**

In [19]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cpu')

In [20]:
factorVAE.to(device)

best_val_loss = 10000.0
optimizer = torch.optim.Adam(factorVAE.parameters(), lr = args['lr'])
scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr = args['lr'], \
    steps_per_epoch = len(train_dataloader), epochs=args['num_epochs'])

In [23]:
def train(factor_model, dataloader, optimizer, args):

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    factor_model.to(device)
    factor_model.train()

    total_loss = 0

    with tqdm(total=len(dataloader)-args['seq_len']+1) as pbar:

        for char, returns in dataloader:
            if char.shape[1] != args['seq_len']:
                continue
            inputs = char.to(device)
            labels = returns[:,-1].reshape(-1,1).to(device)
            inputs = inputs.float()
            labels = labels.float()
            
            optimizer.zero_grad()
            # print(inputs.shape)
            # print(labels.shape)
            loss, reconstruction, factor_mu, factor_sigma, pred_mu, pred_sigma = factor_model(inputs, labels)
            total_loss += loss.item() * inputs.size(0)
            loss.backward()
            optimizer.step()
            pbar.update(1)
        # print(loss)
    avg_loss = total_loss / len(dataloader.dataset)
    return avg_loss


@torch.no_grad()
def validate(factor_model, dataloader, args):

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    factor_model.to(device)
    factor_model.eval()
    total_loss = 0

    with tqdm(total=len(dataloader)-args['seq_len']+1) as pbar:
        for char, returns in dataloader:
            if char.shape[1] != args['seq_len']:
                continue
            inputs = char.to(device)
            labels = returns[:,-1].reshape(-1,1).to(device)
            inputs = inputs.float()
            labels = labels.float()
            
            loss, reconstruction, factor_mu, factor_sigma, pred_mu, pred_sigma = factor_model(inputs, labels)
            total_loss += loss.item() * inputs.size(0)
            pbar.update(1)
            
    avg_loss = total_loss / len(dataloader.dataset)
    return avg_loss

In [24]:
for epoch in tqdm(range(args['num_epochs'])):

    train_loss = train(factorVAE, train_dataloader, optimizer, args)
    val_loss = validate(factorVAE, valid_dataloader, args)

    scheduler.step()
    print(f"Epoch {epoch+1}: Train Loss: {train_loss:.4f}, Validation Loss: {val_loss:.4f}") 

    if val_loss < best_val_loss:
        best_val_loss = val_loss
 
        torch.save(factorVAE.state_dict(), "model.pt")

100%|██████████| 1671/1671 [03:26<00:00,  8.08it/s]
100%|██████████| 244/244 [00:21<00:00, 11.23it/s]
  4%|▍         | 1/25 [03:48<1:31:22, 228.45s/it]

Epoch 1: Train Loss: 0.3387, Validation Loss: 0.1029


100%|██████████| 1671/1671 [03:27<00:00,  8.06it/s]
100%|██████████| 244/244 [00:21<00:00, 11.41it/s]
  8%|▊         | 2/25 [07:37<1:27:37, 228.59s/it]

Epoch 2: Train Loss: 0.0502, Validation Loss: 0.0182


100%|██████████| 1671/1671 [03:25<00:00,  8.14it/s]
100%|██████████| 244/244 [00:21<00:00, 11.41it/s]
 12%|█▏        | 3/25 [11:23<1:23:29, 227.71s/it]

Epoch 3: Train Loss: 0.0118, Validation Loss: 0.0059


100%|██████████| 1671/1671 [03:26<00:00,  8.09it/s]
100%|██████████| 244/244 [00:21<00:00, 11.27it/s]
 16%|█▌        | 4/25 [15:11<1:19:45, 227.89s/it]

Epoch 4: Train Loss: 0.0046, Validation Loss: 0.0027


100%|██████████| 1671/1671 [03:27<00:00,  8.06it/s]
100%|██████████| 244/244 [00:21<00:00, 11.39it/s]
 20%|██        | 5/25 [19:00<1:16:04, 228.20s/it]

Epoch 5: Train Loss: 0.0024, Validation Loss: 0.0015


100%|██████████| 1671/1671 [03:25<00:00,  8.15it/s]
100%|██████████| 244/244 [00:21<00:00, 11.32it/s]
 24%|██▍       | 6/25 [22:47<1:12:05, 227.66s/it]

Epoch 6: Train Loss: 0.0015, Validation Loss: 0.0010


100%|██████████| 1671/1671 [03:24<00:00,  8.15it/s]
100%|██████████| 244/244 [00:21<00:00, 11.29it/s]
 28%|██▊       | 7/25 [26:33<1:08:11, 227.32s/it]

Epoch 7: Train Loss: 0.0011, Validation Loss: 0.0008


100%|██████████| 1671/1671 [03:25<00:00,  8.12it/s]
100%|██████████| 244/244 [00:21<00:00, 11.39it/s]
 32%|███▏      | 8/25 [30:21<1:04:23, 227.25s/it]

Epoch 8: Train Loss: 0.0009, Validation Loss: 0.0006


100%|██████████| 1671/1671 [03:25<00:00,  8.13it/s]
100%|██████████| 244/244 [00:21<00:00, 11.40it/s]
 36%|███▌      | 9/25 [34:08<1:00:34, 227.17s/it]

Epoch 9: Train Loss: 0.0008, Validation Loss: 0.0006


100%|██████████| 1671/1671 [03:25<00:00,  8.12it/s]
100%|██████████| 244/244 [00:21<00:00, 11.34it/s]
 40%|████      | 10/25 [37:55<56:48, 227.21s/it] 

Epoch 10: Train Loss: 0.0007, Validation Loss: 0.0005


100%|██████████| 1671/1671 [03:25<00:00,  8.15it/s]
100%|██████████| 244/244 [00:21<00:00, 11.40it/s]
 44%|████▍     | 11/25 [41:41<52:58, 227.02s/it]

Epoch 11: Train Loss: 0.0007, Validation Loss: 0.0005


100%|██████████| 1671/1671 [03:24<00:00,  8.16it/s]
100%|██████████| 244/244 [00:21<00:00, 11.40it/s]
 48%|████▊     | 12/25 [45:28<49:07, 226.75s/it]

Epoch 12: Train Loss: 0.0007, Validation Loss: 0.0005


100%|██████████| 1671/1671 [03:25<00:00,  8.15it/s]
100%|██████████| 244/244 [00:21<00:00, 11.38it/s]
 52%|█████▏    | 13/25 [49:14<45:20, 226.69s/it]

Epoch 13: Train Loss: 0.0007, Validation Loss: 0.0005


100%|██████████| 1671/1671 [03:25<00:00,  8.15it/s]
100%|██████████| 244/244 [00:21<00:00, 11.34it/s]
 56%|█████▌    | 14/25 [53:01<41:33, 226.68s/it]

Epoch 14: Train Loss: 0.0007, Validation Loss: 0.0005


100%|██████████| 1671/1671 [03:25<00:00,  8.11it/s]
100%|██████████| 244/244 [00:21<00:00, 11.30it/s]
 60%|██████    | 15/25 [56:48<37:49, 226.94s/it]

Epoch 15: Train Loss: 0.0007, Validation Loss: 0.0005


100%|██████████| 1671/1671 [03:25<00:00,  8.13it/s]
100%|██████████| 244/244 [00:21<00:00, 11.37it/s]
 64%|██████▍   | 16/25 [1:00:35<34:02, 226.99s/it]

Epoch 16: Train Loss: 0.0007, Validation Loss: 0.0005


100%|██████████| 1671/1671 [03:24<00:00,  8.15it/s]
100%|██████████| 244/244 [00:21<00:00, 11.42it/s]
 68%|██████▊   | 17/25 [1:04:22<30:14, 226.78s/it]

Epoch 17: Train Loss: 0.0007, Validation Loss: 0.0005


100%|██████████| 1671/1671 [03:25<00:00,  8.14it/s]
100%|██████████| 244/244 [00:21<00:00, 11.39it/s]
 72%|███████▏  | 18/25 [1:08:08<26:27, 226.78s/it]

Epoch 18: Train Loss: 0.0006, Validation Loss: 0.0005


100%|██████████| 1671/1671 [03:24<00:00,  8.16it/s]
100%|██████████| 244/244 [00:21<00:00, 11.48it/s]
 76%|███████▌  | 19/25 [1:11:55<22:39, 226.55s/it]

Epoch 19: Train Loss: 0.0006, Validation Loss: 0.0005


100%|██████████| 1671/1671 [03:24<00:00,  8.16it/s]
100%|██████████| 244/244 [00:21<00:00, 11.48it/s]
 80%|████████  | 20/25 [1:15:41<18:52, 226.41s/it]

Epoch 20: Train Loss: 0.0006, Validation Loss: 0.0005


100%|██████████| 1671/1671 [03:25<00:00,  8.15it/s]
100%|██████████| 244/244 [00:21<00:00, 11.42it/s]
 84%|████████▍ | 21/25 [1:19:27<15:05, 226.41s/it]

Epoch 21: Train Loss: 0.0006, Validation Loss: 0.0005


100%|██████████| 1671/1671 [03:22<00:00,  8.24it/s]
100%|██████████| 244/244 [00:21<00:00, 11.61it/s]
 88%|████████▊ | 22/25 [1:23:11<11:16, 225.66s/it]

Epoch 22: Train Loss: 0.0006, Validation Loss: 0.0005


100%|██████████| 1671/1671 [03:15<00:00,  8.54it/s]
100%|██████████| 244/244 [00:21<00:00, 11.39it/s]
 92%|█████████▏| 23/25 [1:26:48<07:26, 223.11s/it]

Epoch 23: Train Loss: 0.0006, Validation Loss: 0.0005


100%|██████████| 1671/1671 [03:20<00:00,  8.32it/s]
100%|██████████| 244/244 [00:20<00:00, 11.68it/s]
 96%|█████████▌| 24/25 [1:30:30<03:42, 222.71s/it]

Epoch 24: Train Loss: 0.0006, Validation Loss: 0.0005


100%|██████████| 1671/1671 [03:23<00:00,  8.21it/s]
100%|██████████| 244/244 [00:21<00:00, 11.52it/s]
100%|██████████| 25/25 [1:34:15<00:00, 226.20s/it]

Epoch 25: Train Loss: 0.0006, Validation Loss: 0.0005



