In [1]:
pip install numpy==1.19.5 --user

Note: you may need to restart the kernel to use updated packages.


In [2]:
import torch
import torch.nn as nn
import numpy as np

#https://github.com/chongyangma/cs231n/blob/master/assignments/assignment3/style_transfer_pytorch.py
class TVLoss(nn.Module):
    def __init__(self, TVLoss_weight= 1):
        super(TVLoss,self).__init__()
        self.TVLoss_weight = TVLoss_weight

    def forward(self,x):    
        w_variance = torch.sum(torch.pow(x[:,:,:,:-1] - x[:,:,:,1:], 2))
        h_variance = torch.sum(torch.pow(x[:,:,:-1,:] - x[:,:,1:,:], 2))
        loss = self.TVLoss_weight * (h_variance + w_variance)
        return loss

#https://github.com/pytorch/pytorch/issues/9160#issuecomment-483048684
class Identity(nn.Module):
    def __init__(self, *args, **kwargs):
        super().__init__()
    def forward(self, x):
        return x
    
#https://github.com/aitorzip/PyTorch-CycleGAN/blob/master/models.py
class ResidualBlock(nn.Module):
    def __init__(self,in_features):
        super(ResidualBlock,self).__init__()
        conv_block = [  nn.ReflectionPad2d(1),
                        nn.Conv2d(in_features, in_features, 3),
                        nn.InstanceNorm2d(in_features),
                        nn.ReLU(inplace=True),
                        nn.ReflectionPad2d(1),
                        nn.Conv2d(in_features, in_features, 3),
                        nn.InstanceNorm2d(in_features)  ]

        self.conv_block = nn.Sequential(*conv_block)

    def forward(self, x):
        return x + self.conv_block(x) #skip connection

class Encoder(nn.Module):
    def __init__(self, in_nc, ngf=64):
        super(Encoder, self).__init__()

        #Inital Conv Block
        model = [   nn.ReflectionPad2d(3),
                    nn.Conv2d(in_nc, ngf, 7),
                    nn.InstanceNorm2d(ngf),
                    nn.ReLU(inplace=True) ]

        in_features = ngf
        out_features = in_features *2

        for _ in range(2):
            model += [
                nn.Conv2d(in_features, out_features, 3, stride=2, padding=1),
                nn.InstanceNorm2d(out_features),
                nn.ReLU(inplace=True)
            ]

            in_features = out_features
            out_features = in_features * 2

        self.model = nn.Sequential(*model)

    def forward(self,x):
        #Return batch w/ encoded content picture
        return [self.model(x['content']), x['style_label']]

class Decoder(nn.Module):
    def __init__(self, out_nc, ngf, n_residual_blocks=5):
        super(Decoder, self).__init__()

        in_features = ngf * 4
        out_features = in_features//2

        model = []
        for _ in range(n_residual_blocks):
            model += [ResidualBlock(in_features)]

        # Upsampling
        for _ in range(2):
            model += [  nn.ConvTranspose2d(in_features, out_features, 3, stride=2, padding=1, output_padding=1),
                        nn.InstanceNorm2d(out_features),
                        nn.ReLU(inplace=True) ]
            in_features = out_features
            out_features = in_features//2

        # Output layer
        model += [  nn.ReflectionPad2d(3),
                    nn.Conv2d(64, out_nc, 7),
                    nn.Tanh() ]

        self.model = nn.Sequential(*model)

    def forward(self,x):
        return self.model(x)

class Transformer(nn.Module):
    def __init__(self,n_styles, ngf, auto_id=True):
        super(Transformer, self).__init__()
        self.t = nn.ModuleList([ResidualBlock(ngf*4) for i in range(n_styles)])
        if auto_id:
            self.t.append(Identity())

    def forward(self,x):
        #x0 is content, x[1][0] is label
        label = x[1][0]
        mix = np.sum([self.t[i](x[0])*v for (i,v) in enumerate(label) if v])
        #return content transformed by style specific residual block
        return mix

class Generator(nn.Module):
    def __init__(self,in_nc,out_nc,n_styles,ngf):
        super(Generator, self).__init__()

        self.encoder = Encoder(in_nc,ngf)
        self.transformer = Transformer(n_styles,ngf)
        self.decoder = Decoder(out_nc,ngf)

    def forward(self,x):
        e = self.encoder(x)
        t = self.transformer(e)
        d = self.decoder(t)
        return d

