In [1]:
from trainers.gaaltrainer import GAALTrainer
from trainers.activesvmtrainer import ActiveSVMTrainer
from trainers.randomtrainer import RandomTrainer
from trainers.fullysupervisedtrainer import FSTrainer
from trainers.simplegantrainer import SimpleGANTrainer

from trainers.makegan import GANTrainer
from trainers.makewassersteingan import WGANTrainer

import dill as pickle

Instructions on the use of each imported Class can be found in the docstrings which can be viewed using the following sample command.

In [3]:
?GAALTrainer

Perform 10 loops of each and obtain the accuracies of the classifier start at 50 labelled samples, ending at 350 labelled samples. Then, save the results to a pickle file.

## GAAL

In [None]:
# GAAL Example - train and test on mnist 5 & 7
classifier_acc_mnistmnist_gaal = []
classifier_nsamples_mnistmnist_gaal = []
for _ in range(10):
    gaal = GAALTrainer(traindatasettype='mnist', testdatasettype='mnist', 
                       generatorpath='./gans/mnist/generator_model_1000.h5', 
                       oraclepath='./oracles/mnist57.h5', 
                       n_samples_end=350, threshold=1e-8, 
                       start_samples=50, latent_dim=100)
    classifier_acc_mnistmnist_gaal.append(gaal.learner_acc_history)
    classifier_nsamples_mnistmnist_gaal.append(gaal.n_samples)

# # Save results to pickle
# result_gaal_mnistmnist = ['GAAL train and test on mnist', classifier_acc_mnistmnist_gaal, classifier_nsamples_mnistmnist_gaal, gaal.x_train_end, gaal.y_train_end]
# with open('./results/result_gaal_mnistmnist.pkl', 'wb') as file:
#     pickle.dump(result_gaal_mnistmnist, file)

## GAAL with WGAN

Only difference is the generator used.

In [None]:
# GAAL WGAN Example - train and test on cifar10
classifier_acc_cifar10_gaal = []
classifier_nsamples_cifar10_gaal = []
for _ in range(10):
    gaal = GAALTrainer(traindatasettype='cifar10', testdatasettype='cifar10', 
                       generatorpath='./gans/cifar10-wgan/generator_model_1500.h5', 
                       oraclepath='./oracles/cifar10ha.h5', 
                       n_samples_end=350, threshold=1e-8, 
                       start_samples=50, latent_dim=100)
    classifier_acc_cifar10_gaal.append(gaal.learner_acc_history)
    classifier_nsamples_cifar10_gaal.append(gaal.n_samples)

# # Save results to pickle
# result_gaal_cifar10_wgan = ['GAAL train and test on cifar10, WGAN', classifier_acc_cifar10_gaal, classifier_nsamples_cifar10_gaal, gaal.x_train_end, gaal.y_train_end]
# with open('./results/result_gaal_cifar10_wgan.pkl', 'wb') as file:
#     pickle.dump(result_gaal_cifar10_wgan, file)

## GAAL with Diversity Measure

Utilise the option to add a diversity measure (average distance) when generating samples.

In [None]:
ls = [0.0001, 0.001, 0.01, 0.1, 1]
for ll in ls:
    classifier_acc = []
    classifier_nsamples = []
    for _ in range(10):
        gaal = GAALTrainer(traindatasettype='cifar10', testdatasettype='cifar10', 
                        generatorpath='./gans/cifar10-dcgan/generator_model_1500.h5', 
                        oraclepath='./oracles/cifar10ha.h5', 
                        n_samples_end=350, threshold=1e-8, 
                        start_samples=50, latent_dim=100, 
                        diversity = 'avgdist', L = ll) # diversity options
        classifier_acc.append(gaal.learner_acc_history)
        classifier_nsamples.append(gaal.n_samples)

# Save results to pickle
# result_gaal = [f'GAAL train and test on CIFAR10 with Diversity, L = {ll}', classifier_acc, classifier_nsamples, gaal.x_train_end, gaal.y_train_end]
# with open(f'./results/result_gaal_cifar10_d_{ll}.pkl', 'wb') as file:
#     pickle.dump(result_gaal, file)

