# Import Modules

In [19]:
import numpy as np
import random
import scipy.io as scio
import torch
import torch.nn as nn

from models.unrolled_cgan import Generator, Discriminator, GANLoops

# Define Train function

In [14]:
################### HYPERPARAMS ####################
unrolled_steps = 10
num_updates_per_epoch = 500
d_steps = 1
g_steps = 1
####################################################

def train(dataset, num_epochs=1000, batch_size=32, noise_size = 64, lr=0.0002, device='cpu', class_num=10, tag=""):
    random.seed(42)
    np.random.seed(42)
    torch.manual_seed(42)

    criterion = nn.BCELoss()
    D = Discriminator(num_classes=class_num).to(device)
    G = Generator(num_classes=class_num).to(device)
    
    fixed_noise = torch.randn([batch_size, noise_size])
    fixed_conditional = torch.randint(0, class_num, (batch_size,))
    
    discriminator_optimizer = torch.optim.Adam(D.parameters(), lr=lr, betas=(0.5, 0.999))
    generator_optimizer = torch.optim.Adam(G.parameters(), lr=lr, betas=(0.5, 0.999))
    
    global G_Loss, D_Loss, D_real, D_fake, samples
    G_Loss = []
    D_Loss = []
    D_real = []
    D_fake = []
    samples = []
    index = 0
    
    gan_loops = GANLoops(dataset, class_num, device = device)
    
    for epoch in np.arange(num_epochs):
        D.train()
        G.train()
        
        for t in np.arange(num_updates_per_epoch):
            d_infos = []
            for d_index in range(d_steps):
                d_info = gan_loops.d_loop(G, D, discriminator_optimizer, criterion)
                d_infos.append(d_info)
            d_infos = np.mean(d_infos, 0)
            d_real_loss, d_fake_loss, D_G_z, D_x = d_infos

            g_infos = []
            for g_index in range(g_steps):
                g_info = gan_loops.g_loop(G, D, generator_optimizer, discriminator_optimizer, 
                                          criterion, unrolled_steps=unrolled_steps)
                g_infos.append(g_info)
            g_infos = np.mean(g_infos)
            g_loss = g_infos


            errG = g_loss
            errD = d_real_loss + d_fake_loss
            G_Loss.append(errG)
            D_Loss.append(errD)
            D_real.append(D_x)
            D_fake.append(D_G_z)
            
            if(index%100 == 0):
                print("Epoch:", epoch, "Global Iter: ",index,"Current G Loss : %.2f" % errG,"Current D Loss : %.2f" % errD, end=" ")
                print("D(x): %.2f" %  D_x, "D(G(z)): %.2f" % D_G_z)
            index+=1
        
        # Evaluation
        with torch.no_grad():
            G.eval()
            sr = G(fixed_noise.to(device), fixed_conditional.to(device))
            ## make some visualization or saving file func
            samples.append(sr.detach().cpu())
            
            torch.save(G.state_dict(), f"./model_ckpt/ckpt{epoch}_generator_{tag}.pt")
            torch.save(D.state_dict(), f"./model_ckpt/ckpt{epoch}_discriminator_{tag}.pt")   

# Load Dataset & Data Batch Split

In [15]:
cgan_dataset = scio.loadmat("./cgan_dataset.mat")

In [17]:
batch_size = 32
x = []
y = []
iteration = cgan_dataset["X"].shape[0] // batch_size
for i in range((iteration-1)):
    x.append(cgan_dataset['X'][batch_size*i: batch_size*(i+1)])
    y.append(cgan_dataset['Y'][0][batch_size*i: batch_size*(i+1)])
    
x = np.array(x, dtype=float)
x = np.expand_dims(x, 2)
y = np.array(y, dtype=int)

x.shape, y.shape

((695, 32, 1, 200), (695, 32))

# TRAIN

In [None]:
train((x,y), num_epochs=3000, class_num=18, device="cuda", tag="unrolled")

---

# Load Model and Make Dataset

In [20]:
class_num = 18
batch_size = 100
noise_size = 64

model = Generator(num_classes=class_num)
_input = torch.randn([batch_size, noise_size])
_class =  torch.randint(0, class_num, (batch_size,))

### Load from state dict

In [21]:
########################################
i = 0
tag = "unrolled"
########################################

model.load_state_dict(torch.load(f"./model_ckpt/ckpt{i}_generator_{tag}.pt"))

<All keys matched successfully>

In [23]:
results = model(_input, _class)

In [27]:
np.save("./gan_results/gan_output.npy", results.detach().numpy())