<a href="https://colab.research.google.com/github/souravraha/galaxy/blob/experimental/Lightning_Tune.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Prerequisites

## Install/uninstall packages

In [None]:
# If you are running on Google Colab, uncomment below to install the necessary dependencies 
# before beginning the exercise.

print('Setting up colab environment')
# !pip uninstall -y -q pyarrow
!pip install -q lightning-bolts GPy
!pip install -q ray[debug] ray[default]
!pip install -U -q ray[tune]
# !pip install cloud-tpu-client==0.10 https://storage.googleapis.com/tpu-pytorch/wheels/torch_xla-1.8-cp37-cp37m-linux_x86_64.whl

# # A hack to force the runtime to restart, needed to include the above dependencies.
# print('Done installing! Restarting via forced crash (this is not an issue).')
# import os
# os._exit(0)

In [None]:
# If you are running on Google Colab, please install TensorFlow 2.0 by uncommenting below..

try:
  # %tensorflow_version only exists in Colab.
  %tensorflow_version 2.x
except Exception:
  pass

## Import libraries

In [None]:
import os
import math
import numpy as np
from itertools import cycle
from matplotlib import pyplot as plt
# ------------------------------------
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader, random_split
# ------------------------------------
# from torchvision.models import resnet18
from torchvision.utils import save_image
from torchvision import transforms, datasets
# ------------------------------------
import torchmetrics as tm
# ------------------------------------
import pytorch_lightning as pl
from pl_bolts.models.gans import DCGAN
from pytorch_lightning.loggers import TensorBoardLogger
from pl_bolts.models.self_supervised.resnets import BasicBlock, resnet18
from pytorch_lightning.utilities.cloud_io import load as pl_load
# ------------------------------------
from ray import tune
from ray.tune.stopper import TrialPlateauStopper
from ray.tune import CLIReporter, JupyterNotebookReporter
from ray.tune.schedulers import ASHAScheduler, PopulationBasedTraining
from ray.tune.schedulers.pb2 import PB2
from ray.tune.integration.pytorch_lightning import TuneReportCallback, TuneReportCheckpointCallback
# ------------------------------------

# Download and extract data

In [None]:
from google.colab import drive
# drive.flush_and_unmount()
drive.mount('/content/drive', True, 600000)

Here choose the model you wish to use for training/testing. Don't forget to make modifications in the following sections:

1.   GLOBAL in class definition of npyImageData.
2.   correct assignment of metric keys while defining the training wrapper for Tune.
3.   name of the experiment initiated/resumed.

In [None]:
# 'a': 1Cjcw2EWorhdhJSGoWOdxsEUDxvl943dt, 'b': 15yXXC4h5VsytP3Ak1jfUSjQhdgP2s23K, 'c': 1vuQ-pLzoKT4Hd_V7949r9eND9E2fB_u_,
# 'd': , 'e': 1wFuasvb7PthxXtMUlsD13uzYHWlWt06H, 'f': 17l6H61tLAu26zGuei38r_T5ssjbYUeaJ, 
# 'g': 1SxQVosWeEjY3Pyn8LRXA11rLnZ9HK_7B, 'h': 1Atau0RH4oyLAiYReW-G9a8l9pUNltglF, 'i': 15lEgsR1p00KSHieaT9a1nkbJ86pDxwgp, 
# 'j': 1m0EQUbqZZeyl76XsQIKWU5Qd7jGmmWhB, 'k': , 'l': 1meTDi4aeWfdChOiXeLtUOGhjVDVu000e

# !rm -rf images
!gdown --id 17l6H61tLAu26zGuei38r_T5ssjbYUeaJ
!tar zxf ./model_f.tgz
!rm ./model_f.tgz

# def prepare_data(data_dir: str = '/content'):
#     with FileLock(os.path.expanduser(data_dir+'.lock')):
#         gdown.download('https://drive.google.com/uc?id=17l6H61tLAu26zGuei38r_T5ssjbYUeaJ', data_dir+'/model_f.tgz', quiet=True)
        
#         temp = tarfile.open(data_dir+'/model_f.tgz', 'r|gz')
#         temp.extractall()
#         temp.close()

# Class definitions

In [None]:
GLOBAL = np.load('/content/drive/MyDrive/git_repos/forging_new_worlds/GLOBAL_VALS_J.npz')
GLOBAL['VALS'].tolist()

## DataModule
This creates dataloaders which need to be supplied to train, validate or test the module we have.

In [None]:
class npyImageData(pl.LightningDataModule):
    def __init__(self, config, img_width: int = 150, data_dir: str = '/content/images/'):
        super().__init__()
        # This method is not implemented
        # self.save_hyperparameters()
        self.bs = config['bs']
        self.data_dir = os.path.expanduser(data_dir)
        
        # Change the source file containing mean and stdv when changing dataset ------------------------------------------------------
        self.transform = transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.RandomVerticalFlip(),
            # F : [71.75926373866668, 96.139484964214, 5.0, 961.0]
            # J : [50.271541595458984, 94.8838882446289]
            transforms.Normalize(mean=(5,), std=(961,)),
            transforms.Normalize(mean=(0.5,), std=(0.5,)),
            # this shift-scales the pixel values -> [-1, 1]
            transforms.Resize(img_width, transforms.InterpolationMode.NEAREST),
        ])

    ## @staticmethod
    # def download_data(data_dir):
    #     with FileLock(os.path.expanduser("~/.data.lock")):
    #         return MNIST(data_dir, train=True, download=True, transform=transform)

    # def prepare_data(self):
    #     mnist_train = self.download_data(self.data_dir)

    #     self.mnist_train, self.mnist_val = random_split(
    #         mnist_train, [55000, 5000])

    @staticmethod
    def npy_loader(path):
        # s=np.load(path).astype('float',copy=False)
        return torch.from_numpy(np.load(path)).unsqueeze(0).float()
        # Convert to tenssor first, and then to float, otherwise final dtype 
        # would be float64, which would raise errors in conv layers      ###### type as

    def setup(self, stage: str = None):
        if stage in ('fit', None):
            self.train_set = datasets.DatasetFolder(os.path.join(self.data_dir,'train'), 
                self.npy_loader, ('.npy'), self.transform,)
            # self.train_set, self.val_set = random_split(self.full_set, [60000, 15000])            
            self.val_set = datasets.DatasetFolder(os.path.join(self.data_dir,'val'),  
                self.npy_loader, ('.npy'), self.transform,)
            self.dims = tuple(self.train_set[0][0].shape)

        if stage in ('test', None):
            self.test_set = datasets.DatasetFolder(os.path.join(self.data_dir,'val'),  
                self.npy_loader, ('.npy'), self.transform,)
            self.dims = getattr(self, 'dims', self.test_set[0][0].shape)
    
    def train_dataloader(self):
        return DataLoader(self.train_set, self.bs, shuffle=True, num_workers=os.cpu_count(), pin_memory=False)

    def val_dataloader(self):
        return DataLoader(self.val_set, self.bs, shuffle=True, num_workers=os.cpu_count(), pin_memory=False)

    def test_dataloader(self):
        return DataLoader(self.test_set, self.bs, shuffle=True, num_workers=os.cpu_count(), pin_memory=False)

## ResNet
We modify a ResNet slightly for our purpose.

