In [11]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable


import glob
import random
import os
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import torchvision
from torch.utils.data import Dataset
from torch.autograd import Variable
from PIL import Image

In [12]:
import numpy as np
from tqdm import tqdm

import math

In [13]:
class TVLoss(nn.Module):
    def __init__(self, TVLoss_weight= 1):
        super(TVLoss,self).__init__()
        self.TVLoss_weight = TVLoss_weight

    def forward(self,x):    
        w_variance = torch.sum(torch.pow(x[:,:,:,:-1] - x[:,:,:,1:], 2))
        h_variance = torch.sum(torch.pow(x[:,:,:-1,:] - x[:,:,1:,:], 2))
        loss = self.TVLoss_weight * (h_variance + w_variance)
        return loss

class Identity(nn.Module):
    def __init__(self, *args, **kwargs):
        super().__init__()
    def forward(self, x):
        return x
    
class ResidualBlock(nn.Module):
    def __init__(self,in_features):
        super(ResidualBlock,self).__init__()
        conv_block = [  nn.ReflectionPad2d(1),
                        nn.Conv2d(in_features, in_features, 3),
                        nn.InstanceNorm2d(in_features),
                        nn.ReLU(inplace=True),
                        nn.ReflectionPad2d(1),
                        nn.Conv2d(in_features, in_features, 3),
                        nn.InstanceNorm2d(in_features)  ]

        self.conv_block = nn.Sequential(*conv_block)

    def forward(self, x):
        return x + self.conv_block(x) #skip connection

class Encoder(nn.Module):
    def __init__(self, in_nc, ngf=64):
        super(Encoder, self).__init__()

        #Inital Conv Block
        model = [   nn.ReflectionPad2d(3),
                    nn.Conv2d(in_nc, ngf, 7),
                    nn.InstanceNorm2d(ngf),
                    nn.ReLU(inplace=True) ]

        in_features = ngf
        out_features = in_features *2

        for _ in range(2):
            model += [
                nn.Conv2d(in_features, out_features, 3, stride=2, padding=1),
                nn.InstanceNorm2d(out_features),
                nn.ReLU(inplace=True)
            ]

            in_features = out_features
            out_features = in_features * 2

        self.model = nn.Sequential(*model)

    def forward(self,x):
        #Return batch w/ encoded content picture
        return [self.model(x['content']), x['style_label']]

class Transformer(nn.Module):
    def __init__(self,n_styles, ngf, auto_id=True):
        super(Transformer, self).__init__()
        self.t = nn.ModuleList([ResidualBlock(ngf*4) for i in range(n_styles)])
        if auto_id:
            self.t.append(Identity())

    def forward(self,x):
        #x0 - конетент, x[1][0] - номер стиля
        label = x[1][0]
        mix = np.sum([self.t[i](x[0])*v for (i,v) in enumerate(label) if v])
        return mix

class Decoder(nn.Module):
    def __init__(self, out_nc, ngf, n_residual_blocks=5):
        super(Decoder, self).__init__()

        in_features = ngf * 4
        out_features = in_features//2

        model = []
        for _ in range(n_residual_blocks):
            model += [ResidualBlock(in_features)]

        # Upsampling
        for _ in range(2):
            model += [  nn.ConvTranspose2d(in_features, out_features, 3, stride=2, padding=1, output_padding=1),
                        nn.InstanceNorm2d(out_features),
                        nn.ReLU(inplace=True) ]
            in_features = out_features
            out_features = in_features//2

        # Output layer
        model += [  nn.ReflectionPad2d(3),
                    nn.Conv2d(64, out_nc, 7),
                    nn.Tanh() ]

        self.model = nn.Sequential(*model)

    def forward(self,x):
        return self.model(x)

class Generator(nn.Module):
    def __init__(self,in_nc,out_nc,n_styles,ngf):
        super(Generator, self).__init__()

        self.encoder = Encoder(in_nc,ngf)
        self.transformer = Transformer(n_styles,ngf)
        self.decoder = Decoder(out_nc,ngf)

    def forward(self,x):
        e = self.encoder(x)
        t = self.transformer(e)
        d = self.decoder(t)
        return d