class Discriminator(nn.Module):
    """
    Patch-Gan discriminator 
    """
    def __init__(self, in_nc, n_styles, ndf=64):
        super(Discriminator, self).__init__()

        # A bunch of convolutions 
        model = [   nn.Conv2d(in_nc, 256, 4, stride=4),
                    nn.ReLU()]

        model += [  nn.Conv2d(256, 512, 2, stride=2),
                    #nn.LeakyReLU(0.2, inplace=True)
                    nn.ReLU()]
        

        self.model = nn.Sequential(*model)

        # GAN (real/notreal) Output-
        self.fldiscriminator = nn.Conv2d(512, 1, 1,stride=1, padding = 0)
        self.sig = nn.Sigmoid()
        self.pool = nn.AvgPool2d(1)

        # Classification Output
        self.aux_clf = nn.Conv2d(512, n_styles, 1, padding = 2)

    def forward(self, x):
        base =  self.model(x)
        # input : [1, 3, 128, 128]
        # output: [1, 512, 16, 16]
        # discrim: [1, 1, 17, 17]; clf: [1, 17, 17, 4]
        discrim = self.fldiscriminator(base)
        discrim = self.sig(discrim)
        discrim = self.pool(discrim)
        clf = self.aux_clf(base).transpose(1,3)

        return [discrim,clf]


In [3]:
print(np.__version__)

1.19.5


In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np


from torch.utils.data import Dataset
import torchvision.transforms as transforms
from PIL import Image
import glob
import random
import os
import torch
from torch.autograd import Variable
from torchvision.transforms import InterpolationMode


class ImageDataset(Dataset):
    def __init__(self, root_img, root_style, transforms_=None, mode='train'):
        transforms_ = [ transforms.Resize(int(143), InterpolationMode.BICUBIC), 
                transforms.RandomCrop(128), 
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5)) 
              ]
        #content source
        self.transform = transforms.Compose(transforms_)
        self.X = sorted(glob.glob(os.path.join(root_img, '*')))
        #style image source(s)
        self.Y = []
        style_sources = sorted(glob.glob(os.path.join(root_style, '*')))
        for label, style in enumerate(style_sources):
            temp = [(label, x) for x in sorted(glob.glob(style_sources[label]+"/*"))]
            self.Y.extend(temp)
    def __len__(self):
        return max(len(self.X), len(self.Y))
    def __getitem__(self, index):                                    
        output = {}
        output['content'] = self.transform(Image.open(self.X[index % len(self.X)]))
        #select style
        selection = self.Y[random.randint(0, len(self.Y) - 1)]
        try:
            output['style'] = self.transform(Image.open(selection[1]))
        except:
            selection = self.Y[random.randint(0, len(self.Y) - 1)]
            output['style'] = self.transform(Image.open(selection[1]))
            # print('thisuns grey')
            # print(selection)
        output['style_label'] = selection[0]
        return output
from torch.utils.data import DataLoader
root = ImageDataset('./VOCdevkit/VOC2007/JPEGImages', './Gated_patch_voc2007/trainB')
print(len(root.X), len(root.Y), root.transform)

9963 80 Compose(
    Resize(size=143, interpolation=bicubic, max_size=None, antialias=None)
    RandomCrop(size=(128, 128), padding=None)
    RandomHorizontalFlip(p=0.5)
    ToTensor()
    Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
)


In [5]:
from torch.utils.data import TensorDataset, DataLoader

In [6]:
dataloader = DataLoader(root, batch_size=1, shuffle=True)
batch = next(iter(dataloader))
print(1)


1


In [7]:
class ReplayBuffer():
    def __init__(self, max_size=50):
        assert (max_size > 0), 'Empty buffer or trying to create a black hole. Be careful.'
        self.max_size = max_size
        self.data = []

    def push_and_pop(self, data):
        to_return = []
        for element in data.data:
            element = torch.unsqueeze(element, 0)
            if len(self.data) < self.max_size:
                self.data.append(element)
                to_return.append(element)
            else:
                if random.uniform(0,1) > 0.5:
                    i = random.randint(0, self.max_size-1)
                    to_return.append(self.data[i].clone())
                    self.data[i] = element
                else:
                    to_return.append(element)
        return Variable(torch.cat(to_return))

def label2tensor(label,tensor):
    for i in range(label.size(0)):
        tensor[i].fill_(label[i])
    return tensor

# def tensor2image(tensor):
#     image = 127.5*(tensor[0].cpu().float().numpy() + 1.0)
#     if image.shape[0] == 1:
#         image = np.tile(image, (3,1,1))
#     return image.astype(np.uint8)

