<a href="https://colab.research.google.com/github/souravraha/galaxy/blob/pbt2/Lightning_Tune.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Mount Google drive

In [None]:
from google.colab import drive
drive.mount('/content/drive', timeout_ms=600000)

# Prerequisites/ shell commands

## Install/uninstall packages

In [None]:
# If you are running on Google Colab, uncomment below to install the necessary dependencies 
# before beginning the exercise.

print('Setting up colab environment')
# !pip uninstall -y -q pyarrow
!pip install -q lightning-bolts GPy
!pip install -q ray[debug] ray[default]
!pip install -q ray[tune] matplotlib==3.4.3
# !pip install cloud-tpu-client==0.10 https://storage.googleapis.com/tpu-pytorch/wheels/torch_xla-1.8-cp37-cp37m-linux_x86_64.whl

# A hack to force the runtime to restart, needed to include the above dependencies.
print('Done installing! Restarting via forced crash (this is not an issue).')
import os
os._exit(0)

## Download and extract data

Here choose the model you wish to use for training/testing. Don't forget to make modifications in the following sections:

1.   GLOBAL in class definition of npyImageData.
2.   correct assignment of metric keys while defining the training wrapper for Tune.
3.   name of the experiment initiated/resumed.

In [1]:
# 'a': 1Cjcw2EWorhdhJSGoWOdxsEUDxvl943dt, 'b': 15yXXC4h5VsytP3Ak1jfUSjQhdgP2s23K, 'c': 1vuQ-pLzoKT4Hd_V7949r9eND9E2fB_u_,
# 'd': , 'e': 1wFuasvb7PthxXtMUlsD13uzYHWlWt06H, 'f': 17l6H61tLAu26zGuei38r_T5ssjbYUeaJ, 
# 'g': 1SxQVosWeEjY3Pyn8LRXA11rLnZ9HK_7B, 'h': 1Atau0RH4oyLAiYReW-G9a8l9pUNltglF, 'i': 15lEgsR1p00KSHieaT9a1nkbJ86pDxwgp, 
# 'j': 1m0EQUbqZZeyl76XsQIKWU5Qd7jGmmWhB, 'k': , 'l': 1meTDi4aeWfdChOiXeLtUOGhjVDVu000e

# fake data
# 'fpgan': 1-4o0yqSBA9WSY9gTYamIez_RAekwDsHV

!rm -rf images/
!gdown --id 17l6H61tLAu26zGuei38r_T5ssjbYUeaJ -O - --quiet | tar --skip-old-files -zxf -
# !rm ./model_f.tgz

# def prepare_data(data_dir: str = '/content'):
#     with FileLock(os.path.expanduser(data_dir+'.lock')):
#         gdown.download('https://drive.google.com/uc?id=17l6H61tLAu26zGuei38r_T5ssjbYUeaJ', data_dir+'/model_j.tgz', quiet=True)
        
#         temp = tarfile.open(data_dir+'/model_j.tgz', 'r|gz')
#         temp.extractall()
#         temp.close()

# Line wrapping

In [2]:
from IPython.display import HTML, display

def set_css():
    display(HTML('''<style>pre{white-space: pre-wrap;}</style>'''))
    
get_ipython().events.register('pre_run_cell', set_css)

# Import libraries

In [None]:
import os
import numpy as np
# from pathlib import Path
from itertools import cycle
%matplotlib inline
from matplotlib import pyplot as plt
# ------------------------------------
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader, random_split
# ------------------------------------
from torchvision.models import resnet18
from torchvision.utils import save_image, make_grid
from torchvision import transforms, datasets
# ------------------------------------
from sklearn.metrics import PrecisionRecallDisplay, RocCurveDisplay
# ------------------------------------
import torchmetrics as tm
# ------------------------------------
import pytorch_lightning as pl
from pl_bolts.models.gans import DCGAN
from pl_bolts.callbacks import ModuleDataMonitor
from pytorch_lightning.loggers import TensorBoardLogger
from pl_bolts.models.self_supervised.resnets import BasicBlock
from pytorch_lightning.callbacks import ModelCheckpoint, TQDMProgressBar
from drive.MyDrive.ml.Callbacks.confused_logits import ConfusedLogitCallback
from drive.MyDrive.ml.Callbacks.save_images import SaveImages
# ------------------------------------
from ray import tune
# from ray.tune.stopper import TrialPlateauStopper
from ray.tune import CLIReporter, JupyterNotebookReporter
from ray.tune.schedulers import ASHAScheduler, PopulationBasedTraining
from ray.tune.schedulers.pb2 import PB2
from ray.tune.integration.pytorch_lightning import TuneReportCallback, TuneReportCheckpointCallback
# ------------------------------------

# Class definitions

## DataModule
This creates dataloaders which need to be supplied to train, validate or test the module we have.

In [None]:
class npyImageData(pl.LightningDataModule):
    def __init__(self, config, img_width: int = 150, data_dir: str = '/content/images/'):
        super().__init__()
        # This method is not implemented
        # self.save_hyperparameters()
        self.bs = int(2**np.rint(config['bs']))
        self.data_dir = os.path.expanduser(data_dir)
        self.prepare_data_per_node = False

        # Change the source file containing mean and stdv when changing dataset ------------------------------------------------------
        self.transform = transforms.Compose([
            # transforms.RandomHorizontalFlip(),
            # transforms.RandomVerticalFlip(),
            # F : [mean=71.75926373866668, std=96.139484964214, min=5.0, max=966.0]
            # J : [mean=50.271541595458984, std=94.8838882446289, min=0, max=1007.0]
            transforms.Normalize(mean=(0,), std=(966,)),
            transforms.Normalize(mean=(0.5,), std=(0.5,)),
            # this shift-scales the pixel values -> [-1, 1]
            transforms.Resize(img_width, transforms.InterpolationMode.NEAREST),
        ])

    @staticmethod
    def npy_loader(path):
        # s=np.load(path).astype('float',copy=False)
        return torch.from_numpy(np.load(path)).unsqueeze(0).float()
        # Convert to tenssor first, and then to float, otherwise final dtype 
        # would be float64, which would raise errors in conv layers      ###### type as

    def setup(self, stage: str = None):
        if stage in ('fit', None):
            self.train_set = datasets.DatasetFolder(os.path.join(self.data_dir,'train'), 
                self.npy_loader, ('.npy'), self.transform,)
            # self.train_set, self.val_set = random_split(self.full_set, [60000, 15000])            
            self.val_set = datasets.DatasetFolder(os.path.join(self.data_dir,'val'),  
                self.npy_loader, ('.npy'), self.transform,)
            self.dims = tuple(self.train_set[0][0].shape)

        if stage in ('test', None):
            self.test_set = datasets.DatasetFolder(os.path.join(self.data_dir,'val'),  
                self.npy_loader, ('.npy'), self.transform,)
            self.dims = getattr(self, 'dims', self.test_set[0][0].shape)
    
    def train_dataloader(self):
        return DataLoader(self.train_set, self.bs, shuffle=True, num_workers=os.cpu_count(), pin_memory=True)

    def val_dataloader(self):
        return DataLoader(self.val_set, self.bs, shuffle=True, num_workers=os.cpu_count(), pin_memory=True)

    def test_dataloader(self):
        return DataLoader(self.test_set, self.bs, shuffle=True, num_workers=os.cpu_count(), pin_memory=True)

In [None]:
dm = npyImageData({'lr': 0.001, 'bs': 8})
# model = LensResnet.load_from_checkpoint(os.path.join(BEST_F_RESNET, 'checkpoint'), config={'lr': 0.001, 'bs': 8})
# trainer = pl.Trainer(gpus=1)
# trainer.predict(model, dm)

## ResNet
We modify a ResNet slightly for our purpose.

In [7]:
PRE_F_RESNET = '/content/drive/MyDrive/Logs/F/LensResnet/PRETRAINED.pth'
PRE_J_RESNET = '/content/drive/MyDrive/Logs/J/LensResnet/PRETRAINED.pth'

