<a href="https://colab.research.google.com/github/shackste/galaxy-generator/blob/separate_py_files/ACGAN_CVAE_pytorch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

PyTorch implementation of combined Conditional Variable Auto Encoder (CVAE) and Auxiliary-Classifier Generative Adversarial Network (ACGAN)

continuation of work by Mohamad Dia

# Environment Setup

## Modules

In [None]:
!pip install torchviz
!pip install wandb -qqq

In [None]:
import sys 

import matplotlib.pyplot as plt
from torch import cat, add, ones, zeros

## include separate module files
#!git clone -b separate_py_files https://github.com/shackste/galaxy-generator.git
#sys.path.insert(0,"/content/galaxy-generator/python_modules/")

from google.colab import drive
drive.mount("/drive")

sys.path.insert(0,"/drive/MyDrive/FHNW/galaxy_generator/galaxy-generator/python_modules/")

from parameter import labels_dim, input_size, parameter
from file_system import root, folder_results
from helpful_functions import summarize, write_generated_galaxy_images_iteration
from sampler import make_training_sample_generator
from dataset import get_x_train, get_labels_train
from loss import loss_discriminator, loss_generator

import pipeline
from pipeline import VAE, VAEGAN
from discriminator import Discriminator4
from encoder import Encoder4
from decoder import Decoder4


In [None]:
from pdb import set_trace

# Track hyperparameter search via [wandb.ai](https://wandb.ai/shackste/galaxy-generator)

In [None]:
track_hyperparameter = False

if track_hyperparameter:
    import wandb
    !wandb login



"""  USAGE
### the following has to be adjusted for every training run
### since hyperparameters can be changed on the fly, you find this in Training section

wandb.init(project="galaxy-generator", # top level identifier
           group="first", # second level identifier, to seperate several groups of tests
           job_type="training", # third level identifier, organize different jobs like training and evaluation
           tags=["first"], # temporary tags to organize different tasks together
           name="first", # bottom level identifier, label of graph in UI
           config=parameter.return_parameter_dict()  # here we fill the hyperparameters
)

## to follow evolution of loss or other measures wich epoch, after each iteration use:
wandb.log({"loss":loss})  ## here we usually pass loss and accuracy measures

## after training is done and all measures are written, finalize with
wandb.finish()
"""


# basic test Neural Networks
check whether the pipeline components work on their own

In [None]:
summarize(Discriminator4)

In [None]:
import numpy as np
from torch import rand


net = Encoder4()

print(np.sum([np.prod(p.size()) for p in net.parameters()]) == 8508656)
input_dummy = rand(3, *input_size)
label_dummy = rand(3, labels_dim)

net(input_dummy, label_dummy)

In [None]:
net = Decoder4()

print(np.sum([np.prod(p.size()) for p in net.parameters()]))
input_dummy = rand(3, parameter.latent_dim)
label_dummy = rand(3, labels_dim)

net(input_dummy, label_dummy).shape

In [None]:
net = VAE()
images_dummy = rand(5,3,64,64).cuda()
labels_dummy = rand(5,3).cuda()
pred = net(images_dummy, labels_dummy)
pred.shape

In [None]:
net = VAEGAN()
input_dummy = rand(3,3,64,64).cuda()
labels_dummy = rand(3,3).cuda()
pred = net(input_dummy, labels_dummy)
pred.shape, pred[:,0]

# Training Data
load data from files at google drive

In [None]:
x_train = get_x_train()
labels_train = get_labels_train()
N_samples = x_train.shape[0]

# Training

In [None]:
## if you want to change any parameter:
parameter.learning_rate = 0.0002
## the following can also be changed during runtime
parameter.alpha = 1.
parameter.beta = 1.
parameter.gamma = 1.
parameter.delta = 1.
parameter.zeta = 1.


## start training with fresh, untrained networks
pipeline.decoder = Decoder4().cuda()
pipeline.encoder = Encoder4().cuda()
pipeline.discriminator = Discriminator4().cuda()

## if you want to change discriminator, encoder or decoder network:
#pipeline.decoder = YourDecoder().cuda()
## before you create the VAE and VAEGAN

