In [3]:
%run generator.ipynb
%run discriminator.ipynb
import torch
import numpy as np
import pandas as pd 
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
class Gan(nn.Module):

    """
        Class to represent the Generative Adversarial Network (GAN) structure. 

        Attributes: 
            generator (torch.tensor): Tensor for generator.
            discriminator (torch.tensor): Tensor for discriminator.
            lr_g (float): Learning rate for the generator.
            lr_d (float): Learning rate for the discriminator. 
            optimize_generator (torch object): Optimization type for generator.
            optimize_discriminator (torch object): Optimization type for discriminator.
            criterion_d (nn object): Loss function for discriminator.
            criterion_g (nn object): Loss function for generator.
            noise (int): Noise for generator. 
            epochs (int): Number of epochs. 
    """
    
    def __init__(self, generator, discriminator, lr_g = 0.01, lr_d = 0.01, noise=64, epochs=5):

        super(Gan, self).__init__()

        self.generator = generator
        self.discriminator = discriminator
        self.noise = noise
        self.epochs=epochs

        self.lr_g = lr_g
        self.lr_d = lr_d

        self.optimize_generator = torch.optim.SGD(self.generator.parameters(), lr = self.lr_g)
        self.optimize_discriminator = torch.optim.SGD(self.discriminator.parameters(), lr = self.lr_d)

        self.criterion_d = nn.BCELoss()
        self.criterion_g = nn.BCELoss()

    def train(self, train_data_loader):

        """
            Trains the neural network with given data loader.

            Parameters: 
                train_data_loader (torch.tensor): Data loader for the training phase.

            Returns: 
                Returns the model.
        """

        for epoch in range(self.epochs): 
            for i, (X, y) in enumerate(train_data_loader):


                for param in self.generator.parameters():
                    param.requires_grad = False

                for param in self.discriminator.parameters():
                    param.requires_grad = True


                self.optimize_discriminator.zero_grad()

                batch_size = X.size(0)
                noise_dim = torch.randn(batch_size, self.noise)
                real_labels = torch.ones(batch_size, 1)
                fake_labels = torch.zeros(batch_size, 1)

                discriminator_real = self.discriminator(X)
                loss_disc = self.criterion_d(discriminator_real, real_labels)

                fake_ = self.generator(noise_dim)
                disc_fake = self.discriminator(fake_.detach()) 

                loss_fake_disc = self.criterion_d(disc_fake, fake_labels)

                total_disc_loss = (loss_disc + loss_fake_disc)/2

                total_disc_loss.backward()
                self.optimize_discriminator.step()


                for param in self.generator.parameters():
                    param.requires_grad = True

                for param in self.discriminator.parameters():
                    param.requires_grad = False


                self.optimize_generator.zero_grad()

                noise_dim = torch.randn(batch_size, self.noise)

                fake_data = self.generator(noise_dim)

                labels_for_gen = torch.ones(batch_size, 1)

                disc_pred = self.discriminator(fake_data)

                loss_gen = self.criterion_g(disc_pred, labels_for_gen)

                loss_gen.backward()
                self.optimize_generator.step()

        return self.generator

    def creating_new(self, samples_needed):

        """
            Creates the new samples with using the trained generator. 

            Args: 
                samples_needed (int): Needed samples to solve the issue of class imbalances.  

            Returns: 
                new_samples (torch.tensor): Returns new samples as a tensor. 
        """
        self.generator.eval()
        with torch.no_grad():
            
            noise_generate = torch.randn(samples_needed, self.noise)
            new_samples = self.generator(noise_generate).detach().numpy()

        return new_samples
    
        

## References

1- Medium Data Science. *Fraud Detection with Generative Adversarial Nets Gans*. Accessed on May 14, 2025, from https://medium.com/data-science/fraud-detection-with-generative-adversarial-nets-gans-26bea360870d

2- Me, W. *Building a simple GAN model*. Accessed on May 14, 2025, from  https://medium.com/@wasuratme96/building-a-simple-gan-model-9bfea22c651f

3- JakeTae Github. *PyTorch-GAN tutorial*. Accessed on May 10, 2025, from  https://jaketae.github.io/study/pytorch-gan/

4- Stifi, M. *How GANs generate new data a step by step guide*. Accessed on May 14, 2025, from https://mohamed-stifi.medium.com/how-gans-generate-new-data-a-step-by-step-guide-with-sine-waves-1c6aa4049357

5- Medium Data Science. *Conquer class-imbalanced dataset issues using gans*. Accessed on May 14, 2025, from https://medium.com/data-science/conquer-class-imbalanced-dataset-issues-using-gans-2482b52593aa

6- Medium Prabhatzade. *Freezing layers and fine tuning transformer models in PyTorch a simple guide*. Accessed on 13 May, 2025, from https://medium.com/@prabhatzade/freezing-layers-and-fine-tuning-transformer-models-in-pytorch-a-simple-guide-119cad0980c6#:~:text=In%20PyTorch%2C%20every%20parameter%20