Skip to content

Python library for a large variety of GANs (Generative Adversarial Networks) based on Tensorflow and Keras

License

Notifications You must be signed in to change notification settings

quadeer15sh/GANForge

Repository files navigation

GANForge

Python Package Tests

Python library for a wide variety of GANs (Generative Adversarial Networks) based on TensorFlow and Keras.

Table of Contents

  1. Installation
  2. Quick Start
  3. Examples
  4. Supported GANs
  5. Custom Callbacks

Installation

To download the GANForge model from pypi please use the following pip command in your command prompt/terminal

pip install git+https://github.com/quadeer15sh/GANForge.git

Quick Start

You can get started with building GANs in just a few lines of code.

Example:

import tensorflow as tf
from GANForge.dcgan import DCGAN

# train_ds: your image dataset

model = DCGAN(input_shape=(64, 64, 3), latent_dim=128)
model.compile(d_optimizer=tf.keras.optimizers.Adam(learning_rate=0.0002),
              g_optimizer=tf.keras.optimizers.Adam(learning_rate=0.0002),
              loss_fn=tf.keras.losses.BinaryCrossentropy())

model.fit(train_ds, epochs=25)

Examples

Please feel free to explore through the notebook files on each of the GAN models available in GANForge

Supported GANs

Note : This list is updated frequently, please come back to check if the GAN architecture you desire to use is available or not

Sr. GAN Architecture Status
1 DC GAN Available
2 Conditional GAN Available
3 Info GAN In Progress
4 SR GAN Available
5 ESR GAN In Progress
6 Pix2Pix GAN In Progress
7 Cycle GAN In Progress
8 Attention GAN In Progress

Custom Callbacks

Custom callbacks available for usage during your training

Sr. Callback GAN Applicable
1 DCGANVisualization DC GAN
2 ConditionalGANVisualization Conditional DCGAN

Example:

import tensorflow as tf
from GANForge.dcgan import DCGAN
from GANForge.callbacks import DCGANVisualization

# train_ds: your image dataset

model = DCGAN(input_shape=(64, 64, 3), latent_dim=128)
model.compile(d_optimizer=tf.keras.optimizers.Adam(learning_rate=0.0002),
              g_optimizer=tf.keras.optimizers.Adam(learning_rate=0.0002),
              loss_fn=tf.keras.losses.BinaryCrossentropy())
visualizer = DCGANVisualization(n_epochs=5)

model.fit(train_ds, epochs=25, callbacks=[visualizer])

About

Python library for a large variety of GANs (Generative Adversarial Networks) based on Tensorflow and Keras

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages