<a href="https://colab.research.google.com/github/shackste/galaxy-generator/blob/separate_py_files/ACGAN_CVAE_pytorch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

PyTorch implementation of combined Conditional Variable Auto Encoder (CVAE) and Auxiliary-Classifier Generative Adversarial Network (ACGAN)

continuation of work by Mohamad Dia

# Environment Setup

## Modules

In [None]:
!pip install torchviz
!pip install photutils
!pip install statmorph
!pip install wandb -qqq

In [None]:
import sys 
from time import time
from pdb import set_trace

import matplotlib.pyplot as plt
import numpy as np
from torch import cat, add, ones, zeros, tensor_split
import torchvision.models as models

## include separate module files
#!git clone -b separate_py_files https://github.com/shackste/galaxy-generator.git
#sys.path.insert(0,"/content/galaxy-generator/python_modules/")

from google.colab import drive
drive.mount("/drive")

sys.path.insert(0,"/drive/MyDrive/FHNW/galaxy_generator/galaxy-generator/python_modules/")

from parameter import labels_dim, input_size, parameter
from file_system import root, folder_results
from helpful_functions import summarize, write_generated_galaxy_images_iteration
from sampler import make_training_sample_generator
from dataset import get_x_train, get_labels_train
from loss import loss_discriminator, loss_generator

import pipeline
from pipeline import VAE, VAEGAN
from discriminator import Discriminator4
from encoder import Encoder4
from decoder import Decoder4

track_hyperparameter = False


# Track hyperparameter search via [wandb.ai](https://wandb.ai/shackste/galaxy-generator)

In [None]:
track_hyperparameter = True

if track_hyperparameter:
    import wandb
    !wandb login



"""  USAGE
### the following has to be adjusted for every training run
### since hyperparameters can be changed on the fly, you find this in Training section

wandb.init(project="galaxy-generator", # top level identifier
           group="first", # second level identifier, to seperate several groups of tests
           job_type="training", # third level identifier, organize different jobs like training and evaluation
           tags=["first"], # temporary tags to organize different tasks together
           name="first", # bottom level identifier, label of graph in UI
           config=parameter.return_parameter_dict()  # here we fill the hyperparameters
)

## to follow evolution of loss or other measures wich epoch, after each iteration use:
wandb.log({"loss":loss})  ## here we usually pass loss and accuracy measures

## after training is done and all measures are written, finalize with
wandb.finish()
"""


# new architecture

In [None]:
from torch import rand
batch_size = 64
labels_dim = 37

images = rand(batch_size, 3, 64, 64).cuda()
labels = rand(batch_size, labels_dim).cuda()


In [None]:
import torchvision.models as models
net = models.resnet50()
net
### no pretrained inverse available...


In [None]:
from functools import wraps, partial
import numpy as np

from torch import cat, randn, rand, tensor, sum, mean, abs, all, squeeze
from torch.nn import Sequential, \
                     Conv2d, ConvTranspose2d, Linear, \
                     LeakyReLU, Sigmoid, Softmax, Tanh, Softplus, \
                     BatchNorm1d, BatchNorm2d, Flatten


from neuralnetwork import NeuralNetwork, update_networks_on_loss
from parameter import colors_dim, labels_dim, parameter
from additional_layers import Reshape
from sampler import gaussian_sampler, generate_latent, generate_noise, generate_galaxy_labels
from loss import loss_adversarial, loss_reconstruction, loss_kl, loss_class, loss_metric, loss_latent
from labeling import labels_dim, label_group_sizes, class_groups, class_groups_indices, make_galaxy_labels_hierarchical
from weight_functions import constant_weight, rising_weight, falling_weight, cyclical_weight
from accuracy_measures import accuracy_discriminator, accuracy_classifier
from decorators import loss_to_value

## prebuilt model for Encoder
resnet = models.resnet18(pretrained=False)
N_resnet = 512
#resnet = models.resnet50(pretrained=False)
#N_resnet = 2048

## 3x3 kernel

kernel_size = 3
stride = 2
padding = NeuralNetwork.same_padding(kernel_size)
output_padding = padding

## 5x5 kernel
'''
kernel_size = 5
stride = 2
padding = NeuralNetwork.same_padding(kernel_size)
output_padding = 1
#'''

groups_classified = [1, 2, 7]

ixs_labels = []
for g in groups_classified:
    ixs_labels.extend(class_groups_indices[g]-1)
ixs_labels = tensor(ixs_labels)





## weights

weight_1 = partial(constant_weight, weight=1.)
weight_0 = partial(constant_weight, weight=0.)
weight_rise_0_3k__1_4k = partial(rising_weight, iteration_rise=3000, iteration_max=4000, max_weight=1., min_weight=0.)
weight_rise_0_6k__1_7k = partial(rising_weight, iteration_rise=6000, iteration_max=7000, max_weight=1., min_weight=0.)
weight_rise_0_10k__1_12k = partial(rising_weight, iteration_rise=10000, iteration_max=12000, max_weight=1., min_weight=0.)

weights = {
    # discriminator
    "discriminate_real" : weight_0, # weight_1,
    "discriminate_reproduced" : weight_0, # weight_1,
    "discriminate_latent" : weight_0, # weight_1,
    "discriminate_labels" : weight_0, # weight_rise_0_10k__1_12k,
    "discriminate_fake_labels" : weight_0, # weight_rise_0_10k__1_12k,
    # generator
    "regularize_latent" : weight_0, # partial(cyclical_weight, full_cycle=2000, rise_cycle=1000, max_weight=1.),
    "reproduce_image" : weight_0, # partial(constant_weight, weight=1e-2),
    "trick_reproduced" : weight_0, # partial(constant_weight, weight=1e1),
    "reproduce_image_labels" : weight_0, # weight_1, # weight_rise_0_6k__1_7k,
    "trick_labels" : weight_0, # weight_rise_0_10k__1_12k,
    "reproduce_labels" : weight_0, # weight_rise_0_10k__1_12k,
    "trick_fake_labels" : weight_0, # weight_rise_0_10k__1_12k,
    "reproduce_fake_labels" : weight_0, # weight_rise_0_10k__1_12k,
    # image classifier
    "image_label" : weight_0,
    # classifier
    "label_image" : weight_1, # weight_rise_0_3k__1_4k,
    "label_reproduced_image" : weight_0, # weight_rise_0_3k__1_4k,
    "label_labels" : weight_0, # weight_rise_0_10k__1_12k,
    "label_fake_labels" : weight_0, # weight_rise_0_10k__1_12k,
    # inverse classifier
    "unity_labels" : weight_0, # weight_rise_0_6k__1_7k,
    "unity_latent" : weight_0, # weight_rise_0_6k__1_7k,


}




## layer creators

def create_conv_layer(input_dim, output_dim, kernel_size=kernel_size, stride=stride, padding=padding):
    return Sequential(
        Conv2d(input_dim, output_dim, kernel_size=kernel_size, stride=stride, padding=padding),
        BatchNorm2d(output_dim, momentum=parameter.momentum),
        LeakyReLU(negative_slope=parameter.negative_slope),
    )

def create_iconv_layer(input_dim, output_dim, kernel_size=kernel_size, stride=stride, padding=padding, output_padding=output_padding):
    return Sequential(
        ConvTranspose2d(input_dim, output_dim, kernel_size=kernel_size, stride=stride, padding=padding, output_padding=output_padding),
        BatchNorm2d(output_dim, momentum=parameter.momentum),
        LeakyReLU(negative_slope=parameter.negative_slope),
    )

