# Testing cosmogan
Nov 19, 2020

Borrowing pieces of code from : 

- https://github.com/pytorch/tutorials/blob/11569e0db3599ac214b03e01956c2971b02c64ce/beginner_source/dcgan_faces_tutorial.py
- https://github.com/exalearn/epiCorvid/tree/master/cGAN

In [1]:
import argparse
import os
import random
import logging
import sys

import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
from torchsummary import summary
from torch.utils.data import DataLoader, TensorDataset

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from IPython.display import HTML

import time
from datetime import datetime
import glob
import pickle
import yaml
import collections

In [2]:
%matplotlib widget

## Modules

In [3]:
def f_load_config(config_file):
    with open(config_file) as f:
        config = yaml.load(f, Loader=yaml.SafeLoader)
    return config

### Transformation functions for image pixel values
def f_transform(x):
    return 2.*x/(x + 4.) - 1.

def f_invtransform(s):
    return 4.*(1. + s)/(1. - s)


In [4]:
# custom weights initialization called on netG and netD
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

# Generator Code
class View(nn.Module):
    def __init__(self, shape):
        super(View, self).__init__()
        self.shape = shape

    def forward(self, x):
        return x.view(*self.shape)


In [5]:
def f_get_model(model_name,gdict):
    ''' Module to define Generator and Discriminator'''
    print("Model name",model_name)
    if model_name==1: ## With embeddings
        class Generator(nn.Module):
            def __init__(self, gdict):
                super(Generator, self).__init__()

                ## Define new variables from dict
                keys=['num_classes','ngpu','nz','nc','ngf','kernel_size','stride','g_padding']
                num_classes, ngpu, nz,nc,ngf,kernel_size,stride,g_padding=list(collections.OrderedDict({key:gdict[key] for key in keys}).values())

                self.label_embedding=nn.Embedding(num_classes,num_classes)
                self.main = nn.Sequential(
                    # nn.ConvTranspose2d(in_channels, out_channels, kernel_size,stride,padding,output_padding,groups,bias, Dilation,padding_mode)
                    nn.Linear(nz+num_classes,nc*ngf*8*8*8),# 32768
                    nn.BatchNorm2d(nc,eps=1e-05, momentum=0.9, affine=True),
                    nn.ReLU(inplace=True),
                    View(shape=[-1,ngf*8,8,8]),
                    nn.ConvTranspose2d(ngf * 8, ngf * 4, kernel_size, stride, g_padding, output_padding=1, bias=False),
                    nn.BatchNorm2d(ngf*4,eps=1e-05, momentum=0.9, affine=True),
                    nn.ReLU(inplace=True),
                    # state size. (ngf*4) x 8 x 8
                    nn.ConvTranspose2d( ngf * 4, ngf * 2, kernel_size, stride, g_padding, 1, bias=False),
                    nn.BatchNorm2d(ngf*2,eps=1e-05, momentum=0.9, affine=True),
                    nn.ReLU(inplace=True),
                    # state size. (ngf*2) x 16 x 16
                    nn.ConvTranspose2d( ngf * 2, ngf, kernel_size, stride, g_padding, 1, bias=False),
                    nn.BatchNorm2d(ngf,eps=1e-05, momentum=0.9, affine=True),
                    nn.ReLU(inplace=True),
                    # state size. (ngf) x 32 x 32
                    nn.ConvTranspose2d( ngf, nc, kernel_size, stride,g_padding, 1, bias=False),
                    nn.Tanh()
                )

            def forward(self, noise,labels):
                labels=labels.unsqueeze(-1).long()
                gen_input=torch.cat((self.label_embedding(labels),noise),-1)
                img=self.main(gen_input)
        #         print(type(img),img.size())
        #         img=img.view(128,nc,128,128))

                return img

        class Discriminator(nn.Module):
            def __init__(self, gdict):
                super(Discriminator, self).__init__()

                ## Define new variables from dict
                keys=['num_classes','ngpu','nz','nc','ndf','kernel_size','stride','d_padding']
                num_classes, ngpu, nz,nc,ndf,kernel_size,stride,d_padding=list(collections.OrderedDict({key:gdict[key] for key in keys}).values())
                self.label_embedding=nn.Embedding(num_classes,num_classes)

                self.linear_transf=nn.Linear(4,4)
                self.main = nn.Sequential(
                    # input is (nc) x 64 x 64
                    # nn.Conv2d(in_channels, out_channels, kernel_size,stride,padding,output_padding,groups,bias, Dilation,padding_mode)
                    nn.Conv2d(nc+1, ndf,kernel_size, stride, d_padding,  bias=True),
                    nn.BatchNorm2d(ndf,eps=1e-05, momentum=0.9, affine=True),
                    nn.LeakyReLU(0.2, inplace=True),
                    # state size. (ndf) x 32 x 32
                    nn.Conv2d(ndf, ndf * 2, kernel_size, stride, d_padding, bias=True),
                    nn.BatchNorm2d(ndf * 2,eps=1e-05, momentum=0.9, affine=True),
                    nn.LeakyReLU(0.2, inplace=True),
                    # state size. (ndf*2) x 16 x 16
                    nn.Conv2d(ndf * 2, ndf * 4, kernel_size, stride, d_padding, bias=True),
                    nn.BatchNorm2d(ndf * 4,eps=1e-05, momentum=0.9, affine=True),
                    nn.LeakyReLU(0.2, inplace=True),
                    # state size. (ndf*4) x 8 x 8
                    nn.Conv2d(ndf * 4, ndf * 8, kernel_size, stride, d_padding, bias=True),
                    nn.BatchNorm2d(ndf * 8,eps=1e-05, momentum=0.9, affine=True),
                    nn.LeakyReLU(0.2, inplace=True),
                    # state size. (ndf*8) x 4 x 4
                    nn.Flatten(),
                    nn.Linear(nc*ndf*8*8*8, 1)
        #             nn.Sigmoid()
                )

            def forward(self, img,labels):
                labels=labels.unsqueeze(-1).long()
                img_size=gdict['image_size']
                a=self.label_embedding(labels)
                x=a.view(a.size(0),-1)
                x=self.linear_transf(x)
                x=torch.repeat_interleave(x,int((img_size*img_size)/4))
                x=x.view(a.size(0),1,img_size,img_size)
        #         print(x.size())
                d_input=torch.cat((img,x),axis=1)
        #         d_input=torch.cat((img,self.label_embedding(labels)),-1)
                pred=self.main(d_input)
                return pred

    elif model_name==2: #### Model 2: without embeddings
        class Generator(nn.Module):
            def __init__(self, gdict):
                super(Generator, self).__init__()

                ## Define new variables from dict
                keys=['num_classes','ngpu','nz','nc','ngf','kernel_size','stride','g_padding']
                num_classes, ngpu, nz,nc,ngf,kernel_size,stride,g_padding=list(collections.OrderedDict({key:gdict[key] for key in keys}).values())

                self.main = nn.Sequential(
                    # nn.ConvTranspose2d(in_channels, out_channels, kernel_size,stride,padding,output_padding,groups,bias, Dilation,padding_mode)
                    nn.Linear(nz+1,nc*ngf*8*8*8),# 32768
                    nn.BatchNorm2d(nc,eps=1e-05, momentum=0.9, affine=True),
                    nn.ReLU(inplace=True),
                    View(shape=[-1,ngf*8,8,8]),
                    nn.ConvTranspose2d(ngf * 8, ngf * 4, kernel_size, stride, g_padding, output_padding=1, bias=False),
                    nn.BatchNorm2d(ngf*4,eps=1e-05, momentum=0.9, affine=True),
                    nn.ReLU(inplace=True),
                    # state size. (ngf*4) x 8 x 8
                    nn.ConvTranspose2d( ngf * 4, ngf * 2, kernel_size, stride, g_padding, 1, bias=False),
                    nn.BatchNorm2d(ngf*2,eps=1e-05, momentum=0.9, affine=True),
                    nn.ReLU(inplace=True),
                    # state size. (ngf*2) x 16 x 16
                    nn.ConvTranspose2d( ngf * 2, ngf, kernel_size, stride, g_padding, 1, bias=False),
                    nn.BatchNorm2d(ngf,eps=1e-05, momentum=0.9, affine=True),
                    nn.ReLU(inplace=True),
                    # state size. (ngf) x 32 x 32
                    nn.ConvTranspose2d( ngf, nc, kernel_size, stride,g_padding, 1, bias=False),
                    nn.Tanh()
                )

            def forward(self, noise,labels):
                x=labels.unsqueeze(-1).unsqueeze(-1).float()
                gen_input=torch.cat((noise,x),-1)
                img=self.main(gen_input)

                return img

        class Discriminator(nn.Module):
            def __init__(self, gdict):
                super(Discriminator, self).__init__()

                ## Define new variables from dict
                keys=['num_classes','ngpu','nz','nc','ndf','kernel_size','stride','d_padding']
                num_classes, ngpu, nz,nc,ndf,kernel_size,stride,d_padding=list(collections.OrderedDict({key:gdict[key] for key in keys}).values())

                self.linear_transf=nn.Linear(4,4)
                self.main = nn.Sequential(
                    # input is (nc) x 64 x 64
                    # nn.Conv2d(in_channels, out_channels, kernel_size,stride,padding,output_padding,groups,bias, Dilation,padding_mode)
                    nn.Conv2d(nc+1, ndf,kernel_size, stride, d_padding,  bias=True),
                    nn.BatchNorm2d(ndf,eps=1e-05, momentum=0.9, affine=True),
                    nn.LeakyReLU(0.2, inplace=True),
                    # state size. (ndf) x 32 x 32
                    nn.Conv2d(ndf, ndf * 2, kernel_size, stride, d_padding, bias=True),
                    nn.BatchNorm2d(ndf * 2,eps=1e-05, momentum=0.9, affine=True),
                    nn.LeakyReLU(0.2, inplace=True),
                    # state size. (ndf*2) x 16 x 16
                    nn.Conv2d(ndf * 2, ndf * 4, kernel_size, stride, d_padding, bias=True),
                    nn.BatchNorm2d(ndf * 4,eps=1e-05, momentum=0.9, affine=True),
                    nn.LeakyReLU(0.2, inplace=True),
                    # state size. (ndf*4) x 8 x 8
                    nn.Conv2d(ndf * 4, ndf * 8, kernel_size, stride, d_padding, bias=True),
                    nn.BatchNorm2d(ndf * 8,eps=1e-05, momentum=0.9, affine=True),
                    nn.LeakyReLU(0.2, inplace=True),
                    # state size. (ndf*8) x 4 x 4
                    nn.Flatten(),
                    nn.Linear(nc*ndf*8*8*8, 1)
        #             nn.Sigmoid()
                )

            def forward(self, img,labels):
                img_size=gdict['image_size']
                x=labels.unsqueeze(-1).unsqueeze(-1).repeat(1,1,1,4).float() # get to size (128,1,1,4)
                x=self.linear_transf(x)
                x=torch.repeat_interleave(x,int((img_size*img_size)/4)) # get to size (128,1, 128, 128)
                x=x.view(labels.size(0),1,img_size,img_size)
#                 print(x.size())
                d_input=torch.cat((img,x),axis=1)
                pred=self.main(d_input)
                return pred

    elif model_name==3:#### Model 3: with ConditionalInstanceNorm2d
        class ConditionalInstanceNorm2d(nn.Module):
            def __init__(self, num_features, num_params):
                super().__init__()
                self.num_features = num_features
                self.InstNorm = nn.InstanceNorm2d(num_features, affine=False)
                self.affine = nn.Linear(num_params, num_features * 2)
                self.affine.weight.data[:, :num_features].normal_(1, 0.02)  # Initialise scale at N(1, 0.02)
                self.affine.weight.data[:, num_features:].zero_()  # Initialise bias at 0

            def forward(self, x, y):
                out = self.InstNorm(x)
                gamma, beta = self.affine(y).chunk(2, 1)
                out = gamma.view(-1, self.num_features, 1, 1) * out + beta.view(-1, self.num_features, 1, 1)
                return out

        class ConditionalSequential(nn.Sequential):
            def __init__(self,*args):
                super(ConditionalSequential, self).__init__(*args)

            def forward(self, inputs, labels):
                for module in self:
                    if module.__class__ is ConditionalInstanceNorm2d:
                        inputs = module(inputs, labels.float())
                    else:
                        inputs = module(inputs)

                return inputs

        class Generator(nn.Module):
            def __init__(self, gdict):
                super(Generator, self).__init__()

                ## Define new variables from dict
                keys=['num_classes','ngpu','nz','nc','ngf','kernel_size','stride','g_padding']
                num_classes, ngpu, nz,nc,ngf,kernel_size,stride,g_padding=list(collections.OrderedDict({key:gdict[key] for key in keys}).values())

                self.main = ConditionalSequential(
                    # nn.ConvTranspose2d(in_channels, out_channels, kernel_size,stride,padding,output_padding,groups,bias, Dilation,padding_mode)
                    nn.Linear(nz,nc*ngf*8*8*8),# 32768
                    nn.BatchNorm2d(nc,eps=1e-05, momentum=0.9, affine=True),
                    nn.ReLU(inplace=True),
                    View(shape=[-1,ngf*8,8,8]),
                    nn.ConvTranspose2d(ngf * 8, ngf * 4, kernel_size, stride, g_padding, output_padding=1, bias=False),
        #             nn.BatchNorm2d(ngf*4,eps=1e-05, momentum=0.9, affine=True),
                    ConditionalInstanceNorm2d(ngf*4,1),
                    nn.ReLU(inplace=True),
                    # state size. (ngf*4) x 8 x 8
                    nn.ConvTranspose2d( ngf * 4, ngf * 2, kernel_size, stride, g_padding, 1, bias=False),
        #             nn.BatchNorm2d(ngf*2,eps=1e-05, momentum=0.9, affine=True),
                    ConditionalInstanceNorm2d(ngf*2,1),
                    nn.ReLU(inplace=True),
                    # state size. (ngf*2) x 16 x 16
                    nn.ConvTranspose2d( ngf * 2, ngf, kernel_size, stride, g_padding, 1, bias=False),
        #             nn.BatchNorm2d(ngf,eps=1e-05, momentum=0.9, affine=True),
                    ConditionalInstanceNorm2d(ngf,1),
                    nn.ReLU(inplace=True),
                    # state size. (ngf) x 32 x 32
                    nn.ConvTranspose2d( ngf, nc, kernel_size, stride,g_padding, 1, bias=False),
                    nn.Tanh()
                )

            def forward(self, noise,labels):
                img=self.main(noise,labels)

                return img

        class Discriminator(nn.Module):
            def __init__(self, gdict):
                super(Discriminator, self).__init__()

                ## Define new variables from dict
                keys=['num_classes','ngpu','nz','nc','ndf','kernel_size','stride','d_padding']
                num_classes, ngpu, nz,nc,ndf,kernel_size,stride,d_padding=list(collections.OrderedDict({key:gdict[key] for key in keys}).values())

                self.main = ConditionalSequential(
                    # input is (nc) x 64 x 64
                    # nn.Conv2d(in_channels, out_channels, kernel_size,stride,padding,output_padding,groups,bias, Dilation,padding_mode)
                    nn.Conv2d(nc, ndf,kernel_size, stride, d_padding,  bias=True),
        #             nn.BatchNorm2d(ndf,eps=1e-05, momentum=0.9, affine=True),
                    ConditionalInstanceNorm2d(ndf,1),
                    nn.LeakyReLU(0.2, inplace=True),
                    # state size. (ndf) x 32 x 32
                    nn.Conv2d(ndf, ndf * 2, kernel_size, stride, d_padding, bias=True),
        #             nn.BatchNorm2d(ndf * 2,eps=1e-05, momentum=0.9, affine=True),
                    ConditionalInstanceNorm2d(ndf*2,1),
                    nn.LeakyReLU(0.2, inplace=True),
                    # state size. (ndf*2) x 16 x 16
                    nn.Conv2d(ndf * 2, ndf * 4, kernel_size, stride, d_padding, bias=True),
        #             nn.BatchNorm2d(ndf * 4,eps=1e-05, momentum=0.9, affine=True),
                    ConditionalInstanceNorm2d(ndf*4,1),

                    nn.LeakyReLU(0.2, inplace=True),
                    # state size. (ndf*4) x 8 x 8
                    nn.Conv2d(ndf * 4, ndf * 8, kernel_size, stride, d_padding, bias=True),
        #             nn.BatchNorm2d(ndf * 8,eps=1e-05, momentum=0.9, affine=True),
                    ConditionalInstanceNorm2d(ndf*8,1),
                    nn.LeakyReLU(0.2, inplace=True),
                    # state size. (ndf*8) x 4 x 4
                    nn.Flatten(),
                    nn.Linear(nc*ndf*8*8*8, 1)
        #             nn.Sigmoid()
                )

            def forward(self, img,labels):   
                pred=self.main(img,labels)
                return pred

    elif model_name==4: #### Model 4: without embeddings
        class Generator(nn.Module):
            def __init__(self, gdict):
                super(Generator, self).__init__()

                ## Define new variables from dict
                keys=['num_classes','ngpu','nz','nc','ngf','kernel_size','stride','g_padding']
                num_classes, ngpu, nz,nc,ngf,kernel_size,stride,g_padding=list(collections.OrderedDict({key:gdict[key] for key in keys}).values())

                self.main = nn.Sequential(
                    # nn.ConvTranspose2d(in_channels, out_channels, kernel_size,stride,padding,output_padding,groups,bias, Dilation,padding_mode)
                    nn.Linear(nz+1,nc*ngf*8*8*8),# 32768
                    nn.BatchNorm2d(nc,eps=1e-05, momentum=0.9, affine=True),
                    nn.ReLU(inplace=True),
                    View(shape=[-1,ngf*8,8,8]),
                    nn.ConvTranspose2d(ngf * 8, ngf * 4, kernel_size, stride, g_padding, output_padding=1, bias=False),
                    nn.BatchNorm2d(ngf*4,eps=1e-05, momentum=0.9, affine=True),
                    nn.ReLU(inplace=True),
                    # state size. (ngf*4) x 8 x 8
                    nn.ConvTranspose2d( ngf * 4, ngf * 2, kernel_size, stride, g_padding, 1, bias=False),
                    nn.BatchNorm2d(ngf*2,eps=1e-05, momentum=0.9, affine=True),
                    nn.ReLU(inplace=True),
                    # state size. (ngf*2) x 16 x 16
                    nn.ConvTranspose2d( ngf * 2, ngf, kernel_size, stride, g_padding, 1, bias=False),
                    nn.BatchNorm2d(ngf,eps=1e-05, momentum=0.9, affine=True),
                    nn.ReLU(inplace=True),
                    # state size. (ngf) x 32 x 32
                    nn.ConvTranspose2d( ngf, nc, kernel_size, stride,g_padding, 1, bias=False),
                    nn.Tanh()
                )

            def forward(self, noise,labels):
                x=labels.unsqueeze(-1).unsqueeze(-1).float()
                gen_input=torch.cat((noise,x),-1)
                img=self.main(gen_input)

                return img

        class Discriminator(nn.Module):
            def __init__(self, gdict):
                super(Discriminator, self).__init__()

                ## Define new variables from dict
                keys=['num_classes','ngpu','nz','nc','ndf','kernel_size','stride','d_padding']
                num_classes, ngpu, nz,nc,ndf,kernel_size,stride,d_padding=list(collections.OrderedDict({key:gdict[key] for key in keys}).values())

                self.linear_transf=nn.Linear(4,4)
                self.main = nn.Sequential(
                    # input is (nc) x 64 x 64
                    # nn.Conv2d(in_channels, out_channels, kernel_size,stride,padding,output_padding,groups,bias, Dilation,padding_mode)
                    nn.Conv2d(nc+1, ndf,kernel_size, stride, d_padding,  bias=True),
                    nn.BatchNorm2d(ndf,eps=1e-05, momentum=0.9, affine=True),
                    nn.LeakyReLU(0.2, inplace=True),
                    # state size. (ndf) x 32 x 32
                    nn.Conv2d(ndf, ndf * 2, kernel_size, stride, d_padding, bias=True),
                    nn.BatchNorm2d(ndf * 2,eps=1e-05, momentum=0.9, affine=True),
                    nn.LeakyReLU(0.2, inplace=True),
                    # state size. (ndf*2) x 16 x 16
                    nn.Conv2d(ndf * 2, ndf * 4, kernel_size, stride, d_padding, bias=True),
                    nn.BatchNorm2d(ndf * 4,eps=1e-05, momentum=0.9, affine=True),
                    nn.LeakyReLU(0.2, inplace=True),
                    # state size. (ndf*4) x 8 x 8
                    nn.Conv2d(ndf * 4, ndf * 8, kernel_size, stride, d_padding, bias=True),
                    nn.BatchNorm2d(ndf * 8,eps=1e-05, momentum=0.9, affine=True),
                    nn.LeakyReLU(0.2, inplace=True),
                    # state size. (ndf*8) x 4 x 4
                    nn.Flatten(),
                    nn.Linear(nc*ndf*8*8*8, 1)
        #             nn.Sigmoid()
                )

            def forward(self, img,labels):
                img_size=gdict['image_size']
                x=labels.unsqueeze(-1).unsqueeze(-1).repeat(1,1,1,4).float() # get to size (128,1,1,4)
                x=self.linear_transf(x)
                x=torch.repeat_interleave(x,int((img_size*img_size)/4)) # get to size (128,1, 128, 128)
                x=x.view(labels.size(0),1,img_size,img_size)
