In [None]:
import torch
import torch.nn as nn
from torch.nn import init
import random
import numpy as np
import torch.optim as optim
from torchvision import models, utils
import torch.nn.functional as F
from evaluator import evaluation_model as pre_cla
import copy
from torchvision.utils import make_grid, save_image
import dataset
from dataset import *
from tqdm import tqdm 
import torch.nn.utils.spectral_norm as spectral_norm 
import torch


'''
this model implement acgan architecture
but add the label information to the discriminator
'''

device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
print('GPU State:', device)

class Generator(nn.Module):
    def __init__(self, n_classes, img_size, z_dim, upsample_block_num, c_dim=256):
        super(Generator, self).__init__()
        self.img_size = img_size
        self.init_size = self.img_size // 4  #16
        
        self.z_dim=z_dim
        self.c_dim=c_dim
        self.latent_dim=self.z_dim+self.c_dim
        
        self.conditionExpand=nn.Sequential(
            nn.Linear(24, 256),
            nn.ReLU()
        )
        
        self.l1 = nn.Sequential(nn.Linear(self.latent_dim, 128 * self.init_size ** 2))
        
        self.conv_blocks = nn.Sequential(
            nn.BatchNorm2d(128),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 128, 3, stride=1, padding=1),
            nn.BatchNorm2d(128, 0.8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 64, 3, stride=1, padding=1),
            nn.BatchNorm2d(64, 0.8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 3, 3, stride=1, padding=1),
            nn.Tanh(),
        )

    def forward(self, noise, labels):
        labels = self.conditionExpand(labels.float())#[bs, 256]
        z = torch.cat((noise, labels), -1)#(bs, feature_dim(z_dim + n_classes))
        
        out = self.l1(z)
        out = out.view(out.shape[0], 128, self.init_size, self.init_size)
        img = self.conv_blocks(out)
        return img
    
def random_z(batch_size, z_dim):
    return torch.randn(batch_size, z_dim, device=device)
            

In [None]:
import dataset
from dataset import *

root_folder = 'data'
z_dim = 100
n_classes = 24
img_size = 64
batch_size = 64
upsample_block_num = 6


In [None]:
import torchvision.transforms.functional as F

def show(imgs):
    if not isinstance(imgs, list):
        imgs = [imgs]
    fix, axs = plt.subplots(ncols=len(imgs), squeeze=False)
    for i, img in enumerate(imgs):
        img = img.detach()
        img = F.to_pil_image(img)
        axs[0, i].imshow(np.asarray(img))
        axs[0, i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
        
def test(model_path, G, mode, filename):
    '''
    model_path: your model path
    G: generator
    mode: 'test' or 'new_test'
    '''
    z_dim = 100
    G.load_state_dict(torch.load(model_path))
    G.eval()
#     G.load_state_dict(torch.load(model_path))
    test_set = ICLEVRLoader('data', mode = mode)
    test_loader = DataLoader(test_set,
                          batch_size=32,
                          shuffle=False,
                         )
    test_label = next(iter(test_loader)).to(device)#[bs, 24]
    avg_score=0
    
    for _ in range(10):
        z = torch.randn(test_label.size(0), z_dim).to(device)  # (N,100) tensor
        gen_imgs = G(z, test_label)
        evaluation_model = pre_cla()
        score = evaluation_model.eval(gen_imgs, test_label)
        print(f'score: {score:.2f}')
        avg_score += score
        
    gen_imgs = denorm(gen_imgs, device)
    save_image(gen_imgs,f'{filename}.png', nrow=8, normalize=False)
    print()
    print(f'avg score: {avg_score/10:.2f}')
    

In [None]:
G = Generator(n_classes, img_size, z_dim, upsample_block_num).to(device)
model_path = './acgan_dcgan/paras_new/epoch55_score0.85.pth'
mode = 'test'
test(model_path, G, mode, f'best_eval_{mode}')
print()

In [None]:
G = Generator(n_classes, img_size, z_dim, upsample_block_num).to(device)
model_path = './acgan_dcgan/paras_new/epoch55_score0.85.pth'
mode = 'new_test_2021_summer'
test(model_path, G, mode, f'best_eval_{mode}')
print()