<a href="https://colab.research.google.com/github/sandeshar/google/blob/main/DCGAN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import torch.optim as optim
import matplotlib.pyplot as plt
import torch.cuda as cuda

In [None]:
!mkdir -p ~/.kaggle
!mkdir -p ~/images
!mkdir -p ~/checkpoint
!cp kaggle.json ~/.kaggle/
!chmod 600 /root/.kaggle/kaggle.json
!kaggle datasets download -d karnikakapoor/art-portraits
!unzip art-portraits.zip -d images

Downloading art-portraits.zip to /content
100% 1.30G/1.30G [00:11<00:00, 144MB/s]
100% 1.30G/1.30G [00:11<00:00, 124MB/s]
Archive:  art-portraits.zip
  inflating: images/Portraits/000c6828b825f032af6047b46eba2686c.jpg  
  inflating: images/Portraits/0010cbc73014ac5e7ac81fd44eff1f3dc.jpg  
  inflating: images/Portraits/004b5f7cc82dadaa51dbb3b2230b5f85c.jpg  
  inflating: images/Portraits/004d60b7e881eb08966f711ce80523ecc.jpg  
  inflating: images/Portraits/007c5bf3a436793544a83c4a73c5cb4fc.jpg  
  inflating: images/Portraits/007f332f33bd1a8541912ca2b1701252c.jpg  
  inflating: images/Portraits/009c616c4a6415c96f795aa920dc2e85c.jpg  
  inflating: images/Portraits/00afb8e719aa2ea716a5b6a54c5c55fbc.jpg  
  inflating: images/Portraits/00bd05a5d525f451228196e47d51e243c.jpg  
  inflating: images/Portraits/00c775299a9b11d6a4d310a1464d7493c.jpg  
  inflating: images/Portraits/00ca56f16c0bae52185ea31f95f0484cc.jpg  
  inflating: images/Portraits/00d643034afe01ab875b817dc5de3af5c.jpg  
  inflatin

In [None]:
from torchvision.utils import make_grid
bs = 64

stats = (0.5, 0.5, 0.5), (0.5, 0.5, 0.5)
def denorm(img_tensors):
    return img_tensors * stats[1][0] + stats[0][0]

def show_images(images, nmax=64):
    fig, ax = plt.subplots(figsize=(8, 8))
    ax.set_xticks([]); ax.set_yticks([])
    images = images.cpu()
    ax.imshow(make_grid(denorm(images.detach()[:nmax]), nrow=8).permute(1, 2, 0))
    plt.pause(0.005)

def show_batch(dl, nmax=64):
    for images in dl:
        show_images(images, nmax)
        break

In [None]:
Discriminator = nn.Sequential(
     nn.Conv2d(3,bs,kernel_size=4,stride=2,padding=1),
     nn.BatchNorm2d(bs),
     nn.LeakyReLU(0.2),

     nn.Conv2d(bs,bs*2,kernel_size=4,stride=2,padding=1),
     nn.BatchNorm2d(bs*2),
     nn.LeakyReLU(0.2),

     nn.Conv2d(bs*2,bs*4,kernel_size=4,stride=2,padding=1),
     nn.BatchNorm2d(bs*4),
     nn.LeakyReLU(0.2),

     nn.Conv2d(bs*4,bs*8,kernel_size=4,stride=2,padding=1),
     nn.BatchNorm2d(bs*8),
     nn.LeakyReLU(0.2),

     nn.Conv2d(bs*8,1,kernel_size=4,stride=1,padding=0),
     nn.Sigmoid(),
)

In [None]:
Generator = nn.Sequential(
     nn.ConvTranspose2d(100,bs*8,kernel_size=4,stride=2,padding=0),
     nn.BatchNorm2d(bs*8),
     nn.LeakyReLU(0.2),

     nn.ConvTranspose2d(bs*8,bs*4,kernel_size=4,stride=2,padding=1),
     nn.BatchNorm2d(bs*4),
     nn.LeakyReLU(0.2),

     nn.ConvTranspose2d(bs*4,bs*2,kernel_size=4,stride=2,padding=1),
     nn.BatchNorm2d(bs*2),
     nn.LeakyReLU(0.2),

     nn.ConvTranspose2d(bs*2,bs,kernel_size=4,stride=2,padding=1),
     nn.BatchNorm2d(bs),
     nn.LeakyReLU(0.2),

     nn.ConvTranspose2d(bs,3,kernel_size=4,stride=2,padding=1),
     nn.Tanh(),
)

In [None]:
#HyperParameters
device = "cuda" if cuda.is_available() else "cpu"
lr = 0.001
epochs = 200
criterion = nn.BCELoss()

d = Discriminator.to(device)
g = Generator.to(device)

doptim = optim.Adam(d.parameters(),lr=lr)
goptim = optim.Adam(g.parameters(),lr=lr)

transfor = transforms.Compose([
         transforms.Resize(64),
         transforms.CenterCrop(64),
         transforms.ToTensor(),
         transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))])
data = datasets.ImageFolder('images', transform = transfor)
loader = DataLoader(data,batch_size=bs,num_workers=2,shuffle=True)

In [None]:
checkpoint = torch.load('drive/MyDrive/Model/model10.pt')
g.load_state_dict(checkpoint['gsd'])
d.load_state_dict(checkpoint['dsd'])
goptim.load_state_dict(checkpoint['go'])
doptim.load_state_dict(checkpoint['do'])
epoch = checkpoint['epoch']
g.train()
d.train()
for epoch in range(epochs):
  for idx,(real,_) in enumerate(loader):
    real = real.to(device)
    noise = torch.randn(bs,100,1,1).to(device)

    #Train discriminator
    fake = g(noise)
    df = d(fake).view(-1,1,1,1)
    dlossF = criterion(df,torch.zeros_like(df))
    
    dr = d(real).view(-1,1,1,1)
    dlossR = criterion(dr,torch.ones_like(dr))

    dloss = (dlossF+dlossR)/2
    d.zero_grad()
    dloss.backward(retain_graph = True)
    doptim.step()

    #Train generator
    output = d(fake).view(-1,1,1,1)
    gloss = criterion(output,torch.ones_like(output))
    g.zero_grad()
    gloss.backward(retain_graph = True)
    goptim.step()

    if idx == 0:
      show_images(fake)
      if epoch%10 == 0:
        state = {
          'epoch':epoch,
          'gsd':g.state_dict(),
          'dsd':d.state_dict(),
          'go':goptim.state_dict(),
          'do':doptim.state_dict(),
            }
        path = "drive/MyDrive/Model/model"+str(epoch)+".pt"
        torch.save(state,path)
      print(f"Epoch: {epoch}/{epochs} gloss: {gloss} dloss: {dloss}")