# Import Modules

In [1]:
import numpy as np
import random
import scipy.io as scio
import torch
import torch.nn as nn

from models.unrolled_cgan import Generator, Discriminator, GANLoops

# Define Train function

In [2]:
################### HYPERPARAMS ####################
unrolled_steps = 10
num_updates_per_epoch = 500
d_steps = 1
g_steps = 1
####################################################

def train(dataset, num_epochs=1000, batch_size=32, noise_size = 64, lr=0.0002, device='cpu', class_num=10, tag=""):
    random.seed(42)
    np.random.seed(42)
    torch.manual_seed(42)

    criterion = nn.BCELoss()
    D = Discriminator(num_classes=class_num).to(device)
    G = Generator(num_classes=class_num).to(device)
    
    fixed_noise = torch.randn([batch_size, noise_size])
    fixed_conditional = torch.randint(0, class_num, (batch_size,))
    
    discriminator_optimizer = torch.optim.Adam(D.parameters(), lr=lr, betas=(0.5, 0.999))
    generator_optimizer = torch.optim.Adam(G.parameters(), lr=lr, betas=(0.5, 0.999))
    
    global G_Loss, D_Loss, D_real, D_fake, samples
    G_Loss = []
    D_Loss = []
    D_real = []
    D_fake = []
    samples = []
    index = 0
    
    gan_loops = GANLoops(dataset, class_num, device = device)
    
    for epoch in np.arange(num_epochs):
        D.train()
        G.train()
        
        for t in np.arange(num_updates_per_epoch):
            d_infos = []
            for d_index in range(d_steps):
                d_info = gan_loops.d_loop(G, D, discriminator_optimizer, criterion)
                d_infos.append(d_info)
            d_infos = np.mean(d_infos, 0)
            d_real_loss, d_fake_loss, D_G_z, D_x = d_infos

            g_infos = []
            for g_index in range(g_steps):
                g_info = gan_loops.g_loop(G, D, generator_optimizer, discriminator_optimizer, 
                                          criterion, unrolled_steps=unrolled_steps)
                g_infos.append(g_info)
            g_infos = np.mean(g_infos)
            g_loss = g_infos


            errG = g_loss
            errD = d_real_loss + d_fake_loss
            G_Loss.append(errG)
            D_Loss.append(errD)
            D_real.append(D_x)
            D_fake.append(D_G_z)
            
            if(index%100 == 0):
                print("Epoch:", epoch, "Global Iter: ",index,"Current G Loss : %.2f" % errG,"Current D Loss : %.2f" % errD, end=" ")
                print("D(x): %.2f" %  D_x, "D(G(z)): %.2f" % D_G_z)
            index+=1
        
        # Evaluation
        with torch.no_grad():
            G.eval()
            sr = G(fixed_noise.to(device), fixed_conditional.to(device))
            ## make some visualization or saving file func
            samples.append(sr.detach().cpu())
            
            torch.save(G.state_dict(), f"./model_ckpt/ckpt{epoch}_generator_{tag}.pt")
            torch.save(D.state_dict(), f"./model_ckpt/ckpt{epoch}_discriminator_{tag}.pt")   

# Load Dataset & Data Batch Split

In [3]:
cgan_dataset = scio.loadmat("./gan_dataset/cgan_dataset.mat")

In [4]:
batch_size = 32
x = []
y = []
iteration = cgan_dataset["X"].shape[0] // batch_size
for i in range((iteration-1)):
    x.append(cgan_dataset['X'][batch_size*i: batch_size*(i+1)])
    y.append(cgan_dataset['Y'][0][batch_size*i: batch_size*(i+1)])
    
x = np.array(x, dtype=float)
x = np.expand_dims(x, 2)
y = np.array(y, dtype=int)

x.shape, y.shape

((695, 32, 1, 200), (695, 32))

# TRAIN

In [5]:
train((x,y), num_epochs=1000, class_num=18, device="cuda", tag="unrolled")

Epoch: 0 Global Iter:  0 Current G Loss : 0.69 Current D Loss : 1.39 D(x): 0.50 D(G(z)): 0.50
Epoch: 0 Global Iter:  100 Current G Loss : 1.59 Current D Loss : 0.96 D(x): 0.70 D(G(z)): 0.36
Epoch: 0 Global Iter:  200 Current G Loss : 3.03 Current D Loss : 0.98 D(x): 0.68 D(G(z)): 0.35
Epoch: 0 Global Iter:  300 Current G Loss : 3.18 Current D Loss : 0.79 D(x): 0.78 D(G(z)): 0.29
Epoch: 0 Global Iter:  400 Current G Loss : 6.29 Current D Loss : 1.06 D(x): 0.62 D(G(z)): 0.30
Epoch: 1 Global Iter:  500 Current G Loss : 6.13 Current D Loss : 0.78 D(x): 0.74 D(G(z)): 0.27
Epoch: 1 Global Iter:  600 Current G Loss : 9.13 Current D Loss : 0.93 D(x): 0.81 D(G(z)): 0.35
Epoch: 1 Global Iter:  700 Current G Loss : 8.48 Current D Loss : 0.95 D(x): 0.65 D(G(z)): 0.28
Epoch: 1 Global Iter:  800 Current G Loss : 8.04 Current D Loss : 0.95 D(x): 0.76 D(G(z)): 0.35
Epoch: 1 Global Iter:  900 Current G Loss : 9.74 Current D Loss : 0.83 D(x): 0.82 D(G(z)): 0.32
Epoch: 2 Global Iter:  1000 Current G Loss

## Save Train Results

In [6]:
import plotly.express as px
import pandas as pd
import plotly.graph_objects as go

In [7]:
import datetime

In [8]:
time_tag = datetime.datetime.now().date().strftime("%m%d")

In [9]:
def make_average_epoch(data, iter_in_epoch):
    arr = np.array(data)
    return np.nanmean(np.pad(arr.astype(float), (0, iter_in_epoch - arr.size%iter_in_epoch), mode='constant', constant_values=np.NaN).reshape(-1, iter_in_epoch), axis=1)

In [10]:
iter_in_epoch = x.shape[0]
avg_results = pd.DataFrame(np.array(list(map(lambda x : make_average_epoch(x,iter_in_epoch), [G_Loss, D_Loss, D_real, D_fake]))).T)
avg_results.columns = ["G_Loss", "D_Loss", "D_real", "D_fake"]
data = [go.Scatter(x=avg_results.index, y=avg_results.G_Loss, name="G_Loss"),
        go.Scatter(x=avg_results.index, y=avg_results.D_Loss, name="D_Loss"),
        go.Scatter(x=avg_results.index, y=avg_results.D_real, name="D_real"),
        go.Scatter(x=avg_results.index, y=avg_results.D_fake, name="D_fake")]
fig = go.Figure(data=data)
fig.write_html(f"./gan_loss/cgan_results_averaged_{time_tag}.html")

---

# Load Model and Make Dataset

In [20]:
class_num = 18
batch_size = 100
noise_size = 64

model = Generator(num_classes=class_num)
_input = torch.randn([batch_size, noise_size])
_class =  torch.randint(0, class_num, (batch_size,))

### Load from state dict

In [21]:
########################################
i = 0
tag = "unrolled"
########################################

model.load_state_dict(torch.load(f"./model_ckpt/ckpt{i}_generator_{tag}.pt"))

<All keys matched successfully>

## Make dataset

In [23]:
results = model(_input, _class)

In [27]:
np.save("./gan_results/gan_output.npy", results.detach().numpy())