In [3]:
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.utils.data
import torch.nn.functional as F
import os
from torchvision import datasets, transforms
from tensorboardX import SummaryWriter
from PIL import Image
import logging

In [4]:
torch.cuda.is_available()

True

In [5]:
class Generator(nn.Module):
    def __init__(self, input_size=200, alpha=0.2):
        super(Generator, self).__init__()       
        kernel_size = 4
        padding = 1
        stride = 2
        
        self.input = nn.Linear(input_size, 4 * 4 * 1024)
        self.net = nn.Sequential(
            nn.BatchNorm2d(1024),
            nn.LeakyReLU(alpha),
            nn.ConvTranspose2d(1024, 512, kernel_size, stride, padding),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(alpha),
            nn.ConvTranspose2d(512, 512, kernel_size, stride, padding),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(alpha),
            nn.ConvTranspose2d(512, 512, kernel_size, stride, padding),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(alpha),
            nn.ConvTranspose2d(512, 256, kernel_size, stride, padding),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(alpha),
            nn.ConvTranspose2d(256, 128, kernel_size, stride, padding),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(alpha),
            nn.ConvTranspose2d(128, 3, kernel_size, stride, padding),
            nn.Tanh()
        )
  
    def forward(self, z):
        x = self.input(z)
        return self.net(x.view(-1, 1024, 4, 4))
        
gen = Generator()
gen(torch.ones((1, 200))).shape

torch.Size([1, 3, 256, 256])

In [6]:
class Discriminator(nn.Module):
    def __init__(self, alpha=0.2):
        super(Discriminator, self).__init__()

        kernel_size = 4
        padding = 1
        stride = 2
        
        self.net = nn.Sequential(
            nn.Conv2d(3, 128, kernel_size, stride, padding),
            nn.LeakyReLU(alpha),
            nn.Conv2d(128, 256, kernel_size, stride, padding),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(alpha),
            nn.Conv2d(256, 512, kernel_size, stride, padding),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(alpha),
            nn.Conv2d(512, 512, kernel_size, stride, padding),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(alpha),
            nn.Conv2d(512, 512, kernel_size, stride, padding),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(alpha),
            nn.Conv2d(512, 1024, kernel_size, stride, padding),
            nn.BatchNorm2d(1024),
            nn.LeakyReLU(alpha),
        )
        self.output = nn.Linear(4 * 4 * 1024, 1)
        
        
    def forward(self, x):
        x = self.net(x)
        x = torch.reshape(x, (-1, 4 * 4 * 1024))
        x = self.output(x)
        
        if self.training:
            return x
        
        return F.sigmoid(x)
    
dis = Discriminator()
dis(torch.ones(1, 3, 256, 256))

tensor([[ 0.3871]])

In [7]:
class ImageFolderEX(datasets.ImageFolder):
    def __getitem__(self, index):
        def get_img(index):
            path, label = self.imgs[index]
            try:
                img = self.loader(os.path.join(self.root, path))
            except:
                img = get_img(index + 1)
            return img
        img = get_img(index)
        return self.transform(img) * 2 - 1

In [8]:
trans = transforms.Compose([
    transforms.Resize((256, 256), interpolation=2),
    transforms.ToTensor(),
])

data = torch.utils.data.DataLoader(ImageFolderEX('.', trans), batch_size=1, shuffle=True, drop_last=True, num_workers=0)
x = next(iter(data))
print(x.min(), x.max())

RuntimeError: Found 0 files in subfolders of: .
Supported extensions are: .jpg,.jpeg,.png,.ppm,.bmp,.pgm,.tif

In [9]:
def train_dis(dis, gen, x):
    z = torch.tensor(np.random.normal(0, 1, (batch_size, 200)), dtype=torch.float32)

    if next(gen.parameters()).is_cuda:
        x = x.cuda()
        z = z.cuda()

    # train discriminator
    dis.zero_grad()
    y_real_pred = dis(x)
    
    idx = np.random.uniform(0, 1, y_real_pred.shape)
    idx = np.argwhere(idx < 0.03)
    
    # labels
    ones = np.ones(y_real_pred.shape) + np.random.uniform(-0.1, 0.1)
    ones[idx] = 0
    
    zeros = np.zeros(y_real_pred.shape) + np.random.uniform(0, 0.2)
    zeros[idx] = 1
    ones = torch.from_numpy(ones).float()
    zeros = torch.from_numpy(zeros).float()

    if next(gen.parameters()).is_cuda:
        ones = ones.cuda()
        zeros = zeros.cuda()

    loss_real = F.binary_cross_entropy_with_logits(y_real_pred, ones)

    generated = gen(z)
    y_fake_pred = dis(generated)

    loss_fake = F.binary_cross_entropy_with_logits(y_fake_pred, zeros)
    loss = loss_fake + loss_real
    loss.backward()
    optimizer_dis.step()
    return loss

            
