# model architecture

In [3]:
import torch
import torch.nn as nn
from torch.nn import init
import random
import numpy as np
import torch.optim as optim
from torchvision import models
import torch.nn.functional as F
from evaluator import evaluation_model as pre_cla
import copy
from torchvision.utils import make_grid, save_image
import dataset
from dataset import *
from tqdm import tqdm 
import torch.nn.utils.spectral_norm as spectral_norm 
import torch


'''
this model implement cgan+dcgan+wgan, that is:
https://christiancosgrove.com/blog/2018/01/04/spectral-normalization-explained.html
'''

device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
print('GPU State:', device)

def weights_init_normal(m):
    classname = m.__class__.__name__
    if classname.find("Conv") != -1:
        torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find("BatchNorm2d") != -1:
        torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
        torch.nn.init.constant_(m.bias.data, 0.0)
        
class Generator(nn.Module):
    def __init__(self, n_classes, img_size, z_dim, upsample_block_num, c_dim=256):
        super(Generator, self).__init__()
        self.img_size = img_size
        self.init_size = self.img_size // 4  #16
        
        self.z_dim=z_dim
        self.c_dim=c_dim
        self.latent_dim=self.z_dim+self.c_dim
        
        self.conditionExpand=nn.Sequential(
            nn.Linear(24, 256),
            nn.ReLU()
        )
        
        self.l1 = nn.Sequential(nn.Linear(self.latent_dim, 128 * self.init_size ** 2))
        
        self.conv_blocks = nn.Sequential(
            nn.BatchNorm2d(128),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 128, 3, stride=1, padding=1),
            nn.BatchNorm2d(128, 0.8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 64, 3, stride=1, padding=1),
            nn.BatchNorm2d(64, 0.8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 3, 3, stride=1, padding=1),
            nn.Tanh(),
        )

    def forward(self, noise, labels):
        labels = self.conditionExpand(labels.float())#[bs, 256]
        z = torch.cat((noise, labels), -1)#(bs, feature_dim(z_dim + n_classes))
        
        out = self.l1(z)
        out = out.view(out.shape[0], 128, self.init_size, self.init_size)
        img = self.conv_blocks(out)
        return img

# Discriminator
class Discriminator(nn.Module):
    def __init__(self, n_classes, img_size):
        super(Discriminator, self).__init__()
        self.img_size = img_size
        self.n_classes = n_classes
        self.convert_label_layer = nn.Sequential(
            nn.Linear(self.n_classes, self.img_size**2),
            nn.LeakyReLU()
        )
        
        self.main = nn.Sequential(
            nn.Conv2d(4, 128, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(128, 256, 4, 2, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(256, 512, 4, 2, 1, bias=False),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(512, 1024, 4, 2, 1, bias=False),
            nn.BatchNorm2d(1024),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(1024, 1, 4, 1, 0, bias=False),
            nn.Sigmoid())

    def forward(self, img, labels):
        c = self.convert_label_layer(labels.float()).view(-1, 1, self.img_size, self.img_size)#(bs, 1, imgsize, imgsize)
        out = torch.cat((img, c), 1)#concatenate img tensor and c(label) tensor
        
        return self.main(out)

#


            

GPU State: cpu


# train.py

In [None]:
def random_z(batch_size, z_dim):
    return torch.randn(batch_size, z_dim, device=device)

def save_img(images, path):
    out = make_grid(images)
    save_image(out, path)
def smooth_label(mode, shape):
    
    if mode == 'real':
        return (torch.rand(shape, device=device) / 10) + 0.9
    elif mode == 'fake':
        return torch.rand(shape, device=device) / 10

    
def train(G, D, epochs,lr_g, lr_d, train_loader, test_loader):
    adversarial_criterion=nn.BCELoss().to(device)
#     generator_criterion = GeneratorLoss().to(device)
    
    total_loss_g = 0
    total_loss_d = 0
    num = 0
    best_score = 0
    
    test_label = next(iter(test_loader)).to(device)#[bs, 24]
    test_z = random_z(test_label.size(0), z_dim).to(device)
    
    G.train().to(device)
    D.train().to(device)
    
    G.apply(weights_init_normal)
    D.apply(weights_init_normal)
    
    optimizer_g = optim.Adam(G.parameters(), lr=lr_g, betas=(0.5, 0.9))
    optimizer_d = optim.Adam(D.parameters(), lr=lr_d, betas=(0.5, 0.9))
    
    
    for epoch in range(1, epochs):
        num=0
        
        for image, label in tqdm(train_loader):
            image = image.to(device)
            label = label.to(device)
            num += label.size(0)
            real = image.detach()
            
            ############################
            #   Update D network: maximize D(x)-1-D(G(z))  
            ############################
                
            fake_img = G(random_z(label.size(0), z_dim), label).detach()
            
            optimizer_d.zero_grad()
            
            loss_d = -torch.mean(D(real, label)) + torch.mean(D(fake_img, label))
            
            total_loss_d += loss_d.item()
            
            loss_d.backward()
            optimizer_d.step()
            ###########clipping########
            for p in D.parameters():
                p.data.clamp_(-0.01, 0.01)
            
            ############################
            #   Update G network: minimize 1-D(G(z)) + Perception Loss + Image Loss + TV Loss
            ############################
            #generate fake img
#             for _ in range(4):
            optimizer_g.zero_grad()

            fake_img = G(random_z(label.size(0), z_dim), label)
            ##
            loss_g = -torch.mean(D(fake_img, label))
            loss_g.backward()

            optimizer_g.step()

            total_loss_g += loss_g.item()
            
        # evaluate    
        G.eval()
        D.eval()
        
        with torch.no_grad():
            gen_imgs = G(test_z, test_label)
        score = pre_cla().eval(gen_imgs, test_label)
        gen_imgs = denorm(gen_imgs, device)
        
        if score >= best_score:
            best_score = score
            best_model_wts = copy.deepcopy(G.state_dict())
            torch.save(best_model_wts, os.path.join('cgan_dcgan_wgan', 'paras', f'epoch{epoch}_score{score:.2f}.pth'))
        print(f"Epoch[{epoch}/{epochs}]")
        total_loss_d /= num
        total_loss_g /= num
        
        print(f'score: {score}')
        print(f'generator_loss:{total_loss_g:.8f}')
        print(f'discriminator_loss:{total_loss_d:.8f}')
        print()
        # savefig
        save_image(gen_imgs, os.path.join('cgan_dcgan_wgan', 'results', f'epoch{epoch}.png'), nrow=8, normalize=False)
               
            

# start training

In [4]:
import dataset
from dataset import *

root_folder = 'data'
z_dim = 100
n_classes = 24
img_size = 64
batch_size = 64
upsample_block_num = 6
epochs = 200
lr_d = 0.00005
lr_g = 0.00005

train_set = ICLEVRLoader(root_folder, mode = 'train')
test_set = ICLEVRLoader(root_folder, mode = 'test')
train_loader = DataLoader(train_set,
                          batch_size=batch_size,
                          shuffle=True,
                         )
test_loader = DataLoader(test_set,
                          batch_size=64,
                          shuffle=False,
                         )

G = Generator(n_classes, img_size, z_dim, upsample_block_num)
D = Discriminator(24, 64)


train(G, D, epochs, lr_g, lr_d, train_loader, test_loader)

  0%|          | 0/282 [00:00<?, ?it/s]

> Found 18009 images...


  0%|          | 1/282 [00:20<1:33:46, 20.02s/it]


KeyboardInterrupt: 