In [1]:
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.utils.data
import torch.nn.functional as F
import os
from torchvision import datasets, transforms
from tensorboardX import SummaryWriter
from PIL import Image

In [85]:
torch.cuda.is_available()

True

In [279]:
class PixelNormLayer(nn.Module):
    """
    Pixelwise feature vector normalization.
    """
    def __init__(self, eps=1e-8):
        super(PixelNormLayer, self).__init__()
        self.eps = eps
    
    def forward(self, x):
        return x / torch.sqrt(torch.mean(x ** 2, dim=1, keepdim=True) + 1e-8)

    def __repr__(self):
        return self.__class__.__name__ + '(eps = %s)' % (self.eps)
    

    
class GInput(nn.Module):
    def __init__(self, alpha=0.2):
        super(GInput, self).__init__()
        self.layer = nn.Sequential(
            nn.ConvTranspose2d(512, 512, kernel_size=4, stride=1, padding=0),
            PixelNormLayer(),
            nn.LeakyReLU(alpha),
            nn.ConvTranspose2d(512, 512, kernel_size=3, stride=1, padding=1),
            PixelNormLayer(),
            nn.LeakyReLU(alpha),
        )
                
    def forward(self, x):
        return self.layer(x)
    
    
class UpsampleG(nn.Module):
    def __init__(self, alpha=0.2, ch_in=512, ch_out=512):
        super(UpsampleG, self).__init__()
        self.layer = nn.Sequential(
            nn.ConvTranspose2d(ch_in, ch_out, kernel_size=4, stride=2, padding=1),
            PixelNormLayer(),
            nn.LeakyReLU(alpha),
            nn.ConvTranspose2d(ch_in, ch_out, kernel_size=3, stride=1, padding=1),
            PixelNormLayer(),
            nn.LeakyReLU(alpha),
        )
    
    def forward(self, x):
        return self.layer(x)
    
    
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.layer_in = GInput()
        self.layers = [self.layer_in]
        self.net = nn.Sequential(*self.layers)
        self.rgb = None
        self.det_to_rgb()
        
    def det_to_rgb(self):
        c = self.net[-1].layer[3].out_channels
        self.rgb = nn.Sequential(
            nn.Conv2d(c, 3, kernel_size=1, stride=1, padding=0),
            nn.Tanh()
        )
        
    def upsample(self, ch_in=512, ch_out=512):
        self.layers.append(UpsampleG(ch_in=ch_in, ch_out=ch_out))
        self.net = nn.Sequential(*self.layers)
        self.det_to_rgb()
        
        if next(self.parameters()).is_cuda:
            self.to_cuda()
        else:
            self.to_cpu()
        
    def to_cuda(self):
        self.cuda()
        self.net.cuda()
        self.rgb.cuda()
    
    def to_cpu(self):
        self.cpu()
        self.net.cpu()
        self.rgb.cpu()
                
    def forward(self, x, alpha=None):
        if alpha is not None and alpha < 0.999:
            x = self.net[:-1](x)
            x_left = F.upsample(x, scale_factor=2, mode='nearest')
            x_right = self.net[-1](x)
            
            alpha = torch.tensor(alpha)
            one = torch.ones(1)
            if x.is_cuda:
                alpha = alpha.cuda()
                one = one.cuda()
            
            x = (one - alpha) * x_left + alpha * x_right
            
        else:
            x = self.net(x)        
        x = self.rgb(x)
        print(x.shape)
                
a = Generator()
a.upsample()

a(torch.ones(1, 512, 1, 1), 1)

torch.Size([1, 3, 8, 8])


In [328]:
class UpsampleD(nn.Module):
    def __init__(self, alpha=0.2, ch_in=512, ch_out=512):
        super(UpsampleD, self).__init__()
        self.layer = nn.Sequential(
            nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=1, padding=1),
            PixelNormLayer(),
            nn.LeakyReLU(alpha),
            nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=1, padding=1),
            PixelNormLayer(),
            nn.LeakyReLU(alpha),
            nn.AvgPool2d(2)
        )
    
    def forward(self, x):
        return self.layer(x)

