In [1]:
import sys
sys.path.append('../../../')
import matplotlib.pyplot as plt
import torch
from torch.utils.data import TensorDataset, DataLoader
import numpy as np
from time import time

is_cuda = torch.cuda.is_available()
device = 'cuda' if is_cuda else 'cpu'
if not is_cuda:
    print("Warning: CUDA not available; falling back to CPU but this is likely to be very slow.")
    
torch.set_printoptions(precision=3, sci_mode=False)
np.set_printoptions(suppress=True)

In [2]:
from lib.Baseline import ConditionalGenerator, Critic
from lib.Utilities import get_n_params
from lib.Training_wgan import gradient_penalty

In [3]:
data = torch.load('../../../AR/data/data.pt')

for dataset in data:
    data[dataset] = data[dataset].float()

In [4]:
C = Critic(input_size = [1,1], hidden_size = [32,32], num_layers = [5,5], sep=data['X_train'].shape[1]).float().to(device)
G = ConditionalGenerator(1, 1, 32, 5, 5).float().to('cuda')

print(f"Total number of parameters of the Generator: {get_n_params(G):10}")
print(f"Total number of parameters of the Critic: {get_n_params(C):10}")

Total number of parameters of the Generator:      77089
Total number of parameters of the Critic:      76609


In [5]:
G_optimizer = torch.optim.RMSprop(G.parameters(), lr=1e-3)
C_optimizer = torch.optim.RMSprop(C.parameters(), lr=1e-3)

q = data['Y_train'].shape[1]-1

In [6]:
hp = {'C_optimizer': C_optimizer, 'G_optimizer': G_optimizer, 'gp_lambda': 10, 
     'nsteps_disc': 10, 'batch_size': 528}

In [7]:
batch_size = np.arange(10, 600, 10).tolist()
batch_size_time = torch.zeros([len(batch_size), 25])
batch_size_memory = torch.zeros([len(batch_size), 25])

In [8]:
for i in range(len(batch_size)):
    batch_size_ = batch_size[i]
    train_dataset = TensorDataset(data['X_train'][:, :, 1:], data['Y_train'][:, :, 1:])
    train_dataloader = DataLoader(train_dataset, batch_size = batch_size_, shuffle=True)
    infinite_dataloader = (elem for it in iter(lambda: train_dataloader, None) for elem in it)
    
    G, C = G.to(device), C.to(device)

    for step in range(25):
        start_time = time()
        for param in G.parameters():
            param.requires_grad = False
        for param in C.parameters():
            param.requires_grad = True

        for _ in range(hp['nsteps_disc']):
            C_optimizer.zero_grad()
            batch_x, batch_y_real = next(infinite_dataloader)
            batch_x, batch_y_real = batch_x.to(device), batch_y_real.to(device)
            batch_y_fake = G(batch_x, batch_y_real.shape[1]-1)
            batch_real = torch.cat([batch_x, batch_y_real], dim=1)
            batch_fake = torch.cat([batch_x, batch_y_fake], dim=1)
            fake_score, real_score = torch.mean(C(batch_fake)), torch.mean(C(batch_real))
            gp = gradient_penalty(C, batch_real.detach(), batch_fake.detach())
            loss = fake_score - real_score + hp['gp_lambda']*gp
            loss.backward()
            C_optimizer.step()

            del batch_x, batch_y_real, batch_y_fake, batch_real, batch_fake

        for param in G.parameters():
            param.requires_grad = True
        for param in C.parameters():
            param.requires_grad = False

        G_optimizer.zero_grad()
        batch_x, batch_y_real = next(infinite_dataloader)
        batch_x, batch_y_real = batch_x.to(device), batch_y_real.to(device)
        batch_y_fake = G(batch_x, batch_y_real.shape[1]-1)
        batch_fake = torch.cat([batch_x, batch_y_fake], dim=1)
        fake_score = torch.mean(C(batch_fake))
        loss = -fake_score 
        loss.backward()
        G_optimizer.step()

        del batch_x, batch_y_real, batch_y_fake, batch_fake
        end_time = time()
        batch_size_time[i, step] = end_time-start_time
        batch_size_memory[i, step] = torch.cuda.max_memory_allocated(device='cuda')*1e-6
        torch.cuda.reset_max_memory_allocated(device='cuda')

In [9]:
torch.save(batch_size_time, 'batch_size_time.pt')

In [10]:
torch.save(batch_size_memory, 'batch_size_memory.pt')