class LensResnet(pl.LightningModule):
    def __init__(self, config, image_channels: int = 1, num_classes: int = 3, **kwargs):
        super().__init__()
        self.save_hyperparameters(ignore=config)
        self.lr = 10**config['lr']

        # --------------------------------------------------------------------------------------------------
        self.backbone = torch.load(PRE_F_RESNET, map_location=self.device)
        # self.backbone = resnet18(num_classes=self.hparams.num_classes)
        # self.backbone.conv1 = nn.LazyConv2d(64, 7, 2, 3, bias=False)
        
        self.train_metrics = tm.MetricCollection([tm.AUROC(self.hparams.num_classes, average='weighted'),],
            prefix='LensResnet/train/'
        )
        self.val_metrics = tm.MetricCollection(
            [tm.PrecisionRecallCurve(self.hparams.num_classes), tm.ROC(self.hparams.num_classes),
            tm.AveragePrecision(self.hparams.num_classes, average=None), tm.AUROC(self.hparams.num_classes, average=None)]
        )

    def configure_optimizers(self):
        return torch.optim.Adam(self.backbone.parameters(), self.lr)

    def forward(self, x, prob=False):
        logits = self.backbone(x)
        return logits.softmax(1) if prob else logits

    def training_step(self, batch, batch_idx):
        imgs, labels = batch
        self.last_logits = self(imgs)
        loss = F.cross_entropy(self.last_logits, labels)
        self.log('LensResnet/train/loss', loss)
        #  keep only scalars here, for no errors
        
        preds = self.last_logits.softmax(1)
        self.train_metrics.update(preds, labels)
        try:
            self.log_dict(self.train_metrics.compute(), prog_bar=True)
        except Exception as f:
            print(f)
        finally:            
            # self.train_metrics.reset()
            # self.log_dict automatically resets at the end of epoch
            return loss

    def validation_step(self, batch, batch_idx):
        imgs, labels = batch
        logits = self(imgs)
        loss = F.cross_entropy(logits, labels)
        self.log('LensResnet/val/loss', loss)
        #  keep only scalars here, for no errors
        
        preds = logits.softmax(1)
        self.val_metrics.update(preds, labels)

    def validation_epoch_end(self, Listofdicts):
        colors = cycle(['r', 'g', 'b'])
        fig, ax = plt.subplots(1,2, subplot_kw={'xlim': [-0.05, 1.05], 'ylim': [-0.05, 1.05], 'aspect': 1}, figsize=(10, 5))
        # ---------------------------------------------------------------------------------------------------------
        fig.suptitle('One vs. all PR & ROC curves for different types of substructures in model J (orginal data)')
        
        Dict = self.val_metrics.compute()
        self.val_metrics.reset()
        
        key, val = list(Dict.keys()), list(Dict.values())
        for b in range(2):
            if key[b] != 'ROC':
                ax[b].plot([0, 1], [1/self.hparams.num_classes, 1/self.hparams.num_classes], 'k--')
            else:
                ax[b].plot([0, 1], [0, 1], 'k--')
            
            prec_FPR, rec_TPR, _ = val[b]
            for i, color, cls in zip(range(self.hparams.num_classes), colors, self.trainer.datamodule.val_set.classes):
                if key[b] != 'ROC':
                    PrecisionRecallDisplay(prec_FPR[i].cpu(), rec_TPR[i].cpu(), val[2 + b][i], cls).plot(ax[b], c=color)
                else:
                    RocCurveDisplay(prec_FPR[i].cpu(), rec_TPR[i].cpu(), val[2 + b][i], cls).plot(ax[b], c=color)
                
            ax[b].legend(loc='best')
            self.log('LensResnet/val/' + key[2 + b], min(val[2 + b]))

        self.logger.experiment.add_figure('LensResnet/val/PR_ROC', fig)
        fig.savefig(str(self.trainer.log_dir) + '/PR_ROC_step_{:05d}.png'.format(self.global_step))

In [None]:
m = LensResnet({'lr': 1e-3})

## LensGAN128
Here we subclass a DCGAN to create our low resolution GAN.

In [14]:
BEST_F_RESNET = '/content/drive/MyDrive/Logs/F/LensResnet/pbt_tanh/train_LensResnet_eb619_00000_0_2021-09-02_19-42-34/checkpoint_epoch=2-step=28124'
BEST_J_RESNET = '/content/drive/MyDrive/Logs/J/LensResnet/pbt_tanh_fine/train_LensResnet_93609_00000_0_2021-09-02_21-06-00/checkpoint_epoch=2-step=28124'

