# Run your own images through StarGAN
By Spencer Carter


Last updated: 2/3/18

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import os
import time
import datetime
from torch.autograd import grad
from torch.autograd import Variable
from torchvision.utils import save_image
from torchvision import transforms
from model import Generator
from model import Discriminator
from PIL import Image
from model import Generator

In [2]:
class TrainedGenerator(object):
    '''TrainedGenerator class loads up the model (model_save_path) from a .pth file
    and provides a method to score individual images from a file and list of
    ground truth attributes.
    '''
    def __init__(self, model_save_path, cuda=True):
        
        self.model_save_path = model_save_path
        self.cuda = cuda and torch.cuda.is_available()
        
        # Initialize model
        self.G = Generator()
        self.G.load_state_dict(torch.load(self.model_save_path))
        print("Generator weights loaded")
        
        # Cuda
        if self.cuda and torch.cuda.is_available():
            self.G = self.G.cuda()
            print("Using GPU")
        else:
            print("Using CPU")
        
        # Assuming fixed attributes, but could make these inputs to the class
        self.crop_size = 178
        self.image_size = 128
        self.image_shape = (3, self.image_size, self.image_size)
        self.transform = transforms.Compose([
                    transforms.CenterCrop(self.crop_size),
                    transforms.Resize(self.image_size),
                    transforms.RandomHorizontalFlip(),
                    transforms.ToTensor(),
                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
        
    def denorm(self, x):
        '''Lift and load from solver.py'''
        out = (x + 1) / 2
        return out.clamp_(0, 1)

    def to_var(self, x, volatile=False):
        '''Lift and load from solver.py'''
        if self.cuda:
            x = x.cuda()
        return Variable(x, volatile=volatile)

    def transform_img(self, img):
        '''Apply the transformations StarGAN uses,
        and reshape the tensor from 
        (3, 128, 128) -> (1, 3 ,128, 128)
        since the model is looking for a batch size
        '''
        tf = self.transform(img).view(1,*self.image_shape)
        return self.to_var(tf)

    def score_image(self, in_file, out_path, black_hair=1, blond_hair=0, brown_hair=1, male=0, young=1):
        ''' Take in an image file, and the output directory, 
        along with the 5 ground truth attributes for the model
        to perturb.
        self: Remove hair color. It's effectively not used.
        '''
        img = Image.open(in_file)

        real_x = self.transform_img(img)

        attr_list = [black_hair*1, blond_hair*1, brown_hair*1, male*1, young*1] # yeah yeah... implicit conversion is bad

        target_c_list = []
        for j, val in enumerate(attr_list):
            if val not in [0, 1]:
                val=0
            target_c = torch.Tensor(attr_list)
            if j in (0,1,2):
                target_c[0] = 0
                target_c[1] = 0
                target_c[2] = 0
            target_c[j] = 1-val
            target_c_list.append(self.to_var(target_c.view(1,-1), volatile=True))

        fake_image_list = [real_x]
        for target_c in target_c_list:
            fake_image_list.append(self.G(real_x, target_c))

        fake_images = torch.cat(fake_image_list, dim=3)
        name = in_file.split('/')[-1].split('.')[0]
        save_path = os.path.join(out_path, '{}_fake.png'.format(name))
        save_image(self.denorm(fake_images.data), save_path, nrow=1, padding=0)

In [3]:
gen = TrainedGenerator('./stargan_celebA/models/20_1000_G.pth', cuda=True)

Generator weights loaded
Using GPU


### Below is the quick and dirty approach... for many images, use the CelebA approach and have a delimited text file containing the GT attributes.


In [5]:
in_path = './data/self_data/'
out_path = './stargan_celebA/self_results/'

# ['Black_Hair', 'Blond_Hair', 'Brown_Hair', 'Male', 'Young']
friends = []
friends.append(['spencer_1.jpg', 0, 0, 1, 1, 1])
friends.append(['spencer_2.jpg', 0, 0, 1, 1, 1])
friends.append(['spencer_3.jpg', 0, 0, 1, 1, 1])

friends.append(['lucy_1.jpg', 0, 0, 1, 0, 1])
friends.append(['lucy_2.jpg', 0, 0, 1, 0, 1])
friends.append(['lucy_3.jpg', 0, 0, 1, 0, 1])

friends.append(['tron_1.jpg', 0, 0, 1, 1, 1])
friends.append(['tron_2.jpg', 0, 0, 1, 1, 1])

friends.append(['doss_1.jpg', 0, 0, 1, 0, 1])
friends.append(['doss_2.jpg', 0, 0, 1, 0, 1])

friends.append(['brian_1.jpg', 1, 0, 0, 1, 1])
friends.append(['brian_2.jpg', 1, 0, 0, 1, 1])

friends.append(['ben_1.jpg', 1, 0, 0, 1, 1])

friends.append(['katy_1.jpg', 1, 0, 0, 0, 1])

friends.append(['audra_1.jpg', 1, 0, 0, 0, 1])
friends.append(['audra_2.jpg', 1, 0, 0, 0, 1])

friends.append(['ethan_1.jpg', 0, 0, 1, 1, 1])

friends.append(['ryan_1.jpg', 0, 0, 1, 1, 1])
friends.append(['ryan_2.jpg', 0, 0, 1, 1, 1])

friends.append(['nga_1.jpg', 1, 0, 0, 0, 1])

friends.append(['abby_1.jpg', 0, 0, 1, 0, 1])
friends.append(['abby_2.jpg', 0, 0, 1, 0, 1])

for friend in friends:
    f_in = os.path.join(in_path, friend[0])
    gen.score_image(f_in, out_path, *friend[1:])