In [None]:
import numpy as np
import torch
from datasets.mnist1d import load_MNIST1D
from datasets.control import get_mixture_distribution, load_control
import os
import matplotlib.pyplot as plt

In [None]:
mixture_list = [
    # uniform between [-10, -5] and [5, 10] (for 2D data)
    ("uniform",
     np.array([-10.0, -5.0, -10, -10]),
     np.array([5.0, 10.0, -5, 5])),

    # normal with mean [10, 5, -5] and std [2, 12, 1]
    ("normal",
     np.array([10.0,  5.0, -5.0, -5]),
     np.array([ 2.0, 12.0,  1.0, 1])),

    # normal centered at 0 with unit variance in all dims
    ("normal",
     np.array([0.0, 0.0, 0.0, 0]),
     np.array([1.0, 1.0, 1.0, 1])),

    # normal with mean [3, 6, 2] and std [4, 5, 3]
    ("normal",
     np.array([3.0, 6.0, 2.0, 2]),
     np.array([4.0, 5.0, 3.0, 3])),

    # laplace centered at [4, 4, 4] with scale 1
    ("laplace",
     np.array([4.0, 4.0, 4.0, 4]),
     np.array([1.0, 1.0, 1.0, 1])),

    # normal centered at -5 with std 3 in all dims
    ("normal",
     np.array([-5.0, -5.0, -5.0, -5]),
     np.array([3.0, 3.0, 3.0, 3])),

    # exponential with scale=1 in all dims, then shifted by -5
    ("exponential",
     np.array([1.0, 1.0, 1.0, 1]),    # scale
     np.array([-5.0, -5.0, -5.0, -5])), # shift
]


weights = [0.1, 0.2, 0.2, 0.1, 0.1, 0.2, 0.1]

control_data = get_mixture_distribution(mixture_list, weights, size=(10000, 4))

# Train GMMN

## Parameters

In [None]:
model_dir = "./model_weights"

BATCH_SIZE = 1000 # taken from original paper
trainloader, testloader, min_value, max_value = load_MNIST1D(batch_size = BATCH_SIZE)
# trainloader, testloader = load_MNIST(batch_size = BATCH_SIZE, size=7, flatten=True)
trainloader, testloader, min_value, max_value = load_control(control_data, batch_size=1000)
print(next(iter(trainloader))[0].shape)
N_INP = next(iter(trainloader))[0].shape[1]
NOISE_SIZE = 10
ENCODED_SIZE = N_INP // 2
N_ENCODER_EPOCHS = 200
N_GEN_EPOCHS = 200

if not os.path.exists(model_dir):
    os.mkdir(model_dir)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## Train Autoencoder


In [None]:
from models.gmmn.train_autoencoder import train_autoencoder
from visualization.loss import plot_loss

ENCODER_SAVE_PATH = model_dir + "/autoencoder_control.pth"
autoencoder, losses_autoencoder = train_autoencoder(trainloader, N_INP, ENCODED_SIZE, N_ENCODER_EPOCHS, device, ENCODER_SAVE_PATH)
plot_loss(losses_autoencoder, title="Autoencoder Loss")

## Continue with GMMN


In [None]:
from models.gmmn.train_gmmn import train_gmmn
from visualization import loss

GMMN_SAVE_PATH = model_dir + "/gmmn_control-3.pth"
gmm_net, losses_gmmn = train_gmmn(trainloader, autoencoder, ENCODED_SIZE, NOISE_SIZE, BATCH_SIZE, N_GEN_EPOCHS, device, GMMN_SAVE_PATH)
plot_loss(losses_gmmn, title="GMMN Loss")

## Sample Visualizations

##### MNIST 1-D
**Run only if data used is MNIST-1D**

In [None]:
from visualization.visualize_1d_data import visualize_mnist1d
from models.gmmn.gmmn import generate_gmmn_samples

samples, labels = next(iter(trainloader))
visualize_mnist1d(samples, labels, title_prefix="Real")

In [None]:
gen_samples = generate_gmmn_samples(gmm_net, autoencoder, NOISE_SIZE, 10)
visualize_mnist1d(gen_samples, labels, title_prefix="Generated")

##### Control Data
**Work only if working with control data**

In [None]:
import visualization.plots
visualization.plots.plot_3d_kde(control_data)

In [None]:
print(gen_samples.shape)
gen_samples = generate_gmmn_samples(gmm_net, autoencoder, NOISE_SIZE, 10000)
visualization.plots.plot_3d_kde(gen_samples))

## Bootstrapping Hypothesis Test

In [None]:
from utilities.bootstrapping_test import bootstrap_hypothesis_test

original_data = []
for batch_idx, (data, labels) in enumerate(trainloader):
    original_data.append(data.cpu()) # .cpu() if data is on GPU
original_data = torch.cat(original_data, dim=0)
original_data = np.squeeze(original_data.numpy())
generating_function = generate_gmmn_samples
gen_args = (gmm_net, autoencoder, NOISE_SIZE, 20000)
alpha = 0.05
num_iterations = 1000

bootstrap_hypothesis_test(original_data, generating_function, gen_args, alpha, num_iterations)