class Discriminator(nn.Module):
    """
    Patch-Gan discriminator 
    """
    def __init__(self, in_nc, n_styles, ndf=64):
        super(Discriminator, self).__init__()

        model = [   nn.Conv2d(in_nc, 64, 4, stride=2, padding=2),
                    nn.LeakyReLU(0.2, inplace=True) ]

        model += [  nn.Conv2d(64, 128, 4, stride=2, padding=2),
                    nn.InstanceNorm2d(128),
                    nn.LeakyReLU(0.2, inplace=True) ]

        model += [  nn.Conv2d(128, 256, 4, stride=2, padding=2),
                    nn.InstanceNorm2d(256),
                    nn.LeakyReLU(0.2, inplace=True) ]

        model += [  nn.Conv2d(256, 512, 4,stride=1, padding=2),
                    nn.InstanceNorm2d(512),
                    nn.LeakyReLU(0.2, inplace=True) ]

        self.model = nn.Sequential(*model)

        # GAN (real/notreal) Output-
        self.fldiscriminator = nn.Conv2d(512, 1, 4, padding = 2)

        # Classification Output
        self.aux_clf = nn.Conv2d(512, n_styles, 4, padding = 2)

    def forward(self, x):
        base =  self.model(x)
        discrim = self.fldiscriminator(base)
        clf = self.aux_clf(base).transpose(1,3)

        return [discrim,clf]

In [14]:

class ImageDataset(Dataset):
    def __init__(self, root, transforms_=None, mode='train'):
        transforms_ = [ transforms.Resize(int(143), Image.BICUBIC), 
                transforms.RandomCrop(128), 
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5)) 
              ]
        #content source
        self.transform = transforms.Compose(transforms_)
        self.X = sorted(glob.glob(os.path.join(root, f'{mode}A', '*')))
        
        #style image source(s)
        self.Y = []
        style_sources = sorted(glob.glob(os.path.join(root, f'{mode}B', '*')))
        for label,style in enumerate(style_sources):
            temp = [(label,x) for x in sorted(glob.glob(style_sources[label]+"/*"))]
            self.Y.extend(temp)
        
        
    def __getitem__(self,index):
                                                 
        output = {}
        output['content'] = self.transform(Image.open(self.X[index % len(self.X)]))
        
        #select style
        selection = self.Y[random.randint(0, len(self.Y) - 1)]
        
        try:                                         
            output['style'] = self.transform(Image.open(selection[1]))
        except:
            print('Серый')
            print(selection)
                        
        output['style_label'] = selection[0]
    
        return output
    
    def __len__(self):
        return max(len(self.X), len(self.Y))

class ReplayBuffer():
    def __init__(self, max_size=50):
        self.max_size = max_size
        self.data = []

    def push_and_pop(self, data):
        to_return = []
        for element in data.data:
            element = torch.unsqueeze(element, 0)
            if len(self.data) < self.max_size:
                self.data.append(element)
                to_return.append(element)
            else:
                if random.uniform(0,1) > 0.5:
                    i = random.randint(0, self.max_size-1)
                    to_return.append(self.data[i].clone())
                    self.data[i] = element
                else:
                    to_return.append(element)
        return Variable(torch.cat(to_return))

In [15]:
def label2tensor(label,tensor):
    for i in range(label.size(0)):
        tensor[i].fill_(label[i])
    return tensor

def tensor2image(tensor):
    image = 127.5*(tensor[0].cpu().float().numpy() + 1.0)
    if image.shape[0] == 1:
        image = np.tile(image, (3,1,1))
    return image.astype(np.uint8)

