<a href="https://colab.research.google.com/github/souravraha/galaxy/blob/experimental/Lightning_Tune.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Install packages and import

In [1]:
# If you are running on Google Colab, uncomment below to install the necessary dependencies 
# before beginning the exercise.

print('Setting up colab environment')
!pip uninstall -y -q pyarrow
!pip install -q ray[debug] lightning-bolts
!pip install -U -q ray[tune]
# !pip install cloud-tpu-client==0.10 https://storage.googleapis.com/tpu-pytorch/wheels/torch_xla-1.8-cp37-cp37m-linux_x86_64.whl

# # A hack to force the runtime to restart, needed to include the above dependencies.
# print('Done installing! Restarting via forced crash (this is not an issue).')
# import os
# os._exit(0)

Setting up colab environment
[K     |████████████████████████████████| 51.6MB 56kB/s 
[K     |████████████████████████████████| 256kB 48.3MB/s 
[K     |████████████████████████████████| 10.1MB 49.9MB/s 
[K     |████████████████████████████████| 133kB 59.0MB/s 
[K     |████████████████████████████████| 81kB 12.4MB/s 
[K     |████████████████████████████████| 1.3MB 34.5MB/s 
[K     |████████████████████████████████| 81kB 13.2MB/s 
[K     |████████████████████████████████| 71kB 10.8MB/s 
[K     |████████████████████████████████| 3.1MB 38.2MB/s 
[K     |████████████████████████████████| 819kB 31.3MB/s 
[K     |████████████████████████████████| 235kB 58.3MB/s 
[K     |████████████████████████████████| 143kB 57.3MB/s 
[K     |████████████████████████████████| 296kB 56.5MB/s 
[K     |████████████████████████████████| 92kB 12.2MB/s 
[K     |████████████████████████████████| 10.6MB 43.6MB/s 
[K     |████████████████████████████████| 829kB 50.3MB/s 
[K     |█████████████████████

In [1]:
# If you are running on Google Colab, please install TensorFlow 2.0 by uncommenting below..

try:
  # %tensorflow_version only exists in Colab.
  %tensorflow_version 2.x
except Exception:
  pass

In [2]:
# __import_lightning_begin__
import math
# import shutil
import numpy as np              #
from matplotlib import pyplot as plt
from itertools import cycle
import torch
from torch import nn
import pytorch_lightning as pl
# from filelock import FileLock
from torch.utils.data import DataLoader, random_split
from torch.nn import functional as F
import torchvision.datasets as datasets
from torchvision.models import resnet18
from pl_bolts.models.self_supervised.resnets import BasicBlock                  # problem with resnet18
from pl_bolts.models.gans import DCGAN
from pl_bolts.models.gans.dcgan.components import DCGANDiscriminator, DCGANGenerator
import torchmetrics as tm
from torchvision import transforms
import os
from os.path import basename
# __import_lightning_end__

# __import_tune_begin__
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.utilities.cloud_io import load as pl_load
from ray import tune
from ray.tune import CLIReporter, JupyterNotebookReporter
from ray.tune.schedulers import ASHAScheduler, PopulationBasedTraining
from ray.tune.integration.pytorch_lightning import TuneReportCallback, \
    TuneReportCheckpointCallback
# __import_tune_end__



# Download and extract data

In [3]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [4]:
# 'a': 1Cjcw2EWorhdhJSGoWOdxsEUDxvl943dt, 'b': 15yXXC4h5VsytP3Ak1jfUSjQhdgP2s23K, 'c': 1vuQ-pLzoKT4Hd_V7949r9eND9E2fB_u_,
# 'd': , 'e': 1wFuasvb7PthxXtMUlsD13uzYHWlWt06H, 'f': 17l6H61tLAu26zGuei38r_T5ssjbYUeaJ, 
# 'g': 1SxQVosWeEjY3Pyn8LRXA11rLnZ9HK_7B, 'h': 1Atau0RH4oyLAiYReW-G9a8l9pUNltglF, 'i': 15lEgsR1p00KSHieaT9a1nkbJ86pDxwgp, 
# 'j': 1m0EQUbqZZeyl76XsQIKWU5Qd7jGmmWhB, 'k': , 'l': 1meTDi4aeWfdChOiXeLtUOGhjVDVu000e

# !rm -rf images
!gdown --id 17l6H61tLAu26zGuei38r_T5ssjbYUeaJ
!tar zxf ./model_f.tgz

# def prepare_data(data_dir: str = '/content'):
#     gdown.download('https://drive.google.com/uc?id=17l6H61tLAu26zGuei38r_T5ssjbYUeaJ', data_dir+'/model_f.tgz', quiet=True)
    
#     temp = tarfile.open(data_dir+'/model_f.tgz', 'r|gz')
#     temp.extractall()
#     temp.close()

Downloading...
From: https://drive.google.com/uc?id=17l6H61tLAu26zGuei38r_T5ssjbYUeaJ
To: /content/model_f.tgz
2.34GB [00:35, 65.5MB/s]


# DataModule
This creates dataloaders which need to be supplied to train, validate or test the module we have.

In [3]:
class NpyDataModule(pl.LightningDataModule):

    def __init__(self, config, img_width: int = 150, data_dir: str = '/content/images/'):
        super().__init__()
        # This method is not implemented
        # self.save_hyperparameters()
        self.batch_size = config['batch_size']
        self.data_dir = os.path.expanduser(data_dir)
        
        GLOBAL = np.load('/content/drive/MyDrive/git_repos/forging_new_worlds/GLOBAL_VALS_F.npz')
        self.transform = transforms.Compose([
            # transforms.ConvertImageDtype(torch.float32),
            # Can't use this, divides values by dtype.max, use float() in npyloader instead
            transforms.RandomHorizontalFlip(),
            transforms.RandomVerticalFlip(),
            transforms.Normalize(mean=(GLOBAL['VALS'][0],), std=(GLOBAL['VALS'][1],)),
            # this shift-scales the pixel values, N(mu, sigma) -> N(0, 1)
            transforms.Resize(img_width, transforms.InterpolationMode.NEAREST),
        ])
    
    @staticmethod
    def npy_loader(path):
        # s=np.load(path).astype('float',copy=False)
        return torch.from_numpy(np.load(path)).unsqueeze(0).float()
        # Convert to tenssor first, and then to float, otherwise final dtype 
        # would be float64, which would raise errors in conv layers      ###### type as

    def setup(self, stage: str = None):
        if stage in ('fit', None):
            self.full_set = datasets.DatasetFolder(os.path.join(self.data_dir,'train'),
                                                   self.npy_loader, 
                                                   ('.npy'), 
                                                   self.transform,
                                                   )
            self.train_set, self.val_set = random_split(self.full_set, [60000, 15000])            
            # self.val_set = datasets.DatasetFolder(os.path.join(self.data_dir,'val'), 
            #                                        self.npy_loader, 
            #                                        ('.npy'), 
            #                                        self.transform,
            #                                        )
            self.dims = tuple(self.train_set[0][0].shape)

        if stage in ('test', None):
            self.test_set = datasets.DatasetFolder(os.path.join(self.data_dir,'val'), 
                                                   self.npy_loader, 
                                                   ('.npy'), 
                                                   self.transform,
                                                   )
            self.dims = getattr(self, 'dims', self.test_set[0][0].shape)
    
    def train_dataloader(self):
        return DataLoader(self.train_set, self.batch_size, shuffle=True, num_workers=2)

    def val_dataloader(self):
        return DataLoader(self.val_set, self.batch_size, shuffle=True, num_workers=2)

    def test_dataloader(self):
        return DataLoader(self.test_set, self.batch_size, shuffle=True, num_workers=2)


# ResNet:
We modify a ResNet slightly for our purpose.

In [4]:
class LensResnet(pl.LightningModule):

    def __init__(self, config, image_channels: int = 1, num_classes: int = 3, **kwargs):
        super().__init__()
        self.save_hyperparameters(ignore=config)
        self.learning_rate = config['learning_rate']

        # init a pretrained resnet
        self.backbone = resnet18(num_classes = self.hparams.num_classes)
        self.backbone.conv1 = nn.Conv2d(self.hparams.image_channels, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
        #  can't merely change the in_channels since weights have to changed as well
        self.backbone.fc = nn.Sequential(
            nn.Dropout(0.5),
            self.backbone.fc
        )
        # self.backbone.
        # metrics = tm.MetricCollection([
        #     # tm.AUROC(self.hparams.num_classes, average='weighted'),
        #     # tm.ROC(self.hparams.num_classes),
        # #     tm.PrecisionRecallCurve(self.hparams.num_classes),
        # ])
        # self.train_metrics = metrics.clone(prefix='ResNet/train/')
        # self.val_metrics = metrics.clone(prefix='ResNet/val/')

    def configure_optimizers(self):
        return torch.optim.Adam(self.backbone.parameters(), self.learning_rate)

    def forward(self, x):
        return F.softmax(self.backbone(x), 1)

    def training_step(self, batch, batch_idx):
        imgs, labels = batch
        self.log('ResNet/train/auroc', tm.functional.auroc(self(imgs),labels, average='weighted', num_classes=self.hparams.num_classes))
        loss = F.cross_entropy(self.backbone(imgs), labels)
        self.log('ResNet/train/loss', loss)
        #  keep only scalars here, for no errors
        return loss

    def validation_step(self, batch, batch_idx):
        imgs, labels = batch
        self.log('ResNet/val/loss', F.cross_entropy(self.backbone(imgs), labels))
        #  keep only scalars here, for no errors
        return {'pred': self(imgs), 'target': labels}

    def validation_epoch_end(self, Listofdicts):
        prediction, target = torch.cat([x['pred'] for x in Listofdicts]), torch.cat([x['target'] for x in Listofdicts])
        aurocTensor = tm.functional.auroc(prediction, target, num_classes=self.hparams.num_classes, average=None)
        self.log('ResNet/val/auroc', aurocTensor.min())
        fprList, tprList, _ = tm.functional.roc(prediction, target, num_classes=self.hparams.num_classes)
        
        f = plt.figure()
        colors = cycle(['aqua', 'darkorange', 'cornflowerblue'])
        for i, color in zip(range(self.hparams.num_classes), colors):
            plt.plot(fprList[i].cpu(), tprList[i].cpu(), color=color,
                    label='ROC curve of class {0} (area = {1:0.2f})'
                    ''.format(i, aurocTensor[i].cpu()))
        plt.plot([0, 1], [0, 1], 'k--')
        plt.xlim([0.0, 1.0])
        plt.ylim([0.0, 1.05])
        plt.xlabel('False Positive Rate')
        plt.ylabel('True Positive Rate')
        plt.title('Multi-class ROC')
        plt.legend(loc='lower right')

        self.logger.experiment.add_figure('ResNet/val/ROC', f)
        f.savefig(str(tune.get_trial_dir())+'ROC_epoch_'+str(self.current_epoch)+'.pdf')

Trying out Auto Tuning of learning rate

In [None]:
# Can't work with multiple optimizers
config = {
    'learning_rate': 1e-4, 'batch_size': 128, 'feature_maps': 64,
}
dm = NpyDataModule(config)
generator = StackGAN(config)

trainer = pl.Trainer(
    # logger=,
    # checkpoint_callback=,
    default_root_dir='./drive/MyDrive/Logs/', 
    gpus=1,
    auto_select_gpus=True, 
    # tpu_cores=
    progress_bar_refresh_rate=1,
    # fast_dev_run=,
    max_epochs=5,
    # max_time=,
    # limit_train_batches=,
    # flush_logs_every_n_steps=,
    # log_every_n_steps=,
    # resume_from_checkpoint='./drive/MyDrive/Logs/lr_find_temp_model.ckpt',
    auto_lr_find = True,
    # auto_scale_batch_size=True,
    # prepare_data_per_node=,
    )

# Run learning rate finder
lr_finder = trainer.tuner.lr_find(generator, dm)

# # Results can be found in
# # lr_finder.results

# Plot with
fig = lr_finder.plot(suggest=True)
fig.show()

# Pick point based on plot, or get suggestion
new_lr = lr_finder.suggestion()

# # update hparams of the model
# model.hparams.lr = new_lr

# # Fit model
# trainer.fit(model)

# Tune ResNet hyperparameters:
Here we tune hyperparameters as we train our modified ResNet.

In [None]:
# __tune_train_checkpoint_begin
def train_LensResnet_tune_checkpoint(config,
                                    checkpoint_dir=None,
                                    num_epochs=10,
                                    num_gpus=1):
    data_dir = os.path.expanduser('/content/images/')

    trainer = pl.Trainer(
        max_epochs=num_epochs,
        prepare_data_per_node = False,
        num_sanity_val_steps=0,
        # If fractional GPUs passed in, convert to int.
        gpus=math.ceil(num_gpus),
        # tpu_cores = 8,
        logger=TensorBoardLogger(
            save_dir=tune.get_trial_dir(), name='', version='.'),
        # progress_bar_refresh_rate=1,
        callbacks=[
            TuneReportCheckpointCallback(
                metrics={
                    'loss': 'ResNet/val/loss',
                    'auroc': 'ResNet/val/auroc',
                },
                filename='checkpoint',
                # on='validation_end'
            )
        ],
        stochastic_weight_avg=True,
    )

    dm = NpyDataModule(config, data_dir)
    
    if checkpoint_dir:
        # Currently, this leads to errors:
        # model = LensResnet.load_from_checkpoint(
        #     os.path.join(checkpoint, 'checkpoint'))
        # Workaround:
        ckpt = pl_load(
            os.path.join(checkpoint_dir, 'checkpoint'),
            map_location=lambda storage, loc: storage)
        model = LensResnet._load_model_state(
            ckpt, config=config, 
            # data_dir=data_dir
            )
        trainer.current_epoch = ckpt['epoch']
    else:
        model = LensResnet(config, 
                        #  data_dir
                         )

    trainer.fit(model, dm)

# __tune_train_checkpoint_end__


# __tune_asha_begin__
def tune_LensResnet_asha(num_samples=10, num_epochs=10, gpus_per_trial=1):
    # config = {
    #     'learning_rate': tune.choice([1e-5, 1e-4, 1e-3, 1e-2]),
    #     'batch_size': tune.choice([128, 64, 32]),
    # }

    best = {'batch_size': 128, 'learning_rate': 0.0001}    

    scheduler = ASHAScheduler(
        max_t=num_epochs,
        grace_period=1,
        reduction_factor=2)

    reporter = CLIReporter(
        # overwrite=True,
        parameter_columns=['learning_rate', 'batch_size'],
        metric_columns=['loss', 'auroc', 'training_iteration'])

    analysis = tune.run(
        tune.with_parameters(
            train_LensResnet_tune_checkpoint,
            num_epochs=num_epochs,
            num_gpus=gpus_per_trial),
        name='LensResNet_J',                                                    # Change with dataset change
        metric='auroc',
        mode='max',
        config=best,
        resources_per_trial={
            'cpu': 2,
            'gpu': gpus_per_trial,
            # 'tpu': 8,
        },
        # num_samples=num_samples,
        local_dir='./drive/MyDrive/Logs',
        scheduler=scheduler,
        progress_reporter=reporter,
        fail_fast = True,
        restore = '/content/drive/MyDrive/Logs/LensResNet_J/train_LensResnet_tune_checkpoint_c74d9_00003_3_batch_size=128,learning_rate=0.0001_2021-07-08_01-03-43/checkpoint_epoch=2-step=1406',
        # '/content/drive/MyDrive/Logs/tune_LensResnet_asha_model_f/train_LensResnet_tune_checkpoint_e32ba_00000_0_batch_size=64,learning_rate=0.0001_2021-07-06_03-33-10/checkpoint_epoch=14-step=4689',
        # resume=True,
        )

    print('Best hyperparameters found were: ', analysis.best_config)
# __tune_asha_end__


# __tune_pbt_begin__
def tune_LensResnet_pbt(num_samples=10, num_epochs=10, gpus_per_trial=1):
    config = {
        'learning_rate': 1e-3,
        'batch_size': 64,
    }

    scheduler = PopulationBasedTraining(
        perturbation_interval=4,
        hyperparam_mutations={
            'learning_rate': [1e-5, 1e-4, 1e-3, 1e-2],
            'batch_size': [32, 64, 128]
        })

    reporter = CLIReporter(
        # overwrite=True,
        parameter_columns=['learning_rate', 'batch_size'],
        metric_columns=['loss', 'auroc', 'training_iteration'])

    analysis = tune.run(
        # resume=True,
        tune.with_parameters(
            train_LensResnet_tune_checkpoint,
            num_epochs=num_epochs,
            num_gpus=gpus_per_trial),
        metric='auroc',
        mode='max',
        resources_per_trial={
            'cpu': 2,
            'gpu': gpus_per_trial,
            # 'tpu': 8,
        },
        fail_fast = True,
        config=config,
        num_samples=num_samples,
        scheduler=scheduler,
        progress_reporter=reporter,
        local_dir='./drive/MyDrive/Logs' ,
        name='tune_LensResnet_pbt')

    print('Best hyperparameters found were: ', analysis.best_config)

# __tune_pbt_end__


if __name__ == '__main__':
    import argparse

    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--smoke-test', action='store_true', help='Finish quickly for testing')
    args, _ = parser.parse_known_args()

    if args.smoke_test:
        tune_LensResnet_asha(num_samples=1, num_epochs=6, gpus_per_trial=1)
        tune_LensResnet_pbt(num_samples=1, num_epochs=6, gpus_per_trial=1)
    else:
        # ASHA scheduler
        tune_LensResnet_asha(num_samples=12, num_epochs=20, gpus_per_trial=1)
        # Population based training
        # tune_LensResnet_pbt(num_samples=10, num_epochs=10, gpus_per_trial=1)

# Experimental

Check if we can load checkpoints

In [5]:
def pretrained_LensResNets():
    ckptf = pl_load(os.path.join(
        '/content/drive/MyDrive/Logs/LensResNet_F/train_LensResnet_tune_checkpoint_efb38_00000_0_2021-07-09_04-20-04/checkpoint_epoch=9-step=16407',
        'checkpoint'), map_location=lambda storage, loc: storage)
    f = LensResnet._load_model_state(ckptf, config={'batch_size': 32, 'learning_rate': 0.0001})
    ckptj = pl_load(os.path.join(
        '/content/drive/MyDrive/Logs/LensResNet_J/train_LensResnet_tune_checkpoint_21355_00000_0_2021-07-09_16-17-17/checkpoint_epoch=27-step=4687',
        'checkpoint'), map_location=lambda storage, loc: storage)
    j = LensResnet._load_model_state(ckptj, config={'batch_size': 128, 'learning_rate': 0.0001})
    return f, j

def post_plotting(ax):
    ax.plot([0, 1], [0, 1], 'k--')
    ax.legend(loc='lower right')

In [6]:
class Stage1(DCGAN):
    def __init__(self, config, num_classes: int = 3, **kwargs):
        super().__init__(feature_maps_gen=config['ngf'], feature_maps_disc=config['ndf'], learning_rate=config['learning_rate'], **kwargs)
        self.save_hyperparameters(ignore=config)

        self.generator.add_module('emb', nn.Embedding(self.hparams.num_classes, self.hparams.latent_dim))

        self.modelF, self.modelJ = pretrained_LensResNets()

    def forward(self, noise, labels = None):
        if labels is None:
            labels = getattr(self, 'labels', torch.randint(self.hparams.num_classes, noise.shape[:-1], device=self.device))  # last dimension is the hidden dimension
        inp = noise.mul(self.generator.emb(labels))
        return self.generator(inp.view(-1, inp.shape[-1], 1, 1))

    def training_step(self, batch, batch_idx, optimizer_idx):
        real, self.labels = batch

        # Train discriminator
        result = None
        if optimizer_idx == 0:
            result = self._disc_step(real)

        # Train generator
        if optimizer_idx == 1:
            result = self._gen_step(real)

        return result

    def _disc_step(self, real: torch.Tensor) -> torch.Tensor:
        disc_loss = self._get_disc_loss(real)
        self.log('Stage1/D/train/loss', disc_loss, on_epoch=True)
        return disc_loss

    def _gen_step(self, real: torch.Tensor) -> torch.Tensor:
        gen_loss = self._get_gen_loss(real)
        self.log('Stage1/G/train/loss', gen_loss, on_epoch=True)
        return gen_loss

    def validation_step(self, batch, batch_idx):
        imgs, labels = batch
        out_64 = self(torch.randn(labels.shape[0], self.hparams.latent_dim).type_as(imgs), labels)
        out = F.interpolate(out_64, 150)
        return {'predF': self.modelF(out), 'predJ': self.modelJ(out), 'target': labels}

    def validation_epoch_end(self, listofDicts):
        target = torch.cat([x['target'] for x in listofDicts])
        f, ax = plt.subplots(1,2, subplot_kw={'xlim': [0,1], 'xlabel': 'False Positive Rate', 
                                              'ylim': [0,1.05], 'ylabel': 'True Positive Rate'},
                             figsize=[11, 5])
        letters = ['F', 'J']
        for l in range(2):
            prediction = torch.cat([x['pred' + str(letters[l])] for x in listofDicts])
            aurocTensor = tm.functional.auroc(prediction, target, num_classes=self.hparams.num_classes, average=None)
            self.log('Stage1/ResNet(' + str(letters[l]) + ')/val/auroc', aurocTensor.min())
            fprList, tprList, _ = tm.functional.roc(prediction, target, num_classes=self.hparams.num_classes)
            
            colors = cycle(['aqua', 'darkorange', 'cornflowerblue'])
            for i, color in zip(range(self.hparams.num_classes), colors):
                ax[l].plot(fprList[i].cpu(), tprList[i].cpu(), color=color,
                        label='ROC curve of class {0} (area = {1:0.2f})'
                        ''.format(i, aurocTensor[i].cpu()))
            post_plotting(ax[l])
            ax[l].set_title('Multi-class ROC (' + str(letters[l]) + ')')
        
        f.tight_layout()
        self.logger.experiment.add_figure('Stage1/ResNet/val/ROC', f)
        f.savefig(str(tune.get_trial_dir()) + 'ROC_epoch_' + str(self.current_epoch) + '.pdf')

Let's train Stage1 GAN

In [None]:
# __tune_train_checkpoint_begin
def train_Stage1_tune_checkpoint(config, checkpoint_dir=None, num_epochs=10, num_gpus=1):
    # data_dir = os.path.expanduser('/content/images/')
    trainer = pl.Trainer(
        # accumulate_grad_batches=2,
        # limit_train_batches=0.20,
        # limit_val_batches=0.20,
        # num_sanity_val_steps=-1,
        max_epochs=num_epochs,
        prepare_data_per_node = False,
        # If fractional GPUs passed in, convert to int.
        gpus=math.ceil(num_gpus),
        # tpu_cores = 8,
        logger=TensorBoardLogger(save_dir=tune.get_trial_dir(), name='', version='.'),
        # progress_bar_refresh_rate=1,
        callbacks=[
                   TuneReportCheckpointCallback(
                       {'lossG': 'Stage1/G/train/loss', 
                        'lossD': 'Stage1/D/train/loss', 
                        'auroc': 'Stage1/ResNet(F)/val/auroc', 
                        'auroc_cross': 'Stage1/ResNet(J)/val/auroc',
                        },
                   ),
        ],
        # stochastic_weight_avg=True,
        # works with only one optimizer
        # benchmark=True,
    )
    dm = NpyDataModule(config, 64)
    if checkpoint_dir:
        # Currently, this leads to errors:
        # model = Stage1.load_from_checkpoint(
        #     os.path.join(checkpoint, 'checkpoint'))
        # Workaround:
        ckpt = pl_load(os.path.join(checkpoint_dir, 'checkpoint'),
                       map_location=lambda storage, loc: storage)
        model = Stage1._load_model_state(ckpt, config=config)
        trainer.current_epoch = ckpt['epoch']
    else:
        model = Stage1(config)

    trainer.fit(model, dm)
# __tune_train_checkpoint_end__


# __tune_asha_begin__
def tune_Stage1_asha(num_samples=10, num_epochs=10, gpus_per_trial=1):
    analysis = tune.run(
        tune.with_parameters(train_Stage1_tune_checkpoint,
                             num_epochs=num_epochs,
                             num_gpus=gpus_per_trial),
        name='Stage1_F',
        metric='auroc',
        mode='max',
        config={'learning_rate': tune.choice([1e-5, 1e-4, 1e-3]),
                'ngf': tune.choice([128, 64, 32]),
                'ndf': tune.choice([128, 64, 32]),
                'batch_size': tune.choice([128, 64, 32]),
                },
        resources_per_trial={'cpu': 2,
                             'gpu': gpus_per_trial,
                             },
        num_samples=num_samples,
        local_dir='./drive/MyDrive/Logs',
        scheduler = ASHAScheduler(max_t=num_epochs, grace_period=1,  reduction_factor=2),
        progress_reporter=JupyterNotebookReporter(
            overwrite=True,
            parameter_columns=['learning_rate', 'ngf', 'ndf', 'batch_size'],
            metric_columns=['lossG', 'lossD', 'auroc', 'auroc_cross', 'training_iteration'],
            ),
        fail_fast = True,
        # resume='PROMPT',
        )

    print('Best hyperparameters found were: ', analysis.best_config)

# __tune_asha_end__


# __tune_pbt_begin__
def tune_Stage1_pbt(num_samples=10, num_epochs=10, gpus_per_trial=1):
    analysis = tune.run(
        tune.with_parameters(train_Stage1_tune_checkpoint,
                             num_epochs=num_epochs,
                             num_gpus=gpus_per_trial),
        name='Stage1_F',
        metric='auroc',
        mode='max',
        config={'learning_rate': tune.choice([1e-5, 1e-4, 1e-3]),
                'ngf': tune.choice([128, 64, 32]),
                'ndf': tune.choice([128, 64, 32]),
                'batch_size': tune.choice([128, 64, 32]),
                },
        resources_per_trial={'cpu': 2,
                             'gpu': gpus_per_trial,
                             },
        # num_samples=num_samples,
        local_dir='./drive/MyDrive/Logs',
        scheduler = PopulationBasedTraining(perturbation_interval=4,
                                        hyperparam_mutations={
                                            'learning_rate': tune.choice([1e-5, 1e-4, 1e-3]),
                                            # 'ngf': tune.choice([128, 64, 32]),
                                            # 'ndf': tune.choice([128, 64, 32]),
                                            'batch_size': tune.choice([128, 64, 32]),
                                            },
                                        ),
        progress_reporter=JupyterNotebookReporter(
            overwrite=True,
            parameter_columns=['learning_rate', 'ngf', 'ndf', 'batch_size'],
            metric_columns=['lossG', 'lossD', 'auroc', 'auroc_cross', 'training_iteration'],
            ),
        fail_fast = True,
        resume='PROMPT',
        )

    print('Best hyperparameters found were: ', analysis.best_config)

# __tune_pbt_end__


if __name__ == '__main__':
    import argparse

    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--smoke-test', action='store_true', help='Finish quickly for testing')
    args, _ = parser.parse_known_args()

    if args.smoke_test:
        tune_Stage1_asha(num_samples=1, num_epochs=6, gpus_per_trial=1)
        tune_Stage1_pbt(num_samples=1, num_epochs=6, gpus_per_trial=1)
    else:
        # ASHA scheduler
        tune_Stage1_asha(num_samples=12, num_epochs=5, gpus_per_trial=1)
        # Population based training
        # tune_Stage1_pbt(num_samples=8, num_epochs=5, gpus_per_trial=1)

Trial name,status,loc,learning_rate,ngf,ndf,batch_size,lossG,lossD,auroc,auroc_cross,training_iteration
train_Stage1_tune_checkpoint_a5d5c_00000,RUNNING,172.28.0.2:2194,0.0001,128,64,128,6.60173,0.00808735,0.495684,0.495665,1.0
train_Stage1_tune_checkpoint_a5d5c_00001,PENDING,,0.001,128,128,64,,,,,
train_Stage1_tune_checkpoint_a5d5c_00002,PENDING,,0.0001,64,32,128,,,,,
train_Stage1_tune_checkpoint_a5d5c_00003,PENDING,,0.0001,32,64,128,,,,,
train_Stage1_tune_checkpoint_a5d5c_00004,PENDING,,0.0001,128,64,128,,,,,
train_Stage1_tune_checkpoint_a5d5c_00005,PENDING,,0.0001,64,32,128,,,,,
train_Stage1_tune_checkpoint_a5d5c_00006,PENDING,,0.001,64,128,128,,,,,
train_Stage1_tune_checkpoint_a5d5c_00007,PENDING,,0.0001,32,128,64,,,,,
train_Stage1_tune_checkpoint_a5d5c_00008,PENDING,,0.001,64,32,64,,,,,
train_Stage1_tune_checkpoint_a5d5c_00009,PENDING,,1e-05,64,128,32,,,,,


[2m[36m(pid=2194)[0m Epoch 0: 100%|██████████| 587/587 [08:38<00:00,  1.13it/s, loss=3.31, v_num=.]
                                                             [A
Epoch 1:   0%|          | 0/587 [00:00<?, ?it/s, loss=3.31, v_num=.]
Epoch 1:   3%|▎         | 20/587 [00:16<07:55,  1.19it/s, loss=3.28, v_num=.]
Epoch 1:   7%|▋         | 40/587 [00:36<08:25,  1.08it/s, loss=3.28, v_num=.]
Epoch 1:  10%|█         | 60/587 [00:56<08:17,  1.06it/s, loss=3.29, v_num=.]
Epoch 1:  14%|█▎        | 80/587 [01:16<08:02,  1.05it/s, loss=3.32, v_num=.]
Epoch 1:  17%|█▋        | 100/587 [01:35<07:46,  1.04it/s, loss=3.34, v_num=.]
Epoch 1:  20%|██        | 120/587 [01:54<07:27,  1.04it/s, loss=3.45, v_num=.]
Epoch 1:  24%|██▍       | 140/587 [02:13<07:05,  1.05it/s, loss=3.39, v_num=.]
Epoch 1:  27%|██▋       | 160/587 [02:31<06:45,  1.05it/s, loss=3.41, v_num=.]
Epoch 1:  31%|███       | 180/587 [02:50<06:25,  1.05it/s, loss=3.45, v_num=.]
Epoch 1:  34%|███▍      | 200/587 [03:08<06:05,  1.06it

In [None]:
def pretrained_Stage1s():
    ckptf = pl_load(os.path.join(
        '/content/drive/MyDrive/Logs/Stage1_F/train_Stage1_tune_checkpoint_',
        'checkpoint'), map_location=lambda storage, loc: storage)
    f = Stage1._load_model_state(ckptf, config={'batch_size': , 'ngf': , 'ndf': , 'learning_rate': })
    ckptj = pl_load(os.path.join(
        '/content/drive/MyDrive/Logs/Stage1_J/train_Stage1_tune_checkpoint_',
        'checkpoint'), map_location=lambda storage, loc: storage)
    j = Stage1._load_model_state(ckptj, config={'batch_size': , 'ngf': , 'ndf': , 'learning_rate': })
    return f, j

In [None]:
class Generator2(nn.Module):
    def __init__(self, ngf: int = 128, image_channels: int = 1):
        super().__init__()

        ker, strd = 4, 2
        pad = int((ker - 2)/2)
        res_ker, res_strd, res_pad = 3, 1, 1
        
        # 64 -> 32
        self.preprocessing = nn.Sequential(
            nn.Conv2d(image_channels, ngf, ker, strd, pad, bias=False),
            nn.ReLU(True)
        )
        # residuals
        self.residual = nn.Sequential(
            BasicBlock(ngf, ngf),
            BasicBlock(ngf, ngf),
            BasicBlock(ngf, ngf),
            BasicBlock(ngf, ngf),
            BasicBlock(ngf, ngf),
            BasicBlock(ngf, ngf),
        )
        self.ending_residual = nn.Sequential(
            nn.Conv2d(ngf, ngf, res_ker, res_strd, res_pad, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True)
        )

        # at this part, add the residual inputs from after the preprocessing

        image_width = 150 # upscaling should be factor of 2 increase
        mode = 'nearest' # upscaling method is nearest-neighbour
        self.main = nn.Sequential(
            # 32 -> 64
            nn.Upsample(image_width//2, mode=mode),
            nn.Conv2d(ngf, ngf*4, res_ker, res_strd, res_pad, bias=False),
            nn.BatchNorm2d(ngf*4),
            nn.ReLU(True),
            # 64 -> 128
            nn.Upsample(image_width, mode=mode),
            nn.Conv2d(ngf*4, image_channels, res_ker, res_strd, res_pad, bias=False),
            nn.Tanh()
        )

    def forward(self, in_x):
        x_p = self.preprocessing(in_x)
        x_r = x_p
        x_r = self.residual(x_r)
        x_r = self.ending_residual(x_r)
        # large residual connections
        x_f = x_r + x_p
        return self.main(x_f)

In [None]:
class Stage2(DCGAN):
    def __init__(self, checkpoint_dir: str, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.save_hyperparameters()

        self.generator = Generator2(self.hparams.feature_maps_gen, self.hparams.image_channels)
        
        extra = self.discriminator._make_disc_block(self.hparams.feature_maps_disc * 2, self.hparams.feature_maps_disc * 2)
        l = list(self.discriminator.disc)
        l.insert(2, extra)
        self.discriminator.disc = nn.Sequential(*l)
        self.discriminator.apply(self._weights_init)        

    def forward(self, noise):
        return self.generator(noise)

    def _disc_step(self, real: torch.Tensor) -> torch.Tensor:
        disc_loss = self._get_disc_loss(real)
        self.log('Stage2/D/train/loss', disc_loss, on_epoch=True)
        return disc_loss

    def _gen_step(self, real: torch.Tensor) -> torch.Tensor:
        gen_loss = self._get_gen_loss(real)
        self.log('Stage2/G/train/loss', gen_loss, on_epoch=True)
        return gen_loss

    def _get_noise(self, n_samples: int, latent_dim: int):
        # Currently, this leads to errors:
        # model = StackGAN.load_from_checkpoint(
        #     os.path.join(checkpoint, 'checkpoint'))
        # Workaround:
        self.lowres = getattr(self, 'lowres', pretrained_Stage1s()[0])          # Choose the data this was pre-trained on
        return self.lowres(torch.randn(n_samples, latent_dim, device=self.device))

    def validation_step(self, batch, batch_idx):
        imgs, labels = batch
        out = self(torch.randn(labels.shape[0], self.hparams.latent_dim).type_as(imgs), labels)
        # out = F.interpolate(out_64, 150)
        return {'predF': self.modelF(out), 'predJ': self.modelJ(out), 'target': labels}

    def validation_epoch_end(self, listofDicts):
        target = torch.cat([x['target'] for x in listofDicts])
        f, ax = plt.subplots(1,2, subplot_kw={'xlim': [0,1], 'xlabel': 'False Positive Rate', 
                                              'ylim': [0,1.05], 'ylabel': 'True Positive Rate'},
                             figsize=[11, 5])
        letters = ['F', 'J']
        for l in range(2):
            prediction = torch.cat([x['pred' + str(letters[l])] for x in listofDicts])
            aurocTensor = tm.functional.auroc(prediction, target, num_classes=self.hparams.num_classes, average=None)
            self.log('Stage2/ResNet(' + str(letters[l]) + ')/val/auroc', aurocTensor.min())
            fprList, tprList, _ = tm.functional.roc(prediction, target, num_classes=self.hparams.num_classes)
            
            colors = cycle(['aqua', 'darkorange', 'cornflowerblue'])
            for i, color in zip(range(self.hparams.num_classes), colors):
                ax[l].plot(fprList[i].cpu(), tprList[i].cpu(), color=color,
                        label='ROC curve of class {0} (area = {1:0.2f})'
                        ''.format(i, aurocTensor[i].cpu()))
            post_plotting(ax[l])
            ax[l].set_title('Multi-class ROC (' + str(letters[l]) + ')')
        
        f.tight_layout()
        self.logger.experiment.add_figure('Stage2/ResNet/val/ROC', f)
        f.savefig(str(tune.get_trial_dir()) + 'ROC_epoch_' + str(self.current_epoch) + '.pdf')

In [None]:
# __tune_train_checkpoint_begin
def train_Stage2_tune_checkpoint(config, checkpoint_dir=None, num_epochs=10, num_gpus=1):
    # data_dir = os.path.expanduser('/content/images/')
    trainer = pl.Trainer(
        # accumulate_grad_batches=2,
        # limit_train_batches=0.20,
        # limit_val_batches=0.20,
        # num_sanity_val_steps=-1,
        max_epochs=num_epochs,
        prepare_data_per_node = False,
        # If fractional GPUs passed in, convert to int.
        gpus=math.ceil(num_gpus),
        # tpu_cores = 8,
        logger=TensorBoardLogger(save_dir=tune.get_trial_dir(), name='', version='.'),
        # progress_bar_refresh_rate=1,
        callbacks=[
                   TuneReportCheckpointCallback(
                       {'lossG': 'Stage2/G/train/loss', 
                        'lossD': 'Stage2/D/train/loss', 
                        'auroc': 'Stage2/ResNet(F)/val/auroc', 
                        'auroc_cross': 'Stage2/ResNet(J)/val/auroc',
                        },
                   ),
        ],
        # stochastic_weight_avg=True,
        # works with only one optimizer
        # benchmark=True,
    )
    dm = NpyDataModule(config, 64)
    if checkpoint_dir:
        # Currently, this leads to errors:
        # model = Stage2.load_from_checkpoint(
        #     os.path.join(checkpoint, 'checkpoint'))
        # Workaround:
        ckpt = pl_load(os.path.join(checkpoint_dir, 'checkpoint'),
                       map_location=lambda storage, loc: storage)
        model = Stage2._load_model_state(ckpt, config=config)
        trainer.current_epoch = ckpt['epoch']
    else:
        model = Stage2(config)

    trainer.fit(model, dm)
# __tune_train_checkpoint_end__


# __tune_asha_begin__
def tune_Stage2_asha(num_samples=10, num_epochs=10, gpus_per_trial=1):
    analysis = tune.run(
        tune.with_parameters(train_Stage2_tune_checkpoint,
                             num_epochs=num_epochs,
                             num_gpus=gpus_per_trial),
        name='Stage2_F',
        metric='auroc',
        mode='max',
        config={'learning_rate': tune.choice([1e-5, 1e-4, 1e-3]),
                'ngf': tune.choice([128, 64, 32]),
                'ndf': tune.choice([128, 64, 32]),
                'batch_size': tune.choice([128, 64, 32]),
                },
        resources_per_trial={'cpu': 2,
                             'gpu': gpus_per_trial,
                             },
        num_samples=num_samples,
        local_dir='./drive/MyDrive/Logs',
        scheduler = ASHAScheduler(max_t=num_epochs, grace_period=1,  reduction_factor=2),
        progress_reporter=CLIReporter(
            # overwrite=True,
            parameter_columns=['learning_rate', 'ngf', 'ndf', 'batch_size'],
            metric_columns=['lossG', 'lossD', 'auroc', 'auroc_cross', 'training_iteration'],
            ),
        fail_fast = True,
        # resume='PROMPT',
        )

    print('Best hyperparameters found were: ', analysis.best_config)

# __tune_asha_end__


# __tune_pbt_begin__
def tune_Stage2_pbt(num_samples=10, num_epochs=10, gpus_per_trial=1):
    analysis = tune.run(
        tune.with_parameters(train_Stage2_tune_checkpoint,
                             num_epochs=num_epochs,
                             num_gpus=gpus_per_trial),
        name='Stage2_F',
        metric='auroc',
        mode='max',
        config={'learning_rate': tune.choice([1e-5, 1e-4, 1e-3]),
                'ngf': tune.choice([128, 64, 32]),
                'ndf': tune.choice([128, 64, 32]),
                'batch_size': tune.choice([128, 64, 32]),
                },
        resources_per_trial={'cpu': 2,
                             'gpu': gpus_per_trial,
                             },
        num_samples=num_samples,
        local_dir='./drive/MyDrive/Logs',
        scheduler = PopulationBasedTraining(perturbation_interval=4,
                                        hyperparam_mutations={
                                            'learning_rate': tune.choice([1e-5, 1e-4, 1e-3]),
                                            # 'ngf': tune.choice([128, 64, 32]),
                                            # 'ndf': tune.choice([128, 64, 32]),
                                            'batch_size': tune.choice([128, 64, 32]),
                                            },
                                        ),
        progress_reporter=CLIReporter(
            # overwrite=True,
            parameter_columns=['learning_rate', 'ngf', 'ndf', 'batch_size'],
            metric_columns=['lossG', 'lossD', 'auroc', 'auroc_cross', 'training_iteration'],
            ),
        fail_fast = True,
        # resume='PROMPT',
        )

    print('Best hyperparameters found were: ', analysis.best_config)

# __tune_pbt_end__


if __name__ == '__main__':
    import argparse

    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--smoke-test', action='store_true', help='Finish quickly for testing')
    args, _ = parser.parse_known_args()

    if args.smoke_test:
        tune_Stage2_asha(num_samples=1, num_epochs=6, gpus_per_trial=1)
        tune_Stage2_pbt(num_samples=1, num_epochs=6, gpus_per_trial=1)
    else:
        # ASHA scheduler
        tune_Stage2_asha(num_samples=12, num_epochs=3, gpus_per_trial=1)
        # Population based training
        # tune_Stage2_pbt(num_samples=8, num_epochs=5, gpus_per_trial=1)

# StackGAN:
Here we define the GAN module, that we shall use to generate representative images.

In [None]:
class StackGAN(pl.LightningModule):
    def __init__(self, config, noise_size: int = 100, image_width = 64,
                    num_classes: int = 3, image_channels: int = 1, b1: float = 0.5, **kwargs):
        super().__init__()
        self.save_hyperparameters(ignore = config)
        self.feature_maps = config['feature_maps']
        self.lr = config['learning_rate']
        # -------------------------------------
        # Need to create a subclass because we couldn't simply add/remove a layer;
        # there are two inputs of the superclas' forward method.
        self.G1 = DCGANGenerator(self.hparams.noise_size, self.feature_maps, self.hparams.image_channels).apply(self._weights_init)
        l = list(self.G1.gen[0])
        del l[1]
        self.G1.gen[0] = nn.Sequential(*l)
        self.G1.add_module('label_emb', nn.Embedding(self.hparams.num_classes, self.hparams.noise_size))
        # ------------------------------------
        self.D1 = DCGANDiscriminator(self.feature_maps, self.hparams.image_channels).apply(self._weights_init)
        # -------------------------------------
        self.G2 = Generator2(self.hparams.image_channels, self.feature_maps).apply(self._weights_init)
        # -------------------------------------
        self.D2 = DCGANDiscriminator(self.feature_maps, self.hparams.image_channels)
        #  steps to mutate the instance, not the class definition
        extra = self.D2._make_disc_block(self.feature_maps * 2, self.feature_maps * 2)
        l = list(self.D2.disc)
        l.insert(2, extra)
        self.D2.disc = nn.Sequential(*l)
        self.D2.apply(self._weights_init)
        # No need for subclassing as the forward method need not be modified.
        # -------------------------------------
        self.R = LensResnet(config, num_classes = 4).apply(self._weights_init)
        # -------------------------------------
        self.pretrained = LensResnet(config)
        ckpt = pl_load(os.path.join(
            '/content/drive/MyDrive/Logs/tune_LensResnet_asha_model_j/train_LensResnet_tune_checkpoint_e38cb_00000_0_batch_size=128,learning_rate=0.001_2021-07-06_17-52-11/checkpoint_epoch=17-step=1406',
            # '/content/drive/MyDrive/Logs/tune_LensResnet_asha_model_f/train_LensResnet_tune_checkpoint_e32ba_00000_0_batch_size=64,learning_rate=0.0001_2021-07-06_03-33-10/checkpoint_epoch=14-step=4689',
            'checkpoint'),
            map_location=lambda storage, loc: storage)
        self.pretrained._load_model_state(ckpt)
        # -------------------------------------
        self.criterion1 = nn.BCELoss()
        self.criterion2 = nn.CrossEntropyLoss()

    @staticmethod
    def _weights_init(m):
        classname = m.__class__.__name__
        if classname.find('Conv') != -1:
            torch.nn.init.normal_(m.weight, 0.0, 0.02)
        elif classname.find('BatchNorm') != -1:
            torch.nn.init.normal_(m.weight, 1.0, 0.02)
            torch.nn.init.zeros_(m.bias)

    def forward(self, noise, labels = None):
        if labels is None:
            labels = torch.randint(self.hparams.num_classes, noise.shape[:-1])                           # last dimension is the hidden dimension
        inp = torch.mul(noise, self.G1.label_emb(labels))
        out1 = self.G1(inp.view(-1, inp.shape[-1], 1, 1))
        out2 = self.G2(out1.detach())
        return out2, out1

    def training_step(self, batch, batch_idx, optimizer_idx):
        imgs, labels = batch
        temp2, temp1 = self(torch.randn(labels.shape[0], self.hparams.noise_size).type_as(imgs), labels)

        if optimizer_idx == 0:
            loss = self.criterion1(self.D1(temp1), torch.ones_like(labels, dtype=torch.float32))
            self.log('G1/train/loss/disc', loss)
            loss.add_(self.criterion2(self.R.backbone(self.G2(temp1)), labels))
            self.log('G1/train/loss/full', loss)

        elif optimizer_idx == 1:
            real, fake = self.D1(F.interpolate(imgs, self.hparams.image_width, mode='nearest')), self.D1(temp1.detach())
            prediction, target = torch.cat((real, fake)), torch.cat((torch.ones_like(real),torch.zeros_like(fake)))
            loss = self.criterion1(prediction, target)
            self.log('D1/train/loss', loss)

        elif optimizer_idx == 2:
            loss = self.criterion1(self.D2(temp2), torch.ones_like(labels, dtype=torch.float32))
            self.log('G2/train/loss/disc', loss)
            loss.add_(self.criterion2(self.R.backbone(temp2), labels))
            self.log('G2/train/loss/full', loss)

        elif optimizer_idx == 3:
            real, fake = self.D2(imgs), self.D2(temp2.detach())
            prediction, target = torch.cat((real, fake)), torch.cat((torch.ones_like(real),torch.zeros_like(fake)))
            loss = self.criterion1(prediction, target)
            self.log('D2/train/loss', loss)

        elif optimizer_idx == 4:
            real, fake = self.R.backbone(imgs), self.R.backbone(temp2.detach())
            prediction, target = torch.cat((real, fake)), torch.cat((labels, self.hparams.num_classes * torch.ones_like(labels)))
            loss = self.criterion2(prediction, target)
            self.log('R/train/loss', loss)
        
        return loss

    def configure_optimizers(self):
        opt_g1 = torch.optim.Adam(self.G1.parameters(), self.lr, (self.hparams.b1, 0.999))
        opt_d1 = torch.optim.Adam(self.D1.parameters(), self.lr, (self.hparams.b1, 0.999))
        opt_g2 = torch.optim.Adam(self.G2.parameters(), self.lr, (self.hparams.b1, 0.999))
        opt_d2 = torch.optim.Adam(self.D2.parameters(), self.lr, (self.hparams.b1, 0.999))
        opt_r = torch.optim.Adam(self.R.parameters(), self.lr, (self.hparams.b1, 0.999))
        return opt_g1, opt_d1, opt_g2, opt_d2, opt_r

    def validation_step(self, batch, batch_idx):
        imgs, labels = batch
        temp2, _ = self(torch.randn(labels.shape[0], self.hparams.noise_size).type_as(imgs), labels)
        return {'pred': self.pretrained(temp2.detach()), 'target': labels}

    def validation_epoch_end(self, listofDicts):
        prediction, target = torch.cat([x['pred'] for x in listofDicts]), torch.cat([x['target'] for x in listofDicts])
        aurocTensor = tm.functional.auroc(prediction, target, num_classes=self.hparams.num_classes, average=None)
        self.log('Pre/val/auroc', aurocTensor.min())
        fprList, tprList, _ = tm.functional.roc(prediction, target, num_classes=self.hparams.num_classes)
        
        f = plt.figure()
        colors = cycle(['aqua', 'darkorange', 'cornflowerblue'])
        for i, color in zip(range(self.hparams.num_classes), colors):
            plt.plot(fprList[i].cpu(), tprList[i].cpu(), color=color,
                    label='ROC curve of class {0} (area = {1:0.2f})'
                    ''.format(i, aurocTensor[i].cpu()))
        plt.plot([0, 1], [0, 1], 'k--')
        plt.xlim([0.0, 1.0])
        plt.ylim([0.0, 1.05])
        plt.xlabel('False Positive Rate')
        plt.ylabel('True Positive Rate')
        plt.title('Multi-class ROC')
        plt.legend(loc='lower right')

        self.logger.experiment.add_figure('StackGAN/val/ROC', f)
        f.savefig(str(tune.get_trial_dir())+'ROC_epoch_'+str(self.current_epoch)+'.pdf')

# Tune StackGAN:
Here we tune hyperparameters for generating images that resemble the images from input.

In [None]:
# __tune_train_checkpoint_begin
def train_StackGAN_tune_checkpoint(config,
                                   checkpoint_dir=None,
                                   num_epochs=10,
                                   num_gpus=1):
    data_dir = os.path.expanduser('/content/images/')
    trainer = pl.Trainer(
        # accumulate_grad_batches=2,
        # limit_train_batches=0.20,
        # limit_val_batches=0.20,
        num_sanity_val_steps=-1,
        max_epochs=num_epochs,
        prepare_data_per_node = False,
        # If fractional GPUs passed in, convert to int.
        gpus=math.ceil(num_gpus),
        # tpu_cores = 8,
        logger=TensorBoardLogger(save_dir=tune.get_trial_dir(), name='', version='.'),
        # progress_bar_refresh_rate=1,
        callbacks=[
            TuneReportCheckpointCallback(
                metrics={
                    'lossG1': 'G1/train/loss/full',
                    'lossG2': 'G2/train/loss/full',
                    'lossD1': 'D1/train/loss',
                    'lossD2': 'D2/train/loss',
                    'lossR': 'R/train/loss',
                    'auroc': 'Pre/val/auroc',
                },
                filename='checkpoint',
                # on='training_end'
            )
        ],
        # stochastic_weight_avg=True,
        # works with only one optimizer
        )
    dm = NpyDataModule(config, data_dir)
    if checkpoint_dir:
        # Currently, this leads to errors:
        # model = StackGAN.load_from_checkpoint(
        #     os.path.join(checkpoint, 'checkpoint'))
        # Workaround:
        ckpt = pl_load(
            os.path.join(checkpoint_dir, 'checkpoint'),
            map_location=lambda storage, loc: storage)
        model = StackGAN._load_model_state(
            ckpt, config=config, 
            # data_dir=data_dir
            )
        trainer.current_epoch = ckpt['epoch']
    else:
        model = StackGAN(config)

    trainer.fit(model, dm)
# __tune_train_checkpoint_end__


# __tune_asha_begin__
def tune_StackGAN_asha(num_samples=10, num_epochs=10, gpus_per_trial=1):
    config = {
        'learning_rate': tune.choice([1e-4]),
        'feature_maps': tune.choice([64]),
        'batch_size': tune.choice([128, 64]),
    }

    scheduler = ASHAScheduler(
        max_t=num_epochs,
        grace_period=1,
        reduction_factor=2)

    reporter = CLIReporter(
        # overwrite=True,
        parameter_columns=['learning_rate', 'feature_maps', 'batch_size'],
        metric_columns=['lossG1', 'lossG2', 'lossD1', 'lossD2', 'lossR', 'auroc', 'training_iteration'],
        )

    analysis = tune.run(
        tune.with_parameters(
            train_StackGAN_tune_checkpoint,
            num_epochs=num_epochs,
            num_gpus=gpus_per_trial),
        name='tune_StackGAN_asha_model_j',
        metric='auroc',
        mode='max',
        config=config,
        resources_per_trial={
            'cpu': 2,
            'gpu': gpus_per_trial,
            # 'tpu': 8,
        },
        num_samples=num_samples,
        local_dir='./drive/MyDrive/Logs',
        scheduler=scheduler,
        progress_reporter=reporter,
        # restore='/content/drive/MyDrive/Logs/tune_StackGAN_1_asha_model_j/train_StackGAN_tune_checkpoint_fa25b_00000_0_batch_size=64,feature_maps=64,learning_rate=0.0001_2021-07-06_20-23-13/checkpoint_epoch=0-step=937',
        fail_fast = True,
        resume='PROMPT',
        )

    print('Best hyperparameters found were: ', analysis.best_config)

# __tune_asha_end__


# __tune_pbt_begin__
def tune_StackGAN_pbt(num_samples=10, num_epochs=10, gpus_per_trial=1):
    config = {
        'learning_rate': 1e-4,
        'feature_maps': 64,
        'batch_size': 64,
    }

    scheduler = PopulationBasedTraining(
        perturbation_interval=4,
        hyperparam_mutations={
            'learning_rate': [1e-4, 1e-3],
            'feature_maps': [64, 128],
            'batch_size': [32, 64, 128]
        })

    reporter = CLIReporter(
        # overwrite=True,
        parameter_columns=['learning_rate', 'feature_maps', 'batch_size'],
        metric_columns=['lossG1', 'lossG2', 'lossD1', 'lossD2', 'lossR', 'auroc', 'training_iteration'],
        )

    analysis = tune.run(
        # resume=True,
        tune.with_parameters(
            train_StackGAN_tune_checkpoint,
            num_epochs=num_epochs,
            num_gpus=gpus_per_trial),
        name='tune_StackGAN_pbt_model_j',
        metric='auroc',
        mode='max',
        resources_per_trial={
            'cpu': 2,
            'gpu': gpus_per_trial,
            # 'tpu': 8,
        },
        config=config,
        num_samples=num_samples,
        scheduler=scheduler,
        progress_reporter=reporter,
        local_dir='./drive/MyDrive/Logs',
        # restore='/content/drive/MyDrive/Logs/tune_StackGAN_1_asha_model_j/train_StackGAN_tune_checkpoint_fa25b_00000_0_batch_size=64,feature_maps=64,learning_rate=0.0001_2021-07-06_20-23-13/checkpoint_epoch=0-step=937',
        fail_fast = True,
        # resume='PROMPT',
        )

    print('Best hyperparameters found were: ', analysis.best_config)

# __tune_pbt_end__


if __name__ == '__main__':
    import argparse

    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--smoke-test', action='store_true', help='Finish quickly for testing')
    args, _ = parser.parse_known_args()

    if args.smoke_test:
        tune_StackGAN_asha(num_samples=1, num_epochs=6, gpus_per_trial=1)
        tune_StackGAN_pbt(num_samples=1, num_epochs=6, gpus_per_trial=1)
    else:
        # ASHA scheduler
        tune_StackGAN_asha(num_samples=2, num_epochs=1, gpus_per_trial=1)
        # Population based training
        # tune_StackGAN_pbt(num_samples=8, num_epochs=5, gpus_per_trial=1)