# model architecture

In [73]:
import torch
import torch.nn as nn
from torch.nn import init
import random
import numpy as np
import torch.optim as optim
from torchvision import models
import torch.nn.functional as F
from evaluator import evaluation_model as pre_cla
import copy
from torchvision.utils import make_grid, save_image
import dataset
from dataset import *
from tqdm import tqdm 
import torch.nn.utils.spectral_norm as spectral_norm 
import torch
import torch.nn.utils.spectral_norm as sn


'''
this file just use pretrained model paras from acgan+dcgan.py
to be the weight init of it again
'''

device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
print('GPU State:', device)

def weights_init_normal(m):
    classname = m.__class__.__name__
    if classname.find("Conv") != -1:
        torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find("BatchNorm2d") != -1:
        torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
        torch.nn.init.constant_(m.bias.data, 0.0)
        
class Generator(nn.Module):
    def __init__(self, n_classes, img_size, z_dim, c_dim=256):
        super(Generator, self).__init__()
        self.img_size = img_size
        self.init_size = self.img_size // 4  #16
        
        self.z_dim=z_dim
        self.c_dim=c_dim
        self.latent_dim=self.z_dim + self.c_dim
        
        self.conditionExpand=nn.Sequential(
            nn.Linear(24, 256),
            nn.ReLU()
        )
        
        self.l1 = nn.Sequential(nn.Linear(self.latent_dim, 128 * self.init_size ** 2))
        
        self.conv_blocks = nn.Sequential(
            nn.BatchNorm2d(128),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 128, 3, stride=1, padding=1),
            nn.BatchNorm2d(128, 0.8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 64, 3, stride=1, padding=1),
            nn.BatchNorm2d(64, 0.8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 3, 3, stride=1, padding=1),
            nn.Tanh(),
        )

    def forward(self, noise, labels):
        labels = self.conditionExpand(labels.float())#[bs, 256]
        z = torch.cat((noise, labels), -1)#(bs, feature_dim(z_dim + n_classes))
        
        out = self.l1(z)
        out = out.view(out.shape[0], 128, self.init_size, self.init_size)
        img = self.conv_blocks(out)
        return img