def weights_init_normal(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm2d') != -1:
        torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
        torch.nn.init.constant(m.bias.data, 0.0)


In [16]:
#TRAIN OPTIONS FROM GATED GAN
epoch = 0
n_epochs = 200
decay_epoch=100
batchSize = 1
dataroot = '/content/drive/My Drive/Colab Notebooks/photo2fourcollection'
loadSize = 143
fineSize = 128
ngf = 64
ndf = 64    
in_nc = 3 
out_nc = 3 
lr = 0.0002 
gpu = 1 
lambda_A = 10.0
pool_size = 50
resize_or_crop = 'resize_and_crop'
autoencoder_constrain = 10 
n_styles = 4
cuda=True
tv_strength=1e-6

In [None]:
dataloader = DataLoader(ImageDataset('/content/drive/My Drive/Colab Notebooks/photo2fourcollection'), 
                        batch_size=1, shuffle=True, num_workers=4)

batch = next(iter(dataloader))

batch['style_label']

tensor([2])

In [17]:
generator = Generator(in_nc, out_nc, n_styles, ngf)
discriminator = Discriminator(in_nc, n_styles, ndf)

#generator.load_state_dict(torch.load('./output/netG.pth'))


if cuda:
    generator.cuda()
    discriminator.cuda()

In [18]:
#Losses Init
use_lsgan=True
if use_lsgan:
    criterion_GAN = nn.MSELoss()
else: 
    criterion_GAN = nn.BCELoss()
    
    
criterion_ACGAN = nn.CrossEntropyLoss()
criterion_Rec = nn.L1Loss()
criterion_TV = TVLoss(TVLoss_weight=tv_strength)


#Optimizers & LR schedulers
optimizer_G = torch.optim.Adam(generator.parameters(),
                                lr=lr, betas=(0.5, 0.999))
optimizer_D = torch.optim.Adam(discriminator.parameters(), 
                               lr=lr, betas=(0.5, 0.999))


class LambdaLR():
    def __init__(self, n_epochs, offset, decay_start_epoch):
        self.n_epochs = n_epochs
        self.offset = offset
        self.decay_start_epoch = decay_start_epoch

    def step(self, epoch):
        return 1.0 - max(0, epoch + self.offset - self.decay_start_epoch)/(self.n_epochs - self.decay_start_epoch)
        
lr_scheduler_G = torch.optim.lr_scheduler.LambdaLR(optimizer_G, lr_lambda=LambdaLR(n_epochs, epoch,decay_epoch).step)
lr_scheduler_D = torch.optim.lr_scheduler.LambdaLR(optimizer_D, lr_lambda=LambdaLR(n_epochs,epoch, decay_epoch).step)

In [22]:
#Set vars for training
Tensor = torch.cuda.FloatTensor if cuda else torch.Tensor
input_A = Tensor(batchSize, in_nc, fineSize, fineSize)
input_B = Tensor(batchSize, out_nc, fineSize, fineSize)
target_real = Variable(Tensor(batchSize).fill_(1.0), requires_grad=False)
target_fake = Variable(Tensor(batchSize).fill_(0.0), requires_grad=False)

D_A_size = discriminator(input_A.copy_(batch['style']))[0].size()  
D_AC_size = discriminator(input_B.copy_(batch['style']))[1].size()

class_label_B = Tensor(D_AC_size[0],D_AC_size[1],D_AC_size[2]).long()

autoflag_OHE = Tensor(1,n_styles+1).fill_(0).long()
autoflag_OHE[0][-1] = 1

fake_label = Tensor(D_A_size).fill_(0.0)
real_label = Tensor(D_A_size).fill_(0.99) 

rec_A_AE = Tensor(batchSize,in_nc,fineSize,fineSize)

fake_buffer = ReplayBuffer()

##Init Weights
generator.apply(weights_init_normal)

Generator(
  (encoder): Encoder(
    (model): Sequential(
      (0): ReflectionPad2d((3, 3, 3, 3))
      (1): Conv2d(3, 64, kernel_size=(7, 7), stride=(1, 1))
      (2): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
      (3): ReLU(inplace=True)
      (4): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (5): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
      (6): ReLU(inplace=True)
      (7): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (8): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
      (9): ReLU(inplace=True)
    )
  )
  (transformer): Transformer(
    (t): ModuleList(
      (0): ResidualBlock(
        (conv_block): Sequential(
          (0): ReflectionPad2d((1, 1, 1, 1))
          (1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1))
          (2): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=Fals

In [23]:
discriminator.apply(weights_init_normal)

Discriminator(
  (model): Sequential(
    (0): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(2, 2))
    (1): LeakyReLU(negative_slope=0.2, inplace=True)
    (2): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(2, 2))
    (3): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (4): LeakyReLU(negative_slope=0.2, inplace=True)
    (5): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(2, 2))
    (6): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (7): LeakyReLU(negative_slope=0.2, inplace=True)
    (8): Conv2d(256, 512, kernel_size=(4, 4), stride=(1, 1), padding=(2, 2))
    (9): InstanceNorm2d(512, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (10): LeakyReLU(negative_slope=0.2, inplace=True)
  )
  (fldiscriminator): Conv2d(512, 1, kernel_size=(4, 4), stride=(1, 1), padding=(2, 2))
  (aux_clf): Conv2d(512, 4, kernel_size=(4, 4), stride=(1, 1), paddin

In [None]:
logger = Logger(n_epochs, len(dataloader))

In [24]:
### TRAIN LOOP
for epoch in tqdm(range(epoch,n_epochs - 1)):
    for i, batch in enumerate(dataloader):
        # если картинка серая, то пропускаем ее
        if len(batch) == 2:
          continue
        real_content = Variable(input_A.copy_(batch['content']))
        # target style
        real_style = Variable(input_B.copy_(batch['style']))
        # style label
        style_label = batch['style_label']
        # one-hot encoded style
        style_OHE = F.one_hot(style_label,n_styles).long()
        # номер стиля расширенный до 1x19x19 для дискриминатора
        class_label = class_label_B.copy_(label2tensor(style_label,class_label_B)).long()
        

        optimizer_D.zero_grad()
        
        # Generate style-transfered image
        genfake = generator({
            'content':real_content,
            'style_label': style_OHE})
        
        fake = fake_buffer.push_and_pop(genfake)
        # Дискриминатор с фейком
        out_gan, out_class = discriminator(fake)
        errD_fake = criterion_GAN(out_gan, fake_label)
        errD_fake.backward()
        optimizer_D.step()
        
        optimizer_D.zero_grad()
        # Дискриминатор с реальным изображением
        out_gan, out_class = discriminator(real_style)
        errD_real_class = criterion_ACGAN(out_class.transpose(1,3),class_label)*lambda_A
        errD_real = criterion_GAN(out_gan, real_label)        
        errD_real_total = errD_real + errD_real_class
        errD_real_total.backward()
        optimizer_D.step()
        
        
        errD = (errD_real+errD_fake)/2.0
        
                

        # Генератор 
        optimizer_G.zero_grad()
        
        # Дискриминатор с генерированным фейком
        out_gan, out_class = discriminator(genfake)
        
        err_gan = criterion_GAN(out_gan, real_label)
        err_class = criterion_ACGAN(out_class.transpose(1,3), class_label)*lambda_A
        err_TV = criterion_TV(genfake)
        
        errG_tot = err_gan + err_class + err_TV 
        
        errG_tot.backward()
        optimizer_G.step()
        
        ## Auto-Encoder (Recreation) Loss
        optimizer_G.zero_grad()
        identity = generator({
            'content': real_content,
            'style_label': autoflag_OHE,
        })
        err_ae = criterion_Rec(identity,real_content)*autoencoder_constrain
        err_ae.backward()
        optimizer_G.step()
        
        

    
    lr_scheduler_G.step()
    lr_scheduler_D.step()
    
    #Save model
    torch.save(generator.state_dict(), '/content/drive/My Drive/stepik/netG1.pth')
    torch.save(discriminator.state_dict(), '/content/drive/My Drive/stepik/netD1.pth')

  0%|          | 0/199 [00:00<?, ?it/s]

KeyboardInterrupt: ignored

thisuns grey
(2, '/content/drive/My Drive/Colab Notebooks/photo2fourcollection/trainB/ukiyoe/00688.jpg')
