In [1]:
# default_exp gan.learner

# gan.learner

> API details.

In [2]:
# export

import torch
import torch.nn as nn
import matplotlib.pyplot as plt

from tqdm import tqdm
from fastrenewables.synthetic_data import *
from fastrenewables.gan.model import *
from fastrenewables.tabular.model import EmbeddingModule

#import glob

In [5]:
# export

class GANLearner():
    def __init__(self, gan, n_gen=1, n_dis=1):
        super(GANLearner, self).__init__()
        # gan should contain a class which itself contains a generator and discriminator/critic class and combines them
        self.gan = gan
        self.n_gen = n_gen
        self.n_dis = n_dis
   
    def generate_samples(self, x):
        z = self.gan.noise(x)
        fake_samples = self.gan.generator(None, z).detach()
        return fake_samples
    
    def fit(self, dl, epochs=10, plot_epochs=10, save_model=False):
        
        self.gan.to_device(self.gan.device)
        for e in tqdm(range(epochs)):
            for x_cat, x_cont, y in dl:
                x_cat = x_cat.to(self.gan.device)
                x_cont = x_cont.to(self.gan.device)
                y = y.to(self.gan.device)
                
                # todo: only for mlp
                x_cat = x_cat.flatten(1, 2).long()
                x_cont = x_cont.flatten(1, 2)
                y = y[:, 0]
                #if y.dim() == 3:
                #    print('learner', y.shape)
                #    y = y.flatten(1, 2)[:, 0]
                #    print('learner', y.shape)
                                
                for _ in range(self.n_dis):
                    self.gan.train_discriminator(x_cat, x_cont, y)

                for _ in range(self.n_gen):
                    self.gan.train_generator(x_cat, x_cont, y)
                
            if (e+1)%plot_epochs==0:
                plt.figure(figsize=(16, 9))
                plt.plot(self.gan.real_loss, label='Real Loss')
                plt.plot(self.gan.fake_loss, label='Fake Loss')
                plt.legend()
                plt.show()
        
        if save_model:
            self.gan.to_device('cpu')
        
        return

### Examples:

In [6]:
n_z = 100
n_cat_feats = 1
n_cont_feats = 4
n_targets = 1
len_ts = 24
n_samples = 1000
# TODO: shouldn't this be the same as n_cat_features
n_classes = 2
batch_size = 256
epochs = 100

gan_type = 'aux'
model_type = 'mlp'

emb = EmbeddingModule(categorical_dimensions=[n_classes])

structure = [n_z, 64, n_cont_feats]

if gan_type == 'wgan':
    n_gen = 1
    n_dis = 4
else:
    n_gen = 1
    n_dis = 1

#if model_type == 'mlp':
#    n_dim = 2
#elif model_type == 'cnn':
#    n_dim = 3

model = get_gan_model(gan_type=gan_type, model_type=model_type, structure=structure, 
                      len_ts=len_ts, n_classes=n_classes, emb_module=emb)
print(model)
#data = DummyDataset(n_samples=n_samples, n_cat_feats=n_cat_feats, n_cont_feats=n_cont_feats, n_targets=n_targets, len_ts=len_ts, n_dim=n_dim)
data = SineDataset(n_samples=n_samples, n_classes=n_classes, n_features=n_cont_feats, noise=0)
dl = torch.utils.data.DataLoader(data, batch_size=batch_size, shuffle=True, drop_last=True)
#for x_cat, x_cont, y in dl:
#    print(x_cat.shape, x_cont.shape, y.shape)
#    break
learner = GANLearner(gan=model, n_gen=n_gen, n_dis=n_dis)
learner.fit(dl, epochs=epochs, plot_epochs=epochs)

GAN(
  (generator): GANMLP(
    (embedding_module): EmbeddingModule(
      (embeddings): ModuleList(
        (0): Embedding(2, 2)
      )
    )
    (model): Sequential(
      (0): Sequential(
        (0): Linear(in_features=2448, out_features=1536, bias=True)
        (1): ReLU()
      )
      (1): Sequential(
        (0): Linear(in_features=1536, out_features=96, bias=True)
      )
      (2): Sigmoid()
    )
  )
  (discriminator): AuxiliaryDiscriminator(
    (basic_discriminator): GANMLP(
      (model): Sequential(
        (0): Sequential(
          (0): Linear(in_features=96, out_features=1536, bias=True)
          (1): LeakyReLU(negative_slope=0.01)
        )
        (1): Sequential(
          (0): Linear(in_features=1536, out_features=24, bias=True)
        )
        (2): Sigmoid()
      )
    )
    (adv_layer): Sequential(
      (0): Linear(in_features=24, out_features=1, bias=True)
      (1): Sigmoid()
    )
    (aux_layer): Sequential(
      (0): Linear(in_features=24, out_featur

  0%|          | 0/100 [00:00<?, ?it/s]


RuntimeError: mat1 and mat2 shapes cannot be multiplied (256x102 and 2448x1536)

In [None]:
tmp = nn.Linear(1024, 24)

for x_cat, x_cont, y in dl:
    print(x_cat.shape)
    break

x_cat = x_cat.flatten(1, 2)    
x_cat.shape
x_cat.ravel().shape
#tmp(x_cat)

# Tests:

In [None]:
for x_cat, x_cont, y in dl:
    print(x_cat)
    break
x_fake = learner.generate_samples(x_cont).cpu()
x_fake, x_fake.shape

In [None]:
if n_dim == 3:
    plt.figure(figsize=(16, 9))
    for f in range(n_cont_feats):
        plt.subplot(2, 2, f+1)
        plt.plot(x_cont[0, f, :], label='real')
        plt.plot(x_fake[0, f, :], label='fake')
        plt.legend()
    plt.show()