In [None]:
import os
import sys

def add_to_path(new_path: str):
    module_path = os.path.abspath(os.path.join(new_path))
    if module_path not in sys.path:
        sys.path.append(module_path)
    
is_colab = False

if is_colab:
    !git clone https://github.com/pdkary/Karys.git
    !cd Karys && git fetch && git pull
    !cd Karys && pip install -r requirements.txt --quiet
    add_to_path("Karys/")
    from google.colab import drive
    drive.mount("/content/drive")
else:
    !cd ../../ && pip install -r requirements.txt --quiet
    add_to_path("../../")

In [None]:
from data.configs.ImageDataConfig import ImageDataConfig
from data.configs.RandomDataConfig import RandomDataConfig
from data.wrappers.RandomDataWrapper import RandomDataWrapper
from data.wrappers.ImageDataWrapper import ImageDataWrapper

if is_colab:
  image_path = "drive/MyDrive/Colab/Mons"
else:
  image_path = "./test_input"

image_config = ImageDataConfig(image_shape=(64,64,3),image_type=".png", preview_rows=2, preview_cols=3)
random_config = RandomDataConfig([512], 0.0, 1.0, 5000)

image_data_wrapper = ImageDataWrapper.load_from_file_with_single_label(image_path, "REAL", image_config, validation_percentage=0.25)
random_data_wrapper = RandomDataWrapper(random_config)

### Discriminator

In [None]:
from models.ClassificationModel import ClassificationModel
from tensorflow.keras.layers import Conv2D, Dense, LeakyReLU, MaxPooling2D, Flatten, Activation, BatchNormalization
from tensorflow_addons.layers import InstanceNormalization
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.losses import CategoricalCrossentropy, Reduction

classification_labels = ["REAL", "FAKE"]

a = 0.08
disc_layers = [
    Conv2D(128,3), BatchNormalization(), LeakyReLU(a),
    Conv2D(128,3), BatchNormalization(), LeakyReLU(a),
    MaxPooling2D(),
    Conv2D(256,3), BatchNormalization(), LeakyReLU(a),
    Conv2D(256,3), BatchNormalization(), LeakyReLU(a),
    Conv2D(256,3), BatchNormalization(), LeakyReLU(a),
    MaxPooling2D(),
    Conv2D(512,3), BatchNormalization(), LeakyReLU(a),
    Conv2D(512,3), BatchNormalization(), LeakyReLU(a),
    Conv2D(512,3), BatchNormalization(), LeakyReLU(a),
    MaxPooling2D(),
    Flatten(),
    Dense(4096), Activation('relu'),
    Dense(1000), Activation('relu'),
    Dense(len(classification_labels)), Activation('sigmoid'),
]
optimizer = Adam(learning_rate=5e-4)
loss = CategoricalCrossentropy(from_logits=True, reduction=Reduction.SUM)
discriminator = ClassificationModel(image_config.image_shape, classification_labels, disc_layers, optimizer, loss)
discriminator.build()

In [None]:
from models.GenerationModel import GenerationModel
from tensorflow.keras.layers import Conv2D, Dense, LeakyReLU, Activation, Flatten, Reshape, UpSampling2D, BatchNormalization
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.losses import CategoricalCrossentropy, Reduction
from tensorflow_addons.layers import InstanceNormalization

a = 0.08
gen_layers = [
    Dense(512), Activation('relu'),
    Dense(4096), Activation('relu'),
    Reshape((2,2,1024)),
    UpSampling2D(),
    Conv2D(1024,3,padding="same"), BatchNormalization(), LeakyReLU(a),
    Conv2D(1024,3,padding="same"), BatchNormalization(), LeakyReLU(a),
    Conv2D(1024,3,padding="same"), BatchNormalization(), LeakyReLU(a),
    UpSampling2D(),
    Conv2D(256,3,padding="same"), BatchNormalization(), LeakyReLU(a),
    Conv2D(256,3,padding="same"), BatchNormalization(), LeakyReLU(a),
    Conv2D(256,3,padding="same"), BatchNormalization(), LeakyReLU(a),
    UpSampling2D(),
    Conv2D(64,3,padding="same"), BatchNormalization(), LeakyReLU(a),
    Conv2D(64,3,padding="same"), BatchNormalization(), LeakyReLU(a),
    Conv2D(64,3,padding="same"), BatchNormalization(), LeakyReLU(a),
    UpSampling2D(),
    Conv2D(32,3,padding="same"), BatchNormalization(), LeakyReLU(a),
    Conv2D(32,3,padding="same"), BatchNormalization(), LeakyReLU(a),
    Conv2D(32,3,padding="same"), BatchNormalization(), LeakyReLU(a),
    UpSampling2D(),
    Conv2D(16,3,padding="same"), BatchNormalization(), LeakyReLU(a),
    Conv2D(16,3,padding="same"), BatchNormalization(), LeakyReLU(a),
    Conv2D(3,3,padding="same"), Activation('sigmoid'),
]

optimizer = Adam(learning_rate=5e-4)
loss = CategoricalCrossentropy(from_logits=True, reduction=Reduction.SUM)
generator = GenerationModel(random_config.shape,image_config.image_shape,gen_layers,optimizer,loss)
generator.build()

In [None]:
from plotting.TrainPlotter import TrainPlotter
from trainers.ImageGanTrainer import ImageGanTrainer

epochs=100
trains_per_test=2
batch_size = 6
batches_per_loop = 3

train_columns = ["Gen Train Loss","Disc Train Loss", "Gen Test Loss", "Disc Test Loss"]
loss_plot = TrainPlotter(moving_average_size=5,labels=train_columns)
trainer = ImageGanTrainer(generator, discriminator, random_data_wrapper, image_data_wrapper)

if is_colab:
  image_output_path = image_path + "/images"
else:
  image_output_path = "./test_output"

g_test_loss = 0
d_test_loss = 0
for i in range(epochs):
  loss_plot.start_epoch()  
  g_loss, d_loss = trainer.train(batch_size, batches_per_loop)

  if i % trains_per_test == 0 and i != 0:
    g_test_loss, d_test_loss = trainer.test(12,1)
    test_output_filename = image_output_path + "/train-" + str(i) + ".png"
    image_data_wrapper.save_generated_images(test_output_filename, trainer.most_recent_gen_output)
      
  loss_plot.batch_update([g_loss,d_loss,g_test_loss,d_test_loss])
  loss_plot.log_epoch()