# model architecture

In [1]:
import torch
import torch.nn as nn
from torch.nn import init
import random
import numpy as np
import torch.optim as optim
from torchvision import models
import torch.nn.functional as F
from evaluator import evaluation_model as pre_cla
import copy
from torchvision.utils import make_grid, save_image
import dataset
from dataset import *
from tqdm import tqdm 
import torch.nn.utils.spectral_norm as spectral_norm 
import torch


'''
this model implement simplest cgan, that is:
concatnate latent and label together
and send it into generator
this file use spectral normalization instead of bn
https://christiancosgrove.com/blog/2018/01/04/spectral-normalization-explained.html
'''

device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
print('GPU State:', device)

def weights_init_normal(m):
    classname = m.__class__.__name__
    if classname.find("Conv") != -1:
        torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find("BatchNorm2d") != -1:
        torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
        torch.nn.init.constant_(m.bias.data, 0.0)
        

class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.channels = channels
        
        self.block1 = nn.Conv2d(self.channels, self.channels, 3, padding=1)
#         nn.init.xavier_normal_(self.block1.weight)
        
        self.block2 = nn.Conv2d(self.channels, self.channels, 3, padding=1)
#         nn.init.xavier_normal_(self.block2.weight)
        
        self.act = nn.PReLU()
        self.bn1 = nn.BatchNorm2d(self.channels)
        self.bn2 = nn.BatchNorm2d(self.channels)
        
    def forward(self, x):
        residual = self.block1(x)
        residual = self.bn1(residual)
        residual = self.act(residual)
        residual = self.block2(residual)
        residual = self.bn2(residual)
        return x + residual

class UpsampleBlock(nn.Module):
    def __init__(self, in_channels, up_scale):
        super().__init__()
        self.conv = nn.Conv2d(in_channels, in_channels * up_scale ** 2, kernel_size=3, padding=1)
#         nn.init.xavier_normal_(self.conv.weight)
        self.pixel_shuffle = nn.PixelShuffle(up_scale)
        
        self.act = nn.PReLU()
    def forward(self, x):
        x = self.conv(x)
        x = self.pixel_shuffle(x)
        x = self.act(x)
        return x
    
class generator(nn.Module):
    '''
    input:
    noise:[bs, z_dim] this is randomly generated by yourself
    labels:[bs, 24]
    for eg:bs=3 labels = [[0, 1, 1, 0, ..., 0], 
                          [1, 0, 0, 0, ..., 1],
                          [0, 0, 0, 1, ..., 0]]
    return:
    a [bs, 3, img_size, img_size] image
    '''
    
    def __init__(self, n_classes, img_size, z_dim, upsample_block_num, c_dim=256):
        super(generator, self).__init__()
        
        self.n_classes = n_classes
        self.z_dim = z_dim
        self.in_size = n_classes + z_dim
        self.img_size = img_size
        self.upsample_block_num = upsample_block_num
        self.c_dim = c_dim
        
#         self.label_emb = nn.Embedding(self.n_classes, self.n_classes)
#如果真的要算，embedding感覺無法work了，就直接用onehot吧
        self.conditionExpand=nn.Sequential(
            nn.Linear(24, self.c_dim),
            nn.ReLU()
        )

        self.block1 = nn.Sequential(
            nn.Conv2d(self.c_dim + z_dim, 64, 9, padding = 4),
            nn.PReLU()
        )
#         nn.init.xavier_normal_(self.block1[0].weight)
        
        self.block2 = ResidualBlock(64)
        self.block3 = ResidualBlock(64)
        self.block4 = ResidualBlock(64)
        self.block5 = ResidualBlock(64)
        self.block6 = ResidualBlock(64)

        self.block7 = nn.Sequential(
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64)
        )
#         nn.init.xavier_normal_(self.block7[0].weight),
        
        
        block8 = [UpsampleBlock(64, 2) for _ in range(self.upsample_block_num)]
        #decide how many times tdo you want to upsample
        
#         nn.init.xavier_normal_(self.c.weight)
        block8.append(nn.Conv2d(64, 3, kernel_size=9, padding=4))
        self.block8 = nn.Sequential(*block8)

        
    def forward(self, noise, labels):
        #noise should be a [bs, z_dim] dim tensor! RANDOM CHOOSE!
        #https://github.com/eriklindernoren/PyTorch-GAN/blob/master/implementations/cgan/cgan.py
        #gen_input is a [bs, z_dim+n_classes, 1, 1] tensor!
        
        # Concatenate label embedding and image to produce input
