In [88]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
params = {
    'batch-size':128,
    
}

In [56]:
class Generator(nn.Module):
    def __init__(self,in_channels,out_channels):
        super().__init__()
        
        self.block1 = nn.Sequential(
            nn.ConvTranspose2d(in_channels, 1024,4,1),
            nn.BatchNorm2d(1024),
            nn.ReLU()
        )

        self.block2 = nn.Sequential(
            nn.ConvTranspose2d(1024, 512, 4, 2, 1),
            nn.BatchNorm2d(512),
            nn.ReLU()
        )

        self.block3 = nn.Sequential(
            nn.ConvTranspose2d(512, 256, 4, 2, 1),
            nn.BatchNorm2d(256),
            nn.ReLU()
        )

        self.block4 = nn.Sequential(
            nn.ConvTranspose2d(256, 128, 4, 2, 1),
            nn.BatchNorm2d(128), 
            nn.ReLU()
        )
        
        self.block5 = nn.Sequential(
            nn.ConvTranspose2d(128, out_channels, 4, 2, 1),
            nn.Tanh()
        )

    def forward(self,x):
        x = self.block1(x)
        x = self.block2(x)
        x = self.block3(x)
        x = self.block4(x)
        x = self.block5(x)
        return x

In [57]:
gen = Generator(100,3)

In [61]:
x = torch.randn((128,100,1,1))

In [62]:
x.shape

torch.Size([128, 100, 1, 1])

In [63]:
y = gen(x)

In [64]:
y.shape

torch.Size([128, 3, 64, 64])

In [89]:
class Discriminator(nn.Module):
    def __init__(self, inchannels, outchannels):
        super().__init__()

        self.block1 = nn.Sequential(
            nn.Conv2d(3, 64, 4, 2, 1), 
            nn.LeakyReLU(0.2)
        )

        self.block2 = nn.Sequential(
            nn.Conv2d(64, 128, 4, 2, 1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2)
        )

        self.block3 = nn.Sequential(
            nn.Conv2d(128, 256, 4, 2, 1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2)
        )

        self.block4 = nn.Sequential(
            nn.Conv2d(256, 512, 4, 2, 1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2)
        )

        self.block5 = nn.Sequential(
            nn.Conv2d(512, 1, 4, 1, 0),
        )
        
    def forward(self, x):
        x = self.block1(x)
        x = self.block2(x)
        x = self.block3(x)
        x = self.block4(x)
        x = self.block5(x)
        return F.sigmoid(x)

In [90]:
img = torch.randn((128,3,64,64))

In [91]:
dis = Discriminator(1,1)

In [92]:
y = dis(img)
y.shape

torch.Size([128, 1, 1, 1])