if track_hyperparameter:
    wandb.init(project="galaxy-generator", # top level identifier
            group="check_losses", # second level identifier, to seperate several groups of tests
            job_type="training", # third level identifier, organize different jobs like training and evaluation
            tags=["all one"], # temporary tags to organize different tasks together
            name=f"no class loss, delta=0", # bottom level identifier, label of graph in UI
            config=parameter.return_parameter_dict()  # here we fill the hyperparameters
    )




epochs = 20
batch_size = 128
steps = N_samples // batch_size
save_interval = 200

discriminator_losses = []
discriminator_losses_real = []
discriminator_losses_fake = []
generator_losses = []

valid = ones((batch_size,1)).cuda()
fake = zeros((batch_size,1)).cuda()

vae = VAE()
vaegan = VAEGAN()

iteration = 0
epoch = 0
step = 0





In [None]:
def training_step():
    global iteration
    global sample_generator
    images, labels = next(sample_generator)

    # -------------------
    # Train Discriminator
    # -------------------
    vae.train(False)
    pipeline.discriminator.train(True)
    pipeline.discriminator.zero_grad()

    generated_images = vae(images, labels)
    target_real = cat((valid,labels), dim=1)
    prediction_real = pipeline.discriminator(images)[:,:1+labels_dim]
    target_fake = cat((fake, labels), dim=1)
    prediction_fake = pipeline.discriminator(generated_images)[:,:1+labels_dim]

    d_loss_real = loss_discriminator(target_real, prediction_real)
    d_loss_fake = loss_discriminator(target_fake, prediction_fake)
    d_loss = 0.5 * add(d_loss_fake, d_loss_real)
    discriminator_losses.append(d_loss)
    discriminator_losses_fake.append(d_loss_fake)
    discriminator_losses_real.append(d_loss_real)

    d_loss_real.backward()
    d_loss_fake.backward()
    pipeline.discriminator.optimizer.step()

    # ---------------
    # Train Generator
    # ---------------
    vae.train(True)
    pipeline.discriminator.train(False)
    pipeline.encoder.zero_grad()
    pipeline.decoder.zero_grad()

    generated_images = vae(images, labels)
    target = pipeline.discriminator(images)
    target[:,0] = 1
    target[:,1:1+labels_dim] = labels
    target = target.detach()
    prediction = pipeline.discriminator(generated_images)
    latent = pipeline.encoder(images, labels)

    g_loss = loss_generator(target, prediction, images, generated_images, latent)
    g_loss.backward()
    pipeline.encoder.optimizer.step()
    pipeline.decoder.optimizer.step()
    generator_losses.append(g_loss)

    if track_hyperparameter:
        ## save measures to wandb.ai
        wandb.log({"loss discriminator":d_loss, "loss generator":g_loss}) 


    iteration += 1

    print(f"iteration {iteration}, epoch {epoch+1}, batch {step+1}/{steps}," + \
          f"disc_loss {d_loss:.5}, (real {d_loss_real:.5}, fake {d_loss_fake:.5} ) gen_loss {g_loss:.5}")

    if not iteration % save_interval:
        write_generated_galaxy_images_iteration(iteration=iteration, images=generated_images.detach().cpu().numpy())


In [None]:
while epoch < epochs:
    sample_generator = make_training_sample_generator(batch_size, x_train, labels_train)
    step = 0
    while step < steps:
        training_step()
        step += 1
    epoch += 1

    # save a plot of the costs
    plt.clf()
    plt.plot(discriminator_losses, label='discriminator cost')
    plt.plot(generator_losses, label='generator cost')
    plt.plot(discriminator_losses_fake, label='discriminator cost fake', linestyle=":")
    plt.plot(discriminator_losses_real, label='discriminator cost real', linestyle="-.")
    plt.yscale("log")
    plt.legend()
    plt.savefig(folder_results+"cost_vs_iteration.png")
    plt.close()


    ### really save?
    #pipeline.decoder.save()
    #pipeline.encoder.save()
    #pipeline.discriminator.save()


if track_hyperparameter:
    ## after training is done and all measures are written, finalize with
    wandb.finish()


# Testing