# Discriminator
class Discriminator(nn.Module):
    def __init__(self, n_classes, img_size):
        super(Discriminator, self).__init__()
        self.img_size = img_size
        self.n_classes = n_classes
        
        self.convert_label_layer = nn.Sequential(
            nn.Linear(self.n_classes, self.img_size**2),
            nn.LeakyReLU()
        )
        
        self.main = nn.Sequential(
            nn.Conv2d(4, 128, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(128, 256, 4, 2, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(256, 512, 4, 2, 1, bias=False),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(512, 1024, 4, 2, 1, bias=False),
            nn.BatchNorm2d(1024),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Flatten()
        )
        
        # Output layers
        self.adv_layer = nn.Sequential(nn.Linear(1024*4*4, 1), 
                                       nn.Sigmoid())
        
        self.aux_layer = nn.Sequential(nn.Linear(1024*4*4, self.n_classes), 
                                       nn.Sigmoid())


    def forward(self, img, labels):
        c = self.convert_label_layer(labels.float()).view(-1, 1, self.img_size, self.img_size)
        out = torch.cat((img, c), 1)
        out = self.main(out)
        adv = self.adv_layer(out)
        aux = self.aux_layer(out)
        return adv, aux

GPU State: cpu


# loss function

In [135]:

def adversarial_criterion():
    return nn.BCELoss().to(device)

def auxiliary_criterion():
    return nn.BCELoss().to(device)

class GeneratorLoss(nn.Module):
    def __init__(self):
        #modify the path to your own path
        super(GeneratorLoss, self).__init__()
        checkpoint = torch.load('data/classifier_weight.pth')
        self.resnet = models.resnet18(pretrained=False)
        self.resnet.fc = nn.Sequential(
            nn.Linear(512, 24),
            nn.Sigmoid()
        )
        self.resnet.load_state_dict(checkpoint['model'])
        self.loss_network = nn.Sequential(*list(self.resnet.children())[:-1])
        for param in self.loss_network.parameters():
            param.requires_grad = False
        self.loss_network.eval()
        self.loss_network = self.loss_network.cuda()

        self.classnum = 24
        self.mse_loss = nn.MSELoss()
#         self.L1loss = nn.L1Loss()
        self.tv_loss = TVLoss()

    def forward(self, out_labels, out_images, target_images):
        # Adversarial Loss
        adversarial_loss = nn.BCELoss()(out_labels, smooth_label('real', out_labels.shape))
        # Perception Loss
        perception_loss = self.mse_loss(self.loss_network(out_images), self.loss_network(target_images))
        # Image Loss
        image_loss = self.mse_loss(out_images, target_images)
        # TV Loss
        tv_loss = self.tv_loss(out_images)
        return image_loss + 0.001 * adversarial_loss + 0.1 * perception_loss + 2e-8 * tv_loss
#         return adversarial_loss

# introduction of tv loss: https://www.daimajiaoliu.com/daima/479773d4d1003fe
class TVLoss(nn.Module):
    def __init__(self, tv_loss_weight=1):
        super(TVLoss, self).__init__()
        self.tv_loss_weight = tv_loss_weight

    def forward(self, x):
        batch_size = x.size()[0]
        h_x = x.size()[2]
        w_x = x.size()[3]
        count_h = self.tensor_size(x[:, :, 1:, :])
        count_w = self.tensor_size(x[:, :, :, 1:])
        h_tv = torch.pow((x[:, :, 1:, :] - x[:, :, :h_x - 1, :]), 2).sum()
        w_tv = torch.pow((x[:, :, :, 1:] - x[:, :, :, :w_x - 1]), 2).sum()
        return self.tv_loss_weight * 2 * (h_tv / count_h + w_tv / count_w) / batch_size

    @staticmethod
    def tensor_size(t):
        return t.size()[1] * t.size()[2] * t.size()[3]

# train.py

In [104]:
def random_z(batch_size, z_dim):
    return torch.randn(batch_size, z_dim, device=device)

def save_img(images, path):
    out = make_grid(images)
    save_image(out, path)
def smooth_label(mode, shape):
    
    if mode == 'real':
        return (torch.rand(shape, device=device) / 10) + 0.9
    elif mode == 'fake':
        return torch.rand(shape, device=device) / 10

def train(G, D, epochs,lr_g, lr_d, train_loader, test_loader):
    
    
    generator_criterion = GeneratorLoss().to(device)
    adversarial_loss = adversarial_criterion()
    auxiliary_loss = auxiliary_criterion()
    
    total_loss_g = 0
    total_loss_d = 0
    num = 0
    best_score = 0
    
    flag=4#first train discriminator flag times 
    
    test_label = next(iter(test_loader)).to(device)#[bs, 24]
    test_z = random_z(test_label.size(0), z_dim).to(device)
    
    G.train().to(device)
    D.train().to(device)
    
    G.apply(weights_init_normal)
    D.apply(weights_init_normal)
    
    optimizer_g = optim.Adam(G.parameters(), lr=lr_g, betas=(0, 0.9))
    optimizer_d = optim.Adam(D.parameters(), lr=lr_d, betas=(0, 0.9))
    
    
    for epoch in range(1, epochs):
        num=0
        
        for image, label in tqdm(train_loader):
            image = image.to(device)
            label = label.to(device).float()
            num += label.size(0)
            real = image.detach()
            
            ############################
            #   Update D network: maximize D(x)-1-D(G(z))  
            ############################
#             for _ in range(4):

            fake = G(random_z(label.size(0), z_dim), label)

            optimizer_d.zero_grad()

            real_pred, real_aux = D(real, label)

            loss_real = (adversarial_loss(real_pred, smooth_label('real', real_pred.shape)) 
                         + auxiliary_loss(real_aux, label)) / 2


            fake_pred, fake_aux = D(fake, label)
            loss_fake = (adversarial_loss(fake_pred, torch.zeros(fake_pred.shape, device=device)) 
                         + auxiliary_loss(fake_aux, label)) / 2

            loss_d = (loss_real + loss_fake) / 2
#                 if flag!=1:
#                     print(loss_d.item())

            loss_d.backward()
            optimizer_d.step()

            total_loss_d += loss_d.item()
            
            ############################
            #   Update G network: minimize 1-D(G(z)) + Perception Loss + Image Loss + TV Loss
            ############################
            #generate fake img
            for _ in range(4):
                optimizer_g.zero_grad()

                fake_img = G(random_z(label.size(0), z_dim), label)

                validity, pred_label = D(fake_img, label)

                loss_au = 0.2 * auxiliary_loss(pred_label, label)
                
                loss_ge = 0.8 * generator_criterion(validity, fake_img, real)
                
                loss_g = loss_au + loss_ge
                
                loss_g.backward()


                optimizer_g.step()

            total_loss_g += loss_g.item()
            
        # evaluate    
        G.eval()
        D.eval()
        
        with torch.no_grad():
            gen_imgs = G(test_z, test_label)
        score = pre_cla().eval(gen_imgs, test_label)
        gen_imgs = denorm(gen_imgs, device)
        
        if score >= best_score:
            best_score = score
            best_model_wts = copy.deepcopy(G.state_dict())
            torch.save(best_model_wts, os.path.join('acgan_dcgan', 'paras_new', f'epoch{epoch}_score{score:.2f}.pth'))
        print(f"Epoch[{epoch}/{epochs}]")
        total_loss_d /= num
        total_loss_g /= num
        
        print(f'score: {score}')
        print(f'generator_loss:{total_loss_g:.8f}')
        print(f'discriminator_loss:{total_loss_d:.8f}')
        print()
        # savefig
        save_image(gen_imgs, os.path.join('acgan_dcgan', 'results_new', f'epoch{epoch}.png'), nrow=8, normalize=False)
               
            
            

# start training

In [76]:
import dataset
from dataset import *

root_folder = 'data'
z_dim = 100
n_classes = 24
img_size = 64
batch_size = 64
epochs = 200
lr_d = 0.0003
lr_g = 0.0001

train_set = ICLEVRLoader(root_folder, mode = 'train')
test_set = ICLEVRLoader(root_folder, mode = 'test')
train_loader = DataLoader(train_set,
                          batch_size=batch_size,
                          shuffle=True,
                         )
test_loader = DataLoader(test_set,
                          batch_size=batch_size,
                          shuffle=False,
                         )

G = Generator(n_classes, img_size, z_dim, )
D = Discriminator(24, 64)


train(G, D, epochs, lr_g, lr_d, train_loader, test_loader)

> Found 18009 images...


RuntimeError: Attempting to deserialize object on a CUDA device but torch.cuda.is_available() is False. If you are running on a CPU-only machine, please use torch.load with map_location=torch.device('cpu') to map your storages to the CPU.