<a href="https://colab.research.google.com/github/puru9860/GAN/blob/main/pytorch_GAN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

%load_ext tensorboard


The tensorboard extension is already loaded. To reload it, use:
  %reload_ext tensorboard


In [None]:
class Discriminator(nn.Module):
  def __init__(self,channel_img,features_d):
    super(Discriminator,self).__init__()
    self.disc= nn.Sequential(
        #in Bx64x64
        nn.Conv2d(channel_img,features_d,kernel_size=4,stride=2,padding=1), # 32x32
        nn.LeakyReLU(0.2),
        self._block(features_d,features_d*2,4,2,1), #16x16
        self._block(features_d*2,features_d*4,4,2,1),#8x8
        self._block(features_d*4,features_d*8,4,2,1),#4x4
        nn.Conv2d(features_d*8,1,kernel_size=4,stride=2,padding=0),
        nn.Sigmoid()

    )

  def _block(self,in_channel,out_channel,kernel_size,stride,padding):
    return nn.Sequential(
        nn.Conv2d(in_channel,out_channel,kernel_size,stride,padding,bias=False),
        nn.BatchNorm2d(out_channel),
        nn.LeakyReLU(0.2)
    )
  
  def forward(self,x):
    return self.disc(x)

In [None]:
class Generator(nn.Module):
  def __init__(self,z_dim,channels_img,features_g):
    super(Generator,self).__init__()
    self.gen = nn.Sequential(
        #In Nxz_dimx4x4
        self._block(z_dim,features_g*16,4,1,0), #f_dx4x4
        self._block(features_g*16,features_g*8,4,2,1), #8x8
        self._block(features_g*8,features_g*4,4,2,1), #16x616
        self._block(features_g*4,features_g*2,4,2,1), #32x32
        nn.ConvTranspose2d(features_g*2,channels_img,kernel_size=4,stride=2,padding=1),
        nn.Tanh()
    )


  def _block(self,in_channels,out_channels,kernel_size,stride,padding):
    return nn.Sequential(
        nn.ConvTranspose2d(
            in_channels,out_channels,kernel_size,stride,padding,bias=False
        ),
        nn.BatchNorm2d(out_channels),
        nn.ReLU()
    )  
  def forward(self,x):
    return self.gen(x)        

In [None]:
def initialize_weights(model):
  for m in model.modules():
    if isinstance(m,(nn.Conv2d,nn.ConvTranspose2d,nn.BatchNorm2d)):
      nn.init.normal_(m.weight.data,0.0,0.02)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else 'cpu')
LEARNING_RATE = 2E-4 
BATCH_SIZE=128
IMAGE_SIZE=64
CHANNELS_IMG=3
Z_DIM =100
NUM_EPOCHS = 50
FEATURE_DISC = 64
FEATURES_GEN = 64

transform = transforms.Compose([ transforms.Resize(IMAGE_SIZE),
     transforms.ToTensor(),
     transforms.Normalize(
         [0.5 for _ in range(CHANNELS_IMG)],[0.5 for _ in range(CHANNELS_IMG)]),
])
   


In [None]:
dataset = datasets.ImageFolder(root='/content/FFHQ',transform=transform)
loader = DataLoader(dataset,batch_size=BATCH_SIZE,shuffle=True)

In [None]:
# dataset = datasets.MNIST(root="dataset/",train=True,transform=transform,download=True)
dataset = datasets.ImageFolder(root='/content/FFHQ',transform=transform)
loader = DataLoader(dataset,batch_size=BATCH_SIZE,shuffle=True)

gen= Generator(Z_DIM,CHANNELS_IMG,FEATURES_GEN).to(device)
disc = Discriminator(CHANNELS_IMG,FEATURE_DISC).to(device)

initialize_weights(gen)
initialize_weights(disc)

optim_gen = optim.Adam(gen.parameters(),lr=LEARNING_RATE,betas=(0.5,0.999))
optim_disc = optim.Adam(disc.parameters(),lr=LEARNING_RATE,betas=(0.5,0.999))

criterion = nn.BCELoss()

fixed_noise = torch.randn(32,Z_DIM,1,1).to(device)

