In [1]:
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
import torchvision.transforms as transform
import torch
from torchvision.utils import make_grid
import matplotlib.pyplot as plt
import torch.nn as nn
from torchvision.utils import save_image
import os
from tqdm.notebook import tqdm

In [2]:
train_set = ImageFolder('/kaggle/input/animefacedataset', transform = transform.Compose([
    transform.Resize(64),
    transform.CenterCrop(64),
    transform.ToTensor(),
    transform.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))
]))

In [3]:
batch_size = 128
trainloader = DataLoader(train_set, batch_size = batch_size, shuffle = True)

In [4]:
def denorm(img_tensors):
    return img_tensors*0.5 + 0.5

In [5]:
def show_images(images, nmax = 64):
    fig, ax = plt.subplots(figsize = (8,8))
    ax.set_xticks([])
    ax.set_yticks([])
    ax.imshow(make_grid(denorm(images.detach()[:nmax]), nrow = 8).permute(1,2,0))

def show_batch(dl, nmax = 64):
    for images, _ in dl:
        show_images(images,nmax)
        break


In [6]:
show_batch(trainloader)

In [7]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [8]:
def to_device(data, device):
    if isinstance(data, (list, tuple)):
        return [to_device(x, device) for x in data]
    return data.to(device, non_blocking = True)

In [9]:
class device_dataloader():
    def __init__(self, dl, device):
        self.dl = dl
        self.device = device
    def __iter__(self):
        for i in self.dl:
            yield to_device(i, self.device)
    def __len__(self):
        return len(self.dl)
        

In [10]:
train_dl = device_dataloader(trainloader, device)

In [11]:
discriminator = nn.Sequential(
    nn.Conv2d(3,64,kernel_size = 4, stride = 2, padding =1, bias = False),
    nn.BatchNorm2d(64),
    nn.LeakyReLU(0.2, inplace = True),
    
    nn.Conv2d(64,128,kernel_size = 4, stride = 2,padding = 1, bias = False),
    nn.BatchNorm2d(128),
    nn.LeakyReLU(0.2, inplace = True),
    
    nn.Conv2d(128,256,kernel_size = 4, stride = 2 ,padding = 1, bias = False),
    nn.BatchNorm2d(256),
    nn.LeakyReLU(0.2, inplace = True),
    
    nn.Conv2d(256,512,kernel_size = 4, stride = 2, padding = 1, bias = False),
    nn.BatchNorm2d(512),
    nn.LeakyReLU(0.2, inplace = True),
    
    nn.Conv2d(512,1,kernel_size = 4, stride = 1 ,padding = 0, bias = False),
    
    nn.Flatten(),
    nn.Sigmoid()

)

In [14]:
discriminator = to_device(discriminator, device)

In [15]:
latent = 128

In [16]:
generator = nn.Sequential(
    nn.ConvTranspose2d(latent, 512, kernel_size = 4, stride  =1, padding = 0, bias = False),
    nn.BatchNorm2d(512),
    nn.ReLU(),
    
    nn.ConvTranspose2d(512, 256, kernel_size = 4, stride = 2, padding =1, bias = False),
    nn.BatchNorm2d(256),
    nn.ReLU(),
    
    nn.ConvTranspose2d(256, 128, kernel_size = 4, stride = 2, padding =1, bias = False),
    nn.BatchNorm2d(128),
    nn.ReLU(),
    
    nn.ConvTranspose2d(128, 64, kernel_size = 4, stride = 2, padding =1, bias = False),
    nn.BatchNorm2d(64),
    nn.ReLU(),
    
    nn.ConvTranspose2d(64, 3, kernel_size = 4, stride = 2, padding =1, bias = False),
    nn.Tanh()
)

In [17]:
test1 = torch.randn(batch_size, latent, 1,1)
output_test1 = generator(test1)
print(output_test1.shape)
show_images(output_test1)

In [18]:
generator = to_device(generator, device)

In [19]:
def train_discriminator(real_images, dopt):
    dopt.zero_grad()
    
    real_targets = torch.ones(real_images.size(0), 1, device = device)
    real_predictions = discriminator(real_images)
    real_loss = nn.functional.binary_cross_entropy(real_predictions, real_targets)
    real_score = torch.mean(real_predictions).item()
    
    ls = torch.randn(batch_size, latent, 1,1, device = device)
    fakes = generator(ls)
    
    fake_targets = torch.zeros(fakes.size(0), 1, device = device)
    fake_predictions = discriminator(fakes)
    fake_loss = nn.functional.binary_cross_entropy(fake_predictions, fake_targets)
    fake_score = torch.mean(fake_predictions).item()
    
    loss = fake_loss + real_loss
    loss.backward()
    dopt.step()
    
    return loss, real_score, fake_score
    

In [20]:
def train_generator(gopt):
    gopt.zero_grad()
    lsg = torch.randn(batch_size, latent, 1,1, device = device)
    fakes = generator(lsg)
    
    pred = discriminator(fakes)
    target = torch.ones(batch_size, 1, device = device)
    loss = nn.functional.binary_cross_entropy(pred, target)
    
    loss.backward()
    gopt.step()
    
    return loss.item()

In [21]:
sample_dir = 'generated'
os.makedirs(sample_dir, exist_ok=True)

In [22]:
def save_samples(index, latent_tensors, show=True):
    fake_images = generator(latent_tensors)
    fake_fname = 'generated-images-{0:0=4d}.png'.format(index)
    save_image(denorm(fake_images), os.path.join(sample_dir, fake_fname), nrow=8)
    print('Saving', fake_fname)
    if show:
        fig, ax = plt.subplots(figsize=(8, 8))
        ax.set_xticks([]); ax.set_yticks([])
        ax.imshow(make_grid(fake_images.cpu().detach(), nrow=8).permute(1, 2, 0))

In [23]:
fixed = torch.randn(64, latent, 1,1, device = device)

In [24]:
def fit(epochs, lr, start_idx = 1):
    torch.cuda.empty_cache()
    losses_g = []
    losses_d = []
    real_scores = []
    fake_scores = []
    
    dopt = torch.optim.Adam(discriminator.parameters(), lr = lr, betas = (0.5, 0.999))
    gopt = torch.optim.Adam(generator.parameters(), lr = lr, betas = (0.5, 0.999))
    
    for epoch in range (epochs):
        for real_images, _ in tqdm(train_dl):
            loss_d, real_score, fake_score = train_discriminator(real_images, dopt)
            loss_g = train_generator(gopt)
        losses_d.append(loss_d)
        losses_g.append(loss_g)
        real_scores.append(real_score)
        fake_scores.append(fake_score)
        print("Epoch [{}/{}], loss_g: {:.4f} , loss_d: {:.4f}, Real_Score: {:.4f}, Fake_Score: {:.4f}". format(epoch+1, epochs, loss_g, loss_d, real_score, fake_score))
        
        save_samples(start_idx + epoch, fixed)
        
    return losses_g, losses_d, real_scores, fakes_scores

In [25]:
history = fit(25, 0.0002)

In [30]:
torch.save(generator.state_dict(), "anime_faces_generator.pt")

In [31]:
torch.save(discriminator.state_dict(), "anime_faces_discriminator.pt")