In [1]:
import numpy as np
import pandas as pd

from tqdm import tqdm

import torch
from torch.utils.tensorboard import SummaryWriter
from pytorch_lightning import seed_everything

from utils.data import get_hsm_dataset, split_data, log_returns
from utils.metrics import MAPE, WAPE, MAE
from utils.TTS_GAN import TTS_GAN_Generator, TTS_GAN_Discriminator, weights_init, train_TTS_GAN

In [2]:
dataset_path = "data/huge_stock_market_dataset/"
synthetic_path = f"{dataset_path}synthetic/TTS_GAN/"
models_dir = "models/"

In [3]:
device = gpu = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(device)

val_size = 0.15
test_size = 0.0

lr = 2e-4
wd = 0
ctrl_lr = 3.5e-4
beta1 = 0.0
beta2 = 0.9
max_epoch = 200
latent_dim = 128
batch_size = gen_batch_size = dis_batch_size = 64
ema = 0.995
ema_kimg = 500
ema_warmup = 0
world_size = 0
rank = - 1
print_freq = 200
n_critic = 1
phi = 1
accumulated_times = g_accumulated_times = 1
loss = "standard"
seq_len = 150

n_samples = 1600 * 127  # number of samples generated by QuantGAN

cuda:0


In [4]:
ts_iterator = get_hsm_dataset(dataset_path, selected_files=f"{dataset_path}/selected.csv")
seed_everything(0)

start_point = 11
for _ in range(start_point): next(ts_iterator)

for ts_index, time_series in enumerate(ts_iterator, start=start_point):
    print(f"Time Series #{ts_index}")
    
    (train_ts, *_), *_ = split_data(time_series, val_size=val_size, test_size=test_size)
    train_ts = log_returns(train_ts)
    # train_ts = train_ts[:len(train_ts) // 15 * 15]
    # train_dl = torch.utils.data.DataLoader([torch.from_numpy(train_ts.values.reshape(1, 1, - 1)).to(device)], batch_size=batch_size, shuffle=True)
    train_ts = np.array([train_ts[i: i + 150] for i in range(len(train_ts) - seq_len)])
    train_dl = torch.utils.data.DataLoader(torch.from_numpy(train_ts.reshape(- 1, 1, 1, seq_len)).to(device), batch_size=batch_size, shuffle=True)

    TTS_GAN_gen = TTS_GAN_Generator(seq_len=seq_len, channels=1, latent_dim=latent_dim, ).to(device)
    TTS_GAN_dis = TTS_GAN_Discriminator(seq_length=seq_len, in_channels=1).to(device)

    gen_optimizer = torch.optim.AdamW(filter(lambda p: p.requires_grad, TTS_GAN_gen.parameters()), lr)
    dis_optimizer = torch.optim.AdamW(filter(lambda p: p.requires_grad, TTS_GAN_dis.parameters()), lr)
    
    for epoch in range(max_epoch):
        losses = train_TTS_GAN(globals(), TTS_GAN_gen, TTS_GAN_dis, gen_optimizer, dis_optimizer, train_dl, epoch)
    tqdm.write(f"generator loss: {losses[0]: 0.4f} discriminator loss: {losses[1]: 0.4f}")
    del dis_optimizer, gen_optimizer, TTS_GAN_dis, train_dl
    torch.cuda.empty_cache()

    samples_to_gen = n_samples // seq_len
    synth_data = []
    with torch.no_grad():
        for _ in range(samples_to_gen):
            z = torch.cuda.FloatTensor(np.random.normal(0, 1, (1, latent_dim))).cuda(device, non_blocking=True)
            synth_data.append(TTS_GAN_gen(z).cpu().numpy())
            del z
            torch.cuda.empty_cache()
    np.save(synthetic_path + f"selected{ts_index}.npy", np.row_stack(synth_data))

    del TTS_GAN_gen, synth_data
    torch.cuda.empty_cache()

Global seed set to 0


Time Series #11
generator loss: -0.5841 discriminator loss:  1.4286
Time Series #12
generator loss: -0.6188 discriminator loss:  1.4447
Time Series #13
generator loss: -0.6191 discriminator loss:  1.4387
Time Series #14
generator loss: -0.5984 discriminator loss:  1.4094
Time Series #15
generator loss: -0.6206 discriminator loss:  1.4470
Time Series #16
generator loss: -0.6185 discriminator loss:  1.4455
Time Series #17
generator loss: -0.6186 discriminator loss:  1.4416
Time Series #18
generator loss: -0.6084 discriminator loss:  1.4259
Time Series #19
generator loss: -0.5791 discriminator loss:  1.3510
Time Series #20
generator loss: -0.3132 discriminator loss:  0.7786
Time Series #21
generator loss: -0.6185 discriminator loss:  1.4434
Time Series #22
generator loss: -0.6173 discriminator loss:  1.4418
Time Series #23
generator loss: -0.6180 discriminator loss:  1.4450


first 11 time: 38 min
another 13 time: 176 min