## Active SVM

In [None]:
# Active SVM Example - train and test on mnist 5 & 7
classifier_acc_mnistmnist_activesvm = []
classifier_nsamples_mnistmnist_activesvm = []
for _ in range(10):
    asvm = ActiveSVMTrainer(traindatasettype='mnist', testdatasettype='mnist', 
                            oraclepath='./oracles/mnist57.h5', n_samples_end=350, 
                            start_samples=50)
    classifier_acc_mnistmnist_activesvm.append(asvm.learner_acc_history)
    classifier_nsamples_mnistmnist_activesvm.append(asvm.n_samples)

# # Save results to pickle
# result_activesvm_mnistmnist = ['Active SVM train and test on mnist', classifier_acc_mnistmnist_activesvm, classifier_nsamples_mnistmnist_activesvm]
# with open('./results/result_activesvm_mnistmnist.pkl', 'wb') as file:
#     pickle.dump(result_activesvm_mnistmnist, file)

## Random Sampling

In [None]:
# Random Sampling Example - train and test on mnist 5 & 7
classifier_acc_mnistmnist_randomsampling = []
classifier_nsamples_mnistmnist_randomsampling = []
for _ in range(10):
    rnd = RandomTrainer(traindatasettype='mnist', testdatasettype='mnist', 
                            n_samples_end=350, 
                            start_samples=50)
    classifier_acc_mnistmnist_randomsampling.append(rnd.learner_acc_history)
    classifier_nsamples_mnistmnist_randomsampling.append(rnd.n_samples)

# # Save results to pickle
# result_random_mnistmnist = ['Random Sampling train and test on mnist', classifier_acc_mnistmnist_randomsampling, classifier_nsamples_mnistmnist_randomsampling]
# with open('./results/result_random_mnistmnist.pkl', 'wb') as file:
#     pickle.dump(result_random_mnistmnist, file)

## Simple GAN

In [None]:
# Simple GAN Example - train and test on mnist 5 & 7
classifier_acc_mnistmnist_simplegan = []
classifier_nsamples_mnistmnist_simplegan = []
for _ in range(10):
    simplegan = SimpleGANTrainer(traindatasettype='mnist', testdatasettype='mnist', 
                       generatorpath='./gans/mnist/generator_model_1000.h5', 
                       oraclepath='./oracles/mnist57.h5', 
                       n_samples_end=350, threshold=1e-8, 
                       start_samples=50, latent_dim=100)
    classifier_acc_mnistmnist_simplegan.append(simplegan.learner_acc_history)
    classifier_nsamples_mnistmnist_simplegan.append(simplegan.n_samples)

# # Save results to pickle
# result_simplegan_mnistmnist = ['Simple GAN train and test on mnist', classifier_acc_mnistmnist_simplegan, classifier_nsamples_mnistmnist_simplegan, simplegan.x_train_end, simplegan.y_train_end]
# with open('./results/result_simplegan_mnistmnist.pkl', 'wb') as file:
#     pickle.dump(result_simplegan_mnistmnist, file)


# Fully Supervised Training

In [None]:
# Fully Supervised Example - train and test on mnist
fs_mnistmnist = FSTrainer('mnist','mnist')
fs_testacc = fs_mnistmnist.learner_acc_history
fs_nsamples = fs_mnistmnist.x_train.shape[0]

# # Save results to pickle
# result_fs_mnistmnist = ['Fully Supervised train and test on mnist',fs_testacc,fs_nsamples]
# with open('./results/result_fs_mnistmnist.pkl', 'wb') as file:
#     pickle.dump(result_fs_mnistmnist, file)

# Train a GAN

## DC-GAN

In [None]:
savepath = './gan/cifar10/'
gantrainer = GANTrainer('cifar10', savepath, latent_dim = 100, 
                        n_epochs = 1000, batchsize = 256, retries = 5)

## Wasserstein GAN

In [None]:
savepath = './wgan/cifar10/'
wgan = WGANTrainer('cifar10', savepath, latent_dim = 100, n_epochs=1000, 
                   batchsize=256, retries = 5, n_critic = 5)