In [32]:
class LensResnet(pl.LightningModule):
    def __init__(self, config, image_channels: int = 1, num_classes: int = 3, **kwargs):
        super().__init__()
        self.save_hyperparameters(ignore=config)
        self.lr = config['lr']

        # init a pretrained resnet
        self.backbone = resnet18(num_classes=self.hparams.num_classes, **kwargs)
        self.backbone.conv1 = nn.Conv2d(self.hparams.image_channels, 64, kernel_size=7, stride=2, padding=3, bias=False)
        #  can't merely change the in_channels since weights have to changed as well
        self.backbone.fc = nn.Sequential(nn.Dropout(0.5), self.backbone.fc)

        self.train_metrics = tm.MetricCollection([tm.AUROC(self.hparams.num_classes, average='weighted'),],
            prefix='LensResnet/train/'
        )
        self.val_metrics = tm.MetricCollection([tm.AUROC(self.hparams.num_classes, average=None),
                                                tm.ROC(self.hparams.num_classes),],
        )

    def configure_optimizers(self):
        return torch.optim.Adam(self.backbone.parameters(), self.lr)

    def forward(self, x, prob=False):
        logits = self.backbone.fc(self.backbone(x)[-1])
        return F.softmax(logits, 1) if prob else logits

    def training_step(self, batch, batch_idx):
        imgs, labels = batch
        logits = self(imgs)
        loss = F.cross_entropy(logits, labels)
        self.log('LensResnet/train/loss', loss)
        #  keep only scalars here, for no errors
        
        preds = F.softmax(logits, 1)
        self.train_metrics.update(preds, labels)
        if len(torch.unique(torch.stack(self.train_metrics.AUROC.target))) == self.hparams.num_classes:
            self.log_dict(self.train_metrics, prog_bar=True)
            # self.train_metrics.reset()
            # self.log_dict automatically resets at the end of epoch
        return loss

    def validation_step(self, batch, batch_idx):
        imgs, labels = batch
        logits = self(imgs)
        loss = F.cross_entropy(logits, labels)
        self.log('LensResnet/val/loss', loss)
        #  keep only scalars here, for no errors
        
        preds = F.softmax(logits, 1)
        self.val_metrics.update(preds, labels)
        # if batch_idx % 5 ==0:
        #     print('val', batch_idx, len(self.val_metrics.AUROC.target))

    def validation_epoch_end(self, Listofdicts):
        # prediction, target = torch.cat([x['pred'] for x in Listofdicts]), torch.cat([x['target'] for x in Listofdicts])
        # aurocTensor = tm.functional.auroc(prediction, target, num_classes=self.hparams.num_classes, average=None)
        try:
            scoresDict = self.val_metrics.compute()
        except Exception as f:
            # print(f)
            print(batch_idx, torch.unique(torch.stack(self.val_metrics.AUROC.target)).tolist())
        else:
            self.log('LensResnet/val/auroc', scoresDict['AUROC'].min(), prog_bar=True)
            fprList, tprList, _ = scoresDict['ROC']  # tm.functional.roc(prediction, target, num_classes=self.hparams.num_classes)
            
            f = plt.figure()
            colors = cycle(['aqua', 'darkorange', 'cornflowerblue'])
            for i, color in zip(range(self.hparams.num_classes), colors):
                plt.plot(fprList[i].cpu(), tprList[i].cpu(), color=color,
                        label='ROC curve of class {0} (area = {1:0.2f})'
                        ''.format(i, scoresDict['AUROC'][i]))
            plt.plot([0, 1], [0, 1], 'k--')
            plt.xlim([0.0, 1.0])
            plt.ylim([0.0, 1.05])
            plt.xlabel('False Positive Rate')
            plt.ylabel('True Positive Rate')
            plt.title('Multi-class ROC')
            plt.legend(loc='lower right')

            self.logger.experiment.add_figure('LensResnet/val/ROC', f)
            f.savefig(str(tune.get_trial_dir())+'ROC_epoch_'+str(self.current_epoch)+'.pdf')
        finally:
            self.val_metrics.reset()


    def on_fit_end(self):
        delattr(self, 'train_metrics')
        delattr(self, 'val_metrics')

## Stage 1
Here we subclass a DCGAN to create our low resolution GAN.

In [None]:
BEST_RESNET_F = '/content/drive/MyDrive/Logs/LensResNet_F/train_LensResnet_tune_checkpoint_efb38_00000_0_2021-07-09_04-20-04/checkpoint_epoch=9-step=16407'
BEST_RESNET_J = '/content/drive/MyDrive/Logs/LensResNet_J/train_LensResnet_tune_checkpoint_21355_00000_0_2021-07-09_16-17-17/checkpoint_epoch=27-step=4687'

In [None]:
m = DCGAN()

In [None]:
inp = torch.rand(20, 100, 20, 30)

In [None]:
class Stage1(DCGAN):
    def __init__(self, config, num_classes: int = 3, **kwargs):
        super().__init__(feature_maps_gen=config['n_fmaps'], feature_maps_disc=config['n_fmaps'], learning_rate=config['learning_rate'])
        self.save_hyperparameters(ignore=config)

        self.generator.add_module('emb', nn.Embedding(self.hparams.num_classes, self.hparams.latent_dim))

        # These are better as attributes, instead of being returned by a method
        self.modelF = getattr(self, 'modelF', LensResnet.load_from_checkpoint(os.path.join(BEST_RESNET_F, 'checkpoint')).eval())
        self.modelF.backbone.fc = nn.Identity()

        self.modelJ = getattr(self, 'modelJ', LensResnet.load_from_checkpoint(os.path.join(BEST_RESNET_J, 'checkpoint')).eval())
        self.modelJ.backbone.fc = nn.Identity()

        self.val_metrics = tm.MetricCollection(
            {
                'ResNet(F)' : tm.FID(self.modelF.backbone),
                'ResNet(J)' : tm.FID(self.modelJ.backbone),
            },
            prefix='Stage1/',
            postfix='/val/FID',
        )

        self.extract = False

    def forward(self, noise, labels = None):
        if labels is None:
            labels = getattr(self, 'labels', 
                             torch.randint(self.hparams.num_classes, noise.shape[:-1], device=self.device))  # last dimension is the hidden dimension
        inp = noise.mul(self.generator.emb(labels))
        feat = [inp.view(-1, inp.shape[-1], 1, 1)]
        if self.extract:
            for i in self.generator.gen:
                feat.append(i(feat[-1]))
            return feat
        else:
            return self.generator(feat[-1])

    def training_step(self, batch, batch_idx, optimizer_idx):
        real, self.labels = batch

        # Train discriminator
        result = None
        if optimizer_idx == 0:
            result = self._disc_step(real)

        # Train generator
        if optimizer_idx == 1:
            result = self._gen_step(real)

        del self.labels
        return result

    def _disc_step(self, real):
        disc_loss = self._get_disc_loss(real)
        self.log('Stage1/D/train/loss', disc_loss)
        return disc_loss

    def _gen_step(self, real):
        gen_loss = self._get_gen_loss(real)
        self.log('Stage1/G/train/loss', gen_loss)
        return gen_loss

    def validation_step(self, batch, batch_idx):
        imgs, labels = batch
        self.val_metrics.update(imgs, real=True)
        out = F.interpolate(self(self._get_noise(labels.shape[0], self.hparams.latent_dim), labels), 150)
        self.val_metrics.update(out, real=False)
        # return {'predF': self.modelF(out), 'predJ': self.modelJ(out), 'target': labels}

    def validation_epoch_end(self, listofDicts):
        # target = torch.cat([x['target'] for x in listofDicts])
        # f, ax = plt.subplots(1,2, subplot_kw={'xlim': [0,1], 'xlabel': 'False Positive Rate', 
        #                                       'ylim': [0,1.05], 'ylabel': 'True Positive Rate'},
        #                      figsize=[11, 5])
        # letters = ['F', 'J']
        # for l in range(2):
        #     prediction = torch.cat([x['pred' + str(letters[l])] for x in listofDicts])
        #     aurocTensor = tm.functional.auroc(prediction, target, num_classes=self.hparams.num_classes, average=None)
        #     self.log('Stage1/ResNet(' + str(letters[l]) + ')/val/auroc', au)
        #     fprList, tprList, _ = tm.functional.roc(prediction, target, num_classes=self.hparams.num_classes)
            
        #     colors = cycle(['aqua', 'darkorange', 'cornflowerblue'])
        #     for i, color in zip(range(self.hparams.num_classes), colors):
        #         ax[l].plot(fprList[i].cpu(), tprList[i].cpu(), color=color,
        #                 label='ROC curve of class {0} (area = {1:0.2f})'
        #                 ''.format(i, aurocTensor[i].cpu()))
        #     post_plotting(ax[l])
        #     ax[l].set_title('Multi-class ROC (' + str(letters[l]) + ')')
        
        # f.tight_layout()
        # self.logger.experiment.add_figure('Stage1/Res)
        # f.savefig(str(self.trainer.log_dir) + '/ROC_epoch_{:02d}.pdf'.format(self.current_epoch))
        self.log_dict(self.val_me)

        labels = torch.arange(self.hparams.num_classes, device=self.device)
        save_image(F.interpolate(self(self._get_noise(labels.shape[0], self.hparams.latent_dim), labels), 150), 
                   str(self.trainer.log_dir) + '/Fake_epoch_{:02d}.pdf'.format(self.current_epoch), 
                  #  kwargs for make_grid
                   normalize=True, value_range=(-1,1))

    def on_fit_end(self):
        delattr(self, 'modelF')
        delattr(self, 'modelJ')
        delattr(self, 'labels')
        delattr(self, 'val_metrics')

## Stage 2
Here we subclass a DCGAN to create our high resolution GAN.

In [None]:
def post_plotting(ax):
    ax.plot([0, 1], [0, 1], 'k--')
    ax.legend(loc='lower right')

In [None]:
class Generator2(nn.Module):
    def __init__(self, ngf: int = 128, image_channels: int = 1, res_depth: int = 6):
        super().__init__()

        ker, strd = 4, 2
        pad = int((ker - 2)/2)
        res_ker, res_strd, res_pad = 3, 1, 1
        
        # 64 -> 32
        self.preprocessing = nn.Sequential(
            nn.Conv2d(image_channels, ngf, ker, strd, pad, bias=False),
            nn.ReLU(True)
        )
        # residuals
        layer = []
        for _ in range(res_depth):
            layer.append(BasicBlock(ngf, ngf))
        self.residual = nn.Sequential(*layer)
        
        self.ending_residual = nn.Sequential(
            nn.Conv2d(ngf, ngf, res_ker, res_strd, res_pad, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True)
        )

        # at this part, add the residual inputs from after the preprocessing

        image_width = 150 # upscaling should be factor of 2 increase
        mode = 'nearest' # upscaling method is nearest-neighbour
        self.main = nn.Sequential(
            # 32 -> 75
            nn.Upsample(image_width//2, mode=mode),
            nn.Conv2d(ngf, ngf*4, res_ker, res_strd, res_pad, bias=False),
            nn.BatchNorm2d(ngf*4),
            nn.ReLU(True),
            # 75 -> 150
            nn.Upsample(image_width, mode=mode),
            nn.Conv2d(ngf*4, image_channels, res_ker, res_strd, res_pad, bias=False),
            nn.Tanh()
        )

    def forward(self, in_x):
        x_p = self.preprocessing(in_x)
        x_r = x_p
        x_r = self.residual(x_r)
        x_r = self.ending_residual(x_r)
        # large residual connections
        x_f = x_r + x_p
        return self.main(x_f)

In [None]:
BEST_STAGE1_F = '/content/drive/MyDrive/Logs/F/Stage1/pbt/train_Stage1_fdd99_00003_3_n_fmaps=64_2021-08-24_17-28-51/checkpoint_epoch=3-step=37499/'

In [None]:
class Stage2(DCGAN):
    def __init__(self, config, num_classes: int = 3, **kwargs):
        super().__init__(feature_maps_gen=config['n_fmaps'], feature_maps_disc=config['n_fmaps'], learning_rate=config['learning_rate'])
        self.save_hyperparameters(ignore=config)

        self.generator = Generator2(self.hparams.feature_maps_gen, self.hparams.image_channels, config['res_depth'])

        # These are better as attributes, instead of being returned by a method
        self.modelF = getattr(self, 'modelF', LensResnet.load_from_checkpoint(os.path.join(BEST_RESNET_F, 'checkpoint')).eval())
        self.modelJ = getattr(self, 'modelJ', LensResnet.load_from_checkpoint(os.path.join(BEST_RESNET_J, 'checkpoint')).eval())
        # Workaround:
        self.lowres = getattr(self, 'lowres', Stage1.load_from_checkpoint(os.path.join(BEST_STAGE1_F, 'checkpoint')).eval())
        
        metrics = tm.MetricCollection(
            [
             tm.AUROC(num_classes=self.hparams.num_classes, compute_on_step=False, average=None), 
             tm.ROC(num_classes=self.hparams.num_classes, compute_on_step=False),
            ]
        )
        self.metricsF = metrics.clone()
        self.metricsJ = metrics.clone()

    def forward(self, noise):
        return self.generator(noise)

    def training_step(self, batch, batch_idx, optimizer_idx):
        real, self.labels = batch

        # Train discriminator
        result = None
        if optimizer_idx == 0:
            result = self._disc_step(real)

        # Train generator
        if optimizer_idx == 1:
            result = self._gen_step(real)

        return result

    def _disc_step(self, real):
        disc_loss = self._get_disc_loss(real)
        self.log('Stage2/D/train/loss', disc_loss, on_epoch=True)
        return disc_loss

    def _gen_step(self, real):
        gen_loss = self._get_gen_loss(real)
        self.log('Stage2/G/train/loss', gen_loss, on_epoch=True)
        return gen_loss

    def _get_gen_loss(self, real: torch.Tensor) -> torch.Tensor:
        # Train with fake
        fake_pred = self._get_fake_pred(real)
        fake_gt = torch.ones_like(fake_pred)
        gen_loss = self.criterion(fake_pred, fake_gt)

        # class_pred =  self._get_class_pred(len(real))
        # gen_loss += F.cross_entropy(class_pred, self.labels)

        return gen_loss

    def _get_class_pred(self, batch_size) -> torch.Tensor:
        # ----------------------------------------------------------------------------------------------------------------
        return self.modelF.backbone(self(self._get_noise(batch_size, self.hparams.latent_dim)))

    def _get_noise(self, n_samples: int, latent_dim: int, labels = None):
        # can't use self in function definition
        if labels is None:
            labels = self.labels
            # getattr(self, 'labels', torch.randint(self.hparams.num_classes, (n_samples,), device=self.device))  # last dimension is the hidden dimension
        return self.lowres(super()._get_noise(n_samples, latent_dim), labels)

    def validation_step(self, batch, batch_idx):
        imgs, labels = batch
        out = self(self._get_noise(labels.shape[0], self.hparams.latent_dim, labels))
        self.metricsF.update(self.modelF(out), labels)
        self.metricsJ.update(self.modelJ(out), labels)
        # out = Fig.interpolate(out_64, 150)

    def validation_epoch_end(self, listofDicts):
        fig, ax = plt.subplots(1,2, 
            subplot_kw={'xlim': [0,1], 'xlabel': 'False Positive Rate', 'ylim': [0,1.05], 
                        'ylabel': 'True Positive Rate',
            },
            figsize=[11, 5],
        )
        for j, letter in enumerate(['F', 'J']):
            output = getattr(self, 'metrics' + letter).compute()
            self.log('Stage2/ResNet(' + letter + ')/val/auroc', output['AUROC'].min())
            fprList, tprList, _ = output['ROC']
            
            colors = cycle(['aqua', 'darkorange', 'cornflowerblue'])
            for i, color in zip(range(self.hparams.num_classes), colors):
                ax[j].plot(fprList[i].cpu(), tprList[i].cpu(), color=color,
                        label='ROC curve of class {0} (area = {1:0.2f})'
                        ''.format(i, output['AUROC'][i]))
            post_plotting(ax[j])
            ax[j].set_title('Multi-class ROC (' + letter + ')')
        
        fig.tight_layout()
        self.logger.experiment.add_figure('Stage2/ResNet/val/ROC', fig)
        fig.savefig(str(self.trainer.log_dir) + '/ROC_epoch_{:02d}.pdf'.format(self.current_epoch))

        labels = torch.arange(self.hparams.num_classes, device=self.device)
        save_image(self(self._get_noise(labels.shape[0], self.hparams.latent_dim, labels)), 
                   str(self.trainer.log_dir) + '/Fake_epoch_{:02d}.pdf'.format(self.current_epoch), 
                  #  kwargs for make_grid
                   normalize=True, value_range=(-1,1))

    def on_fit_end(self):
        delattr(self, 'modelF')
        delattr(self, 'modelJ')
        delattr(self, 'labels')
        delattr(self, 'lowres')

## StackGAN:
Here we define the full GAN module, that we shall use to generate representative images.

In [None]:
class StackGAN(pl.LightningModule):
    def __init__(self, config, noise_size: int = 100, image_width = 64,
                    num_classes: int = 3, image_channels: int = 1, b1: float = 0.5, **kwargs):
        super().__init__()
        self.save_hyperparameters(ignore = config)
        self.feature_maps = config['feature_maps']
        self.lr = config['learning_rate']
        # -------------------------------------
        # Need to create a subclass because we couldn't simply add/remove a layer;
        # there are two inputs of the superclas' forward method.
        self.G1 = DCGANGenerator(self.hparams.noise_size, self.feature_maps, self.hparams.image_channels).apply(self._weights_init)
        l = list(self.G1.gen[0])
        del l[1]
        self.G1.gen[0] = nn.Sequential(*l)
        self.G1.add_module('label_emb', nn.Embedding(self.hparams.num_classes, self.hparams.noise_size))
        # ------------------------------------
        self.D1 = DCGANDiscriminator(self.feature_maps, self.hparams.image_channels).apply(self._weights_init)
        # -------------------------------------
        self.G2 = Generator2(self.hparams.image_channels, self.feature_maps).apply(self._weights_init)
        # -------------------------------------
        self.D2 = DCGANDiscriminator(self.feature_maps, self.hparams.image_channels)
        #  steps to mutate the instance, not the class definition
        extra = self.D2._make_disc_block(self.feature_maps * 2, self.feature_maps * 2)
        l = list(self.D2.disc)
        l.insert(2, extra)
        self.D2.disc = nn.Sequential(*l)
        self.D2.apply(self._weights_init)
        # No need for subclassing as the forward method need not be modified.
        # -------------------------------------
        self.R = LensResnet(config, num_classes = 4).apply(self._weights_init)
        # -------------------------------------
        self.pretrained = LensResnet(config)
        ckpt = pl_load(os.path.join(
            '/content/drive/MyDrive/Logs/tune_LensResnet_asha_model_j/train_LensResnet_tune_checkpoint_e38cb_00000_0_batch_size=128,learning_rate=0.001_2021-07-06_17-52-11/checkpoint_epoch=17-step=1406',
            # '/content/drive/MyDrive/Logs/tune_LensResnet_asha_model_f/train_LensResnet_tune_checkpoint_e32ba_00000_0_batch_size=64,learning_rate=0.0001_2021-07-06_03-33-10/checkpoint_epoch=14-step=4689',
            'checkpoint'),
            map_location=lambda storage, loc: storage)
        self.pretrained._load_model_state(ckpt)
        # -------------------------------------
        self.criterion1 = nn.BCELoss()
        self.criterion2 = nn.CrossEntropyLoss()

    @staticmethod
    def _weights_init(m):
        classname = m.__class__.__name__
        if classname.find('Conv') != -1:
            torch.nn.init.normal_(m.weight, 0.0, 0.02)
        elif classname.find('BatchNorm') != -1:
            torch.nn.init.normal_(m.weight, 1.0, 0.02)
            torch.nn.init.zeros_(m.bias)

    def forward(self, noise, labels = None):
        if labels is None:
            labels = torch.randint(self.hparams.num_classes, noise.shape[:-1])                           # last dimension is the hidden dimension
        inp = torch.mul(noise, self.G1.label_emb(labels))
        out1 = self.G1(inp.view(-1, inp.shape[-1], 1, 1))
        out2 = self.G2(out1.detach())
        return out2, out1

    def training_step(self, batch, batch_idx, optimizer_idx):
        imgs, labels = batch
        temp2, temp1 = self(torch.randn(labels.shape[0], self.hparams.noise_size).type_as(imgs), labels)

        if optimizer_idx == 0:
            loss = self.criterion1(self.D1(temp1), torch.ones_like(labels, dtype=torch.float32))
            self.log('G1/train/loss/disc', loss)
            loss.add_(self.criterion2(self.R.backbone(self.G2(temp1)), labels))
            self.log('G1/train/loss/full', loss)

        elif optimizer_idx == 1:
            real, fake = self.D1(F.interpolate(imgs, self.hparams.image_width, mode='nearest')), self.D1(temp1.detach())
            prediction, target = torch.cat((real, fake)), torch.cat((torch.ones_like(real),torch.zeros_like(fake)))
            loss = self.criterion1(prediction, target)
            self.log('D1/train/loss', loss)

        elif optimizer_idx == 2:
            loss = self.criterion1(self.D2(temp2), torch.ones_like(labels, dtype=torch.float32))
            self.log('G2/train/loss/disc', loss)
            loss.add_(self.criterion2(self.R.backbone(temp2), labels))
            self.log('G2/train/loss/full', loss)

        elif optimizer_idx == 3:
            real, fake = self.D2(imgs), self.D2(temp2.detach())
            prediction, target = torch.cat((real, fake)), torch.cat((torch.ones_like(real),torch.zeros_like(fake)))
            loss = self.criterion1(prediction, target)
            self.log('D2/train/loss', loss)

        elif optimizer_idx == 4:
            real, fake = self.R.backbone(imgs), self.R.backbone(temp2.detach())
            prediction, target = torch.cat((real, fake)), torch.cat((labels, self.hparams.num_classes * torch.ones_like(labels)))
            loss = self.criterion2(prediction, target)
            self.log('R/train/loss', loss)
        
        return loss

    def configure_optimizers(self):
        opt_g1 = torch.optim.Adam(self.G1.parameters(), self.lr, (self.hparams.b1, 0.999))
        opt_d1 = torch.optim.Adam(self.D1.parameters(), self.lr, (self.hparams.b1, 0.999))
        opt_g2 = torch.optim.Adam(self.G2.parameters(), self.lr, (self.hparams.b1, 0.999))
        opt_d2 = torch.optim.Adam(self.D2.parameters(), self.lr, (self.hparams.b1, 0.999))
        opt_r = torch.optim.Adam(self.R.parameters(), self.lr, (self.hparams.b1, 0.999))
        return opt_g1, opt_d1, opt_g2, opt_d2, opt_r

    def validation_step(self, batch, batch_idx):
        imgs, labels = batch
        temp2, _ = self(torch.randn(labels.shape[0], self.hparams.noise_size).type_as(imgs), labels)
        return {'pred': self.pretrained(temp2.detach()), 'target': labels}

    def validation_epoch_end(self, listofDicts):
        prediction, target = torch.cat([x['pred'] for x in listofDicts]), torch.cat([x['target'] for x in listofDicts])
        aurocTensor = tm.functional.auroc(prediction, target, num_classes=self.hparams.num_classes, average=None)
        self.log('Pre/val/auroc', aurocTensor.min())
        fprList, tprList, _ = tm.functional.roc(prediction, target, num_classes=self.hparams.num_classes)
        
        f = plt.figure()
        colors = cycle(['aqua', 'darkorange', 'cornflowerblue'])
        for i, color in zip(range(self.hparams.num_classes), colors):
            plt.plot(fprList[i].cpu(), tprList[i].cpu(), color=color,
                    label='ROC curve of class {0} (area = {1:0.2f})'
                    ''.format(i, aurocTensor[i].cpu()))
        plt.plot([0, 1], [0, 1], 'k--')
        plt.xlim([0.0, 1.0])
        plt.ylim([0.0, 1.05])
        plt.xlabel('False Positive Rate')
        plt.ylabel('True Positive Rate')
        plt.title('Multi-class ROC')
        plt.legend(loc='lower right')

        self.logger.experiment.add_figure('StackGAN/val/ROC', f)
        f.savefig(str(tune.get_trial_dir())+'ROC_epoch_'+str(self.current_epoch)+'.pdf')

# Tune Hyperparameters


## ResNet
Here we tune hyperparameters as we train our modified ResNet.

In [None]:
%rm -rf ./drive/MyDrive/Logs/F/LensResnet/asha_tanh/

In [33]:
# __tune_train_checkpoint_begin
def train_LensResnet(config, checkpoint_dir=None, num_epochs=10, num_gpus=torch.cuda.device_count()):
    # print(os.cpu_count(), torch.cuda.device_count())
    kwargs = {
        'limit_train_batches' : 0.05,
        'limit_val_batches' : 0.05,
        'progress_bar_refresh_rate' : math.ceil(8250*0.05//config['bs']),
        'max_epochs' : num_epochs,
        'prepare_data_per_node' : False,
        # If fractional GPUs passed in, convert to int.
        'gpus' : math.ceil(num_gpus),
        'logger' : TensorBoardLogger(save_dir=tune.get_trial_dir(), name='', version='.'),
        'callbacks' : [
            TuneReportCheckpointCallback(
                {
                    'loss': 'LensResnet/val/loss', 
                    'auroc': 'LensResnet/val/auroc', 
                },
            ),
        ],
        'stochastic_weight_avg' : True,
        # works with only one optimizer
        # 'benchmark' : True,
        # 'precision' : 16,
        # 'gradient_clip_val' : 0.5, 
        # 'gradient_clip_algorithm' : 'value',
    }
    
    dm = npyImageData(config)                                              # Specify image width here    
    if checkpoint_dir is not None:
        kwargs['resume_from_checkpoint'] = os.path.join(checkpoint_dir, 'checkpoint')
        # model = LensResnet.load_from_checkpoint(kwargs['resume_from_checkpoint'], config=config)
    # else:

    model = LensResnet(config)
    trainer = pl.Trainer(**kwargs)

    trainer.fit(model, dm)
# __tune_train_checkpoint_end__

# __tune_pbt_begin__
def tune_LensResnet_pbt(num_samples=10, num_epochs=10, gpus_per_trial=torch.cuda.device_count()):
    # print(os.cpu_count(), torch.cuda.device_count())
    analysis = tune.run(
        tune.with_parameters(
            train_LensResnet,
            num_epochs=num_epochs,
            num_gpus=gpus_per_trial
        ),
        # Change the folder name when changing dataset--------------------------------------------------------------------------
        name='F/LensResnet/pbt_tanh',
        metric='loss',
        mode='min',
        # stop=TrialPlateauStopper('auroc'),
        resources_per_trial={'cpu': os.cpu_count(), 'gpu': gpus_per_trial},
        local_dir='./drive/MyDrive/Logs',
        # config={'lr': tune.choice([1e-4, 1e-3, 1e-5, 1e-2, 1e-6, 1e-1, 1e-7]),
        #         'bs': tune.grid_search([8, 16, 32, 64, 128]),
        #         },
        # scheduler = pbtScheduler(max_t=num_epochs, grace_period=2, reduction_factor=2),
        # Can't use RB2 as it requires mutations to be continuous
        config={'lr': 1e-4,
                'bs': 8,
                },
        scheduler = PopulationBasedTraining(time_attr='training_iteration', quantile_fraction=0.2,
                                            resample_probability=0.2,  perturbation_interval=1,
                                            hyperparam_mutations={
                                                'lr': tune.loguniform(1e-7, 1e-1),
                                                'bs': [8, 16, 32, 64, 128],
                                            },
        ),
        progress_reporter=JupyterNotebookReporter(
            overwrite=True,
            parameter_columns=['lr', 'bs'],
            metric_columns=['loss', 'auroc', 'training_iteration'],
        ),
        fail_fast = True,
        # reuse_actors=True,
        num_samples=num_samples,
        resume='PROMPT',
        # restore='/content/drive/MyDrive/Logs/delete/train_LensResnet_e42ac_00025_25_batch_size=8,learning_rate=0.01,n_fmaps=8_2021-07-28_21-16-18/checkpoint_epoch=4-step=2339',
        # restore='/content/drive/MyDrive/Logs/delete/train_LensResnet_e42ac_00027_27_batch_size=32,learning_rate=0.01,n_fmaps=8_2021-07-28_21-27-00/checkpoint_epoch=6-step=818',
        # restore='/content/drive/MyDrive/Logs/delete/train_LensResnet_e42ac_00056_56_batch_size=16,learning_rate=0.001,n_fmaps=16_2021-07-29_00-13-37/checkpoint_epoch=3-step=935',
        # restore='/content/drive/MyDrive/Logs/delete/train_LensResnet_e42ac_00060_60_batch_size=8,learning_rate=0.01,n_fmaps=16_2021-07-29_06-58-44/checkpoint_epoch=3-step=1871',
        # restore='/content/drive/MyDrive/Logs/delete/train_LensResnet_e42ac_00061_61_batch_size=16,learning_rate=0.01,n_fmaps=16_2021-07-29_02-41-53/checkpoint_epoch=4-step=1169',
        # restore='/content/drive/MyDrive/Logs/delete/train_LensResnet_e42ac_00061_61_batch_size=16,learning_rate=0.01,n_fmaps=16_2021-07-29_06-58-44/checkpoint_epoch=4-step=1169',
        # restore='/content/drive/MyDrive/Logs/delete/train_LensResnet_e42ac_00062_62_batch_size=32,learning_rate=0.01,n_fmaps=16_2021-07-29_06-58-44/checkpoint_epoch=4-step=584',
        # restore='/content/drive/MyDrive/Logs/delete/train_LensResnet_e42ac_00091_91_batch_size=16,learning_rate=0.001,n_fmaps=32_2021-07-29_15-46-04/checkpoint_epoch=2-step=701',
        # restore='/content/drive/MyDrive/Logs/delete/train_LensResnet_e42ac_00094_94_batch_size=128,learning_rate=0.001,n_fmaps=32_2021-07-29_18-06-24/checkpoint_epoch=2-step=86',
        # restore='/content/drive/MyDrive/Logs/delete/train_LensResnet_e42ac_00096_96_batch_size=16,learning_rate=0.01,n_fmaps=32_2021-07-29_18-06-24/checkpoint_epoch=2-step=701',
        # restore='/content/drive/MyDrive/Logs/delete/train_LensResnet_e42ac_00097_97_batch_size=32,learning_rate=0.01,n_fmaps=32_2021-07-29_18-06-24/checkpoint_epoch=2-step=350',
    )
    print('Best checkpoint path found is: ', analysis.best_checkpoint)

# __tune_pbt_end__

if __name__ == '__main__':
    import argparse

    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--smoke-test', action='store_true', help='Finish quickly for testing')
    args, _ = parser.parse_known_args()

    if args.smoke_test:
        tune_LensResnet_pbt(num_samples=1, num_epochs=6, gpus_per_trial=torch.cuda.device_count())
    else:
        # pbt scheduler
        tune_LensResnet_pbt(num_samples=5, num_epochs=5, gpus_per_trial=torch.cuda.device_count())

Trial name,status,loc,lr,bs,loss,auroc,training_iteration
train_LensResnet_783dc_00001,TERMINATED,,6.4e-05,32,1.10258,0.501559,4
train_LensResnet_783dc_00000,TERMINATED,,3.89287e-07,8,1.11172,0.498633,4
train_LensResnet_783dc_00004,TERMINATED,,8e-05,16,1.07906,0.524675,4
train_LensResnet_783dc_00003,TERMINATED,,9.6e-05,128,1.11952,0.512718,4
train_LensResnet_783dc_00002,ERROR,,6.4e-05,8,1.11477,0.500907,4

Trial name,# failures,error file
train_LensResnet_783dc_00002,1,/content/drive/MyDrive/Logs/F/LensResnet/pbt_tanh/train_LensResnet_783dc_00002_2_2021-08-28_20-30-00/error.txt


TuneError: ignored

In [34]:
!cat /content/drive/MyDrive/Logs/F/LensResnet/pbt_tanh/train_LensResnet_783dc_00002_2_2021-08-28_20-30-00/error.txt

Failure # 1 (occurred at 2021-08-28_22-29-51)
Traceback (most recent call last):
  File "/usr/local/lib/python3.7/dist-packages/ray/tune/trial_runner.py", line 739, in _process_trial
    results = self.trial_executor.fetch_result(trial)
  File "/usr/local/lib/python3.7/dist-packages/ray/tune/ray_trial_executor.py", line 746, in fetch_result
    result = ray.get(trial_future[0], timeout=DEFAULT_GET_TIMEOUT)
  File "/usr/local/lib/python3.7/dist-packages/ray/_private/client_mode_hook.py", line 82, in wrapper
    return func(*args, **kwargs)
  File "/usr/local/lib/python3.7/dist-packages/ray/worker.py", line 1621, in get
    raise value.as_instanceof_cause()
ray.exceptions.RayTaskError(TuneError): [36mray::ImplicitFunc.train_buffered()[39m (pid=6891, ip=172.28.0.2, repr=<types.ImplicitFunc object at 0x7fcda3945990>)
  File "/usr/local/lib/python3.7/dist-packages/ray/tune/trainable.py", line 178, in train_buffered
    result = self.train()
  File "/usr/local/lib/python3.7/dist-packages/r

In [None]:
# __tune_train_checkpoint_begin
def train_LensResnet_tune_checkpoint(config, checkpoint_dir=None,
                                     num_epochs=10,
                                     num_gpus=torch.cuda.device_count()
                                     ):
    # data_dir = os.path.expanduser('/content/images/')

    trainer = pl.Trainer(
        max_epochs=num_epochs,
        prepare_data_per_node = False,
        num_sanity_val_steps=0,
        # If fractional GPUs passed in, convert to int.
        gpus=math.ceil(num_gpus),
        # tpu_cores = 8,
        logger=TensorBoardLogger(
            save_dir=tune.get_trial_dir(), name='', version='.'),
        # progress_bar_refresh_rate=1,
        callbacks=[
            TuneReportCheckpointCallback(
                metrics={
                    'loss': 'ResNet/val/loss',
                    'auroc': 'ResNet/val/auroc',
                },
            )
        ],
        stochastic_weight_avg=True,
    )

    dm = npyImageData(config, data_dir)
    
    if checkpoint_dir:
        # Currently, this leads to errors:
        # model = LensResnet.load_from_checkpoint(
        #     os.path.join(checkpoint, 'checkpoint'))
        # Workaround:
        ckpt = pl_load(
            os.path.join(checkpoint_dir, 'checkpoint'),
            map_location=lambda storage, loc: storage)
        model = LensResnet._load_model_state(
            ckpt, config=config, 
            # data_dir=data_dir
            )
        trainer.current_epoch = ckpt['epoch']
    else:
        model = LensResnet(config, 
                        #  data_dir
                         )

    trainer.fit(model, dm)

# __tune_train_checkpoint_end__


# __tune_asha_begin__
def tune_LensResnet_asha(num_samples=10, num_epochs=10, gpus_per_trial=torch.cuda.device_count()):
    # config = {
    #     'learning_rate': tune.choice([1e-5, 1e-4, 1e-3, 1e-2]),
    #     'batch_size': tune.choice([128, 64, 32]),
    # }

    best = {'batch_size': 128, 'learning_rate': 0.0001}    

    scheduler = ASHAScheduler(
        max_t=num_epochs,
        grace_period=1,
        reduction_factor=2)

    reporter = CLIReporter(
        # overwrite=True,
        parameter_columns=['learning_rate', 'batch_size'],
        metric_columns=['loss', 'auroc', 'training_iteration'])

    analysis = tune.run(
        tune.with_parameters(
            train_LensResnet_tune_checkpoint,
            num_epochs=num_epochs,
            num_gpus=gpus_per_trial),
        name='LensResnet_J',                                                    # Change with dataset change
        metric='auroc',
        mode='max',
        config=best,
        resources_per_trial={
            'cpu': os.cpu_count(),
            'gpu': gpus_per_trial,
            # 'tpu': 8,
        },
        # num_samples=num_samples,
        local_dir='./drive/MyDrive/Logs',
        scheduler=scheduler,
        progress_reporter=reporter,
        fail_fast = True,
        restore = '/content/drive/MyDrive/Logs/LensResnet_J/train_LensResnet_tune_checkpoint_c74d9_00003_3_batch_size=128,learning_rate=0.0001_2021-07-08_01-03-43/checkpoint_epoch=2-step=1406',
        # '/content/drive/MyDrive/Logs/tune_LensResnet_asha_model_f/train_LensResnet_tune_checkpoint_e32ba_00000_0_batch_size=64,learning_rate=0.0001_2021-07-06_03-33-10/checkpoint_epoch=14-step=4689',
        # resume=True,
        )

    print('Best hyperparameters found were: ', analysis.best_config)
# __tune_asha_end__


# __tune_pbt_begin__
def tune_LensResnet_pbt(num_samples=10, num_epochs=10, gpus_per_trial=torch.cuda.device_count()):
    config = {
        'learning_rate': 1e-3,
        'batch_size': 64,
    }

    scheduler = PopulationBasedTraining(
        perturbation_interval=4,
        hyperparam_mutations={
            'learning_rate': [1e-5, 1e-4, 1e-3, 1e-2],
            'batch_size': [32, 64, 128]
        })

    reporter = CLIReporter(
        # overwrite=True,
        parameter_columns=['learning_rate', 'batch_size'],
        metric_columns=['loss', 'auroc', 'training_iteration'])

    analysis = tune.run(
        # resume=True,
        tune.with_parameters(
            train_LensResnet_tune_checkpoint,
            num_epochs=num_epochs,
            num_gpus=gpus_per_trial),
        metric='auroc',
        mode='max',
        resources_per_trial={
            'cpu': os.cpu_count(),
            'gpu': gpus_per_trial,
            # 'tpu': 8,
        },
        fail_fast = True,
        config=config,
        num_samples=num_samples,
        scheduler=scheduler,
        progress_reporter=reporter,
        local_dir='./drive/MyDrive/Logs' ,
        name='tune_LensResnet_pbt')

    print('Best hyperparameters found were: ', analysis.best_config)

# __tune_pbt_end__


if __name__ == '__main__':
    import argparse

    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--smoke-test', action='store_true', help='Finish quickly for testing')
    args, _ = parser.parse_known_args()

    if args.smoke_test:
        tune_LensResnet_asha(num_samples=1, num_epochs=6, gpus_per_trial=torch.cuda.device_count())
        tune_LensResnet_pbt(num_samples=1, num_epochs=6, gpus_per_trial=torch.cuda.device_count())
    else:
        # ASHA scheduler
        tune_LensResnet_asha(num_samples=12, num_epochs=20, gpus_per_trial=torch.cuda.device_count())
        # Population based training
        # tune_LensResnet_pbt(num_samples=10, num_epochs=10, gpus_per_trial=torch.cuda.device_count())

Trying out Auto Tuning of learning rate

In [None]:
# Can't work with multiple optimizers
config = {
    'learning_rate': 1e-4, 'batch_size': 128, 'feature_maps': 64,
}
dm = npyImageData(config)
generator = StackGAN(config)

trainer = pl.Trainer(
    # logger=,
    # checkpoint_callback=,
    default_root_dir='./drive/MyDrive/Logs/', 
    gpus=1,
    auto_select_gpus=True, 
    # tpu_cores=
    progress_bar_refresh_rate=1,
    # fast_dev_run=,
    max_epochs=5,
    # max_time=,
    # limit_train_batches=,
    # flush_logs_every_n_steps=,
    # log_every_n_steps=,
    # resume_from_checkpoint='./drive/MyDrive/Logs/lr_find_temp_model.ckpt',
    auto_lr_find = True,
    # auto_scale_batch_size=True,
    # prepare_data_per_node=,
    )

# Run learning rate finder
lr_finder = trainer.tuner.lr_find(generator, dm)

# # Results can be found in
# # lr_finder.results

# Plot with
fig = lr_finder.plot(suggest=True)
fig.show()

# Pick point based on plot, or get suggestion
new_lr = lr_finder.suggestion()

# # update hparams of the model
# model.hparams.lr = new_lr

# # Fit model
# trainer.fit(model)

## Stage 1
Here we tune hyperparameters as we train our modified DCGAN.

In [None]:
%rm -rf ./drive/MyDrive/Logs/F/Stage1/pbt/

In [None]:
# __tune_train_checkpoint_begin
def train_Stage1(config, checkpoint_dir=None, num_epochs=10, num_gpus=torch.cuda.device_count()):
    # print(os.cpu_count(), torch.cuda.device_count())
    kwargs = {
        # 'limit_train_batches' : 0.05,
        # 'limit_val_batches' : 0.05,
        'progress_bar_refresh_rate' : math.ceil(8250//config['batch_size']),
        'max_epochs' : num_epochs,
        'prepare_data_per_node' : False,
        # If fractional GPUs passed in, convert to int.
        'gpus' : math.ceil(num_gpus),
        'logger' : TensorBoardLogger(save_dir=tune.get_trial_dir(), name='', version='.'),
        'callbacks' : [
            TuneReportCheckpointCallback(
                {
                    'loss_G': 'Stage1/G/train/loss', 
                    'loss_D': 'Stage1/D/train/loss', 
                    # Switch up the FID vlues when training on different dataset -----------------------------------------------
                    'FID': 'Stage1/ResNet(F)/val/FID', 
                    'FID_cross': 'Stage1/ResNet(J)/val/FID',
                },
            ),
        ],
        # 'stochastic_weight_avg' : True,
        # works with only one optimizer
        # 'benchmark' : True,
        # 'gradient_clip_val' : 0.5, 
        # 'gradient_clip_algorithm' : 'value',
    }
    
    dm = npyImageData(config, 64)                                              # Specify image width here    
    if checkpoint_dir is not None:
        kwargs['resume_from_checkpoint'] = os.path.join(checkpoint_dir, 'checkpoint')
        # model = Stage1.load_from_checkpoint(kwargs['resume_from_checkpoint'], config=config)
    # else:

    model = Stage1(config)
    trainer = pl.Trainer(**kwargs)

    trainer.fit(model, dm)
# __tune_train_checkpoint_end__


# # # __tune_asha_begin__
# def tune_Stage1_asha(num_samples=10, num_epochs=10, gpus_per_trial=torch.cuda.device_count()):
#     # print(os.cpu_count(), torch.cuda.device_count())
#     analysis = tune.run(
#         tune.with_parameters(
#             train_Stage1,
#             num_epochs=num_epochs,
#             num_gpus=gpus_per_trial
#         ),
#         # Change the folder name when changing dataset--------------------------------------------------------------------------
#         name='Stage1/pbt/J',
#         metric='FID',
#         mode='min',
#         config={'learning_rate': 1e-4,
#                 'n_fmaps': tune.grid_search([8, 16, 32, 64, 128]),
#                 'batch_size': 8,
#                 },
#         # config={'learning_rate': 0.01,
#         #         'n_fmaps': 32,
#         #         'batch_size': 32,
#         #         },
#         # stop=TrialPlateauStopper('loss_G'),
#         resources_per_trial={'cpu': os.cpu_count(),
#                              'gpu': gpus_per_trial,
#                             },
#         local_dir='./drive/MyDrive/Logs',
#         scheduler = ASHAScheduler(max_t=num_epochs, grace_period=2,  reduction_factor=2),
#         progress_reporter=JupyterNotebookReporter(
#             overwrite=True,
#             parameter_columns=['learning_rate', 'n_fmaps', 'batch_size'],
#             metric_columns=['loss_G', 'loss_D', 'FID', 'FID_cross', 'training_iteration'],
#             sort_by_metric=True,
#         ),
#         fail_fast = True,
#         # reuse_actors=True,
#         # num_samples=num_samples,
#         resume='PROMPT',
# #         restore='/content/drive/MyDrive/Logs/delete/train_Stage1_e42ac_00025_25_batch_size=8,learning_rate=0.01,n_fmaps=8_2021-07-28_21-16-18/checkpoint_epoch=4-step=2339',
#     )

# #     print('Best hyperparameters found were: ', analysis.best_config)

# # # __tune_asha_end__


# __tune_pbt_begin__
def tune_Stage1_pbt(num_samples=10, num_epochs=10, gpus_per_trial=torch.cuda.device_count()):
    # print(os.cpu_count(), torch.cuda.device_count())
    analysis = tune.run(
        tune.with_parameters(
            train_Stage1,
            num_epochs=num_epochs,
            num_gpus=gpus_per_trial
        ),
        # Change the folder name when changing dataset--------------------------------------------------------------------------
        name='F/Stage1/pbt',
        metric='FID',
        mode='min',
        config={'learning_rate': 1e-4,
                'n_fmaps': tune.grid_search([8, 16, 32, 64, 128]),
                'batch_size': 8,
                },
        # config={'learning_rate': 0.01,
        #         'n_fmaps': 32,
        #         'batch_size': 32,
        #         },
        # stop=TrialPlateauStopper('loss_G'),
        stop={'loss_G': 0},
        resources_per_trial={'cpu': os.cpu_count(),
                             'gpu': gpus_per_trial,
                            },
        local_dir='./drive/MyDrive/Logs',
        # Can't use RB2 as it requires mutations to be continuous
        scheduler = PopulationBasedTraining(
            time_attr='training_iteration',
            quantile_fraction=0.4,
            resample_probability=0.5,
            perturbation_interval=2,
            hyperparam_mutations={
                'learning_rate': tune.loguniform(1e-7, 1e-1),
                'batch_size': [8, 16, 32, 64, 128],
            },
        ),
        progress_reporter=JupyterNotebookReporter(
            overwrite=True,
            parameter_columns=['learning_rate', 'n_fmaps', 'batch_size'],
            metric_columns=['loss_G', 'loss_D', 'FID', 'FID_cross', 'training_iteration'],
            # sort_by_metric=True,
        ),
        fail_fast = True,
        # reuse_actors=True,
        # num_samples=num_samples,
        resume='PROMPT',
        # restore='/content/drive/MyDrive/Logs/delete/train_Stage1_e42ac_00027_27_batch_size=32,learning_rate=0.01,n_fmaps=8_2021-07-28_21-27-00/checkpoint_epoch=6-step=818',
        # restore='/content/drive/MyDrive/Logs/delete/train_Stage1_e42ac_00056_56_batch_size=16,learning_rate=0.001,n_fmaps=16_2021-07-29_00-13-37/checkpoint_epoch=3-step=935',
        # restore='/content/drive/MyDrive/Logs/delete/train_Stage1_e42ac_00060_60_batch_size=8,learning_rate=0.01,n_fmaps=16_2021-07-29_06-58-44/checkpoint_epoch=3-step=1871',
        # restore='/content/drive/MyDrive/Logs/delete/train_Stage1_e42ac_00061_61_batch_size=16,learning_rate=0.01,n_fmaps=16_2021-07-29_02-41-53/checkpoint_epoch=4-step=1169',
        # restore='/content/drive/MyDrive/Logs/delete/train_Stage1_e42ac_00061_61_batch_size=16,learning_rate=0.01,n_fmaps=16_2021-07-29_06-58-44/checkpoint_epoch=4-step=1169',
        # restore='/content/drive/MyDrive/Logs/delete/train_Stage1_e42ac_00062_62_batch_size=32,learning_rate=0.01,n_fmaps=16_2021-07-29_06-58-44/checkpoint_epoch=4-step=584',
        # restore='/content/drive/MyDrive/Logs/delete/train_Stage1_e42ac_00091_91_batch_size=16,learning_rate=0.001,n_fmaps=32_2021-07-29_15-46-04/checkpoint_epoch=2-step=701',
        # restore='/content/drive/MyDrive/Logs/delete/train_Stage1_e42ac_00094_94_batch_size=128,learning_rate=0.001,n_fmaps=32_2021-07-29_18-06-24/checkpoint_epoch=2-step=86',
        # restore='/content/drive/MyDrive/Logs/delete/train_Stage1_e42ac_00096_96_batch_size=16,learning_rate=0.01,n_fmaps=32_2021-07-29_18-06-24/checkpoint_epoch=2-step=701',
        # restore='/content/drive/MyDrive/Logs/delete/train_Stage1_e42ac_00097_97_batch_size=32,learning_rate=0.01,n_fmaps=32_2021-07-29_18-06-24/checkpoint_epoch=2-step=350',
    )

    print('Best checkpoint path found is: ', analysis.best_checkpoint)

# __tune_pbt_end__


if __name__ == '__main__':
    import argparse

    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--smoke-test', action='store_true', help='Finish quickly for testing')
    args, _ = parser.parse_known_args()

    if args.smoke_test:
        tune_Stage1_asha(num_samples=1, num_epochs=6, gpus_per_trial=torch.cuda.device_count())
        tune_Stage1_pbt(num_samples=1, num_epochs=6, gpus_per_trial=torch.cuda.device_count())
    else:
        # ASHA scheduler
        # tune_Stage1_asha(num_samples=1, num_epochs=10, gpus_per_trial=torch.cuda.device_count())
        # Population based training
        tune_Stage1_pbt(num_samples=1, num_epochs=30, gpus_per_trial=torch.cuda.device_count())

In [None]:
drive.flush_and_unmount()

In [None]:
!cat /content/drive/MyDrive/Logs/Stage1/pbt/F/train_Stage1_0711f_00001_1_n_fmaps=16_2021-08-08_18-54-16/error.txt

## Stage 2
Here we tune hyperparameters as we train our modified DCGAN.

In [None]:
# __tune_train_checkpoint_begin
def train_Stage2(config, checkpoint_dir=None, num_epochs=10, num_gpus=torch.cuda.device_count()):
    # print(os.cpu_count(), torch.cuda.device_count())
    kwargs = {
        # 'limit_train_batches' : 0.05,
        # 'limit_val_batches' : 0.05,
        'progress_bar_refresh_rate' : math.ceil(8250//config['batch_size']),
        'max_epochs' : num_epochs,
        'prepare_data_per_node' : False,
        # If fractional GPUs passed in, convert to int.
        'gpus' : math.ceil(num_gpus),
        'logger' : TensorBoardLogger(save_dir=tune.get_trial_dir(), name='', version='.'),
        'callbacks' : [
            TuneReportCheckpointCallback(
                {
                    'loss_G': 'Stage2/G/train/loss', 
                    'loss_D': 'Stage2/D/train/loss', 
                    # Switch up the auroc vlues when training on different dataset -----------------------------------------------
                    'auroc': 'Stage2/ResNet(F)/val/auroc', 
                    'auroc_cross': 'Stage2/ResNet(J)/val/auroc',
                },
            ),
        ],
        # 'stochastic_weight_avg' : True,
        # works with only one optimizer
        # 'benchmark' : True,
    }
    
    dm = npyImageData(config)                                              # Specify image width here    
    if checkpoint_dir is not None:
        kwargs['resume_from_checkpoint'] = os.path.join(checkpoint_dir, 'checkpoint')
        # model = Stage2.load_from_checkpoint(kwargs['resume_from_checkpoint'], config=config)
    # else:
        # model = Stage2(config)
    model = Stage2(config)
    trainer = pl.Trainer(**kwargs)

    trainer.fit(model, dm)
# __tune_train_checkpoint_end__


# # # __tune_asha_begin__
# def tune_Stage2_asha(num_samples=10, num_epochs=10, gpus_per_trial=torch.cuda.device_count()):
#     # print(os.cpu_count(), torch.cuda.device_count())
#     analysis = tune.run(
#         tune.with_parameters(
#             train_Stage2,
#             num_epochs=num_epochs,
#             num_gpus=gpus_per_trial
#         ),
#         # Change the folder name when changing dataset--------------------------------------------------------------------------
#         name='Stage2/pbt/J',
#         metric='auroc',
#         mode='max',
#         config={'learning_rate': 1e-4,
#                 'n_fmaps': tune.grid_search([8, 16, 32, 64, 128]),
#                 'batch_size': 8,
#                 },
#         # config={'learning_rate': 0.01,
#         #         'n_fmaps': 32,
#         #         'batch_size': 32,
#         #         },
#         # stop=TrialPlateauStopper('loss_G'),
#         resources_per_trial={'cpu': os.cpu_count(),
#                              'gpu': gpus_per_trial,
#                             },
#         local_dir='./drive/MyDrive/Logs',
#         scheduler = ASHAScheduler(max_t=num_epochs, grace_period=2,  reduction_factor=2),
#         progress_reporter=JupyterNotebookReporter(
#             overwrite=True,
#             parameter_columns=['learning_rate', 'n_fmaps', 'batch_size'],
#             metric_columns=['loss_G', 'loss_D', 'auroc', 'auroc_cross', 'training_iteration'],
#             sort_by_metric=True,
#         ),
#         fail_fast = True,
#         # reuse_actors=True,
#         # num_samples=num_samples,
#         resume='PROMPT',
# #         restore='/content/drive/MyDrive/Logs/delete/train_Stage2_e42ac_00025_25_batch_size=8,learning_rate=0.01,n_fmaps=8_2021-07-28_21-16-18/checkpoint_epoch=4-step=2339',
#     )

# #     print('Best hyperparameters found were: ', analysis.best_config)

# # # __tune_asha_end__


# __tune_pbt_begin__
def tune_Stage2_pbt(num_samples=10, num_epochs=10, gpus_per_trial=torch.cuda.device_count()):
    # print(os.cpu_count(), torch.cuda.device_count())
    analysis = tune.run(
        tune.with_parameters(
            train_Stage2,
            num_epochs=num_epochs,
            num_gpus=gpus_per_trial
        ),
        # Change the folder name when changing dataset--------------------------------------------------------------------------
        name='Stage2/pbt/F',
        metric='auroc',
        mode='max',
        config={'learning_rate': 1e-4,
                'n_fmaps': tune.grid_search([8, 16, 32, 64, 128]),
                'res_depth': tune.choice([1, 2, 3, 4]),
                'batch_size': 8,
                },
        # config={'learning_rate': 0.01,
        #         'n_fmaps': 32,
        #         'batch_size': 32,
        #         },
        # stop=TrialPlateauStopper('loss_G'),
        resources_per_trial={'cpu': os.cpu_count(),
                             'gpu': gpus_per_trial,
                            },
        local_dir='./drive/MyDrive/Logs',
        scheduler = PopulationBasedTraining(time_attr='training_iteration',
                                            quantile_fraction=0.5,
                                            resample_probability=0.8,
                                            perturbation_interval=1,
                                            hyperparam_mutations={
                                                'learning_rate': tune.loguniform(1e-7, 1e-1),
                                                'batch_size': [8, 16, 32, 64, 128],
                                            },
                                            ),
        progress_reporter=JupyterNotebookReporter(
            overwrite=False,
            parameter_columns=['learning_rate', 'n_fmaps', 'res_depth', 'batch_size'],
            metric_columns=['loss_G', 'loss_D', 'auroc', 'auroc_cross', 'training_iteration'],
            sort_by_metric=True,
        ),
        fail_fast = True,
        # reuse_actors=True,
        # num_samples=num_samples,
        resume='PROMPT',
    )

    print('Best hyperparameters found were: ', analysis.best_config)

# __tune_pbt_end__


if __name__ == '__main__':
    import argparse

    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--smoke-test', action='store_true', help='Finish quickly for testing')
    args, _ = parser.parse_known_args()

    if args.smoke_test:
        tune_Stage2_asha(num_samples=1, num_epochs=6, gpus_per_trial=torch.cuda.device_count())
        tune_Stage2_pbt(num_samples=1, num_epochs=6, gpus_per_trial=torch.cuda.device_count())
    else:
        # ASHA scheduler
        # tune_Stage2_asha(num_samples=1, num_epochs=10, gpus_per_trial=torch.cuda.device_count())
        # Population based training
        tune_Stage2_pbt(num_samples=1, num_epochs=30, gpus_per_trial=torch.cuda.device_count())

In [None]:
drive.flush_and_unmount()

In [None]:
!cat /content/drive/MyDrive/Logs/Stage2/pbt/F/train_Stage2_6f508_00000_0_n_fmaps=8_2021-08-13_14-04-54/error.txt

## StackGAN:
Here we tune hyperparameters for generating images that resemble the images from input.

In [None]:
# __tune_train_checkpoint_begin
def train_StackGAN_tune_checkpoint(config,
                                   checkpoint_dir=None,
                                   num_epochs=10,
                                   num_gpus=torch.cuda.device_count()):
    data_dir = os.path.expanduser('/content/images/')
    trainer = pl.Trainer(
        # accumulate_grad_batches=2,
        # limit_train_batches=0.20,
        # limit_val_batches=0.20,
        num_sanity_val_steps=-1,
        max_epochs=num_epochs,
        prepare_data_per_node = False,
        # If fractional GPUs passed in, convert to int.
        gpus=math.ceil(num_gpus),
        # tpu_cores = 8,
        logger=TensorBoardLogger(save_dir=tune.get_trial_dir(), name='', version='.'),
        # progress_bar_refresh_rate=1,
        callbacks=[
            TuneReportCheckpointCallback(
                metrics={
                    'loss_G1': 'G1/train/loss/full',
                    'loss_G2': 'G2/train/loss/full',
                    'loss_D1': 'D1/train/loss',
                    'loss_D2': 'D2/train/loss',
                    'lossR': 'R/train/loss',
                    'auroc': 'Pre/val/auroc',
                },
                filename='checkpoint',
                # on='training_end'
            )
        ],
        # stochastic_weight_avg=True,
        # works with only one optimizer
        )
    dm = npyImageData(config, data_dir)
    if checkpoint_dir:
        # Currently, this leads to errors:
        # model = StackGAN.load_from_checkpoint(
        #     os.path.join(checkpoint, 'checkpoint'))
        # Workaround:
        ckpt = pl_load(
            os.path.join(checkpoint_dir, 'checkpoint'),
            map_location=lambda storage, loc: storage)
        model = StackGAN._load_model_state(
            ckpt, config=config, 
            # data_dir=data_dir
            )
        trainer.current_epoch = ckpt['epoch']
    else:
        model = StackGAN(config)

    trainer.fit(model, dm)
# __tune_train_checkpoint_end__


# __tune_asha_begin__
def tune_StackGAN_asha(num_samples=10, num_epochs=10, gpus_per_trial=torch.cuda.device_count()):
    config = {
        'learning_rate': tune.choice([1e-4]),
        'feature_maps': tune.choice([64]),
        'batch_size': tune.choice([128, 64]),
    }

    scheduler = ASHAScheduler(
        max_t=num_epochs,
        grace_period=1,
        reduction_factor=2)

    reporter = CLIReporter(
        # overwrite=True,
        parameter_columns=['learning_rate', 'feature_maps', 'batch_size'],
        metric_columns=['loss_G1', 'loss_G2', 'loss_D1', 'loss_D2', 'lossR', 'auroc', 'training_iteration'],
        )

    analysis = tune.run(
        tune.with_parameters(
            train_StackGAN_tune_checkpoint,
            num_epochs=num_epochs,
            num_gpus=gpus_per_trial),
        name='tune_StackGAN_asha_model_j',
        metric='auroc',
        mode='max',
        config=config,
        resources_per_trial={
            'cpu': os.cpu_count(),
            'gpu': gpus_per_trial,
            # 'tpu': 8,
        },
        num_samples=num_samples,
        local_dir='./drive/MyDrive/Logs',
        scheduler=scheduler,
        progress_reporter=reporter,
        # restore='/content/drive/MyDrive/Logs/tune_StackGAN_1_asha_model_j/train_StackGAN_tune_checkpoint_fa25b_00000_0_batch_size=64,feature_maps=64,learning_rate=0.0001_2021-07-06_20-23-13/checkpoint_epoch=0-step=937',
        fail_fast = True,
        resume='PROMPT',
        )

    print('Best hyperparameters found were: ', analysis.best_config)

# __tune_asha_end__


# __tune_pbt_begin__
def tune_StackGAN_pbt(num_samples=10, num_epochs=10, gpus_per_trial=torch.cuda.device_count()):
    config = {
        'learning_rate': 1e-4,
        'feature_maps': 64,
        'batch_size': 64,
    }

    scheduler = PopulationBasedTraining(
        perturbation_interval=4,
        hyperparam_mutations={
            'learning_rate': [1e-4, 1e-3],
            'feature_maps': [64, 128],
            'batch_size': [32, 64, 128]
        })

    reporter = CLIReporter(
        # overwrite=True,
        parameter_columns=['learning_rate', 'feature_maps', 'batch_size'],
        metric_columns=['loss_G1', 'loss_G2', 'loss_D1', 'loss_D2', 'lossR', 'auroc', 'training_iteration'],
        )

    analysis = tune.run(
        # resume=True,
        tune.with_parameters(
            train_StackGAN_tune_checkpoint,
            num_epochs=num_epochs,
            num_gpus=gpus_per_trial),
        name='tune_StackGAN_pbt_model_j',
        metric='auroc',
        mode='max',
        resources_per_trial={
            'cpu': os.cpu_count(),
            'gpu': gpus_per_trial,
            # 'tpu': 8,
        },
        config=config,
        num_samples=num_samples,
        scheduler=scheduler,
        progress_reporter=reporter,
        local_dir='./drive/MyDrive/Logs',
        # restore='/content/drive/MyDrive/Logs/tune_StackGAN_1_asha_model_j/train_StackGAN_tune_checkpoint_fa25b_00000_0_batch_size=64,feature_maps=64,learning_rate=0.0001_2021-07-06_20-23-13/checkpoint_epoch=0-step=937',
        fail_fast = True,
        # resume='PROMPT',
        )

    print('Best hyperparameters found were: ', analysis.best_config)

# __tune_pbt_end__


if __name__ == '__main__':
    import argparse

    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--smoke-test', action='store_true', help='Finish quickly for testing')
    args, _ = parser.parse_known_args()

    if args.smoke_test:
        tune_StackGAN_asha(num_samples=1, num_epochs=6, gpus_per_trial=torch.cuda.device_count())
        tune_StackGAN_pbt(num_samples=1, num_epochs=6, gpus_per_trial=torch.cuda.device_count())
    else:
        # ASHA scheduler
        tune_StackGAN_asha(num_samples=2, num_epochs=1, gpus_per_trial=torch.cuda.device_count())
        # Population based training
        # tune_StackGAN_pbt(num_samples=8, num_epochs=5, gpus_per_trial=torch.cuda.device_count())