Skip to content

sepanta007/GAN

Repository files navigation

GANs with f-Divergences and DGflow Refinement

This project explores the impact of different f-divergences and sampling strategies on the quality and diversity of samples generated by Generative Adversarial Networks (GANs).

In particular, we study f-GAN variants (Jensen–Shannon, Kullback–Leibler, Pearson \chi^2) and evaluate the effect of post-hoc refinement using DGflow.


📌 Project Overview

  • Use of a baseline GAN for comparison
  • Implementation of f-GAN with different divergences
  • Study of sampling strategies:
    • Normal sampling
    • Soft truncation
    • Hard truncation
    • DGflow refinement
  • Evaluation using:
    • FID
    • Precision / Recall

📂 Repository Structure

.
├── checkpoints/              # Saved models (keep minimal)
├── data/                     # Dataset
├── samples/                  # Generated samples
├── model.py                  # Generator & Discriminator architectures
├── train.py                  # Baseline GAN training
├── train_fgan.py             # f-GAN training
├── generate.py               # Sample generation
├── sampling_utils.py         # Sampling methods (normal, truncation, DGflow)
├── fgan_utils.py             # f-divergence functions
├── metrics.py                # Evaluation metrics (FID, Precision, Recall)
├── evaluate_all.py           # Evaluation pipeline
├── utils.py                  # Utility functions
├── select_10img.py           # Sample selection utility
├── train_feature_extractor.py # Feature extractor for metrics
├── requirements.txt          # Dependencies
├── report.pdf                # Project report
├── slides.pdf                # Presentation slides
└── README.md

⚙️ Setup

1. Clone the repository

git clone <your-repo-url>
cd GAN

2. Set Up Your Environment

On Juliet (MesoNet's cluster), you need to:

  1. Create a virtual environment for Python:

    python -m venv venv
  2. Activate the environment:

    source venv/bin/activate
  3. Install the required dependencies:

    pip install -r requirements.txt

📦 requirements.txt

Among the good practices of data science, we encourage you to use conda or virtualenv to create a Python environment.

To test your code on our platform, you are required to update the requirements.txt file with all the libraries you use.

When your code is evaluated, the following command will be executed:

pip install -r requirements.txt

🚀 Usage

Train baseline GAN

python train.py

Train f-GAN

python train_fgan.py

Generate samples

python generate.py

Evaluate models

python evaluate_all.py

🧠 DGflow Refinement

DGflow improves generated samples by refining latent vectors using discriminator gradients.

Implemented in:

sampling_utils.py

Key features:

  • Sample-specific refinement
  • No retraining of the generator
  • Step size adaptation depending on divergence

📊 Results

  • DGflow improves FID across JS, KL, and Pearson divergences
  • JS and KL provide more stable gradients and better performance
  • Pearson divergence is more sensitive and requires careful tuning
  • Truncation methods reduce diversity compared to DGflow

See:

  • report.pdf
  • slides.pdf

💾 Checkpoints

Push the minimal amount of models in the folder checkpoints/.


📎 Notes

  • Large batch sizes are recommended for reliable evaluation metrics
  • Results may vary depending on hyperparameters and dataset
  • DGflow requires careful tuning of the step size

📚 References

  • Goodfellow et al., Generative Adversarial Networks, 2014
  • Nowozin et al., f-GAN, 2016
  • Ansari et al., DGflow, 2021

📜 License

This project is released under the MIT License.

About

GANs with f-Divergences and DGflow Refinement Implementation

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages