# Import Modules

In [1]:
import numpy as np
import random
import scipy.io as scio
import torch
import torch.nn as nn

from models.cgan import Generator, Discriminator

# CGAN

A conditional generative adversarial network (CGAN) is a type of GAN that also takes advantage of labels during the training process.  

Generator — Given a label and random array as input, this network generates data with the same structure as the training data observations corresponding to the same label.  

Discriminator — Given batches of labeled data containing observations from both the training data and generated data from the generator, this network attempts to classify the observations as "real" or "generated".  

![Gan](https://www.mathworks.com/help/examples/nnet/win64/TrainConditionalGenerativeAdversarialNetworkCGANExample_02.png)

## GAN Loss

- Loss_D - discriminator loss calculated as the sum of losses for the all real and all fake batches (log(D(x)) + log(1 - D(G(z)))  : WANT
- Loss_G - generator loss calculated as log(D(G(z)))  : WANT maximize
- D(x) - the average output (across the batch) of the discriminator for the all real batch. This should start close to 1 then theoretically converge to 0.5 when G gets better. Think about why this is.  
- D(G(z)) - average discriminator outputs for the all fake batch. The first number is before D is updated and the second number is after D is updated. These numbers should start near 0 and converge to 0.5 as G gets better.  

# Define Train Function

In [2]:
# train
# Loss of original GAN paper.
def train(data, num_epochs=10, batch_size=32, noise_size = 64, lr=0.0002, device='cpu', class_num=10, tag=""):
    random.seed(42)
    np.random.seed(42)
    torch.manual_seed(42)
    
    # original was MSE
    adversarial_criterion = nn.BCELoss()
    discriminator = Discriminator(num_classes=class_num).to(device)
    generator = Generator(num_classes=class_num).to(device)
    
    fixed_noise = torch.randn([batch_size, noise_size])
    fixed_conditional = torch.randint(0, class_num, (batch_size,))
    
    discriminator_optimizer = torch.optim.Adam(discriminator.parameters(), lr=lr, betas=(0.5, 0.999))
    generator_optimizer = torch.optim.Adam(generator.parameters(), lr=lr, betas=(0.5, 0.999))
    
    x = data[0]
    y = data[1]
    global G_Loss, D_Loss, D_real, D_fake, sr_list
    G_Loss = []
    D_Loss = []
    sr_list = []
    D_real = []
    D_fake = []
    
    index = 0
    for epoch in np.arange(num_epochs):
        discriminator.train()
        generator.train()
        
        
        for i, (inputs, target) in enumerate(zip(x,y)):
            inputs = torch.Tensor(inputs).to(device)
            target = torch.IntTensor(target).to(device)
            
            real_label = torch.ones(batch_size, 1).to(device)
            fake_label = torch.zeros(batch_size, 1).to(device)
            
            noise = torch.randn([batch_size, noise_size]).to(device)
            conditional = torch.randint(0, class_num, (batch_size,)).to(device)
            
            # d_loss true
            real_output = discriminator(inputs, target)
            d_loss_real = adversarial_criterion(real_output, real_label)
            D_x = real_output.detach().cpu().mean().item()
            
            # d_loss fake
            fake = generator(noise, conditional)
            fake_output = discriminator(fake.detach(), conditional)
            d_loss_fake = adversarial_criterion(fake_output, fake_label)
            D_G_z = fake_output.detach().cpu().mean().item()
            
            # d_loss total
            d_loss = d_loss_fake + d_loss_real
            
            # d train
            discriminator.zero_grad()
            generator.zero_grad()
            d_loss.backward()
            discriminator_optimizer.step()
            
            # g loss
            fake = generator(noise, conditional)
            fake_output = discriminator(fake.detach(), conditional)
            g_loss = adversarial_criterion(fake_output, real_label)
            
            # g train
            discriminator.zero_grad()
            generator.zero_grad()
            g_loss.backward()
            discriminator_optimizer.step()
            
            errG = g_loss.detach().cpu().mean().item()
            errD = d_loss.detach().cpu().mean().item()
            
            G_Loss.append(errG)
            D_Loss.append(errD)
            
            if(index%100 == 0):
                print("Epoch:", epoch, "Global Iter: ",index,"Current G Loss : %.2f" % errG,"Current D Loss : %.2f" % errD, end=" ")
                print("D(x): %.2f" %  D_x, "D(G(z)): %.2f" % D_G_z)
            index+=1

        # Evaluation
        with torch.no_grad():
            generator.eval()
            sr = generator(fixed_noise.to(device), fixed_conditional.to(device))
            ## make some visualization or saving file func
            # Func(sr, path)
            sr_list.append(sr.detach().cpu())
            
            torch.save(generator.state_dict(), f"./model_ckpt/ckpt{epoch}_generator_{tag}.pt")
            torch.save(discriminator.state_dict(), f"./model_ckpt/ckpt{epoch}_discriminator_{tag}.pt")   

# Load Dataset & Data Batch Split

In [3]:
cgan_dataset = scio.loadmat("./cgan_dataset.mat")

In [4]:
batch_size = 32
x = []
y = []
iteration = cgan_dataset["X"].shape[0] // batch_size
for i in range((iteration-1)):
    x.append(cgan_dataset['X'][batch_size*i: batch_size*(i+1)])
    y.append(cgan_dataset['Y'][0][batch_size*i: batch_size*(i+1)])
    
x = np.array(x, dtype=float)
x = np.expand_dims(x, 2)
y = np.array(y, dtype=int)

x.shape, y.shape

((695, 32, 1, 200), (695, 32))

# TRAIN

In [9]:
train(data=[x,y], num_epochs=1000, batch_size=32, lr=0.00002, device='cuda', class_num=18, tag="")

Epoch: 0 Global Iter:  0 Current G Loss : 0.69 Current D Loss : 1.39 D(x): 0.50 D(G(z)): 0.50
Epoch: 0 Global Iter:  100 Current G Loss : 0.68 Current D Loss : 1.33 D(x): 0.54 D(G(z)): 0.51
Epoch: 0 Global Iter:  200 Current G Loss : 0.54 Current D Loss : 1.15 D(x): 0.77 D(G(z)): 0.58
Epoch: 0 Global Iter:  300 Current G Loss : 0.48 Current D Loss : 1.17 D(x): 0.83 D(G(z)): 0.62
Epoch: 0 Global Iter:  400 Current G Loss : 0.51 Current D Loss : 1.02 D(x): 0.93 D(G(z)): 0.61
Epoch: 0 Global Iter:  500 Current G Loss : 0.52 Current D Loss : 1.19 D(x): 0.77 D(G(z)): 0.60
Epoch: 0 Global Iter:  600 Current G Loss : 0.55 Current D Loss : 0.97 D(x): 0.93 D(G(z)): 0.58
Epoch: 1 Global Iter:  700 Current G Loss : 0.59 Current D Loss : 0.88 D(x): 0.95 D(G(z)): 0.56
Epoch: 1 Global Iter:  800 Current G Loss : 0.63 Current D Loss : 0.92 D(x): 0.88 D(G(z)): 0.54
Epoch: 1 Global Iter:  900 Current G Loss : 0.61 Current D Loss : 0.88 D(x): 0.95 D(G(z)): 0.55


KeyboardInterrupt: 

---

# Load Model and Make Dataset

In [3]:
class_num = 18
batch_size = 100
noise_size = 64

model = Generator(num_classes=class_num)
_input = torch.randn([batch_size, noise_size])
_class =  torch.randint(0, class_num, (batch_size,))

### Load from state dict

In [5]:
########################################
i = 694
tag = ""
########################################

model.load_state_dict(torch.load(f"./model_ckpt/ckpt{i}_generator_{tag}.pt"))

<All keys matched successfully>

In [None]:
results = model(_input, _class)

In [None]:
np.save("./gan_results/gan_output.npy", results.detach().numpy())