<a href="https://colab.research.google.com/github/souravraha/galaxy/blob/pgan/Lightning_Tune.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Mount Google drive

In [None]:
from google.colab import drive
drive.mount('/content/drive', timeout_ms=600000)

# Prerequisites/ shell commands

## Install/uninstall packages

In [None]:
# If you are running on Google Colab, uncomment below to install the necessary dependencies 
# before beginning the exercise.

print('Setting up colab environment')
!pip install -q tf-estimator-nightly==2.8.0.dev2021122109 imgaug==0.2.6 matplotlib==3.4.3 torchmetrics[image] lightning-bolts GPy ray[tune]

# A hack to force the runtime to restart, needed to include the above dependencies.
print('Done installing! Restarting via forced crash (this is not an issue).')
import os
os.kill(os.getpid(), 9)

## Download and extract data

In [None]:
# 'a': 1Cjcw2EWorhdhJSGoWOdxsEUDxvl943dt, 'b': 15yXXC4h5VsytP3Ak1jfUSjQhdgP2s23K, 'c': 1vuQ-pLzoKT4Hd_V7949r9eND9E2fB_u_,
# 'd': , 'e': 1wFuasvb7PthxXtMUlsD13uzYHWlWt06H, 'f': 17l6H61tLAu26zGuei38r_T5ssjbYUeaJ, 
# 'g': 1SxQVosWeEjY3Pyn8LRXA11rLnZ9HK_7B, 'h': 1Atau0RH4oyLAiYReW-G9a8l9pUNltglF, 'i': 15lEgsR1p00KSHieaT9a1nkbJ86pDxwgp, 
# 'j': 1m0EQUbqZZeyl76XsQIKWU5Qd7jGmmWhB, 'k': , 'l': 1meTDi4aeWfdChOiXeLtUOGhjVDVu000e

# fake data
# 'fpgan': 1-4o0yqSBA9WSY9gTYamIez_RAekwDsHV

!rm -rf images/
# !pip install -q --upgrade --no-cache-dir gdown
!gdown --id 17l6H61tLAu26zGuei38r_T5ssjbYUeaJ -O - --quiet | tar --skip-old-files -zxf -

# Line wrapping

In [None]:
from IPython.display import HTML, display

def set_css():
    display(HTML('''<style>pre{white-space: pre-wrap;}</style>'''))
    
get_ipython().events.register('pre_run_cell', set_css)

# Import libraries

In [None]:
import os, copy, psutil, ray
import numpy as np
import pandas as pd
from itertools import cycle
%matplotlib inline
from matplotlib import pyplot as plt
# ------------------------------------
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader, random_split
# ------------------------------------
from torchvision.models import resnet18
from torchvision.utils import save_image, make_grid
from torchvision import transforms, datasets
# ------------------------------------
from sklearn.metrics import PrecisionRecallDisplay, RocCurveDisplay
# ------------------------------------
import torchmetrics as tm
from torchmetrics.image.fid import FrechetInceptionDistance as FID
from torchmetrics.image.kid import KernelInceptionDistance as KID
# ------------------------------------
import pytorch_lightning as pl

# from pl_bolts.models.autoencoders import VAE
from pl_bolts.models.gans import DCGAN, Pix2Pix
from pl_bolts.callbacks import ModuleDataMonitor
from pytorch_lightning.loggers import TensorBoardLogger
from pl_bolts.models.self_supervised.resnets import BasicBlock
from pytorch_lightning.callbacks import TQDMProgressBar #ModelCheckpoint
# from drive.MyDrive.ml.Callbacks.confused_logits import ConfusedLogitCallback
# from drive.MyDrive.ml.Callbacks.save_images import SaveImages
# ------------------------------------
from ray import tune
# from ray.tune.stopper import TrialPlateauStopper
from ray.tune import CLIReporter, JupyterNotebookReporter
from ray.tune.schedulers import ASHAScheduler, PopulationBasedTraining
from ray.tune.schedulers.pb2 import PB2
from ray.tune.integration.pytorch_lightning import TuneReportCallback, TuneReportCheckpointCallback
# ------------------------------------

# Class definitions

## DataModule
This creates dataloaders which need to be supplied to train, validate or test the module we have.

In [None]:
class npyImageData(pl.LightningDataModule):
    def __init__(self, config, img_width: int = 150, data_dir: str = '/content/images/'):
        super().__init__()
        # This method is not implemented
        # self.save_hyperparameters()
        self.bs = int(2**np.rint(config['bs']))
        self.data_dir = os.path.expanduser(data_dir)
        self.prepare_data_per_node = False

        # Change the source file containing mean and stdv when changing dataset ------------------------------------------------------
        self.transform = transforms.Compose([
            # transforms.RandomHorizontalFlip(),
            # transforms.RandomVerticalFlip(),
            # F : [mean=71.75926373866668, std=96.139484964214, min=5.0, max=966.0]
            # J : [mean=50.271541595458984, std=94.8838882446289, min=0, max=1007.0]
            transforms.Normalize(mean=(0,), std=(966,)),
            transforms.Normalize(mean=(0.5,), std=(0.5,)),
            # this shift-scales the pixel values -> [-1, 1]
            transforms.Resize(img_width, transforms.InterpolationMode.NEAREST),
        ])

    @staticmethod
    def npy_loader(path):
        # s=np.load(path).astype('float',copy=False)
        return torch.from_numpy(np.load(path)).unsqueeze(0).float()
        # Convert to tenssor first, and then to float, otherwise final dtype 
        # would be float64, which would raise errors in conv layers      ###### type as

    def setup(self, stage: str = None):
        if stage in ('fit', None):
            self.train_set = datasets.DatasetFolder(os.path.join(self.data_dir,'train'), 
                self.npy_loader, ('.npy'), self.transform,)
            # self.train_set, self.val_set = random_split(self.full_set, [60000, 15000])            
            self.val_set = datasets.DatasetFolder(os.path.join(self.data_dir,'val'),  
                self.npy_loader, ('.npy'), self.transform,)
            # self.dims = tuple(self.train_set[0][0].shape)

        if stage in ('test', None):
            self.test_set = datasets.DatasetFolder(os.path.join(self.data_dir,'val'),  
                self.npy_loader, ('.npy'), self.transform,)
            # self.dims = getattr(self, 'dims', self.test_set[0][0].shape)
    
    def train_dataloader(self):
        return DataLoader(self.train_set, self.bs, shuffle=True, num_workers=os.cpu_count(), pin_memory=True)

    def val_dataloader(self):
        return DataLoader(self.val_set, self.bs, shuffle=True, num_workers=os.cpu_count(), pin_memory=True)

    def test_dataloader(self):
        return DataLoader(self.test_set, self.bs, shuffle=True, num_workers=os.cpu_count(), pin_memory=True)

In [None]:
dm = npyImageData({'lr': 0.001, 'bs': 8})
# model = LensResnet.load_from_checkpoint(os.path.join(BEST_F_RESNET, 'checkpoint'), config={'lr': 0.001, 'bs': 8})
# trainer = pl.Trainer(gpus=1)
# trainer.predict(model, dm)

## ResNet
We modify a ResNet slightly for our purpose.

In [None]:
PRE_F_RESNET = '/content/drive/MyDrive/Logs/F/LensResnet/PRETRAINED.pth'
PRE_J_RESNET = '/content/drive/MyDrive/Logs/J/LensResnet/PRETRAINED.pth'

class LensResnet(pl.LightningModule):
    def __init__(self, config, image_channels: int = 1, num_classes: int = 3, **kwargs):
        super().__init__()
        self.save_hyperparameters(ignore=config)
        self.lr = 10**config['lr']

        # --------------------------------------------------------------------------------------------------
        self.backbone = torch.load(PRE_F_RESNET, map_location=self.device)
        # self.backbone = resnet18(num_classes=self.hparams.num_classes)
        # self.backbone.conv1 = nn.LazyConv2d(64, 7, 2, 3, bias=False)
        
        self.train_metrics = tm.MetricCollection([tm.AUROC(self.hparams.num_classes, average='weighted'),],
            prefix='LensResnet/train/'
        )
        self.val_metrics = tm.MetricCollection(
            [tm.PrecisionRecallCurve(self.hparams.num_classes), tm.ROC(self.hparams.num_classes),
            tm.AveragePrecision(self.hparams.num_classes), tm.AUROC(self.hparams.num_classes, average=None)]
        )

    def configure_optimizers(self):
        return torch.optim.Adam(self.backbone.parameters(), self.lr)

    def forward(self, x, prob=False):
        logits = self.backbone(x)
        return logits.softmax(1) if prob else logits

    def training_step(self, batch, batch_idx):
        imgs, labels = batch
        self.last_logits = self(imgs)
        loss = F.cross_entropy(self.last_logits, labels)
        self.log('LensResnet/train/loss', loss)
        #  keep only scalars here, for no errors
        
        preds = self.last_logits.softmax(1)
        self.train_metrics.update(preds, labels)
        try:
            self.log_dict(self.train_metrics.compute(), prog_bar=True)
        except Exception as f:
            print(f)
        finally:            
            # self.train_metrics.reset()
            # self.log_dict automatically resets at the end of epoch
            return loss

    def validation_step(self, batch, batch_idx):
        imgs, labels = batch
        logits = self(imgs)
        loss = F.cross_entropy(logits, labels)
        self.log('LensResnet/val/loss', loss)
        #  keep only scalars here, for no errors
        
        preds = logits.softmax(1)
        self.val_metrics.update(preds, labels)

    def validation_epoch_end(self, Listofdicts):
        colors = cycle(['r', 'g', 'b'])
        fig, ax = plt.subplots(1,2, subplot_kw={'xlim': [-0.05, 1.05], 'ylim': [-0.05, 1.05], 'aspect': 1}, figsize=(10, 5))
        # ---------------------------------------------------------------------------------------------------------
        fig.suptitle('One vs. all PR & ROC curves for different types of substructures in model J (orginal data)')
        
        Dict = self.val_metrics.compute()
        self.val_metrics.reset()
        
        key, val = list(Dict.keys()), list(Dict.values())
        for b in range(2):
            if key[b] != 'ROC':
                ax[b].plot([0, 1], [1/self.hparams.num_classes, 1/self.hparams.num_classes], 'k--')
            else:
                ax[b].plot([0, 1], [0, 1], 'k--')
            
            prec_FPR, rec_TPR, _ = val[b]
            for i, color, cls in zip(range(self.hparams.num_classes), colors, self.trainer.datamodule.val_set.classes):
                if key[b] != 'ROC':
                    PrecisionRecallDisplay(prec_FPR[i].cpu(), rec_TPR[i].cpu(), val[2 + b][i], cls).plot(ax[b], c=color)
                else:
                    RocCurveDisplay(prec_FPR[i].cpu(), rec_TPR[i].cpu(), val[2 + b][i], cls).plot(ax[b], c=color)
                
            ax[b].legend(loc='best')
            self.log('LensResnet/val/' + key[2 + b], min(val[2 + b]))

        self.logger.experiment.add_figure('LensResnet/val/PR_ROC', fig)
        fig.savefig(str(self.trainer.log_dir) + '/PR_ROC_step_{:05d}.png'.format(self.global_step))

In [None]:
m = LensResnet({'lr': 1e-3})

## LensGAN128
Here we subclass a DCGAN to create our low resolution GAN.

In [None]:
def getLayerNormalizationFactor(x):
    size = x.weight.size()
    fan_in = np.prod(size[1:])
    return np.sqrt(2.0 / fan_in)


class ConstrainedLayer(nn.Module):
    def __init__(self, module, equalized=True, lrMul=1.0, initBiasToZero=True):
        super().__init__()

        self.module = module
        self.equalized = equalized

        if initBiasToZero:
            self.module.bias.data.fill_(0)
        if self.equalized:
            self.module.weight.data.normal_(0, 1)
            self.module.weight.data /= lrMul
            self.weight = getLayerNormalizationFactor(self.module) * lrMul

    def forward(self, x):
        x = self.module(x)
        if self.equalized:
            x *= self.weight
        return x

class SNLayer(nn.Module):
    def __init__(self, module, spec=True, initBiasToZero=True):
        super().__init__()
        
        self.module = module if not spec else nn.utils.parametrizations.spectral_norm(module)
        if initBiasToZero:
            self.module.bias.data.fill_(0)

    def forward(self, x):
        x = self.module(x)
        return x

class EqualizedConv2d(SNLayer):
    def __init__(self, nChannelsPrev, nChannels, kernelSize=4, stride=2, padding=1, bias=True, **kwargs):
        super().__init__(nn.Conv2d(nChannelsPrev, nChannels, kernelSize, stride, padding, bias=bias), **kwargs)
class EqualizedConvTranspose2d(SNLayer):
    def __init__(self, nChannelsPrev, nChannels, kernelSize=4, stride=2, padding=1, bias=True, **kwargs):
        super().__init__(nn.ConvTranspose2d(nChannelsPrev, nChannels, kernelSize, stride, padding, bias=bias), **kwargs)
class EqualizedLinear(SNLayer):
    def __init__(self, nChannelsPrev, nChannels=1, bias=True, **kwargs):
        super().__init__(nn.Linear(nChannelsPrev, nChannels, bias=bias), **kwargs)