class DOutput(nn.Module):
    def __init__(self, alpha=0.2):
        super(DOutput, self).__init__()
        self.layer = nn.Sequential(
            nn.Conv2d(513, 512, kernel_size=3, stride=1, padding=1),
            PixelNormLayer(),
            nn.LeakyReLU(alpha),
            nn.Conv2d(512, 512, kernel_size=4, stride=1, padding=0),
            PixelNormLayer(),
            nn.LeakyReLU(alpha),
        )
        self.out = nn.Linear(512, 1)
                
    def forward(self, x):
        std = x.std()
        ones = torch.ones((1, 1, 4, 4))
        
        if x.is_cuda:
            ones = ones.cuda()
        
        x = torch.cat((x, (ones * std).float()), dim=1)
        
        x = self.layer(x)
        return self.out(x.view(512))

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.layer_out = DOutput()
        self.layers = [self.layer_out]
        self.net = nn.Sequential(*self.layers)
        self.rgb = None
        self.rgb_previous_size = None
        self.det_from_rgb()
        
    def det_from_rgb(self):
        c = self.net[0].layer[0].in_channels
        c = 512 if c == 513 else c
        
        self.rgb_previous_size = self.rgb
        self.rgb = nn.Conv2d(3, c, kernel_size=1, stride=1, padding=0, bias=False)
        
    def upsample(self, ch_in=512, ch_out=512):
        self.layers.insert(0, UpsampleD(ch_in=ch_in, ch_out=ch_out))
        self.net = nn.Sequential(*self.layers)
        self.det_from_rgb()
                
        if next(self.parameters()).is_cuda:
            self.to_cuda()
        else:
            self.to_cpu()
    
    def to_cuda(self):
        self.cuda()
        self.net.cuda()
        self.rgb.cuda()
    
    def to_cpu(self):
        self.cpu()
        self.net.cpu()
        self.rgb.cpu()
    
    def forward(self, x, alpha=None):
        if alpha is not None and alpha < 0.999:
            x_left = F.avg_pool2d(x, 2)
            x_left = self.rgb_previous_size(x_left)
            x_right = self.rgb(x)
            x_right = self.net[0](x_right)
            
            alpha = torch.tensor(alpha)
            one = torch.ones(1)
            if x.is_cuda:
                alpha = alpha.cuda()
                one = one.cuda()
            
            x = (one - alpha) * x_left + alpha * x_right
            x = self.net[1:](x)
        else:
            x = self.rgb(x)
            x = self.net(x)
        
        if self.training:
            return x
        
        return F.sigmoid(x)
        
a = Discriminator()
a.upsample()
a.upsample()
a(torch.rand(1, 3, 16, 16), 1)

tensor([ 0.3352])

In [None]:
def det_loss(gen, input_z, dis, input_x):
    y_real = dis(input_x)
    generated = gen(input_z)
    y_fake = dis(generated)
    
    ones = torch.ones_like(y_real)
    zeros = torch.zeros_like(y_fake)
    if next(gen.parameters()).is_cuda:
        ones = ones.cuda()
        zeros = zeros.cuda()
    
    loss_real = F.binary_cross_entropy_with_logits(y_real, ones)
    loss_fake = F.binary_cross_entropy_with_logits(y_fake, zeros)
    
    loss_dis = loss_real + loss_fake
    loss_gen = F.binary_cross_entropy_with_logits(y_fake, ones)
    
    return loss_dis, loss_gen, generated
    
det_loss(gen, torch.ones((1, 100)), dis, torch.ones(1, 3, 64, 64))

In [6]:
def show_img(x, real=True):
    plt.figure(figsize=(6, 6))
    if isinstance(x, torch.Tensor):
        if next(gen.parameters()).is_cuda:
            x = x.cpu()
        x = x.data.numpy()
    
    x = np.transpose(np.squeeze(x), [1, 2, 0]) 
    if not real:
        x = np.array((x + 1) / 2 * 255, int)

    plt.imshow(x)


