# Bayesian GAN

In [1]:
import gc
import os
from torch.utils.data import Dataset, DataLoader
from sklearn.utils import resample
import torch
import torch.nn as nn
from torch.autograd import Variable
from torch.utils.data import DataLoader

import torchvision
import torchvision.transforms as transforms
import torchvision.utils as vutils

import numpy as np
import random

import pandas as pd


from matplotlib import pyplot as plt

#################### Seed  #####################
torch.manual_seed(1235)
np.random.seed(1235)
torch.cuda.manual_seed_all(1235)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
random.seed(1235)
################################################

################ For KL-prior ######################


def knn_distance(point, sample, k):
    """ Euclidean distance from `point` to it's `k`-Nearest
    Neighbour in `sample` """
    norms = torch.linalg.norm(sample-point, axis=1)
    return torch.sort(norms)[0][k]


def verify_sample_shapes(s1, s2, k):
    # Expects [N, D]
    assert(len(s1.shape) == len(s2.shape) == 2)
    # Check dimensionality of sample is identical
    assert(s1.shape[1] == s2.shape[1])


def naive_estimator(s1, s2, k=1):
    """ KL-Divergence estimator using brute-force (numpy) k-NN
        s1: (N_1,D) Sample drawn from distribution P
        s2: (N_2,D) Sample drawn from distribution Q
        k: Number of neighbours considered (default 1)
        return: estimated D(P|Q)
    """
    verify_sample_shapes(s1, s2, k)

    n, m = len(s1), len(s2)
    D = np.log(m / (n - 1))
    d = float(s1.shape[1])

    for p1 in s1:
        nu = knn_distance(p1, s2, k-1)  # -1 because 'p1' is not in 's2'
        rho = knn_distance(p1, s1, k)
        D += (d/n)*torch.log(nu/rho)
    return D

#################################
### Reading Dataset from File ###
#################################


input_data = np.load(
    r'C:\Users\preet\OneDrive\Documents\CS578_project\useful_mimic3\patient_matrix.npy', allow_pickle=True)

total_samples = input_data.shape[0]
feature_size = input_data.shape[1]
print("total samples:", total_samples)
print("feature size:", feature_size)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

#################################
### hyperparams ###
#################################
batchSize = 1000
lr = 0.001
lr_g = 0.2  # Generator learning rate
n_epoch_ae = 300  # number of autoencoder epochs
num_gen = 10
autoencoder_flag = False
autoencoder_inner_dim = 128
DEBUG = False

################ Preparing Dataset #############


#####################
### Dataset Model ###
#####################


class EHRDataset(torch.utils.data.Dataset):
    def __init__(self, dataset):
        self.dataset = dataset
        self.sample_size = dataset.shape[0]
        self.feature_size = dataset.shape[1]
        self.transform = transforms.Compose([transforms.ToTensor()])

    def __len__(self):
        return self.dataset.shape[0]

    def __getitem__(self, idx):
        # print("type(self.dataset[idx]):",type(self.dataset[idx]))
        return self.dataset[idx]
        # return self.transform(self.dataset[idx])

##########################
### Dataset Processing ###
##########################


train_data = input_data[:int(0.8 * total_samples)]
# print(train_data[0])
# print(type(train_data[0]))
train_data_len = len(train_data)//batchSize*batchSize
train_data = train_data[:train_data_len]

test_data = input_data[int(0.8 * total_samples):]
test_data_len = len(test_data)//batchSize*batchSize
test_data = test_data[:test_data_len]

print('total samples: {}, features: {}'.format(total_samples, feature_size))
print('training data shape: {}, testing data shape: {}, dataset type: {}'.format(
    train_data.shape, test_data.shape, input_data.dtype))
training_dataloader = DataLoader(
    EHRDataset(dataset=train_data),
    batch_size=batchSize,
    shuffle=True
    # num_workers=opt.n_cpu
)

testing_dataloader = DataLoader(
    EHRDataset(dataset=test_data),
    batch_size=batchSize,
    shuffle=True
    # num_workers=opt.n_cpu
)



################################################

############## NN structures ###################

def initialize_weights(net):
    for m in net.modules():
        if isinstance(m, nn.Conv2d):
            m.weight.data.normal_(0, 0.02)
            # m.bias.data.zero_()
        elif isinstance(m, nn.ConvTranspose2d):
            m.weight.data.normal_(0, 0.02)
            # m.bias.data.zero_()
        elif isinstance(m, nn.Linear):
            m.weight.data.normal_(0, 0.02)
            m.bias.data.zero_()
        elif isinstance(m, nn.BatchNorm2d):  # BatchNorm weight init
            m.weight.data.normal_(1.0, 0.02)
            m.bias.data.fill_(0)
        elif isinstance(m, nn.BatchNorm1d):  # BatchNorm weight init
            m.weight.data.normal_(1.0, 0.02)
            m.bias.data.fill_(0)



#############################
### Generator Model ###
# https://github.com/mohibeyki/SCorGAN/blob/main/medGAN/MIMIC-III/medGAN.ipynb
#############################

# Output should be 64 * 20
class Generator(nn.Module):
    def __init__(self, latent_dim, num_gen):
        super(Generator, self).__init__()

        self.name = 'generator'
        self.latent_dim = latent_dim  # default 128
        self.num_gen = num_gen

        if autoencoder_flag:
            self.genDim = autoencoder_inner_dim
        else:
            self.genDim = feature_size
            
        if autoencoder_flag:
            self.linear1 = nn.Linear(latent_dim, self.genDim)
            self.bn1 = nn.BatchNorm1d(self.genDim, eps=0.001, momentum=0.01)
            self.activation1 = nn.ReLU()
            self.linear2 = nn.Linear(latent_dim, self.genDim)
            self.bn2 = nn.BatchNorm1d(self.genDim, eps=0.001, momentum=0.01)
            self.activation2 = nn.Tanh()
        else:
            self.model = nn.Sequential(
                nn.Linear(latent_dim, self.genDim//2),
                nn.BatchNorm1d(self.genDim//2, eps=0.001, momentum=0.01),
                nn.ReLU(),
                nn.Linear(self.genDim//2, self.genDim),
                nn.BatchNorm1d(self.genDim, eps=0.001, momentum=0.01),
                nn.Tanh()
            )
            

        self.gs = []
        for i in range(self.num_gen):
            g = nn.Sequential(
                # input size is z_size
                nn.Linear(latent_dim, self.genDim),
                nn.BatchNorm1d(self.genDim, eps=0.001, momentum=0.01),
                nn.ReLU(inplace=True),

                nn.Linear(self.genDim, latent_dim),
                nn.BatchNorm1d(latent_dim, eps=0.001, momentum=0.01),
                nn.ReLU(inplace=True)

            )
            setattr(self, 'G_{}'.format(i), g)
            self.gs.append(g)

        initialize_weights(self)

    def forward(self, x):
#         print("x.size():",x.size())  # [100,128]
        sp_size = (len(x) - 1) // len(self.gs) + 1
#         print("sp_size:",sp_size)  # 10
        y = []
        for _x, _g in zip(torch.split(x, sp_size, dim=0), self.gs):
            y.append(_g(_x))
        y = torch.cat(y, dim=0)
        
        if DEBUG:
            print("y dim:",y.size())

        if autoencoder_flag:
            # Layer 1
            residual = y
            temp = self.activation1(self.bn1(self.linear1(y)))
            out1 = temp + residual

            # Layer 2
            residual = out1
            temp = self.activation2(self.bn2(self.linear2(out1)))
            out2 = temp + residual
            return out2
        else:
            output = self.model(y)
            return output




###########################
### Discriminator Model ###
# https://github.com/mohibeyki/SCorGAN/blob/main/medGAN/MIMIC-III/medGAN.ipynb
###########################


class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        self.name = 'discriminator'

        # Discriminator's parameters
        self.disDim = 256

        self.model = nn.Sequential(
            nn.Linear(feature_size, self.disDim),
            nn.ReLU(),
            nn.Linear(self.disDim, int(self.disDim / 2)),
            nn.ReLU(),
            nn.Linear(int(self.disDim / 2), 1),
            nn.Sigmoid(),
        )

        initialize_weights(self)

    def forward(self, x):
        # Feeding the model
        output = self.model(x)
        return output


#########################
### AutoEncoder Model ###
#########################

class Autoencoder(nn.Module):
    def __init__(self):
        super(Autoencoder, self).__init__()

        self.encoder = nn.Sequential(
            nn.Linear(feature_size, autoencoder_inner_dim),
            nn.Tanh())
        self.decoder = nn.Sequential(
            nn.Linear(autoencoder_inner_dim, feature_size), nn.Sigmoid())

        initialize_weights(self)

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x

    def decode(self, x):
        x = self.decoder(x)
        return x


########################
### AutoEncoder Loss ###
########################

class AutoEncoderLoss(nn.Module):
    def __init__(self):
        super(AutoEncoderLoss, self).__init__()

    def forward(self, input, target):
        epsilon = 1e-12
        term = target * torch.log(input + epsilon) + \
                                  (1. - target) * torch.log(1. - input + epsilon)
        return torch.mean(-torch.sum(term, 1), 0)


################################################
current_path = os.getcwd()
directory = "med_ebgan"
result_path = os.path.join(current_path, directory)
if not os.path.isdir(result_path):
    os.mkdir(result_path)

################# EBGAN training ###############

############################
### Model Initialization ###
############################

gc.collect()
torch.cuda.empty_cache()

latent_dim = 128
autoencoder = Autoencoder()
generator = Generator(latent_dim, num_gen)
discriminator = Discriminator()

Tensor = torch.FloatTensor

autoencoder.cuda()
generator.cuda()
discriminator.cuda()

Tensor = torch.cuda.FloatTensor

generator_params = [{'params': generator.parameters(
)}, {'params': autoencoder.decoder.parameters(), 'lr': 1e-4}]

initial_lr_d = 0.002
optimizer_A = torch.optim.Adam(autoencoder.parameters(), lr=lr)
optimizer_D = torch.optim.Adam(
    discriminator.parameters(), lr=initial_lr_d, betas=(0.5, 0.999))
optimizer_G = torch.optim.SGD(generator.parameters(), lr=lr_g)



total samples: 46517
feature size: 1358
total samples: 46517, features: 1358
training data shape: (37000, 1358), testing data shape: (9000, 1358), dataset type: int32


In [4]:
criterion = AutoEncoderLoss()


for epoch in range(n_epoch_ae):
    autoencoder.train()
    for batch in training_dataloader:
        batch = Variable(batch.type(Tensor))
        generated = autoencoder(batch)
        loss_A = criterion(generated, batch)
        optimizer_A.zero_grad()
        loss_A.backward()
        optimizer_A.step()

    errors = 0
    testing_loss = 0
    autoencoder.eval()
    for batch in testing_dataloader:
        batch = Variable(batch.type(Tensor))
        generated = autoencoder(batch)
        res = generated.round()
        diff = torch.abs(res - batch).view(1, 1, -
                         1)[0][0].cpu().detach().numpy()
        bad_diffs = diff[diff > 0.5]
        errors += len(bad_diffs)
        testing_loss += criterion(generated, batch)

    print("[Epoch {:3d}/{:3d} of autoencoder training] [Loss: {:10.2f}] [errors: {:6d}]".format(
        epoch + 1, n_epoch_ae, testing_loss, errors), flush=True)

torch.save(autoencoder.state_dict(), os.path.join(result_path, 'autoencoder.model'))

[Epoch   1/100 of autoencoder training] [Loss:    7175.44] [errors: 161219]
[Epoch   2/100 of autoencoder training] [Loss:    6231.60] [errors: 148178]
[Epoch   3/100 of autoencoder training] [Loss:    4903.11] [errors: 143242]
[Epoch   4/100 of autoencoder training] [Loss:    3829.39] [errors: 139921]
[Epoch   5/100 of autoencoder training] [Loss:    2803.84] [errors: 135776]
[Epoch   6/100 of autoencoder training] [Loss:    1824.03] [errors: 130974]
[Epoch   7/100 of autoencoder training] [Loss:    1004.38] [errors: 124904]
[Epoch   8/100 of autoencoder training] [Loss:     269.32] [errors: 119574]
[Epoch   9/100 of autoencoder training] [Loss:    -410.26] [errors: 113852]
[Epoch  10/100 of autoencoder training] [Loss:    -962.85] [errors: 109989]
[Epoch  11/100 of autoencoder training] [Loss:   -1452.54] [errors: 105307]
[Epoch  12/100 of autoencoder training] [Loss:   -1863.63] [errors: 101411]
[Epoch  13/100 of autoencoder training] [Loss:   -2217.36] [errors:  98535]
[Epoch  14/1

In [2]:
autoencoder.load_state_dict(torch.load(os.path.join(result_path, 'autoencoder.model')))


Autoencoder(
  (encoder): Sequential(
    (0): Linear(in_features=1358, out_features=128, bias=True)
    (1): Tanh()
  )
  (decoder): Sequential(
    (0): Linear(in_features=128, out_features=1358, bias=True)
    (1): Sigmoid()
  )
)

In [16]:
autoencoder.eval()

Autoencoder(
  (encoder): Sequential(
    (0): Linear(in_features=1358, out_features=128, bias=True)
    (1): Tanh()
  )
  (decoder): Sequential(
    (0): Linear(in_features=128, out_features=1358, bias=True)
    (1): Sigmoid()
  )
)

In [6]:
errors = 0
for batch in testing_dataloader:
    batch = Variable(batch.type(Tensor))
    generated = autoencoder(batch)
    res = generated.round()
    diff = torch.abs(res - batch).view(1, 1, -1)[0][0].cpu().detach().numpy()
    bad_diffs = diff[diff > 0.5]
    errors += len(bad_diffs)
print("total number of bad samples: {}".format(errors))

total number of bad digits: 58620


In [2]:
N = len(train_data)
# n_c=3


# x_shape=(3,28,28)
# generator = Generator_CNN(latent_dim, x_shape).cuda()
# discriminator = Discriminator_CNN().cuda()


A = 0.05  # Discriminator learning rate
B = 500





n = 0
for p in generator.parameters():
    n = n+1

M = [None]*n
n = 0
for par in generator.parameters():
    M[n] = torch.zeros(list(par.size())).cuda()
    n = n+1


# N = 60000
a = 1
beta_1 = 0.9
n_epochs = 1000
tau = 0.001

bce_loss = torch.nn.BCELoss()

iteration = 0

Tensor = torch.cuda.FloatTensor
# real = Variable(Tensor(batchSize).fill_(1.0), requires_grad=False)
# fake = Variable(Tensor(batchSize).fill_(0.0), requires_grad=False)

Gen_Dis = torch.zeros(int(N/batchSize*n_epochs/100))
Real_Dis = torch.zeros(int(N/batchSize*n_epochs/100))
js = 0
###########################################

for epoch in range(n_epochs):
    for i, sample in enumerate(training_dataloader):

          # ---------------------
          #  Train Discriminator
          # ---------------------
        lr_d = A*(iteration+B)**(-1)

        for param_group in optimizer_D.param_groups:
            param_group['lr'] = lr_d

        generator.zero_grad()
        discriminator.zero_grad()

        real_sample = Variable(sample.type(Tensor))
        
        num_of_samples = real_sample.size(dim=0)

        optimizer_D.zero_grad()
        
        # Sample random latent variables
        # gen_sample = generator(torch.randn((batchSize, latent_dim, 1, 1)).to(device))
        gen_sample = generator(torch.randn((num_of_samples, latent_dim)).to(device))
        
        # PROBABLY PASS gen_imgs THROUGH AUTOENCODER NOW
        
        if autoencoder_flag:
            gen_sample = autoencoder.decoder(gen_sample)  # added
        
        D_gen = discriminator(gen_sample)
        D_real = discriminator(real_sample)
        
        
        real = Variable(Tensor(num_of_samples).fill_(1.0), requires_grad=False)
        fake = Variable(Tensor(num_of_samples).fill_(0.0), requires_grad=False)

        # real = real.unsqueeze(1)
        # fake = fake.unsqueeze(1)
        # fake = torch.unsqueeze(fake, 1)
        fake = fake.view(-1, 1)
        # print("real.size():",real.size())
        # print("D_real.size():",D_real.size())
        # real = torch.unsqueeze(real, 1)
        real = real.view(-1, 1)
        # print("real.size():",real.size())
        
        real_loss = bce_loss(D_real, real)
        fake_loss = bce_loss(D_gen, fake)
        d_loss = (real_loss + fake_loss)

        d_loss.backward()
        optimizer_D.step()

          # ---------------------
          #  Train Generator
          # ---------------------

        optimizer_G.zero_grad()
        # gen_sample = generator(torch.randn((batchSize, latent_dim, 1, 1)).to(device))
        gen_sample = generator(torch.randn((num_of_samples, latent_dim)).to(device))
        
        if DEBUG:
            print("gen_sample.size() before decoding:", gen_sample.size())
        # PROBABLY PASS gen_imgs THROUGH AUTOENCODER NOW
        
        if autoencoder_flag:
            gen_sample = autoencoder.decoder(gen_sample)  # added
        
        if DEBUG:
            print("gen_sample.size() after decoding:", gen_sample.size())

        D_gen = discriminator(gen_sample)
        v_loss = bce_loss(D_gen, real)

        l2 = 0.0
        for p in generator.parameters(): 
            l2 += (p**2).sum()*20
        
        g_loss = v_loss+l2/N
        # g_loss = v_loss +naive_estimator(real_imgs.view((batch,28*28)),gen_imgs.view((batch,28*28)),k=1)/N*100
#         print("real_sample.size():", real_sample.size())
#         print("gen_sample.size():", gen_sample.size())
#         g_loss = v_loss +naive_estimator(real_sample, gen_sample, k=1)/N*100

        g_loss.backward()
        optimizer_G.step()

        n = 0
        for par in generator.parameters(): 
            par.data.sub_(a*M[n]*lr_g/N)
            n = n+1

        with torch.no_grad(): 
            for param in generator.parameters(): 
                param.add_(torch.randn(param.size()).cuda() * np.sqrt(2*tau*lr_g/N))

        n = 0
        for par in generator.parameters(): 
            M[n] *= beta_1
            M[n] += (1-beta_1)*par.grad*N
            n = n+1

        iteration += 1

        # Nash equilibrium check
        if iteration % 100 == 1:
            # Gen_Dis[js] = torch.mean(discriminator(generator(torch.randn(
            #     (batch, latent_dim, 1, 1)).cuda()))).detach().cpu().item()
#             Gen_Dis[js] = torch.mean(discriminator(autoencoder.decoder(generator(torch.randn(
#                 (num_of_samples, latent_dim)).cuda())))).detach().cpu().item()
            
            if autoencoder_flag:
                Gen_Dis[js] = torch.mean(discriminator(autoencoder.decoder(generator(torch.randn(
                (num_of_samples, latent_dim)).cuda())))).detach().cpu().item()
            else:
                Gen_Dis[js] = torch.mean(discriminator(generator(torch.randn(
                (num_of_samples, latent_dim)).cuda()))).detach().cpu().item()
                
            Real_Dis[js] = torch.mean(discriminator(
                Variable(sample.type(Tensor)))).detach().cpu().item()
            js += 1

    print("epoch",epoch,"g_loss",g_loss.item(),"d_loss",d_loss.item(),"E(D(fake))",torch.mean(D_gen).item(),"E(D(real))",torch.mean(D_real).item())


  real = Variable(Tensor(num_of_samples).fill_(1.0), requires_grad=False)


epoch 0 g_loss 10.698472023010254 d_loss 1.2571314573287964 E(D(fake)) 0.4227368235588074 E(D(real)) 0.49597620964050293
epoch 1 g_loss 10.639214515686035 d_loss 1.1008896827697754 E(D(fake)) 0.32931649684906006 E(D(real)) 0.4997137188911438
epoch 2 g_loss 10.385124206542969 d_loss 1.0596849918365479 E(D(fake)) 0.31453296542167664 E(D(real)) 0.5096796751022339
epoch 3 g_loss 10.136439323425293 d_loss 1.0953577756881714 E(D(fake)) 0.3008112609386444 E(D(real)) 0.4875369369983673
epoch 4 g_loss 9.768461227416992 d_loss 1.1097683906555176 E(D(fake)) 0.32704854011535645 E(D(real)) 0.49675440788269043
epoch 5 g_loss 9.436370849609375 d_loss 1.1314208507537842 E(D(fake)) 0.3468097150325775 E(D(real)) 0.5056874752044678
epoch 6 g_loss 9.217110633850098 d_loss 1.1770908832550049 E(D(fake)) 0.33142074942588806 E(D(real)) 0.4750141501426697
epoch 7 g_loss 8.94255256652832 d_loss 1.1747868061065674 E(D(fake)) 0.33619430661201477 E(D(real)) 0.48266786336898804
epoch 8 g_loss 8.652918815612793 d_lo

epoch 69 g_loss 1.750794768333435 d_loss 1.390334129333496 E(D(fake)) 0.518091082572937 E(D(real)) 0.5167074203491211
epoch 70 g_loss 1.7180185317993164 d_loss 1.390707015991211 E(D(fake)) 0.5177038908004761 E(D(real)) 0.5161169767379761
epoch 71 g_loss 1.6862425804138184 d_loss 1.390061616897583 E(D(fake)) 0.5173534154891968 E(D(real)) 0.5160425305366516
epoch 72 g_loss 1.6553239822387695 d_loss 1.3895219564437866 E(D(fake)) 0.5170745849609375 E(D(real)) 0.5160263776779175
epoch 73 g_loss 1.625251293182373 d_loss 1.3902382850646973 E(D(fake)) 0.5168657302856445 E(D(real)) 0.5154598355293274
epoch 74 g_loss 1.5960667133331299 d_loss 1.3901344537734985 E(D(fake)) 0.5167121291160583 E(D(real)) 0.5153355002403259
epoch 75 g_loss 1.5678025484085083 d_loss 1.3897544145584106 E(D(fake)) 0.5165628790855408 E(D(real)) 0.5153639316558838
epoch 76 g_loss 1.5404770374298096 d_loss 1.3892927169799805 E(D(fake)) 0.5163795351982117 E(D(real)) 0.515402615070343
epoch 77 g_loss 1.5140929222106934 d_lo

epoch 137 g_loss 0.8290218710899353 d_loss 1.3897690773010254 E(D(fake)) 0.5081712603569031 E(D(real)) 0.5066026449203491
epoch 138 g_loss 0.8268590569496155 d_loss 1.388775110244751 E(D(fake)) 0.5073831677436829 E(D(real)) 0.5062816143035889
epoch 139 g_loss 0.8250089287757874 d_loss 1.388009786605835 E(D(fake)) 0.5064778923988342 E(D(real)) 0.5057373046875
epoch 140 g_loss 0.8232433795928955 d_loss 1.386867880821228 E(D(fake)) 0.5055782794952393 E(D(real)) 0.5053892135620117
epoch 141 g_loss 0.8213327527046204 d_loss 1.3855102062225342 E(D(fake)) 0.5047959685325623 E(D(real)) 0.505271852016449
epoch 142 g_loss 0.8191914558410645 d_loss 1.383667230606079 E(D(fake)) 0.5041784048080444 E(D(real)) 0.5055640935897827
epoch 143 g_loss 0.8167522549629211 d_loss 1.382112741470337 E(D(fake)) 0.5037577152252197 E(D(real)) 0.5059199333190918
epoch 144 g_loss 0.8140862584114075 d_loss 1.380455493927002 E(D(fake)) 0.5034926533699036 E(D(real)) 0.5064957737922668
epoch 145 g_loss 0.809157133102417

epoch 205 g_loss 0.7069894075393677 d_loss 1.3832964897155762 E(D(fake)) 0.5209541320800781 E(D(real)) 0.5241629481315613
epoch 206 g_loss 0.6916112303733826 d_loss 1.4018210172653198 E(D(fake)) 0.5288916230201721 E(D(real)) 0.5231207609176636
epoch 207 g_loss 0.6811627149581909 d_loss 1.416473627090454 E(D(fake)) 0.5343993306159973 E(D(real)) 0.5216003656387329
epoch 208 g_loss 0.6800603270530701 d_loss 1.4199061393737793 E(D(fake)) 0.5349624752998352 E(D(real)) 0.5204564332962036
epoch 209 g_loss 0.6841068863868713 d_loss 1.417550802230835 E(D(fake)) 0.5327423810958862 E(D(real)) 0.519139289855957
epoch 210 g_loss 0.6890139579772949 d_loss 1.4171886444091797 E(D(fake)) 0.5300319790840149 E(D(real)) 0.5162543654441833
epoch 211 g_loss 0.6907625198364258 d_loss 1.4171152114868164 E(D(fake)) 0.5289874076843262 E(D(real)) 0.5151288509368896
epoch 212 g_loss 0.6929490566253662 d_loss 1.418137550354004 E(D(fake)) 0.5277097225189209 E(D(real)) 0.5131789445877075
epoch 213 g_loss 0.696635246

epoch 273 g_loss 0.717243492603302 d_loss 1.3965333700180054 E(D(fake)) 0.5065591931343079 E(D(real)) 0.5015403032302856
epoch 274 g_loss 0.7180468440055847 d_loss 1.3964996337890625 E(D(fake)) 0.5061221718788147 E(D(real)) 0.5011101365089417
epoch 275 g_loss 0.7188937067985535 d_loss 1.3960773944854736 E(D(fake)) 0.5056608319282532 E(D(real)) 0.500851571559906
epoch 276 g_loss 0.7196083068847656 d_loss 1.3957678079605103 E(D(fake)) 0.5052643418312073 E(D(real)) 0.5005986094474792
epoch 277 g_loss 0.7202295660972595 d_loss 1.3955702781677246 E(D(fake)) 0.5049128532409668 E(D(real)) 0.5003470778465271
epoch 278 g_loss 0.7206553816795349 d_loss 1.3951404094696045 E(D(fake)) 0.5046650767326355 E(D(real)) 0.5003131628036499
epoch 279 g_loss 0.720604658126831 d_loss 1.3954534530639648 E(D(fake)) 0.5046579837799072 E(D(real)) 0.5001477003097534
epoch 280 g_loss 0.7204632759094238 d_loss 1.3957431316375732 E(D(fake)) 0.5046915411949158 E(D(real)) 0.5000318288803101
epoch 281 g_loss 0.72037363

epoch 341 g_loss 0.7216761708259583 d_loss 1.391875982284546 E(D(fake)) 0.5036349892616272 E(D(real)) 0.5009353756904602
epoch 342 g_loss 0.7220112681388855 d_loss 1.3928816318511963 E(D(fake)) 0.5034793615341187 E(D(real)) 0.5002807974815369
epoch 343 g_loss 0.7233982086181641 d_loss 1.3928163051605225 E(D(fake)) 0.5027785301208496 E(D(real)) 0.4996134340763092
epoch 344 g_loss 0.7247170805931091 d_loss 1.3927712440490723 E(D(fake)) 0.5020987391471863 E(D(real)) 0.49894407391548157
epoch 345 g_loss 0.7257304191589355 d_loss 1.3924956321716309 E(D(fake)) 0.5015702247619629 E(D(real)) 0.4985562264919281
epoch 346 g_loss 0.7259218096733093 d_loss 1.3929369449615479 E(D(fake)) 0.5014389753341675 E(D(real)) 0.49819108843803406
epoch 347 g_loss 0.7256107926368713 d_loss 1.3930985927581787 E(D(fake)) 0.5015612840652466 E(D(real)) 0.49821990728378296
epoch 348 g_loss 0.7251038551330566 d_loss 1.3933908939361572 E(D(fake)) 0.5017954111099243 E(D(real)) 0.4983043372631073
epoch 349 g_loss 0.724

epoch 409 g_loss 0.7195353507995605 d_loss 1.3901352882385254 E(D(fake)) 0.5050246715545654 E(D(real)) 0.503237783908844
epoch 410 g_loss 0.7200562953948975 d_loss 1.3910672664642334 E(D(fake)) 0.5047808289527893 E(D(real)) 0.5025389194488525
epoch 411 g_loss 0.7207293510437012 d_loss 1.3923981189727783 E(D(fake)) 0.5044553279876709 E(D(real)) 0.5015164017677307
epoch 412 g_loss 0.7216809988021851 d_loss 1.39241623878479 E(D(fake)) 0.5039841532707214 E(D(real)) 0.5010248422622681
epoch 413 g_loss 0.722724199295044 d_loss 1.3924990892410278 E(D(fake)) 0.503463089466095 E(D(real)) 0.5004493594169617
epoch 414 g_loss 0.7235986590385437 d_loss 1.3926823139190674 E(D(fake)) 0.5030169486999512 E(D(real)) 0.49990156292915344
epoch 415 g_loss 0.7244849801063538 d_loss 1.3927738666534424 E(D(fake)) 0.5025722980499268 E(D(real)) 0.4993995130062103
epoch 416 g_loss 0.7252686023712158 d_loss 1.3931617736816406 E(D(fake)) 0.502184271812439 E(D(real)) 0.4988108277320862
epoch 417 g_loss 0.7260288596

epoch 477 g_loss 0.7287042140960693 d_loss 1.3904340267181396 E(D(fake)) 0.49972090125083923 E(D(real)) 0.49768903851509094
epoch 478 g_loss 0.7286074161529541 d_loss 1.3908627033233643 E(D(fake)) 0.4997599422931671 E(D(real)) 0.49751216173171997
epoch 479 g_loss 0.7285595536231995 d_loss 1.391418695449829 E(D(fake)) 0.4997751712799072 E(D(real)) 0.497252494096756
epoch 480 g_loss 0.7286895513534546 d_loss 1.391465187072754 E(D(fake)) 0.4996989667415619 E(D(real)) 0.49716103076934814
epoch 481 g_loss 0.7286547422409058 d_loss 1.3914875984191895 E(D(fake)) 0.4996984899044037 E(D(real)) 0.4971442222595215
epoch 482 g_loss 0.7283825874328613 d_loss 1.3916410207748413 E(D(fake)) 0.4998159110546112 E(D(real)) 0.4971746504306793
epoch 483 g_loss 0.7278952598571777 d_loss 1.3915865421295166 E(D(fake)) 0.5000432729721069 E(D(real)) 0.49742212891578674
epoch 484 g_loss 0.727410078048706 d_loss 1.3914281129837036 E(D(fake)) 0.5002764463424683 E(D(real)) 0.49773743748664856
epoch 485 g_loss 0.726

epoch 545 g_loss 0.7281749248504639 d_loss 1.3915438652038574 E(D(fake)) 0.5000314712524414 E(D(real)) 0.4974324107170105
epoch 546 g_loss 0.7277683615684509 d_loss 1.3917546272277832 E(D(fake)) 0.5002373456954956 E(D(real)) 0.49753573536872864
epoch 547 g_loss 0.7272170782089233 d_loss 1.3919003009796143 E(D(fake)) 0.5005194544792175 E(D(real)) 0.49774497747421265
epoch 548 g_loss 0.7266589403152466 d_loss 1.3918063640594482 E(D(fake)) 0.5008159279823303 E(D(real)) 0.49808719754219055
epoch 549 g_loss 0.7262903451919556 d_loss 1.3916040658950806 E(D(fake)) 0.5010212063789368 E(D(real)) 0.4983938932418823
epoch 550 g_loss 0.7261968851089478 d_loss 1.3918935060501099 E(D(fake)) 0.501089870929718 E(D(real)) 0.4983241856098175
epoch 551 g_loss 0.7263216376304626 d_loss 1.3919624090194702 E(D(fake)) 0.5010513663291931 E(D(real)) 0.49825960397720337
epoch 552 g_loss 0.7266736626625061 d_loss 1.3920683860778809 E(D(fake)) 0.50089430809021 E(D(real)) 0.4980465769767761
epoch 553 g_loss 0.7272

epoch 613 g_loss 0.7292415499687195 d_loss 1.3862783908843994 E(D(fake)) 0.49917787313461304 E(D(real)) 0.4992286264896393
epoch 614 g_loss 0.7304964661598206 d_loss 1.3861256837844849 E(D(fake)) 0.4985419511795044 E(D(real)) 0.49866220355033875
epoch 615 g_loss 0.7315810918807983 d_loss 1.386292815208435 E(D(fake)) 0.4979807138442993 E(D(real)) 0.49802565574645996
epoch 616 g_loss 0.7323033809661865 d_loss 1.3862831592559814 E(D(fake)) 0.4975917339324951 E(D(real)) 0.49763503670692444
epoch 617 g_loss 0.7325938940048218 d_loss 1.3865097761154175 E(D(fake)) 0.49741506576538086 E(D(real)) 0.4973408579826355
epoch 618 g_loss 0.7324305176734924 d_loss 1.3869794607162476 E(D(fake)) 0.49746719002723694 E(D(real)) 0.4971523582935333
epoch 619 g_loss 0.7318883538246155 d_loss 1.3874894380569458 E(D(fake)) 0.49770745635032654 E(D(real)) 0.49713295698165894
epoch 620 g_loss 0.7310605049133301 d_loss 1.3877838850021362 E(D(fake)) 0.498097687959671 E(D(real)) 0.4973663091659546
epoch 621 g_loss 0

epoch 680 g_loss 0.7306692004203796 d_loss 1.3872891664505005 E(D(fake)) 0.49827489256858826 E(D(real)) 0.49779409170150757
epoch 681 g_loss 0.7301167845726013 d_loss 1.387413740158081 E(D(fake)) 0.4985303282737732 E(D(real)) 0.4979822039604187
epoch 682 g_loss 0.7293972969055176 d_loss 1.3870937824249268 E(D(fake)) 0.4988696873188019 E(D(real)) 0.4984775483608246
epoch 683 g_loss 0.7285966873168945 d_loss 1.387001633644104 E(D(fake)) 0.4992564916610718 E(D(real)) 0.49890896677970886
epoch 684 g_loss 0.7278070449829102 d_loss 1.3868695497512817 E(D(fake)) 0.49964550137519836 E(D(real)) 0.49936580657958984
epoch 685 g_loss 0.7270050644874573 d_loss 1.3869513273239136 E(D(fake)) 0.5000491142272949 E(D(real)) 0.49973341822624207
epoch 686 g_loss 0.72626793384552 d_loss 1.3866944313049316 E(D(fake)) 0.5004270076751709 E(D(real)) 0.5002442002296448
epoch 687 g_loss 0.7256820797920227 d_loss 1.386488437652588 E(D(fake)) 0.5007326602935791 E(D(real)) 0.5006580948829651
epoch 688 g_loss 0.7253

epoch 748 g_loss 0.726153552532196 d_loss 1.3864760398864746 E(D(fake)) 0.5006442070007324 E(D(real)) 0.5005807280540466
epoch 749 g_loss 0.7261162400245667 d_loss 1.3860077857971191 E(D(fake)) 0.5006799101829529 E(D(real)) 0.5008524656295776
epoch 750 g_loss 0.726233720779419 d_loss 1.3868753910064697 E(D(fake)) 0.5006282925605774 E(D(real)) 0.5003688931465149
epoch 751 g_loss 0.7265669107437134 d_loss 1.3860037326812744 E(D(fake)) 0.5004751086235046 E(D(real)) 0.500654935836792
epoch 752 g_loss 0.7269761562347412 d_loss 1.386040449142456 E(D(fake)) 0.5002772212028503 E(D(real)) 0.5004395842552185
epoch 753 g_loss 0.7274782657623291 d_loss 1.3862340450286865 E(D(fake)) 0.5000311732292175 E(D(real)) 0.5000937581062317
epoch 754 g_loss 0.7280347347259521 d_loss 1.3857896327972412 E(D(fake)) 0.49975231289863586 E(D(real)) 0.5000391006469727
epoch 755 g_loss 0.7285947799682617 d_loss 1.3847784996032715 E(D(fake)) 0.4994746446609497 E(D(real)) 0.5002761483192444
epoch 756 g_loss 0.72904145

epoch 816 g_loss 0.7273269891738892 d_loss 1.3882148265838623 E(D(fake)) 0.5002338290214539 E(D(real)) 0.4992889165878296
epoch 817 g_loss 0.7272886037826538 d_loss 1.38798987865448 E(D(fake)) 0.5002561807632446 E(D(real)) 0.4994237422943115
epoch 818 g_loss 0.7271880507469177 d_loss 1.388063907623291 E(D(fake)) 0.5003035664558411 E(D(real)) 0.49943795800209045
epoch 819 g_loss 0.7272054553031921 d_loss 1.3875353336334229 E(D(fake)) 0.5003025531768799 E(D(real)) 0.4997091293334961
epoch 820 g_loss 0.7271929383277893 d_loss 1.3879644870758057 E(D(fake)) 0.5003141164779663 E(D(real)) 0.4995008111000061
epoch 821 g_loss 0.7270785570144653 d_loss 1.38761305809021 E(D(fake)) 0.5003777146339417 E(D(real)) 0.49974340200424194
epoch 822 g_loss 0.7270908951759338 d_loss 1.387498140335083 E(D(fake)) 0.500377893447876 E(D(real)) 0.49980270862579346
epoch 823 g_loss 0.7271352410316467 d_loss 1.3882858753204346 E(D(fake)) 0.5003618001937866 E(D(real)) 0.49939602613449097
epoch 824 g_loss 0.72743284

epoch 884 g_loss 0.7252109050750732 d_loss 1.3874149322509766 E(D(fake)) 0.5012478232383728 E(D(real)) 0.5007132291793823
epoch 885 g_loss 0.7248386144638062 d_loss 1.3870409727096558 E(D(fake)) 0.501443088054657 E(D(real)) 0.5010966062545776
epoch 886 g_loss 0.7245932817459106 d_loss 1.386987328529358 E(D(fake)) 0.5015951991081238 E(D(real)) 0.5012754201889038
epoch 887 g_loss 0.7245073914527893 d_loss 1.3872514963150024 E(D(fake)) 0.5016600489616394 E(D(real)) 0.5012155771255493
epoch 888 g_loss 0.7244290113449097 d_loss 1.3873913288116455 E(D(fake)) 0.5017210245132446 E(D(real)) 0.5012102127075195
epoch 889 g_loss 0.724480152130127 d_loss 1.387507677078247 E(D(fake)) 0.5017163157463074 E(D(real)) 0.5011470317840576
epoch 890 g_loss 0.7248011231422424 d_loss 1.3878982067108154 E(D(fake)) 0.5015755891799927 E(D(real)) 0.500813364982605
epoch 891 g_loss 0.7252581119537354 d_loss 1.3877692222595215 E(D(fake)) 0.5013691782951355 E(D(real)) 0.5006682872772217
epoch 892 g_loss 0.7258942723

epoch 952 g_loss 0.7275494337081909 d_loss 1.3874866962432861 E(D(fake)) 0.4998815655708313 E(D(real)) 0.49930715560913086
epoch 953 g_loss 0.7275234460830688 d_loss 1.3877099752426147 E(D(fake)) 0.4999048709869385 E(D(real)) 0.49921733140945435
epoch 954 g_loss 0.7276358604431152 d_loss 1.3872379064559937 E(D(fake)) 0.4998548924922943 E(D(real)) 0.4994043707847595
epoch 955 g_loss 0.727886974811554 d_loss 1.3873600959777832 E(D(fake)) 0.49973469972610474 E(D(real)) 0.49922534823417664
epoch 956 g_loss 0.7281976342201233 d_loss 1.3873271942138672 E(D(fake)) 0.49958252906799316 E(D(real)) 0.49909546971321106
epoch 957 g_loss 0.7284983396530151 d_loss 1.3876097202301025 E(D(fake)) 0.4994378387928009 E(D(real)) 0.4988064169883728
epoch 958 g_loss 0.7286868095397949 d_loss 1.3877601623535156 E(D(fake)) 0.49934104084968567 E(D(real)) 0.49862661957740784
epoch 959 g_loss 0.7289155125617981 d_loss 1.3875908851623535 E(D(fake)) 0.4992203116416931 E(D(real)) 0.49859023094177246
epoch 960 g_loss

In [63]:
torch.save(generator.state_dict(), os.path.join(result_path, 'generator_batch_1000.model'))
torch.save(discriminator.state_dict(), os.path.join(result_path, 'discriminator_batch_1000.model'))

In [3]:
num_fake_batches = test_data_len//batchSize
fake_data = torch.zeros((0, feature_size), device='cpu')
for _ in range(num_fake_batches):
    z = torch.randn(batchSize, latent_dim, device=device)
    generated_batch = generator(z)
    if autoencoder_flag:
        fake_batch = autoencoder.decoder(generator(z))
    else:
        fake_batch = generator(z)
    fake_data = torch.cat((fake_data, fake_batch.round().to('cpu')), 0)
    
fake_data = fake_data.detach().numpy()

In [9]:

print(fake_data.shape)

(9300, 1358)


In [14]:
print(test_data.shape)

(9300, 1358)


In [16]:
test_data[:, 0].shape

(9300,)

In [4]:
from scipy.stats import kstest
count = 0
for i in range(test_data.shape[1]):
    stat, p_value = kstest(test_data[:, i], fake_data[:, i])    # dimension-wise ks test
    if p_value > 0.05:
        count+=1

avg = count/test_data.shape[1]
print(avg)

  return ks_2samp(xvals, yvals, alternative=alternative, method=method)


0.8696612665684831
