<a href="https://colab.research.google.com/github/souravraha/galaxy/blob/experimental/Lightning_Tune.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Prerequisites/ shell commands

## Install/uninstall packages

In [None]:
# If you are running on Google Colab, uncomment below to install the necessary dependencies 
# before beginning the exercise.

print('Setting up colab environment')
# !pip uninstall -y -q pyarrow
!pip install -q lightning-bolts GPy
!pip install -q ray[debug] ray[default]
!pip install -U -q ray[tune]
# !pip install cloud-tpu-client==0.10 https://storage.googleapis.com/tpu-pytorch/wheels/torch_xla-1.8-cp37-cp37m-linux_x86_64.whl

# # A hack to force the runtime to restart, needed to include the above dependencies.
# print('Done installing! Restarting via forced crash (this is not an issue).')
# import os
# os._exit(0)

## Download and extract data

Here choose the model you wish to use for training/testing. Don't forget to make modifications in the following sections:

1.   GLOBAL in class definition of npyImageData.
2.   correct assignment of metric keys while defining the training wrapper for Tune.
3.   name of the experiment initiated/resumed.

In [None]:
# 'a': 1Cjcw2EWorhdhJSGoWOdxsEUDxvl943dt, 'b': 15yXXC4h5VsytP3Ak1jfUSjQhdgP2s23K, 'c': 1vuQ-pLzoKT4Hd_V7949r9eND9E2fB_u_,
# 'd': , 'e': 1wFuasvb7PthxXtMUlsD13uzYHWlWt06H, 'f': 17l6H61tLAu26zGuei38r_T5ssjbYUeaJ, 
# 'g': 1SxQVosWeEjY3Pyn8LRXA11rLnZ9HK_7B, 'h': 1Atau0RH4oyLAiYReW-G9a8l9pUNltglF, 'i': 15lEgsR1p00KSHieaT9a1nkbJ86pDxwgp, 
# 'j': 1m0EQUbqZZeyl76XsQIKWU5Qd7jGmmWhB, 'k': , 'l': 1meTDi4aeWfdChOiXeLtUOGhjVDVu000e

# fake data
# 'fpgan': 1-4o0yqSBA9WSY9gTYamIez_RAekwDsHV

# !rm -rf fakes/
!gdown --id 17l6H61tLAu26zGuei38r_T5ssjbYUeaJ -O - --quiet | tar --skip-old-files -zxf -
# !rm ./model_f.tgz

# def prepare_data(data_dir: str = '/content'):
#     with FileLock(os.path.expanduser(data_dir+'.lock')):
#         gdown.download('https://drive.google.com/uc?id=17l6H61tLAu26zGuei38r_T5ssjbYUeaJ', data_dir+'/model_j.tgz', quiet=True)
        
#         temp = tarfile.open(data_dir+'/model_j.tgz', 'r|gz')
#         temp.extractall()
#         temp.close()

In [None]:
from google.colab import drive
drive.mount('/content/drive')

# Import libraries

In [1]:
import os
import math
import numpy as np
from pathlib import Path
from itertools import cycle
%matplotlib inline
from matplotlib import pyplot as plt
# ------------------------------------
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader, random_split
# ------------------------------------
from torchvision.models import resnet18
from torchvision.utils import save_image, make_grid
from torchvision import transforms, datasets
# ------------------------------------
import torchmetrics as tm
# ------------------------------------
import pytorch_lightning as pl
from pl_bolts.models.gans import DCGAN
from pl_bolts.callbacks import ModuleDataMonitor
from pytorch_lightning.loggers import TensorBoardLogger
from pl_bolts.models.self_supervised.resnets import BasicBlock
from pytorch_lightning.utilities.cloud_io import load as pl_load
from drive.MyDrive.ml.Callbacks.confused_logits import ConfusedLogitCallback
# ------------------------------------
from ray import tune
from ray.tune.stopper import TrialPlateauStopper
from ray.tune import CLIReporter, JupyterNotebookReporter
from ray.tune.schedulers import ASHAScheduler, PopulationBasedTraining
from ray.tune.schedulers.pb2 import PB2
from ray.tune.integration.pytorch_lightning import TuneReportCallback, TuneReportCheckpointCallback
# ------------------------------------

# Class definitions

## DataModule
This creates dataloaders which need to be supplied to train, validate or test the module we have.

In [3]:
class npyImageData(pl.LightningDataModule):
    def __init__(self, config, img_width: int = 150, data_dir: str = '/content/images/'):
        super().__init__()
        # This method is not implemented
        # self.save_hyperparameters()
        self.bs = config['bs']
        self.data_dir = os.path.expanduser(data_dir)
        
        # Change the source file containing mean and stdv when changing dataset ------------------------------------------------------
        self.transform = transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.RandomVerticalFlip(),
            # F : [mean=71.75926373866668, std=96.139484964214, min=5.0, range=961.0]
            # J : [mean=50.271541595458984, std=94.8838882446289, min=0, range=1007.0]
            transforms.Normalize(mean=(0,), std=(961,)),
            transforms.Normalize(mean=(0.5,), std=(0.5,)),
            # this shift-scales the pixel values -> [-1, 1]
            transforms.Resize(img_width, transforms.InterpolationMode.NEAREST),
        ])

    @staticmethod
    def npy_loader(path):
        # s=np.load(path).astype('float',copy=False)
        return torch.from_numpy(np.load(path)).unsqueeze(0).float()
        # Convert to tenssor first, and then to float, otherwise final dtype 
        # would be float64, which would raise errors in conv layers      ###### type as

    def setup(self, stage: str = None):
        if stage in ('fit', None):
            self.train_set = datasets.DatasetFolder(os.path.join(self.data_dir,'train'), 
                self.npy_loader, ('.npy'), self.transform,)
            # self.train_set, self.val_set = random_split(self.full_set, [60000, 15000])            
            self.val_set = datasets.DatasetFolder(os.path.join(self.data_dir,'val'),  
                self.npy_loader, ('.npy'), self.transform,)
            self.dims = tuple(self.train_set[0][0].shape)

        if stage in ('test', None):
            self.test_set = datasets.DatasetFolder(os.path.join(self.data_dir,'val'),  
                self.npy_loader, ('.npy'), self.transform,)
            self.dims = getattr(self, 'dims', self.test_set[0][0].shape)
    
    def train_dataloader(self):
        return DataLoader(self.train_set, self.bs, shuffle=True, num_workers=os.cpu_count(), pin_memory=True)

    def val_dataloader(self):
        return DataLoader(self.val_set, self.bs, shuffle=True, num_workers=os.cpu_count(), pin_memory=True)

    def test_dataloader(self):
        return DataLoader(self.test_set, self.bs, shuffle=True, num_workers=os.cpu_count(), pin_memory=True)

In [None]:
class ImageData(pl.LightningDataModule):
    def __init__(self, config, img_width: int = 150, data_dir: str = '/content/fakes'):
        super().__init__()
        # This method is not implemented
        # self.save_hyperparameters()
        self.bs = config['bs']
        self.data_dir = os.path.expanduser(data_dir)
        
        # Change the source file containing mean and stdv when changing dataset ------------------------------------------------------
        self.transform = transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.RandomVerticalFlip(),
            transforms.Grayscale(),
            transforms.ToTensor(),
            # F : [mean=71.75926373866668, std=96.139484964214, min=5.0, range=0.88235295]
            # J : [mean=50.271541595458984, std=94.8838882446289, min=0, range=1007.0]
            transforms.Normalize(mean=(0,), std=(0.88235295,)),
            transforms.Normalize(mean=(0.5,), std=(0.5,)),
            # # this shift-scales the pixel values -> [-1, 1]
            transforms.Resize(img_width, transforms.InterpolationMode.NEAREST),
        ])

    def setup(self, stage: str = None):
        if stage in ('fit', None):
            self.train_set = datasets.ImageFolder(os.path.join(self.data_dir,'train'), self.transform,)
            # self.train_set, self.val_set = random_split(self.full_set, [60000, 15000])            
            self.val_set = datasets.ImageFolder(os.path.join(self.data_dir,'val'), self.transform,)
            self.dims = tuple(self.train_set[0][0].shape)

        if stage in ('predict', None):
            self.predict_set = datasets.ImageFolder(os.path.join(self.data_dir,'val'), self.transform,)
            self.dims = getattr(self, 'dims', self.predict_set[0][0].shape)
    
    def train_dataloader(self):
        return DataLoader(self.train_set, self.bs, shuffle=True, num_workers=os.cpu_count(), pin_memory=False)

    def val_dataloader(self):
        return DataLoader(self.val_set, self.bs, shuffle=True, num_workers=os.cpu_count(), pin_memory=False)

    def predict_dataloader(self):
        return DataLoader(self.predict_set, self.bs, shuffle=True, num_workers=os.cpu_count(), pin_memory=False)

In [None]:
dm = npyImageData({'lr': 0.001, 'bs': 8})
# model = LensResnet.load_from_checkpoint(os.path.join(BEST_F_RESNET, 'checkpoint'), config={'lr': 0.001, 'bs': 8})
# trainer = pl.Trainer(gpus=1)
# trainer.predict(model, dm)

## ResNet
We modify a ResNet slightly for our purpose.

In [None]:
m = LensResnet({'lr': 1e-3})

In [4]:
PRE_F_RESNET = '/content/drive/MyDrive/Logs/F/LensResnet/PRETRAINED.pth'
PRE_J_RESNET = '/content/drive/MyDrive/Logs/J/LensResnet/PRETRAINED.pth'

class LensResnet(pl.LightningModule):
    def __init__(self, config, image_channels: int = 1, num_classes: int = 3, **kwargs):
        super().__init__()
        self.save_hyperparameters(ignore=config)
        self.lr = config['lr']

        # --------------------------------------------------------------------------
        self.backbone = torch.load(PRE_F_RESNET, map_location=self.device)
        
        self.train_metrics = tm.MetricCollection([tm.AUROC(self.hparams.num_classes, average='weighted'),],
            prefix='LensResnet/train/'
        )
        self.val_metrics = tm.MetricCollection([tm.AUROC(self.hparams.num_classes, average=None),
                                                tm.ROC(self.hparams.num_classes),],
        )

    def configure_optimizers(self):
        return torch.optim.Adam(self.backbone.parameters(), self.lr)

    def forward(self, x, prob=False):
        logits = self.backbone(x)
        return logits.softmax(1) if prob else logits

    def training_step(self, batch, batch_idx):
        imgs, labels = batch
        self.last_logits = self(imgs)
        loss = F.cross_entropy(self.last_logits, labels)
        self.log('LensResnet/train/loss', loss)
        #  keep only scalars here, for no errors
        
        preds = self.last_logits.softmax(1)
        self.train_metrics.update(preds, labels)
        try:
            self.log_dict(self.train_metrics.compute(), prog_bar=True)
        except Exception as f:
            print(f)
        finally:            
            # self.train_metrics.reset()
            # self.log_dict automatically resets at the end of epoch
            return loss

    def validation_step(self, batch, batch_idx):
        imgs, labels = batch
        logits = self(imgs)
        loss = F.cross_entropy(logits, labels)
        self.log('LensResnet/val/loss', loss)
        #  keep only scalars here, for no errors
        
        preds = logits.softmax(1)
        self.val_metrics.update(preds, labels)

    def validation_epoch_end(self, Listofdicts):
        # prediction, target = torch.cat([x['pred'] for x in Listofdicts]), torch.cat([x['target'] for x in Listofdicts])
        # aurocTensor = tm.functional.auroc(prediction, target, num_classes=self.hparams.num_classes, average=None)
        try:
            scoresDict = self.val_metrics.compute()
        except Exception as f:
            # print(f)
            print(torch.unique(torch.stack(self.val_metrics.AUROC.target)).tolist())
        else:
            self.log('LensResnet/val/auroc', scoresDict['AUROC'].min(), prog_bar=True)
            fprList, tprList, _ = scoresDict['ROC']  # tm.functional.roc(prediction, target, num_classes=self.hparams.num_classes)
            
            f = plt.figure()
            colors = cycle(['red', 'blue', 'green'])
            for i, color in zip(range(self.hparams.num_classes), colors):
                plt.plot(fprList[i].cpu(), tprList[i].cpu(), color=color,
                        label='ROC curve of class {0} (area = {1:0.2f})'
                        ''.format(i, scoresDict['AUROC'][i]))
            plt.plot([0, 1], [0, 1], 'k--')
            plt.xlim([0.0, 1.0])
            plt.ylim([0.0, 1.05])
            plt.xlabel('False Positive Rate')
            plt.ylabel('True Positive Rate')
            plt.title('One vs. all ROC curve')
            plt.legend(loc='lower right')

            self.logger.experiment.add_figure('LensResnet/val/ROC', f)
            f.savefig(str(self.trainer.log_dir) + '/ROC_step_{:05d}.pdf'.format(self.global_step))
        finally:
            self.val_metrics.reset()

## LensGAN128
Here we subclass a DCGAN to create our low resolution GAN.

In [7]:
m = LensGAN128({'lr':0.001, 'n_fmaps': 128, 'bs': 4})



In [None]:
m.generator

In [12]:
BEST_F_RESNET = '/content/drive/MyDrive/Logs/F/LensResnet/pbt_tanh/train_LensResnet_eb619_00000_0_2021-09-02_19-42-34/checkpoint_epoch=2-step=28124'
BEST_J_RESNET = '/content/drive/MyDrive/Logs/J/LensResnet/pbt_tanh_fine/train_LensResnet_93609_00000_0_2021-09-02_21-06-00/checkpoint_epoch=2-step=28124'

def post_plotting(ax):
    ax.plot([0, 1], [0, 1], 'k--')
    ax.legend(loc='lower right')