#                 print(x.size())
                d_input=torch.cat((img,x),axis=1)
                pred=self.main(d_input)
                return pred

    elif model_name==5:#### Model 5: with ConditionalInstanceNorm2d
        class ConditionalInstanceNorm2d(nn.Module):
            def __init__(self, num_features, num_params):
                super().__init__()
                self.num_features = num_features
                self.InstNorm = nn.InstanceNorm2d(num_features, affine=False)
                self.affine = nn.Linear(num_params, num_features * 2)
                self.affine.weight.data[:, :num_features].normal_(1, 0.02)  # Initialise scale at N(1, 0.02)
                self.affine.weight.data[:, num_features:].zero_()  # Initialise bias at 0

            def forward(self, x, y):
                out = self.InstNorm(x)
                gamma, beta = self.affine(y).chunk(2, 1)
                out = gamma.view(-1, self.num_features, 1, 1) * out + beta.view(-1, self.num_features, 1, 1)
                return out

        class ConditionalSequential(nn.Sequential):
            def __init__(self,*args):
                super(ConditionalSequential, self).__init__(*args)

            def forward(self, inputs, labels):
                for module in self:
                    if module.__class__ is ConditionalInstanceNorm2d:
                        inputs = module(inputs, labels.float())
                    else:
                        inputs = module(inputs)

                return inputs

        class Generator(nn.Module):
            def __init__(self, gdict):
                super(Generator, self).__init__()

                ## Define new variables from dict
                keys=['num_classes','ngpu','nz','nc','ngf','kernel_size','stride','g_padding']
                num_classes, ngpu, nz,nc,ngf,kernel_size,stride,g_padding=list(collections.OrderedDict({key:gdict[key] for key in keys}).values())

                self.main = ConditionalSequential(
                    # nn.ConvTranspose2d(in_channels, out_channels, kernel_size,stride,padding,output_padding,groups,bias, Dilation,padding_mode)
                    nn.Linear(nz,nc*ngf*8*8*8),# 32768
                    nn.BatchNorm2d(nc,eps=1e-05, momentum=0.9, affine=True),
                    nn.ReLU(inplace=True),
                    View(shape=[-1,ngf*8,8,8]),
                    nn.ConvTranspose2d(ngf * 8, ngf * 4, kernel_size, stride, g_padding, output_padding=1, bias=False),
        #             nn.BatchNorm2d(ngf*4,eps=1e-05, momentum=0.9, affine=True),
                    ConditionalInstanceNorm2d(ngf*4,1),
                    nn.ReLU(inplace=True),
                    # state size. (ngf*4) x 8 x 8
                    nn.ConvTranspose2d( ngf * 4, ngf * 2, kernel_size, stride, g_padding, 1, bias=False),
        #             nn.BatchNorm2d(ngf*2,eps=1e-05, momentum=0.9, affine=True),
                    ConditionalInstanceNorm2d(ngf*2,1),
                    nn.ReLU(inplace=True),
                    # state size. (ngf*2) x 16 x 16
                    nn.ConvTranspose2d( ngf * 2, ngf, kernel_size, stride, g_padding, 1, bias=False),
        #             nn.BatchNorm2d(ngf,eps=1e-05, momentum=0.9, affine=True),
                    ConditionalInstanceNorm2d(ngf,1),
                    nn.ReLU(inplace=True),
                    # state size. (ngf) x 32 x 32
                    nn.ConvTranspose2d( ngf, nc, kernel_size, stride,g_padding, 1, bias=False),
                    nn.Tanh()
                )

            def forward(self, noise,labels):
                img=self.main(noise,labels)

                return img

        class Discriminator(nn.Module):
            def __init__(self, gdict):
                super(Discriminator, self).__init__()

                ## Define new variables from dict
                keys=['num_classes','ngpu','nz','nc','ndf','kernel_size','stride','d_padding']
                num_classes, ngpu, nz,nc,ndf,kernel_size,stride,d_padding=list(collections.OrderedDict({key:gdict[key] for key in keys}).values())

                self.main = ConditionalSequential(
                    # input is (nc) x 64 x 64
                    # nn.Conv2d(in_channels, out_channels, kernel_size,stride,padding,output_padding,groups,bias, Dilation,padding_mode)
                    nn.Conv2d(nc, ndf,kernel_size, stride, d_padding,  bias=True),
        #             nn.BatchNorm2d(ndf,eps=1e-05, momentum=0.9, affine=True),
                    ConditionalInstanceNorm2d(ndf,1),
                    nn.LeakyReLU(0.2, inplace=True),
                    # state size. (ndf) x 32 x 32
                    nn.Conv2d(ndf, ndf * 2, kernel_size, stride, d_padding, bias=True),
        #             nn.BatchNorm2d(ndf * 2,eps=1e-05, momentum=0.9, affine=True),
                    ConditionalInstanceNorm2d(ndf*2,1),
                    nn.LeakyReLU(0.2, inplace=True),
                    # state size. (ndf*2) x 16 x 16
                    nn.Conv2d(ndf * 2, ndf * 4, kernel_size, stride, d_padding, bias=True),
        #             nn.BatchNorm2d(ndf * 4,eps=1e-05, momentum=0.9, affine=True),
                    ConditionalInstanceNorm2d(ndf*4,1),

                    nn.LeakyReLU(0.2, inplace=True),
                    # state size. (ndf*4) x 8 x 8
                    nn.Conv2d(ndf * 4, ndf * 8, kernel_size, stride, d_padding, bias=True),
        #             nn.BatchNorm2d(ndf * 8,eps=1e-05, momentum=0.9, affine=True),
                    ConditionalInstanceNorm2d(ndf*8,1),
                    nn.LeakyReLU(0.2, inplace=True),
                    # state size. (ndf*8) x 4 x 4
                    nn.Flatten(),
                    nn.Linear(nc*ndf*8*8*8, 1)
        #             nn.Sigmoid()
                )

            def forward(self, img,labels):   
                pred=self.main(img,labels)
                return pred

    return Generator, Discriminator


In [6]:
def f_get_sigma(ip_categories,gdict):
    sigma_list=gdict['sigma_list']
    return torch.tensor([sigma_list[i] for i in ip_categories.long()],device=gdict['device']).unsqueeze(-1)


def f_gen_images(gdict,netG,optimizerG,label,ip_fname,op_loc,op_strg='inf_img_',op_size=500):
    '''Generate images for best saved models
     Arguments: gdict, netG, optimizerG, 
                 label : class label
                 ip_fname: name of input file
                op_strg: [string name for output file]
                op_size: Number of images to generate
    '''

    nz,device=gdict['nz'],gdict['device']

    try:
        if torch.cuda.is_available(): checkpoint=torch.load(ip_fname)
        else: checkpoint=torch.load(ip_fname,map_location=torch.device('cpu'))
    except Exception as e:
        print(e)
        print("skipping generation of images for ",ip_fname)
        return
    
    ## Load checkpoint
    if gdict['multi-gpu']:
        netG.module.load_state_dict(checkpoint['G_state'])
    else:
        netG.load_state_dict(checkpoint['G_state'])
    
    ## Load other stuff
    iters=checkpoint['iters']
    epoch=checkpoint['epoch']
    optimizerG.load_state_dict(checkpoint['optimizerG_state_dict'])
    
    # Generate batch of latent vectors
    noise = torch.randn(op_size, 1, 1, nz, device=device)
    tnsr_categories=(torch.ones(op_size,device=device)*label).view(op_size,1)
    if gdict['model']>3: tnsr_categories=f_get_sigma(tnsr_categories,gdict)

    # Generate fake image batch with G
    netG.eval() ## This is required before running inference
    gen = netG(noise,tnsr_categories)
    gen_images=gen.detach().cpu().numpy()[:,0,:,:]
    print(gen_images.shape)

    op_fname='%s_epoch-%s_step-%s_label-%s.npy'%(op_strg,epoch,iters,label)
    np.save(op_loc+op_fname,gen_images)
    
    print("Image saved in ",op_fname)
    
def f_save_checkpoint(gdict,epoch,iters,best_chi1,best_chi2,netG,netD,optimizerG,optimizerD,save_loc):
    ''' Checkpoint model '''
    
    if gdict['multi-gpu']: ## Dataparallel
        torch.save({'epoch':epoch,'iters':iters,'best_chi1':best_chi1,'best_chi2':best_chi2,
                'G_state':netG.module.state_dict(),'D_state':netD.module.state_dict(),'optimizerG_state_dict':optimizerG.state_dict(),
                'optimizerD_state_dict':optimizerD.state_dict()}, save_loc) 
    else :
        torch.save({'epoch':epoch,'iters':iters,'best_chi1':best_chi1,'best_chi2':best_chi2,
                'G_state':netG.state_dict(),'D_state':netD.state_dict(),'optimizerG_state_dict':optimizerG.state_dict(),
                'optimizerD_state_dict':optimizerD.state_dict()}, save_loc)


def f_load_checkpoint(ip_fname,netG,netD,optimizerG,optimizerD,gdict):
    ''' Load saved checkpoint
    Also loads step, epoch, best_chi1, best_chi2'''
    
    try:
        checkpoint=torch.load(ip_fname)
    except Exception as e:
        print(e)
        print("skipping generation of images for ",ip_fname)
        raise SystemError
    
    ## Load checkpoint
    if gdict['multi-gpu']:
        netG.module.load_state_dict(checkpoint['G_state'])
        netD.module.load_state_dict(checkpoint['D_state'])
    else:
        netG.load_state_dict(checkpoint['G_state'])
        netD.load_state_dict(checkpoint['D_state'])
    
    optimizerD.load_state_dict(checkpoint['optimizerD_state_dict'])
    optimizerG.load_state_dict(checkpoint['optimizerG_state_dict'])
    
    iters=checkpoint['iters']
    epoch=checkpoint['epoch']
    best_chi1=checkpoint['best_chi1']
    best_chi2=checkpoint['best_chi2']

    netG.train()
    netD.train()
    
    return iters,epoch,best_chi1,best_chi2

In [7]:
####################
### Pytorch code ###
####################

def f_torch_radial_profile(img, center=(None,None)):
    ''' Module to compute radial profile of a 2D image 
    Bincount causes issues with backprop, so not using this code
    '''
    
    y,x=torch.meshgrid(torch.arange(0,img.shape[0]),torch.arange(0,img.shape[1])) # Get a grid of x and y values
    if center[0]==None and center[1]==None:
        center = torch.Tensor([(x.max()-x.min())/2.0, (y.max()-y.min())/2.0]) # compute centers

    # get radial values of every pair of points
    r = torch.sqrt((x - center[0])**2 + (y - center[1])**2)
    r= r.int()
    
#     print(r.shape,img.shape)
    # Compute histogram of r values
    tbin=torch.bincount(torch.reshape(r,(-1,)),weights=torch.reshape(img,(-1,)).type(torch.DoubleTensor))
    nr = torch.bincount(torch.reshape(r,(-1,)))
    radialprofile = tbin / nr
    
    return radialprofile[1:-1]


def f_torch_get_azimuthalAverage_with_batch(image, center=None): ### Not used in this code.
    """
    Calculate the azimuthally averaged radial profile. Only use if you need to combine batches

    image - The 2D image
    center - The [x,y] pixel coordinates used as the center. The default is 
             None, which then uses the center of the image (including 
             fracitonal pixels).
    source: https://www.astrobetter.com/blog/2010/03/03/fourier-transforms-of-images-in-python/
    """
    
    batch, channel, height, width = image.shape
    # Create a grid of points with x and y coordinates
    y, x = np.indices([height,width])

    if not center:
        center = np.array([(x.max()-x.min())/2.0, (y.max()-y.min())/2.0])

    # Get the radial coordinate for every grid point. Array has the shape of image
    r = torch.tensor(np.hypot(x - center[0], y - center[1]))

    # Get sorted radii
    ind = torch.argsort(torch.reshape(r, (batch, channel,-1)))
    r_sorted = torch.gather(torch.reshape(r, (batch, channel, -1,)),2, ind)
    i_sorted = torch.gather(torch.reshape(image, (batch, channel, -1,)),2, ind)

    # Get the integer part of the radii (bin size = 1)
    r_int=r_sorted.to(torch.int32)

    # Find all pixels that fall within each radial bin.
    deltar = r_int[:,:,1:] - r_int[:,:,:-1]  # Assumes all radii represented
    rind = torch.reshape(torch.where(deltar)[2], (batch, -1))    # location of changes in radius
    rind=torch.unsqueeze(rind,1)
    nr = (rind[:,:,1:] - rind[:,:,:-1]).type(torch.float)       # number of radius bin

    # Cumulative sum to figure out sums for each radius bin

    csum = torch.cumsum(i_sorted, axis=-1)
#     print(csum.shape,rind.shape,nr.shape)

    tbin = torch.gather(csum, 2, rind[:,:,1:]) - torch.gather(csum, 2, rind[:,:,:-1])
    radial_prof = tbin / nr

    return radial_prof


def f_get_rad(img):
    ''' Get the radial tensor for use in f_torch_get_azimuthalAverage '''
    
    height,width=img.shape[-2:]
    # Create a grid of points with x and y coordinates
    y, x = np.indices([height,width])
    
    center=[]
    if not center:
        center = np.array([(x.max()-x.min())/2.0, (y.max()-y.min())/2.0])

    # Get the radial coordinate for every grid point. Array has the shape of image
    r = torch.tensor(np.hypot(x - center[0], y - center[1]))
    
    # Get sorted radii
    ind = torch.argsort(torch.reshape(r, (-1,)))
    
    return r.detach(),ind.detach()


def f_torch_get_azimuthalAverage(image,r,ind):
    """
    Calculate the azimuthally averaged radial profile.

    image - The 2D image
    center - The [x,y] pixel coordinates used as the center. The default is 
             None, which then uses the center of the image (including 
             fracitonal pixels).
    source: https://www.astrobetter.com/blog/2010/03/03/fourier-transforms-of-images-in-python/
    """
    
#     height, width = image.shape
#     # Create a grid of points with x and y coordinates
#     y, x = np.indices([height,width])

#     if not center:
#         center = np.array([(x.max()-x.min())/2.0, (y.max()-y.min())/2.0])

#     # Get the radial coordinate for every grid point. Array has the shape of image
#     r = torch.tensor(np.hypot(x - center[0], y - center[1]))

#     # Get sorted radii
#     ind = torch.argsort(torch.reshape(r, (-1,)))

    r_sorted = torch.gather(torch.reshape(r, ( -1,)),0, ind)
    i_sorted = torch.gather(torch.reshape(image, ( -1,)),0, ind)
    
    # Get the integer part of the radii (bin size = 1)
    r_int=r_sorted.to(torch.int32)

    # Find all pixels that fall within each radial bin.
    deltar = r_int[1:] - r_int[:-1]  # Assumes all radii represented
    rind = torch.reshape(torch.where(deltar)[0], (-1,))    # location of changes in radius
    nr = (rind[1:] - rind[:-1]).type(torch.float)       # number of radius bin

    # Cumulative sum to figure out sums for each radius bin
    
    csum = torch.cumsum(i_sorted, axis=-1)
    tbin = torch.gather(csum, 0, rind[1:]) - torch.gather(csum, 0, rind[:-1])
    radial_prof = tbin / nr

    return radial_prof

