<a href="https://colab.research.google.com/github/rakshitroshan/Gan-from-scratch/blob/master/Untitled12.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch


In [None]:
from torch import nn

In [None]:
from torchvision.models import vgg19

In [None]:
class Conv1(nn.Module):

  def __init__(
      self,
      incom,
      output,
      discriminator=False,
      use_act=True,
      use_bn=True,
      **kwargs,
  ):
      super().__init__()
      self.use_act=use_act
      self.cnn=nn.Conv2d(incom,output,**kwargs,bias=not use_bn)
      self.bn=nn.BatchNorm2d(output) if use_bn else nn.Identity()
      self.act=(
          nn.LeakyReLU(0.2,inplace=True)
          if discriminator
          else nn.PReLU(num_parameters=output)
      )
  def forward(self,x):
    return self.act(self.bn(self.cnn(x))) if self.use_act else self.bn(self.cnn(x))



In [None]:
class Upsample(nn.Module):
  def __init__(self,inc,scale):
    super().__init__()
    self.conv=nn.Conv2d(inc,inc*scale**2,3,1,1)
    self.ps=nn.PixelShuffle(scale)
    self.act=nn.PReLU(num_parameters=inc)
  def forward(self,x):
    return self.act(self.ps(self.conv(x)))


In [None]:
class Residualblock(nn.Module):
  def __init__(self,incom):
    super().__init__()
    self.block1=Conv1(
        incom,
        incom,
        kernel_size=3,
        stride=1,
        padding=1
    )
    self.block2=Conv1(
        incom,
        incom,
        kernel_size=3,
        stride=1,
        padding=1,
        use_act=False
    )
  def forward(self,x):
    out=self.block1(x)
    out=self.block2(out)
    return out + x

In [None]:
class Generator(nn.Module):
  def __init__(self,incom=16,num_channels=64,num_blocks=16):
    super().__init__()
    self.initial = Conv1(incom,num_channels,kernel_size=9,stride=1,padding=4,use_bn=False)
    self.residuals=nn.Sequential(*[Residualblock(num_channels) for _ in range(num_blocks)])
    self.convblock=Conv1(num_channels,kernel_size=3,stride=1,padding=1,use_act=False)
    self.upsamples=nn.Sequential(Upsample(num_channels,2),Upsample(num_channels,2))
    self.final=nn.Conv2d(num_channels,incom,kernel_size=9,stride=1,padding=4)
  def forward(self,x):
    initial=self.initial(x)
    x=self.residuals(initial)
    x=self.convblock(x)+initial
    x=self.upsamples(x)
    return torch.tanh(self.final(x))


In [None]:
class Discriminator(nn.Module):
  def __init__(self,incom=3,features=[64,64,128,128,256,256,512,512]):
    super().__init__()
    blocks=[]
    for idx,feature in enumerate(features):
      blocks.append(
          Conv1(
              incom,
              kernel_size=3,
              stride=1 + idx%2,
              paddinf=1,
              discriminator=True,
              use_act=True,
              use_bn=False if idx ==0 else True
          )
      )
      incom=feature
      self.blocks=nn.Sequential(*blocks)
      self.classifier=nn.Sequential(
          nn.AdaptiveAvgPool2d((6,6)),
          nn.Flatten(),
          nn.Linear(512*6*6,1024),
          nn.LeakyReLU(0.2,inplace=True),
          nn.Linear(1024,1)
      )
    def forward(self,x):
      x=self.blocks(x)
      return self.classifier(x)
