In [1]:
# default_exp gan.learner

# gan.learner

> API details.

In [2]:
# TODOs:
# rework device handling
# plot every n_epochs

In [3]:
# export

import torch
import torch.nn as nn
import matplotlib.pyplot as plt

from tqdm import tqdm

In [10]:
# export
# first drafts for the actual learner and model classes

class DummyDataset(torch.utils.data.Dataset):

    def __init__(self, n_samples=1000, n_cat_feats=10, n_cont_feats=10):

        self.n_samples = n_samples
        self.cat_data = torch.randn(n_cat_feats, n_samples)
        self.cont_data = torch.randn(n_cont_feats, n_samples)

    def __len__(self):
        return self.n_samples
    
    def __getitem__(self, idx):
        y_cat = self.cat_data[:, idx]
        y_cont = self.cont_data[:, idx]
        return y_cat, y_cont

class W_Gan(nn.Module):
    def __init__(self, generator, discriminator, gen_optim, dis_optim, clip=0.001):
        super(W_Gan, self).__init__()
        self.generator = generator
        self.discriminator = discriminator
        self.gen_optim = gen_optim
        self.dis_optim = dis_optim
        self.clip = clip
        self.real_loss = []
        self.fake_loss = []
        
    def train_generator(self, z):
        # train the generator model
        self.generator.zero_grad()
        x_fake = self.generator(z)
        y_fake = self.discriminator(x_fake)
        loss = - y_fake.mean()
        loss.backward()
        self.gen_optim.step()
        return
    
    def train_discriminator(self, z, x_cont):
        # train the discriminator model
        self.discriminator.zero_grad()
        x_real = x_cont
        y_real = self.discriminator(x_real)
        real_loss = - y_real.mean()
        real_loss.backward()
        self.dis_optim.step()
        self.real_loss.append(real_loss.item())
        
        self.discriminator.zero_grad()
        x_fake = self.generator(z).detach()
        y_fake = self.discriminator(x_fake)
        fake_loss = y_fake.mean()
        fake_loss.backward()
        self.dis_optim.step()
        self.fake_loss.append(fake_loss.item())
        
        for p in self.discriminator.parameters():
            p = torch.clamp(p, -self.clip, self.clip)
        return

class GanLearner():
    def __init__(self, gan):
        super(GanLearner, self).__init__()
        # gan should contain a class which itself contains a generator and discriminator/critic class and combines them
        self.gan = gan
        
    def generate_samples(self, x, n_z=100):
        z = torch.randn(x.shape[0], n_z)
        return z
    
    def fit(self, dl, epochs=5, n_gen=1, n_dis=1):
        # train gan and store parameters and losses in given class
        for e in tqdm(range(epochs)):
            
            for x_cat, x_cont in dl:
                
                for _ in range(n_dis):
                    z = self.generate_samples(x_cont)
                    self.gan.train_discriminator(z, x_cont)

                for _ in range(n_gen):
                    z = self.generate_samples(x_cont)
                    self.gan.train_generator(z)

        return

In [11]:
ds = DummyDataset()
dl = torch.utils.data.DataLoader(dataset=ds, batch_size=64, shuffle=True, drop_last=True)

for y_cat, y_cont in dl:
    print(y_cat.shape, y_cont.shape)
    break

generator = nn.Sequential(nn.Linear(100, 50), nn.BatchNorm1d(50), nn.ReLU(), nn.Linear(50, 10), nn.BatchNorm1d(10), nn.ReLU())
discriminator = nn.Sequential(nn.Linear(10, 50), nn.BatchNorm1d(50), nn.ReLU(), nn.Linear(50, 1), nn.BatchNorm1d(1), nn.ReLU())
gen_optim = torch.optim.RMSprop(generator.parameters())
dis_optim = torch.optim.RMSprop(discriminator.parameters())
gan = W_Gan(generator, discriminator, gen_optim, dis_optim)
learn = GanLearner(gan)
learn.fit(dl)

torch.Size([64, 10]) torch.Size([64, 10])


100%|█████████████████████████████████████████████| 5/5 [00:00<00:00, 19.58it/s]