class LensGAN128(DCGAN):
    def __init__(self, config, num_classes: int = 3, **kwargs):
        super().__init__(feature_maps_gen=config['n_fmaps'], feature_maps_disc=config['n_fmaps'], learning_rate=config['lr'], **kwargs)
        self.save_hyperparameters(ignore=config)

        # self.generator.add_module('emb', nn.Embedding(self.hparams.num_classes, self.hparams.latent_dim))

        self.generator.gen[0] = self.generator._make_gen_block(self.hparams.feature_maps_gen * 16, self.hparams.feature_maps_gen * 8)
        first = self.generator._make_gen_block(self.hparams.num_classes + self.hparams.latent_dim, self.hparams.feature_maps_gen * 16, kernel_size=4, stride=1, padding=0)
        self.generator.gen = nn.Sequential(first, *list(self.generator.gen))

        self.discriminator.disc[-1] = self.discriminator._make_disc_block(self.hparams.feature_maps_disc * 8, self.hparams.feature_maps_disc * 16)
        last = self.discriminator._make_disc_block(self.hparams.feature_maps_disc * 16, 1, kernel_size=4, stride=1, padding=0, last_block=True)
        del last[-1]
        self.discriminator.disc = nn.Sequential(*list(self.discriminator.disc), last)
        
        for i in range(5):
            self.generator.gen[i][-1] = self.discriminator.disc[i][-1]
            self.generator.gen[i] = nn.Sequential(*list(self.generator.gen[i]), nn.Dropout())
            self.discriminator.disc[i] = nn.Sequential(*list(self.discriminator.disc[i]), nn.Dropout())

        self.criterion = nn.BCEWithLogitsLoss()

        temp = LensResnet.load_from_checkpoint(os.path.join(BEST_F_RESNET, 'checkpoint'))
        temp.freeze()
        self.modelF = temp.backbone
        # self.modelF = torch.load(PRE_F_RESNET, map_location=self.device).eval())
        self.lastF = self.modelF.fc
        self.modelF.fc = nn.Identity()
        
        temp = LensResnet.load_from_checkpoint(os.path.join(BEST_J_RESNET, 'checkpoint'))
        temp.freeze()
        self.modelJ = temp.backbone
        # self.modelJ = torch.load(PRE_J_RESNET, map_location=self.device).eval())
        self.lastJ = self.modelJ.fc
        self.modelJ.fc = nn.Identity()

        self.imgMetrics = tm.MetricCollection(
            {
                'FID_F' : tm.FID(self.modelF),
                'FID_J' : tm.FID(self.modelJ),
            },
            prefix='LensGAN128/val/',
        )

    def forward(self, noise, labels):
        inp = torch.cat((F.one_hot(labels, self.hparams.num_classes), noise), 1)
        return super().forward(inp)

    def training_step(self, batch, batch_idx, optimizer_idx):
        real, labels = batch
        fake = self._get_fake_pred(labels).type_as(real)

        # Train discriminator
        result = None
        if optimizer_idx == 0:
            result = self._disc_step(real, fake.detach())

        # Train generator
        if optimizer_idx == 1:
            result = self._gen_step(fake, labels)

        return result

    def _disc_step(self, real, fake):
        disc_loss = self._get_disc_loss(real, fake)
        self.log('LensGAN128/D/train/loss', disc_loss)
        return disc_loss

    def _gen_step(self, fake, labels):
        gen_loss = self._get_gen_loss(fake, labels)
        self.log('LensGAN128/G/train/loss', gen_loss)
        return gen_loss

    def _get_disc_loss(self, real, fake):
        # Train with real
        real_pred = self.discriminator(real)
        real_gt = torch.ones_like(real_pred)
        real_loss = self.criterion(real_pred, real_gt)

        # Train with fake
        fake_pred = self.discriminator(fake)
        fake_gt = torch.zeros_like(fake_pred)
        fake_loss = self.criterion(fake_pred, fake_gt)

        disc_loss = real_loss + fake_loss

        return disc_loss

    def _get_gen_loss(self, fake, labels):
        # Train with fake
        fake_pred = self.discriminator(fake)
        fake_gt = torch.ones_like(fake_pred)
        gen_loss = self.criterion(fake_pred, fake_gt)

        # -----------------------------------------------------------------------------
        # class_pred = self.lastF(self.modelF(fake))
        # gen_loss += np.random.rand() * nn.CrossEntropyLoss()(class_pred, labels)

        return gen_loss

    def _get_fake_pred(self, labels):
        batch_size = len(labels)
        noise = self._get_noise(batch_size, self.hparams.latent_dim)
        fake = self(noise, labels)

        return fake

    def validation_step(self, batch, batch_idx):
        imgs64, labels = batch
        self.imgMetrics.update(F.interpolate(imgs64, 150), real=True)

        fake = F.interpolate(self._get_fake_pred(labels), 150).type_as(imgs64)
        self.imgMetrics.update(fake, real=False)
        # np.savez_compressed(self.trainer.log_dir + '/samples_{:05d}'.format(self.global_step + batch_idx),
        #                     x = fake.cpu(), y = labels.cpu())
        return {
            'predF': self.lastF(self.modelF(fake)).softmax(1), 
            'predJ': self.lastJ(self.modelJ(fake)).softmax(1), 
            'target': labels
        }

    def validation_epoch_end(self, ListofDicts):
        fid = self.imgMetrics.compute()
        self.log_dict(fid)

        target = torch.cat([x['target'] for x in ListofDicts])
        if len(torch.unique(target)) == self.hparams.num_classes:
            fig, ax = plt.subplots(1,2, 
                subplot_kw={'xlim': [0,1], 'xlabel': 'False Positive Rate', 'ylim': [0,1.05], 
                            'ylabel': 'True Positive Rate',
                },
                figsize=[11, 5],
            )
            for j, src in enumerate(['F', 'J']):
                prediction = torch.cat([x['pred' + src] for x in ListofDicts])
                aurocTensor = tm.functional.auroc(prediction, target, num_classes=self.hparams.num_classes, average=None)
                self.log('LensGAN128/LensResnet(' + src + ')/val/auroc', aurocTensor.min())
                fprList, tprList, _ = tm.functional.roc(prediction, target, num_classes=self.hparams.num_classes)
                
                colors = cycle(['red', 'blue', 'green'])
                for i, color in zip(range(self.hparams.num_classes), colors):
                    ax[j].plot(fprList[i].cpu(), tprList[i].cpu(), color=color,
                            label='ROC curve of class {0} (area = {1:0.2f})'
                            ''.format(i, aurocTensor[i]))
                post_plotting(ax[j])
                ax[j].set_title('One vs. all ROC curve (FID : {:0.2f})'.format(fid['LensGAN128/val/FID_'+ src]))
        
            fig.tight_layout()
            self.logger.experiment.add_figure('LensGAN128/LensResnet/val/ROC', fig)
            fig.savefig(str(self.trainer.log_dir) + '/ROC_step_{:05d}.pdf'.format(self.global_step))
        else:
            print('Only these labels were encountered : ', torch.unique(target).tolist())

        labels = torch.arange(self.hparams.num_classes, device=self.device)
        imgs = self._get_fake_pred(labels)

        while len(labels) > 0:
            x, y = self.trainer.datamodule.val_set[np.random.randint(0, len(self.trainer.datamodule.val_set))]
            if y == labels[0]:
                imgs = torch.cat((imgs, x.unsqueeze(0).type_as(imgs)))
                labels = labels[1:]

        save_image(F.interpolate(imgs, 150), 
                str(self.trainer.log_dir) + '/Fake_step_{:05d}.pdf'.format(self.global_step), 
                #  kwargs for make_grid
                nrow=self.hparams.num_classes, normalize=True, value_range=(-1,1))

## Stage 2
Here we subclass a DCGAN to create our high resolution GAN.