def weights_init_normal(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm2d') != -1:
        torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
        torch.nn.init.constant(m.bias.data, 0.0)

class LambdaLR():
    def __init__(self, n_epochs, offset, decay_start_epoch):
        assert ((n_epochs - decay_start_epoch) > 0), "Decay must start before the training session ends!"
        self.n_epochs = n_epochs
        self.offset = offset
        self.decay_start_epoch = decay_start_epoch

    def step(self, epoch):
        return 1.0 - max(0, epoch + self.offset - self.decay_start_epoch)/(self.n_epochs - self.decay_start_epoch)

In [None]:
#TRAIN OPTIONS FROM GATED GAN
epoch = 0
n_epochs = 16
decay_epoch=1
batchSize = 1
dataroot = './photo2fourcollection'
loadSize = 143
fineSize = 128
ngf = 64
ndf = 64    
in_nc = 3 
out_nc = 3 
lr = 0.0002 
gpu = 1 
lambda_A = 10.0
pool_size = 50
resize_or_crop = 'resize_and_crop'
autoencoder_constrain = 10 
n_styles = 4
cuda=True
tv_strength=2e-6

generator = Generator(in_nc, out_nc, n_styles, ngf)
discriminator= Discriminator(in_nc, n_styles, ndf)

if cuda:
    generator.cuda()
    discriminator.cuda()

#Losses Init
use_lsgan=True
if use_lsgan:
    criterion_GAN = nn.MSELoss()
else: 
    criterion_GAN = nn.BCELoss()
    

criterion_ACGAN = nn.CrossEntropyLoss()
criterion_Rec = nn.L1Loss()
criterion_TV = TVLoss(TVLoss_weight=tv_strength)

#Optimizers & LR schedulers
optimizer_G = torch.optim.Adam(generator.parameters(),
                                lr=lr, betas=(0.5, 0.999))
optimizer_D = torch.optim.Adam(discriminator.parameters(), 
                               lr=lr, betas=(0.5, 0.999))


lr_scheduler_G = torch.optim.lr_scheduler.LambdaLR(optimizer_G, lr_lambda=LambdaLR(n_epochs, epoch,decay_epoch).step)
lr_scheduler_D = torch.optim.lr_scheduler.LambdaLR(optimizer_D, lr_lambda=LambdaLR(n_epochs,epoch, decay_epoch).step)

#Set vars for training
Tensor = torch.cuda.FloatTensor if cuda else torch.Tensor
input_A = Tensor(batchSize, in_nc, fineSize, fineSize)
input_B = Tensor(batchSize, out_nc, fineSize, fineSize)
target_real = Variable(Tensor(batchSize).fill_(1.0), requires_grad=False)
target_fake = Variable(Tensor(batchSize).fill_(0.0), requires_grad=False)

D_A_size = discriminator(input_A.copy_(batch['style']))[0].size()  
D_AC_size = discriminator(input_B.copy_(batch['style']))[1].size()

class_label_B = Tensor(D_AC_size[0],D_AC_size[1],D_AC_size[2]).long()

autoflag_OHE = Tensor(1,n_styles+1).fill_(0).long()
autoflag_OHE[0][-1] = 1

fake_label = Tensor(D_A_size).fill_(0.0)
real_label = Tensor(D_A_size).fill_(0.99) 

rec_A_AE = Tensor(batchSize,in_nc,fineSize,fineSize)

fake_buffer = ReplayBuffer()

##Init Weights
generator.apply(weights_init_normal)
discriminator.apply(weights_init_normal)



### TRAIN LOOP
for epoch in range(epoch, n_epochs):
    for i, batch in enumerate(dataloader):
        ## Unpack minibatch
        # source content
        real_content = Variable(input_A.copy_(batch['content']))
        # target style
        real_style = Variable(input_B.copy_(batch['style']))
        # style label
        style_label = batch['style_label']
        # one-hot encoded style
        style_OHE = F.one_hot(style_label,n_styles).long()
        # style Label mapped over 1x19x19 tensor for patch discriminator 
        class_label = class_label_B.copy_(label2tensor(style_label,class_label_B)).long()
        
        #### Update Discriminator
        optimizer_D.zero_grad()
        
        # Generate style-transfered image
        genfake = generator({
            'content':real_content,
            'style_label': style_OHE})
        
        # Add generated image to image pool and randomly sample pool 
        fake = fake_buffer.push_and_pop(genfake)
        # Discriminator forward pass with sampled fake 
        out_gan, out_class = discriminator(fake)
        # Discriminator Fake loss (correctly identify generated images)
        errD_fake = criterion_GAN(out_gan, fake_label)
        # Backward pass and parameter optimization
        errD_fake.backward()
        optimizer_D.step()
        
        optimizer_D.zero_grad()
        # Discriminator forward pass with target style
        out_gan, out_class = discriminator(real_style)
        # Discriminator Style Classification loss
        errD_real_class = criterion_ACGAN(out_class.transpose(1,3),class_label)*lambda_A
        # Discriminator Real loss (correctly identify real style images)
        errD_real = criterion_GAN(out_gan, real_label)        
        errD_real_total = errD_real + errD_real_class
        # Backward pass and parameter optimization
        errD_real_total.backward()
        optimizer_D.step()
        
        
        errD = (errD_real+errD_fake)/2.0
        
                
        #### Generator Update
        ## Style Transfer Loss
        optimizer_G.zero_grad()
        
        # Discriminator forward pass with generated style transfer
        out_gan, out_class = discriminator(genfake)
        
        # Generator gan (real/fake) loss
        err_gan = criterion_GAN(out_gan, real_label)
        # Generator style class loss
        err_class = criterion_ACGAN(out_class.transpose(1,3), class_label)*lambda_A
        # Total Variation loss
        err_TV = criterion_TV(genfake)
        
        errG_tot = err_gan + err_class + err_TV
        errG_tot.backward()
        optimizer_G.step()
        
        ## Auto-Encoder (Recreation) Loss
        optimizer_G.zero_grad()
        identity = generator({
            'content': real_content,
            'style_label': autoflag_OHE,
        })
        err_ae = criterion_Rec(identity,real_content)*autoencoder_constrain
        err_ae.backward()
        optimizer_G.step()
        if i % 20 == 0:
            print('Batch:', i)
            print("Discriminator loss:", errD_real_total.item(), "Generator loss:", errG_tot.item())

        
    
    ##update learning rates
    lr_scheduler_G.step()
    lr_scheduler_D.step()
    
    #Save model
    torch.save(generator.state_dict(), 'output/netG.pth')
    torch.save(discriminator.state_dict(), 'output/netD.pth')

Batch: 0
Discriminator loss: 14.162635803222656 Generator loss: 13.963260650634766
Batch: 20
Discriminator loss: 14.413558959960938 Generator loss: 14.272843360900879
Batch: 40
Discriminator loss: 10.688887596130371 Generator loss: 13.021153450012207
Batch: 60
Discriminator loss: 13.797469139099121 Generator loss: 13.379288673400879
Batch: 80
Discriminator loss: 11.066908836364746 Generator loss: 10.75394344329834
Batch: 100
Discriminator loss: 8.040660858154297 Generator loss: 8.807954788208008
Batch: 120
Discriminator loss: 6.967514991760254 Generator loss: 5.838262557983398
Batch: 140
Discriminator loss: 8.105093955993652 Generator loss: 7.01444673538208
Batch: 160
Discriminator loss: 10.052586555480957 Generator loss: 8.930877685546875
Batch: 180
Discriminator loss: 11.936324119567871 Generator loss: 9.769805908203125
Batch: 200
Discriminator loss: 15.019433975219727 Generator loss: 15.117688179016113
Batch: 220
Discriminator loss: 9.412156105041504 Generator loss: 8.64853572845459

Batch: 1960
Discriminator loss: 6.226527214050293 Generator loss: 5.91234827041626
Batch: 1980
Discriminator loss: 5.648400783538818 Generator loss: 5.976004600524902
Batch: 2000
Discriminator loss: 7.083575248718262 Generator loss: 5.740179061889648
Batch: 2020
Discriminator loss: 6.060420513153076 Generator loss: 5.652252197265625
Batch: 2040
Discriminator loss: 7.406171798706055 Generator loss: 5.904869079589844
Batch: 2060
Discriminator loss: 5.869866847991943 Generator loss: 5.529762268066406
Batch: 2080
Discriminator loss: 5.735332489013672 Generator loss: 5.459377765655518
Batch: 2100
Discriminator loss: 6.81845235824585 Generator loss: 5.674100399017334
Batch: 2120
Discriminator loss: 9.379009246826172 Generator loss: 5.796261310577393
Batch: 2140
Discriminator loss: 5.353666305541992 Generator loss: 5.522758960723877
Batch: 2160
Discriminator loss: 6.169582843780518 Generator loss: 5.509429931640625
Batch: 2180
Discriminator loss: 7.250353813171387 Generator loss: 5.7624588012