def create_dense_layer(input_dim, output_dim):
    return Sequential(
        Linear(input_dim, output_dim),
        BatchNorm1d(output_dim, momentum=parameter.momentum),
        LeakyReLU(negative_slope=parameter.negative_slope),
    )

accuracy = {}


## useful functions



## useful decorators

losses = {}


def weighted_loss(name):
    def actual_decorator(func):
        @wraps(func)
        def wrapper(*args, **kwargs):
            weight = weights[name](args[0].iteration)
            if not weight:
                return tensor(0.).cuda()
            loss = func(*args, **kwargs)
            loss *= weight
            losses[name] = loss.item()
            return loss
        return wrapper
    return actual_decorator


## my architecture



class Encoder(NeuralNetwork):
    """ transforms image to latent vector (distribution) """
    def __init__(self):
        super(Encoder, self).__init__()
        N_features = 8

        self.conv0 = Sequential(
            Conv2d(colors_dim, N_features, kernel_size=1, stride=1),
            LeakyReLU(negative_slope=parameter.negative_slope),
        )
        self.conv1 = create_conv_layer(N_features, 2*N_features)
        self.conv2 = create_conv_layer(2*N_features, 2**2*N_features)
        self.conv3 = create_conv_layer(2**2*N_features, 2**3*N_features)
        self.conv4 = create_conv_layer(2**3*N_features, 2**4*N_features)
        self.conv5 = create_conv_layer(2**4*N_features, 2**5*N_features, kernel_size=4, stride=1, padding=0)

        self.flat = Flatten()
        self.dense = create_dense_layer(2**5*N_features, 2**6*N_features)

        ## the following take the same input
        self.dense_z_mu = Sequential(
            Linear(2**6*N_features, parameter.latent_dim),
#            Tanh(),
        )
        self.dense_z_std = Sequential(
            Linear(2**6*N_features, parameter.latent_dim),
            Softplus()
#            Sigmoid(),
        )
        self.set_optimizer(parameter.optimizer, lr=parameter.learning_rate, betas=parameter.betas)




    def forward(self, images):
        x = self.conv0(images)
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.conv4(x)
        x = self.conv5(x)
        x = self.flat(x)
        x = self.dense(x)
        z_mu = self.dense_z_mu(x)
        z_std = self.dense_z_std(x)
        return z_mu, z_std


class Encoder(NeuralNetwork):
    """ transforms image to latent vector (distribution) """
    def __init__(self):
        super(Encoder, self).__init__()
        self.conv = Sequential(*(list(resnet.children())[:-1]))
        self.set_optimizer(parameter.optimizer, lr=parameter.learning_rate, betas=parameter.betas)

    def forward(self, images):
        latent = self.conv(images)
        latent = squeeze(latent)
        latent_mu, latent_std = tensor_split(latent, 2, dim=1)
        return latent_mu, latent_std


class Decoder(NeuralNetwork):
    """ generates image from latent vector """
    def __init__(self):
        super(Decoder, self).__init__()
        N_features = 8

        self.dense1 = create_dense_layer(parameter.latent_dim, 2**6*N_features)
        self.dense2 = create_dense_layer(2**6*N_features, 2**5*N_features)
        self.reshape = Reshape(2**5*N_features,1,1)

        self.iconv1 = create_iconv_layer(2**5*N_features, 2**4*N_features, kernel_size=4, stride=1, padding=0, output_padding=0)
        self.iconv2 = create_iconv_layer(2**4*N_features, 2**3*N_features)
        self.iconv3 = create_iconv_layer(2**3*N_features, 2**2*N_features)
        self.iconv4 = create_iconv_layer(2**2*N_features, 2*N_features)
        self.iconv5 = create_iconv_layer(2*N_features, N_features)
        self.conv_out = Sequential(
            Conv2d(N_features, colors_dim, kernel_size=1, stride=1),
            Sigmoid(),
        )

        self.set_optimizer(parameter.optimizer, lr=parameter.learning_rate, betas=parameter.betas)



    def forward(self, latent):
        x = self.dense1(latent)
        x = self.dense2(x)
        x = self.reshape(x)
        x = self.iconv1(x)
        x = self.iconv2(x)
        x = self.iconv3(x)
        x = self.iconv4(x)
        x = self.iconv5(x)
        x = self.conv_out(x)
        return x


class Discriminator(NeuralNetwork):
    """ discriminates real and fake images """
    def __init__(self):
        super(Discriminator, self).__init__()
        N_features = 8
        self.conv0 = Sequential(
            Conv2d(colors_dim, N_features, kernel_size=1, stride=1),
            LeakyReLU(negative_slope=parameter.negative_slope),
        )
        self.conv1 = create_conv_layer(N_features, 2*N_features)
        self.conv2 = create_conv_layer(2*N_features, 2**2*N_features)
        self.conv3 = create_conv_layer(2**2*N_features, 2**3*N_features)
        self.conv4 = create_conv_layer(2**3*N_features, 2**4*N_features)
        self.conv5 = create_conv_layer(2**4*N_features, 2**5*N_features, kernel_size=4, stride=1, padding=0)

        self.flat = Flatten()
        self.dense1 = create_dense_layer(2**5*N_features, 2**5*N_features)
        self.dense_out = Sequential(
            Linear(2**5*N_features, 1),
            Sigmoid()
        )
        self.set_optimizer(parameter.optimizer, lr=parameter.learning_rate, betas=parameter.betas)

    def forward(self, images):
        x = self.conv0(images)
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.conv4(x)
        x = self.conv5(x)
        x = self.flat(x)
        x = self.dense1(x)
        x = self.dense_out(x)
        return x


class ImageClassifier(NeuralNetwork):
    def __init__(self):
        super(ImageClassifier, self).__init__()
        self.conv = Sequential(*(list(resnet.children())[:-1]))
        self.dense_groups = []
        for i, N_label in enumerate(label_group_sizes):
            layer = f"dense_group{i}"
            setattr(self, layer,
                Sequential(
                   Linear(N_resnet, N_label),
                    Softmax(dim=1)
                )
            )
            self.dense_groups.append( getattr(self, layer) )
        self.dense_noise = Linear(N_resnet, labels_dim)

        self.set_optimizer(parameter.optimizer, lr=parameter.learning_rate, betas=parameter.betas)

    def forward(self, image):
        x = self.conv(image)
        x = squeeze(x)
        noise = self.dense_noise(x)
        label_groups = [dense(x) for dense in self.dense_groups]
        ## renormalize label groups to fit settings of Galaxy zoo
        labels = make_galaxy_labels_hierarchical(label_groups)
        return labels, noise


class Classifier(NeuralNetwork):
    """ classifies image from latent vector (distribution).
    Reproduce classification scheme of Galaxy Zoo, see
    https://www.kaggle.com/c/galaxy-zoo-the-galaxy-challenge/overview/the-galaxy-zoo-decision-tree
    """
    def __init__(self):
        super(Classifier, self).__init__()
        N_hidden = 2*parameter.latent_dim
        self.dense1 = create_dense_layer(parameter.latent_dim, N_hidden)
        self.dense2 = create_dense_layer(N_hidden, N_hidden)
        self.dense3 = create_dense_layer(N_hidden, N_hidden)
        self.dense4 = create_dense_layer(N_hidden, N_hidden)
        ## the following take the same input
        self.dense_groups = []
        for i, N_label in enumerate(label_group_sizes):
            layer = f"dense_group{i}"
            setattr(self, layer,
                Sequential(
                   Linear(N_hidden, N_label),
                    Softmax(dim=1)
                )
            )
            self.dense_groups.append( getattr(self, layer) )
        self.dense_noise = Linear(N_hidden, labels_dim)


        self.set_optimizer(parameter.optimizer, lr=parameter.learning_rate, betas=parameter.betas)

    def forward(self, latent_mu, latent_sigma):
