In [None]:
import torch
import torchvision
import os
import PIL
import pdb
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms,datasets
from torchvision.utils import make_grid
from tqdm.auto import tqdm
import numpy as np
import matplotlib.pylab as plt
from PIL import Image
from torchvision.utils import save_image

In [None]:
def show(tensor, num=25,name=''):
    data = tensor.detach().cpu()
    gred = make_grid(data[:num],nrow=5).permute(1,2,0)
    plt.imshow(gred)
    plt.title(name)
    plt.show()

In [None]:
epochs = 130
start_epoch = 0
batch = 64
gen_lr = 1e-4
crit_lr = 0.00008
z_dim = 200
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)

cur_step = 0
crit_cycls = 5
gan_losses =[]
crit_losses = []
show_step = 35

In [None]:
class Dog_Dataset(Dataset):
    
    def __init__(self, paths, size=128):
        self.sizes=[size, size]
        items, labels=[],[]

        for data in os.listdir(paths):
            item = os.path.join(paths,data)
            items.append(item)
            labels.append(data)
        self.items=items
        self.labels=labels


    def __len__(self):
        return len(self.items)

    def __getitem__(self,idx):
        data = PIL.Image.open(self.items[idx]).convert('RGB') 
        data = np.asarray(torchvision.transforms.Resize(self.sizes)(data)) 
        data = np.transpose(data, (2,0,1)).astype(np.float32, copy=False) 
        data = torch.from_numpy(data).div(255) 
        return data, self.labels[idx]


paths = "afhq/train/dog"
ds = Dog_Dataset(paths, size=512)
dataloader = DataLoader(ds, batch_size=batch, shuffle=True)

x,y= next(iter(dataloader))
show(x)

print('dataloader length: {}'.format(len(dataloader)))
print('Dataset length: {}'.format(len(ds)))
print(x.shape)

In [None]:
class Cat_Dataset(Dataset):
    
    def __init__(self, paths, size=128):
        self.sizes=[size, size]
        items, labels=[],[]

        for data in os.listdir(paths):
            item = os.path.join(paths,data)
            items.append(item)
            labels.append(data)
        self.items=items
        self.labels=labels

    def __len__(self):
        return len(self.items)

    def __getitem__(self,idx):
        data = PIL.Image.open(self.items[idx]).convert('RGB') 
        data = np.asarray(torchvision.transforms.Resize(self.sizes)(data)) 
        data = np.transpose(data, (2,0,1)).astype(np.float32, copy=False) 
        data = torch.from_numpy(data).div(255) 
        return data, self.labels[idx]


paths = "afhq/train/cat"
ds = Cat_Dataset(paths, size=512)
dataloader = DataLoader(ds, batch_size=batch, shuffle=True)

x,y= next(iter(dataloader))
show(x)

print('dataloader length: {}'.format(len(dataloader)))
print('Dataset length: {}'.format(len(ds)))
print(x.shape)

In [None]:
class Generator(nn.Module):
    
    def __init__(self, z_dim=64, d_dim=16):
        super(Generator, self).__init__()
        self.z_dim=z_dim

        self.gen = nn.Sequential(
            nn.ConvTranspose2d(z_dim, d_dim * 64, 4, 1, 0,bias=False), 
            nn.BatchNorm2d(d_dim*64),
            nn.ReLU(True),

            nn.ConvTranspose2d(d_dim*64, d_dim*32, 4, 2, 1,bias=False), 
            nn.BatchNorm2d(d_dim*32),
            nn.ReLU(True),

            nn.ConvTranspose2d(d_dim*32, d_dim*16, 4, 2, 1,bias=False), 
            nn.BatchNorm2d(d_dim*16),
            nn.ReLU(True),

            nn.ConvTranspose2d(d_dim*16, d_dim*8, 4, 2, 1,bias=False), 
            nn.BatchNorm2d(d_dim*8),
            nn.ReLU(True),            

            nn.ConvTranspose2d(d_dim*8, d_dim*4, 4, 2, 1,bias=False), 
            nn.BatchNorm2d(d_dim*4),
            nn.ReLU(True),            

            nn.ConvTranspose2d(d_dim*4, d_dim*2, 4, 2, 1,bias=False), 
            nn.BatchNorm2d(d_dim*2),
            nn.ReLU(True),    

            nn.ConvTranspose2d(d_dim*2, d_dim, 4, 2, 1,bias=False), 
            nn.BatchNorm2d(d_dim),
            nn.ReLU(True),               

            nn.ConvTranspose2d(d_dim, 3, 4, 2, 1,bias=False), 
            nn.Tanh() 
        )


    def forward(self, noise):
        x = noise.view(len(noise), self.z_dim, 1, 1)  
        return self.gen(x)


    def gen_noise(num, z_dim, device='cuda'):
        return torch.randn(num, z_dim, device=device) 