#         gen_input = torch.cat((self.label_emb(labels), noise), -1)#(bs, feature_dim(z_dim + n_classes))
        # Concatenate onehot label and image to produce input
        labels = self.conditionExpand(labels.float())#[bs, 256]
        gen_input = torch.cat((noise, labels), -1)#(bs, feature_dim(z_dim + n_classes))
        gen_input = gen_input.view(gen_input.size(0), gen_input.size(-1), 1, 1)#(bs, feat_dim, 1, 1)
        
        block1 = self.block1(gen_input)
        block2 = self.block2(block1)
        block3 = self.block3(block2)
        block4 = self.block4(block3)
        block5 = self.block5(block4)
        block6 = self.block6(block5)
        block7 = self.block7(block6)
        
        block8 = self.block8(block1 + block7)

        return (torch.tanh(block8) + 1) / 2
    
    
class discriminator(nn.Module):
    def __init__(self, n_classes, img_size):
        super(discriminator, self).__init__()
        self.img_size = img_size#default:64
        self.n_classes = n_classes
#         self.label_embedding = nn.Embedding(self.n_classes, self.n_classes)
#因為一個照片不只要一個label, 這沒用了
        
        self.convert_label_layer = nn.Sequential(
            nn.Linear(self.n_classes, self.img_size**2),
            nn.LeakyReLU()
        )
        
        self.net = nn.Sequential(
            nn.Conv2d(4, 64, kernel_size=3, padding=1),
            nn.LeakyReLU(0.2),

            nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2),

            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),

            nn.Conv2d(128, 128, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),

            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2),

            nn.Conv2d(256, 256, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2),

            nn.Conv2d(256, 512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2),

            nn.Conv2d(512, 512, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2),

            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(512, 1024, kernel_size=1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(1024, 1, kernel_size=1)
        )

    def forward(self, img, labels):
        '''
        input:
        img:[bs, 3, img_size=64, img_size=64]
        labels:[bs, 24], similar to above
        
        return:[bs] tensor for true prob
        '''
        # Concatenate label and image to produce input
#         c = self.convert_label_layer(self.label_embedding(labels)).reshape(-1, 1, self.img_size, self.img_size
        c = self.convert_label_layer(labels.float()).view(-1, 1, self.img_size, self.img_size)#(bs, 1, imgsize, imgsize)
        out = torch.cat((img, c), 1)#concatenate img tensor and c(label) tensor
        
        out = self.net(out).view(out.size(0)) 
        
        return torch.sigmoid(out)




GPU State: cpu


# loss function


In [2]:
##########################################LOSS part########################################
class GeneratorLoss(nn.Module):
    def __init__(self):
        super(GeneratorLoss, self).__init__()
        vgg = models.vgg16(pretrained=True)
        loss_network = nn.Sequential(*list(vgg.features)[:31]).eval()
        for param in loss_network.parameters():
            param.requires_grad = False
        self.loss_network = loss_network
        self.mse_loss = nn.MSELoss()
#         self.L1loss = nn.L1Loss()
        self.tv_loss = TVLoss()

    def forward(self, out_labels, out_images, target_images):
        # Adversarial Loss
        adversarial_loss = nn.BCELoss()(out_labels, smooth_label('real', out_labels.shape))
        # Perception Loss
        perception_loss = self.mse_loss(self.loss_network(out_images), self.loss_network(target_images))
        # Image Loss
        image_loss = self.mse_loss(out_images, target_images)
        # TV Loss
        tv_loss = self.tv_loss(out_images)
        return image_loss + 0.001 * adversarial_loss + 0.006 * perception_loss + 2e-8 * tv_loss
#         return adversarial_loss

# introduction of tv loss: https://www.daimajiaoliu.com/daima/479773d4d1003fe
class TVLoss(nn.Module):
    def __init__(self, tv_loss_weight=1):
        super(TVLoss, self).__init__()
        self.tv_loss_weight = tv_loss_weight

    def forward(self, x):
        batch_size = x.size()[0]
        h_x = x.size()[2]
        w_x = x.size()[3]
        count_h = self.tensor_size(x[:, :, 1:, :])
        count_w = self.tensor_size(x[:, :, :, 1:])
        h_tv = torch.pow((x[:, :, 1:, :] - x[:, :, :h_x - 1, :]), 2).sum()
        w_tv = torch.pow((x[:, :, :, 1:] - x[:, :, :, :w_x - 1]), 2).sum()
        return self.tv_loss_weight * 2 * (h_tv / count_h + w_tv / count_w) / batch_size

    @staticmethod
    def tensor_size(t):
        return t.size()[1] * t.size()[2] * t.size()[3]