writer_real = SummaryWriter(f"logs/real")
writer_fake = SummaryWriter(f"logs/fake")

step = 0

gen.train()
disc.train()


In [None]:
for epoch in range(NUM_EPOCHS):
  for batch_idx,(real, _) in enumerate(loader):
    real = real.to(device)
    noise = torch.randn((BATCH_SIZE,Z_DIM,1,1)).to(device)
    fake = gen(noise)
     
    disc_real = disc(real).reshape(-1)
    loss_disc_real = criterion(disc_real,torch.ones_like(disc_real))
    disc_fake = disc(fake).reshape(-1)
    loss_disc_fake = criterion(disc_fake,torch.zeros_like(disc_fake))
    loss_disc = (loss_disc_fake+loss_disc_real)/2

    disc.zero_grad()
    loss_disc.backward(retain_graph=True)
    optim_disc.step()

    output = disc(fake).reshape(-1)
    loss_gen = criterion(output,torch.ones_like(output))
    gen.zero_grad()
    loss_gen.backward()
    optim_gen.step()

    if batch_idx % 100 == 0 :
      print(f"Epoch [{epoch}/{NUM_EPOCHS}] Batch {batch_idx}/{len(loader)} \
      Loss D: {loss_disc:.4f},loss G: {loss_gen:.4f} ")

      with torch.no_grad():
        fake = gen(fixed_noise)

        img_grid_real = torchvision.utils.make_grid(
            real[:32],normalize=True
        )
        img_grid_fake = torchvision.utils.make_grid(
            fake[:32],normalize=True
        )
        writer_real.add_image("Real",img_grid_real,global_step=step)
        writer_fake.add_image("Fake",img_grid_fake,global_step=step)

      step +=1



Epoch [0/50] Batch 0/547       Loss D: 0.5843,loss G: 2.1252 
Epoch [0/50] Batch 100/547       Loss D: 0.5703,loss G: 2.1670 
Epoch [0/50] Batch 200/547       Loss D: 0.5709,loss G: 1.1170 
Epoch [0/50] Batch 300/547       Loss D: 0.4873,loss G: 1.2094 
Epoch [0/50] Batch 400/547       Loss D: 0.4704,loss G: 1.5430 
Epoch [0/50] Batch 500/547       Loss D: 0.4461,loss G: 1.5401 
Epoch [1/50] Batch 0/547       Loss D: 0.4398,loss G: 1.6066 
Epoch [1/50] Batch 100/547       Loss D: 0.4329,loss G: 1.4726 
Epoch [1/50] Batch 200/547       Loss D: 0.7805,loss G: 2.3341 
Epoch [1/50] Batch 300/547       Loss D: 0.4474,loss G: 2.0663 
Epoch [1/50] Batch 400/547       Loss D: 0.3992,loss G: 1.6029 
Epoch [1/50] Batch 500/547       Loss D: 0.6513,loss G: 1.2738 
Epoch [2/50] Batch 0/547       Loss D: 0.4011,loss G: 2.2950 
Epoch [2/50] Batch 100/547       Loss D: 0.3788,loss G: 1.3448 
Epoch [2/50] Batch 200/547       Loss D: 0.3695,loss G: 2.1263 
Epoch [2/50] Batch 300/547       Loss D: 0.552

In [None]:
%tensorboard --logdir=/content/logs/

In [None]:
!mkdir FFHQ/images

In [None]:
!kaggle datasets download -d arnaud58/flickrfaceshq-dataset-ffhq

In [None]:
!unzip /content/ffhq-face-data-set.zip

In [None]:
import os
for f in os.listdir('/content/FFHQ/images'):
  pass
print(f)

In [None]:
import shutil
shutil.rmtree('/content/logs')

In [None]:
!kaggle datasets download -d greatgamedota/ffhq-face-data-set

Downloading ffhq-face-data-set.zip to /content
100% 1.96G/1.97G [00:24<00:00, 98.7MB/s]
100% 1.97G/1.97G [00:24<00:00, 85.1MB/s]


In [None]:
%cd /content

/content


In [None]:
!lss