def train_gen(gen, batch_size):
    z = torch.tensor(np.random.normal(0, 1, (batch_size, 200)), dtype=torch.float32)
    
    if next(gen.parameters()).is_cuda:
        z = z.cuda()
    
    gen.zero_grad()
    generated = gen(z)
    y_fake = dis(generated)

    ones = torch.ones_like(y_fake)
    if next(gen.parameters()).is_cuda:
        ones = ones.cuda()

    loss = F.binary_cross_entropy_with_logits(y_fake, ones)
    loss.backward()
    optimizer_gen.step()
    return loss, generated

In [6]:
def show_img(x, real=True, title=None, figsize=(6, 6)):
    plt.figure(figsize=figsize)
    if title:
        plt.title(title)
    if isinstance(x, torch.Tensor):
        if next(gen.parameters()).is_cuda:
            x = x.cpu()
        x = x.data.numpy()
    
    x = np.transpose(np.squeeze(x), [1, 2, 0]) 
    if not real:
        x = np.array((x + 1) / 2 * 255, int)
    plt.imshow(x)

In [10]:
logger = logging.getLogger()
logger.addHandler(logging.FileHandler("256.log"))
logger.setLevel(logging.INFO)

In [11]:
def save_checkpoint(state, filename='checkpoint.pth.tar'):
    torch.save(state, filename)

In [7]:
dis = Discriminator().cuda()
gen = Generator().cuda()
gen = Generator()

lr = 0.0002
beta_1 = 0.5
beta_2 = 0.999
optimizer_gen = torch.optim.Adam(gen.parameters(), lr, betas=(beta_1, beta_2))
optimizer_dis = torch.optim.Adam(dis.parameters(), lr, betas=(beta_1, beta_2))

epochs = 20
batch_size = 64
data = torch.utils.data.DataLoader(ImageFolderEX('.', trans), batch_size=batch_size, shuffle=True, drop_last=True, num_workers=2)


from_checkpoint = True
if from_checkpoint:
    state = torch.load('gen5.pth')
    gen.load_state_dict(state['state_dict'])
    state = torch.load('backup/dis1.pth')
    dis.load_state_dict(state['state_dict'])
    total_step = state['total_step']
else:
    total_step = 0
    

In [16]:
def save():
    save_checkpoint({
        'epoch': epoch,
        'state_dict': dis.state_dict(),
        'optimizer' : optimizer_dis.state_dict(),
        'total_step': global_step,
    }, f'backup/dis{epoch}.pth')

    save_checkpoint({
                'epoch': epoch,
                'state_dict': gen.state_dict(),
                'optimizer' : optimizer_gen.state_dict(),
                'total_step': global_step,
            }, f'backup/gen{epoch}.pth')

In [None]:
n = len(data)
for epoch in range(0, epochs):
    c = 0
    n = len(data) 
    writer = SummaryWriter(log_dir=f'tb/epoch_{epoch}')

    for x in iter(data): 
        c += 1

        loss_dis = train_dis(dis, gen, x)
        loss_gen, generated = train_gen(gen, batch_size)
        
        global_step = epoch * n + c
        if c % 1 == 0:
            writer.add_scalar('loss_dis', loss_dis.item(), total_step + global_step)
            writer.add_scalar('loss_gen', loss_gen.item(), total_step + global_step)
        if c % 4 == 0:
            msg = f'loss: {loss_dis.item()}, \t {loss_gen.item()} \t epoch: {epoch}, \t {c}/{n}'
            logging.info(msg)
            
        if c % 2 == 0:
            writer.add_image('img', torch.sigmoid(generated[:5]), total_step + global_step)
            
        if c > (n // 2):
            save()
            
    print('finished epoch', epoch)