#        x = self.dense1(cat((latent_mu, latent_sigma), dim=1))
        x = gaussian_sampler(latent_mu, latent_sigma)
        x = self.dense1(x)
        x = self.dense2(x)
        x = self.dense3(x)
        x = self.dense4(x)
        noise = self.dense_noise(x)
        label_groups = [dense(x) for dense in self.dense_groups]
        ## renormalize label groups to fit settings of Galaxy zoo
        labels = make_galaxy_labels_hierarchical(label_groups)
        return labels, noise

class InverseClassifier(NeuralNetwork):
    """ transforms class info to latent vector (distribution) """
    def __init__(self):
        super(InverseClassifier, self).__init__()
        N_hidden = 4*labels_dim
        self.dense1 = create_dense_layer(2*labels_dim, N_hidden)
        self.dense2 = create_dense_layer(N_hidden, N_hidden)

        ## the following take the same input
        self.dense_mu= Sequential(
            Linear(4*labels_dim, parameter.latent_dim),
#            Tanh()
        )
        self.dense_sigma= Sequential(
            Linear(4*labels_dim, parameter.latent_dim),
            Softplus()
#            Sigmoid()
        )
        self.set_optimizer(parameter.optimizer, lr=parameter.learning_rate, betas=parameter.betas)

    def forward(self, labels, noise):
        x = self.dense1(cat((labels,noise),dim=1))
        x = self.dense2(x)
        z_mu = self.dense_mu(x)
        z_sigma = self.dense_sigma(x)
        return z_mu, z_sigma




### DCGAN architecture (Redford et al 2015)


class Encoder_(NeuralNetwork):
    """ transforms image to latent vector (distribution) """
    def __init__(self):
        super(Encoder, self).__init__()
        N_features = 8

        self.conv1 = create_conv_layer(colors_dim, 128)
        self.conv2 = create_conv_layer(128, 256)
        self.conv3 = create_conv_layer(256, 512)
        self.conv4 = create_conv_layer(512, 1024)

        self.flat = Flatten()

        ## the following take the same input
        self.dense_z_mu = Sequential(
            Linear(1024*4*4, parameter.latent_dim),
#            Tanh(),
        )
        self.dense_z_std = Sequential(
            Linear(1024*4*4, parameter.latent_dim),
            Softplus()
#            Sigmoid(),
        )
        self.set_optimizer(parameter.optimizer, lr=parameter.learning_rate, betas=parameter.betas)




    def forward(self, images):
        x = self.conv1(images)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.conv4(x)
        x = self.flat(x)
        z_mu = self.dense_z_mu(x)
        z_std = self.dense_z_std(x)
        return z_mu, z_std


class Decoder_(NeuralNetwork):
    """ generates image from latent vector """
    def __init__(self):
        super(Decoder, self).__init__()
        N_features = 8

        self.dense1 = create_dense_layer(parameter.latent_dim, 4*4*1024)
        self.reshape = Reshape(1024,4,4)

        self.iconv1 = create_iconv_layer(1024, 512)
        self.iconv2 = create_iconv_layer(512, 256)
        self.iconv3 = create_iconv_layer(256, 128)
        self.iconv4 = create_iconv_layer(128, 3)

        self.set_optimizer(parameter.optimizer, lr=parameter.learning_rate, betas=parameter.betas)



    def forward(self, latent):
        x = self.dense1(latent)
        x = self.reshape(x)
        x = self.iconv1(x)
        x = self.iconv2(x)
        x = self.iconv3(x)
        x = self.iconv4(x)
        return x

class Discriminator_(NeuralNetwork):
    """ discriminate real and fake images """
    def __init__(self):
        super(Discriminator, self).__init__()
        N_features = 8

        self.conv1 = create_conv_layer(colors_dim, 128)
        self.conv2 = create_conv_layer(128, 256)
        self.conv3 = create_conv_layer(256, 512)
        self.conv4 = create_conv_layer(512, 1024)

        self.flat = Flatten()

        self.dense = Sequential(
            Linear(1024*4*4, 1),
            Sigmoid(),
        )
        self.set_optimizer(parameter.optimizer, lr=parameter.learning_rate, betas=parameter.betas)

    def forward(self, images):
        x = self.conv1(images)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.conv4(x)
        x = self.flat(x)
        genuine = self.dense(x)
        return genuine