def f_torch_fftshift(real, imag):
    for dim in range(0, len(real.size())):
        real = torch.roll(real, dims=dim, shifts=real.size(dim)//2)
        imag = torch.roll(imag, dims=dim, shifts=imag.size(dim)//2)
    return real, imag

def f_torch_compute_spectrum(arr,r,ind):
    
    GLOBAL_MEAN=1.0
    arr=(arr-GLOBAL_MEAN)/(GLOBAL_MEAN)
    y1=torch.rfft(arr,signal_ndim=2,onesided=False)
    real,imag=f_torch_fftshift(y1[:,:,0],y1[:,:,1])    ## last index is real/imag part
    y2=real**2+imag**2     ## Absolute value of each complex number
    
#     print(y2.shape)
    z1=f_torch_get_azimuthalAverage(y2,r,ind)     ## Compute radial profile
    
    return z1

def f_torch_compute_batch_spectrum(arr,r,ind):
    
    batch_pk=torch.stack([f_torch_compute_spectrum(i,r,ind) for i in arr])
    
    return batch_pk

def f_torch_image_spectrum(x,num_channels,r,ind):
    '''
    Data has to be in the form (batch,channel,x,y)
    '''
    
    mean=[[] for i in range(num_channels)]    
    sdev=[[] for i in range(num_channels)]    

    for i in range(num_channels):
        arr=x[:,i,:,:]
        batch_pk=f_torch_compute_batch_spectrum(arr,r,ind)
        mean[i]=torch.mean(batch_pk,axis=0)
#         sdev[i]=torch.std(batch_pk,axis=0)/np.sqrt(batch_pk.shape[0])
#         sdev[i]=torch.std(batch_pk,axis=0)
        sdev[i]=torch.var(batch_pk,axis=0)
    
    mean=torch.stack(mean)
    sdev=torch.stack(sdev)
        
    return mean,sdev

def f_compute_hist(data,bins):
    
    try: 
        hist_data=torch.histc(data,bins=bins)
        ## A kind of normalization of histograms: divide by total sum
        hist_data=(hist_data*bins)/torch.sum(hist_data)
    except Exception as e:
        print(e)
        hist_data=torch.zeros(bins)

    return hist_data

### Losses 
def loss_spectrum(spec_mean,spec_mean_ref,spec_std,spec_std_ref,image_size,lambda1):
    ''' Loss function for the spectrum : mean + variance 
    Log(sum( batch value - expect value) ^ 2 )) '''
    
    idx=int(image_size/2) ### For the spectrum, use only N/2 indices for loss calc.
    ### Warning: the first index is the channel number.For multiple channels, you are averaging over them, which is fine.
        
    spec_mean=torch.log(torch.mean(torch.pow(spec_mean[:,:idx]-spec_mean_ref[:,:idx],2)))
    spec_sdev=torch.log(torch.mean(torch.pow(spec_std[:,:idx]-spec_std_ref[:,:idx],2)))
    
    lambda1=lambda1;
    lambda2=lambda1;
    ans=lambda1*spec_mean+lambda2*spec_sdev
    
    if torch.isnan(spec_sdev).any():    print("spec loss with nan",ans)
    
    return ans
    

def loss_hist(hist_sample,hist_ref):
    
    lambda1=1.0
    return lambda1*torch.log(torch.mean(torch.pow(hist_sample-hist_ref,2)))


In [8]:

# def f_get_hist_cond(img_tensor,categories,bins,gdict,hist_val_tnsr):
#     ''' Module to compute pixel intensity histogram loss for conditional GAN '''
#     num_classes=gdict['num_classes'];device=gdict['device']
    
#     loss_hist_tensor=torch.zeros(num_classes,device=device)
#     for i in np.arange(num_classes):    
#         idxs=torch.where(categories==i)[0] ## Get indices for that category
#     #     print(i,idxs.size(0))
#         if idxs.size(0)>1: 
#             img=img_tensor[idxs]
#             loss_hist_tensor[i]=loss_hist(f_compute_hist(img,bins),hist_val_tnsr[i])
#     hist_loss=loss_hist_tensor.sum()
    
#     return hist_loss

# def f_get_spec_cond(img_tensor,categories,gdict,spec_mean_tnsr,spec_sdev_tnsr,r,ind):
#     ''' Module to compute spectral loss for conditional GAN '''
#     num_classes=gdict['num_classes'];device=gdict['device']
    
#     loss_spec_tensor=torch.zeros(num_classes,device=device)
#     for i in np.arange(num_classes):    
#         idxs=torch.where(categories==i)[0] ## Get indices for that category
#     #     print(i,idxs.size(0))
#         if idxs.size(0)>1: 
#             img=img_tensor[idxs]
#             mean,sdev=f_torch_image_spectrum(f_invtransform(img),1,r,ind)
#             loss_spec_tensor[i]=loss_spectrum(mean,spec_mean_tnsr[i],sdev,spec_sdev_tnsr[i],gdict['image_size'],gdict['lambda1'])
#     spec_loss=loss_spec_tensor.sum()
#     return spec_loss



def f_get_hist_cond(img_tensor,categories,bins,gdict,hist_val_tnsr):
    ''' Module to compute pixel intensity histogram loss for conditional GAN '''
    num_classes=gdict['num_classes'];device=gdict['device']
    
    loss_hist_tensor=torch.zeros(num_classes,device=device)
    for count,i in enumerate(gdict['sigma_list']):    
        idxs=torch.where(categories==i)[0] ## Get indices for that category
        if idxs.size(0)>1: 
            num_frac=idxs.size(0)/img_tensor.shape[0] ## Fraction of points in the category
            img=img_tensor[idxs]
            loss_hist_tensor[count]=loss_hist(f_compute_hist(img,bins),hist_val_tnsr[count])*num_frac
    hist_loss=loss_hist_tensor.sum()
    
    return hist_loss

def f_get_spec_cond(img_tensor,categories,gdict,spec_mean_tnsr,spec_sdev_tnsr,r,ind):
    ''' Module to compute spectral loss for conditional GAN '''
    num_classes=gdict['num_classes'];device=gdict['device']
    
    loss_spec_tensor=torch.zeros(num_classes,device=device)
    for count,i in enumerate(gdict['sigma_list']):    
        idxs=torch.where(categories==i)[0] ## Get indices for that category
        if idxs.size(0)>1: 
            num_frac=idxs.size(0)/img_tensor.shape[0] ## Fraction of points in the category
            img=img_tensor[idxs]
            mean,sdev=f_torch_image_spectrum(f_invtransform(img),1,r,ind)
            loss_spec_tensor[count]=loss_spectrum(mean,spec_mean_tnsr[count],sdev,spec_sdev_tnsr[count],gdict['image_size'],gdict['lambda1'])*num_frac
    spec_loss=loss_spec_tensor.sum()
    return spec_loss



In [9]:
# def f_size(ip):
#     p=2;s=2
# #     return (ip + 2 * 0 - 1 * (p-1) -1 )/ s + 1

#     return (ip-1)*s - 2 * p + 1 *(5-1)+ 1 + 1

# f_size(128)

In [10]:
# logging.basicConfig(filename=save_dir+'/log.log',filemode='w',format='%(name)s - %(levelname)s - %(message)s')

## Main code

In [11]:
def f_train_loop(dataloader,metrics_df,gdict):
    ''' Train single epoch '''
    
    ## Define new variables from dict
    keys=['start_epoch','epochs','iters','best_chi1','best_chi2','save_dir','device','flip_prob','nz','num_classes','batch_size','bns']
    start_epoch,epochs,iters,best_chi1,best_chi2,save_dir,device,flip_prob,nz,num_classes,batch_size,bns=list(collections.OrderedDict({key:gdict[key] for key in keys}).values())
    
    for epoch in range(start_epoch,epochs):
        t_epoch_start=time.time()
        for count, data in enumerate(dataloader, 0):
            
            ####### Train GAN ########
            netG.train(); netD.train();  ### Need to add these after inference and before training

            tme1=time.time()
            ### Update D network: maximize log(D(x)) + log(1 - D(G(z)))
            netD.zero_grad()

            real_cpu = data[0].to(device)
            real_categories=data[1].to(device)
#             real_categories=real_categories.squeeze(-1)
            
            b_size = real_cpu.size(0)
            real_label = torch.full((b_size,), 1, device=device)
            fake_label = torch.full((b_size,), 0, device=device)
            g_label = torch.full((b_size,), 1, device=device) # No flipping for Generator labels
            # Flip labels with probability flip_prob for Discriminator
            for idx in np.random.choice(np.arange(b_size),size=int(np.ceil(b_size*flip_prob))):
                real_label[idx]=0; fake_label[idx]=1
            
            # Generate fake image batch with G
            noise = torch.randn(b_size, 1, 1, nz, device=device)
            fake_categories=torch.randint(gdict['num_classes'],(gdict['batch_size'],1),device=gdict['device'])
            if gdict['model']>3: fake_categories=f_get_sigma(fake_categories,gdict)
            fake = netG(noise,fake_categories)  
            
            # Forward pass real batch through D
            output = netD(real_cpu,real_categories).view(-1)
            errD_real = criterion(output, real_label)
            errD_real.backward()
            D_x = output.mean().item()
            
            # Forward pass fake batch through D
            output = netD(fake.detach(),fake_categories).view(-1)
            errD_fake = criterion(output, fake_label)
            errD_fake.backward()
            D_G_z1 = output.mean().item()
            errD = errD_real + errD_fake
            optimizerD.step()
            
            ###Update G network: maximize log(D(G(z)))
            netG.zero_grad()
            output = netD(fake,fake_categories).view(-1)
            errG_adv = criterion(output, g_label)
            # Histogram pixel intensity loss
            
#             hist_gen=f_compute_hist(fake,bins=bns)
#             hist_loss=loss_hist(hist_gen,hist_val.to(device))
            hist_loss=f_get_hist_cond(fake,fake_categories,bns,gdict,hist_val_tnsr)
            
            # Add spectral loss
#             mean,sdev=f_torch_image_spectrum(f_invtransform(fake),1,r.to(device),ind.to(device))
#             spec_loss=loss_spectrum(mean,mean_spec_val.to(device),sdev,sdev_spec_val.to(device),image_size)
            spec_loss=f_get_spec_cond(fake,fake_categories,gdict,spec_mean_tnsr,spec_sdev_tnsr,r,ind)

            if gdict['spec_loss_flag']: errG=errG_adv+spec_loss
            else: errG=errG_adv
#             errG=errG_adv
            if torch.isnan(errG).any():
                print(errG)
                raise SystemError
            
            # Calculate gradients for G
            errG.backward()
            D_G_z2 = output.mean().item()
            optimizerG.step()
            
            tme2=time.time()
            
            ####### Store metrics ########
            # Output training stats
            if count % gdict['checkpoint_size'] == 0:
                print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_adv: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f'
                      % (epoch, epochs, count, len(dataloader), errD.item(), errG_adv.item(),errG.item(), D_x, D_G_z1, D_G_z2)),
                print("Spec loss: %s,\t hist loss: %s"%(spec_loss.item(),hist_loss.item())),
                print("Training time for step %s : %s"%(iters, tme2-tme1))

            # Save metrics
            cols=['step','epoch','Dreal','Dfake','Dfull','G_adv','G_full','spec_loss','hist_loss','D(x)','D_G_z1','D_G_z2','time']
            vals=[iters,epoch,errD_real.item(),errD_fake.item(),errD.item(),errG_adv.item(),errG.item(),spec_loss.item(),hist_loss.item(),D_x,D_G_z1,D_G_z2,tme2-tme1]
            for col,val in zip(cols,vals):  metrics_df.loc[iters,col]=val

            ### Checkpoint the best model
            checkpoint=True
            iters += 1  ### Model has been updated, so update iters before saving metrics and model.

            ### Compute validation metrics for updated model
            netG.eval()
            with torch.no_grad():
                fake = netG(fixed_noise,fixed_categories)
#                 hist_gen=f_compute_hist(fake,bins=bns)
#                 hist_chi=loss_hist(hist_gen,hist_val.to(device))
#                 mean,sdev=f_torch_image_spectrum(f_invtransform(fake),1,r.to(device),ind.to(device))
#                 spec_chi=loss_spectrum(mean,mean_spec_val.to(device),sdev,sdev_spec_val.to(device),image_size)      
                hist_chi=f_get_hist_cond(fake,fixed_categories,bns,gdict,hist_val_tnsr)
                spec_chi=f_get_spec_cond(fake,fixed_categories,gdict,spec_mean_tnsr,spec_sdev_tnsr,r,ind)

            # Storing chi for next step
            for col,val in zip(['spec_chi','hist_chi'],[spec_chi.item(),hist_chi.item()]):  metrics_df.loc[iters,col]=val            

            # Checkpoint model for continuing run
            if count == len(dataloader)-1: ## Check point at last step of epoch
                f_save_checkpoint(gdict,epoch,iters,best_chi1,best_chi2,netG,netD,optimizerG,optimizerD,save_loc=save_dir+'/models/checkpoint_last.tar')  

            if (checkpoint and (epoch > 1)): # Choose best models by metric
                if hist_chi< best_chi1:
                    f_save_checkpoint(gdict,epoch,iters,best_chi1,best_chi2,netG,netD,optimizerG,optimizerD,save_loc=save_dir+'/models/checkpoint_best_hist.tar')
                    best_chi1=hist_chi.item()
                    print("Saving best hist model at epoch %s, step %s."%(epoch,iters))

                if  spec_chi< best_chi2:
                    f_save_checkpoint(gdict,epoch,iters,best_chi1,best_chi2,netG,netD,optimizerG,optimizerD,save_loc=save_dir+'/models/checkpoint_best_spec.tar')
                    best_chi2=spec_chi.item()
                    print("Saving best spec model at epoch %s, step %s"%(epoch,iters))
                    
                if iters in gdict['save_steps_list']:
                    f_save_checkpoint(gdict,epoch,iters,best_chi1,best_chi2,netG,netD,optimizerG,optimizerD,save_loc=save_dir+'/models/checkpoint_{0}.tar'.format(iters))
                    print("Saving given-step at epoch %s, step %s."%(epoch,iters))
                    
            # Save G's output on fixed_noise
            if ((iters % gdict['checkpoint_size'] == 0) or ((epoch == epochs-1) and (count == len(dataloader)-1))):
                netG.eval()
                with torch.no_grad():
                    for category in range(num_classes):
#                         cat_tensor=torch.LongTensor(np.ones(batch_size)*category).view(batch_size,1,1)
                        tnsr_categories=(torch.ones(batch_size,device=device)*category).view(batch_size,1)
                        if gdict['model']>3: tnsr_categories=f_get_sigma(tnsr_categories,gdict)
                        fake = netG(fixed_noise,tnsr_categories).detach().cpu()
                        img_arr=np.array(fake[:,0,:,:])
                        fname='gen_img_label-%s_epoch-%s_step-%s'%(category,epoch,iters)
                        np.save(save_dir+'/images/'+fname,img_arr)
        
        t_epoch_end=time.time()
        print("Time taken for epoch %s: %s"%(epoch,t_epoch_end-t_epoch_start))
        # Save Metrics to file after each epoch
        metrics_df.to_pickle(save_dir+'/df_metrics.pkle')
        
    print("best chis: {0}, {1}".format(best_chi1,best_chi2))

In [12]:
def f_init_gdict(gdict,config_dict):
    ''' Initialize the global dictionary gdict with values in config file'''
    keys1=['workers','nc','nz','ngf','ndf','beta1','kernel_size','stride','g_padding','d_padding','flip_prob']
    keys2=['image_size','checkpoint_size','num_imgs','ip_fname','op_loc']
    for key in keys1: gdict[key]=config_dict['training'][key]
    for key in keys2: gdict[key]=config_dict['data'][key]



In [16]:
if __name__=="__main__":
    torch.backends.cudnn.benchmark=True
#     torch.backends.cudnn.deterministic=True
    
    t0=time.time()
    #################################
#     args=f_parse_args()
    # Manually add args ( different for jupyter notebook)
    args=argparse.Namespace()
    args.config='1_main_code/config_128.yaml'
    args.ngpu=1
    args.batch_size=128
    args.spec_loss_flag=False
    args.checkpoint_size=100
    args.epochs=5
    args.learn_rate=0.0002
    args.mode='fresh'
#     args.mode='continue'
#     args.ip_fldr='/global/cfs/cdirs/m3363/vayyar/cosmogan_data/results_from_other_code/pytorch/results/128sq/20201211_093818_nb_test/'    args.run_suffix='_nb_test'
    args.run_suffix='nb_test'
    args.deterministic=False
    args.seed='234373'
    args.lambda1=0.1
    args.model=4
    args.save_steps_list=[5,10]
    
    ### Set up ###
    config_file=args.config
    config_dict=f_load_config(config_file)

    # Initilize variables    
    gdict={}
    f_init_gdict(gdict,config_dict)
    
    ## Add args variables to gdict
    gdict['sigma_list']=[0.5,0.65,0.8,1.1]
    for key in ['ngpu','batch_size','mode','spec_loss_flag','epochs','learn_rate','lambda1','save_steps_list','model']:
        gdict[key]=vars(args)[key]
       
    ###### Set up directories #######
    if gdict['mode']=='fresh':
        # Create prefix for foldername        
        fldr_name=datetime.now().strftime('%Y%m%d_%H%M%S') ## time format
        gdict['save_dir']=gdict['op_loc']+fldr_name+'_'+args.run_suffix
        
        if not os.path.exists(gdict['save_dir']):
            os.makedirs(gdict['save_dir']+'/models')
            os.makedirs(gdict['save_dir']+'/images')
        
    elif args.mode=='continue': ## For checkpointed runs
        gdict['save_dir']=args.ip_fldr
        ### Read loss data
        with open (gdict['save_dir']+'df_metrics.pkle','rb') as f:
            metrics_dict=pickle.load(f) 

#     ### Write all print statements to stdout and log file (different for jpt notebooks)
#     logfile=gdict['save_dir']+'/log.log'
#     logging.basicConfig(level=logging.DEBUG, filename=logfile, filemode="a+", format="%(asctime)-15s %(levelname)-8s %(message)s")
    
#     Lg = logging.getLogger()
#     Lg.setLevel(logging.DEBUG)
#     lg_handler_file = logging.FileHandler(logfile)
#     lg_handler_stdout = logging.StreamHandler(sys.stdout)
#     Lg.addHandler(lg_handler_file)
#     Lg.addHandler(lg_handler_stdout)
    
#     print('Args: {0}'.format(args))
#     print(config_dict)
#     print('Start: %s'%(datetime.now().strftime('%Y-%m-%d  %H:%M:%S')))
#     if gdict['spec_loss_flag']: print("Using Spectral loss")

    ### Override (different for jpt notebooks)
    gdict['num_imgs']=2000
#     gdict['learn_rate']=0.0008

    ## Special declarations
    gdict['bns']=50
    gdict['num_classes']=4
    gdict['device']=torch.device("cuda" if (torch.cuda.is_available() and gdict['ngpu'] > 0) else "cpu")
    gdict['ngpu']=torch.cuda.device_count()
    
    gdict['multi-gpu']=True if (gdict['device'].type == 'cuda') and (gdict['ngpu'] > 1) else False 
    print(gdict)
    
    ### Initialize random seed
    if args.seed=='random': manualSeed = np.random.randint(1, 10000)
    else: manualSeed=int(args.seed)
    print("Seed:{0}".format(manualSeed))
    random.seed(manualSeed)
    np.random.seed(manualSeed)
    torch.manual_seed(manualSeed)
    torch.cuda.manual_seed_all(manualSeed)
    print('Device:{0}'.format(gdict['device']))

    if args.deterministic: 
        print("Running with deterministic sequence. Performance will be slower")
        torch.backends.cudnn.deterministic=True
#         torch.backends.cudnn.enabled = False
        torch.backends.cudnn.benchmark = False
    
    #################################   
    ### Read input data from different files
    for count,sigma in enumerate(gdict['sigma_list']):
        fname=gdict['ip_fname']+'/norm_1_sig_%s_train_val.npy'%(sigma)
        x=np.load(fname,mmap_mode='r')[:gdict['num_imgs']].transpose(0,1,2,3)
        size=x.shape[0]
        y=count*np.ones(size)

        if count==0:
            img=x[:]
            lab=y[:]
        else: 
            img=np.vstack([img,x]) # Store images
            lab=np.hstack([lab,y]) # Store class labels

    t_img=torch.from_numpy(img)
    labels=torch.LongTensor(lab).view(size*4,1)
    print("%s, %s"%(labels.shape,t_img.shape))
    
    dataset=TensorDataset(t_img,labels)
    dataloader=DataLoader(dataset,batch_size=gdict['batch_size'],shuffle=True,num_workers=0,drop_last=True)

    # Precompute metrics with validation data for computing losses
    with torch.no_grad():
        spec_mean_list=[];spec_sdev_list=[];hist_val_list=[]
        
        for count,sigma in enumerate(gdict['sigma_list']):
            ip_fname=gdict['ip_fname']+'/norm_1_sig_%s_train_val.npy'%(sigma)
            val_img=np.load(ip_fname,mmap_mode='r')[-3000:].transpose(0,1,2,3)
            t_val_img=torch.from_numpy(val_img).to(gdict['device'])

            # Precompute radial coordinates
            if count==0: 
                r,ind=f_get_rad(img)
                r=r.to(gdict['device']); ind=ind.to(gdict['device'])
            # Stored mean and std of spectrum for full input data once
            mean_spec_val,sdev_spec_val=f_torch_image_spectrum(f_invtransform(t_val_img),1,r,ind)
            hist_val=f_compute_hist(t_val_img,bins=gdict['bns'])
            
            spec_mean_list.append(mean_spec_val)
            spec_sdev_list.append(sdev_spec_val)
            hist_val_list.append(hist_val)
        spec_mean_tnsr=torch.stack(spec_mean_list)
        spec_sdev_tnsr=torch.stack(spec_sdev_list)
        hist_val_tnsr=torch.stack(hist_val_list)
        
        del val_img; del t_val_img; del img; del t_img; del spec_mean_list; del spec_sdev_list; del hist_val_list
    
    #################################
    ###### Build Networks ###
    # Define Models
    Generator, Discriminator=f_get_model(gdict['model'],gdict)
    print("Building GAN networks")
    # Create Generator
    netG = Generator(gdict).to(gdict['device'])
    netG.apply(weights_init)
#     print(netG)
#     summary(netG,(1,1,64))
    # Create Discriminator
    netD = Discriminator(gdict).to(gdict['device'])
    netD.apply(weights_init)
#     print(netD)
#     summary(netD,(1,128,128))
    
    print("Number of GPUs used %s"%(gdict['ngpu']))
    if (gdict['multi-gpu']):
        netG = nn.DataParallel(netG, list(range(gdict['ngpu'])))
        netD = nn.DataParallel(netD, list(range(gdict['ngpu'])))
    
    #### Initialize networks ####
    # criterion = nn.BCELoss()
    criterion = nn.BCEWithLogitsLoss()
    
    if gdict['mode']=='fresh':
        optimizerD = optim.Adam(netD.parameters(), lr=gdict['learn_rate'], betas=(gdict['beta1'], 0.999),eps=1e-7)
        optimizerG = optim.Adam(netG.parameters(), lr=gdict['learn_rate'], betas=(gdict['beta1'], 0.999),eps=1e-7)
        ### Initialize variables      
        iters,start_epoch,best_chi1,best_chi2=0,0,1e10,1e10    
    
    ### Load network weights for continuing run
    elif gdict['mode']=='continue':
        iters,start_epoch,best_chi1,best_chi2=f_load_checkpoint(gdict['save_dir']+'/models/checkpoint_last.tar',netG,netD,optimizerG,optimizerD,gdict) 
        print("Continuing existing run. Loading checkpoint with epoch {0} and step {1}".format(start_epoch,iters))
        start_epoch+=1  ## Start with the next epoch  

    ## Add to gdict
    for key,val in zip(['best_chi1','best_chi2','iters','start_epoch'],[best_chi1,best_chi2,iters,start_epoch]): gdict[key]=val
    print(gdict)
    
    fixed_noise = torch.randn(gdict['batch_size'], 1, 1, gdict['nz'], device=gdict['device']) #Latent vectors to view G progress
    fixed_categories=torch.randint(gdict['num_classes'],(gdict['batch_size'],1),device=gdict['device'])
    if gdict['model']>3: fixed_categories=f_get_sigma(fixed_categories,gdict)


{'workers': 2, 'nc': 1, 'nz': 64, 'ngf': 64, 'ndf': 64, 'beta1': 0.5, 'kernel_size': 5, 'stride': 2, 'g_padding': 2, 'd_padding': 2, 'flip_prob': 0.01, 'image_size': 128, 'checkpoint_size': 10, 'num_imgs': 2000, 'ip_fname': '/global/cfs/cdirs/m3363/vayyar/cosmogan_data/raw_data/128_square/dataset_5_4univ_cgan', 'op_loc': '/global/cfs/cdirs/m3363/vayyar/cosmogan_data/results_from_other_code/pytorch/results/128sq/', 'sigma_list': [0.5, 0.65, 0.8, 1.1], 'ngpu': 1, 'batch_size': 128, 'mode': 'fresh', 'spec_loss_flag': False, 'epochs': 5, 'learn_rate': 0.0002, 'lambda1': 0.1, 'save_steps_list': [5, 10], 'model': 4, 'save_dir': '/global/cfs/cdirs/m3363/vayyar/cosmogan_data/results_from_other_code/pytorch/results/128sq/20201214_184914_nb_test', 'bns': 50, 'num_classes': 4, 'device': device(type='cuda'), 'multi-gpu': False}
Seed:234373
Device:cuda
torch.Size([8000, 1]), torch.Size([8000, 1, 128, 128])
Model name 4
Building GAN networks
Number of GPUs used 1
{'workers': 2, 'nc': 1, 'nz': 64, 'n

In [17]:
if __name__=="__main__":
    #################################       
    ### Set up metrics dataframe
    cols=['step','epoch','Dreal','Dfake','Dfull','G_adv','G_full','spec_loss','hist_loss','spec_chi','hist_chi','D(x)','D_G_z1','D_G_z2','time']
    # size=int(len(dataloader) * epochs)+1
    metrics_df=pd.DataFrame(columns=cols)
    
    #################################
    ########## Train loop and save metrics and images ######
    print("Starting Training Loop...")
    f_train_loop(dataloader,metrics_df,gdict)
    
    ### Generate images for best saved models ######
    for cl in np.arange(gdict['num_classes']):
        op_loc=gdict['save_dir']+'/images/'
        ip_fname=gdict['save_dir']+'/models/checkpoint_best_spec.tar'
        f_gen_images(gdict,netG,optimizerG,cl,ip_fname,op_loc,op_strg='best_spec',op_size=200)

        ip_fname=gdict['save_dir']+'/models/checkpoint_best_hist.tar'
        f_gen_images(gdict,netG,optimizerG,cl,ip_fname,op_loc,op_strg='best_hist',op_size=200)

    tf=time.time()
    print("Total time %s"%(tf-t0))
    print('End: %s'%(datetime.now().strftime('%Y-%m-%d  %H:%M:%S')))
    

Starting Training Loop...
[0/5][0/62]	Loss_D: 1.3360	Loss_adv: 2.1226	Loss_G: 2.1226	D(x): -0.1150	D(G(z)): -0.2790 / -1.9769
Spec loss: 8.47372817993164,	 hist loss: 1.7015306949615479
Training time for step 0 : 0.3444478511810303
[0/5][10/62]	Loss_D: 0.3116	Loss_adv: 3.3980	Loss_G: 3.3980	D(x): 1.6894	D(G(z)): -3.4901 / -3.3620
Spec loss: 8.6365327835083,	 hist loss: 1.6095385551452637
Training time for step 10 : 0.3219606876373291
[0/5][20/62]	Loss_D: 0.1948	Loss_adv: 0.8379	Loss_G: 0.8379	D(x): 5.5186	D(G(z)): -2.5910 / -0.1358
Spec loss: 8.612759590148926,	 hist loss: 1.4996614456176758
Training time for step 20 : 0.3218653202056885
[0/5][30/62]	Loss_D: 0.2266	Loss_adv: 3.8646	Loss_G: 3.8646	D(x): 3.9871	D(G(z)): -2.9224 / -3.8374
Spec loss: 8.577536582946777,	 hist loss: 1.256093978881836
Training time for step 30 : 0.3233358860015869
[0/5][40/62]	Loss_D: 0.1805	Loss_adv: 6.4732	Loss_G: 6.4732	D(x): 5.9008	D(G(z)): -5.1730 / -6.4715
Spec loss: 8.540903091430664,	 hist loss: 1.080

In [19]:
metrics_df.shape

(311, 15)

In [18]:
metrics_df.plot(x='step',y=['hist_loss','spec_chi'],kind='line')
# metrics_df.plot(x='step',y=['hist_loss','hist_chi'],kind='line')

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

<AxesSubplot:xlabel='step'>

In [13]:

def f_get_hist_cond(img_tensor,categories,bins,gdict,hist_val_tnsr):
    ''' Module to compute pixel intensity histogram loss for conditional GAN '''
    num_classes=gdict['num_classes'];device=gdict['device']
    
    loss_hist_tensor=torch.zeros(num_classes,device=device)
    for count,i in enumerate(gdict['sigma_list']):    
        idxs=torch.where(categories==i)[0] ## Get indices for that category
        if idxs.size(0)>1: 
            num_frac=idxs.size(0)/img_tensor.shape[0] ## Fraction of points in the category
            img=img_tensor[idxs]
            loss_hist_tensor[count]=loss_hist(f_compute_hist(img,bins),hist_val_tnsr[count])*num_frac
    hist_loss=loss_hist_tensor.sum()
    
    return hist_loss

def f_get_spec_cond(img_tensor,categories,gdict,spec_mean_tnsr,spec_sdev_tnsr,r,ind):
    ''' Module to compute spectral loss for conditional GAN '''
    num_classes=gdict['num_classes'];device=gdict['device']
    
    loss_spec_tensor=torch.zeros(num_classes,device=device)
    for count,i in enumerate(gdict['sigma_list']):    
        idxs=torch.where(categories==i)[0] ## Get indices for that category
        if idxs.size(0)>1: 
            num_frac=idxs.size(0)/img_tensor.shape[0] ## Fraction of points in the category
            img=img_tensor[idxs]
            mean,sdev=f_torch_image_spectrum(f_invtransform(img),1,r,ind)
            loss_spec_tensor[count]=loss_spectrum(mean,spec_mean_tnsr[count],sdev,spec_sdev_tnsr[count],gdict['image_size'],gdict['lambda1'])*num_frac
    spec_loss=loss_spec_tensor.sum()
    return spec_loss


In [None]:
fixed_categories=torch.randint(gdict['num_classes'],(gdict['batch_size'],1),device=gdict['device'])
# fixed_categories=(torch.ones(gdict['batch_size'],device=gdict['device'])*3).view(gdict['batch_size'],1)
if gdict['model']>3: fixed_categories=f_get_sigma(fixed_categories,gdict)

In [None]:
# def f_get_hist_cond(img_tensor,categories,bins,gdict,hist_val_tnsr):
#     ''' Module to compute pixel intensity histogram loss for conditional GAN '''
#     num_classes=gdict['num_classes'];device=gdict['device']
#     print(img_tensor.shape[0])
#     loss_hist_tensor=torch.zeros(num_classes,device=device)
#     for count,i in enumerate(gdict['sigma_list']):    
#         idxs=torch.where(categories==i)[0] ## Get indices for that category
# #         print(categories.shape,idxs)
# #         print(i,idxs.size(0))
#         if idxs.size(0)>1: 
#             num_frac=idxs.size(0)/img_tensor.shape[0]
#             print(idxs.shape,count,i,num_frac)
#             img=img_tensor[idxs]
#             loss_hist_tensor[count]=loss_hist(f_compute_hist(img,bins),hist_val_tnsr[count])*num_frac
# #             print(f_compute_hist(img,bins),hist_val_tnsr[count])
# #             print(torch.log(torch.mean(torch.pow(f_compute_hist(img,bins)-hist_val_tnsr[count],2))))
# #             print(loss_hist_tensor)
#     hist_loss=loss_hist_tensor.sum()
    
#     return hist_loss

hist_chi=f_get_hist_cond(t_val_img[:128],fixed_categories,50,gdict,hist_val_tnsr)
print(hist_chi)
spec_chi=f_get_spec_cond(t_val_img[:128],fixed_categories,gdict,spec_mean_tnsr,spec_sdev_tnsr,r,ind)
print(spec_chi)

In [None]:
# fixed_categories
# idxs.shape[0]

In [None]:
hist_chi=f_get_hist_cond(fake,fixed_categories,50,gdict,hist_val_tnsr)
print(hist_chi)
spec_chi=f_get_spec_cond(fake,fixed_categories,gdict,spec_mean_tnsr,spec_sdev_tnsr,r,ind)

In [None]:
t_val_img.shape

In [None]:
f_compute_hist(t_val_img,50),hist_val_tnsr[2]

### Test models


In [None]:
def f_get_model(model_name,gdict):
    ''' Module to define Generator and Discriminator'''
    print("Model name",model_name)
    if model_name==1: ## With embeddings
        class Generator(nn.Module):
            def __init__(self, gdict):
                super(Generator, self).__init__()

                ## Define new variables from dict
                keys=['num_classes','ngpu','nz','nc','ngf','kernel_size','stride','g_padding']
                num_classes, ngpu, nz,nc,ngf,kernel_size,stride,g_padding=list(collections.OrderedDict({key:gdict[key] for key in keys}).values())

                self.label_embedding=nn.Embedding(num_classes,num_classes)
                self.main = nn.Sequential(
                    # nn.ConvTranspose2d(in_channels, out_channels, kernel_size,stride,padding,output_padding,groups,bias, Dilation,padding_mode)
                    nn.Linear(nz+num_classes,nc*ngf*8*8*8),# 32768
                    nn.BatchNorm2d(nc,eps=1e-05, momentum=0.9, affine=True),
                    nn.ReLU(inplace=True),
                    View(shape=[-1,ngf*8,8,8]),
                    nn.ConvTranspose2d(ngf * 8, ngf * 4, kernel_size, stride, g_padding, output_padding=1, bias=False),
                    nn.BatchNorm2d(ngf*4,eps=1e-05, momentum=0.9, affine=True),
                    nn.ReLU(inplace=True),
                    # state size. (ngf*4) x 8 x 8
                    nn.ConvTranspose2d( ngf * 4, ngf * 2, kernel_size, stride, g_padding, 1, bias=False),
                    nn.BatchNorm2d(ngf*2,eps=1e-05, momentum=0.9, affine=True),
                    nn.ReLU(inplace=True),
                    # state size. (ngf*2) x 16 x 16
                    nn.ConvTranspose2d( ngf * 2, ngf, kernel_size, stride, g_padding, 1, bias=False),
                    nn.BatchNorm2d(ngf,eps=1e-05, momentum=0.9, affine=True),
                    nn.ReLU(inplace=True),
                    # state size. (ngf) x 32 x 32
                    nn.ConvTranspose2d( ngf, nc, kernel_size, stride,g_padding, 1, bias=False),
                    nn.Tanh()
                )

            def forward(self, noise,labels):
                labels=labels.unsqueeze(-1).long()
                gen_input=torch.cat((self.label_embedding(labels),noise),-1)
                img=self.main(gen_input)
        #         print(type(img),img.size())
        #         img=img.view(128,nc,128,128))

                return img

        class Discriminator(nn.Module):
            def __init__(self, gdict):
                super(Discriminator, self).__init__()

                ## Define new variables from dict
                keys=['num_classes','ngpu','nz','nc','ndf','kernel_size','stride','d_padding']
                num_classes, ngpu, nz,nc,ndf,kernel_size,stride,d_padding=list(collections.OrderedDict({key:gdict[key] for key in keys}).values())
                self.label_embedding=nn.Embedding(num_classes,num_classes)

                self.linear_transf=nn.Linear(4,4)
                self.main = nn.Sequential(
                    # input is (nc) x 64 x 64
                    # nn.Conv2d(in_channels, out_channels, kernel_size,stride,padding,output_padding,groups,bias, Dilation,padding_mode)
                    nn.Conv2d(nc+1, ndf,kernel_size, stride, d_padding,  bias=True),
                    nn.BatchNorm2d(ndf,eps=1e-05, momentum=0.9, affine=True),
                    nn.LeakyReLU(0.2, inplace=True),
                    # state size. (ndf) x 32 x 32
                    nn.Conv2d(ndf, ndf * 2, kernel_size, stride, d_padding, bias=True),
                    nn.BatchNorm2d(ndf * 2,eps=1e-05, momentum=0.9, affine=True),
                    nn.LeakyReLU(0.2, inplace=True),
                    # state size. (ndf*2) x 16 x 16
                    nn.Conv2d(ndf * 2, ndf * 4, kernel_size, stride, d_padding, bias=True),
                    nn.BatchNorm2d(ndf * 4,eps=1e-05, momentum=0.9, affine=True),
                    nn.LeakyReLU(0.2, inplace=True),
                    # state size. (ndf*4) x 8 x 8
                    nn.Conv2d(ndf * 4, ndf * 8, kernel_size, stride, d_padding, bias=True),
                    nn.BatchNorm2d(ndf * 8,eps=1e-05, momentum=0.9, affine=True),
                    nn.LeakyReLU(0.2, inplace=True),
                    # state size. (ndf*8) x 4 x 4
                    nn.Flatten(),
                    nn.Linear(nc*ndf*8*8*8, 1)
        #             nn.Sigmoid()
                )

            def forward(self, img,labels):
                labels=labels.unsqueeze(-1).long()
                img_size=gdict['image_size']
                a=self.label_embedding(labels)
                x=a.view(a.size(0),-1)
                x=self.linear_transf(x)
                x=torch.repeat_interleave(x,int((img_size*img_size)/4))
                x=x.view(a.size(0),1,img_size,img_size)
        #         print(x.size())
                d_input=torch.cat((img,x),axis=1)
        #         d_input=torch.cat((img,self.label_embedding(labels)),-1)
                pred=self.main(d_input)
                return pred

    elif model_name==2: #### Model 2: without embeddings
        class Generator(nn.Module):
            def __init__(self, gdict):
                super(Generator, self).__init__()

                ## Define new variables from dict
                keys=['num_classes','ngpu','nz','nc','ngf','kernel_size','stride','g_padding']
                num_classes, ngpu, nz,nc,ngf,kernel_size,stride,g_padding=list(collections.OrderedDict({key:gdict[key] for key in keys}).values())

                self.main = nn.Sequential(
                    # nn.ConvTranspose2d(in_channels, out_channels, kernel_size,stride,padding,output_padding,groups,bias, Dilation,padding_mode)
                    nn.Linear(nz+1,nc*ngf*8*8*8),# 32768
                    nn.BatchNorm2d(nc,eps=1e-05, momentum=0.9, affine=True),
                    nn.ReLU(inplace=True),
                    View(shape=[-1,ngf*8,8,8]),
                    nn.ConvTranspose2d(ngf * 8, ngf * 4, kernel_size, stride, g_padding, output_padding=1, bias=False),
                    nn.BatchNorm2d(ngf*4,eps=1e-05, momentum=0.9, affine=True),
                    nn.ReLU(inplace=True),
                    # state size. (ngf*4) x 8 x 8
                    nn.ConvTranspose2d( ngf * 4, ngf * 2, kernel_size, stride, g_padding, 1, bias=False),
                    nn.BatchNorm2d(ngf*2,eps=1e-05, momentum=0.9, affine=True),
                    nn.ReLU(inplace=True),
                    # state size. (ngf*2) x 16 x 16
                    nn.ConvTranspose2d( ngf * 2, ngf, kernel_size, stride, g_padding, 1, bias=False),
                    nn.BatchNorm2d(ngf,eps=1e-05, momentum=0.9, affine=True),
                    nn.ReLU(inplace=True),
                    # state size. (ngf) x 32 x 32
                    nn.ConvTranspose2d( ngf, nc, kernel_size, stride,g_padding, 1, bias=False),
                    nn.Tanh()
                )

            def forward(self, noise,labels):
                x=labels.unsqueeze(-1).unsqueeze(-1).float()
                print(x.shape)
                gen_input=torch.cat((noise,x),-1)
                img=self.main(gen_input)

                return img

        class Discriminator(nn.Module):
            def __init__(self, gdict):
                super(Discriminator, self).__init__()

                ## Define new variables from dict
                keys=['num_classes','ngpu','nz','nc','ndf','kernel_size','stride','d_padding']
                num_classes, ngpu, nz,nc,ndf,kernel_size,stride,d_padding=list(collections.OrderedDict({key:gdict[key] for key in keys}).values())

                self.linear_transf=nn.Linear(4,4)
                self.main = nn.Sequential(
                    # input is (nc) x 64 x 64
                    # nn.Conv2d(in_channels, out_channels, kernel_size,stride,padding,output_padding,groups,bias, Dilation,padding_mode)
                    nn.Conv2d(nc+1, ndf,kernel_size, stride, d_padding,  bias=True),
                    nn.BatchNorm2d(ndf,eps=1e-05, momentum=0.9, affine=True),
                    nn.LeakyReLU(0.2, inplace=True),
                    # state size. (ndf) x 32 x 32
                    nn.Conv2d(ndf, ndf * 2, kernel_size, stride, d_padding, bias=True),
                    nn.BatchNorm2d(ndf * 2,eps=1e-05, momentum=0.9, affine=True),
                    nn.LeakyReLU(0.2, inplace=True),
                    # state size. (ndf*2) x 16 x 16
                    nn.Conv2d(ndf * 2, ndf * 4, kernel_size, stride, d_padding, bias=True),
                    nn.BatchNorm2d(ndf * 4,eps=1e-05, momentum=0.9, affine=True),
                    nn.LeakyReLU(0.2, inplace=True),
                    # state size. (ndf*4) x 8 x 8
                    nn.Conv2d(ndf * 4, ndf * 8, kernel_size, stride, d_padding, bias=True),
                    nn.BatchNorm2d(ndf * 8,eps=1e-05, momentum=0.9, affine=True),
                    nn.LeakyReLU(0.2, inplace=True),
                    # state size. (ndf*8) x 4 x 4
                    nn.Flatten(),
                    nn.Linear(nc*ndf*8*8*8, 1)
        #             nn.Sigmoid()
                )

            def forward(self, img,labels):
                img_size=gdict['image_size']
                x=labels.unsqueeze(-1).unsqueeze(-1).repeat(1,1,1,4).float() # get to size (128,1,1,4)
                print(x.shape)
                x=self.linear_transf(x)
                x=torch.repeat_interleave(x,int((img_size*img_size)/4)) # get to size (128,1, 128, 128)
                x=x.view(labels.size(0),1,img_size,img_size)
#                 print(x.size())
                d_input=torch.cat((img,x),axis=1)
                pred=self.main(d_input)
                return pred

    elif model_name==3:#### Model 3: with ConditionalInstanceNorm2d
        class ConditionalInstanceNorm2d(nn.Module):
            def __init__(self, num_features, num_params):
                super().__init__()
                self.num_features = num_features
                self.InstNorm = nn.InstanceNorm2d(num_features, affine=False)
                self.affine = nn.Linear(num_params, num_features * 2)
                self.affine.weight.data[:, :num_features].normal_(1, 0.02)  # Initialise scale at N(1, 0.02)
                self.affine.weight.data[:, num_features:].zero_()  # Initialise bias at 0

            def forward(self, x, y):
                out = self.InstNorm(x)
                gamma, beta = self.affine(y).chunk(2, 1)
                out = gamma.view(-1, self.num_features, 1, 1) * out + beta.view(-1, self.num_features, 1, 1)
                return out

        class ConditionalSequential(nn.Sequential):
            def __init__(self,*args):
                super(ConditionalSequential, self).__init__(*args)

            def forward(self, inputs, labels):
                for module in self:
                    if module.__class__ is ConditionalInstanceNorm2d:
                        inputs = module(inputs, labels.float())
                    else:
                        inputs = module(inputs)

                return inputs

        class Generator(nn.Module):
            def __init__(self, gdict):
                super(Generator, self).__init__()

                ## Define new variables from dict
                keys=['num_classes','ngpu','nz','nc','ngf','kernel_size','stride','g_padding']
                num_classes, ngpu, nz,nc,ngf,kernel_size,stride,g_padding=list(collections.OrderedDict({key:gdict[key] for key in keys}).values())

                self.main = ConditionalSequential(
                    # nn.ConvTranspose2d(in_channels, out_channels, kernel_size,stride,padding,output_padding,groups,bias, Dilation,padding_mode)
                    nn.Linear(nz,nc*ngf*8*8*8),# 32768
                    nn.BatchNorm2d(nc,eps=1e-05, momentum=0.9, affine=True),
                    nn.ReLU(inplace=True),
                    View(shape=[-1,ngf*8,8,8]),
                    nn.ConvTranspose2d(ngf * 8, ngf * 4, kernel_size, stride, g_padding, output_padding=1, bias=False),
        #             nn.BatchNorm2d(ngf*4,eps=1e-05, momentum=0.9, affine=True),
                    ConditionalInstanceNorm2d(ngf*4,1),
                    nn.ReLU(inplace=True),
                    # state size. (ngf*4) x 8 x 8
                    nn.ConvTranspose2d( ngf * 4, ngf * 2, kernel_size, stride, g_padding, 1, bias=False),
        #             nn.BatchNorm2d(ngf*2,eps=1e-05, momentum=0.9, affine=True),
                    ConditionalInstanceNorm2d(ngf*2,1),
                    nn.ReLU(inplace=True),
                    # state size. (ngf*2) x 16 x 16
                    nn.ConvTranspose2d( ngf * 2, ngf, kernel_size, stride, g_padding, 1, bias=False),
        #             nn.BatchNorm2d(ngf,eps=1e-05, momentum=0.9, affine=True),
                    ConditionalInstanceNorm2d(ngf,1),
                    nn.ReLU(inplace=True),
                    # state size. (ngf) x 32 x 32
                    nn.ConvTranspose2d( ngf, nc, kernel_size, stride,g_padding, 1, bias=False),
                    nn.Tanh()
                )

            def forward(self, noise,labels):
                img=self.main(noise,labels)

                return img

        class Discriminator(nn.Module):
            def __init__(self, gdict):
                super(Discriminator, self).__init__()

                ## Define new variables from dict
                keys=['num_classes','ngpu','nz','nc','ndf','kernel_size','stride','d_padding']
                num_classes, ngpu, nz,nc,ndf,kernel_size,stride,d_padding=list(collections.OrderedDict({key:gdict[key] for key in keys}).values())

                self.main = ConditionalSequential(
                    # input is (nc) x 64 x 64
                    # nn.Conv2d(in_channels, out_channels, kernel_size,stride,padding,output_padding,groups,bias, Dilation,padding_mode)
                    nn.Conv2d(nc, ndf,kernel_size, stride, d_padding,  bias=True),
        #             nn.BatchNorm2d(ndf,eps=1e-05, momentum=0.9, affine=True),
                    ConditionalInstanceNorm2d(ndf,1),
                    nn.LeakyReLU(0.2, inplace=True),
                    # state size. (ndf) x 32 x 32
                    nn.Conv2d(ndf, ndf * 2, kernel_size, stride, d_padding, bias=True),
        #             nn.BatchNorm2d(ndf * 2,eps=1e-05, momentum=0.9, affine=True),
                    ConditionalInstanceNorm2d(ndf*2,1),
                    nn.LeakyReLU(0.2, inplace=True),
                    # state size. (ndf*2) x 16 x 16
                    nn.Conv2d(ndf * 2, ndf * 4, kernel_size, stride, d_padding, bias=True),
        #             nn.BatchNorm2d(ndf * 4,eps=1e-05, momentum=0.9, affine=True),
                    ConditionalInstanceNorm2d(ndf*4,1),

                    nn.LeakyReLU(0.2, inplace=True),
                    # state size. (ndf*4) x 8 x 8
                    nn.Conv2d(ndf * 4, ndf * 8, kernel_size, stride, d_padding, bias=True),
        #             nn.BatchNorm2d(ndf * 8,eps=1e-05, momentum=0.9, affine=True),
                    ConditionalInstanceNorm2d(ndf*8,1),
                    nn.LeakyReLU(0.2, inplace=True),
                    # state size. (ndf*8) x 4 x 4
                    nn.Flatten(),
                    nn.Linear(nc*ndf*8*8*8, 1)
        #             nn.Sigmoid()
                )

            def forward(self, img,labels):   
                pred=self.main(img,labels)
                return pred

    elif model_name==4: #### Model 4: without embeddings
        class Generator(nn.Module):
            def __init__(self, gdict):
                super(Generator, self).__init__()

                ## Define new variables from dict
                keys=['num_classes','ngpu','nz','nc','ngf','kernel_size','stride','g_padding']
                num_classes, ngpu, nz,nc,ngf,kernel_size,stride,g_padding=list(collections.OrderedDict({key:gdict[key] for key in keys}).values())

                self.main = nn.Sequential(
                    # nn.ConvTranspose2d(in_channels, out_channels, kernel_size,stride,padding,output_padding,groups,bias, Dilation,padding_mode)
                    nn.Linear(nz+1,nc*ngf*8*8*8),# 32768
                    nn.BatchNorm2d(nc,eps=1e-05, momentum=0.9, affine=True),
                    nn.ReLU(inplace=True),
                    View(shape=[-1,ngf*8,8,8]),
                    nn.ConvTranspose2d(ngf * 8, ngf * 4, kernel_size, stride, g_padding, output_padding=1, bias=False),
                    nn.BatchNorm2d(ngf*4,eps=1e-05, momentum=0.9, affine=True),
                    nn.ReLU(inplace=True),
                    # state size. (ngf*4) x 8 x 8
                    nn.ConvTranspose2d( ngf * 4, ngf * 2, kernel_size, stride, g_padding, 1, bias=False),
                    nn.BatchNorm2d(ngf*2,eps=1e-05, momentum=0.9, affine=True),
                    nn.ReLU(inplace=True),
                    # state size. (ngf*2) x 16 x 16
                    nn.ConvTranspose2d( ngf * 2, ngf, kernel_size, stride, g_padding, 1, bias=False),
                    nn.BatchNorm2d(ngf,eps=1e-05, momentum=0.9, affine=True),
                    nn.ReLU(inplace=True),
                    # state size. (ngf) x 32 x 32
                    nn.ConvTranspose2d( ngf, nc, kernel_size, stride,g_padding, 1, bias=False),
                    nn.Tanh()
                )

            def forward(self, noise,labels):
                x=labels.unsqueeze(-1).unsqueeze(-1).float()
                print(x.shape)
                gen_input=torch.cat((noise,x),-1)
                img=self.main(gen_input)

                return img

        class Discriminator(nn.Module):
            def __init__(self, gdict):
                super(Discriminator, self).__init__()

                ## Define new variables from dict
                keys=['num_classes','ngpu','nz','nc','ndf','kernel_size','stride','d_padding']
                num_classes, ngpu, nz,nc,ndf,kernel_size,stride,d_padding=list(collections.OrderedDict({key:gdict[key] for key in keys}).values())

                self.linear_transf=nn.Linear(4,4)
                self.main = nn.Sequential(
                    # input is (nc) x 64 x 64
                    # nn.Conv2d(in_channels, out_channels, kernel_size,stride,padding,output_padding,groups,bias, Dilation,padding_mode)
                    nn.Conv2d(nc+1, ndf,kernel_size, stride, d_padding,  bias=True),
                    nn.BatchNorm2d(ndf,eps=1e-05, momentum=0.9, affine=True),
                    nn.LeakyReLU(0.2, inplace=True),
                    # state size. (ndf) x 32 x 32
                    nn.Conv2d(ndf, ndf * 2, kernel_size, stride, d_padding, bias=True),
                    nn.BatchNorm2d(ndf * 2,eps=1e-05, momentum=0.9, affine=True),
                    nn.LeakyReLU(0.2, inplace=True),
                    # state size. (ndf*2) x 16 x 16
                    nn.Conv2d(ndf * 2, ndf * 4, kernel_size, stride, d_padding, bias=True),
                    nn.BatchNorm2d(ndf * 4,eps=1e-05, momentum=0.9, affine=True),
                    nn.LeakyReLU(0.2, inplace=True),
                    # state size. (ndf*4) x 8 x 8
                    nn.Conv2d(ndf * 4, ndf * 8, kernel_size, stride, d_padding, bias=True),
                    nn.BatchNorm2d(ndf * 8,eps=1e-05, momentum=0.9, affine=True),
                    nn.LeakyReLU(0.2, inplace=True),
                    # state size. (ndf*8) x 4 x 4
                    nn.Flatten(),
                    nn.Linear(nc*ndf*8*8*8, 1)
        #             nn.Sigmoid()
                )

            def forward(self, img,labels):
                img_size=gdict['image_size']
                x=labels.unsqueeze(-1).unsqueeze(-1).repeat(1,1,1,4).float() # get to size (128,1,1,4)
                print(x.shape)
                x=self.linear_transf(x)
                x=torch.repeat_interleave(x,int((img_size*img_size)/4)) # get to size (128,1, 128, 128)
                x=x.view(labels.size(0),1,img_size,img_size)
#                 print(x.size())
                d_input=torch.cat((img,x),axis=1)
                pred=self.main(d_input)
                return pred

    elif model_name==5:#### Model 5: with ConditionalInstanceNorm2d
        class ConditionalInstanceNorm2d(nn.Module):
            def __init__(self, num_features, num_params):
                super().__init__()
                self.num_features = num_features
                self.InstNorm = nn.InstanceNorm2d(num_features, affine=False)
                self.affine = nn.Linear(num_params, num_features * 2)
                self.affine.weight.data[:, :num_features].normal_(1, 0.02)  # Initialise scale at N(1, 0.02)
                self.affine.weight.data[:, num_features:].zero_()  # Initialise bias at 0

            def forward(self, x, y):
                out = self.InstNorm(x)
                gamma, beta = self.affine(y).chunk(2, 1)
                out = gamma.view(-1, self.num_features, 1, 1) * out + beta.view(-1, self.num_features, 1, 1)
                return out

        class ConditionalSequential(nn.Sequential):
            def __init__(self,*args):
                super(ConditionalSequential, self).__init__(*args)

            def forward(self, inputs, labels):
                for module in self:
                    if module.__class__ is ConditionalInstanceNorm2d:
                        inputs = module(inputs, labels.float())
                    else:
                        inputs = module(inputs)

                return inputs

        class Generator(nn.Module):
            def __init__(self, gdict):
                super(Generator, self).__init__()

                ## Define new variables from dict
                keys=['num_classes','ngpu','nz','nc','ngf','kernel_size','stride','g_padding']
                num_classes, ngpu, nz,nc,ngf,kernel_size,stride,g_padding=list(collections.OrderedDict({key:gdict[key] for key in keys}).values())

                self.main = ConditionalSequential(
                    # nn.ConvTranspose2d(in_channels, out_channels, kernel_size,stride,padding,output_padding,groups,bias, Dilation,padding_mode)
                    nn.Linear(nz,nc*ngf*8*8*8),# 32768
                    nn.BatchNorm2d(nc,eps=1e-05, momentum=0.9, affine=True),
                    nn.ReLU(inplace=True),
                    View(shape=[-1,ngf*8,8,8]),
                    nn.ConvTranspose2d(ngf * 8, ngf * 4, kernel_size, stride, g_padding, output_padding=1, bias=False),
        #             nn.BatchNorm2d(ngf*4,eps=1e-05, momentum=0.9, affine=True),
                    ConditionalInstanceNorm2d(ngf*4,1),
                    nn.ReLU(inplace=True),
                    # state size. (ngf*4) x 8 x 8
                    nn.ConvTranspose2d( ngf * 4, ngf * 2, kernel_size, stride, g_padding, 1, bias=False),
        #             nn.BatchNorm2d(ngf*2,eps=1e-05, momentum=0.9, affine=True),
                    ConditionalInstanceNorm2d(ngf*2,1),
                    nn.ReLU(inplace=True),
                    # state size. (ngf*2) x 16 x 16
                    nn.ConvTranspose2d( ngf * 2, ngf, kernel_size, stride, g_padding, 1, bias=False),
        #             nn.BatchNorm2d(ngf,eps=1e-05, momentum=0.9, affine=True),
                    ConditionalInstanceNorm2d(ngf,1),
                    nn.ReLU(inplace=True),
                    # state size. (ngf) x 32 x 32
                    nn.ConvTranspose2d( ngf, nc, kernel_size, stride,g_padding, 1, bias=False),
                    nn.Tanh()
                )

            def forward(self, noise,labels):
                img=self.main(noise,labels)

                return img

        class Discriminator(nn.Module):
            def __init__(self, gdict):
                super(Discriminator, self).__init__()

                ## Define new variables from dict
                keys=['num_classes','ngpu','nz','nc','ndf','kernel_size','stride','d_padding']
                num_classes, ngpu, nz,nc,ndf,kernel_size,stride,d_padding=list(collections.OrderedDict({key:gdict[key] for key in keys}).values())

                self.main = ConditionalSequential(
                    # input is (nc) x 64 x 64
                    # nn.Conv2d(in_channels, out_channels, kernel_size,stride,padding,output_padding,groups,bias, Dilation,padding_mode)
                    nn.Conv2d(nc, ndf,kernel_size, stride, d_padding,  bias=True),
        #             nn.BatchNorm2d(ndf,eps=1e-05, momentum=0.9, affine=True),
                    ConditionalInstanceNorm2d(ndf,1),
                    nn.LeakyReLU(0.2, inplace=True),
                    # state size. (ndf) x 32 x 32
                    nn.Conv2d(ndf, ndf * 2, kernel_size, stride, d_padding, bias=True),
        #             nn.BatchNorm2d(ndf * 2,eps=1e-05, momentum=0.9, affine=True),
                    ConditionalInstanceNorm2d(ndf*2,1),
                    nn.LeakyReLU(0.2, inplace=True),
                    # state size. (ndf*2) x 16 x 16
                    nn.Conv2d(ndf * 2, ndf * 4, kernel_size, stride, d_padding, bias=True),
        #             nn.BatchNorm2d(ndf * 4,eps=1e-05, momentum=0.9, affine=True),
                    ConditionalInstanceNorm2d(ndf*4,1),

                    nn.LeakyReLU(0.2, inplace=True),
                    # state size. (ndf*4) x 8 x 8
                    nn.Conv2d(ndf * 4, ndf * 8, kernel_size, stride, d_padding, bias=True),
        #             nn.BatchNorm2d(ndf * 8,eps=1e-05, momentum=0.9, affine=True),
                    ConditionalInstanceNorm2d(ndf*8,1),
                    nn.LeakyReLU(0.2, inplace=True),
                    # state size. (ndf*8) x 4 x 4
                    nn.Flatten(),
                    nn.Linear(nc*ndf*8*8*8, 1)
        #             nn.Sigmoid()
                )

            def forward(self, img,labels):   
                pred=self.main(img,labels)
                return pred

    return Generator, Discriminator








In [None]:
Generator, Discriminator=f_get_model(4,gdict)

In [None]:
netG = Generator(gdict).to(gdict['device'])
netG.apply(weights_init)
netD = Discriminator(gdict).to(gdict['device'])
netD.apply(weights_init)
print()

In [None]:
noise = torch.randn(gdict['batch_size'], 1, 1, gdict['nz'], device=gdict['device'])
fake_categories=torch.randint(gdict['num_classes'],(gdict['batch_size'],1),device=gdict['device'])
fake_categories=f_get_sigma(fake_categories,gdict)
fake = netG(noise,fake_categories)    

print(noise.shape,fake_categories.shape,fake.shape)

In [None]:
output=netD(fake,fake_categories)
print(output.shape)

In [None]:
## Total number of parameters
sum(p.numel() for p in netD.parameters() if p.requires_grad)

In [None]:
# summary(netG,[(1,1,64),(1,1)])
# summary(netD,[(1,1,64),(1,1,1)])

In [None]:
f_get_spec_cond(fake,fake_categories,gdict,spec_mean_tnsr,spec_sdev_tnsr,r,ind)


In [None]:
fake_categories

In [None]:
sigma_list[

In [None]:
fake_categories.shape

In [None]:
for i in fake_categories[:10]:
#     print(i.shape)
    print(sigma_list[i])

In [None]:
if gdict['model']>3: fake_categories=f_get_sigma(fake_categories,gdict,sigma_list)