In [None]:
class Generator2(nn.Module):
    def __init__(self, ngf: int = 128, image_channels: int = 1, res_depth: int = 6):
        super().__init__()

        ker, strd = 4, 2
        pad = int((ker - 2)/2)
        res_ker, res_strd, res_pad = 3, 1, 1
        
        # 64 -> 32
        self.preprocessing = nn.Sequential(
            nn.Conv2d(image_channels, ngf, ker, strd, pad, bias=False),
            nn.ReLU(True)
        )
        # residuals
        layer = []
        for _ in range(res_depth):
            layer.append(BasicBlock(ngf, ngf))
        self.residual = nn.Sequential(*layer)
        
        self.ending_residual = nn.Sequential(
            nn.Conv2d(ngf, ngf, res_ker, res_strd, res_pad, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True)
        )

        # at this part, add the residual inputs from after the preprocessing

        image_width = 150 # upscaling should be factor of 2 increase
        mode = 'nearest' # upscaling method is nearest-neighbour
        self.main = nn.Sequential(
            # 32 -> 75
            nn.Upsample(image_width//2, mode=mode),
            nn.Conv2d(ngf, ngf*4, res_ker, res_strd, res_pad, bias=False),
            nn.BatchNorm2d(ngf*4),
            nn.ReLU(True),
            # 75 -> 150
            nn.Upsample(image_width, mode=mode),
            nn.Conv2d(ngf*4, image_channels, res_ker, res_strd, res_pad, bias=False),
            nn.Tanh()
        )

    def forward(self, in_x):
        x_p = self.preprocessing(in_x)
        x_r = x_p
        x_r = self.residual(x_r)
        x_r = self.ending_residual(x_r)
        # large residual connections
        x_f = x_r + x_p
        return self.main(x_f)

In [None]:
BEST_F_LensGAN128 = '/content/drive/MyDrive/Logs/F/LensGAN128/pbt_tanh_1/train_LensGAN128_90727_00001_1_n_fmaps=16_2021-08-30_08-02-05/checkpoint_epoch=4-step=1988/'
BEST_J_LensGAN128 ='/content/drive/MyDrive/Logs/J/LensGAN128/pbt_tanh/train_LensGAN128_28c03_00003_3_n_fmaps=64_2021-08-31_20-45-44/checkpoint_epoch=2-step=1403/'

In [None]:
class Stage2(DCGAN):
    def __init__(self, config, num_classes: int = 3, **kwargs):
        super().__init__(feature_maps_gen=config['n_fmaps'], feature_maps_disc=config['n_fmaps'], learning_rate=config['learning_rate'])
        self.save_hyperparameters(ignore=config)

        self.generator = Generator2(self.hparams.feature_maps_gen, self.hparams.image_channels, config['res_depth'])

        # These are better as attributes, instead of being returned by a method
        self.modelF = getattr(self, 'modelF', LensResnet.load_from_checkpoint(os.path.join(BEST_RESNET_F, 'checkpoint')).eval())
        self.modelJ = getattr(self, 'modelJ', LensResnet.load_from_checkpoint(os.path.join(BEST_RESNET_J, 'checkpoint')).eval())
        # Workaround:
        self.lowres = getattr(self, 'lowres', LensGAN128.load_from_checkpoint(os.path.join(BEST_LensGAN128_F, 'checkpoint')).eval())
        
        metrics = tm.MetricCollection(
            [
             tm.AUROC(num_classes=self.hparams.num_classes, compute_on_step=False, average=None), 
             tm.ROC(num_classes=self.hparams.num_classes, compute_on_step=False),
            ]
        )
        self.metricsF = metrics.clone()
        self.metricsJ = metrics.clone()

    def forward(self, noise):
        return self.generator(noise)

    def training_step(self, batch, batch_idx, optimizer_idx):
        real, self.labels = batch

        # Train discriminator
        result = None
        if optimizer_idx == 0:
            result = self._disc_step(real)

        # Train generator
        if optimizer_idx == 1:
            result = self._gen_step(real)

        return result

    def _disc_step(self, real):
        disc_loss = self._get_disc_loss(real)
        self.log('Stage2/D/train/loss', disc_loss, on_epoch=True)
        return disc_loss

    def _gen_step(self, real):
        gen_loss = self._get_gen_loss(real)
        self.log('Stage2/G/train/loss', gen_loss, on_epoch=True)
        return gen_loss

    def _get_gen_loss(self, real: torch.Tensor) -> torch.Tensor:
        # Train with fake
        fake_pred = self._get_fake_pred(real)
        fake_gt = torch.ones_like(fake_pred)
        gen_loss = self.criterion(fake_pred, fake_gt)

        # class_pred =  self._get_class_pred(len(real))
        # gen_loss += F.cross_entropy(class_pred, self.labels)

        return gen_loss

    def _get_class_pred(self, batch_size) -> torch.Tensor:
        # ----------------------------------------------------------------------------------------------------------------
        return self.modelF.backbone(self(self._get_noise(batch_size, self.hparams.latent_dim)))

    def _get_noise(self, n_samples: int, latent_dim: int, labels = None):
        # can't use self in function definition
        if labels is None:
            labels = self.labels
            # getattr(self, 'labels', torch.randint(self.hparams.num_classes, (n_samples,), device=self.device))  # last dimension is the hidden dimension
        return self.lowres(super()._get_noise(n_samples, latent_dim), labels)

    def validation_step(self, batch, batch_idx):
        imgs, labels = batch
        out = self(self._get_noise(labels.shape[0], self.hparams.latent_dim, labels))
        self.metricsF.update(self.modelF(out), labels)
        self.metricsJ.update(self.modelJ(out), labels)
        # out = Fig.interpolate(out_64, 150)

    def validation_epoch_end(self, listofDicts):
        fig, ax = plt.subplots(1,2, 
            subplot_kw={'xlim': [0,1], 'xlabel': 'False Positive Rate', 'ylim': [0,1.05], 
                        'ylabel': 'True Positive Rate',
            },
            figsize=[11, 5],
        )
        for j, letter in enumerate(['F', 'J']):
            output = getattr(self, 'metrics' + letter).compute()
            self.log('Stage2/ResNet(' + letter + ')/val/auroc', output['AUROC'].min())
            fprList, tprList, _ = output['ROC']
            
            colors = cycle(['red', 'blue', 'green'])
            for i, color in zip(range(self.hparams.num_classes), colors):
                ax[j].plot(fprList[i].cpu(), tprList[i].cpu(), color=color,
                        label='ROC curve of class {0} (area = {1:0.2f})'
                        ''.format(i, output['AUROC'][i]))
            post_plotting(ax[j])
            ax[j].set_title('One vs. all ROC curve (' + letter + ')')
        
        fig.tight_layout()
        self.logger.experiment.add_figure('Stage2/ResNet/val/ROC', fig)
        fig.savefig(str(self.trainer.log_dir) + '/ROC_step_{:05d}.pdf'.format(self.global_step))

        labels = torch.arange(self.hparams.num_classes, device=self.device)
        save_image(self(self._get_noise(labels.shape[0], self.hparams.latent_dim, labels)), 
                   str(self.trainer.log_dir) + '/Fake_step_{:05d}.pdf'.format(self.global_step), 
                  #  kwargs for make_grid
                   normalize=True, value_range=(-1,1))

    def on_fit_end(self):
        delattr(self, 'modelF')
        delattr(self, 'modelJ')
        delattr(self, 'labels')
        delattr(self, 'lowres')

# Tune Hyperparameters


## ResNet
Here we tune hyperparameters as we train our modified ResNet.

In [None]:
%rm -rf /content/drive/MyDrive/Logs/fakeF/PGAN/LensResnet/pbt_tanh_validate

In [None]:
# __tune_train_checkpoint_begin
def train_LensResnet(config, checkpoint_dir=None, num_epochs=10, num_gpus=torch.cuda.device_count()):
    # print(os.cpu_count(), torch.cuda.device_count())
    kwargs = {
        # 'limit_train_batches' : 0.005,
        # 'limit_val_batches' : 0.005,
        'progress_bar_refresh_rate' : math.ceil(8250//config['bs']),
        'max_epochs' : num_epochs,
        'prepare_data_per_node' : False,
        # If fractional GPUs passed in, convert to int.
        'gpus' : math.ceil(num_gpus),
        'logger' : TensorBoardLogger(save_dir=tune.get_trial_dir(), name='', version='.'),
        'callbacks' : [
            TuneReportCheckpointCallback(
                {
                    'loss': 'LensResnet/val/loss', 
                    'auroc': 'LensResnet/val/auroc', 
                },
            ),
            ModuleDataMonitor(['backbone.layer2', 'backbone.layer4', 'backbone.fc']),
            ConfusedLogitCallback(5),
        ],
        'stochastic_weight_avg' : True,
        # works with only one optimizer
        'benchmark' : True,
        'precision' : 16,     # can't use on cpu
        # 'track_grad_norm': 2,
        # 'gradient_clip_val' : 0.5, 
        # 'gradient_clip_algorithm' : 'value',
    }
    
    dm = npyImageData(config)                                              # Specify image width here    
    if checkpoint_dir is not None:
        kwargs['resume_from_checkpoint'] = os.path.join(checkpoint_dir, 'checkpoint')
        # model = LensResnet.load_from_checkpoint(kwargs['resume_from_checkpoint'], config=config)
    # else:

    model = LensResnet(config)
    trainer = pl.Trainer(**kwargs)

    trainer.fit(model, dm)
# __tune_train_checkpoint_end__

# __tune_pbt_begin__
def tune_LensResnet_pbt(num_samples=10, num_epochs=10, gpus_per_trial=torch.cuda.device_count()):
    # print(os.cpu_count(), torch.cuda.device_count())
    analysis = tune.run(
        tune.with_parameters(
            train_LensResnet,
            num_epochs=num_epochs,
            num_gpus=gpus_per_trial
        ),
        # Change the folder name when changing dataset--------------------------------------------------------------------------
        name='J/LensResnet/pbt_tanh_fine',
        metric='loss',
        mode='min',
        # stop=TrialPlateauStopper('auroc'),
        resources_per_trial={'cpu': os.cpu_count(), 'gpu': gpus_per_trial},
        local_dir='./drive/MyDrive/Logs',
        # config={'lr': tune.choice([1e-4, 1e-3, 1e-5, 1e-2, 1e-6, 1e-1, 1e-7]),
        #         'bs': tune.grid_search([8, 16, 32, 64, 128]),
        #         },
        # scheduler = pbtScheduler(max_t=num_epochs, grace_period=2, reduction_factor=2),
        # Can't use RB2 as it requires mutations to be continuous
        config={'lr': 1e-5,
                'bs': 8,
                # RuntimeError: stack expects each tensor to be equal size, but got [128] at entry 0 and [120] at entry 585
                },
        scheduler = PopulationBasedTraining(time_attr='training_iteration', quantile_fraction=0.4,
                                            resample_probability=0.2,  perturbation_interval=1,
                                            hyperparam_mutations={
                                                'lr': tune.loguniform(1e-6, 1e-4),
                                                'bs': [8, 16, 32, 64, 128],
                                            },
        ),
        progress_reporter=JupyterNotebookReporter(
            overwrite=False,
            parameter_columns=['lr', 'bs'],
            metric_columns=['loss', 'auroc', 'training_iteration'],
        ),
        fail_fast = True,
        # reuse_actors=True,
        num_samples=num_samples,
        # resume='PROMPT',
    )
    BEST_J_RESNET = analysis.best_checkpoint
    print('Best checkpoint path found is: ', BEST_J_RESNET)

# __tune_pbt_end__

if __name__ == '__main__':
    import argparse

    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--smoke-test', action='store_true', help='Finish quickly for testing')
    args, _ = parser.parse_known_args()

    if args.smoke_test:
        tune_LensResnet_pbt(num_samples=1, num_epochs=6, gpus_per_trial=torch.cuda.device_count())
    else:
        # pbt scheduler
        tune_LensResnet_pbt(num_samples=1, num_epochs=5, gpus_per_trial=torch.cuda.device_count())

## LensGAN128
Here we tune hyperparameters as we train our modified DCGAN.

In [10]:
%rm -rf drive/MyDrive/Logs/F/LensGAN128/onehot_leaky/

In [None]:
# __tune_train_checkpoint_begin
def train_LensGAN128(config, checkpoint_dir=None, num_epochs=10, num_gpus=torch.cuda.device_count()):
    # print(os.cpu_count(), torch.cuda.device_count())
    kwargs = {
        'limit_train_batches' : 0.05,
        'limit_val_batches' : 0.05,
        'progress_bar_refresh_rate' : math.ceil(8250 * 0.05//config['bs']),
        'max_epochs' : num_epochs,
        'prepare_data_per_node' : False,
        # If fractional GPUs passed in, convert to int.
        'gpus' : math.ceil(num_gpus),
        'logger' : TensorBoardLogger(save_dir=tune.get_trial_dir(), name='', version='.'),
        'callbacks' : [
            TuneReportCheckpointCallback(
                {
                    'loss_G': 'LensGAN128/G/train/loss', 
                    'loss_D': 'LensGAN128/D/train/loss', 
                    # Switch up the FID vlues when training on different dataset -----------------------------------------------
                    'FID': 'LensGAN128/val/FID_F', 
                    'FID_cross': 'LensGAN128/val/FID_J',
                    'auroc': 'LensGAN128/LensResnet(F)/val/auroc',
                    'auroc_cross': 'LensGAN128/LensResnet(J)/val/auroc',
                },
            ),
            # ModuleDataMonitor(True),
        ],
        # 'stochastic_weight_avg' : True,
        # works with only one optimizer
        'benchmark' : True,
        'precision' : 16,
        # 'gradient_clip_val' : 0.5, 
        # 'gradient_clip_algorithm' : 'value',
    }
    
    dm = npyImageData(config, 128)                                              # Specify image width here    
    if checkpoint_dir is not None:
        kwargs['resume_from_checkpoint'] = os.path.join(checkpoint_dir, 'checkpoint')

    model = LensGAN128(config)
    trainer = pl.Trainer(**kwargs)

    trainer.fit(model, dm)
# __tune_train_checkpoint_end__

# __tune_pbt_begin__
def tune_LensGAN128_pbt(num_samples=10, num_epochs=10, gpus_per_trial=torch.cuda.device_count()):
    # print(os.cpu_count(), torch.cuda.device_count())
    analysis = tune.run(
        tune.with_parameters(
            train_LensGAN128,
            num_epochs=num_epochs,
            num_gpus=gpus_per_trial
        ),
        # Change the folder name when changing dataset--------------------------------------------------------------------------
        name='F/LensGAN128/onehot_leaky_dropout',
        metric='FID',
        mode='min',
        # stop=TrialPlateauStopper('FID'),
        resources_per_trial={'cpu': os.cpu_count(), 'gpu': gpus_per_trial},
        local_dir='./drive/MyDrive/Logs',
        # config={'lr': tune.choice([1e-4, 1e-3, 1e-5, 1e-2, 1e-6, 1e-1, 1e-7]),
        #         'bs': tune.grid_search([8, 16, 32, 64, 128]),
        #         },
        # scheduler = pbtScheduler(max_t=num_epochs, grace_period=2, reduction_factor=2),
        # Can't use RB2 as it requires mutations to be continuous
        config={'lr': 1e-3,
                'n_fmaps': tune.grid_search([8, 16, 32, 64]),
                'bs': 8,
                },
        # config = {'lr': 2.340983544823817e-05, 'n_fmaps': 32, 'bs': 8},
        scheduler = PopulationBasedTraining(time_attr='training_iteration', quantile_fraction=0.25,
                                            resample_probability=0.25,  perturbation_interval=1,
                                            hyperparam_mutations={
                                                'lr': tune.loguniform(1e-5, 1e-3),
                                                'bs': [8, 16, 32, 64, 128],
                                            },
        ),
        progress_reporter=JupyterNotebookReporter(
            overwrite=False,
            parameter_columns=['lr', 'n_fmaps', 'bs'],
            metric_columns=['loss_G', 'loss_D', 'FID', 'auroc', 'FID_cross', 
                            'auroc_cross', 'training_iteration'],
        ),
        fail_fast = True,
        # reuse_actors=True,
        num_samples=num_samples,
        # resume='PROMPT',
        # restore=BEST_J_LensGAN128,
    )
    # ---------------------------------------------------------------------------------------------

    print('Best checkpoint path found is: ', analysis.best_checkpoint)

# __tune_pbt_end__

if __name__ == '__main__':
    import argparse

    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--smoke-test', action='store_true', help='Finish quickly for testing')
    args, _ = parser.parse_known_args()

    if args.smoke_test:
        tune_LensGAN128_pbt(num_samples=1, num_epochs=6, gpus_per_trial=torch.cuda.device_count())
    else:
        # pbt scheduler
        tune_LensGAN128_pbt(num_samples=1, num_epochs=5, gpus_per_trial=torch.cuda.device_count())

Trial name,status,loc,lr,n_fmaps,bs
train_LensGAN128_bf0ed_00000,PENDING,,0.001,8,8
train_LensGAN128_bf0ed_00001,PENDING,,0.001,16,8
train_LensGAN128_bf0ed_00002,PENDING,,0.001,32,8
train_LensGAN128_bf0ed_00003,PENDING,,0.001,64,8


[2m[36m(pid=7995)[0m Using native 16bit precision.
[2m[36m(pid=7995)[0m GPU available: True, used: True
[2m[36m(pid=7995)[0m TPU available: False, using: 0 TPU cores
[2m[36m(pid=7995)[0m IPU available: False, using: 0 IPUs
[2m[36m(pid=7995)[0m LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Trial name,status,loc,lr,n_fmaps,bs
train_LensGAN128_bf0ed_00000,RUNNING,,0.001,8,8
train_LensGAN128_bf0ed_00001,PENDING,,0.001,16,8
train_LensGAN128_bf0ed_00002,PENDING,,0.001,32,8
train_LensGAN128_bf0ed_00003,PENDING,,0.001,64,8


[2m[36m(pid=7995)[0m 
[2m[36m(pid=7995)[0m   | Name          | Type               | Params
[2m[36m(pid=7995)[0m -----------------------------------------------------
[2m[36m(pid=7995)[0m 0 | generator     | DCGANGenerator     | 385 K 
[2m[36m(pid=7995)[0m 1 | discriminator | DCGANDiscriminator | 176 K 
[2m[36m(pid=7995)[0m 2 | criterion     | BCEWithLogitsLoss  | 0     
[2m[36m(pid=7995)[0m 3 | modelF        | ResNet             | 11.2 M
[2m[36m(pid=7995)[0m 4 | lastF         | Sequential         | 1.5 K 
[2m[36m(pid=7995)[0m 5 | modelJ        | ResNet             | 11.2 M
[2m[36m(pid=7995)[0m 6 | lastJ         | Sequential         | 1.5 K 
[2m[36m(pid=7995)[0m 7 | imgMetrics    | MetricCollection   | 22.3 M
[2m[36m(pid=7995)[0m -----------------------------------------------------
[2m[36m(pid=7995)[0m 562 K     Trainable params
[2m[36m(pid=7995)[0m 22.3 M    Non-trainable params
[2m[36m(pid=7995)[0m 22.9 M    Total params
[2m[36m(pid=7995

[2m[36m(pid=7995)[0m Validation sanity check: 0it [00:00, ?it/s]Validation sanity check:   0%|          | 0/2 [00:00<?, ?it/s]


[2m[36m(pid=7995)[0m   return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)


Trial name,status,loc,lr,n_fmaps,bs
train_LensGAN128_bf0ed_00000,RUNNING,,0.001,8,8
train_LensGAN128_bf0ed_00001,PENDING,,0.001,16,8
train_LensGAN128_bf0ed_00002,PENDING,,0.001,32,8
train_LensGAN128_bf0ed_00003,PENDING,,0.001,64,8


[2m[36m(pid=7995)[0m   "`Trainer.running_sanity_check` has been renamed to `Trainer.sanity_checking` and will be removed in v1.5."


Epoch 0:   0%|          | 0/514 [00:00<00:00, 2777.68it/s]  
Epoch 0:  10%|▉         | 51/514 [00:03<00:29, 15.86it/s, loss=1.49, v_num=.]
Epoch 0:  20%|█▉        | 102/514 [00:06<00:25, 16.44it/s, loss=1.91, v_num=.]
Epoch 0:  30%|██▉       | 153/514 [00:09<00:21, 16.52it/s, loss=1.65, v_num=.]
Epoch 0:  40%|███▉      | 204/514 [00:12<00:18, 16.54it/s, loss=1.66, v_num=.]
Epoch 0:  50%|████▉     | 255/514 [00:15<00:15, 16.63it/s, loss=2.52, v_num=.]
Epoch 0:  60%|█████▉    | 306/514 [00:18<00:12, 16.72it/s, loss=2.3, v_num=.] 
Epoch 0:  69%|██████▉   | 357/514 [00:21<00:09, 16.83it/s, loss=2.17, v_num=.]
Epoch 0:  79%|███████▉  | 408/514 [00:24<00:06, 16.89it/s, loss=1.53, v_num=.]
Epoch 0:  89%|████████▉ | 459/514 [00:27<00:03, 16.98it/s, loss=1.96, v_num=.]
Epoch 0:  99%|█████████▉| 510/514 [00:27<00:00, 18.51it/s, loss=1.96, v_num=.]
Validating: 0it [00:00, ?it/s][A
Validating:   0%|          | 0/46 [00:00<?, ?it/s][A




[2m[36m(pid=7995)[0m 
[2m[36m(pid=7995)[0m Validating: 100%|██████████| 46/46 [00:04<00:00, 10.89it/s][A


[2m[36m(pid=7995)[0m   "`Trainer.running_sanity_check` has been renamed to `Trainer.sanity_checking` and will be removed in v1.5."


Result for train_LensGAN128_bf0ed_00000:
  FID: 88.5
  FID_cross: 94.3125
  auroc: 0.45989441871643066
  auroc_cross: 0.4412270188331604
  date: 2021-09-13_20-47-46
  done: false
  experiment_id: 7ee666093f6043f9bc108680723b15b4
  hostname: af2913abb491
  iterations_since_restore: 1
  loss_D: 0.3548980951309204
  loss_G: 3.650420665740967
  node_ip: 172.28.0.2
  pid: 7995
  should_checkpoint: true
  time_since_restore: 41.830915451049805
  time_this_iter_s: 41.830915451049805
  time_total_s: 41.830915451049805
  timestamp: 1631566066
  timesteps_since_restore: 0
  training_iteration: 1
  trial_id: bf0ed_00000
  


Trial name,status,loc,lr,n_fmaps,bs,loss_G,loss_D,FID,auroc,FID_cross,auroc_cross,training_iteration
train_LensGAN128_bf0ed_00000,RUNNING,172.28.0.2:7995,0.001,8,8,3.65042,0.354898,88.5,0.459894,94.3125,0.441227,1.0
train_LensGAN128_bf0ed_00001,PENDING,,0.001,16,8,,,,,,,
train_LensGAN128_bf0ed_00002,PENDING,,0.001,32,8,,,,,,,
train_LensGAN128_bf0ed_00003,PENDING,,0.001,64,8,,,,,,,


[2m[36m(pid=7995)[0m 2021-09-13 20:47:47,405	INFO trainable.py:76 -- Checkpoint size is 96370070 bytes
[2m[36m(pid=8067)[0m Using native 16bit precision.
[2m[36m(pid=8067)[0m GPU available: True, used: True
[2m[36m(pid=8067)[0m TPU available: False, using: 0 TPU cores
[2m[36m(pid=8067)[0m IPU available: False, using: 0 IPUs
[2m[36m(pid=8067)[0m LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
[2m[36m(pid=8067)[0m 
[2m[36m(pid=8067)[0m   | Name          | Type               | Params
[2m[36m(pid=8067)[0m -----------------------------------------------------
[2m[36m(pid=8067)[0m 0 | generator     | DCGANGenerator     | 1.1 M 
[2m[36m(pid=8067)[0m 1 | discriminator | DCGANDiscriminator | 701 K 
[2m[36m(pid=8067)[0m 2 | criterion     | BCEWithLogitsLoss  | 0     
[2m[36m(pid=8067)[0m 3 | modelF        | ResNet             | 11.2 M
[2m[36m(pid=8067)[0m 4 | lastF         | Sequential         | 1.5 K 
[2m[36m(pid=8067)[0m 5 | modelJ        | ResNet        

[2m[36m(pid=8067)[0m Validation sanity check: 0it [00:00, ?it/s]Validation sanity check:   0%|          | 0/2 [00:00<?, ?it/s]


[2m[36m(pid=8067)[0m   return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)
[2m[36m(pid=8067)[0m   "`Trainer.running_sanity_check` has been renamed to `Trainer.sanity_checking` and will be removed in v1.5."


Epoch 0:   0%|          | 0/514 [00:00<00:00, 3297.41it/s]  
Epoch 0:  10%|▉         | 51/514 [00:03<00:29, 15.75it/s, loss=1.9, v_num=.]
Epoch 0:  20%|█▉        | 102/514 [00:06<00:25, 16.29it/s, loss=3.21, v_num=.]
Epoch 0:  30%|██▉       | 153/514 [00:09<00:21, 16.49it/s, loss=1.91, v_num=.]
Epoch 0:  40%|███▉      | 204/514 [00:12<00:18, 16.42it/s, loss=1.88, v_num=.]
Epoch 0:  50%|████▉     | 255/514 [00:15<00:15, 16.37it/s, loss=2.29, v_num=.]
Epoch 0:  60%|█████▉    | 306/514 [00:18<00:12, 16.39it/s, loss=1.42, v_num=.]
Epoch 0:  69%|██████▉   | 357/514 [00:21<00:09, 16.40it/s, loss=2.13, v_num=.]
Epoch 0:  79%|███████▉  | 408/514 [00:25<00:06, 16.30it/s, loss=2.29, v_num=.]
Epoch 0:  89%|████████▉ | 459/514 [00:28<00:03, 16.34it/s, loss=2.6, v_num=.] 




[2m[36m(pid=8067)[0m Epoch 0:  99%|█████████▉| 510/514 [00:28<00:00, 17.79it/s, loss=2.6, v_num=.]
[2m[36m(pid=8067)[0m Validating: 0it [00:00, ?it/s][A
Validating:   0%|          | 0/46 [00:00<?, ?it/s][A
[2m[36m(pid=8067)[0m 
Validating: 100%|██████████| 46/46 [00:04<00:00, 10.82it/s][A


[2m[36m(pid=8067)[0m   "`Trainer.running_sanity_check` has been renamed to `Trainer.sanity_checking` and will be removed in v1.5."
2021-09-13 20:48:33,741	INFO pbt.py:490 -- [pbt]: no checkpoint for trial. Skip exploit for Trial train_LensGAN128_bf0ed_00001


Result for train_LensGAN128_bf0ed_00001:
  FID: 91.8125
  FID_cross: 64.9375
  auroc: 0.4725712537765503
  auroc_cross: 0.48893314599990845
  date: 2021-09-13_20-48-33
  done: false
  experiment_id: ba54eed17dca4ff4ba799df668858677
  hostname: af2913abb491
  iterations_since_restore: 1
  loss_D: 0.043915458023548126
  loss_G: 4.852165222167969
  node_ip: 172.28.0.2
  pid: 8067
  should_checkpoint: true
  time_since_restore: 43.23782467842102
  time_this_iter_s: 43.23782467842102
  time_total_s: 43.23782467842102
  timestamp: 1631566113
  timesteps_since_restore: 0
  training_iteration: 1
  trial_id: bf0ed_00001
  


Trial name,status,loc,lr,n_fmaps,bs,loss_G,loss_D,FID,auroc,FID_cross,auroc_cross,training_iteration
train_LensGAN128_bf0ed_00001,RUNNING,172.28.0.2:8067,0.001,16,8,4.85217,0.0439155,91.8125,0.472571,64.9375,0.488933,1.0
train_LensGAN128_bf0ed_00000,PAUSED,,0.001,8,8,3.65042,0.354898,88.5,0.459894,94.3125,0.441227,1.0
train_LensGAN128_bf0ed_00002,PENDING,,0.001,32,8,,,,,,,
train_LensGAN128_bf0ed_00003,PENDING,,0.001,64,8,,,,,,,


[2m[36m(pid=8067)[0m 2021-09-13 20:48:34,254	INFO trainable.py:76 -- Checkpoint size is 111478742 bytes
[2m[36m(pid=8141)[0m Using native 16bit precision.
[2m[36m(pid=8141)[0m GPU available: True, used: True
[2m[36m(pid=8141)[0m TPU available: False, using: 0 TPU cores
[2m[36m(pid=8141)[0m IPU available: False, using: 0 IPUs


Trial name,status,loc,lr,n_fmaps,bs,loss_G,loss_D,FID,auroc,FID_cross,auroc_cross,training_iteration
train_LensGAN128_bf0ed_00002,RUNNING,,0.001,32,8,,,,,,,
train_LensGAN128_bf0ed_00000,PAUSED,,0.001,8,8,3.65042,0.354898,88.5,0.459894,94.3125,0.441227,1.0
train_LensGAN128_bf0ed_00001,PAUSED,,0.001,16,8,4.85217,0.0439155,91.8125,0.472571,64.9375,0.488933,1.0
train_LensGAN128_bf0ed_00003,PENDING,,0.001,64,8,,,,,,,


[2m[36m(pid=8141)[0m LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
[2m[36m(pid=8141)[0m 
[2m[36m(pid=8141)[0m   | Name          | Type               | Params
[2m[36m(pid=8141)[0m -----------------------------------------------------
[2m[36m(pid=8141)[0m 0 | generator     | DCGANGenerator     | 3.6 M 
[2m[36m(pid=8141)[0m 1 | discriminator | DCGANDiscriminator | 2.8 M 
[2m[36m(pid=8141)[0m 2 | criterion     | BCEWithLogitsLoss  | 0     
[2m[36m(pid=8141)[0m 3 | modelF        | ResNet             | 11.2 M
[2m[36m(pid=8141)[0m 4 | lastF         | Sequential         | 1.5 K 
[2m[36m(pid=8141)[0m 5 | modelJ        | ResNet             | 11.2 M
[2m[36m(pid=8141)[0m 6 | lastJ         | Sequential         | 1.5 K 
[2m[36m(pid=8141)[0m 7 | imgMetrics    | MetricCollection   | 22.3 M
[2m[36m(pid=8141)[0m -----------------------------------------------------
[2m[36m(pid=8141)[0m 6.4 M     Trainable params
[2m[36m(pid=8141)[0m 22.3 M    Non-trainable params

[2m[36m(pid=8141)[0m Validation sanity check: 0it [00:00, ?it/s]Validation sanity check:   0%|          | 0/2 [00:00<?, ?it/s]


[2m[36m(pid=8141)[0m   return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)


Trial name,status,loc,lr,n_fmaps,bs,loss_G,loss_D,FID,auroc,FID_cross,auroc_cross,training_iteration
train_LensGAN128_bf0ed_00002,RUNNING,,0.001,32,8,,,,,,,
train_LensGAN128_bf0ed_00000,PAUSED,,0.001,8,8,3.65042,0.354898,88.5,0.459894,94.3125,0.441227,1.0
train_LensGAN128_bf0ed_00001,PAUSED,,0.001,16,8,4.85217,0.0439155,91.8125,0.472571,64.9375,0.488933,1.0
train_LensGAN128_bf0ed_00003,PENDING,,0.001,64,8,,,,,,,


[2m[36m(pid=8141)[0m   "`Trainer.running_sanity_check` has been renamed to `Trainer.sanity_checking` and will be removed in v1.5."


Epoch 0:   0%|          | 0/514 [00:00<00:00, 3705.22it/s]  
Epoch 0:  10%|▉         | 51/514 [00:03<00:33, 13.79it/s, loss=3.34, v_num=.]
Epoch 0:  20%|█▉        | 102/514 [00:07<00:29, 14.14it/s, loss=2.83, v_num=.]
Epoch 0:  30%|██▉       | 153/514 [00:10<00:25, 14.18it/s, loss=1.36, v_num=.]
Epoch 0:  40%|███▉      | 204/514 [00:14<00:21, 14.16it/s, loss=1.85, v_num=.]
Epoch 0:  50%|████▉     | 255/514 [00:18<00:18, 14.15it/s, loss=1.76, v_num=.]
Epoch 0:  60%|█████▉    | 306/514 [00:21<00:14, 14.27it/s, loss=1.57, v_num=.]
Epoch 0:  69%|██████▉   | 357/514 [00:25<00:10, 14.30it/s, loss=2.04, v_num=.]
Epoch 0:  79%|███████▉  | 408/514 [00:28<00:07, 14.31it/s, loss=2.37, v_num=.]
Epoch 0:  89%|████████▉ | 459/514 [00:32<00:03, 14.31it/s, loss=1.56, v_num=.]
Epoch 0:  99%|█████████▉| 510/514 [00:32<00:00, 15.61it/s, loss=1.56, v_num=.]
Validating: 0it [00:00, ?it/s][A
Validating:   0%|          | 0/46 [00:00<?, ?it/s][A




[2m[36m(pid=8141)[0m 
[2m[36m(pid=8141)[0m Validating: 100%|██████████| 46/46 [00:04<00:00, 10.92it/s][A


[2m[36m(pid=8141)[0m   "`Trainer.running_sanity_check` has been renamed to `Trainer.sanity_checking` and will be removed in v1.5."
2021-09-13 20:49:25,629	INFO pbt.py:490 -- [pbt]: no checkpoint for trial. Skip exploit for Trial train_LensGAN128_bf0ed_00002


Result for train_LensGAN128_bf0ed_00002:
  FID: 95.8125
  FID_cross: 68.25
  auroc: 0.5082958936691284
  auroc_cross: 0.435516893863678
  date: 2021-09-13_20-49-25
  done: false
  experiment_id: 2ee24e63b5e64a87974d0d649b2dacba
  hostname: af2913abb491
  iterations_since_restore: 1
  loss_D: 0.8265526294708252
  loss_G: 1.6876806020736694
  node_ip: 172.28.0.2
  pid: 8141
  should_checkpoint: true
  time_since_restore: 48.200762033462524
  time_this_iter_s: 48.200762033462524
  time_total_s: 48.200762033462524
  timestamp: 1631566165
  timesteps_since_restore: 0
  training_iteration: 1
  trial_id: bf0ed_00002
  


Trial name,status,loc,lr,n_fmaps,bs,loss_G,loss_D,FID,auroc,FID_cross,auroc_cross,training_iteration
train_LensGAN128_bf0ed_00002,RUNNING,172.28.0.2:8141,0.001,32,8,1.68768,0.826553,95.8125,0.508296,68.25,0.435517,1.0
train_LensGAN128_bf0ed_00000,PAUSED,,0.001,8,8,3.65042,0.354898,88.5,0.459894,94.3125,0.441227,1.0
train_LensGAN128_bf0ed_00001,PAUSED,,0.001,16,8,4.85217,0.0439155,91.8125,0.472571,64.9375,0.488933,1.0
train_LensGAN128_bf0ed_00003,PENDING,,0.001,64,8,,,,,,,


[2m[36m(pid=8141)[0m 2021-09-13 20:49:26,542	INFO trainable.py:76 -- Checkpoint size is 166763030 bytes
[2m[36m(pid=8214)[0m Using native 16bit precision.
[2m[36m(pid=8214)[0m GPU available: True, used: True
[2m[36m(pid=8214)[0m TPU available: False, using: 0 TPU cores
[2m[36m(pid=8214)[0m IPU available: False, using: 0 IPUs


Trial name,status,loc,lr,n_fmaps,bs,loss_G,loss_D,FID,auroc,FID_cross,auroc_cross,training_iteration
train_LensGAN128_bf0ed_00003,RUNNING,,0.001,64,8,,,,,,,
train_LensGAN128_bf0ed_00001,PAUSED,,0.001,16,8,4.85217,0.0439155,91.8125,0.472571,64.9375,0.488933,1.0
train_LensGAN128_bf0ed_00002,PAUSED,,0.001,32,8,1.68768,0.826553,95.8125,0.508296,68.25,0.435517,1.0
train_LensGAN128_bf0ed_00000,PENDING,,0.001,8,8,3.65042,0.354898,88.5,0.459894,94.3125,0.441227,1.0


[2m[36m(pid=8214)[0m LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
[2m[36m(pid=8214)[0m 
[2m[36m(pid=8214)[0m   | Name          | Type               | Params
[2m[36m(pid=8214)[0m -----------------------------------------------------
[2m[36m(pid=8214)[0m 0 | generator     | DCGANGenerator     | 12.8 M
[2m[36m(pid=8214)[0m 1 | discriminator | DCGANDiscriminator | 11.2 M
[2m[36m(pid=8214)[0m 2 | criterion     | BCEWithLogitsLoss  | 0     
[2m[36m(pid=8214)[0m 3 | modelF        | ResNet             | 11.2 M
[2m[36m(pid=8214)[0m 4 | lastF         | Sequential         | 1.5 K 
[2m[36m(pid=8214)[0m 5 | modelJ        | ResNet             | 11.2 M
[2m[36m(pid=8214)[0m 6 | lastJ         | Sequential         | 1.5 K 
[2m[36m(pid=8214)[0m 7 | imgMetrics    | MetricCollection   | 22.3 M
[2m[36m(pid=8214)[0m -----------------------------------------------------
[2m[36m(pid=8214)[0m 24.0 M    Trainable params
[2m[36m(pid=8214)[0m 22.3 M    Non-trainable params

[2m[36m(pid=8214)[0m Validation sanity check: 0it [00:00, ?it/s]Validation sanity check:   0%|          | 0/2 [00:00<?, ?it/s]


[2m[36m(pid=8214)[0m   return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)


Trial name,status,loc,lr,n_fmaps,bs,loss_G,loss_D,FID,auroc,FID_cross,auroc_cross,training_iteration
train_LensGAN128_bf0ed_00003,RUNNING,,0.001,64,8,,,,,,,
train_LensGAN128_bf0ed_00001,PAUSED,,0.001,16,8,4.85217,0.0439155,91.8125,0.472571,64.9375,0.488933,1.0
train_LensGAN128_bf0ed_00002,PAUSED,,0.001,32,8,1.68768,0.826553,95.8125,0.508296,68.25,0.435517,1.0
train_LensGAN128_bf0ed_00000,PENDING,,0.001,8,8,3.65042,0.354898,88.5,0.459894,94.3125,0.441227,1.0


[2m[36m(pid=8214)[0m   "`Trainer.running_sanity_check` has been renamed to `Trainer.sanity_checking` and will be removed in v1.5."


Epoch 0:   0%|          | 0/514 [00:00<00:00, 3231.36it/s]  
Epoch 0:  10%|▉         | 51/514 [00:08<01:11,  6.50it/s, loss=3.61, v_num=.]
Epoch 0:  20%|█▉        | 102/514 [00:15<01:02,  6.63it/s, loss=2.21, v_num=.]
Epoch 0:  30%|██▉       | 153/514 [00:23<00:54,  6.59it/s, loss=2.11, v_num=.]
Epoch 0:  40%|███▉      | 204/514 [00:30<00:46,  6.63it/s, loss=2.07, v_num=.]
Epoch 0:  50%|████▉     | 255/514 [00:38<00:38,  6.65it/s, loss=1.56, v_num=.]
Epoch 0:  60%|█████▉    | 306/514 [00:45<00:31,  6.68it/s, loss=2.08, v_num=.]
Epoch 0:  69%|██████▉   | 357/514 [00:53<00:23,  6.69it/s, loss=1.81, v_num=.]
Epoch 0:  79%|███████▉  | 408/514 [01:01<00:15,  6.70it/s, loss=2.12, v_num=.]
Epoch 0:  89%|████████▉ | 459/514 [01:08<00:08,  6.70it/s, loss=1.64, v_num=.]
Epoch 0:  99%|█████████▉| 510/514 [01:09<00:00,  7.30it/s, loss=1.64, v_num=.]
Validating: 0it [00:00, ?it/s][A
Validating:   0%|          | 0/46 [00:00<?, ?it/s][A




[2m[36m(pid=8214)[0m 
[2m[36m(pid=8214)[0m Validating: 100%|██████████| 46/46 [00:05<00:00,  8.85it/s][A


[2m[36m(pid=8214)[0m   "`Trainer.running_sanity_check` has been renamed to `Trainer.sanity_checking` and will be removed in v1.5."
2021-09-13 20:50:58,728	INFO pbt.py:490 -- [pbt]: no checkpoint for trial. Skip exploit for Trial train_LensGAN128_bf0ed_00003


Result for train_LensGAN128_bf0ed_00003:
  FID: 115.6875
  FID_cross: 79.5625
  auroc: 0.5020161271095276
  auroc_cross: 0.48553910851478577
  date: 2021-09-13_20-50-58
  done: false
  experiment_id: 9bba428635cb45d99b37e10935dd969a
  hostname: af2913abb491
  iterations_since_restore: 1
  loss_D: 2.1142096519470215
  loss_G: 1.4121789932250977
  node_ip: 172.28.0.2
  pid: 8214
  should_checkpoint: true
  time_since_restore: 88.97758150100708
  time_this_iter_s: 88.97758150100708
  time_total_s: 88.97758150100708
  timestamp: 1631566258
  timesteps_since_restore: 0
  training_iteration: 1
  trial_id: bf0ed_00003
  


Trial name,status,loc,lr,n_fmaps,bs,loss_G,loss_D,FID,auroc,FID_cross,auroc_cross,training_iteration
train_LensGAN128_bf0ed_00003,RUNNING,172.28.0.2:8214,0.001,64,8,1.41218,2.11421,115.688,0.502016,79.5625,0.485539,1
train_LensGAN128_bf0ed_00001,PAUSED,,0.001,16,8,4.85217,0.0439155,91.8125,0.472571,64.9375,0.488933,1
train_LensGAN128_bf0ed_00002,PAUSED,,0.001,32,8,1.68768,0.826553,95.8125,0.508296,68.25,0.435517,1
train_LensGAN128_bf0ed_00000,PENDING,,0.001,8,8,3.65042,0.354898,88.5,0.459894,94.3125,0.441227,1


2021-09-13 20:50:59,685	INFO trainable.py:76 -- Checkpoint size is 96370070 bytes
[2m[36m(pid=8214)[0m 2021-09-13 20:51:00,448	INFO trainable.py:76 -- Checkpoint size is 377601622 bytes


Trial name,status,loc,lr,n_fmaps,bs,loss_G,loss_D,FID,auroc,FID_cross,auroc_cross,training_iteration
train_LensGAN128_bf0ed_00000,RUNNING,,0.001,8,8,3.65042,0.354898,88.5,0.459894,94.3125,0.441227,1
train_LensGAN128_bf0ed_00002,PAUSED,,0.001,32,8,1.68768,0.826553,95.8125,0.508296,68.25,0.435517,1
train_LensGAN128_bf0ed_00003,PAUSED,,0.001,64,8,1.41218,2.11421,115.688,0.502016,79.5625,0.485539,1
train_LensGAN128_bf0ed_00001,PENDING,,0.001,16,8,4.85217,0.0439155,91.8125,0.472571,64.9375,0.488933,1


[2m[36m(pid=8294)[0m 2021-09-13 20:51:04,566	INFO trainable.py:383 -- Restored on 172.28.0.2 from checkpoint: /content/drive/MyDrive/Logs/F/LensGAN128/onehot_leaky_dropout/train_LensGAN128_bf0ed_00000_0_n_fmaps=8_2021-09-13_20-47-01/checkpoint_tmp8cc0ef/./
[2m[36m(pid=8294)[0m 2021-09-13 20:51:04,566	INFO trainable.py:390 -- Current state after restoring: {'_iteration': 1, '_timesteps_total': None, '_time_total': 41.830915451049805, '_episodes_total': None}
[2m[36m(pid=8294)[0m Using native 16bit precision.
[2m[36m(pid=8294)[0m GPU available: True, used: True
[2m[36m(pid=8294)[0m TPU available: False, using: 0 TPU cores
[2m[36m(pid=8294)[0m IPU available: False, using: 0 IPUs
[2m[36m(pid=8294)[0m Restoring states from the checkpoint file at /content/drive/MyDrive/Logs/F/LensGAN128/onehot_leaky_dropout/train_LensGAN128_bf0ed_00000_0_n_fmaps=8_2021-09-13_20-47-01/checkpoint_tmp8cc0ef/./checkpoint
[2m[36m(pid=8294)[0m LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Trial name,status,loc,lr,n_fmaps,bs,loss_G,loss_D,FID,auroc,FID_cross,auroc_cross,training_iteration
train_LensGAN128_bf0ed_00000,RUNNING,,0.001,8,8,3.65042,0.354898,88.5,0.459894,94.3125,0.441227,1
train_LensGAN128_bf0ed_00002,PAUSED,,0.001,32,8,1.68768,0.826553,95.8125,0.508296,68.25,0.435517,1
train_LensGAN128_bf0ed_00003,PAUSED,,0.001,64,8,1.41218,2.11421,115.688,0.502016,79.5625,0.485539,1
train_LensGAN128_bf0ed_00001,PENDING,,0.001,16,8,4.85217,0.0439155,91.8125,0.472571,64.9375,0.488933,1


[2m[36m(pid=8294)[0m Restored all states from the checkpoint file at /content/drive/MyDrive/Logs/F/LensGAN128/onehot_leaky_dropout/train_LensGAN128_bf0ed_00000_0_n_fmaps=8_2021-09-13_20-47-01/checkpoint_tmp8cc0ef/./checkpoint
[2m[36m(pid=8294)[0m 
[2m[36m(pid=8294)[0m   | Name          | Type               | Params
[2m[36m(pid=8294)[0m -----------------------------------------------------
[2m[36m(pid=8294)[0m 0 | generator     | DCGANGenerator     | 385 K 
[2m[36m(pid=8294)[0m 1 | discriminator | DCGANDiscriminator | 176 K 
[2m[36m(pid=8294)[0m 2 | criterion     | BCEWithLogitsLoss  | 0     
[2m[36m(pid=8294)[0m 3 | modelF        | ResNet             | 11.2 M
[2m[36m(pid=8294)[0m 4 | lastF         | Sequential         | 1.5 K 
[2m[36m(pid=8294)[0m 5 | modelJ        | ResNet             | 11.2 M
[2m[36m(pid=8294)[0m 6 | lastJ         | Sequential         | 1.5 K 
[2m[36m(pid=8294)[0m 7 | imgMetrics    | MetricCollection   | 22.3 M
[2m[36m(pid=8294)

[2m[36m(pid=8294)[0m Validation sanity check: 0it [00:00, ?it/s]Validation sanity check:   0%|          | 0/2 [00:00<?, ?it/s]


[2m[36m(pid=8294)[0m   return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)
[2m[36m(pid=8294)[0m   "`Trainer.running_sanity_check` has been renamed to `Trainer.sanity_checking` and will be removed in v1.5."


Epoch 1:   0%|          | 0/514 [00:00<00:00, 2978.91it/s]  
Epoch 1:  10%|▉         | 51/514 [00:03<00:27, 17.08it/s, loss=2.96, v_num=.]
Epoch 1:  20%|█▉        | 102/514 [00:05<00:23, 17.68it/s, loss=2.57, v_num=.]
Epoch 1:  30%|██▉       | 153/514 [00:08<00:20, 17.50it/s, loss=2.41, v_num=.]
Epoch 1:  40%|███▉      | 204/514 [00:11<00:17, 17.62it/s, loss=2.71, v_num=.]
Epoch 1:  50%|████▉     | 255/514 [00:14<00:14, 17.83it/s, loss=2.72, v_num=.]
Epoch 1:  60%|█████▉    | 306/514 [00:17<00:11, 17.86it/s, loss=2.22, v_num=.]
Epoch 1:  69%|██████▉   | 357/514 [00:20<00:08, 17.85it/s, loss=2.66, v_num=.]
Epoch 1:  79%|███████▉  | 408/514 [00:22<00:05, 17.95it/s, loss=2.65, v_num=.]
Epoch 1:  89%|████████▉ | 459/514 [00:25<00:03, 17.85it/s, loss=2.76, v_num=.]
Epoch 1:  99%|█████████▉| 510/514 [00:26<00:00, 19.46it/s, loss=2.76, v_num=.]
Validating: 0it [00:00, ?it/s][A
Validating:   0%|          | 0/46 [00:00<?, ?it/s][A




[2m[36m(pid=8294)[0m 
[2m[36m(pid=8294)[0m Validating: 100%|██████████| 46/46 [00:04<00:00, 11.07it/s][A


[2m[36m(pid=8294)[0m   "`Trainer.running_sanity_check` has been renamed to `Trainer.sanity_checking` and will be removed in v1.5."


Result for train_LensGAN128_bf0ed_00000:
  FID: 14.7265625
  FID_cross: 17.09375
  auroc: 0.45977500081062317
  auroc_cross: 0.47227221727371216
  date: 2021-09-13_20-51-47
  done: false
  experiment_id: 7ee666093f6043f9bc108680723b15b4
  hostname: af2913abb491
  iterations_since_restore: 1
  loss_D: 0.030693329870700836
  loss_G: 6.504297256469727
  node_ip: 172.28.0.2
  pid: 8294
  should_checkpoint: true
  time_since_restore: 43.19827151298523
  time_this_iter_s: 43.19827151298523
  time_total_s: 85.02918696403503
  timestamp: 1631566307
  timesteps_since_restore: 0
  training_iteration: 2
  trial_id: bf0ed_00000
  


[2m[36m(pid=8294)[0m 2021-09-13 20:51:48,149	INFO trainable.py:76 -- Checkpoint size is 96370070 bytes


Trial name,status,loc,lr,n_fmaps,bs,loss_G,loss_D,FID,auroc,FID_cross,auroc_cross,training_iteration
train_LensGAN128_bf0ed_00000,RUNNING,172.28.0.2:8294,0.001,8,8,6.5043,0.0306933,14.7266,0.459775,17.0938,0.472272,2
train_LensGAN128_bf0ed_00002,PAUSED,,0.001,32,8,1.68768,0.826553,95.8125,0.508296,68.25,0.435517,1
train_LensGAN128_bf0ed_00003,PAUSED,,0.001,64,8,1.41218,2.11421,115.688,0.502016,79.5625,0.485539,1
train_LensGAN128_bf0ed_00001,PENDING,,0.001,16,8,4.85217,0.0439155,91.8125,0.472571,64.9375,0.488933,1


[2m[36m(pid=8294)[0m 2021-09-13 20:51:48,664	INFO trainable.py:76 -- Checkpoint size is 96370070 bytes
2021-09-13 20:51:49,175	INFO trainable.py:76 -- Checkpoint size is 111478742 bytes
[2m[36m(pid=8372)[0m 2021-09-13 20:51:52,965	INFO trainable.py:383 -- Restored on 172.28.0.2 from checkpoint: /content/drive/MyDrive/Logs/F/LensGAN128/onehot_leaky_dropout/train_LensGAN128_bf0ed_00001_1_n_fmaps=16_2021-09-13_20-47-02/checkpoint_tmpe8cae0/./
[2m[36m(pid=8372)[0m 2021-09-13 20:51:52,965	INFO trainable.py:390 -- Current state after restoring: {'_iteration': 1, '_timesteps_total': None, '_time_total': 43.23782467842102, '_episodes_total': None}


Trial name,status,loc,lr,n_fmaps,bs,loss_G,loss_D,FID,auroc,FID_cross,auroc_cross,training_iteration
train_LensGAN128_bf0ed_00001,RUNNING,,0.001,16,8,4.85217,0.0439155,91.8125,0.472571,64.9375,0.488933,1
train_LensGAN128_bf0ed_00000,PAUSED,,0.001,8,8,6.5043,0.0306933,14.7266,0.459775,17.0938,0.472272,2
train_LensGAN128_bf0ed_00003,PAUSED,,0.001,64,8,1.41218,2.11421,115.688,0.502016,79.5625,0.485539,1
train_LensGAN128_bf0ed_00002,PENDING,,0.001,32,8,1.68768,0.826553,95.8125,0.508296,68.25,0.435517,1


[2m[36m(pid=8372)[0m Using native 16bit precision.
[2m[36m(pid=8372)[0m GPU available: True, used: True
[2m[36m(pid=8372)[0m TPU available: False, using: 0 TPU cores
[2m[36m(pid=8372)[0m IPU available: False, using: 0 IPUs
[2m[36m(pid=8372)[0m Restoring states from the checkpoint file at /content/drive/MyDrive/Logs/F/LensGAN128/onehot_leaky_dropout/train_LensGAN128_bf0ed_00001_1_n_fmaps=16_2021-09-13_20-47-02/checkpoint_tmpe8cae0/./checkpoint
[2m[36m(pid=8372)[0m LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Trial name,status,loc,lr,n_fmaps,bs,loss_G,loss_D,FID,auroc,FID_cross,auroc_cross,training_iteration
train_LensGAN128_bf0ed_00001,RUNNING,,0.001,16,8,4.85217,0.0439155,91.8125,0.472571,64.9375,0.488933,1
train_LensGAN128_bf0ed_00000,PAUSED,,0.001,8,8,6.5043,0.0306933,14.7266,0.459775,17.0938,0.472272,2
train_LensGAN128_bf0ed_00003,PAUSED,,0.001,64,8,1.41218,2.11421,115.688,0.502016,79.5625,0.485539,1
train_LensGAN128_bf0ed_00002,PENDING,,0.001,32,8,1.68768,0.826553,95.8125,0.508296,68.25,0.435517,1


[2m[36m(pid=8372)[0m Restored all states from the checkpoint file at /content/drive/MyDrive/Logs/F/LensGAN128/onehot_leaky_dropout/train_LensGAN128_bf0ed_00001_1_n_fmaps=16_2021-09-13_20-47-02/checkpoint_tmpe8cae0/./checkpoint
[2m[36m(pid=8372)[0m 
[2m[36m(pid=8372)[0m   | Name          | Type               | Params
[2m[36m(pid=8372)[0m -----------------------------------------------------
[2m[36m(pid=8372)[0m 0 | generator     | DCGANGenerator     | 1.1 M 
[2m[36m(pid=8372)[0m 1 | discriminator | DCGANDiscriminator | 701 K 
[2m[36m(pid=8372)[0m 2 | criterion     | BCEWithLogitsLoss  | 0     
[2m[36m(pid=8372)[0m 3 | modelF        | ResNet             | 11.2 M
[2m[36m(pid=8372)[0m 4 | lastF         | Sequential         | 1.5 K 
[2m[36m(pid=8372)[0m 5 | modelJ        | ResNet             | 11.2 M
[2m[36m(pid=8372)[0m 6 | lastJ         | Sequential         | 1.5 K 
[2m[36m(pid=8372)[0m 7 | imgMetrics    | MetricCollection   | 22.3 M
[2m[36m(pid=8372)

[2m[36m(pid=8372)[0m Validation sanity check: 0it [00:00, ?it/s]Validation sanity check:   0%|          | 0/2 [00:00<?, ?it/s]


[2m[36m(pid=8372)[0m   return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)
[2m[36m(pid=8372)[0m   "`Trainer.running_sanity_check` has been renamed to `Trainer.sanity_checking` and will be removed in v1.5."


Epoch 1:   0%|          | 0/514 [00:00<00:00, 3492.34it/s]  
Epoch 1:  10%|▉         | 51/514 [00:03<00:27, 16.71it/s, loss=2, v_num=.]
Epoch 1:  20%|█▉        | 102/514 [00:06<00:24, 16.51it/s, loss=2.22, v_num=.]
Epoch 1:  30%|██▉       | 153/514 [00:09<00:21, 16.53it/s, loss=2.85, v_num=.]
Epoch 1:  40%|███▉      | 204/514 [00:12<00:18, 16.62it/s, loss=2.39, v_num=.]
Epoch 1:  50%|████▉     | 255/514 [00:15<00:15, 16.82it/s, loss=2.32, v_num=.]
Epoch 1:  60%|█████▉    | 306/514 [00:18<00:12, 16.94it/s, loss=3.11, v_num=.]
Epoch 1:  69%|██████▉   | 357/514 [00:21<00:09, 16.96it/s, loss=2.47, v_num=.]
Epoch 1:  79%|███████▉  | 408/514 [00:24<00:06, 16.93it/s, loss=3.6, v_num=.] 
Epoch 1:  89%|████████▉ | 459/514 [00:27<00:03, 16.86it/s, loss=3.7, v_num=.]




[2m[36m(pid=8372)[0m Epoch 1:  99%|█████████▉| 510/514 [00:27<00:00, 18.39it/s, loss=3.7, v_num=.]
[2m[36m(pid=8372)[0m Validating: 0it [00:00, ?it/s][A
Validating:   0%|          | 0/46 [00:00<?, ?it/s][A
[2m[36m(pid=8372)[0m 
Validating: 100%|██████████| 46/46 [00:04<00:00,  9.84it/s][A


[2m[36m(pid=8372)[0m   "`Trainer.running_sanity_check` has been renamed to `Trainer.sanity_checking` and will be removed in v1.5."


Result for train_LensGAN128_bf0ed_00001:
  FID: 19.46875
  FID_cross: 58.0
  auroc: 0.4394567608833313
  auroc_cross: 0.4274238646030426
  date: 2021-09-13_20-52-36
  done: false
  experiment_id: ba54eed17dca4ff4ba799df668858677
  hostname: af2913abb491
  iterations_since_restore: 1
  loss_D: 0.29861387610435486
  loss_G: 8.073526382446289
  node_ip: 172.28.0.2
  pid: 8372
  should_checkpoint: true
  time_since_restore: 43.815224409103394
  time_this_iter_s: 43.815224409103394
  time_total_s: 87.05304908752441
  timestamp: 1631566356
  timesteps_since_restore: 0
  training_iteration: 2
  trial_id: bf0ed_00001
  


Trial name,status,loc,lr,n_fmaps,bs,loss_G,loss_D,FID,auroc,FID_cross,auroc_cross,training_iteration
train_LensGAN128_bf0ed_00001,RUNNING,172.28.0.2:8372,0.001,16,8,8.07353,0.298614,19.4688,0.439457,58.0,0.427424,2
train_LensGAN128_bf0ed_00000,PAUSED,,0.001,8,8,6.5043,0.0306933,14.7266,0.459775,17.0938,0.472272,2
train_LensGAN128_bf0ed_00003,PAUSED,,0.001,64,8,1.41218,2.11421,115.688,0.502016,79.5625,0.485539,1
train_LensGAN128_bf0ed_00002,PENDING,,0.001,32,8,1.68768,0.826553,95.8125,0.508296,68.25,0.435517,1


[2m[36m(pid=8372)[0m 2021-09-13 20:52:37,419	INFO trainable.py:76 -- Checkpoint size is 111478742 bytes
2021-09-13 20:52:38,168	INFO trainable.py:76 -- Checkpoint size is 166763030 bytes


Trial name,status,loc,lr,n_fmaps,bs,loss_G,loss_D,FID,auroc,FID_cross,auroc_cross,training_iteration
train_LensGAN128_bf0ed_00002,RUNNING,,0.001,32,8,1.68768,0.826553,95.8125,0.508296,68.25,0.435517,1
train_LensGAN128_bf0ed_00000,PAUSED,,0.001,8,8,6.5043,0.0306933,14.7266,0.459775,17.0938,0.472272,2
train_LensGAN128_bf0ed_00001,PAUSED,,0.001,16,8,8.07353,0.298614,19.4688,0.439457,58.0,0.427424,2
train_LensGAN128_bf0ed_00003,PENDING,,0.001,64,8,1.41218,2.11421,115.688,0.502016,79.5625,0.485539,1


[2m[36m(pid=8451)[0m 2021-09-13 20:52:42,595	INFO trainable.py:383 -- Restored on 172.28.0.2 from checkpoint: /content/drive/MyDrive/Logs/F/LensGAN128/onehot_leaky_dropout/train_LensGAN128_bf0ed_00002_2_n_fmaps=32_2021-09-13_20-47-47/checkpoint_tmpef9ffb/./
[2m[36m(pid=8451)[0m 2021-09-13 20:52:42,595	INFO trainable.py:390 -- Current state after restoring: {'_iteration': 1, '_timesteps_total': None, '_time_total': 48.200762033462524, '_episodes_total': None}
[2m[36m(pid=8451)[0m Using native 16bit precision.
[2m[36m(pid=8451)[0m GPU available: True, used: True
[2m[36m(pid=8451)[0m TPU available: False, using: 0 TPU cores
[2m[36m(pid=8451)[0m IPU available: False, using: 0 IPUs
[2m[36m(pid=8451)[0m Restoring states from the checkpoint file at /content/drive/MyDrive/Logs/F/LensGAN128/onehot_leaky_dropout/train_LensGAN128_bf0ed_00002_2_n_fmaps=32_2021-09-13_20-47-47/checkpoint_tmpef9ffb/./checkpoint
[2m[36m(pid=8451)[0m LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Trial name,status,loc,lr,n_fmaps,bs,loss_G,loss_D,FID,auroc,FID_cross,auroc_cross,training_iteration
train_LensGAN128_bf0ed_00002,RUNNING,,0.001,32,8,1.68768,0.826553,95.8125,0.508296,68.25,0.435517,1
train_LensGAN128_bf0ed_00000,PAUSED,,0.001,8,8,6.5043,0.0306933,14.7266,0.459775,17.0938,0.472272,2
train_LensGAN128_bf0ed_00001,PAUSED,,0.001,16,8,8.07353,0.298614,19.4688,0.439457,58.0,0.427424,2
train_LensGAN128_bf0ed_00003,PENDING,,0.001,64,8,1.41218,2.11421,115.688,0.502016,79.5625,0.485539,1


[2m[36m(pid=8451)[0m Restored all states from the checkpoint file at /content/drive/MyDrive/Logs/F/LensGAN128/onehot_leaky_dropout/train_LensGAN128_bf0ed_00002_2_n_fmaps=32_2021-09-13_20-47-47/checkpoint_tmpef9ffb/./checkpoint
[2m[36m(pid=8451)[0m 
[2m[36m(pid=8451)[0m   | Name          | Type               | Params
[2m[36m(pid=8451)[0m -----------------------------------------------------
[2m[36m(pid=8451)[0m 0 | generator     | DCGANGenerator     | 3.6 M 
[2m[36m(pid=8451)[0m 1 | discriminator | DCGANDiscriminator | 2.8 M 
[2m[36m(pid=8451)[0m 2 | criterion     | BCEWithLogitsLoss  | 0     
[2m[36m(pid=8451)[0m 3 | modelF        | ResNet             | 11.2 M
[2m[36m(pid=8451)[0m 4 | lastF         | Sequential         | 1.5 K 
[2m[36m(pid=8451)[0m 5 | modelJ        | ResNet             | 11.2 M
[2m[36m(pid=8451)[0m 6 | lastJ         | Sequential         | 1.5 K 
[2m[36m(pid=8451)[0m 7 | imgMetrics    | MetricCollection   | 22.3 M
[2m[36m(pid=8451)

[2m[36m(pid=8451)[0m Validation sanity check: 0it [00:00, ?it/s]Validation sanity check:   0%|          | 0/2 [00:00<?, ?it/s]


[2m[36m(pid=8451)[0m   return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)
[2m[36m(pid=8451)[0m   "`Trainer.running_sanity_check` has been renamed to `Trainer.sanity_checking` and will be removed in v1.5."


Epoch 1:   0%|          | 0/514 [00:00<00:00, 4029.11it/s]  
Epoch 1:  10%|▉         | 51/514 [00:03<00:33, 13.98it/s, loss=2.02, v_num=.]
Epoch 1:  20%|█▉        | 102/514 [00:07<00:28, 14.37it/s, loss=1.72, v_num=.]
Epoch 1:  30%|██▉       | 153/514 [00:10<00:24, 14.52it/s, loss=2.1, v_num=.] 
Epoch 1:  40%|███▉      | 204/514 [00:14<00:21, 14.62it/s, loss=2.39, v_num=.]
Epoch 1:  50%|████▉     | 255/514 [00:17<00:17, 14.63it/s, loss=3.18, v_num=.]
Epoch 1:  60%|█████▉    | 306/514 [00:20<00:14, 14.63it/s, loss=3.11, v_num=.]
Epoch 1:  69%|██████▉   | 357/514 [00:24<00:10, 14.61it/s, loss=3.2, v_num=.] 
Epoch 1:  79%|███████▉  | 408/514 [00:28<00:07, 14.57it/s, loss=2.94, v_num=.]
Epoch 1:  89%|████████▉ | 459/514 [00:31<00:03, 14.57it/s, loss=2.56, v_num=.]




[2m[36m(pid=8451)[0m Epoch 1:  99%|█████████▉| 510/514 [00:32<00:00, 15.86it/s, loss=2.56, v_num=.]
[2m[36m(pid=8451)[0m Validating: 0it [00:00, ?it/s][A
Validating:   0%|          | 0/46 [00:00<?, ?it/s][A
[2m[36m(pid=8451)[0m 
Validating: 100%|██████████| 46/46 [00:04<00:00, 10.89it/s][A


[2m[36m(pid=8451)[0m   "`Trainer.running_sanity_check` has been renamed to `Trainer.sanity_checking` and will be removed in v1.5."


Result for train_LensGAN128_bf0ed_00002:
  FID: 11.2734375
  FID_cross: 21.265625
  auroc: 0.5028094053268433
  auroc_cross: 0.44413959980010986
  date: 2021-09-13_20-53-31
  done: false
  experiment_id: 2ee24e63b5e64a87974d0d649b2dacba
  hostname: af2913abb491
  iterations_since_restore: 1
  loss_D: 0.4128419756889343
  loss_G: 1.973191499710083
  node_ip: 172.28.0.2
  pid: 8451
  should_checkpoint: true
  time_since_restore: 48.39972639083862
  time_this_iter_s: 48.39972639083862
  time_total_s: 96.60048842430115
  timestamp: 1631566411
  timesteps_since_restore: 0
  training_iteration: 2
  trial_id: bf0ed_00002
  


Trial name,status,loc,lr,n_fmaps,bs,loss_G,loss_D,FID,auroc,FID_cross,auroc_cross,training_iteration
train_LensGAN128_bf0ed_00002,RUNNING,172.28.0.2:8451,0.001,32,8,1.97319,0.412842,11.2734,0.502809,21.2656,0.44414,2
train_LensGAN128_bf0ed_00000,PAUSED,,0.001,8,8,6.5043,0.0306933,14.7266,0.459775,17.0938,0.472272,2
train_LensGAN128_bf0ed_00001,PAUSED,,0.001,16,8,8.07353,0.298614,19.4688,0.439457,58.0,0.427424,2
train_LensGAN128_bf0ed_00003,PENDING,,0.001,64,8,1.41218,2.11421,115.688,0.502016,79.5625,0.485539,1


[2m[36m(pid=8451)[0m 2021-09-13 20:53:31,597	INFO trainable.py:76 -- Checkpoint size is 166763030 bytes
[2m[36m(pid=8451)[0m 2021-09-13 20:53:32,373	INFO trainable.py:76 -- Checkpoint size is 166763030 bytes
2021-09-13 20:53:34,326	INFO trainable.py:76 -- Checkpoint size is 377601622 bytes


Trial name,status,loc,lr,n_fmaps,bs,loss_G,loss_D,FID,auroc,FID_cross,auroc_cross,training_iteration
train_LensGAN128_bf0ed_00003,RUNNING,,0.001,64,8,1.41218,2.11421,115.688,0.502016,79.5625,0.485539,1
train_LensGAN128_bf0ed_00001,PAUSED,,0.001,16,8,8.07353,0.298614,19.4688,0.439457,58.0,0.427424,2
train_LensGAN128_bf0ed_00002,PAUSED,,0.001,32,8,1.97319,0.412842,11.2734,0.502809,21.2656,0.44414,2
train_LensGAN128_bf0ed_00000,PENDING,,0.001,8,8,6.5043,0.0306933,14.7266,0.459775,17.0938,0.472272,2


[2m[36m(pid=8536)[0m 2021-09-13 20:53:38,855	INFO trainable.py:383 -- Restored on 172.28.0.2 from checkpoint: /content/drive/MyDrive/Logs/F/LensGAN128/onehot_leaky_dropout/train_LensGAN128_bf0ed_00003_3_n_fmaps=64_2021-09-13_20-48-34/checkpoint_tmp48c07f/./
[2m[36m(pid=8536)[0m 2021-09-13 20:53:38,855	INFO trainable.py:390 -- Current state after restoring: {'_iteration': 1, '_timesteps_total': None, '_time_total': 88.97758150100708, '_episodes_total': None}
[2m[36m(pid=8536)[0m Using native 16bit precision.
[2m[36m(pid=8536)[0m GPU available: True, used: True
[2m[36m(pid=8536)[0m TPU available: False, using: 0 TPU cores
[2m[36m(pid=8536)[0m IPU available: False, using: 0 IPUs
[2m[36m(pid=8536)[0m Restoring states from the checkpoint file at /content/drive/MyDrive/Logs/F/LensGAN128/onehot_leaky_dropout/train_LensGAN128_bf0ed_00003_3_n_fmaps=64_2021-09-13_20-48-34/checkpoint_tmp48c07f/./checkpoint


Trial name,status,loc,lr,n_fmaps,bs,loss_G,loss_D,FID,auroc,FID_cross,auroc_cross,training_iteration
train_LensGAN128_bf0ed_00003,RUNNING,,0.001,64,8,1.41218,2.11421,115.688,0.502016,79.5625,0.485539,1
train_LensGAN128_bf0ed_00001,PAUSED,,0.001,16,8,8.07353,0.298614,19.4688,0.439457,58.0,0.427424,2
train_LensGAN128_bf0ed_00002,PAUSED,,0.001,32,8,1.97319,0.412842,11.2734,0.502809,21.2656,0.44414,2
train_LensGAN128_bf0ed_00000,PENDING,,0.001,8,8,6.5043,0.0306933,14.7266,0.459775,17.0938,0.472272,2


[2m[36m(pid=8536)[0m LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
[2m[36m(pid=8536)[0m Restored all states from the checkpoint file at /content/drive/MyDrive/Logs/F/LensGAN128/onehot_leaky_dropout/train_LensGAN128_bf0ed_00003_3_n_fmaps=64_2021-09-13_20-48-34/checkpoint_tmp48c07f/./checkpoint
[2m[36m(pid=8536)[0m 
[2m[36m(pid=8536)[0m   | Name          | Type               | Params
[2m[36m(pid=8536)[0m -----------------------------------------------------
[2m[36m(pid=8536)[0m 0 | generator     | DCGANGenerator     | 12.8 M
[2m[36m(pid=8536)[0m 1 | discriminator | DCGANDiscriminator | 11.2 M
[2m[36m(pid=8536)[0m 2 | criterion     | BCEWithLogitsLoss  | 0     
[2m[36m(pid=8536)[0m 3 | modelF        | ResNet             | 11.2 M
[2m[36m(pid=8536)[0m 4 | lastF         | Sequential         | 1.5 K 
[2m[36m(pid=8536)[0m 5 | modelJ        | ResNet             | 11.2 M
[2m[36m(pid=8536)[0m 6 | lastJ         | Sequential         | 1.5 K 
[2m[36m(pid=8536)[0m 7

[2m[36m(pid=8536)[0m Validation sanity check: 0it [00:00, ?it/s]Validation sanity check:   0%|          | 0/2 [00:00<?, ?it/s]


[2m[36m(pid=8536)[0m   return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)
[2m[36m(pid=8536)[0m   "`Trainer.running_sanity_check` has been renamed to `Trainer.sanity_checking` and will be removed in v1.5."


Epoch 1:   0%|          | 0/514 [00:00<00:00, 1500.11it/s]  
Epoch 1:  10%|▉         | 51/514 [00:08<01:12,  6.36it/s, loss=1.9, v_num=.]
Epoch 1:  20%|█▉        | 102/514 [00:15<01:02,  6.58it/s, loss=1.91, v_num=.]
Epoch 1:  30%|██▉       | 153/514 [00:23<00:54,  6.66it/s, loss=1.61, v_num=.]
Epoch 1:  40%|███▉      | 204/514 [00:30<00:46,  6.69it/s, loss=2.07, v_num=.]
Epoch 1:  50%|████▉     | 255/514 [00:38<00:38,  6.72it/s, loss=1.82, v_num=.]
Epoch 1:  60%|█████▉    | 306/514 [00:45<00:30,  6.74it/s, loss=1.92, v_num=.]
Epoch 1:  69%|██████▉   | 357/514 [00:53<00:23,  6.75it/s, loss=1.94, v_num=.]
Epoch 1:  79%|███████▉  | 408/514 [01:00<00:15,  6.76it/s, loss=1.67, v_num=.]
Epoch 1:  89%|████████▉ | 459/514 [01:07<00:08,  6.77it/s, loss=1.4, v_num=.] 




[2m[36m(pid=8536)[0m Epoch 1:  99%|█████████▉| 510/514 [01:09<00:00,  7.37it/s, loss=1.4, v_num=.]
[2m[36m(pid=8536)[0m Validating: 0it [00:00, ?it/s][A
Validating:   0%|          | 0/46 [00:00<?, ?it/s][A
[2m[36m(pid=8536)[0m 
Validating: 100%|██████████| 46/46 [00:04<00:00,  9.92it/s][A


[2m[36m(pid=8536)[0m   "`Trainer.running_sanity_check` has been renamed to `Trainer.sanity_checking` and will be removed in v1.5."
2021-09-13 20:55:07,712	INFO pbt.py:543 -- [exploit] transferring weights from trial train_LensGAN128_bf0ed_00002 (score -11.2734375) -> train_LensGAN128_bf0ed_00003 (score -24.59375)
2021-09-13 20:55:07,715	INFO pbt.py:558 -- [explore] perturbed config from {'lr': 0.001, 'bs': 8} -> {'lr': 0.0012, 'bs': 64}


Result for train_LensGAN128_bf0ed_00003:
  FID: 24.59375
  FID_cross: 30.53125
  auroc: 0.47018229961395264
  auroc_cross: 0.48548173904418945
  date: 2021-09-13_20-55-07
  done: false
  experiment_id: 9bba428635cb45d99b37e10935dd969a
  hostname: af2913abb491
  iterations_since_restore: 1
  loss_D: 1.3658102750778198
  loss_G: 1.6765936613082886
  node_ip: 172.28.0.2
  pid: 8536
  should_checkpoint: true
  time_since_restore: 88.81802535057068
  time_this_iter_s: 88.81802535057068
  time_total_s: 177.79560685157776
  timestamp: 1631566507
  timesteps_since_restore: 0
  training_iteration: 2
  trial_id: bf0ed_00003
  


Trial name,status,loc,lr,n_fmaps,bs,loss_G,loss_D,FID,auroc,FID_cross,auroc_cross,training_iteration
train_LensGAN128_bf0ed_00003,RUNNING,172.28.0.2:8536,0.0012,32,64,1.67659,1.36581,24.5938,0.470182,30.5312,0.485482,2
train_LensGAN128_bf0ed_00001,PAUSED,,0.001,16,8,8.07353,0.298614,19.4688,0.439457,58.0,0.427424,2
train_LensGAN128_bf0ed_00002,PAUSED,,0.001,32,8,1.97319,0.412842,11.2734,0.502809,21.2656,0.44414,2
train_LensGAN128_bf0ed_00000,PENDING,,0.001,8,8,6.5043,0.0306933,14.7266,0.459775,17.0938,0.472272,2


[2m[36m(pid=8536)[0m 2021-09-13 20:55:08,771	INFO trainable.py:383 -- Restored on 172.28.0.2 from checkpoint: /content/drive/MyDrive/Logs/F/LensGAN128/onehot_leaky_dropout/train_LensGAN128_bf0ed_00003_3_n_fmaps=64_2021-09-13_20-48-34/checkpoint_tmp5ed48e/./
[2m[36m(pid=8536)[0m 2021-09-13 20:55:08,771	INFO trainable.py:390 -- Current state after restoring: {'_iteration': 2, '_timesteps_total': None, '_time_total': 96.60048842430115, '_episodes_total': None}
[2m[36m(pid=8536)[0m 2021-09-13 20:55:10,781	INFO trainable.py:76 -- Checkpoint size is 166763030 bytes
2021-09-13 20:55:13,198	INFO trainable.py:76 -- Checkpoint size is 96370070 bytes


Trial name,status,loc,lr,n_fmaps,bs,loss_G,loss_D,FID,auroc,FID_cross,auroc_cross,training_iteration
train_LensGAN128_bf0ed_00000,RUNNING,,0.001,8,8,6.5043,0.0306933,14.7266,0.459775,17.0938,0.472272,2
train_LensGAN128_bf0ed_00001,PAUSED,,0.001,16,8,8.07353,0.298614,19.4688,0.439457,58.0,0.427424,2
train_LensGAN128_bf0ed_00002,PAUSED,,0.001,32,8,1.97319,0.412842,11.2734,0.502809,21.2656,0.44414,2
train_LensGAN128_bf0ed_00003,PAUSED,,0.0012,32,64,1.67659,1.36581,24.5938,0.470182,30.5312,0.485482,2


[2m[36m(pid=8612)[0m 2021-09-13 20:55:15,571	INFO trainable.py:383 -- Restored on 172.28.0.2 from checkpoint: /content/drive/MyDrive/Logs/F/LensGAN128/onehot_leaky_dropout/train_LensGAN128_bf0ed_00000_0_n_fmaps=8_2021-09-13_20-47-01/checkpoint_tmpf0ae6e/./
[2m[36m(pid=8612)[0m 2021-09-13 20:55:15,571	INFO trainable.py:390 -- Current state after restoring: {'_iteration': 2, '_timesteps_total': None, '_time_total': 85.02918696403503, '_episodes_total': None}
[2m[36m(pid=8612)[0m Using native 16bit precision.
[2m[36m(pid=8612)[0m GPU available: True, used: True
[2m[36m(pid=8612)[0m TPU available: False, using: 0 TPU cores
[2m[36m(pid=8612)[0m IPU available: False, using: 0 IPUs
[2m[36m(pid=8612)[0m Restoring states from the checkpoint file at /content/drive/MyDrive/Logs/F/LensGAN128/onehot_leaky_dropout/train_LensGAN128_bf0ed_00000_0_n_fmaps=8_2021-09-13_20-47-01/checkpoint_tmpf0ae6e/./checkpoint
[2m[36m(pid=8612)[0m LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Trial name,status,loc,lr,n_fmaps,bs,loss_G,loss_D,FID,auroc,FID_cross,auroc_cross,training_iteration
train_LensGAN128_bf0ed_00000,RUNNING,,0.001,8,8,6.5043,0.0306933,14.7266,0.459775,17.0938,0.472272,2
train_LensGAN128_bf0ed_00002,PAUSED,,0.001,32,8,1.97319,0.412842,11.2734,0.502809,21.2656,0.44414,2
train_LensGAN128_bf0ed_00003,PAUSED,,0.0012,32,64,1.67659,1.36581,24.5938,0.470182,30.5312,0.485482,2
train_LensGAN128_bf0ed_00001,PENDING,,0.001,16,8,8.07353,0.298614,19.4688,0.439457,58.0,0.427424,2


[2m[36m(pid=8612)[0m Restored all states from the checkpoint file at /content/drive/MyDrive/Logs/F/LensGAN128/onehot_leaky_dropout/train_LensGAN128_bf0ed_00000_0_n_fmaps=8_2021-09-13_20-47-01/checkpoint_tmpf0ae6e/./checkpoint
[2m[36m(pid=8612)[0m 
[2m[36m(pid=8612)[0m   | Name          | Type               | Params
[2m[36m(pid=8612)[0m -----------------------------------------------------
[2m[36m(pid=8612)[0m 0 | generator     | DCGANGenerator     | 385 K 
[2m[36m(pid=8612)[0m 1 | discriminator | DCGANDiscriminator | 176 K 
[2m[36m(pid=8612)[0m 2 | criterion     | BCEWithLogitsLoss  | 0     
[2m[36m(pid=8612)[0m 3 | modelF        | ResNet             | 11.2 M
[2m[36m(pid=8612)[0m 4 | lastF         | Sequential         | 1.5 K 
[2m[36m(pid=8612)[0m 5 | modelJ        | ResNet             | 11.2 M
[2m[36m(pid=8612)[0m 6 | lastJ         | Sequential         | 1.5 K 
[2m[36m(pid=8612)[0m 7 | imgMetrics    | MetricCollection   | 22.3 M
[2m[36m(pid=8612)

[2m[36m(pid=8612)[0m Validation sanity check: 0it [00:00, ?it/s]Validation sanity check:   0%|          | 0/2 [00:00<?, ?it/s]


[2m[36m(pid=8612)[0m   return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)


Trial name,status,loc,lr,n_fmaps,bs,loss_G,loss_D,FID,auroc,FID_cross,auroc_cross,training_iteration
train_LensGAN128_bf0ed_00000,RUNNING,,0.001,8,8,6.5043,0.0306933,14.7266,0.459775,17.0938,0.472272,2
train_LensGAN128_bf0ed_00002,PAUSED,,0.001,32,8,1.97319,0.412842,11.2734,0.502809,21.2656,0.44414,2
train_LensGAN128_bf0ed_00003,PAUSED,,0.0012,32,64,1.67659,1.36581,24.5938,0.470182,30.5312,0.485482,2
train_LensGAN128_bf0ed_00001,PENDING,,0.001,16,8,8.07353,0.298614,19.4688,0.439457,58.0,0.427424,2


[2m[36m(pid=8612)[0m   "`Trainer.running_sanity_check` has been renamed to `Trainer.sanity_checking` and will be removed in v1.5."


Epoch 2:   0%|          | 0/514 [00:00<00:00, 4686.37it/s]  
Epoch 2:  10%|▉         | 51/514 [00:03<00:26, 17.24it/s, loss=2.93, v_num=.]
Epoch 2:  20%|█▉        | 102/514 [00:10<00:41,  9.92it/s, loss=3.3, v_num=.] 
Epoch 2:  30%|██▉       | 153/514 [00:13<00:31, 11.57it/s, loss=3.07, v_num=.]
Epoch 2:  40%|███▉      | 204/514 [00:16<00:24, 12.78it/s, loss=3.06, v_num=.]
Epoch 2:  50%|████▉     | 255/514 [00:18<00:19, 13.54it/s, loss=3.24, v_num=.]
Epoch 2:  60%|█████▉    | 306/514 [00:21<00:14, 14.13it/s, loss=2.58, v_num=.]
Epoch 2:  69%|██████▉   | 357/514 [00:24<00:10, 14.58it/s, loss=3.01, v_num=.]
Epoch 2:  79%|███████▉  | 408/514 [00:27<00:07, 14.85it/s, loss=3.22, v_num=.]
Epoch 2:  89%|████████▉ | 459/514 [00:30<00:03, 15.07it/s, loss=2.98, v_num=.]
Epoch 2:  99%|█████████▉| 510/514 [00:31<00:00, 16.42it/s, loss=2.98, v_num=.]
Validating: 0it [00:00, ?it/s][A
Validating:   0%|          | 0/46 [00:00<?, ?it/s][A




[2m[36m(pid=8612)[0m 
[2m[36m(pid=8612)[0m Validating: 100%|██████████| 46/46 [00:04<00:00, 10.85it/s][A


[2m[36m(pid=8612)[0m   "`Trainer.running_sanity_check` has been renamed to `Trainer.sanity_checking` and will be removed in v1.5."


Result for train_LensGAN128_bf0ed_00000:
  FID: 5.57421875
  FID_cross: 13.3203125
  auroc: 0.47783249616622925
  auroc_cross: 0.48905321955680847
  date: 2021-09-13_20-56-01
  done: false
  experiment_id: 7ee666093f6043f9bc108680723b15b4
  hostname: af2913abb491
  iterations_since_restore: 1
  loss_D: 0.020195264369249344
  loss_G: 10.664708137512207
  node_ip: 172.28.0.2
  pid: 8612
  should_checkpoint: true
  time_since_restore: 46.179887771606445
  time_this_iter_s: 46.179887771606445
  time_total_s: 131.20907473564148
  timestamp: 1631566561
  timesteps_since_restore: 0
  training_iteration: 3
  trial_id: bf0ed_00000
  


Trial name,status,loc,lr,n_fmaps,bs,loss_G,loss_D,FID,auroc,FID_cross,auroc_cross,training_iteration
train_LensGAN128_bf0ed_00000,RUNNING,172.28.0.2:8612,0.001,8,8,10.6647,0.0201953,5.57422,0.477832,13.3203,0.489053,3
train_LensGAN128_bf0ed_00002,PAUSED,,0.001,32,8,1.97319,0.412842,11.2734,0.502809,21.2656,0.44414,2
train_LensGAN128_bf0ed_00003,PAUSED,,0.0012,32,64,1.67659,1.36581,24.5938,0.470182,30.5312,0.485482,2
train_LensGAN128_bf0ed_00001,PENDING,,0.001,16,8,8.07353,0.298614,19.4688,0.439457,58.0,0.427424,2


[2m[36m(pid=8612)[0m 2021-09-13 20:56:02,160	INFO trainable.py:76 -- Checkpoint size is 96370070 bytes
[2m[36m(pid=8612)[0m 2021-09-13 20:56:02,535	INFO trainable.py:76 -- Checkpoint size is 96370070 bytes
2021-09-13 20:56:03,105	INFO trainable.py:76 -- Checkpoint size is 111478742 bytes


Trial name,status,loc,lr,n_fmaps,bs,loss_G,loss_D,FID,auroc,FID_cross,auroc_cross,training_iteration
train_LensGAN128_bf0ed_00001,RUNNING,,0.001,16,8,8.07353,0.298614,19.4688,0.439457,58.0,0.427424,2
train_LensGAN128_bf0ed_00000,PAUSED,,0.001,8,8,10.6647,0.0201953,5.57422,0.477832,13.3203,0.489053,3
train_LensGAN128_bf0ed_00003,PAUSED,,0.0012,32,64,1.67659,1.36581,24.5938,0.470182,30.5312,0.485482,2
train_LensGAN128_bf0ed_00002,PENDING,,0.001,32,8,1.97319,0.412842,11.2734,0.502809,21.2656,0.44414,2


[2m[36m(pid=8697)[0m 2021-09-13 20:56:07,264	INFO trainable.py:383 -- Restored on 172.28.0.2 from checkpoint: /content/drive/MyDrive/Logs/F/LensGAN128/onehot_leaky_dropout/train_LensGAN128_bf0ed_00001_1_n_fmaps=16_2021-09-13_20-47-02/checkpoint_tmpa4b718/./
[2m[36m(pid=8697)[0m 2021-09-13 20:56:07,264	INFO trainable.py:390 -- Current state after restoring: {'_iteration': 2, '_timesteps_total': None, '_time_total': 87.05304908752441, '_episodes_total': None}
[2m[36m(pid=8697)[0m Using native 16bit precision.
[2m[36m(pid=8697)[0m GPU available: True, used: True
[2m[36m(pid=8697)[0m TPU available: False, using: 0 TPU cores
[2m[36m(pid=8697)[0m IPU available: False, using: 0 IPUs
[2m[36m(pid=8697)[0m Restoring states from the checkpoint file at /content/drive/MyDrive/Logs/F/LensGAN128/onehot_leaky_dropout/train_LensGAN128_bf0ed_00001_1_n_fmaps=16_2021-09-13_20-47-02/checkpoint_tmpa4b718/./checkpoint
[2m[36m(pid=8697)[0m LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Trial name,status,loc,lr,n_fmaps,bs,loss_G,loss_D,FID,auroc,FID_cross,auroc_cross,training_iteration
train_LensGAN128_bf0ed_00001,RUNNING,,0.001,16,8,8.07353,0.298614,19.4688,0.439457,58.0,0.427424,2
train_LensGAN128_bf0ed_00000,PAUSED,,0.001,8,8,10.6647,0.0201953,5.57422,0.477832,13.3203,0.489053,3
train_LensGAN128_bf0ed_00003,PAUSED,,0.0012,32,64,1.67659,1.36581,24.5938,0.470182,30.5312,0.485482,2
train_LensGAN128_bf0ed_00002,PENDING,,0.001,32,8,1.97319,0.412842,11.2734,0.502809,21.2656,0.44414,2


[2m[36m(pid=8697)[0m Restored all states from the checkpoint file at /content/drive/MyDrive/Logs/F/LensGAN128/onehot_leaky_dropout/train_LensGAN128_bf0ed_00001_1_n_fmaps=16_2021-09-13_20-47-02/checkpoint_tmpa4b718/./checkpoint
[2m[36m(pid=8697)[0m 
[2m[36m(pid=8697)[0m   | Name          | Type               | Params
[2m[36m(pid=8697)[0m -----------------------------------------------------
[2m[36m(pid=8697)[0m 0 | generator     | DCGANGenerator     | 1.1 M 
[2m[36m(pid=8697)[0m 1 | discriminator | DCGANDiscriminator | 701 K 
[2m[36m(pid=8697)[0m 2 | criterion     | BCEWithLogitsLoss  | 0     
[2m[36m(pid=8697)[0m 3 | modelF        | ResNet             | 11.2 M
[2m[36m(pid=8697)[0m 4 | lastF         | Sequential         | 1.5 K 
[2m[36m(pid=8697)[0m 5 | modelJ        | ResNet             | 11.2 M
[2m[36m(pid=8697)[0m 6 | lastJ         | Sequential         | 1.5 K 
[2m[36m(pid=8697)[0m 7 | imgMetrics    | MetricCollection   | 22.3 M
[2m[36m(pid=8697)

[2m[36m(pid=8697)[0m Validation sanity check: 0it [00:00, ?it/s]Validation sanity check:   0%|          | 0/2 [00:00<?, ?it/s]


[2m[36m(pid=8697)[0m   return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)
[2m[36m(pid=8697)[0m   "`Trainer.running_sanity_check` has been renamed to `Trainer.sanity_checking` and will be removed in v1.5."


Epoch 2:   0%|          | 0/514 [00:00<00:00, 2435.72it/s] 
Epoch 2:  10%|▉         | 51/514 [00:04<00:42, 10.79it/s, loss=3, v_num=.]
Epoch 2:  20%|█▉        | 102/514 [00:07<00:31, 13.28it/s, loss=3.69, v_num=.]
Epoch 2:  30%|██▉       | 153/514 [00:10<00:25, 14.26it/s, loss=2.93, v_num=.]
Epoch 2:  40%|███▉      | 204/514 [00:13<00:20, 14.77it/s, loss=3.15, v_num=.]
Epoch 2:  50%|████▉     | 255/514 [00:16<00:17, 15.21it/s, loss=3.41, v_num=.]
Epoch 2:  60%|█████▉    | 306/514 [00:19<00:13, 15.48it/s, loss=3.42, v_num=.]
Epoch 2:  69%|██████▉   | 357/514 [00:22<00:10, 15.68it/s, loss=3.38, v_num=.]
Epoch 2:  79%|███████▉  | 408/514 [00:25<00:06, 15.73it/s, loss=3.65, v_num=.]
Epoch 2:  89%|████████▉ | 459/514 [00:29<00:03, 15.72it/s, loss=3.04, v_num=.]




[2m[36m(pid=8697)[0m Epoch 2:  99%|█████████▉| 510/514 [00:29<00:00, 17.15it/s, loss=3.04, v_num=.]
[2m[36m(pid=8697)[0m Validating: 0it [00:00, ?it/s][A
Validating:   0%|          | 0/46 [00:00<?, ?it/s][A
[2m[36m(pid=8697)[0m 
Validating: 100%|██████████| 46/46 [00:04<00:00, 10.90it/s][A


[2m[36m(pid=8697)[0m   "`Trainer.running_sanity_check` has been renamed to `Trainer.sanity_checking` and will be removed in v1.5."


Result for train_LensGAN128_bf0ed_00001:
  FID: 22.671875
  FID_cross: 23.3125
  auroc: 0.42622601985931396
  auroc_cross: 0.4664292335510254
  date: 2021-09-13_20-56-53
  done: false
  experiment_id: ba54eed17dca4ff4ba799df668858677
  hostname: af2913abb491
  iterations_since_restore: 1
  loss_D: 0.05917145311832428
  loss_G: 8.371068954467773
  node_ip: 172.28.0.2
  pid: 8697
  should_checkpoint: true
  time_since_restore: 46.025771379470825
  time_this_iter_s: 46.025771379470825
  time_total_s: 133.07882046699524
  timestamp: 1631566613
  timesteps_since_restore: 0
  training_iteration: 3
  trial_id: bf0ed_00001
  


Trial name,status,loc,lr,n_fmaps,bs,loss_G,loss_D,FID,auroc,FID_cross,auroc_cross,training_iteration
train_LensGAN128_bf0ed_00001,RUNNING,172.28.0.2:8697,0.001,16,8,8.37107,0.0591715,22.6719,0.426226,23.3125,0.466429,3
train_LensGAN128_bf0ed_00000,PAUSED,,0.001,8,8,10.6647,0.0201953,5.57422,0.477832,13.3203,0.489053,3
train_LensGAN128_bf0ed_00003,PAUSED,,0.0012,32,64,1.67659,1.36581,24.5938,0.470182,30.5312,0.485482,2
train_LensGAN128_bf0ed_00002,PENDING,,0.001,32,8,1.97319,0.412842,11.2734,0.502809,21.2656,0.44414,2


[2m[36m(pid=8697)[0m 2021-09-13 20:56:54,187	INFO trainable.py:76 -- Checkpoint size is 111478742 bytes
2021-09-13 20:56:54,956	INFO trainable.py:76 -- Checkpoint size is 166763030 bytes


Trial name,status,loc,lr,n_fmaps,bs,loss_G,loss_D,FID,auroc,FID_cross,auroc_cross,training_iteration
train_LensGAN128_bf0ed_00002,RUNNING,,0.001,32,8,1.97319,0.412842,11.2734,0.502809,21.2656,0.44414,2
train_LensGAN128_bf0ed_00000,PAUSED,,0.001,8,8,10.6647,0.0201953,5.57422,0.477832,13.3203,0.489053,3
train_LensGAN128_bf0ed_00001,PAUSED,,0.001,16,8,8.37107,0.0591715,22.6719,0.426226,23.3125,0.466429,3
train_LensGAN128_bf0ed_00003,PENDING,,0.0012,32,64,1.67659,1.36581,24.5938,0.470182,30.5312,0.485482,2


[2m[36m(pid=8782)[0m 2021-09-13 20:56:59,175	INFO trainable.py:383 -- Restored on 172.28.0.2 from checkpoint: /content/drive/MyDrive/Logs/F/LensGAN128/onehot_leaky_dropout/train_LensGAN128_bf0ed_00002_2_n_fmaps=32_2021-09-13_20-47-47/checkpoint_tmp01d4fe/./
[2m[36m(pid=8782)[0m 2021-09-13 20:56:59,175	INFO trainable.py:390 -- Current state after restoring: {'_iteration': 2, '_timesteps_total': None, '_time_total': 96.60048842430115, '_episodes_total': None}
[2m[36m(pid=8782)[0m Using native 16bit precision.
[2m[36m(pid=8782)[0m GPU available: True, used: True
[2m[36m(pid=8782)[0m TPU available: False, using: 0 TPU cores
[2m[36m(pid=8782)[0m IPU available: False, using: 0 IPUs
[2m[36m(pid=8782)[0m Restoring states from the checkpoint file at /content/drive/MyDrive/Logs/F/LensGAN128/onehot_leaky_dropout/train_LensGAN128_bf0ed_00002_2_n_fmaps=32_2021-09-13_20-47-47/checkpoint_tmp01d4fe/./checkpoint
[2m[36m(pid=8782)[0m LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Trial name,status,loc,lr,n_fmaps,bs,loss_G,loss_D,FID,auroc,FID_cross,auroc_cross,training_iteration
train_LensGAN128_bf0ed_00002,RUNNING,,0.001,32,8,1.97319,0.412842,11.2734,0.502809,21.2656,0.44414,2
train_LensGAN128_bf0ed_00000,PAUSED,,0.001,8,8,10.6647,0.0201953,5.57422,0.477832,13.3203,0.489053,3
train_LensGAN128_bf0ed_00001,PAUSED,,0.001,16,8,8.37107,0.0591715,22.6719,0.426226,23.3125,0.466429,3
train_LensGAN128_bf0ed_00003,PENDING,,0.0012,32,64,1.67659,1.36581,24.5938,0.470182,30.5312,0.485482,2


[2m[36m(pid=8782)[0m Restored all states from the checkpoint file at /content/drive/MyDrive/Logs/F/LensGAN128/onehot_leaky_dropout/train_LensGAN128_bf0ed_00002_2_n_fmaps=32_2021-09-13_20-47-47/checkpoint_tmp01d4fe/./checkpoint
[2m[36m(pid=8782)[0m 
[2m[36m(pid=8782)[0m   | Name          | Type               | Params
[2m[36m(pid=8782)[0m -----------------------------------------------------
[2m[36m(pid=8782)[0m 0 | generator     | DCGANGenerator     | 3.6 M 
[2m[36m(pid=8782)[0m 1 | discriminator | DCGANDiscriminator | 2.8 M 
[2m[36m(pid=8782)[0m 2 | criterion     | BCEWithLogitsLoss  | 0     
[2m[36m(pid=8782)[0m 3 | modelF        | ResNet             | 11.2 M
[2m[36m(pid=8782)[0m 4 | lastF         | Sequential         | 1.5 K 
[2m[36m(pid=8782)[0m 5 | modelJ        | ResNet             | 11.2 M
[2m[36m(pid=8782)[0m 6 | lastJ         | Sequential         | 1.5 K 
[2m[36m(pid=8782)[0m 7 | imgMetrics    | MetricCollection   | 22.3 M
[2m[36m(pid=8782)

[2m[36m(pid=8782)[0m Validation sanity check: 0it [00:00, ?it/s]Validation sanity check:   0%|          | 0/2 [00:00<?, ?it/s]


[2m[36m(pid=8782)[0m   return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)
[2m[36m(pid=8782)[0m   "`Trainer.running_sanity_check` has been renamed to `Trainer.sanity_checking` and will be removed in v1.5."


Epoch 2:   0%|          | 0/514 [00:00<00:00, 1927.53it/s]  
Epoch 2:  10%|▉         | 51/514 [00:10<01:33,  4.96it/s, loss=2.57, v_num=.]
Epoch 2:  20%|█▉        | 102/514 [00:19<01:19,  5.16it/s, loss=3.35, v_num=.]
Epoch 2:  30%|██▉       | 153/514 [00:28<01:06,  5.41it/s, loss=3.11, v_num=.]
Epoch 2:  40%|███▉      | 204/514 [00:33<00:50,  6.17it/s, loss=2.65, v_num=.]
Epoch 2:  50%|████▉     | 255/514 [00:36<00:37,  6.98it/s, loss=2.96, v_num=.]
Epoch 2:  60%|█████▉    | 306/514 [00:40<00:27,  7.62it/s, loss=3.2, v_num=.] 


In [None]:
drive.flush_and_unmount()

## Stage 2
Here we tune hyperparameters as we train our modified DCGAN.

In [None]:
# __tune_train_checkpoint_begin
def train_Stage2(config, checkpoint_dir=None, num_epochs=10, num_gpus=torch.cuda.device_count()):
    # print(os.cpu_count(), torch.cuda.device_count())
    kwargs = {
        # 'limit_train_batches' : 0.05,
        # 'limit_val_batches' : 0.05,
        'progress_bar_refresh_rate' : math.ceil(8250//config['batch_size']),
        'max_epochs' : num_epochs,
        'prepare_data_per_node' : False,
        # If fractional GPUs passed in, convert to int.
        'gpus' : math.ceil(num_gpus),
        'logger' : TensorBoardLogger(save_dir=tune.get_trial_dir(), name='', version='.'),
        'callbacks' : [
            TuneReportCheckpointCallback(
                {
                    'loss_G': 'Stage2/G/train/loss', 
                    'loss_D': 'Stage2/D/train/loss', 
                    # Switch up the auroc vlues when training on different dataset -----------------------------------------------
                    'auroc': 'Stage2/ResNet(F)/val/auroc', 
                    'auroc_cross': 'Stage2/ResNet(J)/val/auroc',
                },
            ),
        ],
        # 'stochastic_weight_avg' : True,
        # works with only one optimizer
        # 'benchmark' : True,
    }
    
    dm = npyImageData(config)                                              # Specify image width here    
    if checkpoint_dir is not None:
        kwargs['resume_from_checkpoint'] = os.path.join(checkpoint_dir, 'checkpoint')
        # model = Stage2.load_from_checkpoint(kwargs['resume_from_checkpoint'], config=config)
    # else:
        # model = Stage2(config)
    model = Stage2(config)
    trainer = pl.Trainer(**kwargs)

    trainer.fit(model, dm)
# __tune_train_checkpoint_end__


# # # __tune_asha_begin__
# def tune_Stage2_asha(num_samples=10, num_epochs=10, gpus_per_trial=torch.cuda.device_count()):
#     # print(os.cpu_count(), torch.cuda.device_count())
#     analysis = tune.run(
#         tune.with_parameters(
#             train_Stage2,
#             num_epochs=num_epochs,
#             num_gpus=gpus_per_trial
#         ),
#         # Change the folder name when changing dataset--------------------------------------------------------------------------
#         name='Stage2/pbt/J',
#         metric='auroc',
#         mode='max',
#         config={'learning_rate': 1e-4,
#                 'n_fmaps': tune.grid_search([8, 16, 32, 64, 128]),
#                 'batch_size': 8,
#                 },
#         # config={'learning_rate': 0.01,
#         #         'n_fmaps': 32,
#         #         'batch_size': 32,
#         #         },
#         # stop=TrialPlateauStopper('loss_G'),
#         resources_per_trial={'cpu': os.cpu_count(),
#                              'gpu': gpus_per_trial,
#                             },
#         local_dir='./drive/MyDrive/Logs',
#         scheduler = ASHAScheduler(max_t=num_epochs, grace_period=2,  reduction_factor=2),
#         progress_reporter=JupyterNotebookReporter(
#             overwrite=True,
#             parameter_columns=['learning_rate', 'n_fmaps', 'batch_size'],
#             metric_columns=['loss_G', 'loss_D', 'auroc', 'auroc_cross', 'training_iteration'],
#             sort_by_metric=True,
#         ),
#         fail_fast = True,
#         # reuse_actors=True,
#         # num_samples=num_samples,
#         resume='PROMPT',
# #         restore='/content/drive/MyDrive/Logs/delete/train_Stage2_e42ac_00025_25_batch_size=8,learning_rate=0.01,n_fmaps=8_2021-07-28_21-16-18/checkpoint_epoch=4-step=2339',
#     )

# #     print('Best hyperparameters found were: ', analysis.best_config)

# # # __tune_asha_end__


# __tune_pbt_begin__
def tune_Stage2_pbt(num_samples=10, num_epochs=10, gpus_per_trial=torch.cuda.device_count()):
    # print(os.cpu_count(), torch.cuda.device_count())
    analysis = tune.run(
        tune.with_parameters(
            train_Stage2,
            num_epochs=num_epochs,
            num_gpus=gpus_per_trial
        ),
        # Change the folder name when changing dataset--------------------------------------------------------------------------
        name='Stage2/pbt/F',
        metric='auroc',
        mode='max',
        config={'learning_rate': 1e-4,
                'n_fmaps': tune.grid_search([8, 16, 32, 64, 128]),
                'res_depth': tune.choice([1, 2, 3, 4]),
                'batch_size': 8,
                },
        # config={'learning_rate': 0.01,
        #         'n_fmaps': 32,
        #         'batch_size': 32,
        #         },
        # stop=TrialPlateauStopper('loss_G'),
        resources_per_trial={'cpu': os.cpu_count(),
                             'gpu': gpus_per_trial,
                            },
        local_dir='./drive/MyDrive/Logs',
        scheduler = PopulationBasedTraining(time_attr='training_iteration',
                                            quantile_fraction=0.5,
                                            resample_probability=0.8,
                                            perturbation_interval=1,
                                            hyperparam_mutations={
                                                'learning_rate': tune.loguniform(1e-7, 1e-1),
                                                'batch_size': [8, 16, 32, 64, 128],
                                            },
                                            ),
        progress_reporter=JupyterNotebookReporter(
            overwrite=False,
            parameter_columns=['learning_rate', 'n_fmaps', 'res_depth', 'batch_size'],
            metric_columns=['loss_G', 'loss_D', 'auroc', 'auroc_cross', 'training_iteration'],
            sort_by_metric=True,
        ),
        fail_fast = True,
        # reuse_actors=True,
        # num_samples=num_samples,
        resume='PROMPT',
    )

    print('Best hyperparameters found were: ', analysis.best_config)

# __tune_pbt_end__


if __name__ == '__main__':
    import argparse

    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--smoke-test', action='store_true', help='Finish quickly for testing')
    args, _ = parser.parse_known_args()

    if args.smoke_test:
        tune_Stage2_asha(num_samples=1, num_epochs=6, gpus_per_trial=torch.cuda.device_count())
        tune_Stage2_pbt(num_samples=1, num_epochs=6, gpus_per_trial=torch.cuda.device_count())
    else:
        # ASHA scheduler
        # tune_Stage2_asha(num_samples=1, num_epochs=10, gpus_per_trial=torch.cuda.device_count())
        # Population based training
        tune_Stage2_pbt(num_samples=1, num_epochs=30, gpus_per_trial=torch.cuda.device_count())