In [None]:
from torch.nn.modules.linear import Identity
BEST_F_RESNET = '/content/drive/MyDrive/Logs/F/LensResnet/pbt_tanh/train_LensResnet_eb619_00000_0_2021-09-02_19-42-34/checkpoint_epoch=2-step=28124'
BEST_J_RESNET = '/content/drive/MyDrive/Logs/J/LensResnet/pbt_tanh_fine/train_LensResnet_93609_00000_0_2021-09-02_21-06-00/checkpoint_epoch=2-step=28124'

class LensGAN128(DCGAN):
    def __init__(self, config, num_classes: int = 3, step_size: int = 3000, **kwargs):
        super().__init__(feature_maps_gen=config['n_fmaps'], feature_maps_disc=config['n_fmaps'], 
                         learning_rate=10**config['lr'], **kwargs)
        self.save_hyperparameters(ignore=config)

        del self.generator.gen, self.discriminator.disc
        self.discriminator.disc, self.generator.gen = nn.ModuleList(), nn.ModuleList()
        self.discriminator.fromGrey, self.generator.toGrey = nn.ModuleList(), nn.ModuleList()
        chnl = self.hparams.config['n_fmaps'] ** 2 // 4
        c, kw = True, {'s': 1, 'p' : 0}
        while chnl >= self.hparams.feature_maps_gen:
            self.generator.gen.append(
                self._make_gen_block(self.hparams.latent_dim if c else min(2 * chnl, 512), min(chnl, 512), **(kw if c else {}))
            )
            self.generator.toGrey.append(nn.Sequential(EqualizedConvTranspose2d(min(chnl, 512), 1, 1, 1, 0), nn.Tanh()))
            self.discriminator.fromGrey.append(EqualizedConv2d(1, min(chnl, 512), 1, 1, 0))
            self.discriminator.disc.append(self._make_disc_block(min(chnl, 512), 1 if c else min(2 * chnl, 512), **(kw if c else {})))
            chnl = chnl//2
            c = False
        # Remove dropout from last layer
        del self.discriminator.disc[0][-1]

        self.generator.add_module('mu', nn.Embedding(self.hparams.num_classes, self.hparams.latent_dim))
        self.generator.add_module('sigma', nn.Embedding(self.hparams.num_classes, self.hparams.latent_dim))
        self.generator.forward = self.forward 
        # # Remove if ACGAN not needed 
        self.discriminator.add_module(
            'aux', nn.Sequential(
                nn.Flatten(), EqualizedLinear(min(self.hparams.config['n_fmaps']**2 // 4, 512) * 4**2, self.hparams.config['n_fmaps']),
                nn.LeakyReLU(0.2, inplace=True), nn.Dropout(0.3), EqualizedLinear(self.hparams.config['n_fmaps'], self.hparams.num_classes)
            )
        )
        self.discriminator.forward = self.discriminator_forward

        self.stage = 0
        self.alpha = 1.0
        self.ssimloss = 10

        # Not needed in WGAN architectures
        self.criterion = nn.BCEWithLogitsLoss()

        tmp = LensResnet.load_from_checkpoint(os.path.join(BEST_F_RESNET, 'checkpoint'))
        tmp.freeze()
        self.modelF = tmp.backbone     # torch.load(PRE_F_RESNET, map_location=self.device).eval())
        # Proper way to copy the last layer
        self.lastF = self.modelF.fc
        self.modelF.fc = nn.Identity()
        
        tmp = LensResnet.load_from_checkpoint(os.path.join(BEST_J_RESNET, 'checkpoint'))
        tmp.freeze()
        self.modelJ = tmp.backbone     # torch.load(PRE_J_RESNET, map_location=self.device).eval())
        # Proper way to copy the last layer
        self.lastJ = self.modelJ.fc
        self.modelJ.fc = nn.Identity()

        self.imgMetrics = tm.MetricCollection(
            {
                'FID_F' : FID(self.modelF),
                'FID_J' : FID(self.modelJ)
            },
            prefix='LensGAN128/val/',
        )

        self.metrics = tm.MetricCollection(
            [tm.PrecisionRecallCurve(self.hparams.num_classes), tm.ROC(self.hparams.num_classes),
            tm.AveragePrecision(self.hparams.num_classes, average=None), tm.AUROC(self.hparams.num_classes, average=None)],
        )
    
    @staticmethod
    def _make_gen_block(x : int, y : int, k :int = 4, s : int = 2, p : int = 1, b : bool = True) -> nn.Sequential:
        return nn.Sequential(
            EqualizedConvTranspose2d(x, y, k, s, p, b),
            nn.BatchNorm2d(y),
            # nn.LocalResponseNorm(2*y, 2, 0.5, 1e-8),
            nn.LeakyReLU(0.2, inplace=True),
        )

    @staticmethod
    def _make_disc_block(x : int, y : int, k :int = 4, s : int = 2, p : int = 1, b : bool = True) -> nn.Sequential:
        return nn.Sequential(
            EqualizedConv2d(x, y, k, s, p, b),
            # nn.LocalResponseNorm(s, 1, 0.5, 1e-8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout(0.3)
        )

    def discriminator_forward(self, inp):
        # Add noise to the data
        # inp = inp + torch.randn_like(inp) * np.exp(-self.global_step * len(inp) / self.hparams.step_size)
        x = self.discriminator.fromGrey[self.stage](inp)
        out = self.discriminator.disc[self.stage](x)
        if 0 <= self.alpha < 1.0:
            out = torch.lerp(self.discriminator.fromGrey[self.stage - 1](F.adaptive_avg_pool2d(inp, out.shape[-2:])), out, self.alpha)
        
        for layer in reversed(self.discriminator.disc[:self.stage]):
            # shouldn't combine in a single line
            x = out
            out = layer(out)

        # # Not needed if not using ACGAN
        # out5 = self.discriminator.disc(inp)
        return out.squeeze(), self.discriminator.aux(x)
        # return self.discriminator.disc(inp).squeeze()

    def forward(self, noise, labels, all_layers=False):
        inp = noise * self.generator.sigma(labels) + self.generator.mu(labels)
        # inp = torch.cat((F.one_hot(labels, self.hparams.num_classes), noise), 1)
        out = [inp.view(*inp.shape, 1, 1)]
        
        for layer in self.generator.gen[:self.stage+1]:
            out.append(layer(out[-1]))
        
        # for i, layer in enumerate(self.generator.toGrey[:self.stage+1], 1):
        #     out[i] = layer(out[i])
        out[-1] = self.generator.toGrey[self.stage](out[-1])

        if 0 <= self.alpha < 1.0:
            out[-1] = torch.lerp(F.interpolate(self.generator.toGrey[self.stage - 1](out[-2]), out[-1].shape[-2:]), out[-1], self.alpha)

        return out[1:] if all_layers else out[-1]

    # Save layer-wise activations
    def plotting(self):
        self.eval()
        torch.set_grad_enabled(False)

        labels = torch.arange(self.hparams.num_classes, device=self.device)
        imgs = self._get_fake_data(labels, all_layers=True)
        for lyr, img4d in enumerate(imgs[:-1]):
            new = list(img4d)
            for cls, filters3d in enumerate(new):
                new[cls] = make_grid(filters3d.unsqueeze(1), nrow=int(np.ceil(np.sqrt(len(filters3d)))), normalize=True)

            imgs[lyr] = F.interpolate(torch.stack(new), 150)

        if len(imgs) > 1:
            save_image(torch.cat(imgs[-2::-1]), 
                str(self.trainer.log_dir) + '/F_maps_step_{:05d}.png'.format(self.global_step), 
                #  kwargs for make_grid
                nrow=self.hparams.num_classes, pad_value=0.5)
        
        imgs = imgs[-1]
        while len(labels) > 0:
            x, y = self.trainer.datamodule.val_set[np.random.randint(0, len(self.trainer.datamodule.val_set))]
            if y == labels[0]:
                x = F.adaptive_avg_pool2d(x.unsqueeze(0).type_as(imgs), imgs.shape[-2:])
                imgs = torch.cat((imgs, x))
                labels = labels[1:]

        save_image(F.interpolate(imgs, 150), 
                str(self.trainer.log_dir) + '/Fake_step_{:05d}.png'.format(self.global_step), 
                #  kwargs for make_grid
                nrow=self.hparams.num_classes, normalize=True, value_range=(-1,1), pad_value=0.5)
        
        self.train()
        torch.set_grad_enabled(True)

    def on_train_batch_start(self, batch, batch_idx):
        if self.global_step % self.hparams.step_size == 0:
            self.plotting()

        # Update stage and alpha
        if self.alpha == 1.0 and self.stage < len(self.generator.gen) - 1:
            s = self.trainer.logged_metrics.get('loss_G', None)
            if s != None and s < np.exp(-self.stage) * 2.0:
            # if self.ssimloss < 0.2: #np.exp(-self.stage):
                self.plotting()
                self.stage += 1
                self.alpha = 0.0
        self.alpha = min(self.alpha + 1/self.hparams.step_size, 1.0)

    def training_step(self, batch, batch_idx, optimizer_idx):
        real, self.labels = batch # Can remove this attribute when not using acgan
        fake = self._get_fake_data(self.labels).type_as(real)
        # ssim = tm.functional.multiscale_structural_similarity_index_measure(
        #   F.interpolate(fake, real.shape[-2:]), real, True, 0.4, 3, 
        #   betas=tuple(np.diff(np.sort(np.random.rand(self.stage)), prepend=0, append=1)),
        # )
        real = F.adaptive_avg_pool2d(real, fake.shape[-2:])
        ssim = tm.functional.multiscale_structural_similarity_index_measure(
          fake, real.detach(), True, 0.4, 3, 
          betas=tuple(np.diff(np.sort(np.random.rand(self.stage)), prepend=0, append=1)), normalize='relu',
        )
        # print(ssim)
        self.ssimloss = torch.clamp(-torch.log10(ssim), max=10)
        # self.mse = F.mse_loss(fake.detach(), real.detach())

        # Train discriminator
        result = None
        if optimizer_idx == 0:
            result = self._disc_step(real, fake.detach())

        # Train generator
        if optimizer_idx == 1:
            result = self._gen_step(fake)
        
        return result

    def _disc_step(self, real, fake):
        # # Not needed if using gradient penalty
        # for p in self.discriminator.parameters():
        #     p.data.clamp_(-0.01, 0.01)
        disc_loss = self._get_disc_loss(real, fake)
        self.log('loss_D', disc_loss, True)
        return disc_loss

    def _gen_step(self, fake):
        gen_loss = self._get_gen_loss(fake)
        self.log('loss_G', gen_loss, True)
        return gen_loss

    def _get_disc_loss(self, real, fake):
        # Train with real
        # realCritic_pred = self.discriminator(real)
        realCritic_pred, realAux_pred = self.discriminator(real)
        # real_loss = -realCritic_pred.mean()
        real_gt = torch.ones_like(realCritic_pred)
        real_loss = self.criterion(realCritic_pred, real_gt)
        self.log('loss_D_real', real_loss, True)

        # Train with fake
        # fakeCritic_pred = self.discriminator(fake)
        fakeCritic_pred, _ = self.discriminator(fake)
        # fake_loss = fakeCritic_pred.mean()
        fake_gt = torch.zeros_like(fakeCritic_pred)
        fake_loss = self.criterion(fakeCritic_pred, fake_gt)
        self.log('loss_D_fake', fake_loss, True)

        # # Classifier loss
        class_loss = nn.CrossEntropyLoss()(realAux_pred, self.labels) 
        # + nn.CrossEntropyLoss()(fakeAux_pred, self.labels)
        self.log('loss_D_class', class_loss, True)

        # # Compute gradient penalty
        # gp = 10 * self._gradient_penalty(real, fake)
        # self.log('loss_D_gp', gp, True)

        # Modi
        return real_loss + fake_loss + class_loss #+ gp 

    def _get_gen_loss(self, fake):
        # Train with fake
        # fakeCritic_pred = self.discriminator(fake)
        fakeCritic_pred, fakeAux_pred = self.discriminator(fake)
        # fake_loss = -fakeCritic_pred.mean()
        fake_gt = torch.ones_like(fakeCritic_pred)
        fake_loss = self.criterion(fakeCritic_pred, fake_gt)
        self.log('loss_G_fake', fake_loss, True)

        # Classifier loss
        class_loss = nn.CrossEntropyLoss()(fakeAux_pred, self.labels)
        self.log('loss_G_class', class_loss, True)

        # L2 loss
        # self.log('loss_G_mse', self.mse, True)

        self.log('loss_G_ssim', self.ssimloss, True)
        
        return fake_loss + class_loss + self.ssimloss

    def _get_fake_data(self, labels, **kwargs):
        batch_size = len(labels)
        noise = self._get_noise(batch_size, self.hparams.latent_dim)
        fake = self(noise, labels, **kwargs)

        return fake

    # def _gradient_penalty(self, real, fake):
    #     """Calculates the gradient penalty loss for WGAN GP"""
    #     # Random weight term for interpolation between real and fake samples
    #     alpha = torch.rand(len(real), 1, 1, 1)
    #     # Get random interpolation between real and fake samples
    #     mix = torch.lerp(real, fake, alpha.type_as(real)).requires_grad_(True)
    #     # # Remove the underscore if not using ACGAN
    #     d_mix ,_ = self.discriminator(mix)
    #     # Get gradient w.r.t. mix
    #     gradients = torch.autograd.grad(
    #         outputs=d_mix,
    #         inputs=mix,
    #         grad_outputs=torch.ones_like(d_mix),
    #         create_graph=True,
    #         retain_graph=True,
    #         only_inputs=True,
    #     )[0]
    #     gradients = gradients.view(len(mix), -1)
    #     gp = ((gradients.norm(2, dim=1) - 1) ** 2)
    #     return gp.mean()

    def validation_step(self, batch, batch_idx):
        orig, labels = batch
        self.imgMetrics.update(orig, real=True)
        fake = self._get_fake_data(labels).type_as(orig)
        # ssim = tm.functional.multiscale_structural_similarity_index_measure(
        #   fake, F.adaptive_avg_pool2d(orig, fake.shape[-2:]), True, 0.4, 3, 
        #   betas=tuple(np.diff(np.sort(np.random.rand(self.stage)), prepend=0, append=1)), reduction=None
        # )
        fake = F.interpolate(fake, orig.shape[-2:])
        self.imgMetrics.update(fake, real=False)

        return {
            'y0' : labels, 
            'y1' : self.lastF(self.modelF(orig)).softmax(1), 
            'y2' : self.lastF(self.modelF(fake)).softmax(1),
            'y3' : self.lastJ(self.modelJ(orig)).softmax(1), 
            'y4' : self.lastJ(self.modelJ(fake)).softmax(1),
            # 'y5' : ssim
        }

    def curve_plot(self, ax, score, b, src, colors, y0 = -1):
        keys, items = list(score[y0].keys()), list(score[y0].values())
        score1, score2, _ = items[b]
        if b == 0:
            ax.plot([0, 1], [1/self.hparams.num_classes, 1/self.hparams.num_classes], 'k--')
        else:
            ax.plot([0, 1], [0, 1], 'k--')

        for i, color, cls in zip(range(self.hparams.num_classes), colors, 
                                self.trainer.datamodule.val_set.classes):
            if b == 0:
                PrecisionRecallDisplay(precision=score1[i].cpu(), recall=score2[i].cpu(), 
                                    average_precision=items[2 + b][i], estimator_name=cls).plot(ax, c=color)
            else:
                RocCurveDisplay(fpr=score1[i].cpu(), tpr=score2[i].cpu(), roc_auc=items[2 + b][i], 
                                estimator_name=cls).plot(ax, c=color)

        ax.legend(loc='best')
        if y0 % 4 == 3:
            self.log('LensGAN128/LensResnet(' + src + ')/val/' + keys[2 + b], sum(items[2 + b])/len(items[2 + b]))
        else:
            ax.set_title('Class ' + str(y0) + ' is fake')

    def validation_epoch_end(self, ListofDicts):
        # Classification scores
        Fid = self.imgMetrics.compute()
        # self.log_dict(Fid) isn't compatible with val_check_interval      
        self.log_dict(Fid)
        self.imgMetrics.reset()

        # ssim = torch.cat([x['y5'] for x in ListofDicts]).mean()
        # self.log('LensGAN128/val/ssim', ssim)

        labels = torch.cat([x['y0'] for x in ListofDicts])
        orig = [torch.cat([x['y1'] for x in ListofDicts]), torch.cat([x['y3'] for x in ListofDicts])]
        fake = [torch.cat([x['y2'] for x in ListofDicts]), torch.cat([x['y4'] for x in ListofDicts])]

        fig = plt.figure(constrained_layout=True, figsize=(20, 18))
        char = 'original' if self.global_step == 0 else 'generated'
        fig.suptitle('One vs. all PR & ROC curves for different types of substructures in model F (' 
                     + char + ' data)'
                      # \nStructural Similarity Index Measure: {:0.4f}'.format(ssim)
                     )

        # Different datasets
        subfigs = fig.subfigures(2, 1)
        colors = cycle(['r', 'g', 'b'])

        # -----------------------------------------------------------------------------------------------------
        if self.global_step > 0:
            cake = []
            for _ in range(self.hparams.num_classes):
                cake.append(copy.deepcopy(orig))

            for m, y0 in enumerate(labels):
                for a, yn in enumerate(cake[int(y0)]):
                    yn[m] = copy.deepcopy(fake[a][m])

            cake.append(copy.deepcopy(fake))     

        for a, src in enumerate(['F', 'J']):
            output=[]
            subfigs[a].suptitle('Classifier : model ' + src + ', Frechet Inception Distance : {:0.4f}'.format(Fid['LensGAN128/val/FID_'+ src]))
            if self.global_step == 0:
                ax = subfigs[a].subplots(1, 2, subplot_kw={'xlim': [-0.05,1.05], 'ylim': [-0.05,1.05], 'aspect': 1})
                self.metrics.reset()
                output.append(self.metrics(orig[a], labels))

            else:
                subsub = subfigs[a].subfigures(1, 2)
                for x in cake:
                    self.metrics.reset()
                    output.append(self.metrics(x[a], labels))
            
            # Different curve types
            for b in range(2):
                if self.global_step == 0:
                    self.curve_plot(ax[b], output, b, src, colors)
                else:
                    ax = subsub[b].subplots(2, 2, subplot_kw={'xlim': [-0.05,1.05], 'ylim': [-0.05,1.05], 'aspect': 1})
                    for x in range(2):
                        for y in range(2):
                            self.curve_plot(ax[x, y], output, b, src, colors, int(str(x) + str(y), 2))

        self.logger.experiment.add_figure('LensGAN128/LensResnet/val/PR_ROC', fig)
        fig.savefig(str(self.trainer.log_dir) + '/PR_ROC_step_{:05d}.png'.format(self.global_step))

In [None]:
# del m
m = LensGAN128({'lr':0.001, 'n_fmaps': 64, 'bs': 8, 'latent_dim' : 512})
print(m.generator, m.discriminator)

## VAE
Here we subclass a VAE.

In [None]:
BEST_F_RESNET = '/content/drive/MyDrive/Logs/F/LensResnet/pbt_tanh/train_LensResnet_eb619_00000_0_2021-09-02_19-42-34/checkpoint_epoch=2-step=28124'
BEST_J_RESNET = '/content/drive/MyDrive/Logs/J/LensResnet/pbt_tanh_fine/train_LensResnet_93609_00000_0_2021-09-02_21-06-00/checkpoint_epoch=2-step=28124'

class LensVAE(VAE):
    def __init__(self, config, num_classes: int = 3, **kwargs):
        super().__init__(input_height=config['input_height'], latent_dim=config['latent_dim'], 
                         lr=config['lr'], first_conv=True, maxpool1=True, **kwargs)
        self.save_hyperparameters(ignore=config)

        self.encoder.conv1 = nn.LazyConv2d(self.encoder.conv1.out_channels, 3 ,1, 1, bias=False)
        self.decoder.conv1 = nn.LazyConv2d(1, 3 ,1, 1, bias=False)
        self.decoder.linear = nn.LazyLinear(self.decoder.linear.out_features, False)

        temp = LensResnet.load_from_checkpoint(os.path.join(BEST_F_RESNET, 'checkpoint'))
        temp.freeze()
        self.modelF = temp.backbone     # torch.load(PRE_F_RESNET, map_location=self.device).eval())
        # Proper way to copy the last layer
        self.lastF = self.modelF.fc
        self.modelF.fc = nn.Identity()
        
        temp = LensResnet.load_from_checkpoint(os.path.join(BEST_J_RESNET, 'checkpoint'))
        temp.freeze()
        self.modelJ = temp.backbone     # torch.load(PRE_J_RESNET, map_location=self.device).eval())
        # Proper way to copy the last layer
        self.lastJ = self.modelJ.fc
        self.modelJ.fc = nn.Identity()

        self.imgMetrics = tm.MetricCollection(
            {
                'FID_F' : tm.FID(self.modelF),
                'FID_J' : tm.FID(self.modelJ),
            },
            prefix='LensVAE/val/',
        )

        metrics = tm.MetricCollection(
            [tm.PrecisionRecallCurve(self.hparams.num_classes), tm.ROC(self.hparams.num_classes),
            tm.AveragePrecision(self.hparams.num_classes, average=None), tm.AUROC(self.hparams.num_classes, average=None)],
        )
        self.FMetrics = metrics.clone()
        self.JMetrics = metrics.clone()
    

    def forward(self, y):
        # x = self.encoder(x)
        # mu = self.fc_mu(x)
        # log_var = self.fc_var(x)
        # _, _, z = self.sample(0, 1)
        inp = torch.cat((F.one_hot(y, self.hparams.num_classes), 
                         torch.randn(len(y), self.latent_dim, device=self.device)), 1)
        return self.decoder(inp)

    def _run_step(self, x, y):
        x = self.encoder(x)
        mu = self.fc_mu(x)
        log_var = self.fc_var(x)
        p, q, z = self.sample(mu, log_var)
        inp = torch.cat((F.one_hot(y, self.hparams.num_classes), z), 1)
        return z, self.decoder(inp), p, q

    def step(self, batch, batch_idx):
        x, y = batch
        z, x_hat, p, q = self._run_step(x, y)

        recon_loss = F.mse_loss(x_hat, x, reduction="mean")

        kl = torch.distributions.kl_divergence(q, p)
        kl = kl.mean()
        kl *= self.kl_coeff

        loss = kl + recon_loss

        logs = {
            "recon_loss": recon_loss,
            "kl": kl,
            "loss": loss,
        }
        return loss, logs

    def validation_step(self, batch, batch_idx):
        imgs, labels = batch
        self.imgMetrics.update(imgs, real=True)
        fake = F.interpolate(self(labels), 150).type_as(imgs)
        self.imgMetrics.update(fake, real=False)
        
        if self.global_step == 0:
            self.FMetrics.update(self.lastF(self.modelF(imgs)).softmax(1), labels)
            self.JMetrics.update(self.lastJ(self.modelJ(imgs)).softmax(1), labels)
        else:
            self.FMetrics.update(self.lastF(self.modelF(fake)).softmax(1), labels)
            self.JMetrics.update(self.lastJ(self.modelJ(fake)).softmax(1), labels)

    def validation_epoch_end(self, ListofDicts):
        # Classification scores
        fid = self.imgMetrics.compute()
        # self.log_dict(fid) isn't compatible with val_check_interval      
        self.log_dict(self.imgMetrics)
        
        fig = plt.figure(constrained_layout=True, figsize=(10, 9))
        # -----------------------------------------------------------------------------------------------------
        if self.global_step == 0:
            fig.suptitle('One vs. all PR & ROC curves for different types of substructures in model F (original data)')
        else:
            fig.suptitle('One vs. all PR & ROC curves for different types of substructures in model F (generated data)')
        subfigs = fig.subfigures(2, 1)
        colors = cycle(['r', 'g', 'b'])
        
        for a, src in enumerate(['F', 'J']):
            if self.global_step == 0:
                subfigs[a].suptitle('Classifier : model ' + src)
            else:
                subfigs[a].suptitle('Classifier : model ' + src + ', FID : {:0.2f}'.format(fid['LensVAE/val/FID_'+ src]))
            ax = subfigs[a].subplots(1,2, subplot_kw={'xlim': [-0.05,1.05], 'ylim': [-0.05,1.05], 'aspect': 1})
            
            temp = getattr(self, src + 'Metrics')
            Dict = temp.compute()
            temp.reset()

            key, val = list(Dict.keys()), list(Dict.values())
            for b in range(2):
                if key[b] != 'ROC':
                    ax[b].plot([0, 1], [1/self.hparams.num_classes, 1/self.hparams.num_classes], 'k--')
                else:
                    ax[b].plot([0, 1], [0, 1], 'k--')
                
                prec_FPR, rec_TPR, _ = val[b]
                for i, color, cls in zip(range(self.hparams.num_classes), colors, self.trainer.datamodule.val_set.classes):
                    if key[b] != 'ROC':
                        PrecisionRecallDisplay(prec_FPR[i].cpu(), rec_TPR[i].cpu(), val[2 + b][i], cls).plot(ax[b], c=color)
                    else:
                        RocCurveDisplay(prec_FPR[i].cpu(), rec_TPR[i].cpu(), val[2 + b][i], cls).plot(ax[b], c=color)
                    
                ax[b].legend(loc='best')
                self.log('LensVAE/LensResnet(' + src + ')/val/' + key[2 + b], sum(val[2 + b])/len(val[2 + b]))

        self.logger.experiment.add_figure('LensVAE/LensResnet/val/PR_ROC', fig)
        fig.savefig(str(self.trainer.log_dir) + '/PR_ROC_step_{:05d}.png'.format(self.global_step))

        # Save layer-wise activations
        labels = torch.arange(self.hparams.num_classes, device=self.device)
        imgs = self(labels)
        # for lyr, img4d in enumerate(imgs[:-1]):
        #     new = list(img4d)
        #     for cls, filters3d in enumerate(new):
        #         new[cls] = make_grid(filters3d.unsqueeze(1), nrow=int(np.ceil(np.sqrt(len(filters3d)))), normalize=True)
        #         # print(three3d.shape)
        #         # new[cls] = three3d[0].unsqueeze(0)
        #         # print(new[cls].shape)

        #     imgs[lyr] = F.interpolate(torch.stack(new), 150)
        #     # print(imgs[lyr].shape)

        # save_image(torch.cat(imgs[-2::-1]), 
        #         str(self.trainer.log_dir) + '/F_maps_step_{:05d}.png'.format(self.global_step), 
        #         #  kwargs for make_grid
        #         nrow=self.hparams.num_classes, pad_value=0.5)
        
        # imgs = imgs[-1]
        while len(labels) > 0:
            x, y = self.trainer.datamodule.val_set[np.random.randint(0, len(self.trainer.datamodule.val_set))]
            if y == labels[0]:
                try:
                    imgs = torch.cat((imgs, x.unsqueeze(0).type_as(imgs)))
                except Exception as f:
                    print(imgs.size(), x.size())
                    raise Exception(f)
                labels = labels[1:]

        save_image(F.interpolate(imgs, 150), 
                str(self.trainer.log_dir) + '/Fake_step_{:05d}.png'.format(self.global_step), 
                #  kwargs for make_grid
                nrow=self.hparams.num_classes, normalize=True, value_range=(-1,1), pad_value=0.5)

In [None]:
config={'input_height': 128,
        'latent_dim': 100,
        'lr': 1e-3,
        'bs': 8,
        }
m = LensVAE(config)

## Stage 2
Here we subclass a DCGAN to create our high resolution GAN.

In [None]:
class Generator2(nn.Module):
    def __init__(self, ngf: int = 128, image_channels: int = 1, res_depth: int = 6):
        super().__init__()

        ker, strd = 4, 2
        pad = int((ker - 2)/2)
        res_ker, res_strd, res_pad = 3, 1, 1
        
        # 64 -> 32
        self.preprocessing = nn.Sequential(
            nn.Conv2d(image_channels, ngf, ker, strd, pad, bias=False),
            nn.ReLU(True)
        )
        # residuals
        layer = []
        for _ in range(res_depth):
            layer.append(BasicBlock(ngf, ngf))
        self.residual = nn.Sequential(*layer)
        
        self.ending_residual = nn.Sequential(
            nn.Conv2d(ngf, ngf, res_ker, res_strd, res_pad, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True)
        )

        # at this part, add the residual inputs from after the preprocessing

        image_width = 150 # upscaling should be factor of 2 increase
        mode = 'nearest' # upscaling method is nearest-neighbour
        self.main = nn.Sequential(
            # 32 -> 75
            nn.Upsample(image_width//2, mode=mode),
            nn.Conv2d(ngf, ngf*4, res_ker, res_strd, res_pad, bias=False),
            nn.BatchNorm2d(ngf*4),
            nn.ReLU(True),
            # 75 -> 150
            nn.Upsample(image_width, mode=mode),
            nn.Conv2d(ngf*4, image_channels, res_ker, res_strd, res_pad, bias=False),
            nn.Tanh()
        )

    def forward(self, in_x):
        x_p = self.preprocessing(in_x)
        x_r = x_p
        x_r = self.residual(x_r)
        x_r = self.ending_residual(x_r)
        # large residual connections
        x_f = x_r + x_p
        return self.main(x_f)

In [None]:
BEST_F_LensGAN128 = '/content/drive/MyDrive/Logs/F/LensGAN128/pbt_tanh_1/train_LensGAN128_90727_00001_1_n_fmaps=16_2021-08-30_08-02-05/checkpoint_epoch=4-step=1988/'
BEST_J_LensGAN128 ='/content/drive/MyDrive/Logs/J/LensGAN128/pbt_tanh/train_LensGAN128_28c03_00003_3_n_fmaps=64_2021-08-31_20-45-44/checkpoint_epoch=2-step=1403/'

class Stage2(DCGAN):
    def __init__(self, config, num_classes: int = 3, **kwargs):
        super().__init__(feature_maps_gen=config['n_fmaps'], feature_maps_disc=config['n_fmaps'], learning_rate=config['learning_rate'])
        self.save_hyperparameters(ignore=config)

        self.generator = Generator2(self.hparams.feature_maps_gen, self.hparams.image_channels, config['res_depth'])

        # These are better as attributes, instead of being returned by a method
        self.modelF = getattr(self, 'modelF', LensResnet.load_from_checkpoint(os.path.join(BEST_RESNET_F, 'checkpoint')).eval())
        self.modelJ = getattr(self, 'modelJ', LensResnet.load_from_checkpoint(os.path.join(BEST_RESNET_J, 'checkpoint')).eval())
        # Workaround:
        self.lowres = getattr(self, 'lowres', LensGAN128.load_from_checkpoint(os.path.join(BEST_LensGAN128_F, 'checkpoint')).eval())
        
        metrics = tm.MetricCollection(
            [
             tm.AUROC(num_classes=self.hparams.num_classes, compute_on_step=False, average=None), 
             tm.ROC(num_classes=self.hparams.num_classes, compute_on_step=False),
            ]
        )
        self.metricsF = metrics.clone()
        self.metricsJ = metrics.clone()

    def forward(self, noise):
        return self.generator(noise)

    def training_step(self, batch, batch_idx, optimizer_idx):
        real, self.labels = batch

        # Train discriminator
        result = None
        if optimizer_idx == 0:
            result = self._disc_step(real)

        # Train generator
        if optimizer_idx == 1:
            result = self._gen_step(real)

        return result

    def _disc_step(self, real):
        disc_loss = self._get_disc_loss(real)
        self.log('Stage2/D/train/loss', disc_loss, on_epoch=True)
        return disc_loss

    def _gen_step(self, real):
        gen_loss = self._get_gen_loss(real)
        self.log('Stage2/G/train/loss', gen_loss, on_epoch=True)
        return gen_loss

    def _get_gen_loss(self, real: torch.Tensor) -> torch.Tensor:
        # Train with fake
        fake_pred = self._get_fake_pred(real)
        fake_gt = torch.ones_like(fake_pred)
        gen_loss = self.criterion(fake_pred, fake_gt)

        # class_pred =  self._get_class_pred(len(real))
        # gen_loss += F.cross_entropy(class_pred, self.labels)

        return gen_loss

    def _get_class_pred(self, batch_size) -> torch.Tensor:
        # ----------------------------------------------------------------------------------------------------------------
        return self.modelF.backbone(self(self._get_noise(batch_size, self.hparams.latent_dim)))

    def _get_noise(self, n_samples: int, latent_dim: int, labels = None):
        # can't use self in function definition
        if labels is None:
            labels = self.labels
            # getattr(self, 'labels', torch.randint(self.hparams.num_classes, (n_samples,), device=self.device))  # last dimension is the hidden dimension
        return self.lowres(super()._get_noise(n_samples, latent_dim), labels)

    def validation_step(self, batch, batch_idx):
        imgs, labels = batch
        out = self(self._get_noise(labels.shape[0], self.hparams.latent_dim, labels))
        self.metricsF.update(self.modelF(out), labels)
        self.metricsJ.update(self.modelJ(out), labels)

    def validation_epoch_end(self, listofDicts):
        fig, ax = plt.subplots(1,2, 
            subplot_kw={'xlim': [0,1], 'xlabel': 'False Positive Rate', 'ylim': [0,1.05], 
                        'ylabel': 'True Positive Rate',
            },
            figsize=[11, 5],
        )
        for j, letter in enumerate(['F', 'J']):
            output = getattr(self, 'metrics' + letter).compute()
            self.log('Stage2/ResNet(' + letter + ')/val/auroc', output['AUROC'].min())
            fprList, tprList, _ = output['ROC']
            
            colors = cycle(['red', 'blue', 'green'])
            for i, color in zip(range(self.hparams.num_classes), colors):
                ax[j].plot(fprList[i].cpu(), tprList[i].cpu(), color=color,
                        label='ROC curve of class {0} (area = {1:0.4f})'
                        ''.format(i, output['AUROC'][i]))
            post_plotting(ax[j])
            ax[j].set_title('One vs. all ROC curve (' + letter + ')')
        
        fig.tight_layout()
        self.logger.experiment.add_figure('Stage2/ResNet/val/ROC', fig)
        fig.savefig(str(self.trainer.log_dir) + '/ROC_step_{:05d}.png'.format(self.global_step))

        labels = torch.arange(self.hparams.num_classes, device=self.device)
        save_image(self(self._get_noise(labels.shape[0], self.hparams.latent_dim, labels)), 
                   str(self.trainer.log_dir) + '/Fake_step_{:05d}.png'.format(self.global_step), 
                  #  kwargs for make_grid
                   normalize=True, value_range=(-1,1))

    def on_fit_end(self):
        delattr(self, 'modelF')
        delattr(self, 'modelJ')
        delattr(self, 'labels')
        delattr(self, 'lowres')

# Tune Hyperparameters


## ResNet
Here we tune hyperparameters as we train our modified ResNet.

In [None]:
%rm -rf /content/drive/MyDrive/Logs/fakeF/PGAN/LensResnet/pbt_tanh_validate

In [None]:
# __tune_train_checkpoint_begin
def train_LensResnet(config, checkpoint_dir=None, num_epochs=10, num_gpus=torch.cuda.device_count()):
    # print(os.cpu_count(), torch.cuda.device_count())
    kwargs = {
        # 'limit_train_batches' : 0.005,
        # 'limit_val_batches' : 0.005,
        'progress_bar_refresh_rate' : int(8250//int(2**np.rint(config['bs']))),
        'max_epochs' : num_epochs,
        'prepare_data_per_node' : False,
        # If fractional GPUs passed in, convert to int.
        'gpus' : int(num_gpus),
        'logger' : TensorBoardLogger(save_dir=tune.get_trial_dir(), name='', version='.'),
        'callbacks' : [
            TuneReportCheckpointCallback(
                {
                    'loss': 'LensResnet/val/loss', 
                    'auroc': 'LensResnet/val/AUROC', 
                    'ap' : 'LensResnet/val/AveragePrecision',
                },
            ),
            ModuleDataMonitor(['backbone.layer2', 'backbone.layer4', 'backbone.fc']),
            ConfusedLogitCallback(5),
        ],
        'stochastic_weight_avg' : True,
        # works with only one optimizer
        'benchmark' : True,
        'precision' : 16,     # can't use on cpu
        # 'track_grad_norm': 2,
        # 'gradient_clip_val' : 0.5, 
        # 'gradient_clip_algorithm' : 'value',
    }
    
    dm = npyImageData(config)                                              # Specify image width here    
    if checkpoint_dir is not None:
        kwargs['resume_from_checkpoint'] = os.path.join(checkpoint_dir, 'checkpoint')
        # model = LensResnet.load_from_checkpoint(kwargs['resume_from_checkpoint'], config=config)
    # else:

    model = LensResnet(config)
    trainer = pl.Trainer(**kwargs)

    trainer.fit(model, dm)
# __tune_train_checkpoint_end__

# __tune_pbt_begin__
def tune_LensResnet_pbt(num_samples=10, num_epochs=10, gpus_per_trial=torch.cuda.device_count()):
    # print(os.cpu_count(), torch.cuda.device_count())
    analysis = tune.run(
        tune.with_parameters(
            train_LensResnet,
            num_epochs=num_epochs,
            num_gpus=gpus_per_trial
        ),
        # Change the folder name when changing dataset--------------------------------------------------------------------------
        name='J/LensResnet/pbt_tanh_fine',
        metric='loss',
        mode='min',
        # stop=TrialPlateauStopper('auroc'),
        resources_per_trial={'cpu': os.cpu_count(), 'gpu': gpus_per_trial},
        local_dir='./drive/MyDrive/Logs',
        # config={'lr': tune.choice([1e-4, 1e-3, 1e-5, 1e-2, 1e-6, 1e-1, 1e-7]),
        #         'bs': tune.grid_search([8, 16, 32, 64, 128]),
        #         },
        # scheduler = pbtScheduler(max_t=num_epochs, grace_period=2, reduction_factor=2),
        # Can't use RB2 as it requires mutations to be continuous
        config={'lr': 1e-5,
                'bs': 8,
                # RuntimeError: stack expects each tensor to be equal size, but got [128] at entry 0 and [120] at entry 585
                },
        scheduler = PopulationBasedTraining(time_attr='training_iteration', quantile_fraction=0.4,
                                            resample_probability=0.2,  perturbation_interval=1,
                                            hyperparam_mutations={
                                                'lr': tune.loguniform(1e-6, 1e-4),
                                                'bs': [8, 16, 32, 64, 128],
                                            },
        ),
        progress_reporter=JupyterNotebookReporter(
            overwrite=False,
            parameter_columns=['lr', 'bs'],
            metric_columns=['loss', 'auroc', 'ap', 'training_iteration'],
        ),
        fail_fast = True,
        # reuse_actors=True,
        num_samples=num_samples,
        # resume='PROMPT',
    )
    BEST_J_RESNET = analysis.best_checkpoint
    print('Best checkpoint path found is: ', BEST_J_RESNET)

# __tune_pbt_end__

if __name__ == '__main__':
    import argparse

    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--smoke-test', action='store_true', help='Finish quickly for testing')
    args, _ = parser.parse_known_args()

    if args.smoke_test:
        tune_LensResnet_pbt(num_samples=1, num_epochs=6, gpus_per_trial=torch.cuda.device_count())
    else:
        # pbt scheduler
        tune_LensResnet_pbt(num_samples=1, num_epochs=5, gpus_per_trial=torch.cuda.device_count())

## LensGAN128
Here we tune hyperparameters as we train our modified DCGAN.

In [None]:
%rm -rf drive/MyDrive/Logs/F/LensGAN128/pgan_variable_ssim_dropout_emb_lr3_128

In [None]:
# __tune_train_checkpoint_begin
def train_LensGAN128(config, checkpoint_dir=None, num_steps=2, num_gpus=torch.cuda.device_count()):
    dm = npyImageData(config)                                              # Specify image width here
    model = LensGAN128(config, latent_dim=512)
    trainer = pl.Trainer(
        # detect_anomaly=True,
        # limit_train_batches=0.10, limit_val_batches=0.10,
        # val_check_interval=0.10, 
        track_grad_norm=2,
        num_sanity_val_steps=10, max_epochs=num_steps,
        # prepare_data_per_node=False,
        # If fractional GPUs passed in, convert to int.
        gpus=int(num_gpus), logger=TensorBoardLogger(save_dir=tune.get_trial_dir(), name='', version='.'),
        callbacks=[
            TuneReportCheckpointCallback(
                {
                    # 'loss_G': 'LensGAN128/G/train/loss', 
                    # 'fake_G' : 'LensGAN128/G/train/loss/fake', 
                    # 'CE_G' : 'LensGAN128/G/train/loss/class', 
                    # 'ssim_G' : 'LensGAN128/G/train/loss/ssim',
                    # 'mse_G' : 'LensGAN128/G/train/loss/mse',
                    # 'loss_D': 'LensGAN128/D/train/loss', 
                    # 'real_D' : 'LensGAN128/D/train/loss/real', 
                    # 'fake_D' : 'LensGAN128/D/train/loss/fake', 
                    # 'CE_D' : 'LensGAN128/D/train/loss/class', 
                    # 'gp_D' : 'LensGAN128/D/train/loss/gp',
                    # Switch up the FID vlues when training on different dataset -----------------------------------------------
                    'FID_F' : 'LensGAN128/val/FID_F', 
                    'FID_J' : 'LensGAN128/val/FID_J',
                    # 'ssim' : 'LensGAN128/val/ssim',
                    'auroc_F' : 'LensGAN128/LensResnet(F)/val/AUROC',
                    'auroc_J' : 'LensGAN128/LensResnet(J)/val/AUROC',
                    'ap_F' : 'LensGAN128/LensResnet(F)/val/AveragePrecision',
                    'ap_J' : 'LensGAN128/LensResnet(J)/val/AveragePrecision'
                },
                # Validation end is better, resumes with updated checkpoint
                # on='train_end',
            ),
            ModuleDataMonitor(True),
            TQDMProgressBar(refresh_rate=int(7500//int(2**np.rint(config['bs'])))),
        ],
        # stochastic_weight_avg=True,
        # works with only one optimizer
        benchmark=True, 
        precision=16,
    )
    
    loc = os.path.join(checkpoint_dir, 'checkpoint') if checkpoint_dir is not None else None
    trainer.fit(model, dm, ckpt_path=loc)
# __tune_train_checkpoint_end__

# __tune_pbt_begin__
def tune_LensGAN128_pbt(num_samples=10, num_steps=2, gpus_per_trial=torch.cuda.device_count()):
    analysis = tune.run(
        tune.with_parameters(
            train_LensGAN128,
            num_steps=num_steps,
            num_gpus=gpus_per_trial
        ),
        # Change the folder name when changing dataset--------------------------------------------------------------------------
        name='F/LensGAN128/pgan_variable_ssim_dropout_emb_lr3_128', 
        metric='auroc_F',
        mode='max',
        resources_per_trial={'cpu': os.cpu_count(), 'gpu': gpus_per_trial},
        local_dir='./drive/MyDrive/Logs',
        config={'lr': -3,
                'n_fmaps': 128,
                'bs': 3,
                },
        scheduler = PB2('training_iteration', quantile_fraction=0.25, perturbation_interval=1,
                            # resample_probability=0.25,  
                            hyperparam_bounds={
                                'lr': [-6, -3],  #tune.loguniform(1e-5, 1e-3),
                                'bs': [3, 7],    #[8, 16, 32, 64, 128],
                            },
        ),
        progress_reporter=JupyterNotebookReporter(
            overwrite=False,
            parameter_columns=['lr', 'n_fmaps', 'bs'],
            metric_columns=['FID_F', 'auroc_F', 'ap_F', 
                            # 'FID_J', 'auroc_J', 'ap_J', 
                            'training_iteration',
                            # 'ssim_G', 'fake_G', 'CE_G', 'real_D', 'fake_D', 'CE_D', 
                            #'mse_G', 'gp_D',
                            ],
            max_report_frequency=300,
        ),
        fail_fast = True,
        reuse_actors=True,
        num_samples=num_samples,
        resume=True,
    )
    print('Best checkpoint path found is: ', analysis.best_checkpoint)

# __tune_pbt_end__

if __name__ == '__main__':
    import argparse

    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--smoke-test', action='store_true', help='Finish quickly for testing')
    args, _ = parser.parse_known_args()

    ray._private.utils.get_system_memory = lambda: psutil.virtual_memory().total
    PL_FAULT_TOLERANT_TRAINING=1
    if args.smoke_test:
        tune_LensGAN128_pbt(num_samples=1, num_steps=10, gpus_per_trial=torch.cuda.device_count())
    else:
        # pbt scheduler
        tune_LensGAN128_pbt(num_samples=1, num_steps=10, gpus_per_trial=torch.cuda.device_count())

2022-05-24 01:36:08,495	INFO tune.py:605 -- TrialRunner resumed, ignoring new add_experiment but updating trial resources.
2022-05-24 01:36:08,684	INFO trial_runner.py:803 -- starting train_LensGAN128_52f0b_00000
2022-05-24 01:36:11,085	INFO trainable.py:93 -- Checkpoint size is 274857457 bytes


Trial name,status,loc,lr,n_fmaps,bs,FID_F,auroc_F,ap_F,training_iteration
train_LensGAN128_52f0b_00000,RUNNING,172.28.0.2:1455,-3,128,3,191.25,0.59189,0.51623,3


[2m[36m(train_LensGAN128 pid=3261)[0m 2022-05-24 01:36:17,780	INFO trainable.py:535 -- Restored on 172.28.0.2 from checkpoint: /content/drive/MyDrive/Logs/F/LensGAN128/pgan_variable_ssim_dropout_emb_lr3_128/train_LensGAN128_52f0b_00000_0_2022-05-24_00-27-42/checkpoint_tmp0a4be0/./
[2m[36m(train_LensGAN128 pid=3261)[0m 2022-05-24 01:36:17,781	INFO trainable.py:543 -- Current state after restoring: {'_iteration': 3, '_timesteps_total': None, '_time_total': 3748.126875400543, '_episodes_total': None}
[2m[36m(train_LensGAN128 pid=3261)[0m Using 16bit native Automatic Mixed Precision (AMP)
[2m[36m(train_LensGAN128 pid=3261)[0m GPU available: True, used: True
[2m[36m(train_LensGAN128 pid=3261)[0m TPU available: False, using: 0 TPU cores
[2m[36m(train_LensGAN128 pid=3261)[0m IPU available: False, using: 0 IPUs
[2m[36m(train_LensGAN128 pid=3261)[0m HPU available: False, using: 0 HPUs
[2m[36m(train_LensGAN128 pid=3261)[0m   "The `on_keyboard_interrupt` callback hook was 

Epoch 2:   0%|          | 0/10313 [00:00<00:00, -45616705.34it/s]  
Epoch 2:   0%|          | 0/10313 [00:00<00:00, -49088.19it/s, loss=nan, v_num=.]
Epoch 3:   0%|          | 0/10313 [00:00<00:00, -28976860.72it/s, loss=nan, v_num=.]


[2m[36m(train_LensGAN128 pid=3261)[0m E0524 01:38:22.280083813    3301 chttp2_transport.cc:1103]   Received a GOAWAY with error code ENHANCE_YOUR_CALM and debug data equal to "too_many_pings"


[2m[36m(train_LensGAN128 pid=3261)[0m Epoch 3:   9%|▉         | 937/10313 [04:09<-1:55:23, -33.80it/s, loss=nan, v_num=.] Epoch 3:   9%|▉         | 937/10313 [04:09<-1:55:23, -33.80it/s, loss=nan, v_num=.]Epoch 3:   9%|▉         | 937/10313 [04:09<-1:55:23, -33.80it/s, loss=2.13, v_num=., loss_G_fake=0.694, loss_G_class=1.080, loss_G_ssim=0.274, loss_G=2.050, loss_D_real=0.429, loss_D_fake=0.693, loss_D_class=1.120, loss_D=2.240]


Trial name,status,loc,lr,n_fmaps,bs,FID_F,auroc_F,ap_F,training_iteration
train_LensGAN128_52f0b_00000,RUNNING,172.28.0.2:1455,-3,128,3,191.25,0.59189,0.51623,3


[2m[36m(train_LensGAN128 pid=3261)[0m Epoch 3:  18%|█▊        | 1874/10313 [08:34<-1:50:22, -14.59it/s, loss=2.13, v_num=., loss_G_fake=0.694, loss_G_class=1.080, loss_G_ssim=0.274, loss_G=2.050, loss_D_real=0.429, loss_D_fake=0.693, loss_D_class=1.120, loss_D=2.240]Epoch 3:  18%|█▊        | 1874/10313 [08:34<-1:50:22, -14.59it/s, loss=2.13, v_num=., loss_G_fake=0.694, loss_G_class=1.080, loss_G_ssim=0.274, loss_G=2.050, loss_D_real=0.429, loss_D_fake=0.693, loss_D_class=1.120, loss_D=2.240]Epoch 3:  18%|█▊        | 1874/10313 [08:34<-1:50:22, -14.59it/s, loss=2.03, v_num=., loss_G_fake=0.730, loss_G_class=1.140, loss_G_ssim=0.452, loss_G=2.320, loss_D_real=0.163, loss_D_fake=0.687, loss_D_class=1.160, loss_D=2.010]


Trial name,status,loc,lr,n_fmaps,bs,FID_F,auroc_F,ap_F,training_iteration
train_LensGAN128_52f0b_00000,RUNNING,172.28.0.2:1455,-3,128,3,191.25,0.59189,0.51623,3


[2m[36m(train_LensGAN128 pid=3261)[0m Epoch 3:  27%|██▋       | 2811/10313 [12:46<-1:45:24, -8.56it/s, loss=2.03, v_num=., loss_G_fake=0.730, loss_G_class=1.140, loss_G_ssim=0.452, loss_G=2.320, loss_D_real=0.163, loss_D_fake=0.687, loss_D_class=1.160, loss_D=2.010] Epoch 3:  27%|██▋       | 2811/10313 [12:46<-1:45:24, -8.56it/s, loss=2.03, v_num=., loss_G_fake=0.730, loss_G_class=1.140, loss_G_ssim=0.452, loss_G=2.320, loss_D_real=0.163, loss_D_fake=0.687, loss_D_class=1.160, loss_D=2.010]Epoch 3:  27%|██▋       | 2811/10313 [12:46<-1:45:24, -8.56it/s, loss=2.13, v_num=., loss_G_fake=0.643, loss_G_class=1.100, loss_G_ssim=0.207, loss_G=1.950, loss_D_real=0.415, loss_D_fake=0.769, loss_D_class=1.100, loss_D=2.290]


Trial name,status,loc,lr,n_fmaps,bs,FID_F,auroc_F,ap_F,training_iteration
train_LensGAN128_52f0b_00000,RUNNING,172.28.0.2:1455,-3,128,3,191.25,0.59189,0.51623,3


[2m[36m(train_LensGAN128 pid=3261)[0m Epoch 3:  36%|███▋      | 3748/10313 [16:44<-1:40:28, -5.60it/s, loss=2.13, v_num=., loss_G_fake=0.643, loss_G_class=1.100, loss_G_ssim=0.207, loss_G=1.950, loss_D_real=0.415, loss_D_fake=0.769, loss_D_class=1.100, loss_D=2.290]Epoch 3:  36%|███▋      | 3748/10313 [16:44<-1:40:28, -5.60it/s, loss=2.13, v_num=., loss_G_fake=0.643, loss_G_class=1.100, loss_G_ssim=0.207, loss_G=1.950, loss_D_real=0.415, loss_D_fake=0.769, loss_D_class=1.100, loss_D=2.290]Epoch 3:  36%|███▋      | 3748/10313 [16:44<-1:40:28, -5.60it/s, loss=2.06, v_num=., loss_G_fake=0.715, loss_G_class=1.090, loss_G_ssim=0.171, loss_G=1.980, loss_D_real=0.381, loss_D_fake=0.665, loss_D_class=1.090, loss_D=2.140]


Trial name,status,loc,lr,n_fmaps,bs,FID_F,auroc_F,ap_F,training_iteration
train_LensGAN128_52f0b_00000,RUNNING,172.28.0.2:1455,-3,128,3,191.25,0.59189,0.51623,3


[2m[36m(train_LensGAN128 pid=3261)[0m Epoch 3:  45%|████▌     | 4685/10313 [20:48<-1:35:02, -3.76it/s, loss=2.06, v_num=., loss_G_fake=0.715, loss_G_class=1.090, loss_G_ssim=0.171, loss_G=1.980, loss_D_real=0.381, loss_D_fake=0.665, loss_D_class=1.090, loss_D=2.140]Epoch 3:  45%|████▌     | 4685/10313 [20:48<-1:35:02, -3.76it/s, loss=2.06, v_num=., loss_G_fake=0.715, loss_G_class=1.090, loss_G_ssim=0.171, loss_G=1.980, loss_D_real=0.381, loss_D_fake=0.665, loss_D_class=1.090, loss_D=2.140]Epoch 3:  45%|████▌     | 4685/10313 [20:48<-1:35:02, -3.76it/s, loss=2.02, v_num=., loss_G_fake=0.717, loss_G_class=1.150, loss_G_ssim=0.188, loss_G=2.050, loss_D_real=0.212, loss_D_fake=0.671, loss_D_class=1.150, loss_D=2.030]


Trial name,status,loc,lr,n_fmaps,bs,FID_F,auroc_F,ap_F,training_iteration
train_LensGAN128_52f0b_00000,RUNNING,172.28.0.2:1455,-3,128,3,191.25,0.59189,0.51623,3


Epoch 3:  55%|█████▍    | 5622/10313 [24:52<-1:28:55, -2.51it/s, loss=2.13, v_num=., loss_G_fake=0.683, loss_G_class=1.030, loss_G_ssim=0.236, loss_G=1.950, loss_D_real=0.433, loss_D_fake=0.706, loss_D_class=1.170, loss_D=2.310]
Epoch 3:  64%|██████▎   | 6559/10313 [28:55<-1:21:26, -1.62it/s, loss=2.12, v_num=., loss_G_fake=0.655, loss_G_class=1.090, loss_G_ssim=0.201, loss_G=1.940, loss_D_real=0.441, loss_D_fake=0.713, loss_D_class=1.100, loss_D=2.250]


Trial name,status,loc,lr,n_fmaps,bs,FID_F,auroc_F,ap_F,training_iteration
train_LensGAN128_52f0b_00000,RUNNING,172.28.0.2:1455,-3,128,3,191.25,0.59189,0.51623,3


[2m[36m(train_LensGAN128 pid=3261)[0m Epoch 3:  73%|███████▎  | 7496/10313 [32:49<-1:10:47, -0.95it/s, loss=2.12, v_num=., loss_G_fake=0.655, loss_G_class=1.090, loss_G_ssim=0.201, loss_G=1.940, loss_D_real=0.441, loss_D_fake=0.713, loss_D_class=1.100, loss_D=2.250]Epoch 3:  73%|███████▎  | 7496/10313 [32:49<-1:10:47, -0.95it/s, loss=2.12, v_num=., loss_G_fake=0.655, loss_G_class=1.090, loss_G_ssim=0.201, loss_G=1.940, loss_D_real=0.441, loss_D_fake=0.713, loss_D_class=1.100, loss_D=2.250]Epoch 3:  73%|███████▎  | 7496/10313 [32:49<-1:10:47, -0.95it/s, loss=2.12, v_num=., loss_G_fake=0.643, loss_G_class=1.100, loss_G_ssim=0.138, loss_G=1.880, loss_D_real=0.477, loss_D_fake=0.758, loss_D_class=1.100, loss_D=2.330]


Trial name,status,loc,lr,n_fmaps,bs,FID_F,auroc_F,ap_F,training_iteration
train_LensGAN128_52f0b_00000,RUNNING,172.28.0.2:1455,-3,128,3,191.25,0.59189,0.51623,3


[2m[36m(train_LensGAN128 pid=3261)[0m Epoch 3:  82%|████████▏ | 8433/10313 [36:53<-2:46:23, -0.43it/s, loss=2.12, v_num=., loss_G_fake=0.643, loss_G_class=1.100, loss_G_ssim=0.138, loss_G=1.880, loss_D_real=0.477, loss_D_fake=0.758, loss_D_class=1.100, loss_D=2.330]Epoch 3:  82%|████████▏ | 8433/10313 [36:53<-2:46:23, -0.43it/s, loss=2.12, v_num=., loss_G_fake=0.643, loss_G_class=1.100, loss_G_ssim=0.138, loss_G=1.880, loss_D_real=0.477, loss_D_fake=0.758, loss_D_class=1.100, loss_D=2.330]Epoch 3:  82%|████████▏ | 8433/10313 [36:53<-2:46:23, -0.43it/s, loss=2.11, v_num=., loss_G_fake=0.643, loss_G_class=1.100, loss_G_ssim=0.107, loss_G=1.850, loss_D_real=0.470, loss_D_fake=0.716, loss_D_class=1.110, loss_D=2.290]


Trial name,status,loc,lr,n_fmaps,bs,FID_F,auroc_F,ap_F,training_iteration
train_LensGAN128_52f0b_00000,RUNNING,172.28.0.2:1455,-3,128,3,191.25,0.59189,0.51623,3


Epoch 3:  91%|█████████ | 9370/10313 [40:56<-129:18:01, -0.00it/s, loss=2.09, v_num=., loss_G_fake=0.618, loss_G_class=1.010, loss_G_ssim=0.197, loss_G=1.830, loss_D_real=0.444, loss_D_fake=0.791, loss_D_class=1.060, loss_D=2.300]
[2m[36m(train_LensGAN128 pid=3261)[0m 
Validation: 0it [00:00, ?it/s][A
[2m[36m(train_LensGAN128 pid=3261)[0m 
Validation:   0%|          | 0/938 [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|          | 0/938 [00:00<?, ?it/s][A
Epoch 3: 100%|█████████▉| 10307/10313 [41:52<00:16,  2.70s/it, loss=2.09, v_num=., loss_G_fake=0.618, loss_G_class=1.010, loss_G_ssim=0.197, loss_G=1.830, loss_D_real=0.444, loss_D_fake=0.791, loss_D_class=1.060, loss_D=2.300]
[2m[36m(train_LensGAN128 pid=3261)[0m 
Validation DataLoader 0: 100%|█████████▉| 937/938 [00:55<00:00, 16.83it/s][A
Validation DataLoader 0: 100%|█████████▉| 937/938 [00:55<00:00, 16.83it/s][A
[2m[36m(train_LensGAN128 pid=3261)[0m 
Epoch 3: 100%|██████████| 10313/10313 [41:53<00:00,  2.68s/it

Trial name,status,loc,lr,n_fmaps,bs,FID_F,auroc_F,ap_F,training_iteration
train_LensGAN128_52f0b_00000,RUNNING,172.28.0.2:3261,-3,128,3,74.625,0.0755425,0.197507,4


[2m[36m(train_LensGAN128 pid=3261)[0m Epoch 4:   9%|▉         | 937/10313 [03:57<-1:55:37, -35.57it/s, loss=2.11, v_num=., loss_G_fake=0.686, loss_G_class=1.080, loss_G_ssim=0.195, loss_G=1.960, loss_D_real=0.478, loss_D_fake=0.695, loss_D_class=1.100, loss_D=2.280]Epoch 4:   9%|▉         | 937/10313 [03:57<-1:55:37, -35.57it/s, loss=2.11, v_num=., loss_G_fake=0.686, loss_G_class=1.080, loss_G_ssim=0.195, loss_G=1.960, loss_D_real=0.478, loss_D_fake=0.695, loss_D_class=1.100, loss_D=2.280]Epoch 4:   9%|▉         | 937/10313 [03:57<-1:55:37, -35.57it/s, loss=2.07, v_num=., loss_G_fake=0.586, loss_G_class=1.130, loss_G_ssim=0.172, loss_G=1.890, loss_D_real=0.381, loss_D_fake=0.881, loss_D_class=1.170, loss_D=2.430]


Trial name,status,loc,lr,n_fmaps,bs,FID_F,auroc_F,ap_F,training_iteration
train_LensGAN128_52f0b_00000,RUNNING,172.28.0.2:3261,-3,128,3,74.625,0.0755425,0.197507,4


Epoch 4:  18%|█▊        | 1874/10313 [08:01<-1:50:59, -15.58it/s, loss=2.1, v_num=., loss_G_fake=0.652, loss_G_class=1.110, loss_G_ssim=0.114, loss_G=1.880, loss_D_real=0.421, loss_D_fake=0.776, loss_D_class=1.130, loss_D=2.320] 
Epoch 4:  27%|██▋       | 2811/10313 [12:06<-1:46:10, -9.04it/s, loss=2.07, v_num=., loss_G_fake=0.609, loss_G_class=1.110, loss_G_ssim=0.177, loss_G=1.900, loss_D_real=0.312, loss_D_fake=0.883, loss_D_class=1.110, loss_D=2.310]


Trial name,status,loc,lr,n_fmaps,bs,FID_F,auroc_F,ap_F,training_iteration
train_LensGAN128_52f0b_00000,RUNNING,172.28.0.2:3261,-3,128,3,74.625,0.0755425,0.197507,4


[2m[36m(train_LensGAN128 pid=3261)[0m Epoch 4:  36%|███▋      | 3748/10313 [16:00<-1:41:20, -5.86it/s, loss=2.07, v_num=., loss_G_fake=0.609, loss_G_class=1.110, loss_G_ssim=0.177, loss_G=1.900, loss_D_real=0.312, loss_D_fake=0.883, loss_D_class=1.110, loss_D=2.310]Epoch 4:  36%|███▋      | 3748/10313 [16:00<-1:41:20, -5.86it/s, loss=2.07, v_num=., loss_G_fake=0.609, loss_G_class=1.110, loss_G_ssim=0.177, loss_G=1.900, loss_D_real=0.312, loss_D_fake=0.883, loss_D_class=1.110, loss_D=2.310]Epoch 4:  36%|███▋      | 3748/10313 [16:00<-1:41:20, -5.86it/s, loss=2.07, v_num=., loss_G_fake=0.669, loss_G_class=1.100, loss_G_ssim=0.111, loss_G=1.880, loss_D_real=0.371, loss_D_fake=0.740, loss_D_class=1.100, loss_D=2.210]


Trial name,status,loc,lr,n_fmaps,bs,FID_F,auroc_F,ap_F,training_iteration
train_LensGAN128_52f0b_00000,RUNNING,172.28.0.2:3261,-3,128,3,74.625,0.0755425,0.197507,4


[2m[36m(train_LensGAN128 pid=3261)[0m Epoch 4:  45%|████▌     | 4685/10313 [20:05<-1:35:54, -3.89it/s, loss=2.07, v_num=., loss_G_fake=0.669, loss_G_class=1.100, loss_G_ssim=0.111, loss_G=1.880, loss_D_real=0.371, loss_D_fake=0.740, loss_D_class=1.100, loss_D=2.210]Epoch 4:  45%|████▌     | 4685/10313 [20:05<-1:35:54, -3.89it/s, loss=2.07, v_num=., loss_G_fake=0.669, loss_G_class=1.100, loss_G_ssim=0.111, loss_G=1.880, loss_D_real=0.371, loss_D_fake=0.740, loss_D_class=1.100, loss_D=2.210]Epoch 4:  45%|████▌     | 4685/10313 [20:05<-1:35:54, -3.89it/s, loss=2.06, v_num=., loss_G_fake=0.607, loss_G_class=1.080, loss_G_ssim=0.0902, loss_G=1.780, loss_D_real=0.389, loss_D_fake=0.809, loss_D_class=1.100, loss_D=2.300]


Trial name,status,loc,lr,n_fmaps,bs,FID_F,auroc_F,ap_F,training_iteration
train_LensGAN128_52f0b_00000,RUNNING,172.28.0.2:3261,-3,128,3,74.625,0.0755425,0.197507,4


[2m[36m(train_LensGAN128 pid=3261)[0m Epoch 4:  55%|█████▍    | 5622/10313 [24:10<-1:29:48, -2.59it/s, loss=2.06, v_num=., loss_G_fake=0.607, loss_G_class=1.080, loss_G_ssim=0.0902, loss_G=1.780, loss_D_real=0.389, loss_D_fake=0.809, loss_D_class=1.100, loss_D=2.300]Epoch 4:  55%|█████▍    | 5622/10313 [24:10<-1:29:48, -2.59it/s, loss=2.06, v_num=., loss_G_fake=0.607, loss_G_class=1.080, loss_G_ssim=0.0902, loss_G=1.780, loss_D_real=0.389, loss_D_fake=0.809, loss_D_class=1.100, loss_D=2.300]Epoch 4:  55%|█████▍    | 5622/10313 [24:10<-1:29:48, -2.59it/s, loss=2.05, v_num=., loss_G_fake=0.554, loss_G_class=1.090, loss_G_ssim=0.158, loss_G=1.800, loss_D_real=0.320, loss_D_fake=0.914, loss_D_class=1.110, loss_D=2.350] 


Trial name,status,loc,lr,n_fmaps,bs,FID_F,auroc_F,ap_F,training_iteration
train_LensGAN128_52f0b_00000,RUNNING,172.28.0.2:3261,-3,128,3,74.625,0.0755425,0.197507,4


Epoch 4:  64%|██████▎   | 6559/10313 [28:15<-1:22:20, -1.66it/s, loss=2.12, v_num=., loss_G_fake=0.741, loss_G_class=1.070, loss_G_ssim=0.0995, loss_G=1.910, loss_D_real=0.509, loss_D_fake=0.675, loss_D_class=1.140, loss_D=2.320]
Epoch 4:  73%|███████▎  | 7496/10313 [32:09<-1:11:47, -0.97it/s, loss=2.12, v_num=., loss_G_fake=0.665, loss_G_class=1.070, loss_G_ssim=0.128, loss_G=1.870, loss_D_real=0.502, loss_D_fake=0.758, loss_D_class=1.090, loss_D=2.350] 


Trial name,status,loc,lr,n_fmaps,bs,FID_F,auroc_F,ap_F,training_iteration
train_LensGAN128_52f0b_00000,RUNNING,172.28.0.2:3261,-3,128,3,74.625,0.0755425,0.197507,4


[2m[36m(train_LensGAN128 pid=3261)[0m Epoch 4:  82%|████████▏ | 8433/10313 [36:14<-2:47:41, -0.43it/s, loss=2.12, v_num=., loss_G_fake=0.665, loss_G_class=1.070, loss_G_ssim=0.128, loss_G=1.870, loss_D_real=0.502, loss_D_fake=0.758, loss_D_class=1.090, loss_D=2.350]Epoch 4:  82%|████████▏ | 8433/10313 [36:14<-2:47:41, -0.43it/s, loss=2.12, v_num=., loss_G_fake=0.665, loss_G_class=1.070, loss_G_ssim=0.128, loss_G=1.870, loss_D_real=0.502, loss_D_fake=0.758, loss_D_class=1.090, loss_D=2.350]Epoch 4:  82%|████████▏ | 8433/10313 [36:14<-2:47:41, -0.43it/s, loss=2.04, v_num=., loss_G_fake=0.693, loss_G_class=1.070, loss_G_ssim=0.177, loss_G=1.940, loss_D_real=0.323, loss_D_fake=0.706, loss_D_class=1.090, loss_D=2.120]


Trial name,status,loc,lr,n_fmaps,bs,FID_F,auroc_F,ap_F,training_iteration
train_LensGAN128_52f0b_00000,RUNNING,172.28.0.2:3261,-3,128,3,74.625,0.0755425,0.197507,4


Epoch 4:  91%|█████████ | 9370/10313 [40:19<-127:15:30, -0.00it/s, loss=2.13, v_num=., loss_G_fake=0.705, loss_G_class=1.070, loss_G_ssim=0.237, loss_G=2.010, loss_D_real=0.397, loss_D_fake=0.670, loss_D_class=1.090, loss_D=2.160]
[2m[36m(train_LensGAN128 pid=3261)[0m 
Validation: 0it [00:00, ?it/s][A
[2m[36m(train_LensGAN128 pid=3261)[0m 
Validation:   0%|          | 0/938 [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|          | 0/938 [00:00<?, ?it/s][A
Epoch 4: 100%|█████████▉| 10307/10313 [41:18<00:15,  2.66s/it, loss=2.13, v_num=., loss_G_fake=0.705, loss_G_class=1.070, loss_G_ssim=0.237, loss_G=2.010, loss_D_real=0.397, loss_D_fake=0.670, loss_D_class=1.090, loss_D=2.160]
[2m[36m(train_LensGAN128 pid=3261)[0m 
Validation DataLoader 0: 100%|█████████▉| 937/938 [00:58<00:00, 15.94it/s][A
Validation DataLoader 0: 100%|█████████▉| 937/938 [00:58<00:00, 15.94it/s][A
[2m[36m(train_LensGAN128 pid=3261)[0m 
Epoch 4: 100%|██████████| 10313/10313 [41:18<00:00,  2.64s/it

Trial name,status,loc,lr,n_fmaps,bs,FID_F,auroc_F,ap_F,training_iteration
train_LensGAN128_52f0b_00000,RUNNING,172.28.0.2:3261,-3,128,3,39.5,0.674615,0.566132,5


[2m[36m(train_LensGAN128 pid=3261)[0m Epoch 5:   9%|▉         | 937/10313 [03:58<-1:55:35, -35.36it/s, loss=2.13, v_num=., loss_G_fake=0.507, loss_G_class=1.070, loss_G_ssim=0.227, loss_G=1.800, loss_D_real=0.381, loss_D_fake=1.040, loss_D_class=1.090, loss_D=2.510] Epoch 5:   9%|▉         | 937/10313 [03:58<-1:55:35, -35.36it/s, loss=2.13, v_num=., loss_G_fake=0.507, loss_G_class=1.070, loss_G_ssim=0.227, loss_G=1.800, loss_D_real=0.381, loss_D_fake=1.040, loss_D_class=1.090, loss_D=2.510]Epoch 5:   9%|▉         | 937/10313 [03:58<-1:55:35, -35.36it/s, loss=2.07, v_num=., loss_G_fake=0.637, loss_G_class=1.070, loss_G_ssim=0.189, loss_G=1.900, loss_D_real=0.391, loss_D_fake=0.751, loss_D_class=1.090, loss_D=2.240]


Trial name,status,loc,lr,n_fmaps,bs,FID_F,auroc_F,ap_F,training_iteration
train_LensGAN128_52f0b_00000,RUNNING,172.28.0.2:3261,-3,128,3,39.5,0.674615,0.566132,5


[2m[36m(train_LensGAN128 pid=3261)[0m Epoch 5:  18%|█▊        | 1874/10313 [08:04<-1:50:56, -15.48it/s, loss=2.07, v_num=., loss_G_fake=0.637, loss_G_class=1.070, loss_G_ssim=0.189, loss_G=1.900, loss_D_real=0.391, loss_D_fake=0.751, loss_D_class=1.090, loss_D=2.240]Epoch 5:  18%|█▊        | 1874/10313 [08:04<-1:50:56, -15.48it/s, loss=2.07, v_num=., loss_G_fake=0.637, loss_G_class=1.070, loss_G_ssim=0.189, loss_G=1.900, loss_D_real=0.391, loss_D_fake=0.751, loss_D_class=1.090, loss_D=2.240]Epoch 5:  18%|█▊        | 1874/10313 [08:04<-1:50:56, -15.48it/s, loss=2.11, v_num=., loss_G_fake=0.661, loss_G_class=1.020, loss_G_ssim=0.171, loss_G=1.860, loss_D_real=0.502, loss_D_fake=0.713, loss_D_class=1.110, loss_D=2.320]


Trial name,status,loc,lr,n_fmaps,bs,FID_F,auroc_F,ap_F,training_iteration
train_LensGAN128_52f0b_00000,RUNNING,172.28.0.2:3261,-3,128,3,39.5,0.674615,0.566132,5


Epoch 5:  27%|██▋       | 2811/10313 [12:08<-1:46:08, -9.01it/s, loss=2.06, v_num=., loss_G_fake=0.652, loss_G_class=1.070, loss_G_ssim=0.127, loss_G=1.850, loss_D_real=0.407, loss_D_fake=0.717, loss_D_class=1.080, loss_D=2.200]
Epoch 5:  36%|███▋      | 3748/10313 [16:02<-1:41:17, -5.85it/s, loss=2.02, v_num=., loss_G_fake=0.600, loss_G_class=1.030, loss_G_ssim=0.128, loss_G=1.760, loss_D_real=0.360, loss_D_fake=0.847, loss_D_class=1.100, loss_D=2.310]


Trial name,status,loc,lr,n_fmaps,bs,FID_F,auroc_F,ap_F,training_iteration
train_LensGAN128_52f0b_00000,RUNNING,172.28.0.2:3261,-3,128,3,39.5,0.674615,0.566132,5


[2m[36m(train_LensGAN128 pid=3261)[0m Epoch 5:  45%|████▌     | 4685/10313 [20:05<-1:35:53, -3.89it/s, loss=2.02, v_num=., loss_G_fake=0.600, loss_G_class=1.030, loss_G_ssim=0.128, loss_G=1.760, loss_D_real=0.360, loss_D_fake=0.847, loss_D_class=1.100, loss_D=2.310]Epoch 5:  45%|████▌     | 4685/10313 [20:05<-1:35:53, -3.89it/s, loss=2.02, v_num=., loss_G_fake=0.600, loss_G_class=1.030, loss_G_ssim=0.128, loss_G=1.760, loss_D_real=0.360, loss_D_fake=0.847, loss_D_class=1.100, loss_D=2.310]Epoch 5:  45%|████▌     | 4685/10313 [20:05<-1:35:53, -3.89it/s, loss=2.05, v_num=., loss_G_fake=0.688, loss_G_class=1.080, loss_G_ssim=0.144, loss_G=1.910, loss_D_real=0.434, loss_D_fake=0.706, loss_D_class=1.120, loss_D=2.260]


Trial name,status,loc,lr,n_fmaps,bs,FID_F,auroc_F,ap_F,training_iteration
train_LensGAN128_52f0b_00000,RUNNING,172.28.0.2:3261,-3,128,3,39.5,0.674615,0.566132,5


[2m[36m(train_LensGAN128 pid=3261)[0m Epoch 5:  55%|█████▍    | 5622/10313 [24:09<-1:29:49, -2.59it/s, loss=2.05, v_num=., loss_G_fake=0.688, loss_G_class=1.080, loss_G_ssim=0.144, loss_G=1.910, loss_D_real=0.434, loss_D_fake=0.706, loss_D_class=1.120, loss_D=2.260]Epoch 5:  55%|█████▍    | 5622/10313 [24:09<-1:29:49, -2.59it/s, loss=2.05, v_num=., loss_G_fake=0.688, loss_G_class=1.080, loss_G_ssim=0.144, loss_G=1.910, loss_D_real=0.434, loss_D_fake=0.706, loss_D_class=1.120, loss_D=2.260]Epoch 5:  55%|█████▍    | 5622/10313 [24:09<-1:29:49, -2.59it/s, loss=2.04, v_num=., loss_G_fake=0.622, loss_G_class=0.963, loss_G_ssim=0.122, loss_G=1.710, loss_D_real=0.416, loss_D_fake=0.797, loss_D_class=1.110, loss_D=2.330]


Trial name,status,loc,lr,n_fmaps,bs,FID_F,auroc_F,ap_F,training_iteration
train_LensGAN128_52f0b_00000,RUNNING,172.28.0.2:3261,-3,128,3,39.5,0.674615,0.566132,5


[2m[36m(train_LensGAN128 pid=3261)[0m Epoch 5:  64%|██████▎   | 6559/10313 [28:14<-1:22:22, -1.66it/s, loss=2.04, v_num=., loss_G_fake=0.622, loss_G_class=0.963, loss_G_ssim=0.122, loss_G=1.710, loss_D_real=0.416, loss_D_fake=0.797, loss_D_class=1.110, loss_D=2.330]Epoch 5:  64%|██████▎   | 6559/10313 [28:14<-1:22:22, -1.66it/s, loss=2.04, v_num=., loss_G_fake=0.622, loss_G_class=0.963, loss_G_ssim=0.122, loss_G=1.710, loss_D_real=0.416, loss_D_fake=0.797, loss_D_class=1.110, loss_D=2.330]Epoch 5:  64%|██████▎   | 6559/10313 [28:14<-1:22:22, -1.66it/s, loss=2, v_num=., loss_G_fake=0.583, loss_G_class=0.981, loss_G_ssim=0.101, loss_G=1.660, loss_D_real=0.400, loss_D_fake=0.871, loss_D_class=1.110, loss_D=2.390]   


Trial name,status,loc,lr,n_fmaps,bs,FID_F,auroc_F,ap_F,training_iteration
train_LensGAN128_52f0b_00000,RUNNING,172.28.0.2:3261,-3,128,3,39.5,0.674615,0.566132,5


Epoch 5:  73%|███████▎  | 7496/10313 [32:08<-1:11:49, -0.97it/s, loss=2, v_num=., loss_G_fake=0.637, loss_G_class=0.957, loss_G_ssim=0.110, loss_G=1.700, loss_D_real=0.439, loss_D_fake=0.767, loss_D_class=1.110, loss_D=2.320]
Epoch 5:  82%|████████▏ | 8433/10313 [36:12<-2:47:45, -0.43it/s, loss=2, v_num=., loss_G_fake=0.620, loss_G_class=0.845, loss_G_ssim=0.176, loss_G=1.640, loss_D_real=0.363, loss_D_fake=0.837, loss_D_class=1.090, loss_D=2.290]


Trial name,status,loc,lr,n_fmaps,bs,FID_F,auroc_F,ap_F,training_iteration
train_LensGAN128_52f0b_00000,RUNNING,172.28.0.2:3261,-3,128,3,39.5,0.674615,0.566132,5


Epoch 5:  91%|█████████ | 9370/10313 [40:16<-127:23:02, -0.00it/s, loss=1.96, v_num=., loss_G_fake=0.638, loss_G_class=0.897, loss_G_ssim=0.131, loss_G=1.670, loss_D_real=0.357, loss_D_fake=0.814, loss_D_class=1.100, loss_D=2.270]
[2m[36m(train_LensGAN128 pid=3261)[0m 
Validation: 0it [00:00, ?it/s][A
[2m[36m(train_LensGAN128 pid=3261)[0m 
Validation:   0%|          | 0/938 [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|          | 0/938 [00:00<?, ?it/s][A
Epoch 5: 100%|█████████▉| 10307/10313 [41:17<00:15,  2.66s/it, loss=1.96, v_num=., loss_G_fake=0.638, loss_G_class=0.897, loss_G_ssim=0.131, loss_G=1.670, loss_D_real=0.357, loss_D_fake=0.814, loss_D_class=1.100, loss_D=2.270]
[2m[36m(train_LensGAN128 pid=3261)[0m 
Validation DataLoader 0: 100%|█████████▉| 937/938 [00:59<00:00, 15.62it/s][A
Validation DataLoader 0: 100%|█████████▉| 937/938 [00:59<00:00, 15.62it/s][A
Epoch 5: 100%|██████████| 10313/10313 [41:17<00:00,  2.64s/it, loss=1.96, v_num=., loss_G_fake=0.638, l

Trial name,status,loc,lr,n_fmaps,bs,FID_F,auroc_F,ap_F,training_iteration
train_LensGAN128_52f0b_00000,RUNNING,172.28.0.2:3261,-3,128,3,39.5,0.674615,0.566132,5


[2m[36m(train_LensGAN128 pid=3261)[0m Epoch 5: 100%|██████████| 10313/10313 [41:28<00:00,  2.65s/it, loss=1.96, v_num=., loss_G_fake=0.638, loss_G_class=0.897, loss_G_ssim=0.131, loss_G=1.670, loss_D_real=0.357, loss_D_fake=0.814, loss_D_class=1.100, loss_D=2.270]
[2m[36m(train_LensGAN128 pid=3261)[0m Validation DataLoader 0: 100%|██████████| 938/938 [01:11<00:00, 15.62it/s][A
Result for train_LensGAN128_52f0b_00000:
  FID_F: 70.3125
  FID_J: 68.25
  ap_F: 0.32872846722602844
  ap_J: 0.23333387076854706
  auroc_F: 0.43105876445770264
  auroc_J: 0.21448326110839844
  date: 2022-05-24_03-41-46
  done: false
  experiment_id: 152d034e835a4eb783ed5470faa95a27
  hostname: c8992defe8a4
  iterations_since_restore: 3
  node_ip: 172.28.0.2
  pid: 3261
  should_checkpoint: true
  time_since_restore: 7529.109290599823
  time_this_iter_s: 2491.2888629436493
  time_total_s: 11277.236166000366
  timestamp: 1653363706
  timesteps_since_restore: 0
  training_iteration: 6
  trial_id: 52f0b_0000

Trial name,status,loc,lr,n_fmaps,bs,FID_F,auroc_F,ap_F,training_iteration
train_LensGAN128_52f0b_00000,RUNNING,172.28.0.2:3261,-3,128,3,70.3125,0.431059,0.328728,6


[2m[36m(train_LensGAN128 pid=3261)[0m Epoch 6:  18%|█▊        | 1874/10313 [08:03<-1:50:57, -15.52it/s, loss=1.99, v_num=., loss_G_fake=0.563, loss_G_class=0.958, loss_G_ssim=0.209, loss_G=1.730, loss_D_real=0.364, loss_D_fake=0.908, loss_D_class=1.090, loss_D=2.370]Epoch 6:  18%|█▊        | 1874/10313 [08:03<-1:50:57, -15.52it/s, loss=1.99, v_num=., loss_G_fake=0.563, loss_G_class=0.958, loss_G_ssim=0.209, loss_G=1.730, loss_D_real=0.364, loss_D_fake=0.908, loss_D_class=1.090, loss_D=2.370]Epoch 6:  18%|█▊        | 1874/10313 [08:03<-1:50:57, -15.52it/s, loss=2.02, v_num=., loss_G_fake=0.590, loss_G_class=0.934, loss_G_ssim=0.146, loss_G=1.670, loss_D_real=0.440, loss_D_fake=0.861, loss_D_class=1.100, loss_D=2.410]


Trial name,status,loc,lr,n_fmaps,bs,FID_F,auroc_F,ap_F,training_iteration
train_LensGAN128_52f0b_00000,RUNNING,172.28.0.2:3261,-3,128,3,70.3125,0.431059,0.328728,6


[2m[36m(train_LensGAN128 pid=3261)[0m Epoch 6:  27%|██▋       | 2811/10313 [12:07<-1:46:09, -9.02it/s, loss=2.02, v_num=., loss_G_fake=0.590, loss_G_class=0.934, loss_G_ssim=0.146, loss_G=1.670, loss_D_real=0.440, loss_D_fake=0.861, loss_D_class=1.100, loss_D=2.410] Epoch 6:  27%|██▋       | 2811/10313 [12:07<-1:46:09, -9.02it/s, loss=2.02, v_num=., loss_G_fake=0.590, loss_G_class=0.934, loss_G_ssim=0.146, loss_G=1.670, loss_D_real=0.440, loss_D_fake=0.861, loss_D_class=1.100, loss_D=2.410]Epoch 6:  27%|██▋       | 2811/10313 [12:07<-1:46:09, -9.02it/s, loss=2, v_num=., loss_G_fake=0.655, loss_G_class=0.850, loss_G_ssim=0.148, loss_G=1.650, loss_D_real=0.527, loss_D_fake=0.751, loss_D_class=1.110, loss_D=2.390]   


In [None]:
# %reload_ext tensorboard
!tensorboard --logdir=./drive/MyDrive/Logs/F/LensGAN128/test

In [23]:
drive.

## VAE

In [None]:
def train_mnist(config, checkpoint=None, num_steps=10000, num_gpus=torch.cuda.device_count()):
    dm = npyImageData(config, 128)                                              # Specify image width here    
    model = LensVAE(config)
    trainer = pl.Trainer(val_check_interval=0.10, max_steps=num_steps, gpus=num_gpus,
                         logger=TensorBoardLogger(save_dir=os.path.join('drive/MyDrive/Logs/', 'F/LensVAE'), 
                                                  name='delete', sub_dir='_'.join('{}_{}'.format(*i) for i in config.items())), 
                         callbacks=[
                                    ModelCheckpoint(monitor='train_loss', 
                                                    verbose=True, save_last=True, save_top_k=5, 
                                                    mode='min', auto_insert_metric_name=False, 
                                                    every_n_train_steps=int(7500//int(2**np.rint(config['bs'])))),
                                    ModuleDataMonitor(True), 
                                    ],
                         benchmark=True, precision=16)

    trainer.fit(model, dm, ckpt_path=checkpoint)
# __lightning_end__

# __no_tune_train_begin__
def train_mnist_no_tune(*args, **kwargs):
    config={'input_height': 128,
            'latent_dim': 100,
            'lr': 1e-3,
            'bs': 3,
            }
    train_mnist(config, *args, **kwargs)
# __no_tune_train_end__

if __name__ == '__main__':
    import argparse

    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--smoke-test', action='store_true', help='Finish quickly for testing')
    args, _ = parser.parse_known_args()

    # os.environ["PL_FAULT_TOLERANT_TRAINING"] = "1"
    if args.smoke_test:
        train_mnist_no_tune()
    else:
        # pbt scheduler
        train_mnist_no_tune()

## Stage 2
Here we tune hyperparameters as we train our modified DCGAN.

In [None]:
# __tune_train_checkpoint_begin
def train_Stage2(config, checkpoint_dir=None, num_epochs=10, num_gpus=torch.cuda.device_count()):
    # print(os.cpu_count(), torch.cuda.device_count())
    kwargs = {
        # 'limit_train_batches' : 0.05,
        # 'limit_val_batches' : 0.05,
        'progress_bar_refresh_rate' : int(8250//config['batch_size']),
        'max_epochs' : num_epochs,
        'prepare_data_per_node' : False,
        # If fractional GPUs passed in, convert to int.
        'gpus' : int(num_gpus),
        'logger' : TensorBoardLogger(save_dir=tune.get_trial_dir(), name='', version='.'),
        'callbacks' : [
            TuneReportCheckpointCallback(
                {
                    'loss_G': 'Stage2/G/train/loss', 
                    'loss_D': 'Stage2/D/train/loss', 
                    # Switch up the auroc vlues when training on different dataset -----------------------------------------------
                    'auroc': 'Stage2/ResNet(F)/val/auroc', 
                    'auroc_cross': 'Stage2/ResNet(J)/val/auroc',
                },
            ),
        ],
        # 'stochastic_weight_avg' : True,
        # works with only one optimizer
        # 'benchmark' : True,
    }
    
    dm = npyImageData(config)                                              # Specify image width here    
    if checkpoint_dir is not None:
        kwargs['resume_from_checkpoint'] = os.path.join(checkpoint_dir, 'checkpoint')
        # model = Stage2.load_from_checkpoint(kwargs['resume_from_checkpoint'], config=config)
    # else:
        # model = Stage2(config)
    model = Stage2(config)
    trainer = pl.Trainer(**kwargs)

    trainer.fit(model, dm)
# __tune_train_checkpoint_end__


# # # __tune_asha_begin__
# def tune_Stage2_asha(num_samples=10, num_epochs=10, gpus_per_trial=torch.cuda.device_count()):
#     # print(os.cpu_count(), torch.cuda.device_count())
#     analysis = tune.run(
#         tune.with_parameters(
#             train_Stage2,
#             num_epochs=num_epochs,
#             num_gpus=gpus_per_trial
#         ),
#         # Change the folder name when changing dataset--------------------------------------------------------------------------
#         name='Stage2/pbt/J',
#         metric='auroc',
#         mode='max',
#         config={'learning_rate': 1e-4,
#                 'n_fmaps': tune.grid_search([8, 16, 32, 64, 128]),
#                 'batch_size': 8,
#                 },
#         # config={'learning_rate': 0.01,
#         #         'n_fmaps': 32,
#         #         'batch_size': 32,
#         #         },
#         # stop=TrialPlateauStopper('loss_G'),
#         resources_per_trial={'cpu': os.cpu_count(),
#                              'gpu': gpus_per_trial,
#                             },
#         local_dir='./drive/MyDrive/Logs',
#         scheduler = ASHAScheduler(max_t=num_epochs, grace_period=2,  reduction_factor=2),
#         progress_reporter=JupyterNotebookReporter(
#             overwrite=True,
#             parameter_columns=['learning_rate', 'n_fmaps', 'batch_size'],
#             metric_columns=['loss_G', 'loss_D', 'auroc', 'auroc_cross', 'training_iteration'],
#             sort_by_metric=True,
#         ),
#         fail_fast = True,
#         # reuse_actors=True,
#         # num_samples=num_samples,
#         resume='PROMPT',
# #         restore='/content/drive/MyDrive/Logs/delete/train_Stage2_e42ac_00025_25_batch_size=8,learning_rate=0.01,n_fmaps=8_2021-07-28_21-16-18/checkpoint_epoch=4-step=2339',
#     )

# #     print('Best hyperparameters found were: ', analysis.best_config)

# # # __tune_asha_end__


# __tune_pbt_begin__
def tune_Stage2_pbt(num_samples=10, num_epochs=10, gpus_per_trial=torch.cuda.device_count()):
    # print(os.cpu_count(), torch.cuda.device_count())
    analysis = tune.run(
        tune.with_parameters(
            train_Stage2,
            num_epochs=num_epochs,
            num_gpus=gpus_per_trial
        ),
        # Change the folder name when changing dataset--------------------------------------------------------------------------
        name='Stage2/pbt/F',
        metric='auroc',
        mode='max',
        config={'learning_rate': 1e-4,
                'n_fmaps': tune.grid_search([8, 16, 32, 64, 128]),
                'res_depth': tune.choice([1, 2, 3, 4]),
                'batch_size': 8,
                },
        # config={'learning_rate': 0.01,
        #         'n_fmaps': 32,
        #         'batch_size': 32,
        #         },
        # stop=TrialPlateauStopper('loss_G'),
        resources_per_trial={'cpu': os.cpu_count(),
                             'gpu': gpus_per_trial,
                            },
        local_dir='./drive/MyDrive/Logs',
        scheduler = PopulationBasedTraining(time_attr='training_iteration',
                                            quantile_fraction=0.5,
                                            resample_probability=0.8,
                                            perturbation_interval=1,
                                            hyperparam_mutations={
                                                'learning_rate': tune.loguniform(1e-7, 1e-1),
                                                'batch_size': [8, 16, 32, 64, 128],
                                            },
                                            ),
        progress_reporter=JupyterNotebookReporter(
            overwrite=False,
            parameter_columns=['learning_rate', 'n_fmaps', 'res_depth', 'batch_size'],
            metric_columns=['loss_G', 'loss_D', 'auroc', 'auroc_cross', 'training_iteration'],
            sort_by_metric=True,
        ),
        fail_fast = True,
        # reuse_actors=True,
        # num_samples=num_samples,
        resume='PROMPT',
    )

    print('Best hyperparameters found were: ', analysis.best_config)

# __tune_pbt_end__


if __name__ == '__main__':
    import argparse

    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--smoke-test', action='store_true', help='Finish quickly for testing')
    args, _ = parser.parse_known_args()

    if args.smoke_test:
        tune_Stage2_asha(num_samples=1, num_epochs=6, gpus_per_trial=torch.cuda.device_count())
        tune_Stage2_pbt(num_samples=1, num_epochs=6, gpus_per_trial=torch.cuda.device_count())
    else:
        # ASHA scheduler
        # tune_Stage2_asha(num_samples=1, num_epochs=10, gpus_per_trial=torch.cuda.device_count())
        # Population based training
        tune_Stage2_pbt(num_samples=1, num_epochs=30, gpus_per_trial=torch.cuda.device_count())