In [5]:
import numpy as np
import os
import glob
import random

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torchvision.models import vgg19
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
from torchvision.utils import save_image, make_grid


import matplotlib.pyplot as plt
from PIL import Image
from tqdm import tqdm_notebook as tqdm

random.seed(42)
import warnings
warnings.filterwarnings("ignore")

In [6]:
# number of epochs of training
n_epochs = 75
# size of the batches
batch_size = 16
# name of the dataset
dataset_name = "../input/celeba-dataset/img_align_celeba/img_align_celeba"
# adam: learning rate
lr = 0.00008
# adam: decay of first order momentum of gradient
b1 = 0.5
# adam: decay of first order momentum of gradient
b2 = 0.999
# number of cpu threads to use during batch generation
n_cpu = 4
# size of each image dimension
img_size = 128
# size of random mask
mask_size = 64
# number of image channels
channels = 3
# interval between image sampling
sample_interval = 5000
#adversarial loss ratio
Lambda=0.001

cuda = True if torch.cuda.is_available() else False
# os.makedirs("testImages", exist_ok=True)
# os.makedirs("trainImages", exist_ok=True)
# os.makedirs("saved_models", exist_ok=True)

# Calculate output dims of image discriminator (PatchGAN)
patch_h, patch_w = int(img_size / 2 ** 3), int(img_size / 2 ** 3)
patch = (1, patch_h, patch_w)

In [7]:
class ImageDataset(Dataset):
    def __init__(self, root, transforms_=None, img_size=128, mask_size=64, mode="train"):
        self.transform = transforms.Compose(transforms_)
        self.img_size = img_size
        self.mask_size = mask_size
        self.mode = mode
        self.files = sorted(glob.glob("%s/*.jpg" % root))
        self.files = self.files[:-4000] if mode == "train" else self.files[-4000:]

    def apply_random_mask(self, img):
        """Randomly masks image"""
        y1, x1 = np.random.randint(0, self.img_size - self.mask_size, 2)
        y2, x2 = y1 + self.mask_size, x1 + self.mask_size
        masked_part = img[:, y1:y2, x1:x2]
        masked_img = img.clone()
        masked_img[:, y1:y2, x1:x2] = 0

        return masked_img

    def __getitem__(self, index):
        img = Image.open(self.files[index % len(self.files)])
        img = self.transform(img)
        masked_img= self.apply_random_mask(img)

        return img, masked_img

    def __len__(self):
        return len(self.files)

In [8]:
# transforms_ = [
#     transforms.Resize((img_size, img_size), Image.BICUBIC),
#     transforms.ToTensor(),
#     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
# ]
# dataloader = DataLoader(
#     ImageDataset(dataset_name, transforms_=transforms_),
#     batch_size=batch_size,
#     shuffle=True,
#     num_workers=n_cpu,
# )
# test_dataloader = DataLoader(
#     ImageDataset(dataset_name, transforms_=transforms_, mode="test"),
#     batch_size=batch_size,
#     shuffle=True,
#     num_workers=n_cpu,
# )

In [9]:
class Generator(nn.Module):
    def __init__(self, channels=3):
        super(Generator, self).__init__()

        def downsample(in_feat, out_feat, normalize=True):
            layers = [nn.Conv2d(in_feat, out_feat, 4, stride=2, padding=1)]
            if normalize:
                layers.append(nn.BatchNorm2d(out_feat, 0.8))
            layers.append(nn.LeakyReLU(0.2))
            return layers

        def upsample(in_feat, out_feat, normalize=True):
            layers = [nn.ConvTranspose2d(in_feat, out_feat, 4, stride=2, padding=1)]
            if normalize:
                layers.append(nn.BatchNorm2d(out_feat, 0.8))
            layers.append(nn.ReLU())
            return layers

        self.model = nn.Sequential(
            *downsample(channels, 64, normalize=False),
            *downsample(64, 64),
            *downsample(64, 128),
            *downsample(128, 256),
            *downsample(256, 512),
            nn.Conv2d(512, 4000, 1),
            *upsample(4000, 512),
            *upsample(512, 256),
            *upsample(256, 128),
            *upsample(128,64),
            *upsample(64,32),
            nn.Conv2d(32, channels, 3, 1, 1),
        )

    def forward(self, x):
        return self.model(x)


class Discriminator(nn.Module):
    def __init__(self, channels=3):
        super(Discriminator, self).__init__()

        def discriminator_block(in_filters, out_filters, stride, normalize):
            """Returns layers of each discriminator block"""
            layers = [nn.Conv2d(in_filters, out_filters, 3, stride, 1)]
            if normalize:
                layers.append(nn.InstanceNorm2d(out_filters))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        layers = []
        in_filters = channels
        for out_filters, stride, normalize in [(128, 2, False), (256, 2, True), (512, 2, True), (1024, 1, True)]:
            layers.extend(discriminator_block(in_filters, out_filters, stride, normalize))
            in_filters = out_filters

        layers.append(nn.Conv2d(out_filters, 1, 3, 1, 1))

        self.model = nn.Sequential(*layers)

    def forward(self, img):
        return self.model(img)