class Pipeline:
    def __init__(self):
        self.encoder = Encoder().cuda()
        self.decoder = Decoder().cuda()
        self.discriminator = Discriminator().cuda()
        self.classifier = Classifier().cuda()
        self.image_classifier = ImageClassifier().cuda()
        self.inverse_classifier = InverseClassifier().cuda()
        self.iteration = 0
        self.discriminator_train_interval = 1

    def discriminate(self, images):
        """ find if image is real or fake """
        genuine = self.discriminator(images)
        return genuine


    def classify(self, images, detach=False):
        """ find class vector describing image """
        latent_mu, latent_sigma = self.encoder(images)
        if detach:
            latent_mu = latent_mu.detach()
            latent_sigma = latent_sigma.detach()
        labels, noise = self.labels_from_latent(latent_mu, latent_sigma)
        return labels, noise

    def classify_image(self, images, detach=False):
        """ find class vector describing image """
        labels, noise = self.image_classifier(images)
        return labels, noise

    def labels_from_latent(self, latent_mu, latent_sigma):
        """ find class vector corresponding to latent vector distribution """
        labels, noise = self.classifier(latent_mu, latent_sigma)
        return labels, noise

    def latent_from_images(self, images):
        """ obtain latent distribution from images """
        latent_mu, latent_sigma = self.encoder(images)
        return latent_mu, latent_sigma

    def latent_from_labels(self, labels, noise):
        """ create latent vector distribution from labels """
        latent_mu, latent_sigma = self.inverse_classifier(labels, noise)
        return latent_mu, latent_sigma

    def reproduce(self, images, detach=False):
        """ recreate given images """
        latent_mu, latent_sigma = self.encoder(images)
        latent = gaussian_sampler(latent_mu, latent_sigma)
        if detach:
            latent = latent.detach()
        generated_images = self.generate_from_latent(latent)
        return generated_images

    def generate_from_latent(self, latent):
        """ generate images from latent vector """
        generated_images = self.decoder(latent)
        return generated_images

    def generate_from_latent_distribution(self, latent_mu, latent_sigma):
        """ generate images from latent distribution """
        latent = gaussian_sampler(latent_mu, latent_sigma)
        generated_images = self.generate_from_latent(latent)
        return generated_images

    def generate_from_labels(self, labels, noise):
        """ generate images from class labels """
        latent = self.latent_from_labels(labels, noise)
        generated_images = self.generate_from_latent_distribution(*latent)
        return generated_images

    ## training
    def train_step(self, images, labels):
        """ train the architecture on given images and corresponding labels """
        batch_size = images.shape[0]
        noise_gen = generate_noise(batch_size)
        labels_gen = generate_galaxy_labels(batch_size)
        latent_gen = generate_latent(batch_size)
        self.fake = zeros((batch_size,1)).cuda() # target values
        self.real = ones((batch_size,1)).cuda() # target values

        if weights["image_label"](self.iteration):
            self.train_image_classifier(images, labels)
        self.train_classifier(images, labels, noise_gen, labels_gen, latent_gen)
        self.train_inverse_classifier(images, labels, noise_gen, labels_gen, latent_gen)
        self.train_generator(images, labels, noise_gen, labels_gen, latent_gen)
        if not self.iteration % self.discriminator_train_interval:
            self.train_discriminator(images, labels, noise_gen, labels_gen, latent_gen)
        self.iteration += 1


    def estimate_accuracy(self, images, labels):
        """ train the architecture on given images and corresponding labels """
        batch_size = images.shape[0]
        noise_gen = generate_noise(batch_size)
        labels_gen = generate_galaxy_labels(batch_size)
        latent_gen = generate_latent(batch_size)
        self.fake = zeros((batch_size,1)).cuda() # target values
        self.real = ones((batch_size,1)).cuda() # target values

        if weights["image_label"](self.iteration):
            self.train_image_classifier(images, labels, train=False)
        self.train_classifier(images, labels, noise_gen, labels_gen, latent_gen, train=False)
        self.train_inverse_classifier(images, labels, noise_gen, labels_gen, latent_gen, train=False)
        self.train_generator(images, labels, noise_gen, labels_gen, latent_gen, train=False)
        self.train_discriminator(images, labels, noise_gen, labels_gen, latent_gen, train=False)
        

    def train_step_fails(self, images, labels):
        """ train the architecture on given images and corresponding labels """
        batch_size = images.shape[0]
        noise_gen = generate_noise(batch_size) if weights["trick_labels"](self.iteration) else 0
        labels_gen = generate_galaxy_labels(batch_size) if weights["unity_labels"](self.iteration) else 0
        latent_gen = generate_latent(batch_size)

        # starting from images
        d1 = self.discriminate(images)
        latent1 = self.latent_from_images(images)
        labels1, _ = self.labels_from_latent(*latent1)
        images1 = self.generate_from_latent_distribution(*latent1)
        d2 = self.discriminate(images1)
        labels2, _ = self.classify(images1) if weights["label_reproduced_image"](self.iteration) else (0, 0)

        # starting from labels
        imgs = self.generate_from_labels(labels, noise_gen) if weights["trick_labels"](self.iteration) else 0
        d3 = self.discriminate(imgs) if weights["trick_labels"](self.iteration) else 0
        labels3, _ = self.classify(imgs) if weights["label_labels"](self.iteration) else 0

        # starting from generated labels
        z = self.latent_from_labels(labels_gen, noise_gen) if weights["unity_labels"](self.iteration) else 0
        labels4, noise4 = self.labels_from_latent(*z) if weights["unity_labels"](self.iteration) else (0, 0)
        imgs = self.generate_from_latent_distribution(*z) if weights["trick_fake_labels"](self.iteration) else 0
        d4 = self.discriminate(imgs) if weights["trick_fake_labels"](self.iteration) else 0
        labels5, _ = self.classify(imgs) if weights["reproduce_fake_labels"](self.iteration) else (0, 0)

        # starting from generated latent
        latent2 = self.latent_from_labels(*self.labels_from_latent(*latent_gen)) if weights["unity_latent"](self.iteration) else 0
        d5 = self.discriminate(self.generate_from_latent_distribution(*latent_gen))

        self.fake = zeros((batch_size,1)).cuda() # target values
        self.real = ones((batch_size,1)).cuda() # target values
        loss_discriminator = self.get_loss_discriminator(d1, d2, d3, d4, d5)
        loss_generator = self.get_loss_generator(images, labels, labels_gen, images1, latent1, d2, d3, d4, labels2, labels3, labels5)
        loss_classifier = self.get_loss_classifier(labels, labels_gen, labels1, labels2, labels3, labels5)
        loss_inverse_classifier = self.get_loss_inverse_classifier(latent_gen, labels_gen, noise_gen, latent2, labels4, noise4)

        update_networks_on_loss(loss_discriminator, self.discriminator)
        update_networks_on_loss(loss_generator, self.encoder, self.decoder)
        update_networks_on_loss(loss_classifier, self.classifier)
        update_networks_on_loss(loss_inverse_classifier, self.inverse_classifier)

        self.iteration += 1

    ## discriminator
    def train_discriminator(self, images, labels, noise_gen, labels_gen, latent_gen, train=True):
        # from images
        d1 = self.discriminate(images) if weights["discriminate_real"](self.iteration) else 0
        if weights["discriminate_reproduced"](self.iteration):
            images1 = self.reproduce(images)
            d2 = self.discriminate(images1) 
        else:
            d2 = 0
        # from labels
        if weights["discriminate_labels"](self.iteration):
            imgs = self.generate_from_labels(labels, noise_gen)
            d3 = self.discriminate(imgs)
        else:
            d3 = 0
        # from fake labels
        if weights["discriminate_fake_labels"](self.iteration):
            imgs = self.generate_from_labels(labels_gen, noise_gen)
            d4 = self.discriminate(imgs)
        else:
            d4 = 0
        # from latent
        if weights["discriminate_latent"](self.iteration):
            d5 = self.discriminate(self.generate_from_latent_distribution(*latent_gen))
        else:
            d5 = 0

        if train:
            loss = self.get_loss_discriminator(d1, d2, d3, d4, d5)
            update_networks_on_loss(loss, self.discriminator)
            losses["discriminator"] = loss.item()
        else:
            accuracy["discriminate_real"] = accuracy_discriminator(self.real, d1)
            accuracy["discriminate_reproduced"] = accuracy_discriminator(self.fake, d2)
            accuracy["discriminate_labels"] = accuracy_discriminator(self.fake, d3)
            accuracy["discriminate_fake_labels"] = accuracy_discriminator(self.fake, d4)
            accuracy["discriminate_latent"] = accuracy_discriminator(self.fake, d5)



    def get_loss_discriminator(self, d1, d2, d3, d4, d5):
        loss = self.get_loss_discriminate_real(d1)
        loss += self.get_loss_discriminate_reproduced(d2)
        loss += self.get_loss_discriminate_labels(d3)
        loss += self.get_loss_discriminate_fake_labels(d4)
        loss += self.get_loss_discriminate_latent(d5)
        return loss

    @weighted_loss("discriminate_real")
    def get_loss_discriminate_real(self, d1):
        return loss_adversarial(self.real, d1)

    @weighted_loss("discriminate_reproduced")
    def get_loss_discriminate_reproduced(self, d2):
        return loss_adversarial(self.fake, d2)

    @weighted_loss("discriminate_labels")
    def get_loss_discriminate_labels(self, d3):
        return loss_adversarial(self.fake, d3)

    @weighted_loss("discriminate_fake_labels")
    def get_loss_discriminate_fake_labels(self, d4):
        return loss_adversarial(self.fake, d4)

    @weighted_loss("discriminate_latent")
    def get_loss_discriminate_latent(self, d5):
        return loss_adversarial(self.fake, d5)


    ## generator
    def train_generator(self, images, labels, noise_gen, labels_gen, latent_gen, train=True):
        # from images
        latent1 = self.latent_from_images(images)
        images1 = self.generate_from_latent_distribution(*latent1)
        d2 = self.discriminate(images1) if weights["trick_reproduced"](self.iteration) else 0
        labels2, _ = self.classify(images1) if weights["reproduce_image_labels"](self.iteration) else (0, 0)
        # from labels
        if weights["trick_labels"](self.iteration):
            imgs = self.generate_from_labels(labels, noise_gen)
            d3 = self.discriminate(imgs)
            labels3, _ = self.classify(imgs)
        else:
            d3, labels3 = 0, 0
        # from generated labels
        if weights["trick_fake_labels"](self.iteration):
            z = self.latent_from_labels(labels_gen, noise_gen)
            imgs = self.generate_from_latent_distribution(*z)
            d4 = self.discriminate(imgs)
            labels5, _ = self.classify(imgs)
        else:
            d4, labels5 = 0, 0

        if train:
            loss = self.get_loss_generator(images, labels, labels_gen, images1, latent1, d2, d3, d4, labels2, labels3, labels5)
            update_networks_on_loss(loss, self.encoder, self.decoder)
            losses["generator"] = loss.item()
            if np.isnan(loss.item()):
                set_trace()
        else:
            accuracy["trick_reproduced"] = accuracy_discriminator(self.real, d2)
            accuracy["reproduce_image_labels"] = accuracy_classifier(labels, labels2)
            accuracy["trick_labels"] = accuracy_discriminator(self.real, d3)
            accuracy["reproduce_labels"] = accuracy_classifier(labels, labels3)
            accuracy["trick_fake_labels"] = accuracy_discriminator(self.real, d4)
            accuracy["reproduce_fake_labels"] = accuracy_classifier(labels_gen, labels5)


    def get_loss_generator(self, images, labels, labels_gen, images1, latent1, d2, d3, d4, labels2, labels3, labels5):
        loss = self.get_loss_regularize_latent(latent1)
        loss += self.get_loss_reproduce_image(images, images1)
        loss += self.get_loss_trick_reproduced(d2)
        loss += self.get_loss_reproduce_image_labels(labels, labels2)
        loss += self.get_loss_trick_labels(d3)
        loss += self.get_loss_reproduce_labels(labels, labels3)
        loss += self.get_loss_trick_fake_labels(d4)
        loss += self.get_loss_reproduce_fake_labels(labels_gen, labels5)
        return loss

    @weighted_loss("regularize_latent")
    def get_loss_regularize_latent(self, latent1):
        return loss_kl(latent1)

    @weighted_loss("reproduce_image")
    def get_loss_reproduce_image(self, images, images1):
        return loss_reconstruction(images, images1)

    @weighted_loss("trick_reproduced")
    def get_loss_trick_reproduced(self, d2):
        return loss_adversarial(self.real, d2)

    @weighted_loss("trick_labels")
    def get_loss_trick_labels(self, d3):
        return loss_adversarial(self.real, d3)

    @weighted_loss("trick_fake_labels")
    def get_loss_trick_fake_labels(self, d4):
        return loss_adversarial(self.real, d4)

    @weighted_loss("reproduce_image_labels")
    def get_loss_reproduce_image_labels(self, labels, labels2):
        return loss_class(labels, labels2)

    @weighted_loss("reproduce_labels")
    def get_loss_reproduce_labels(self, labels, labels3):
        return loss_class(labels, labels3)

    @weighted_loss("reproduce_fake_labels")
    def get_loss_reproduce_fake_labels(self, labels, labels5):
        return loss_class(labels, labels5)


    ## image classifier
    def train_image_classifier(self, images, labels, train=True):
        labels1, _ = self.classify_image(images)

        if train:
