In [None]:
import os
import sys

def add_to_path(new_path: str):
    module_path = os.path.abspath(os.path.join(new_path))
    if module_path not in sys.path:
        sys.path.append(module_path)
    

is_colab = False

if is_colab:
    !git clone https://github.com/pdkary/Karys.git
    !cd Karys && git fetch && git pull
    !cd Karys && pip install -r requirements.txt --quiet
    add_to_path("Karys/")
    from google.colab import drive
    drive.mount("/content/drive")
    !cd ../../ && pip install -r requirements.txt --quiet
else:
    add_to_path("../../")

In [None]:
from data.configs.TextDataConfig import TextDataConfig
from data.configs.RandomDataConfig import RandomDataConfig
from data.wrappers.RandomDataWrapper import RandomDataWrapper
from data.wrappers.TextDataWrapper import TextDataWrapper

text_config = TextDataConfig(25000,20,5,1)
text_data_wrapper = TextDataWrapper.load_from_file('test_input/corpus.txt', text_config)

In [None]:
import numpy as np
from models.ModelWrapper import ModelWrapper
from tensorflow.keras.layers import Dense, LeakyReLU, ReLU
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.losses import MSE, MSLE, binary_crossentropy

a = 0.08
glayers = [Dense(1024),LeakyReLU(a),
          Dense(1024),LeakyReLU(a),
          Dense(1024),LeakyReLU(a),
          Dense(1024),LeakyReLU(a),
          Dense(np.prod(text_config.output_shape), ReLU())]
goptimizer = Adam(learning_rate=5e-4)
gloss = MSLE
#should take in input shape, and give out output shape
text_generator_model = ModelWrapper(text_config.input_shape, text_config.output_shape, glayers, goptimizer, gloss)
text_generator_model.build()

dlayers = [Dense(1024),LeakyReLU(a),
          Dense(1024),LeakyReLU(a),
          Dense(1024),LeakyReLU(a),
          Dense(1024),LeakyReLU(a),
          Dense(1), ReLU()]
doptimizer = Adam(learning_rate=5e-4)

dloss = MSLE
text_discriminator_model = ModelWrapper(text_config.output_shape, [1], dlayers, doptimizer, dloss)
text_discriminator_model.build()

In [None]:
from trainers.TextGanTrainer import TextGanTrainer
from plotting.TrainPlotter import TrainPlotter

if is_colab:
    file_output = "drive/MyDrive/Colab/Language/seinfeld.txt"
else:
    file_output = "./test_output/test.txt"

train_columns = ["Gen Train Loss","Disc Train Loss", "Gen Test Loss", "Disc Test Loss"]
loss_plot = TrainPlotter(moving_average_size=5,labels=train_columns)

epochs=10000
trains_per_test=20
batch_size = 20
batches_per_train = 5  

text_trainer = TextGanTrainer(text_generator_model, text_discriminator_model, text_data_wrapper, text_data_wrapper)
text_trainer.load_datasets()

g_test_loss, d_test_loss = 0, 0
for i in range(epochs):
    loss_plot.start_epoch()
    g_train_loss,d_train_loss = text_trainer.train(batch_size, batches_per_train)

    if i % trains_per_test == 0 and i != 0:
        g_test_loss, d_test_loss = text_trainer.test(10, 1)
        with open("./test_output/test.txt", 'w+')as f:
            for gen_i,gen_o in zip(text_trainer.most_recent_gen_input,text_trainer.most_recent_gen_output):
                t_ins = "\n".join([text_data_wrapper.translate_sentence(i) for i in gen_i])
                to = text_data_wrapper.translate_sentence(gen_o[0])
                out = "\n".join(["INPUT:",t_ins,"OUTPUT:",to])
                f.write(out +"\n\n")

    loss_plot.batch_update([g_train_loss,d_train_loss,g_test_loss,d_test_loss])
    loss_plot.log_epoch()