In [12]:
image = torch.randn(16,3,128,128)
gen = Generator(channels = 3)
gen_output = gen(image)
image.shape

torch.Size([16, 3, 128, 128])

In [13]:
disc = Discriminator(channels = 3)
disc_out = disc(image)
disc_out.shape

torch.Size([16, 1, 16, 16])

In [10]:
def weights_init_normal(m):
    classname = m.__class__.__name__
    if classname.find("Conv") != -1:
        torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find("BatchNorm2d") != -1:
        torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
        torch.nn.init.constant_(m.bias.data, 0.0)

def save_sample_test(batches_done):
    samples, masked_samples = next(iter(test_dataloader))
    samples = Variable(samples.type(Tensor))
    masked_samples = Variable(masked_samples.type(Tensor))
    # Generate inpainted image
    gen_img = generator(masked_samples)
    # Save sample
    sample = torch.cat((samples.data ,masked_samples.data, gen_img.data), -2)
    save_image(sample, "testImages/%d.png" % batches_done, nrow=4, normalize=True)

def save_sample_train(batches_done):
    samples, masked_samples = next(iter(dataloader))
    samples = Variable(samples.type(Tensor))
    masked_samples = Variable(masked_samples.type(Tensor))
    # Generate inpainted image
    gen_img = generator(masked_samples)
    # Save sample
    sample = torch.cat((samples.data ,masked_samples.data, gen_img.data), -2)
    save_image(sample, "trainImages/%d.png" % batches_done, nrow=4, normalize=True)
    
# Loss function
adversarial_loss = torch.nn.MSELoss()
pixelwise_loss = torch.nn.L1Loss()

# Initialize generator and discriminator
generator = Generator(channels=channels)
discriminator = Discriminator(channels=channels)

# Initialize weights
generator.apply(weights_init_normal)
discriminator.apply(weights_init_normal)

if cuda:
    generator.cuda()
    discriminator.cuda()
    adversarial_loss.cuda()
    pixelwise_loss.cuda()

# Optimizers
optimizer_G = torch.optim.Adam(generator.parameters(), lr=lr, betas=(b1, b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=lr, betas=(b1, b2))

Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor

In [11]:
for epoch in range(n_epochs):
    print(f'epoch :{epoch}')
    
    gen_adv_loss, gen_pixel_loss, disc_loss = 0, 0, 0
    tqdm_bar = tqdm(dataloader, desc=f'Training Epoch {epoch} ', total=int(len(dataloader)))
    for i, (imgs, masked_imgs) in enumerate(tqdm_bar):

        # Adversarial ground truths
        valid = Variable(Tensor(imgs.shape[0], *patch).fill_(1.0), requires_grad=False)
        fake = Variable(Tensor(imgs.shape[0], *patch).fill_(0.0), requires_grad=False)

        # Configure input
        imgs = Variable(imgs.type(Tensor))
        masked_imgs = Variable(masked_imgs.type(Tensor))

        ## Train Generator ##
        with torch.cuda.amp.autocast():
            # Generate a batch of images
            gen = generator(masked_imgs)

            # Adversarial and pixelwise loss
            g_adv = adversarial_loss(discriminator(gen), valid)
            g_pixel = pixelwise_loss(gen, imgs)
            # Total loss
            g_loss = Lambda * g_adv + (1-Lambda) * g_pixel
            
        optimizer_G.zero_grad()
        g_loss.backward()
        optimizer_G.step()

        ## Train Discriminator ##
        with torch.cuda.amp.autocast():
            # Measure discriminator's ability to classify real from generated samples
            real_loss = adversarial_loss(discriminator(imgs), valid)
            fake_loss = adversarial_loss(discriminator(gen.detach()), fake)
            d_loss = 0.5 * (real_loss + fake_loss)

        optimizer_D.zero_grad()
        d_loss.backward()
        optimizer_D.step()
        
        gen_adv_loss += g_adv.item()
        gen_pixel_loss += g_pixel.item()
        disc_loss += d_loss.item()
        tqdm_bar.set_postfix(gen_adv_loss=gen_adv_loss/(i+1), gen_pixel_loss=gen_pixel_loss/(i+1), disc_loss=disc_loss/(i+1))
        
        
        # Generate sample at sample interval
        generator.eval()
        with torch.inference_mode():
            batches_done = epoch * len(dataloader) + i
            if batches_done % sample_interval == 0:
                save_sample_test(batches_done)
                save_sample_train(batches_done)
                
    print(f'gen_adv_loss={gen_adv_loss/(i+1)}, gen_pixel_loss={gen_pixel_loss/(i+1)}, disc_loss={disc_loss/(i+1)}')
    torch.save(generator.state_dict(), "/kaggle/working/saved_models/generator.pth")
    torch.save(discriminator.state_dict(), "/kaggle/working/saved_models/discriminator.pth")

epoch :0


NameError: name 'dataloader' is not defined