#            loss = loss_class(labels, labels1)
            loss = loss_class(labels[:,ixs_labels], labels1[:,ixs_labels])
            update_networks_on_loss(loss, self.image_classifier)
            losses["image_classifier"] = loss.item()
        else:
#            accuracy["image_label"] = accuracy_classifier(labels, labels1)
            accuracy["image_label"] = accuracy_classifier(labels[:,ixs_labels], labels1[:,ixs_labels])



    ## classifier
    def train_classifier(self, images, labels, noise_gen, labels_gen, latent_gen, train=True):
        # starting from images
        if weights["label_image"](self.iteration):
            latent1 = self.latent_from_images(images)
            labels1, _ = self.labels_from_latent(*latent1)
            images1 = self.generate_from_latent_distribution(*latent1)
            labels2, _ = self.classify(images1)
        else:
            labels1, labels2 = 0, 0
        # starting from labels
        if weights["label_labels"](self.iteration):
            imgs = self.generate_from_labels(labels, noise_gen)
            labels3, _ = self.classify(imgs)
        else:
            labels3 = 0

        # starting from generated labels
        if weights["label_fake_labels"](self.iteration):
            z = self.latent_from_labels(labels_gen, noise_gen)
            imgs = self.generate_from_latent_distribution(*z)
            labels5, _ = self.classify(imgs)
        else:
            labels5 = 0

        if train:
            loss = self.get_loss_classifier(labels, labels_gen, labels1, labels2, labels3, labels5)
            update_networks_on_loss(loss, self.classifier, self.encoder)
            losses["classifier"] = loss.item()
            if np.isnan(loss.item()):
                set_trace()
                print(labels1, labels2, labels3, labels5)
        else:
#            accuracy["label_image"] = accuracy_classifier(labels, labels1)
            accuracy["label_image"] = accuracy_classifier(labels[:,ixs_labels], labels1[:,ixs_labels])
            accuracy["label_reproduced_image"] = accuracy_classifier(labels, labels2)
            accuracy["label_labels"] = accuracy_classifier(labels, labels3)
            accuracy["label_fake_labels"] = accuracy_classifier(labels_gen, labels5)
            if False:
                self.plot_mislabeled_galaxies(images, labels, labels1)

    def plot_mislabeled_galaxies(self, images, labels, prediction, N=10, p=0.1):
        accurate = abs(labels[:,ixs_labels] - prediction[:,ixs_labels]) < p
        accurate = all(accurate, dim=1)
        i = 0
        for ix, acc in enumerate(accurate):
            if not acc:
                plt.imshow(np.transpose(images[ix].cpu(), (1,2,0)))
                plt.show()
                print((labels[ix,ixs_labels]*100).round() / 100)
                print((prediction[ix,ixs_labels]*100).round() / 100)
                print((abs(prediction[ix,ixs_labels]-labels[ix,ixs_labels])*100).round() / 100)
                i += 1
                if i >= N:
                    break


    def get_loss_classifier(self, labels, labels_gen, labels1, labels2, labels3, labels5):
#        loss = self.get_loss_label_image(labels, labels1)
        loss = self.get_loss_label_image(labels[:,ixs_labels], labels1[:,ixs_labels])
        loss += self.get_loss_label_reproduced_image(labels, labels2)
        loss += self.get_loss_label_labels(labels, labels3)
        loss += self.get_loss_label_fake_labels(labels_gen, labels5)
        return loss

    @weighted_loss("label_image")
    def get_loss_label_image(self, labels, labels1):
        return loss_class(labels, labels1)

    @weighted_loss("label_reproduced_image")
    def get_loss_label_reproduced_image(self, labels, labels2):
        return loss_class(labels, labels2)

    @weighted_loss("label_labels")
    def get_loss_label_labels(self, labels, labels3):
        return loss_class(labels, labels3)

    @weighted_loss("label_fake_labels")
    def get_loss_label_fake_labels(self, labels, labels5):
        return loss_class(labels, labels5)

    ## inverse classifier
    def train_inverse_classifier(self, images, labels, noise_gen, labels_gen, latent_gen, train=True):
        # from generated labels
        if weights["unity_labels"](self.iteration):
            labels4, noise4 = self.labels_from_latent(*self.latent_from_labels(labels_gen, noise_gen))
        else:
            labels4, noise4 = 0, 0
        # from latent
        if weights["unity_latent"](self.iteration):
            label, noise = self.labels_from_latent(*latent_gen)
            latent2 = self.latent_from_labels(label, noise)
            latent2 = self.latent_from_labels(*self.labels_from_latent(*latent_gen))
        else:
            latent2 = 0
        if train:
            loss = self.get_loss_inverse_classifier(latent_gen, labels_gen, noise_gen, latent2, labels4, noise4)
            update_networks_on_loss(loss, self.inverse_classifier)
            losses["inverse_classifier"] = loss.item()
        else:
            accuracy["unity_labels"] = accuracy_classifier(labels_gen, labels4)
            accuracy["unity_latent"] = accuracy_classifier(latent_gen, latent2)