In [None]:
class Critic(nn.Module):
    
    def __init__(self, d_dim=16):
        super(Critic, self).__init__()

        self.crit = nn.Sequential(
            nn.Conv2d(3, d_dim, 4, 2, 1,bias=False), 
            nn.InstanceNorm2d(d_dim, affine=True), 
            nn.LeakyReLU(0.2),

            nn.Conv2d(d_dim, d_dim*2, 4, 2, 1,bias=False), 
            nn.InstanceNorm2d(d_dim*2, affine=True), 
            nn.LeakyReLU(0.2),

            nn.Conv2d(d_dim*2, d_dim*4, 4, 2, 1,bias=False), 
            nn.InstanceNorm2d(d_dim*4, affine=True), 
            nn.LeakyReLU(0.2),

            nn.Conv2d(d_dim*4, d_dim*8, 4, 2, 1,bias=False), 
            nn.InstanceNorm2d(d_dim*8, affine=True), 
            nn.LeakyReLU(0.2),

            nn.Conv2d(d_dim*8, d_dim*16, 4, 2, 1,bias=False), 
            nn.InstanceNorm2d(d_dim*16, affine=True), 
            nn.LeakyReLU(0.2),
              
            nn.Conv2d(d_dim*16, d_dim*32, 4, 2, 1,bias=False), 
            nn.InstanceNorm2d(d_dim*32, affine=True), 
            nn.LeakyReLU(0.2),

            nn.Conv2d(d_dim*32, d_dim*64, 4, 2, 1,bias=False), 
            nn.InstanceNorm2d(d_dim*64, affine=True), 
            nn.LeakyReLU(0.2),

            nn.Conv2d(d_dim*64, 1, 4, 1, 0,bias=False), 

        )


    def forward(self, image):
        crit_pred = self.crit(image) 
        return crit_pred.view(len(crit_pred),-1)   
  

In [None]:
def init_weights(m):
    if isinstance(m, nn.Conv2d) or isinstance(m,nn.ConvTranspose2d):
        torch.nn.init.normal_(m.weight, 0.0, 0.02)

    if isinstance(m,nn.BatchNorm2d):
        torch.nn.init.normal_(m.weight, 1.0, 0.02)
        torch.nn.init.constant_(m.bias,0)

In [None]:
def get_gp(real , fake , crit, alpha, gamma=10):
    mix_img = alpha*real + (1-alpha)*fake
    mix_scores = critic(mix_img)
    
    gradient = torch.autograd.grad(
        inputs= mix_img,
        outputs = mix_scores,
        grad_outputs = torch.ones_like(mix_scores),
        retain_graph = True,
        create_graph= True)[0]
    
    gradient = gradient.view(len(gradient), -1)
    gradient_norm = gradient.norm(2,dim=1)
    gp = gamma*((gradient_norm-1)**2).mean()
    
    return gp

In [None]:
def save_checkpoint(gan_path,critic_path):
    state_gan = {'epoch':epoch,'model':gan,'optimizer':gan_opt}
    state_critic = {'epoch':epoch,'model':critic,'optimizer':critic_opt}
    torch.save(state_gan,gan_path)
    torch.save(state_critic,critic_path)
    
def load_checkpoint(gan_path,critic_path):
    state_gan = torch.load(gan_path)
    state_critic = torch.load(critic_path)
    
    return state_gan,state_critic

