# GAN

GAN can generate new content that seems original.

<img src="./img/gan.png" alt="gan.png" style="width: 600px;"/>

A Generative Adversarial Network (GAN) consists of two primary components: 

* Generator, which creates synthetic text data from noise
* Discriminator, which distinguishes between real and generated text data
  
Noise refers to random changes to real data, such as adding special characters to a word. 

These components collaborate, with the Generator improving its fakes and the Discriminator enhancing its ability to detect them until the generated text becomes indistinguishable from real text.

# GAN

Create a model that generates credible reviews.

In [3]:
import torch
import torch.nn as nn

In [7]:
class Generator(nn.Module):
    def __init__(self, seq_length=100):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(seq_length, seq_length),
            nn.Sigmoid()
        )
    def forward(self, x):
        return self.model(x)

class Discriminator(nn.Module):
    def __init__(self, seq_length=100):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(seq_length, 1),
            nn.Sigmoid()
        )
    def forward(self, x):
        return self.model(x)

Geneartor has a a linear layer that transforms an input into the same dimension, followed by sigmoid function.

Discriminator has a linear layer that transforms an input into single value, followed by sigmoid function. 

The output is the probability that the input data is real.

In [10]:
generator = Generator()
discriminator = Discriminator()

In [12]:
loss_function = nn.BCELoss()

optim_gen = torch.optim.Adam(generator.parameters(), lr=0.001)
optim_dis = torch.optim.Adam(discriminator.parameters(), lr=0.001)

## Training Discriminator

In training loop we generate batches of real data and random noise for Generator to create a fake data.

We obtain predictions from the Discrimator for real and fake data, using the detach function to prevent gradient tracking.

Discriminator loss is calculated using torch-dot-ones_like and torch-dot-zeros_like to match the expected real and fake labels. 

We reset the gradients in the Discriminator's optimizer with zero_grad, perform backpropagation to calculate gradients, and update the Discriminator's parameters.

## Training Generator

We calculate the Generator's loss based on how well it fooled the Discriminator. 

The loss is determined by the difference between the Discriminator's predictions on fake data and an array of ones. 

We then reset the gradients in the Generator's optimizer, perform backpropagation to calculate gradients, and update the Generator's parameters. 

In [None]:
# training loop
seq_length = 100
for epoch in range(100):
    for data in dataloader:
        real_data = data.unsqueeze(0)
        
        noise = torch.rand((1, seq_length))
        fake = generator(noise)

        # Training Discriminator
        
        disc_real = discriminator(real_data)
        disc_fake = discriminator(fake.detach())

        loss_disc = loss_function(disc_real, torch.ones_like(disc_real)) + loss_function(disc_fake, torch.zeros_like(disc_fake))
        
        optim_dis.zero_grad()
        loss_disc.backward()
        optim_dis.step()

        # Training Generator
        disc_fake = discriminator(fake)
        
        loss_gen = loss_function(disc_fake, torch.ones_like(disc_fake))
        
        optim_gen.zero_grad()
        loss_gen.backward()
        optim_gen.step()

    if (epoch+1) %10 == 0:
        print(f'epoch {epoch} gen loss: {loss_gen.item()} , disc loss: {loss_disc.item()})