class LensGAN128(DCGAN):
    def __init__(self, config, num_classes: int = 3, **kwargs):
        super().__init__(feature_maps_gen=config['n_fmaps'], feature_maps_disc=config['n_fmaps'], 
                         learning_rate=10**config['lr'], **kwargs)
        self.save_hyperparameters(ignore=config)

        # Batch-norm needs depends upon out channels of the previous layer
        first = self.generator._make_gen_block(1, self.hparams.feature_maps_gen * 16)
        # Accepts inputs of any channels
        first[0] = nn.LazyConvTranspose2d(self.hparams.feature_maps_gen * 16, kernel_size=4, stride=1, padding=0)
        # Turn this into the second layer
        self.generator.gen[0] = self.generator._make_gen_block(self.hparams.feature_maps_gen * 16, self.hparams.feature_maps_gen * 8)
        self.generator.gen = nn.Sequential(first, *list(self.generator.gen))
        
        # Turn this into the penultimate layer
        self.discriminator.disc[-1] = self.discriminator._make_disc_block(self.hparams.feature_maps_disc * 8, self.hparams.feature_maps_disc * 16)
        # self.discriminator.disc = nn.Sequential(*list(self.discriminator.disc), nn.LazyConv2d(1, kernel_size=4, stride=1, padding=0, bias=False))
        # # Necessary if using ACGAN, else could be appended to disc module
        self.discriminator.add_module('critic', nn.LazyConv2d(1, kernel_size=4, stride=1, padding=0, bias=False))
        # # Remove if ACGAN not needed 
        self.discriminator.add_module('aux', nn.Sequential(nn.Flatten(), nn.LazyLinear(self.hparams.num_classes)))
        self.discriminator.forward = self.discriminator_forward

        for i in range(5):
            # Necessary for implementing gradient penalty, else remove
            # if i != 0:
                # self.discriminator.disc[i][1] = nn.LayerNorm([self.discriminator.disc[i][0].out_channels, 2 ** (6 - i), 2 ** (6 - i)])
            # Remove if LeakyRelu not needed in generator
            self.generator.gen[i][-1] = self.discriminator.disc[i][-1]
            # Implement dropouts
            # self.generator.gen[i] = nn.Sequential(*list(self.generator.gen[i]), nn.Dropout())
            # self.discriminator.disc[i] = nn.Sequential(*list(self.discriminator.disc[i]), nn.Dropout())

        # Not needed in WGAN architectures
        self.criterion = nn.BCEWithLogitsLoss()

        temp = LensResnet.load_from_checkpoint(os.path.join(BEST_F_RESNET, 'checkpoint'))
        temp.freeze()
        self.modelF = temp.backbone     # torch.load(PRE_F_RESNET, map_location=self.device).eval())
        # Proper way to copy the last layer
        self.lastF = self.modelF.fc
        self.modelF.fc = nn.Identity()
        
        temp = LensResnet.load_from_checkpoint(os.path.join(BEST_J_RESNET, 'checkpoint'))
        temp.freeze()
        self.modelJ = temp.backbone     # torch.load(PRE_J_RESNET, map_location=self.device).eval())
        # Proper way to copy the last layer
        self.lastJ = self.modelJ.fc
        self.modelJ.fc = nn.Identity()

        self.imgMetrics = tm.MetricCollection(
            {
                'FID_F' : tm.FID(self.modelF),
                'FID_J' : tm.FID(self.modelJ),
            },
            prefix='LensGAN128/val/',
        )

        metrics = tm.MetricCollection(
            [tm.PrecisionRecallCurve(self.hparams.num_classes), tm.ROC(self.hparams.num_classes),
            tm.AveragePrecision(self.hparams.num_classes, average=None), tm.AUROC(self.hparams.num_classes, average=None)],
        )
        self.FMetrics = metrics.clone()
        self.JMetrics = metrics.clone()
    
    def discriminator_forward(self, x):
        # Add noise to the data
        inp = x + torch.randn_like(x) * np.exp(-self.global_step * len(x) / 7500)
        # # Not needed if not using ACGAN
        out5 = self.discriminator.disc(inp)
        return self.discriminator.critic(out5).squeeze(), self.discriminator.aux(out5)
        # return self.discriminator.disc(inp).squeeze()

    def forward(self, noise, labels, all_layers=False):
        inp = torch.cat((F.one_hot(labels, self.hparams.num_classes), noise), 1)
        if all_layers:
            out = [inp.view(*inp.shape, 1, 1)]
            for layer in self.generator.gen:
                out.append(layer(out[-1]))
            return out[1:]
        else:
            return super().forward(inp)

    def training_step(self, batch, batch_idx, optimizer_idx):
        real, self.labels = batch
        fake = self._get_fake_data(self.labels).type_as(real)

        # Train discriminator
        result = None
        if optimizer_idx == 0:
            result = self._disc_step(real, fake.detach())

        # Train generator
        if optimizer_idx == 1:
            result = self._gen_step(fake)

        del self.labels
        return result

    def _disc_step(self, real, fake):
        # # Not needed if using gradient penalty
        # for p in self.discriminator.parameters():
        #     p.data.clamp_(-0.01, 0.01)
        disc_loss = self._get_disc_loss(real, fake)
        self.log('LensGAN128/D/train/loss', disc_loss)
        return disc_loss

    def _gen_step(self, fake):
        gen_loss = self._get_gen_loss(fake)
        self.log('LensGAN128/G/train/loss', gen_loss)
        return gen_loss

    def _get_disc_loss(self, real, fake):
        # Train with real
        # realCritic_pred = self.discriminator(real)
        realCritic_pred, realAux_pred = self.discriminator(real)
        # real_loss = -realCritic_pred.mean()
        real_gt = torch.ones_like(realCritic_pred)
        real_loss = self.criterion(realCritic_pred, real_gt)
        self.log('LensGAN128/D/train/loss/real', real_loss)

        # Train with fake
        # fakeCritic_pred = self.discriminator(fake)
        fakeCritic_pred, fakeAux_pred = self.discriminator(fake)
        # fake_loss = fakeCritic_pred.mean()
        fake_gt = torch.zeros_like(fakeCritic_pred)
        fake_loss = self.criterion(fakeCritic_pred, fake_gt)
        self.log('LensGAN128/D/train/loss/fake', fake_loss)

        # # Classifier loss
        class_loss = nn.CrossEntropyLoss()(realAux_pred, self.labels) 
        # + nn.CrossEntropyLoss()(fakeAux_pred, self.labels)
        self.log('LensGAN128/D/train/loss/class', class_loss)

        # # Compute gradient penalty
        # gp = 10 * self._gradient_penalty(real, fake)
        # self.log('LensGAN128/D/train/loss/gp', gp)

        # Modi
        return real_loss + fake_loss + class_loss 
        # + gp 

    def _get_gen_loss(self, fake):
        # Train with fake
        # fakeCritic_pred = self.discriminator(fake)
        fakeCritic_pred, fakeAux_pred = self.discriminator(fake)
        # fake_loss = -fakeCritic_pred.mean()
        fake_gt = torch.ones_like(fakeCritic_pred)
        fake_loss = self.criterion(fakeCritic_pred, fake_gt)
        self.log('LensGAN128/G/train/loss/fake', fake_loss)

        # Classifier loss
        class_loss = nn.CrossEntropyLoss()(fakeAux_pred, self.labels)
        self.log('LensGAN128/G/train/loss/class', class_loss)

        return fake_loss + class_loss

    def _get_fake_data(self, labels, **kwargs):
        batch_size = len(labels)
        noise = self._get_noise(batch_size, self.hparams.latent_dim)
        fake = self(noise, labels, **kwargs)

        return fake

    # def _gradient_penalty(self, real, fake):
    #     """Calculates the gradient penalty loss for WGAN GP"""
    #     # Random weight term for interpolation between real and fake samples
    #     alpha = torch.rand(len(real), 1, 1, 1)
    #     # Get random interpolation between real and fake samples
    #     mix = torch.lerp(real, fake, alpha.type_as(real)).requires_grad_(True)
    #     # # Remove the underscore if not using ACGAN
    #     d_mix ,_ = self.discriminator(mix)
    #     # Get gradient w.r.t. mix
    #     gradients = torch.autograd.grad(
    #         outputs=d_mix,
    #         inputs=mix,
    #         grad_outputs=torch.ones_like(d_mix),
    #         create_graph=True,
    #         retain_graph=True,
    #         only_inputs=True,
    #     )[0]
    #     gradients = gradients.view(len(mix), -1)
    #     gp = ((gradients.norm(2, dim=1) - 1) ** 2)
    #     return gp.mean()

    def validation_step(self, batch, batch_idx):
        # print(self.global_step, batch_idx, len(self.imgMetrics.FID_J.fake_features))
        imgs128, labels = batch
        self.imgMetrics.update(F.interpolate(imgs128, 150), real=True)
        fake = F.interpolate(self._get_fake_data(labels), 150).type_as(imgs128)
        self.imgMetrics.update(fake, real=False)
        
        if self.global_step == 0:
            self.FMetrics.update(self.lastF(self.modelF(F.interpolate(imgs128, 150))).softmax(1), labels)
            self.JMetrics.update(self.lastJ(self.modelJ(F.interpolate(imgs128, 150))).softmax(1), labels)
        else:
            self.FMetrics.update(self.lastF(self.modelF(fake)).softmax(1), labels)
            self.JMetrics.update(self.lastJ(self.modelJ(fake)).softmax(1), labels)

    def validation_epoch_end(self, ListofDicts):
        # Classification scores
        fid = self.imgMetrics.compute()
        # self.log_dict(fid) isn't compatible with val_check_interval      
        self.log_dict(self.imgMetrics)
        
        fig = plt.figure(constrained_layout=True, figsize=(10, 9))
        # -----------------------------------------------------------------------------------------------------
        if self.global_step == 0:
            fig.suptitle('One vs. all PR & ROC curves for different types of substructures in model F (original data)')
        else:
            fig.suptitle('One vs. all PR & ROC curves for different types of substructures in model F (generated data)')
        subfigs = fig.subfigures(2, 1)
        colors = cycle(['r', 'g', 'b'])
        
        for a, src in enumerate(['F', 'J']):
            if self.global_step == 0:
                subfigs[a].suptitle('Classifier : model ' + src)
            else:
                subfigs[a].suptitle('Classifier : model ' + src + ', FID : {:0.2f}'.format(fid['LensGAN128/val/FID_'+ src]))
            ax = subfigs[a].subplots(1,2, subplot_kw={'xlim': [-0.05,1.05], 'ylim': [-0.05,1.05], 'aspect': 1})
            
            temp = getattr(self, src + 'Metrics')
            Dict = temp.compute()
            temp.reset()

            key, val = list(Dict.keys()), list(Dict.values())
            for b in range(2):
                if key[b] != 'ROC':
                    ax[b].plot([0, 1], [1/self.hparams.num_classes, 1/self.hparams.num_classes], 'k--')
                else:
                    ax[b].plot([0, 1], [0, 1], 'k--')
                
                prec_FPR, rec_TPR, _ = val[b]
                for i, color, cls in zip(range(self.hparams.num_classes), colors, self.trainer.datamodule.val_set.classes):
                    if key[b] != 'ROC':
                        try:
                            PrecisionRecallDisplay(prec_FPR[i].cpu(), rec_TPR[i].cpu(), val[2 + b][i], cls).plot(ax[b], c=color)
                        except Exception as f:
                            print(prec_FPR) 
                            print(rec_TPR) 
                            print(val[2+b])
                            raise Exception(f)

                    else:
                        RocCurveDisplay(prec_FPR[i].cpu(), rec_TPR[i].cpu(), val[2 + b][i], cls).plot(ax[b], c=color)
                    
                ax[b].legend(loc='best')
                self.log('LensGAN128/LensResnet(' + src + ')/val/' + key[2 + b], sum(val[2 + b])/len(val[2 + b]))

        self.logger.experiment.add_figure('LensGAN128/LensResnet/val/PR_ROC', fig)
        fig.savefig(str(self.trainer.log_dir) + '/PR_ROC_step_{:05d}.png'.format(self.global_step))

        # Save layer-wise activations
        labels = torch.arange(self.hparams.num_classes, device=self.device)
        imgs = self._get_fake_data(labels, all_layers=True)
        for lyr, img4d in enumerate(imgs[:-1]):
            new = list(img4d)
            for cls, filters3d in enumerate(new):
                new[cls] = make_grid(filters3d.unsqueeze(1), nrow=int(np.ceil(np.sqrt(len(filters3d)))), normalize=True)

            imgs[lyr] = F.interpolate(torch.stack(new), 150)

        save_image(torch.cat(imgs[-2::-1]), 
                str(self.trainer.log_dir) + '/F_maps_step_{:05d}.png'.format(self.global_step), 
                #  kwargs for make_grid
                nrow=self.hparams.num_classes, pad_value=0.5)
        
        imgs = imgs[-1]
        while len(labels) > 0:
            x, y = self.trainer.datamodule.val_set[np.random.randint(0, len(self.trainer.datamodule.val_set))]
            if y == labels[0]:
                imgs = torch.cat((imgs, x.unsqueeze(0).type_as(imgs)))
                labels = labels[1:]

        save_image(F.interpolate(imgs, 150), 
                str(self.trainer.log_dir) + '/Fake_step_{:05d}.png'.format(self.global_step), 
                #  kwargs for make_grid
                nrow=self.hparams.num_classes, normalize=True, value_range=(-1,1), pad_value=0.5)

