In [1]:
import pandas as pd
import numpy as np
from torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils import spectral_norm as SN
import torch
from IPython.display import clear_output

import os
from argparse import ArgumentParser
from collections import OrderedDict

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import MNIST

import pytorch_lightning as pl

In [2]:
data = pd.read_csv("./dataset/creditcard.csv")


In [3]:
data.drop(["Time", "Class"],axis=1, inplace= True)
cuda = True if torch.cuda.is_available() else False
cuda

False

In [4]:
from sklearn.preprocessing import MinMaxScaler as mms
num_scaler = mms(feature_range=(-1,1))
columns = data.columns.tolist()
data[columns] = num_scaler.fit_transform(data[columns])
data_np = data.values

In [5]:
class TabularDataModule(pl.LightningDataModule) :
    def __init__(self, data , batch_size:int = 32 , num_workers:int=3) :
        super().__init__()
        self.data = data
        self.batch_size = batch_size
        self.num_workers = num_workers

        self.dims = self.data.shape[1]
    
    def prepare_data(self,) :
        pass

    def setup(self,stage=None) :
        if stage == "fit" or state is None :
            train_length = int(len(self.data)*0.8)
            lengths = [train_length, int(len(self.data)-train_length)]
            self.train, self.val = random_split(self.data, lengths)
        
        if stage == "test" or stage is None :
            self.test = self.data   


    def train_dataloader(self):
        return DataLoader(self.train , batch_size= self.batch_size , num_workers=self.num_workers )

    def valid_dataloader(self):
        return DataLoader(self.val , batch_size= self.batch_size , num_workers=self.num_workers)

    def test_dataloader(self):
        return DataLoader(self.test , batch_size= self.batch_size , num_workers=self.num_workers)



In [6]:
class Generator(nn.Module):
    def __init__(self, latent_dim, out_shape, scaler):
        super().__init__()
        self.out_shape = out_shape

        def block(in_feat, out_feat, normalize=True):
            layers = [nn.Linear(in_feat, out_feat)]
            if normalize:
                layers.append(nn.BatchNorm1d(out_feat, 0.8))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        self.model = nn.Sequential(
            *block(latent_dim, 128, normalize=False),
            *block(128, 256),
            *block(256, 512),
            *block(512, 1024),
            nn.Linear(1024, self.out_shape),
            nn.Tanh()
        )
        self.scaler = scaler 

    def forward(self, z):
        x = self.model(z)
        # img = img.view(img.size(0), *self.out_shape)
        return x

    def inference(self,z) :
        x = self.model(z).detach().numpy()
        x = self.scaler.inverse_transform(x)
        return x


In [7]:
class Discriminator(nn.Module):
    def __init__(self, out_shape):
        super().__init__()

        self.model = nn.Sequential(
            nn.Linear(out_shape, 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 1),
            # nn.ReLU(), ##CHANGED
        )

    def forward(self, out):
        validity = self.model(out)
        return validity