#            accuracy["unity_latent"] = accuracy_classifier(cat(latent_gen, dim=1), cat(latent2, dim=1))


    def get_loss_inverse_classifier(self, latent_gen, labels_gen, noise_gen, latent2, labels4, noise4):
        loss = self.get_loss_unity_labels(labels_gen, noise_gen, labels4, noise4)
        loss += self.get_loss_unity_latent(latent_gen, latent2)
        return loss

    @weighted_loss("unity_latent")
    def get_loss_unity_latent(self, latent_gen, latent2):
        return loss_latent(latent_gen, latent2)

    @weighted_loss("unity_labels")
    def get_loss_unity_labels(self, labels_gen, noise_gen, labels4, noise4):
        return loss_class(labels_gen, labels4) + loss_metric(noise_gen, noise4)




In [None]:
pipe = Pipeline()

In [None]:
batch_size = 64

pipe.iteration = 20001

images = rand(batch_size, 3, 64, 64).cuda()
labels = rand(batch_size, labels_dim).cuda()

pipe.train_step(images, labels)


In [None]:
latent = pipe.latent_from_images(images)
for l in latent:
    print(l.device)
loss_kl(latent).device

# basic test Neural Networks
check whether the pipeline components work on their own

In [None]:
summarize(Discriminator4)

In [None]:
import numpy as np
from torch import rand


net = Encoder()

print("N_parameters:", np.sum([np.prod(p.size()) for p in net.parameters()]))
input_dummy = rand(3, *input_size)
label_dummy = rand(3, labels_dim)

net(input_dummy)

In [None]:
net = Decoder()

print("N_parameters:", np.sum([np.prod(p.size()) for p in net.parameters()]))
input_dummy = rand(3, parameter.latent_dim)
label_dummy = rand(3, labels_dim)

net(input_dummy).shape

In [None]:
net = Discriminator()

print("N_parameters:", np.sum([np.prod(p.size()) for p in net.parameters()]))
input_dummy = rand(3, parameter.latent_dim)
label_dummy = rand(3, labels_dim)

net(input_dummy)

In [None]:
net = Classifier()

print("N_parameters:", np.sum([np.prod(p.size()) for p in net.parameters()]))
input_dummy = rand(3, parameter.latent_dim)
label_dummy = rand(3, labels_dim)

net(input_dummy)

In [None]:
net = InverseClassifier()

print("N_parameters:", np.sum([np.prod(p.size()) for p in net.parameters()]))
input_dummy = rand(3, labels_dim)
label_dummy = rand(3, labels_dim)

net(input_dummy)

In [None]:
net = VAE()
images_dummy = rand(5,3,64,64).cuda()
labels_dummy = rand(5,3).cuda()
pred = net(images_dummy, labels_dummy)
pred.shape

In [None]:
net = VAEGAN()
input_dummy = rand(3,3,64,64).cuda()
labels_dummy = rand(3,3).cuda()
pred = net(input_dummy, labels_dummy)
pred.shape, pred[:,0]

# Training Data
load data from files at google drive

In [None]:
x_train = get_x_train()
labels_train = get_labels_train()
N_samples = x_train.shape[0]

In [None]:
from torch import tensor
from torch.utils.data import Dataset, random_split

N_test = 3000

np.random.seed(N_test)
ixs_train = np.arange(x_train.shape[0])
np.random.shuffle(ixs_train)
ixs_test, ixs_train = np.split(ixs_train, (N_test,))

np.random.seed()

x_test = x_train[ixs_test]
labels_test = labels_train[ixs_test]

x_train = x_train[ixs_train]
labels_train = labels_train[ixs_train]


# Training new network

In [None]:
## if you want to change any parameter:
parameter.learning_rate = 0.002 #0.01
parameter.latent_dim = 2048
parameter.latent_dim = N_resnet // 2

pipe = Pipeline()

#'''  load pretrained models
pipe.encoder.load()
pipe.classifier.load()
#'''

#'''  choose which class label groups to train on
groups_classified = [6, 8] #[1, 2, 6, 7]
ixs_labels = []
for g in groups_classified:
    ixs_labels.extend(class_groups_indices[g]-1)
ixs_labels = tensor(ixs_labels)
#'''

pipe.discriminator_train_interval = 1 #10

if track_hyperparameter:
    wandb.init(project="galaxy-generator", # top level identifier
            group="check_losses", # second level identifier, to seperate several groups of tests
            job_type="training", # third level identifier, organize different jobs like training and evaluation
            tags=["all one"], # temporary tags to organize different tasks together
            name=f"no class loss, delta=0", # bottom level identifier, label of graph in UI
            config=parameter.return_parameter_dict()  # here we fill the hyperparameters
    )

epochs = 50
batch_size = 64
steps = x_train.shape[0] // batch_size
save_interval = 500
loss_interval = 100

losses_series = {name:[] for name in list(weights.keys())+["discriminator", "generator", "classifier","inverse_classifier", "image_classifier"]}
accuracies = [
              # discriminator
              "discriminate_real", "discriminate_reproduced", "discriminate_latent", "discriminate_labels", "discriminate_fake_labels",
              # generator
              "trick_reproduced", "trick_labels", "trick_fake_labels",
              "reproduce_image_labels", "reproduce_labels", "reproduce_fake_labels",
              # classifier
              "label_images", "label_reproduced_image", "label_labels", "label_fake_labels",
              # inverse classifier
              "unity_labels", "unity_latent"
              ]
accuracy_series = {name:[] for name in list(weights.keys())}
tests = [
         "N_pairs_real", "N_pairs_reproduced", "N_pairs_fake", 
         "average_residual_real", "average_residual_reproduced", "average_residual_fake", 
         "intensity_real", "intensity_reproduced", "intensity_fake", 
         "bluriness_real", "bluriness_reproduced", "bluriness_fake"
        ]
test_series = {name:[] for name in tests}
losses = {}

from statistical_tests import get_image_pair_residual_statistics, compute_residual_average_image, compute_average_intensity, compute_average_bluriness

discriminator_losses = []
generator_losses = []
encoder_losses = []
classifier_losses = []
inverse_classifier_losses = []

iteration = 1
epoch = 1




In [None]:
t0 = time()
while epoch <= epochs:
    sample_generator = make_training_sample_generator(batch_size, x_train, labels_train)
    step = 1
    while step <= steps:
        images, labels = next(sample_generator)
        pipe.train_step(images, labels)

        if not pipe.iteration % loss_interval:
            t1 = time()
            print(f"iteration {pipe.iteration}, epoch {epoch}, batch {step+1}/{steps}, took {t1-t0:.1f} s")
            t0 = t1
            ixs = np.random.choice(np.arange(x_test.shape[0]), size=300, replace=False )
            pipe.estimate_accuracy(images=x_test[ixs].cuda(), labels=labels_test[ixs].cuda())
            for name in accuracy_series.keys():
                try:
                    accuracy_series[name].append(accuracy[name])
                except:
                    accuracy_series[name].append(0.)
            for name in losses_series.keys():
                try:
                    losses_series[name].append(losses[name])
                except:
                    losses_series[name].append(0.)
                if name in ["discriminator", "generator", "classifier","inverse_classifier", "image_classifier"]: #["discriminate_real", "discriminate_generated", "generator_from_labels", "generator_reproduced"]:
                    try:
                        print(f"{name}: {losses[name]:.2}", end=", ")
                    except:
                        pass
            else:
                print("")
            
