In [12]:
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision.transforms.functional as F
import matplotlib.pyplot as plt

class GAN_Dis(nn.Module):
    def __init__(self):
        super(GAN_Dis, self).__init__()
        self.fc1 = nn.Conv2d(3, 64, 3, 2)
        self.fc2 = nn.BatchNorm2d(64, 31, 31)
        
        self.fc3 = nn.Conv2d(64, 128, 3, 2)
        self.fc4 = nn.BatchNorm2d(128, 15, 15)
        
        self.fc5 = nn.Conv2d(128, 64, 3, 2)
        self.fc6 = nn.BatchNorm2d(64, 7, 7)
        
        self.fc7 = nn.Conv2d(64, 64, 3, 2)
        self.fc8 = nn.BatchNorm2d(64, 3, 3)
        
        self.fc9 = nn.Conv2d(64, 1, 3, 2)
        
        self.relu = nn.LeakyReLU()
        self.output = nn.Sigmoid()

    def forward(self, x):
        x = x.view(-1, 3, 64, 64)
        x = self.fc1(x)
        x = self.fc2(x)
        x = self.relu(x)
        x = self.fc3(x)
        x = self.fc4(x)
        x = self.relu(x)
        x = self.fc5(x)
        x = self.fc6(x)
        x = self.relu(x)
        x = self.fc7(x)
        x = self.fc8(x)
        x = self.relu(x)
        x = self.fc9(x)
        x = self.output(x)
        return x.view(-1)

        
class GAN_Gen(nn.Module):
    def __init__(self):
        super(GAN_Gen, self).__init__()
        self.fc1 = nn.Linear(128, 1024*4*4)
        self.fc2 = nn.BatchNorm2d(1024, 4, 4)
        
        self.fc3 = nn.ConvTranspose2d(1024, 512, 4, 2, 1)
        self.fc4 = nn.BatchNorm2d(512, 8, 8)
        
        self.fc5 = nn.ConvTranspose2d(512, 256, 4, 2, 1)
        self.fc6 = nn.BatchNorm2d(256, 16, 16)
        
        self.fc7 = nn.ConvTranspose2d(256, 128, 4, 2, 1)
        self.fc8 = nn.BatchNorm2d(128, 32, 32)
        
        self.fc9 = nn.ConvTranspose2d(128, 3, 4, 2, 1)
        self.fc10 = nn.BatchNorm2d(3, 64, 64)
        
        self.relu = nn.ReLU()
        self.output = nn.Sigmoid()
        
    def forward(self, x):
        x = self.fc1(x)
        x = x.view(-1, 1024, 4, 4)
        x = self.fc2(x)
        x = self.relu(x)
        x = self.fc3(x)
        x = self.fc4(x)
        x = self.relu(x)
        x = self.fc5(x)
        x = self.fc6(x)
        x = self.relu(x)
        x = self.fc7(x)
        x = self.fc8(x)
        x = self.relu(x)
        x = self.fc9(x)
        x = self.fc10(x)
        x = self.output(x)
        return x
    
    def generate_image(self, save_path):
        with torch.no_grad():
            seed = torch.rand(128)
            image = self.forward(seed)
            image = F.to_pil_image(image[0])
            image.save(save_path)        

In [13]:
# Tests
a = GAN_Gen()
b = GAN_Dis()
i = torch.rand(2, 128)
c = a(i)
#print(c.shape)
#print(b(c))

torch.Size([2, 3, 64, 64])
tensor([0.5085, 0.5085], grad_fn=<ViewBackward0>)
