<a href="https://colab.research.google.com/github/uvais-6/Generative-AI/blob/main/Defining_GAN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Generative Adversarial Networks

### Import Required libraries

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import numpy as np

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

### Defining a transform

In [None]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))
])

### Load the Dataset

In [None]:
train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True)

Files already downloaded and verified


In [None]:
# Hyperparameters
latent_dim = 100
lr = 0.0002
beta1 = 0.5
beta2 = 0.999
num_epochs = 10

### Define Generator

In [None]:
class Generator(nn.Module):
    def __init__(self,latent_dim):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(latent_dim, 128*8*8),
            nn.ReLU(),
            nn.Unflatten(1,(128,8,8)),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128, momentum = 0.78),
            nn.ReLU(),
            nn.Conv2d(64,3,kernel_size=3,padding=1),
            nn.Tanh()
        )
    def forward(self,z):
        img = self.model(z)
        return img

### Define the Discriminator

In [None]:
def __init__(self):
    super(Discriminator, self).__init__()
    self.model = nn.Sequential(
        nn.Conv2d(3,32,kernel_size=3,padding=1, stride=2),
        nn.ZeroPad2d((0,1,0,1)),
        nn.BatchNorm2d(64, momentum=0.8),
        nn.LeakyReLU(0.2),
        nn.Dropout(0.25),
        nn.Conv2d(32,64,kernel_size=3,padding=1, stride=2),
        nn.BatchNorm2d(128, momentum=0.8),
        nn.LeakyReLU(0.2),
        nn.Dropout(0.25),
        nn.Conv2d(64,128,kernel_size=3,padding=1, stride=2),
        nn.BatchNorm2d(128, momentum=0.8),
        nn.LeakyReLU(0.2),
        nn.Dropout(0.25),
        nn.Conv2d(128,256,kernel_size=3,padding=1, stride=2),
        nn.BatchNorm2d(256, momentum=0.8),
        nn.LeakyReLU(0.2),
        nn.Dropout(0.25),
        nn.Flatten(),
        nn.Linear(256*5*5,1),
        nn.Sigmoid()
    )
def forward(self,img):
    validity = self.model(img)
    return validity