In [1]:
import torch
from torch.utils.data import DataLoader
from torch.autograd import Variable
import wandb
import tqdm

from generation.config import WANDB_PROJECT, SIGNALS_TRAINING_CONFIG as config
from generation.dataset.signals_dataset import SignalsDataset
from generation.nets.signals import Generator, Discriminator
from generation.training.wgan_trainer import WganTrainer
from generation.utils import set_seed


In [2]:
config['device'] = 'cuda:0'
config['lr'] = 3e-4
config['x_dim'] = 512

In [3]:
dataset = SignalsDataset(signal_dim=config['x_dim'])
dataloader = DataLoader(dataset, batch_size=config["batch_size"], shuffle=True)

generator = Generator(config).to(config['device'])
optimizer = torch.optim.Adam(generator.parameters(), lr=config['lr'])

In [4]:
criterion = torch.nn.L1Loss()

In [5]:
wandb.init(config=config, project=WANDB_PROJECT)
wandb.watch(generator)

wandb: Currently logged in as: whitera2bit (use `wandb login --relogin` to force relogin)
wandb: wandb version 0.10.7 is available!  To upgrade, please run:
wandb:  $ pip install wandb --upgrade
wandb: Tracking run with wandb version 0.10.2
wandb: Run data is saved locally in wandb/run-20201021_185006-2j0cnm37
wandb: Syncing run upbeat-waterfall-293





[<wandb.wandb_torch.TorchGraph at 0x7f76c3603e10>]

In [None]:
for epoch in tqdm.tqdm(range(config['epochs_num'])):
    for it, data in enumerate(dataloader):
        X = Variable(data)
        X = X.to(config['device'])
        z = Variable(torch.randn(X.shape[0], config['z_dim']))
        z = z.to(config['device'])

        g_sample = generator(z)
        
        optimizer.zero_grad()
        loss = criterion(X, g_sample)
        loss.backward()
        optimizer.step()
        
    if epoch % config['log_each'] == 0:
        wandb.log(
            {
                "MSE loss": loss.cpu(),
            },
            step=epoch)
        generator.visualize(g_sample, X, epoch)

 29%|██▉       | 1466/5000 [1:11:12<2:50:25,  2.89s/it]