In [1]:
# default_exp gan.model

# gan.model

> API details.

In [2]:
# export

import numpy as np
import torch
import torch.nn as nn
import matplotlib.pyplot as plt

from functools import partial

from tqdm import tqdm

from fastrenewables.utils import flatten_ts
from fastrenewables.synthetic_data import GaussianDataset, plot_class_hists, DummyDataset
from fastrenewables.timeseries.model import TemporalCNN
from fastrenewables.tabular.model import EmbeddingModule

In [3]:
# export

def LinBnAct(si, so, use_bn, act_cls):
    layers = [nn.Linear(si, so)]
    if use_bn:
        layers += [nn.BatchNorm1d(so)]
    if act_cls is not None:
        layers += [act_cls]
    
    return nn.Sequential(*layers)

In [4]:
# export

class GANMLP(torch.nn.Module):
    def __init__(self, ann_structure, bn_cont=False, act_fct=torch.nn.ReLU, final_act_fct=nn.Sigmoid, embedding_module=None, transpose=False):
        super(GANMLP, self).__init__()
        
        n_cont = ann_structure[0]
        if embedding_module is not None:
            emb_sz = []
            ann_structure[0] = ann_structure[0] + embedding_module.no_of_embeddings

        self.embedding_module = embedding_module
        
        layers = []
        for idx in range(1, len(ann_structure)):
            cur_use_bn = bn_cont
            cur_act_fct = act_fct()
            if idx == 1 and not bn_cont:
                cur_use_bn = False
            if idx == len(ann_structure)-1:
                cur_act_fct = None
                cur_use_bn = False
                
            layer = LinBnAct(ann_structure[idx-1], ann_structure[idx], cur_use_bn, cur_act_fct)
            layers.append(layer)
        if final_act_fct is not None:
            layers.append(final_act_fct())
        
        self.model = nn.Sequential(*layers)

    def forward(self, x_cat, x_cont):
        if self.embedding_module is not None:
            x_cat = self.embedding_module(x_cat)
            x_cont = torch.cat([x_cat, x_cont], 1)
        
        return self.model(x_cont)

In [5]:
# #hide

# n_samples = 1024
# n_classes = 2
# n_features = 1
# batch_size = 512
# n_z = 10
# n_in = n_features
# n_hidden = 64
# epochs = 2

# data = GaussianDataset(n_samples, n_classes)
# dl = torch.utils.data.DataLoader(data, batch_size=batch_size, shuffle=True, drop_last=True)
# emb_module = EmbeddingModule(categorical_dimensions=[n_classes+1])

# model = get_gan_model(structure=[n_z, n_hidden, n_hidden, n_in], n_classes=n_classes, emb_module=emb_module, bn=True, gan_type='bce')
# #print(model)    
# for e in tqdm(range(epochs)):
#     for x_cat, x_cont, y in dl:
#         x_cat = x_cat.to(model.device).long()
#         x_cont = x_cont.to(model.device)
#         y = y.to(model.device)
#         model.train_discriminator(x_cat, x_cont, y)
#         model.train_generator(x_cat, x_cont, y)

# plt.figure(figsize=(16, 9))
# plt.plot(model.real_loss, label='Real Loss')
# plt.plot(model.fake_loss, label='Fake Loss')
# plt.plot(model.aux_loss, label='Aux Loss')
# plt.legend()
# plt.show()

# model.eval()
# z = model.noise(x_cont)
# x_fake = model.generator(x_cat, z)

# print('distribution of real data:')
# plot_class_hists(x_cat.cpu(), x_cont.cpu())

# print('distribution of generated data:')
# plot_class_hists(x_cat.cpu(), x_fake.cpu().detach())

In [6]:
# # hide
# from fastai.data.core import DataLoader,DataLoaders


# opt = torch.optim.Adam(params=model.parameters())
# loss = torch.nn.MSELoss()

# n_cats, n_z, n_targets = 1,2, 2
# data = DummyDataset(n_samples=100, n_cat_feats=n_cats, n_cont_feats=n_z, n_targets=n_targets, n_dim=3)


# dl_train = DataLoader(data, batch_size=8, shuffle=True, drop_last=True)
# dl_valid = DataLoader(data, batch_size=8, shuffle=True, drop_last=True)



In [7]:
# from fastai.data.core import *

# cats, conts, ys = DataLoaders(dl_train, dl_valid).one_batch()

In [8]:
# cats.shape, conts.shape, ys.shape

In [9]:
# n_hidden = 5
# gen_structure = [n_z, n_hidden, n_hidden, n_targets]
# dis_structure = [n_targets, n_hidden, n_hidden, 1]

In [10]:
# emb_module = EmbeddingModule(categorical_dimensions=None, embedding_dimensions=[(2,1)])

In [11]:
# generator = TemporalCNN(cnn_structure=gen_structure, batch_norm_cont=False, 
#                                    cnn_type='tcn', 
#                                    final_activation=nn.Sigmoid,
#                                    embedding_module=emb_module, 
#                                    add_embedding_at_layer=[idx for idx in range(len(gen_structure)-2)],
#                        )
# generator

In [12]:
# discriminator = TemporalCNN(cnn_structure=dis_structure, batch_norm_cont=False, 
#                                    cnn_type='tcn', 
#                                    final_activation=nn.Sigmoid,
#                                    embedding_module=emb_module, 
#                                    add_embedding_at_layer=[idx for idx in range(len(dis_structure)-2)],
#                        )
# discriminator

In [13]:
# opt_fct = torch.optim.Adam
# gan_class = GAN
# auxiliary = False
# aux_factor=0.1
# label_noise=0
# label_bias=0

# gen_opt = opt_fct(params=generator.parameters())
# dis_opt = opt_fct(params=discriminator.parameters())

