In [1]:
# import MBGAN.py and utils.py functions
from MBGAN import MBGAN
from scipy.stats import describe
from utils import *

# disable eager execution
from tensorflow.python.framework.ops import disable_eager_execution
disable_eager_execution()

In [2]:
# name the GAN and its output directory, specify data file name, num. epochs, batch size, and save interval
NAME = "mbgan"
EXP_DIR = "mbgan_test"
FILE = "raw_data.pkl"
EPOCHS = 10000
BATCH_SIZE = 32
SAVE_INTERVAL = 1000

In [3]:
# define a function for saving GAN output
def get_save_fn(taxa_list):
    
    def fn(model, epoch):
        # make a directory for output tables
        table_dir = os.path.join(model.log_dir, "tables")
        if not os.path.exists(table_dir):
            os.makedirs(table_dir)
        
        #create 1000 samples and measure sparsity and entropy
        res = model.predict(1000, transform=None, seed=None)
        sparsity, entropy = get_sparsity(res), shannon_entropy(res)
        print("sparsity: %s" % str(describe(sparsity)))
        print("entropy: %s" % str(describe(entropy)))
        
        # name output file
        filename = "{:s}_{:06d}--{:.4f}--{:.4f}.csv".format(
            model.model_name, epoch, np.mean(sparsity), np.mean(entropy))
        
        # save output to csv
        pd.DataFrame(res, columns=taxa_list).to_csv(os.path.join(table_dir, filename))
        
        return fn

In [4]:
# load data from .pkl or .csv - function changes depending on file format
data_o_case, data_o_ctrl, taxa_list = load_sample_pickle_data(FILE)

# use expand_phylo function from utils.py to create adj matrix and taxa indices
adj_matrix, taxa_indices = expand_phylo(taxa_list)

# convert adj matrix to dense
tf_matrix = adjmatrix_to_dense(adj_matrix, shape=(len(taxa_list), len(taxa_indices)))
    
#specify model configurations
model_config = {
        'ntaxa': 719, #num. taxa
        'latent_dim': 100, #z_dim
        'generator': {'n_channels': 512}, #num. channels in each generator layer
        'critic': {'n_channels': 256, 'dropout_rate': 0.25, 
                   'tf_matrix': tf_matrix, 't_pow': 1000.} #num. channels, dropout, phylogenetic matrix, and scale for critic
    }
    
# specify training configuration
train_config = {
        'generator': {'optimizer': ('RMSprop', {}), 'lr': 0.00005}, #specify generator optimizer + learning_rate
        'critic': {'loss_weights': [1, 1, 10],
                   'optimizer': ('RMSprop', {}), 'lr': 0.00005}, #specify critic loss weights, optimizer, and learning rate
    }

In [5]:
#instantiate model
mbgan = MBGAN(NAME, model_config, train_config)

Interpolated Sample Shape: Tensor("random_weighted_average/add:0", shape=(None, 719), dtype=float32)


In [6]:
# train model
mbgan.train(data_o_case, iteration=EPOCHS, batch_size=BATCH_SIZE,
            n_critic=5, n_generator=1, save_fn=get_save_fn(taxa_list),
            save_interval=SAVE_INTERVAL, experiment_dir=EXP_DIR,
            pre_processor=None, verbose=0)

#####################################################
Training start at: 2022-04-12 13:35:16
Run MB-GAN for 10000 iterations with batch_size=32
Save generated samples and model every 1000 iters
Results are exported to folder: mbgan_test\mbgan_20220412T133516
    Create log folder: mbgan_test\mbgan_20220412T133516
    Create model folder: mbgan_test\mbgan_20220412T133516\models
Generator structure:
Model: "generator_graph"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
noise (InputLayer)           [(None, 100)]             0         
_________________________________________________________________
generator (Functional)       (None, 719)               952015    
_________________________________________________________________
critic (Functional)          (None, 1)                 449281    
Total params: 1,401,296
Trainable params: 948,943
Non-trainable params: 452,353
________________________________

KeyboardInterrupt: 