In [None]:
m = LensGAN128({'lr':0.001, 'n_fmaps': 128, 'bs': 8})

## VAE
Here we subclass a VAE.

In [None]:
BEST_F_RESNET = '/content/drive/MyDrive/Logs/F/LensResnet/pbt_tanh/train_LensResnet_eb619_00000_0_2021-09-02_19-42-34/checkpoint_epoch=2-step=28124'
BEST_J_RESNET = '/content/drive/MyDrive/Logs/J/LensResnet/pbt_tanh_fine/train_LensResnet_93609_00000_0_2021-09-02_21-06-00/checkpoint_epoch=2-step=28124'

from pl_bolts.models.autoencoders import VAE

class LensVAE(VAE):
    def __init__(self, config, num_classes: int = 3, **kwargs):
        super().__init__(feature_maps_gen=config['n_fmaps'], feature_maps_disc=config['n_fmaps'], 
                         lr=10**config['lr'], **kwargs)
        self.save_hyperparameters(ignore=config)

        # Batch-norm needs depends upon out channels of the previous layer
        first = self.generator._make_gen_block(1, self.hparams.feature_maps_gen * 16)
        # Accepts inputs of any channels
        first[0] = nn.LazyConvTranspose2d(self.hparams.feature_maps_gen * 16, kernel_size=4, stride=1, padding=0)
        # Turn this into the second layer
        self.generator.gen[0] = self.generator._make_gen_block(self.hparams.feature_maps_gen * 16, self.hparams.feature_maps_gen * 8)
        self.generator.gen = nn.Sequential(first, *list(self.generator.gen))
        
        # Turn this into the penultimate layer
        self.discriminator.disc[-1] = self.discriminator._make_disc_block(self.hparams.feature_maps_disc * 8, self.hparams.feature_maps_disc * 16)
        # self.discriminator.disc = nn.Sequential(*list(self.discriminator.disc), nn.LazyConv2d(1, kernel_size=4, stride=1, padding=0, bias=False))
        # # Necessary if using ACGAN, else could be appended to disc module
        self.discriminator.add_module('critic', nn.LazyConv2d(1, kernel_size=4, stride=1, padding=0, bias=False))
        # # Remove if ACGAN not needed 
        self.discriminator.add_module('aux', nn.Sequential(nn.Flatten(), nn.LazyLinear(self.hparams.num_classes)))
        self.discriminator.forward = self.discriminator_forward

        for i in range(5):
            # Necessary for implementing gradient penalty, else remove
            # if i != 0:
                # self.discriminator.disc[i][1] = nn.LayerNorm([self.discriminator.disc[i][0].out_channels, 2 ** (6 - i), 2 ** (6 - i)])
            # Remove if LeakyRelu not needed in generator
            self.generator.gen[i][-1] = self.discriminator.disc[i][-1]
            # Implement dropouts
            # self.generator.gen[i] = nn.Sequential(*list(self.generator.gen[i]), nn.Dropout())
            # self.discriminator.disc[i] = nn.Sequential(*list(self.discriminator.disc[i]), nn.Dropout())

        # Not needed in WGAN architectures
        self.criterion = nn.BCEWithLogitsLoss()

        temp = LensResnet.load_from_checkpoint(os.path.join(BEST_F_RESNET, 'checkpoint'))
        temp.freeze()
        self.modelF = temp.backbone     # torch.load(PRE_F_RESNET, map_location=self.device).eval())
        # Proper way to copy the last layer
        self.lastF = self.modelF.fc
        self.modelF.fc = nn.Identity()
        
        temp = LensResnet.load_from_checkpoint(os.path.join(BEST_J_RESNET, 'checkpoint'))
        temp.freeze()
        self.modelJ = temp.backbone     # torch.load(PRE_J_RESNET, map_location=self.device).eval())
        # Proper way to copy the last layer
        self.lastJ = self.modelJ.fc
        self.modelJ.fc = nn.Identity()

        self.imgMetrics = tm.MetricCollection(
            {
                'FID_F' : tm.FID(self.modelF),
                'FID_J' : tm.FID(self.modelJ),
            },
            prefix='LensGAN128/val/',
        )

        metrics = tm.MetricCollection(
            [tm.PrecisionRecallCurve(self.hparams.num_classes), tm.ROC(self.hparams.num_classes),
            tm.AveragePrecision(self.hparams.num_classes), tm.AUROC(self.hparams.num_classes, average=None)],
        )
        self.FMetrics = metrics.clone()
        self.JMetrics = metrics.clone()
    
    def discriminator_forward(self, x):
        # Add noise to the data
        inp = x + torch.randn_like(x) * np.exp(-self.global_step * len(x) / 7500)
        # # Not needed if not using ACGAN
        out5 = self.discriminator.disc(inp)
        return self.discriminator.critic(out5).squeeze(), self.discriminator.aux(out5)
        # return self.discriminator.disc(inp).squeeze()

    def forward(self, noise, labels, all_layers=False):
        inp = torch.cat((F.one_hot(labels, self.hparams.num_classes), noise), 1)
        if all_layers:
            out = [inp.view(*inp.shape, 1, 1)]
            for layer in self.generator.gen:
                out.append(layer(out[-1]))
            return out[1:]
        else:
            return super().forward(inp)

    def training_step(self, batch, batch_idx, optimizer_idx):
        real, self.labels = batch
        fake = self._get_fake_data(self.labels).type_as(real)

        # Train discriminator
        result = None
        if optimizer_idx == 0:
            result = self._disc_step(real, fake.detach())

        # Train generator
        if optimizer_idx == 1:
            result = self._gen_step(fake)

        del self.labels
        return result

    def _disc_step(self, real, fake):
        # # Not needed if using gradient penalty
        # for p in self.discriminator.parameters():
        #     p.data.clamp_(-0.01, 0.01)
        disc_loss = self._get_disc_loss(real, fake)
        self.log('LensGAN128/D/train/loss', disc_loss)
        return disc_loss

    def _gen_step(self, fake):
        gen_loss = self._get_gen_loss(fake)
        self.log('LensGAN128/G/train/loss', gen_loss)
        return gen_loss

    def _get_disc_loss(self, real, fake):
        # Train with real
        # realCritic_pred = self.discriminator(real)
        realCritic_pred, realAux_pred = self.discriminator(real)
        # real_loss = -realCritic_pred.mean()
        real_gt = torch.ones_like(realCritic_pred)
        real_loss = self.criterion(realCritic_pred, real_gt)
        self.log('LensGAN128/D/train/loss/real', real_loss)

        # Train with fake
        # fakeCritic_pred = self.discriminator(fake)
        fakeCritic_pred, fakeAux_pred = self.discriminator(fake)
        # fake_loss = fakeCritic_pred.mean()
        fake_gt = torch.zeros_like(fakeCritic_pred)
        fake_loss = self.criterion(fakeCritic_pred, fake_gt)
        self.log('LensGAN128/D/train/loss/fake', fake_loss)

        # # Classifier loss
        class_loss = nn.CrossEntropyLoss()(realAux_pred, self.labels) 
        # + nn.CrossEntropyLoss()(fakeAux_pred, self.labels)
        self.log('LensGAN128/D/train/loss/class', class_loss)

        # # Compute gradient penalty
        # gp = 10 * self._gradient_penalty(real, fake)
        # self.log('LensGAN128/D/train/loss/gp', gp)

        # Modi
        return real_loss + fake_loss + class_loss 
        # + gp 

    def _get_gen_loss(self, fake):
        # Train with fake
        # fakeCritic_pred = self.discriminator(fake)
        fakeCritic_pred, fakeAux_pred = self.discriminator(fake)
        # fake_loss = -fakeCritic_pred.mean()
        fake_gt = torch.ones_like(fakeCritic_pred)
        fake_loss = self.criterion(fakeCritic_pred, fake_gt)
        self.log('LensGAN128/G/train/loss/fake', fake_loss)

        # Classifier loss
        class_loss = nn.CrossEntropyLoss()(fakeAux_pred, self.labels)
        self.log('LensGAN128/G/train/loss/class', class_loss)

        return fake_loss + class_loss

    def _get_fake_data(self, labels, **kwargs):
        batch_size = len(labels)
        noise = self._get_noise(batch_size, self.hparams.latent_dim)
        fake = self(noise, labels, **kwargs)

        return fake

    # def _gradient_penalty(self, real, fake):
    #     """Calculates the gradient penalty loss for WGAN GP"""
    #     # Random weight term for interpolation between real and fake samples
    #     alpha = torch.rand(len(real), 1, 1, 1)
    #     # Get random interpolation between real and fake samples
    #     mix = torch.lerp(real, fake, alpha.type_as(real)).requires_grad_(True)
    #     # # Remove the underscore if not using ACGAN
    #     d_mix ,_ = self.discriminator(mix)
    #     # Get gradient w.r.t. mix
    #     gradients = torch.autograd.grad(
    #         outputs=d_mix,
    #         inputs=mix,
    #         grad_outputs=torch.ones_like(d_mix),
    #         create_graph=True,
    #         retain_graph=True,
    #         only_inputs=True,
    #     )[0]
    #     gradients = gradients.view(len(mix), -1)
    #     gp = ((gradients.norm(2, dim=1) - 1) ** 2)
    #     return gp.mean()

    def validation_step(self, batch, batch_idx):
        # print(self.global_step, batch_idx, len(self.imgMetrics.FID_J.fake_features))
        imgs128, labels = batch
        self.imgMetrics.update(F.interpolate(imgs128, 150), real=True)
        fake = F.interpolate(self._get_fake_data(labels), 150).type_as(imgs128)
        self.imgMetrics.update(fake, real=False)
        
        if self.global_step == 0:
            self.FMetrics.update(self.lastF(self.modelF(F.interpolate(imgs128, 150))).softmax(1), labels)
            self.JMetrics.update(self.lastJ(self.modelJ(F.interpolate(imgs128, 150))).softmax(1), labels)
        else:
            self.FMetrics.update(self.lastF(self.modelF(fake)).softmax(1), labels)
            self.JMetrics.update(self.lastJ(self.modelJ(fake)).softmax(1), labels)

    def validation_epoch_end(self, ListofDicts):
        # Classification scores
        fid = self.imgMetrics.compute()
        # self.log_dict(fid) isn't compatible with val_check_interval      
        self.log_dict(self.imgMetrics)
        
        fig = plt.figure(constrained_layout=True, figsize=(10, 9))
        # -----------------------------------------------------------------------------------------------------
        if self.global_step == 0:
            fig.suptitle('One vs. all PR & ROC curves for different types of substructures in model F (original data)')
        else:
            fig.suptitle('One vs. all PR & ROC curves for different types of substructures in model F (generated data)')
        subfigs = fig.subfigures(2, 1)
        colors = cycle(['r', 'g', 'b'])
        
        for a, src in enumerate(['F', 'J']):
            if self.global_step == 0:
                subfigs[a].suptitle('Classifier : model ' + src)
            else:
                subfigs[a].suptitle('Classifier : model ' + src + ', FID : {:0.2f}'.format(fid['LensGAN128/val/FID_'+ src]))
            ax = subfigs[a].subplots(1,2, subplot_kw={'xlim': [-0.05,1.05], 'ylim': [-0.05,1.05], 'aspect': 1})
            
            temp = getattr(self, src + 'Metrics')
            Dict = temp.compute()
            temp.reset()

            key, val = list(Dict.keys()), list(Dict.values())
            for b in range(2):
                if key[b] != 'ROC':
                    ax[b].plot([0, 1], [1/self.hparams.num_classes, 1/self.hparams.num_classes], 'k--')
                else:
                    ax[b].plot([0, 1], [0, 1], 'k--')
                
                prec_FPR, rec_TPR, _ = val[b]
                for i, color, cls in zip(range(self.hparams.num_classes), colors, self.trainer.datamodule.val_set.classes):
                    if key[b] != 'ROC':
                        PrecisionRecallDisplay(prec_FPR[i].cpu(), rec_TPR[i].cpu(), val[2 + b][i], cls).plot(ax[b], c=color)
                    else:
                        RocCurveDisplay(prec_FPR[i].cpu(), rec_TPR[i].cpu(), val[2 + b][i], cls).plot(ax[b], c=color)
                    
                ax[b].legend(loc='best')
                self.log('LensGAN128/LensResnet(' + src + ')/val/' + key[2 + b], sum(val[2 + b])/len(val[2 + b]))

        self.logger.experiment.add_figure('LensGAN128/LensResnet/val/PR_ROC', fig)
        fig.savefig(str(self.trainer.log_dir) + '/PR_ROC_step_{:05d}.png'.format(self.global_step))

        # Save layer-wise activations
        labels = torch.arange(self.hparams.num_classes, device=self.device)
        imgs = self._get_fake_data(labels, all_layers=True)
        for lyr, img4d in enumerate(imgs[:-1]):
            new = list(img4d)
            for cls, filters3d in enumerate(new):
                new[cls] = make_grid(filters3d.unsqueeze(1), nrow=int(np.ceil(np.sqrt(len(filters3d)))), normalize=True)
                # print(three3d.shape)
                # new[cls] = three3d[0].unsqueeze(0)
                # print(new[cls].shape)

            imgs[lyr] = F.interpolate(torch.stack(new), 150)
            # print(imgs[lyr].shape)

        save_image(torch.cat(imgs[-2::-1]), 
                str(self.trainer.log_dir) + '/F_maps_step_{:05d}.png'.format(self.global_step), 
                #  kwargs for make_grid
                nrow=self.hparams.num_classes, pad_value=0.5)
        
        imgs = imgs[-1]
        while len(labels) > 0:
            x, y = self.trainer.datamodule.val_set[np.random.randint(0, len(self.trainer.datamodule.val_set))]
            if y == labels[0]:
                imgs = torch.cat((imgs, x.unsqueeze(0).type_as(imgs)))
                labels = labels[1:]

        save_image(F.interpolate(imgs, 150), 
                str(self.trainer.log_dir) + '/Fake_step_{:05d}.png'.format(self.global_step), 
                #  kwargs for make_grid
                nrow=self.hparams.num_classes, normalize=True, value_range=(-1,1), pad_value=0.5)