# model = gan_class(generator=generator, discriminator=discriminator, \
#                   gen_optim=gen_opt, dis_optim=dis_opt, n_z=n_z, auxiliary=auxiliary,\
#                   auxiliary_weighting_factor=aux_factor, label_noise=label_noise, label_bias=label_bias)

In [14]:
# x_cat.shape, x_cont.shape, y.shape

In [15]:
# x_cont.shape, y.shape

In [16]:
# epochs = 2
# for e in tqdm(range(epochs)):
#     for x_cat, x_cont, y in dl_train:
#         x_cat = x_cat.to(model.device).long()
#         x_cont = x_cont.to(model.device)
#         y = y.to(model.device)
        
#         model.train_generator(x_cat, x_cont, y)
#         model.train_discriminator(x_cat, y, _)

In [17]:
# from fastrenewables.gan.learner import GANLearner

In [18]:
# n_gen, n_dis = 1,1
# epochs = 5
# lr = 1e-4
# plot_epochs = 2
# learner = GANLearner(gan=model, n_gen=n_gen, n_dis=n_dis)
# learner.fit(dl_train, epochs=epochs, lr=lr, plot_epochs=plot_epochs, save_model=False)
    

In [19]:
#hide

#generator = GANMLP([n_z, n_hidden, n_in], embedding_module=emb, bn_cont=True)
#discriminator = GANMLP([n_in, n_hidden, 1], bn_cont=True)
#gen_opt = torch.optim.Adam(generator.parameters())
#dis_opt = torch.optim.Adam(discriminator.parameters())
#model = GAN(generator, discriminator, gen_opt, dis_opt, n_z=n_z)
#print(model)
#
#for e in tqdm(range(epochs)):
#    for x_cat, x_cont, y in dl:
#        x_cat = x_cat.to(model.device).long()
#        x_cont = x_cont.to(model.device)
#        y = y.to(model.device)
#
#        model.train_discriminator(x_cat, x_cont, y)
#        model.train_generator(x_cat, x_cont, y)
#
#plt.figure()
#plt.plot(model.real_loss, label='Real Loss')
#plt.plot(model.fake_loss, label='Fake Loss')
#plt.legend()
#plt.show()
#
#assert(np.abs(model.real_loss[-1] - model.fake_loss[-1]) < 0.5)
#
#z = model.noise(x_cont)
#x_fake = model.generator(x_cat, z)
#assert((x_fake - x_cont).mean().abs().item() < 0.5)
#
#print('distribution of real data:')
#plot_class_hists(x_cat.cpu(), x_cont.cpu())
#
#print('distribution of generated data:')
#plot_class_hists(x_cat.cpu(), x_fake.cpu().detach())

In [20]:
#hide

#generator = GANMLP([n_z, n_hidden,  n_in], embedding_module=emb, bn_cont=True)
#discriminator = GANMLP([n_in, n_hidden, 1], final_act_fct=nn.Identity, bn_cont=True)
#gen_opt = torch.optim.RMSprop(generator.parameters())
#dis_opt = torch.optim.RMSprop(discriminator.parameters())
#model = WGAN(generator, discriminator, gen_opt, dis_opt, n_z=n_z)
#print(model)
#
#for e in tqdm(range(epochs)):
#    for x_cat, x_cont, y in dl:
#        x_cat = x_cat.to(model.device).long()
#        x_cont = x_cont.to(model.device)
#        y = y.to(model.device)
#
#        model.train_discriminator(x_cat, x_cont, y)
#        model.train_generator(x_cat, x_cont, y)
#
#plt.figure()
#plt.plot(model.real_loss, label='Real Loss')
#plt.plot(model.fake_loss, label='Fake Loss')
#plt.legend()
#plt.show()
#
#z = model.noise(x_cont)
#x_fake = model.generator(x_cat, z)
#
#print('distribution of real data:')
#plot_class_hists(x_cat.cpu(), x_cont.cpu())
#
#print('distribution of generated data:')
#plot_class_hists(x_cat.cpu(), x_fake.cpu().detach())

In [21]:
#hide

#generator = GANMLP([n_z, n_hidden, n_hidden, n_in], embedding_module=emb, bn_cont=True)
#discriminator = GANMLP([n_in, n_hidden, n_hidden], final_act_fct=nn.ReLU, bn_cont=True)
#discriminator = AuxiliaryDiscriminator(basic_discriminator=discriminator, n_classes=n_classes, final_input_size=n_hidden, len_ts=1)
#gen_opt = torch.optim.Adam(generator.parameters())
#dis_opt = torch.optim.Adam(discriminator.parameters())
#model = GAN(generator, discriminator, gen_opt, dis_opt, n_z=n_z, auxiliary=True, auxiliary_weighting_factor=1)
#print(model)
#
#for e in tqdm(range(epochs)):
#    for x_cat, x_cont, y in dl:
#        x_cat = x_cat.to(model.device).long()
#        x_cont = x_cont.to(model.device)
#        y = y.to(model.device)
#
#        model.train_discriminator(x_cat, x_cont, y)
#        model.train_generator(x_cat, x_cont, y)
#
#plt.figure()
#plt.plot(model.real_loss, label='Real Loss')
#plt.plot(model.fake_loss, label='Fake Loss')
#plt.plot([a.item() for a in model.aux_loss], label='Aux Loss')
#plt.legend()
#plt.show()
#
#z = model.noise(x_cont)
#x_fake = model.generator(x_cat, z)
#
#print('distribution of real data:')
#plot_class_hists(x_cat.cpu(), x_cont.cpu())
#
#print('distribution of generated data:')
#plot_class_hists(x_cat.cpu(), x_fake.cpu().detach())