#            print(f"iteration {iteration}, epoch {epoch}, batch {step}/{steps}, " + \
#                f"disc_loss {losses['discriminator']:.5},  gen_loss {losses['generator']:.5}")

        if not pipe.iteration % save_interval:
            write_generated_galaxy_images_iteration(iteration=pipe.iteration-1, images=images.detach().cpu().numpy())
            reproduced_images = pipe.reproduce(images)
            write_generated_galaxy_images_iteration(iteration=pipe.iteration, images=reproduced_images.detach().cpu().numpy())
            noise = generate_noise(batch_size)
            generated_images = pipe.generate_from_labels(labels, noise)
            write_generated_galaxy_images_iteration(iteration=pipe.iteration+1, images=generated_images.detach().cpu().numpy())
            test_series["N_pairs_real"].append(get_image_pair_residual_statistics(images)[0])
            test_series["N_pairs_reproduced"].append(get_image_pair_residual_statistics(reproduced_images)[0])
            test_series["N_pairs_fake"].append(get_image_pair_residual_statistics(generated_images)[0])
            test_series["average_residual_real"].append(compute_residual_average_image(images, last_images))
            test_series["average_residual_reproduced"].append(compute_residual_average_image(images, reproduced_images))
            test_series["average_residual_fake"].append(compute_residual_average_image(images, generated_images))
            test_series["intensity_real"].append(compute_average_intensity(images.detach().cpu().numpy()))
#            test_series["intensity_reproduced"].append(compute_average_intensity(reproduced_images.detach().cpu().numpy()))
#            test_series["intensity_fake"].append(compute_average_intensity(generated_images.detach().cpu().numpy()))
            test_series["bluriness_real"].append(compute_average_bluriness(images.detach().cpu()))
            test_series["bluriness_reproduced"].append(compute_average_bluriness(reproduced_images.detach().cpu()))
            test_series["bluriness_fake"].append(compute_average_bluriness(generated_images.detach().cpu()))
        step += 1
        iteration += 1
        last_images = images
    epoch += 1

    # save a plot of the costs
    fig, axs = plt.subplots(1, 6, figsize=(30,8))
    '''
    for name in ["discriminator", "generator", "classifier","inverse_classifier"]:
        ax.plot(np.arange(loss_interval,iteration,loss_interval), losses_series[name], label=name)
    ax.set_yscale("log")
    ax.legend()
    plt.savefig(folder_results+"cost_vs_iteration.png")
    plt.show()
    plt.close()
    '''

    for name, loss_series in losses_series.items():
        if not loss_series:
            continue
        if name in ["discriminator", "generator", "classifier","inverse_classifier", "image_classifier"]:
            ax = axs[0]
        elif "discriminate" in name:
            ax = axs[1]
        elif "label_" in name:
            ax = axs[3]
        elif "unity_" in name:
            ax = axs[4]
        elif "image_" in name:
            ax = axs[5]
        else:
            ax = axs[2]
#        ax.plot(np.arange(loss_interval,pipe.iteration,loss_interval), loss_series, label=name)
        ax.plot(np.arange(1,len(loss_series)+1) * loss_interval, loss_series, label=name)
    axs[0].set_title("total loss")
    axs[1].set_title("discriminator")
    axs[2].set_title("generator")
    axs[3].set_title("classifier")
    axs[4].set_title("inverse classifier")
    axs[5].set_title("image classifier")
    for ax in axs:
        ax.set_yscale("log")
        ax.legend()
    plt.savefig(folder_results+"cost_vs_iteration.png")
    plt.show()
    plt.close()

    # save plot of tests
    fig, axs = plt.subplots(1, 4, figsize=(30,8))
    for name, series in test_series.items():
        if not series:
            continue
        if "N_pairs" in name:
            ax = axs[0]
        elif "residual" in name:
            ax = axs[1]
        elif "intensity" in name:
            ax = axs[2]
        elif "bluriness" in name:
            ax = axs[3]
#        ax.plot(np.arange(save_interval,pipe.iteration,save_interval), series, label=name, linestyle=":" if "real" in name else "-")
        ax.plot(np.arange(1,len(series)+1) * loss_interval, series, label=name, linestyle=":" if "real" in name else "-")
    axs[0].set_title("N_pairs")
    axs[1].set_title("average residual")
    axs[2].set_title("intensity")
    axs[3].set_title("bluriness")
    for ax in axs:
        ax.set_yscale("log")
        ax.legend()
    plt.savefig(folder_results+"tests_vs_iteration.png")
    plt.show()
    plt.close()

    # save plot of accuracies
    fig, axs = plt.subplots(1, 5, figsize=(30,8))
    for name, series in accuracy_series.items():
        linestyle = "-"
        if not series:
            continue
        if "discriminate" in name:
            ax = axs[0]
        elif "reproduce_" in name:
            ax = axs[1]
        elif "label_" in name:
            ax = axs[2]
        elif "unity_" in name:
            ax = axs[3]
        elif "image_" in name:
#            ax = axs[4]
            ax = axs[2]
            linestyle = "-."
        else:
            ax = axs[1]
#        ax.plot(np.arange(loss_interval,pipe.iteration,loss_interval), series, label=name, linestyle=linestyle)
        ax.plot(np.arange(1,len(series)+1) * loss_interval, series, label=name, linestyle=linestyle)
    axs[0].set_title("discriminator")
    axs[1].set_title("generator")
    axs[2].set_title("classifier")
    axs[3].set_title("inverse classifier")
    axs[4].set_title("image classifier")
    for ax in axs:
        ax.legend()
        ax.set_ylim(0,1)
    plt.savefig(folder_results+"accuracy_vs_iteration.png")
    plt.show()
    plt.close()




    ### really save?
    #pipeline.decoder.save()
    pipe.encoder.save()
    pipe.classifier.save()
    #pipeline.discriminator.save()


if track_hyperparameter:
    ## after training is done and all measures are written, finalize with
    wandb.finish()


In [None]:
fig, axs = plt.subplots(1, 3, figsize=(12,3))
for name, loss_series in losses_series.items():
    if not loss_series:
        continue
    if "discriminate" in name:
        ax = axs[0]
    elif "generator" in name:
        if len(name) > 22:
            continue
        ax = axs[1]
    else:
        ax = axs[2]
    ax.plot(np.arange(loss_interval,iteration,loss_interval), loss_series, label=name)
for ax in axs:
    ax.set_yscale("log")
    ax.legend()
plt.show()


In [None]:
def dot_flat(a, b):
    """ multiply each element in a to each element in b, return results in flat array """
    result = a[...,None] @ b[None,...]
    return result.flatten()


def label_leaves_likelihood(labels):
    """ compute likelihood for all leaves in the hierarchical label tree """
    def P_group(g):
        return labels[class_groups_indices[g]-1]
    P3 = P_group(3).sum()
    P5 = P_group(5).sum()
    P6 = P_group(6).sum()