In [None]:
gan = Generator(z_dim).to(device)
critic = Critic().to(device)
gan =gan.to(device)
critic =critic.to(device) 

#gan=gan.apply(init_weights)
#critic=critic.apply(init_weights)


gan_opt = torch.optim.Adam(gan.parameters(),lr= gen_lr, betas=(0.2,0.9))
critic_opt = torch.optim.Adam(critic.parameters(),lr= crit_lr, betas=(0.5,0.9))

In [None]:
for epoch in range(0,epochs):
    for real , _ in tqdm(dataloader):
        cur_bs = len(real)
        real = real.to(device)
        maen_critic_loss = 0
        for _ in range(crit_cycls):
            critic_opt.zero_grad()

            noise = gen_noise(cur_bs,z_dim)
            fake = gan(noise)
            crit_fake_prad = critic(fake)
            crit_real_prad = critic(real)

            alpha = torch.rand((cur_bs,1,1,1),requires_grad = True).to(device)
            gp = get_gp(real , fake.detach() , critic, alpha)
            crit_loss = crit_fake_prad.mean() - crit_real_prad.mean() + gp

            maen_critic_loss += crit_loss.item() / crit_cycls

            crit_loss.backward(retain_graph=True)
            critic_opt.step()

        crit_losses.append(maen_critic_loss)    

        gan_opt.zero_grad()
        noise = gen_noise(cur_bs,z_dim)
        fake = gan(noise)
        crit_fake_prad = critic(fake)

        gen_loss = - crit_fake_prad.mean()
        gen_loss.backward(retain_graph=True)
        gan_opt.step()

        gan_losses.append(gen_loss.item())


        if (cur_step % show_step == 0 and cur_step > 0):
            show(fake, name='fake')
            show(real, name='real')
            print("epoch: {} , step: {} ,gen_loss: {}, crit_loss: {}".format(epoch,cur_step,gen_loss,crit_loss))

            plt.plot(
              range(len(gan_losses)),
              torch.Tensor(gan_losses),
              label="Generator Loss"
            )

            plt.plot(
              range(len(gan_losses)),
              torch.Tensor(crit_losses),
              label="Critic Loss"
            )

            plt.ylim(-150,150)
            plt.legend()
            plt.show()

        cur_step+=1

    print("Saving checkpoint: ", cur_step, epoch)
    save_checkpoint("gen_cat " + str(epoch),"crit_cat " + str(epoch))

In [None]:
gan_dog_path,critic_dog_path = load_checkpoint("1gan_cat 78","1crit_cat 78")
gan_dog = gan_dog_path['model']
critic_dog = critic_dog_path['model']


gan_cat_path,critic_cat_path = load_checkpoint("1gan_cat 78","1crit_cat 78")
gan_cat = gan_cat_path['model']
critic_cat = critic_cat_path['model']

In [None]:
noise0 = gen_noise(batch, z_dim)
fake_dog = gan_dog(noise0)
show(fake_dog)
plt.imshow(fake_dog[0].detach().cpu().permute(1,2,0).squeeze().clip(0,1))

noise1 = gen_noise(batch, z_dim)
fake_cat = gan_cat(noise1)
show(fake_cat)
plt.imshow(fake_cat[0].detach().cpu().permute(1,2,0).squeeze().clip(0,1))

In [None]:
count = 0
os.mkdir("gan_img")
for _ in range(0,20):
    if not os.path.exists("gan_img/1"):
        os.mkdir("gan_img/1")
    if not os.path.exists("gan_img/0"):
        os.mkdir("gan_img/0")
        
    noise0 = gen_noise(batch, z_dim)
    fake_dog = gan_dog(noise0)
    
    noise1 = gen_noise(batch, z_dim)
    fake_cat = gan_cat(noise1)
    
    for idx in range(0,len(fake_cat)):
        img = fake_dog[idx].detach().cpu().squeeze().clip(0,1)
        save_image(img, os.path.join("gan_img/1","dog " + str(count) + ".jpg"))

        img = fake_cat[idx].detach().cpu().squeeze().clip(0,1)
        save_image(img, os.path.join("gan_img/0","cat " + str(count) + ".jpg"))
        
        count+=1