In [12]:
class WGAN(pl.LightningModule): 
    def __init__(self,
                input_dim = None,
                scaler = None,
                latent_dim =100 , 
                lr: float = 0.0002,
                b1: float = 0.5,
                b2: float = 0.999,
                batch_size: int = 64,
                **kwargs) :
        super().__init__()
        self.save_hyperparameters()

        self.generator = Generator(latent_dim=self.hparams.latent_dim,
        out_shape=self.hparams.input_dim,scaler=scaler)

        self.discriminator = Discriminator(out_shape=self.hparams.input_dim)

        self.validation_z = torch.randn(8, self.hparams.latent_dim)

        self.example_input_array = torch.zeros(2, self.hparams.latent_dim)

    def forward(self, z):
        return self.generator(z)
    
    def inference(self,z) :
        return self.generator.inference(z)

    # def adversarial_loss(self, y_hat, y):
    #     return F.mse_loss(y_hat, y) ##CHANGED

    def training_step(self, batch, batch_idx, optimizer_idx):
        x = batch
        # sample noise
        z = torch.randn(x.shape[0], self.hparams.latent_dim)
        z = z.type_as(x)
        # train generator
        if optimizer_idx == 0:

            # generate images
            self.generated_imgs = self(z)

            # ground truth result (ie: all fake)
            # put on GPU because we created this tensor inside training_loop
            valid = torch.ones(x.size(0), 1)
            valid = valid.type_as(x)

            # adversarial loss is binary cross-entropy
            # g_loss = self.adversarial_loss(self.discriminator(self(z)), valid)
            g_loss = -torch.mean(self.discriminator(self(z))) ## cahnged
            tqdm_dict = {'g_loss': g_loss}
            output = OrderedDict({
                'loss': g_loss,
                'progress_bar': tqdm_dict,
                'log': tqdm_dict
            })
            return output

        # train discriminator
        if optimizer_idx == 1:
            # Measure discriminator's ability to classify real from generated samples

            # # how well can it label as real?
            # valid = torch.ones(x.size(0), 1)
            # valid = valid.type_as(x)

            # real_loss = self.adversarial_loss(self.discriminator(x), valid)

            # # how well can it label as fake?
            # fake = torch.zeros(x.size(0), 1)
            # fake = fake.type_as(x)

            # fake_loss = self.adversarial_loss(
            #     self.discriminator(self(z).detach()), fake)

            # # discriminator loss is the average of these
            # d_loss = (real_loss + fake_loss) / 2

            d_loss = -torch.mean(self.discriminator(x)) + torch.mean(self.discriminator(self(z)))
             
            clip_value = 0.01
            for p in self.discriminator.parameters():
                p.data.clamp_(-clip_value, clip_value)

            tqdm_dict = {'d_loss': d_loss}
            output = OrderedDict({
                'loss': d_loss,
                'progress_bar': tqdm_dict,
                'log': tqdm_dict
            })
            return output

    def configure_optimizers(self):
        lr = self.hparams.lr
        b1 = self.hparams.b1
        b2 = self.hparams.b2

        opt_g = torch.optim.Adam(self.generator.parameters(), lr=lr, betas=(b1, b2))
        opt_d = torch.optim.Adam(self.discriminator.parameters(), lr=lr, betas=(b1, b2))
        return [opt_g, opt_d], []

    def on_epoch_end(self):
        z = self.validation_z.type_as(self.generator.model[0].weight)


In [11]:
dm = TabularDataModule(data_np.astype(np.float32))
model = WGAN(dm.size(),num_scaler)
trainer = pl.Trainer(gpus=0, max_epochs=5, progress_bar_refresh_rate=20)
trainer.fit(model, dm)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores

  | Name          | Type          | Params | In sizes | Out sizes
-----------------------------------------------------------------------
0 | generator     | Generator     | 736 K  | [2, 100] | [2, 29]  
1 | discriminator | Discriminator | 146 K  | ?        | ?        
-----------------------------------------------------------------------
883 K     Trainable params
0         Non-trainable params
883 K     Total params
3.532     Total estimated model params size (MB)


Training: 0it [00:00, ?it/s]

NameError: name 'clip_value' is not defined

In [84]:
import matplotlib.pyplot as plt
model.eval()
z = torch.randn(1, model.hparams.latent_dim)
model.inference(z)

array([[ 2.3204608 ,  3.4700842 ,  1.2084354 ,  0.3087364 , -6.6636    ,
        -9.814687  , -6.738342  ,  2.5460782 ,  0.18513215, -0.90156585,
         0.32576516,  1.8583136 ,  0.18644534,  0.66480637,  0.0244148 ,
        -1.1541592 ,  0.64006954,  0.6347178 ,  0.18365327, -1.170155  ,
        -2.9694934 , -0.22930194,  0.9808853 , -0.10587221,  0.38431725,
         0.13272496,  2.2867596 , -1.3832966 , 23.648825  ]],
      dtype=float32)