In [13]:
class ImageFolderEX(datasets.ImageFolder):
    def __getitem__(self, index):
        def get_img(index):
            path, label = self.imgs[index]
            try:
                img = self.loader(os.path.join(self.root, path))
            except:
                img = get_img(index + 1)
            return img
        img = get_img(index)
        return self.transform(img)

In [19]:
trans = transforms.Compose([
    transforms.RandomHorizontalFlip(0.5),
    transforms.RandomResizedCrop(64, scale=(0.4, 0.8), ratio=(1, 1)),
#     transforms.Resize((64, 64), interpolation=2),
    transforms.ToTensor(),
])

data = torch.utils.data.DataLoader(ImageFolderEX('.', trans), batch_size=1, shuffle=True, drop_last=True, num_workers=0)
x = next(iter(data))
# show_img(x[0], False)

In [24]:
def save_checkpoint(state, filename='checkpoint.pth.tar'):
    torch.save(state, filename)



In [9]:
dis = Discriminator().cuda()
gen = Generator().cuda()
writer = SummaryWriter(log_dir='tb/7')

state = torch.load('gen.pth')
gen.load_state_dict(state['state_dict'])
state = torch.load('dis.pth')
dis.load_state_dict(state['state_dict'])

In [10]:
lr = 0.0002
beta_1 = 0.5
optimizer_gen = torch.optim.Adam(gen.parameters(), lr, betas=(beta_1, 0.999))
optimizer_dis = torch.optim.Adam(dis.parameters(), lr, betas=(beta_1, 0.999))

In [15]:
state.keys()

dict_keys(['epoch', 'state_dict', 'optimizer', 'total_step'])

In [17]:
total_step = state['total_step']

In [None]:
epochs = 20
batch_size = 256

data = torch.utils.data.DataLoader(ImageFolderEX('.', trans), batch_size=batch_size, shuffle=True, drop_last=True, num_workers=0)

for epoch in range(epochs):
    
    c = 0
    for x in iter(data):
        c += 1
        dis.zero_grad()
        
        z = torch.tensor(np.random.uniform(-1, 1, (batch_size, 100)), dtype=torch.float32)
        
        if next(gen.parameters()).is_cuda:
            x = x.cuda()
            z = z.cuda()
            
        y_real = dis(x)
        generated = gen(z)
        y_fake = dis(generated)
        
        ones = torch.ones_like(y_real)
        zeros = torch.zeros_like(y_fake)
        if next(gen.parameters()).is_cuda:
            ones = ones.cuda()
            zeros = zeros.cuda()
        
        loss_real = F.binary_cross_entropy_with_logits(y_real, ones)
        loss_fake = F.binary_cross_entropy_with_logits(y_fake, zeros)
        loss_dis = loss_real + loss_fake
        
        loss_dis.backward()
        optimizer_dis.step()
                
        gen.zero_grad()
        generated = gen(z)
        y_fake = dis(generated)
        loss_gen = F.binary_cross_entropy_with_logits(y_fake, ones)
        loss_gen.backward()
        optimizer_gen.step()
        
        global_step = total_step + epoch * len(data) + c
        
        if c % 1 == 0:
            writer.add_scalar('loss_dis', loss_dis.item(), global_step)
            writer.add_scalar('loss_gen', loss_gen.item(), global_step)
        
        if c % 4 == 0:
            print(loss_dis.item(), loss_gen.item(), 'step', global_step)
            writer.add_image('img', generated[0], global_step)
    print('finished epoch', epoch)
        
    

In [25]:
save_checkpoint({
            'epoch': epoch + 1,
            'state_dict': dis.state_dict(),
            'optimizer' : optimizer_dis.state_dict(),
            'total_step': global_step,
        }, 'dis.pth')

save_checkpoint({
            'epoch': epoch + 1,
            'state_dict': gen.state_dict(),
            'optimizer' : optimizer_gen.state_dict(),
            'total_step': global_step,
        }, 'gen.pth')