## Stage 2
Here we subclass a DCGAN to create our high resolution GAN.

In [None]:
class Generator2(nn.Module):
    def __init__(self, ngf: int = 128, image_channels: int = 1, res_depth: int = 6):
        super().__init__()

        ker, strd = 4, 2
        pad = int((ker - 2)/2)
        res_ker, res_strd, res_pad = 3, 1, 1
        
        # 64 -> 32
        self.preprocessing = nn.Sequential(
            nn.Conv2d(image_channels, ngf, ker, strd, pad, bias=False),
            nn.ReLU(True)
        )
        # residuals
        layer = []
        for _ in range(res_depth):
            layer.append(BasicBlock(ngf, ngf))
        self.residual = nn.Sequential(*layer)
        
        self.ending_residual = nn.Sequential(
            nn.Conv2d(ngf, ngf, res_ker, res_strd, res_pad, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True)
        )

        # at this part, add the residual inputs from after the preprocessing

        image_width = 150 # upscaling should be factor of 2 increase
        mode = 'nearest' # upscaling method is nearest-neighbour
        self.main = nn.Sequential(
            # 32 -> 75
            nn.Upsample(image_width//2, mode=mode),
            nn.Conv2d(ngf, ngf*4, res_ker, res_strd, res_pad, bias=False),
            nn.BatchNorm2d(ngf*4),
            nn.ReLU(True),
            # 75 -> 150
            nn.Upsample(image_width, mode=mode),
            nn.Conv2d(ngf*4, image_channels, res_ker, res_strd, res_pad, bias=False),
            nn.Tanh()
        )

    def forward(self, in_x):
        x_p = self.preprocessing(in_x)
        x_r = x_p
        x_r = self.residual(x_r)
        x_r = self.ending_residual(x_r)
        # large residual connections
        x_f = x_r + x_p
        return self.main(x_f)

In [None]:
BEST_F_LensGAN128 = '/content/drive/MyDrive/Logs/F/LensGAN128/pbt_tanh_1/train_LensGAN128_90727_00001_1_n_fmaps=16_2021-08-30_08-02-05/checkpoint_epoch=4-step=1988/'
BEST_J_LensGAN128 ='/content/drive/MyDrive/Logs/J/LensGAN128/pbt_tanh/train_LensGAN128_28c03_00003_3_n_fmaps=64_2021-08-31_20-45-44/checkpoint_epoch=2-step=1403/'

class Stage2(DCGAN):
    def __init__(self, config, num_classes: int = 3, **kwargs):
        super().__init__(feature_maps_gen=config['n_fmaps'], feature_maps_disc=config['n_fmaps'], learning_rate=config['learning_rate'])
        self.save_hyperparameters(ignore=config)

        self.generator = Generator2(self.hparams.feature_maps_gen, self.hparams.image_channels, config['res_depth'])

        # These are better as attributes, instead of being returned by a method
        self.modelF = getattr(self, 'modelF', LensResnet.load_from_checkpoint(os.path.join(BEST_RESNET_F, 'checkpoint')).eval())
        self.modelJ = getattr(self, 'modelJ', LensResnet.load_from_checkpoint(os.path.join(BEST_RESNET_J, 'checkpoint')).eval())
        # Workaround:
        self.lowres = getattr(self, 'lowres', LensGAN128.load_from_checkpoint(os.path.join(BEST_LensGAN128_F, 'checkpoint')).eval())
        
        metrics = tm.MetricCollection(
            [
             tm.AUROC(num_classes=self.hparams.num_classes, compute_on_step=False, average=None), 
             tm.ROC(num_classes=self.hparams.num_classes, compute_on_step=False),
            ]
        )
        self.metricsF = metrics.clone()
        self.metricsJ = metrics.clone()

    def forward(self, noise):
        return self.generator(noise)

    def training_step(self, batch, batch_idx, optimizer_idx):
        real, self.labels = batch

        # Train discriminator
        result = None
        if optimizer_idx == 0:
            result = self._disc_step(real)

        # Train generator
        if optimizer_idx == 1:
            result = self._gen_step(real)

        return result

    def _disc_step(self, real):
        disc_loss = self._get_disc_loss(real)
        self.log('Stage2/D/train/loss', disc_loss, on_epoch=True)
        return disc_loss

    def _gen_step(self, real):
        gen_loss = self._get_gen_loss(real)
        self.log('Stage2/G/train/loss', gen_loss, on_epoch=True)
        return gen_loss

    def _get_gen_loss(self, real: torch.Tensor) -> torch.Tensor:
        # Train with fake
        fake_pred = self._get_fake_pred(real)
        fake_gt = torch.ones_like(fake_pred)
        gen_loss = self.criterion(fake_pred, fake_gt)

        # class_pred =  self._get_class_pred(len(real))
        # gen_loss += F.cross_entropy(class_pred, self.labels)

        return gen_loss

    def _get_class_pred(self, batch_size) -> torch.Tensor:
        # ----------------------------------------------------------------------------------------------------------------
        return self.modelF.backbone(self(self._get_noise(batch_size, self.hparams.latent_dim)))

    def _get_noise(self, n_samples: int, latent_dim: int, labels = None):
        # can't use self in function definition
        if labels is None:
            labels = self.labels
            # getattr(self, 'labels', torch.randint(self.hparams.num_classes, (n_samples,), device=self.device))  # last dimension is the hidden dimension
        return self.lowres(super()._get_noise(n_samples, latent_dim), labels)

    def validation_step(self, batch, batch_idx):
        imgs, labels = batch
        out = self(self._get_noise(labels.shape[0], self.hparams.latent_dim, labels))
        self.metricsF.update(self.modelF(out), labels)
        self.metricsJ.update(self.modelJ(out), labels)
        # out = Fig.interpolate(out_64, 150)

    def validation_epoch_end(self, listofDicts):
        fig, ax = plt.subplots(1,2, 
            subplot_kw={'xlim': [0,1], 'xlabel': 'False Positive Rate', 'ylim': [0,1.05], 
                        'ylabel': 'True Positive Rate',
            },
            figsize=[11, 5],
        )
        for j, letter in enumerate(['F', 'J']):
            output = getattr(self, 'metrics' + letter).compute()
            self.log('Stage2/ResNet(' + letter + ')/val/auroc', output['AUROC'].min())
            fprList, tprList, _ = output['ROC']
            
            colors = cycle(['red', 'blue', 'green'])
            for i, color in zip(range(self.hparams.num_classes), colors):
                ax[j].plot(fprList[i].cpu(), tprList[i].cpu(), color=color,
                        label='ROC curve of class {0} (area = {1:0.4f})'
                        ''.format(i, output['AUROC'][i]))
            post_plotting(ax[j])
            ax[j].set_title('One vs. all ROC curve (' + letter + ')')
        
        fig.tight_layout()
        self.logger.experiment.add_figure('Stage2/ResNet/val/ROC', fig)
        fig.savefig(str(self.trainer.log_dir) + '/ROC_step_{:05d}.png'.format(self.global_step))

        labels = torch.arange(self.hparams.num_classes, device=self.device)
        save_image(self(self._get_noise(labels.shape[0], self.hparams.latent_dim, labels)), 
                   str(self.trainer.log_dir) + '/Fake_step_{:05d}.png'.format(self.global_step), 
                  #  kwargs for make_grid
                   normalize=True, value_range=(-1,1))

    def on_fit_end(self):
        delattr(self, 'modelF')
        delattr(self, 'modelJ')
        delattr(self, 'labels')
        delattr(self, 'lowres')

# Tune Hyperparameters


## ResNet
Here we tune hyperparameters as we train our modified ResNet.

In [None]:
%rm -rf /content/drive/MyDrive/Logs/fakeF/PGAN/LensResnet/pbt_tanh_validate

In [None]:
# __tune_train_checkpoint_begin
def train_LensResnet(config, checkpoint_dir=None, num_epochs=10, num_gpus=torch.cuda.device_count()):
    # print(os.cpu_count(), torch.cuda.device_count())
    kwargs = {
        # 'limit_train_batches' : 0.005,
        # 'limit_val_batches' : 0.005,
        'progress_bar_refresh_rate' : int(8250//int(2**np.rint(config['bs']))),
        'max_epochs' : num_epochs,
        'prepare_data_per_node' : False,
        # If fractional GPUs passed in, convert to int.
        'gpus' : int(num_gpus),
        'logger' : TensorBoardLogger(save_dir=tune.get_trial_dir(), name='', version='.'),
        'callbacks' : [
            TuneReportCheckpointCallback(
                {
                    'loss': 'LensResnet/val/loss', 
                    'auroc': 'LensResnet/val/AUROC', 
                    'ap' : 'LensResnet/val/AveragePrecision',
                },
            ),
            ModuleDataMonitor(['backbone.layer2', 'backbone.layer4', 'backbone.fc']),
            ConfusedLogitCallback(5),
        ],
        'stochastic_weight_avg' : True,
        # works with only one optimizer
        'benchmark' : True,
        'precision' : 16,     # can't use on cpu
        # 'track_grad_norm': 2,
        # 'gradient_clip_val' : 0.5, 
        # 'gradient_clip_algorithm' : 'value',
    }
    
    dm = npyImageData(config)                                              # Specify image width here    
    if checkpoint_dir is not None:
        kwargs['resume_from_checkpoint'] = os.path.join(checkpoint_dir, 'checkpoint')
        # model = LensResnet.load_from_checkpoint(kwargs['resume_from_checkpoint'], config=config)
    # else:

    model = LensResnet(config)
    trainer = pl.Trainer(**kwargs)

    trainer.fit(model, dm)
# __tune_train_checkpoint_end__

# __tune_pbt_begin__
def tune_LensResnet_pbt(num_samples=10, num_epochs=10, gpus_per_trial=torch.cuda.device_count()):
    # print(os.cpu_count(), torch.cuda.device_count())
    analysis = tune.run(
        tune.with_parameters(
            train_LensResnet,
            num_epochs=num_epochs,
            num_gpus=gpus_per_trial
        ),
        # Change the folder name when changing dataset--------------------------------------------------------------------------
        name='J/LensResnet/pbt_tanh_fine',
        metric='loss',
        mode='min',
        # stop=TrialPlateauStopper('auroc'),
        resources_per_trial={'cpu': os.cpu_count(), 'gpu': gpus_per_trial},
        local_dir='./drive/MyDrive/Logs',
        # config={'lr': tune.choice([1e-4, 1e-3, 1e-5, 1e-2, 1e-6, 1e-1, 1e-7]),
        #         'bs': tune.grid_search([8, 16, 32, 64, 128]),
        #         },
        # scheduler = pbtScheduler(max_t=num_epochs, grace_period=2, reduction_factor=2),
        # Can't use RB2 as it requires mutations to be continuous
        config={'lr': 1e-5,
                'bs': 8,
                # RuntimeError: stack expects each tensor to be equal size, but got [128] at entry 0 and [120] at entry 585
                },
        scheduler = PopulationBasedTraining(time_attr='training_iteration', quantile_fraction=0.4,
                                            resample_probability=0.2,  perturbation_interval=1,
                                            hyperparam_mutations={
                                                'lr': tune.loguniform(1e-6, 1e-4),
                                                'bs': [8, 16, 32, 64, 128],
                                            },
        ),
        progress_reporter=JupyterNotebookReporter(
            overwrite=False,
            parameter_columns=['lr', 'bs'],
            metric_columns=['loss', 'auroc', 'ap', 'training_iteration'],
        ),
        fail_fast = True,
        # reuse_actors=True,
        num_samples=num_samples,
        # resume='PROMPT',
    )
    BEST_J_RESNET = analysis.best_checkpoint
    print('Best checkpoint path found is: ', BEST_J_RESNET)

# __tune_pbt_end__

if __name__ == '__main__':
    import argparse

    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--smoke-test', action='store_true', help='Finish quickly for testing')
    args, _ = parser.parse_known_args()

    if args.smoke_test:
        tune_LensResnet_pbt(num_samples=1, num_epochs=6, gpus_per_trial=torch.cuda.device_count())
    else:
        # pbt scheduler
        tune_LensResnet_pbt(num_samples=1, num_epochs=5, gpus_per_trial=torch.cuda.device_count())

## LensGAN128
Here we tune hyperparameters as we train our modified DCGAN.

In [None]:
def train_mnist(config, checkpoint=None, num_steps=10000, num_gpus=torch.cuda.device_count()):
    dm = npyImageData(config, 128)                                              # Specify image width here    
    model = LensGAN128(config)
    trainer = pl.Trainer(val_check_interval=0.10, max_steps=num_steps, gpus=num_gpus,
                         logger=TensorBoardLogger(save_dir=os.path.join('drive/MyDrive/Logs/', 'F/LensGAN128'), 
                                                  name='test', sub_dir='_'.join('{}_{}'.format(*i) for i in config.items())), 
                         callbacks=[
                                    ModelCheckpoint(monitor='LensGAN128/LensResnet(F)/val/AUROC', 
                                                    verbose=True, save_last=True, save_top_k=5, 
                                                    mode='max', auto_insert_metric_name=False, 
                                                    every_n_train_steps=int(7500//int(2**np.rint(config['bs'])))),
                                    ModuleDataMonitor(True), 
                                    ],
                         benchmark=True, precision=16)

    trainer.fit(model, dm, ckpt_path=checkpoint)
# __lightning_end__

# __no_tune_train_begin__
def train_mnist_no_tune(*args, **kwargs):
    config={'lr': -3,
            'n_fmaps': 32,
            'bs': 3,
            }
    train_mnist(config, *args, **kwargs)
# __no_tune_train_end__

if __name__ == '__main__':
    import argparse

    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--smoke-test', action='store_true', help='Finish quickly for testing')
    args, _ = parser.parse_known_args()

    os.environ["PL_FAULT_TOLERANT_TRAINING"] = "1"
    if args.smoke_test:
        train_mnist_no_tune()
    else:
        # pbt scheduler
        train_mnist_no_tune()

Using 16bit native Automatic Mixed Precision (AMP)
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name          | Type               | Params
-----------------------------------------------------
0 | generator     | DCGANGenerator     | 2.8 M 
1 | discriminator | DCGANDiscriminator | 2.8 M 
2 | criterion     | BCEWithLogitsLoss  | 0     
3 | modelF        | ResNet             | 11.2 M
4 | lastF         | Sequential         | 1.5 K 
5 | modelJ        | ResNet             | 11.2 M
6 | lastJ         | Sequential         | 1.5 K 
7 | imgMetrics    | MetricCollection   | 22.3 M
8 | FMetrics      | MetricCollection   | 0     
9 | JMetrics      | MetricCollection   | 0     
-----------------------------------------------------
5.6 M     Trainable params
22.3 M    Non-trainable params
27.9 M    Total params
55.838    Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]



Training: 0it [00:00, ?it/s]

Epoch 0, global step 936: LensGAN128/LensResnet(F)/val/AUROC was not in top 5


Validating: 0it [00:00, ?it/s]

Epoch 0, global step 1873: LensGAN128/LensResnet(F)/val/AUROC reached 0.49840 (best 0.49840), saving model to "drive/MyDrive/Logs/F/LensGAN128/test/version_2/checkpoints/0-1873.ckpt" as top 5


Validating: 0it [00:00, ?it/s]

Epoch 0, global step 2810: LensGAN128/LensResnet(F)/val/AUROC reached 0.49875 (best 0.49875), saving model to "drive/MyDrive/Logs/F/LensGAN128/test/version_2/checkpoints/0-2810.ckpt" as top 5


Validating: 0it [00:00, ?it/s]

Epoch 0, global step 3747: LensGAN128/LensResnet(F)/val/AUROC reached 0.48109 (best 0.49875), saving model to "drive/MyDrive/Logs/F/LensGAN128/test/version_2/checkpoints/0-3747.ckpt" as top 5


Validating: 0it [00:00, ?it/s]

Epoch 0, global step 4684: LensGAN128/LensResnet(F)/val/AUROC reached 0.49019 (best 0.49875), saving model to "drive/MyDrive/Logs/F/LensGAN128/test/version_2/checkpoints/0-4684.ckpt" as top 5


Validating: 0it [00:00, ?it/s]

In [None]:
%rm -rf drive/MyDrive/Logs/F/LensGAN128/acgan_nodiscfake_leaky_discnoise_full_pb2

In [None]:

# __tune_train_checkpoint_begin
def train_LensGAN128(config, checkpoint_dir=None, num_steps=10000, num_gpus=torch.cuda.device_count()):
    # print(os.cpu_count(), torch.cuda.device_count())
    kwargs = {
        # 'limit_train_batches' : 0.05,
        # 'limit_val_batches' : 0.05,
        # 'progress_bar_refresh_rate' : int(7500//int(2**np.rint(config['bs']))),
        'val_check_interval' : 0.10,
        'max_steps' : num_steps,
        # 'prepare_data_per_node' : False,
        # If fractional GPUs passed in, convert to int.
        'gpus' : int(num_gpus),
        'logger' : TensorBoardLogger(save_dir=tune.get_trial_dir(), name='', version='.'),
        'callbacks' : [
            TuneReportCheckpointCallback(
                {
                    'loss_G': 'LensGAN128/G/train/loss', 
                    'loss_D': 'LensGAN128/D/train/loss', 
                    # Switch up the FID vlues when training on different dataset -----------------------------------------------
                    'FID': 'LensGAN128/val/FID_F', 
                    'FID_cross': 'LensGAN128/val/FID_J',
                    'auroc': 'LensGAN128/LensResnet(F)/val/AUROC',
                    'auroc_cross': 'LensGAN128/LensResnet(J)/val/AUROC',
                    'ap': 'LensGAN128/LensResnet(F)/val/AveragePrecision',
                    'ap_cross': 'LensGAN128/LensResnet(J)/val/AveragePrecision'
                },
                # Validation end is better, resumes with updated checkpoint
                # on='train_end',
            ),
            ModuleDataMonitor(True),
            TQDMProgressBar(refresh_rate=int(7500//int(2**np.rint(config['bs'])))),
        ],
        # 'stochastic_weight_avg' : True,
        # works with only one optimizer
        'benchmark' : True,
        'precision' : 16,
    }
    
    dm = npyImageData(config, 128)                                              # Specify image width here    
    if checkpoint_dir is not None:
        kwargs['resume_from_checkpoint'] = os.path.join(checkpoint_dir, 'checkpoint')

    model = LensGAN128(config)
    trainer = pl.Trainer(**kwargs)

    trainer.fit(model, dm)
# __tune_train_checkpoint_end__

# __tune_pbt_begin__
def tune_LensGAN128_pbt(num_samples=10, num_steps=10000, gpus_per_trial=torch.cuda.device_count()):
    # print(os.cpu_count(), torch.cuda.device_count())
    analysis = tune.run(
        tune.with_parameters(
            train_LensGAN128,
            num_steps=num_steps,
            num_gpus=gpus_per_trial
        ),
        # Change the folder name when changing dataset--------------------------------------------------------------------------
        name='F/LensGAN128/test', 
        metric='auroc',
        mode='max',
        # stop=TrialPlateauStopper('FID'),
        resources_per_trial={'cpu': os.cpu_count(), 'gpu': gpus_per_trial},
        local_dir='./drive/MyDrive/Logs',
        config={'lr': -3,
                'n_fmaps': tune.grid_search([8, 16, 32, 64]),
                'bs': 3,
                },
        # config = {'lr': 2.340983544823817e-05, 'n_fmaps': 32, 'bs': 8},
        scheduler = PB2('training_iteration', quantile_fraction=0.25, perturbation_interval=2,
                            # resample_probability=0.25,  
                            hyperparam_bounds={
                                'lr': [-6, -3],  #tune.loguniform(1e-5, 1e-3),
                                'bs': [3, 7],    #[8, 16, 32, 64, 128],
                            },
        ),
        progress_reporter=JupyterNotebookReporter(
            overwrite=False,
            parameter_columns=['lr', 'n_fmaps', 'bs'],
            metric_columns=['loss_G', 'loss_D', 'FID', 'auroc', 'ap',
                            # 'FID_cross', 'auroc_cross', 'ap_cross', 
                            'training_iteration'],
            max_report_frequency=5 * 60,
        ),
        fail_fast = True,
        # reuse_actors=True,
        num_samples=num_samples,
        # resume='PROMPT',
    )
    # ---------------------------------------------------------------------------------------------
    # Plot by wall-clock time
    # dfs = analysis.fetch_trial_dataframes()
    # # This plots everything on the same plot
    # ax = None
    # for d in dfs.values():
    #     ax = d.plot('training_iteration', ['loss_G', 'loss_D', 'FID', 'auroc', 'ap'], 
    #                 ax=ax, subplots=True, legend=False, layout=(2, 3))

    #     ax.xlabel('iterations')
    #     ax.ylabel('Test Accuracy')

    # print('best config:', analysis.get_best_config('auroc'))

    print('Best checkpoint path found is: ', analysis.best_checkpoint)

# __tune_pbt_end__

if __name__ == '__main__':
    import argparse

    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--smoke-test', action='store_true', help='Finish quickly for testing')
    args, _ = parser.parse_known_args()

    if args.smoke_test:
        tune_LensGAN128_pbt(num_samples=1, num_steps=10000, gpus_per_trial=torch.cuda.device_count())
    else:
        # pbt scheduler
        tune_LensGAN128_pbt(num_samples=1, num_steps=10000, gpus_per_trial=torch.cuda.device_count())

In [None]:
drive.flush_and_unmount()

## Stage 2
Here we tune hyperparameters as we train our modified DCGAN.

In [None]:
# __tune_train_checkpoint_begin
def train_Stage2(config, checkpoint_dir=None, num_epochs=10, num_gpus=torch.cuda.device_count()):
    # print(os.cpu_count(), torch.cuda.device_count())
    kwargs = {
        # 'limit_train_batches' : 0.05,
        # 'limit_val_batches' : 0.05,
        'progress_bar_refresh_rate' : int(8250//config['batch_size']),
        'max_epochs' : num_epochs,
        'prepare_data_per_node' : False,
        # If fractional GPUs passed in, convert to int.
        'gpus' : int(num_gpus),
        'logger' : TensorBoardLogger(save_dir=tune.get_trial_dir(), name='', version='.'),
        'callbacks' : [
            TuneReportCheckpointCallback(
                {
                    'loss_G': 'Stage2/G/train/loss', 
                    'loss_D': 'Stage2/D/train/loss', 
                    # Switch up the auroc vlues when training on different dataset -----------------------------------------------
                    'auroc': 'Stage2/ResNet(F)/val/auroc', 
                    'auroc_cross': 'Stage2/ResNet(J)/val/auroc',
                },
            ),
        ],
        # 'stochastic_weight_avg' : True,
        # works with only one optimizer
        # 'benchmark' : True,
    }
    
    dm = npyImageData(config)                                              # Specify image width here    
    if checkpoint_dir is not None:
        kwargs['resume_from_checkpoint'] = os.path.join(checkpoint_dir, 'checkpoint')
        # model = Stage2.load_from_checkpoint(kwargs['resume_from_checkpoint'], config=config)
    # else:
        # model = Stage2(config)
    model = Stage2(config)
    trainer = pl.Trainer(**kwargs)

    trainer.fit(model, dm)
# __tune_train_checkpoint_end__


# # # __tune_asha_begin__
# def tune_Stage2_asha(num_samples=10, num_epochs=10, gpus_per_trial=torch.cuda.device_count()):
#     # print(os.cpu_count(), torch.cuda.device_count())
#     analysis = tune.run(
#         tune.with_parameters(
#             train_Stage2,
#             num_epochs=num_epochs,
#             num_gpus=gpus_per_trial
#         ),
#         # Change the folder name when changing dataset--------------------------------------------------------------------------
#         name='Stage2/pbt/J',
#         metric='auroc',
#         mode='max',
#         config={'learning_rate': 1e-4,
#                 'n_fmaps': tune.grid_search([8, 16, 32, 64, 128]),
#                 'batch_size': 8,
#                 },
#         # config={'learning_rate': 0.01,
#         #         'n_fmaps': 32,
#         #         'batch_size': 32,
#         #         },
#         # stop=TrialPlateauStopper('loss_G'),
#         resources_per_trial={'cpu': os.cpu_count(),
#                              'gpu': gpus_per_trial,
#                             },
#         local_dir='./drive/MyDrive/Logs',
#         scheduler = ASHAScheduler(max_t=num_epochs, grace_period=2,  reduction_factor=2),
#         progress_reporter=JupyterNotebookReporter(
#             overwrite=True,
#             parameter_columns=['learning_rate', 'n_fmaps', 'batch_size'],
#             metric_columns=['loss_G', 'loss_D', 'auroc', 'auroc_cross', 'training_iteration'],
#             sort_by_metric=True,
#         ),
#         fail_fast = True,
#         # reuse_actors=True,
#         # num_samples=num_samples,
#         resume='PROMPT',
# #         restore='/content/drive/MyDrive/Logs/delete/train_Stage2_e42ac_00025_25_batch_size=8,learning_rate=0.01,n_fmaps=8_2021-07-28_21-16-18/checkpoint_epoch=4-step=2339',
#     )

# #     print('Best hyperparameters found were: ', analysis.best_config)

# # # __tune_asha_end__


# __tune_pbt_begin__
def tune_Stage2_pbt(num_samples=10, num_epochs=10, gpus_per_trial=torch.cuda.device_count()):
    # print(os.cpu_count(), torch.cuda.device_count())
    analysis = tune.run(
        tune.with_parameters(
            train_Stage2,
            num_epochs=num_epochs,
            num_gpus=gpus_per_trial
        ),
        # Change the folder name when changing dataset--------------------------------------------------------------------------
        name='Stage2/pbt/F',
        metric='auroc',
        mode='max',
        config={'learning_rate': 1e-4,
                'n_fmaps': tune.grid_search([8, 16, 32, 64, 128]),
                'res_depth': tune.choice([1, 2, 3, 4]),
                'batch_size': 8,
                },
        # config={'learning_rate': 0.01,
        #         'n_fmaps': 32,
        #         'batch_size': 32,
        #         },
        # stop=TrialPlateauStopper('loss_G'),
        resources_per_trial={'cpu': os.cpu_count(),
                             'gpu': gpus_per_trial,
                            },
        local_dir='./drive/MyDrive/Logs',
        scheduler = PopulationBasedTraining(time_attr='training_iteration',
                                            quantile_fraction=0.5,
                                            resample_probability=0.8,
                                            perturbation_interval=1,
                                            hyperparam_mutations={
                                                'learning_rate': tune.loguniform(1e-7, 1e-1),
                                                'batch_size': [8, 16, 32, 64, 128],
                                            },
                                            ),
        progress_reporter=JupyterNotebookReporter(
            overwrite=False,
            parameter_columns=['learning_rate', 'n_fmaps', 'res_depth', 'batch_size'],
            metric_columns=['loss_G', 'loss_D', 'auroc', 'auroc_cross', 'training_iteration'],
            sort_by_metric=True,
        ),
        fail_fast = True,
        # reuse_actors=True,
        # num_samples=num_samples,
        resume='PROMPT',
    )

    print('Best hyperparameters found were: ', analysis.best_config)

# __tune_pbt_end__


if __name__ == '__main__':
    import argparse

    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--smoke-test', action='store_true', help='Finish quickly for testing')
    args, _ = parser.parse_known_args()

    if args.smoke_test:
        tune_Stage2_asha(num_samples=1, num_epochs=6, gpus_per_trial=torch.cuda.device_count())
        tune_Stage2_pbt(num_samples=1, num_epochs=6, gpus_per_trial=torch.cuda.device_count())
    else:
        # ASHA scheduler
        # tune_Stage2_asha(num_samples=1, num_epochs=10, gpus_per_trial=torch.cuda.device_count())
        # Population based training
        tune_Stage2_pbt(num_samples=1, num_epochs=30, gpus_per_trial=torch.cuda.device_count())