#    P6 = P_group(1)[:2].sum()   
    P11 = P_group(11).sum()
    P3 = P3 if P3 else 1.
    P5 = P5 if P5 else 1.
    P6 = P6 if P6 else 1.
    P11 = P11 if P11 else 1.
    ## artifact                 1
    leaves = [P_group(1)[2]]
    ## round regular            3
    l = P_group(7) * P_group(6)[1] / P6
    leaves += list(l)
    ## round odd                21
    l = dot_flat(P_group(7), P_group(8)) / P6
    leaves += list(l)
    ## disk edge regular        3
    l = P_group(9) * P_group(6)[1] / P6
    leaves += list(l)
    ## disk edge odd            21
    l = dot_flat(P_group(9), P_group(8)) / P6
    leaves += list(l)
    ## disk arms regular        144
    l = dot_flat(dot_flat(dot_flat(P_group(3), P_group(10)), P_group(11)), P_group(5)) * P_group(6)[1] / P3 / P11 / P5 / P6 
    leaves += list(l)
    ## disk arms odd            1008
    l = dot_flat(dot_flat(dot_flat(dot_flat(P_group(3), P_group(10)), P_group(11)), P_group(5)), P_group(8)) / P3 / P11 / P5 / P6 
    leaves += list(l)
    ## disk no arms regular     8
    l = dot_flat(P_group(3), P_group(5)) * P_group(4)[1] * P_group(6)[1] / P3 / P5 / P6
    leaves += list(l)
    ## disk no arms odd         65
    l = dot_flat(dot_flat(P_group(3), P_group(5)), P_group(8)) * P_group(4)[1] / P3 / P5 / P6
    leaves += list(l)
    ## total                    1265
    return tensor(leaves)

N = 2
ix = np.random.randint(x_train.shape[0]-5)

image = x_train[ix:ix+2].cuda()

label, noise = pipe.classify_image(image)
label = generate_galaxy_labels(2)

print(f"{label_leaves_likelihood(label[0].cpu()).sum().item():.4} == {1.:.4}")
print(f"{label[0,class_groups_indices[1]-1].sum().item():.4} == {1.:.4}")
print(f"{label[0,class_groups_indices[2]-1].sum().item():.4} == {label[0,1].item():.4}")
print(f"{label[0,class_groups_indices[3]-1].sum().item():.4} == {label[0,4].item():.4}")
print(f"{label[0,class_groups_indices[4]-1].sum().item():.4} == {label[0,4].item():.4}")
print(f"{label[0,class_groups_indices[5]-1].sum().item():.4} == {label[0,4].item():.4}")
print(f"{label[0,class_groups_indices[6]-1].sum().item():.4} == {1.:.4}")
print(f"{label[0,class_groups_indices[7]-1].sum().item():.4} == {label[0,0].item():.4}")
print(f"{label[0,class_groups_indices[8]-1].sum().item():.4} == {label[0,13].item():.4}")
print(f"{label[0,class_groups_indices[9]-1].sum().item():.4} == {label[0,3].item():.4}")
print(f"{label[0,class_groups_indices[10]-1].sum().item():.4} == {label[0,7].item():.4}")
print(f"{label[0,class_groups_indices[11]-1].sum().item():.4} == {label[0,7].item():.4}")
print(label[0])

# Training

In [None]:
## if you want to change any parameter:
parameter.learning_rate = 0.0002
## the following can also be changed during runtime
parameter.alpha = 1.
parameter.beta = 1.
parameter.gamma = 1.
parameter.delta = 1.
parameter.zeta = 1.


## start training with fresh, untrained networks
pipeline.decoder = Decoder4().cuda()
pipeline.encoder = Encoder4().cuda()
pipeline.discriminator = Discriminator4().cuda()

## if you want to change discriminator, encoder or decoder network:
#pipeline.decoder = YourDecoder().cuda()
## before you create the VAE and VAEGAN

if track_hyperparameter:
    wandb.init(project="galaxy-generator", # top level identifier
            group="check_losses", # second level identifier, to seperate several groups of tests
            job_type="training", # third level identifier, organize different jobs like training and evaluation
            tags=["all one"], # temporary tags to organize different tasks together
            name=f"no class loss, delta=0", # bottom level identifier, label of graph in UI
            config=parameter.return_parameter_dict()  # here we fill the hyperparameters
    )




epochs = 20
batch_size = 64
steps = N_samples // batch_size
save_interval = 200

discriminator_losses = []
discriminator_losses_real = []
discriminator_losses_fake = []
generator_losses = []

valid = ones((batch_size,1)).cuda()
fake = zeros((batch_size,1)).cuda()

vae = VAE()
vaegan = VAEGAN()

iteration = 0
epoch = 0
step = 0





In [None]:
def training_step(images, labels):
    global iteration

    # -------------------
    # Train Discriminator
    # -------------------
    vae.train(False)
    pipeline.discriminator.train(True)
    pipeline.discriminator.zero_grad()

    generated_images = vae(images, labels)
    target_real = cat((valid,labels), dim=1)
    prediction_real = pipeline.discriminator(images)[:,:1+labels_dim]
    target_fake = cat((fake, labels), dim=1)
    prediction_fake = pipeline.discriminator(generated_images)[:,:1+labels_dim]

    d_loss_real = loss_discriminator(target_real, prediction_real)
    d_loss_fake = loss_discriminator(target_fake, prediction_fake)
    d_loss = 0.5 * add(d_loss_fake, d_loss_real)
    discriminator_losses.append(d_loss)
    discriminator_losses_fake.append(d_loss_fake)
    discriminator_losses_real.append(d_loss_real)

    d_loss_real.backward()
    d_loss_fake.backward()
    pipeline.discriminator.optimizer.step()

    # ---------------
    # Train Generator
    # ---------------
    vae.train(True)
    pipeline.discriminator.train(False)
    pipeline.encoder.zero_grad()
    pipeline.decoder.zero_grad()

    generated_images = vae(images, labels)
    target = pipeline.discriminator(images)
    target[:,0] = 1
    target[:,1:1+labels_dim] = labels
    target = target.detach()
    prediction = pipeline.discriminator(generated_images)
    latent = pipeline.encoder(images, labels)

    g_loss = loss_generator(target, prediction, images, generated_images, latent)
    g_loss.backward()
    pipeline.encoder.optimizer.step()
    pipeline.decoder.optimizer.step()
    generator_losses.append(g_loss)

    if track_hyperparameter:
        ## save measures to wandb.ai
        wandb.log({"loss discriminator":d_loss, "loss generator":g_loss}) 


    iteration += 1

    print(f"iteration {iteration}, epoch {epoch+1}, batch {step+1}/{steps}," + \
          f"disc_loss {d_loss:.5}, (real {d_loss_real:.5}, fake {d_loss_fake:.5} ) gen_loss {g_loss:.5}")

    if not iteration % save_interval:
        write_generated_galaxy_images_iteration(iteration=iteration, images=generated_images.detach().cpu().numpy())


In [None]:
while epoch < epochs:
    sample_generator = make_training_sample_generator(batch_size, x_train, labels_train)
    step = 0
    while step < steps:
        images, labels = next(sample_generator)
        training_step(images, labels)
        step += 1
    epoch += 1

    # save a plot of the costs
    plt.clf()
    plt.plot(discriminator_losses, label='discriminator cost')
    plt.plot(generator_losses, label='generator cost')
    plt.plot(discriminator_losses_fake, label='discriminator cost fake', linestyle=":")
    plt.plot(discriminator_losses_real, label='discriminator cost real', linestyle="-.")
    plt.yscale("log")
    plt.legend()
    plt.savefig(folder_results+"cost_vs_iteration.png")
    plt.close()


    ### really save?
    #pipeline.decoder.save()
    #pipeline.encoder.save()
    #pipeline.discriminator.save()


if track_hyperparameter:
    ## after training is done and all measures are written, finalize with
    wandb.finish()


# Testing