# train.py

In [3]:
#####################################################################################################

def random_z(batch_size, z_dim):
    return torch.randn(batch_size, z_dim, device=device)

def save_img(images, path):
    out = make_grid(images)
    save_image(out, path)
def smooth_label(mode, shape):
    
    if mode == 'real':
        return (torch.rand(shape, device=device) / 10) + 0.9
    elif mode == 'fake':
        return torch.rand(shape, device=device) / 10


def train(G, D, epochs,lr_g, lr_d, train_loader, test_loader):
#     adversarial_criterion=nn.BCELoss().to(device)
    generator_criterion = GeneratorLoss().to(device)
    
    total_loss_g = 0
    total_loss_d = 0
    num = 0
    best_score = 0
    
    test_label = next(iter(test_loader)).to(device)#[bs, 24]
    test_z = random_z(test_label.size(0), z_dim).to(device)
    
    G.train().to(device)
    D.train().to(device)
    
    G.apply(weights_init_normal)
    D.apply(weights_init_normal)
    
    optimizer_g = optim.Adam(G.parameters(), lr=lr_g, betas=(0, 0.9))
    optimizer_d = optim.Adam(D.parameters(), lr=lr_d, betas=(0, 0.9))
    
    
    for epoch in range(1, epochs):
        num=0
        
        for image, label in tqdm(train_loader):
            image = image.to(device)
            label = label.to(device)
            num += label.size(0)
            real = image
            
            ############################
            #   Update D network: maximize D(x)-1-D(G(z))  
            ############################
                
            fake_img = G(random_z(label.size(0), z_dim), label)
            
            optimizer_d.zero_grad()
            real_out = D(real, label)
            fake_out = D(fake_img, label)
            
            loss_real = nn.BCELoss()(real_out, smooth_label('real', real_out.shape))
            loss_fake = nn.BCELoss()(fake_out, smooth_label('fake', fake_out.shape))
#             real_out = D(real, label).mean()
#             fake_out = D(fake_img, label).mean()
#             loss_d = 1 - real_out + fake_out
            loss_d = loss_real + loss_fake


            total_loss_d += loss_d.item()
            loss_d.backward(retain_graph=True)
            optimizer_d.step()
            
            ############################
            #   Update G network: minimize 1-D(G(z)) + Perception Loss + Image Loss + TV Loss
            ############################
            #generate fake img
            optimizer_g.zero_grad()

            fake_img = G(random_z(label.size(0), z_dim), label)
            fake_out = D(fake_img, label)
            ##
            loss_g = generator_criterion(fake_out, fake_img, real)
            loss_g.backward()

            optimizer_g.step()

            total_loss_g += loss_g.item()
            
        # evaluate    
        G.eval()
        D.eval()
        
        with torch.no_grad():
            gen_imgs = G(test_z, test_label)
        score = pre_cla().eval(gen_imgs, test_label)
        gen_imgs = denorm(gen_imgs, device)
        
        if score >= best_score:
            best_score = score
            best_model_wts = copy.deepcopy(G.state_dict())
            torch.save(best_model_wts, os.path.join('cgan_srgan', 'paras_sn', f'epoch{epoch}_score{score:.2f}.pth'))
        print(f"Epoch[{epoch}/{epochs}]")
        total_loss_d /= num
        total_loss_g /= num
        
        print(f'score: {score}')
        print(f'generator_loss:{total_loss_g:.8f}')
        print(f'discriminator_loss:{total_loss_d:.8f}')
        print()
        # savefig
        save_image(gen_imgs, os.path.join('cgan_srgan', 'results_sn', f'epoch{epoch}.png'), nrow=8, normalize=False)
               
            
            

# start training

In [4]:
import dataset
from dataset import *

root_folder = 'data'
z_dim = 100
n_classes = 24
img_size = 64
batch_size = 32
upsample_block_num = 6
epochs = 200
lr_d = 0.00004
lr_g = 0.00001

train_set = ICLEVRLoader(root_folder, mode = 'train')
test_set = ICLEVRLoader(root_folder, mode = 'test')
train_loader = DataLoader(train_set,
                          batch_size=batch_size,
                          shuffle=True,
                         )
test_loader = DataLoader(test_set,
                          batch_size=32,
                          shuffle=False,
                         )

G = generator(n_classes, img_size, z_dim, upsample_block_num)
D = discriminator(24, 64)


# train(G, D, epochs, lr_g, lr_d, train_loader, test_loader)

> Found 18009 images...
