# pytorch metrics playground

Author: Jacob A Rose  
Created: Monday, April 19th, 2021

In [1]:
%load_ext autoreload 
%autoreload 2

%matplotlib inline


from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = "all"

# InteractiveShell.ast_node_interactivity = "last_expr"

In [1]:
from pathlib import Path

In [3]:
path = 'fam/image.png'

Path(path).name

'image.png'

In [2]:
from IPython.core.debugger import set_trace

import os
import types
os.environ['CUDA_VISIBLE_DEVICES'] = "0"
import numpy as np
import pytorch_lightning as pl
from torchvision import models
import torchvision
import torch
import timm
from rich import print
import matplotlib.pyplot as plt
from contrastive_learning.data.pytorch.pnas import PNASLightningDataModule
from contrastive_learning.data.pytorch.extant import ExtantLightningDataModule
from contrastive_learning.data.pytorch.common import DataStageError, colorbar

from lightning_hydra_classifiers.callbacks.wandb_callbacks import WatchModelWithWandb, LogPerClassMetricsToWandb, WandbClassificationCallback # LogConfusionMatrixToWandb
from lightning_hydra_classifiers.models.resnet import ResNet, get_scalar_metrics
import lightning_hydra_classifiers
from torch import nn
import inspect

import wandb
pl.trainer.seed_everything(seed=9)

    
class Config:
    pass


config = Config()

config.model_name = 'resnet50'
# config.dataset_name = 'PNAS_family_100_512'
config.dataset_name = 'Extant_family_10_512'
config.normalize = True
config.num_workers = 4
config.batch_size = 48
config.debug=False
########################################
def get_datamodule(config):
    if 'Extant' in config.dataset_name:
        datamodule = ExtantLightningDataModule(name=config.dataset_name,
                                               batch_size=config.batch_size,
                                               debug=config.debug,
                                               normalize=config.normalize,
                                               num_workers=config.num_workers)
    elif 'PNAS' in config.dataset_name:
        datamodule = PNASLightningDataModule(name=config.dataset_name,
                                             batch_size=config.batch_size,
                                             debug=config.debug,
                                             normalize=config.normalize,
                                             num_workers=config.num_workers)
    
    return datamodule
    
datamodule = get_datamodule(config)
        
datamodule.setup('fit')
datamodule.setup('test')
########################################
num_classes = len(datamodule.classes)
config.num_classes = num_classes

Global seed set to 9


In [3]:
classes = datamodule.classes

In [4]:
assert (num_classes == 19) | (num_classes == 179)

########################################
model = ResNet(model_name=config.model_name, num_classes=config.num_classes)
model.reset_classifier(config.num_classes,'avg')
model.unfreeze(model.layer4)
####################################################

2048 100352


In [5]:
ckpt_path = "/media/data/jacob/GitHub/prj_fossils_contrastive/notebooks/wandb/run-20210508_041808-2dj05hbk/files/default/2dj05hbk/checkpoints/epoch=19-step=11259.ckpt"
# ckpt = torch.load(ckpt_path)
# loaded_model = model.load_from_checkpoint(ckpt_path)

# test_dataloader = datamodule.test_dataloader()
# num_batches = len(test_dataloader)

# print(num_batches, "batches")
# batch = next(iter(test_dataloader))
# print(f'batch_size = {batch[0].shape[0]}')

In [6]:

class ImagePredictionLogger(pl.Callback):
    def __init__(self, val_samples, num_samples=32):
        super().__init__()
        self.num_samples = num_samples
        self.val_imgs, self.val_labels = val_samples
        
        
    def on_epoch_end(self, trainer, pl_module):
        # Bring the tensors to CPU
        val_imgs = self.val_imgs.to(device=pl_module.device)
        val_labels = self.val_labels.to(device=pl_module.device)
        # Get model prediction
        if pl_module.training:
            subset='train'
        else:
            subset='val'
        
        
        logits = pl_module(val_imgs)
        preds = torch.argmax(logits, -1)
        # Log the images as wandb Image
        trainer.logger.experiment.log({
            f"{subset}/examples":[wandb.Image(x, caption=f"Pred:{pred}, Label:{y}") 
                           for x, pred, y in zip(val_imgs[:self.num_samples], 
                                                 preds[:self.num_samples], 
                                                 val_labels[:self.num_samples])]
            })


datamodule.batch_size = 32
val_dataloader = datamodule.val_dataloader()
        
# val_dataloader.batch_size = 32
image_pred_logger_cb = ImagePredictionLogger(next(iter(val_dataloader)))

In [7]:
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pl_bolts.callbacks import ModuleDataMonitor, BatchGradientVerificationCallback
# verification = BatchGradientVerificationCallback()


wandb.init(name=f"{config.dataset_name}-timm-{config.model_name}",
           group='baselines',
           config=config)

monitor = ModuleDataMonitor()#log_every_n_steps=25)
per_class_metric_plots_cb = LogPerClassMetricsToWandb()
early_stop_callback = EarlyStopping(
                                    monitor='val_loss',
                                    patience=3,
                                    verbose=False,
                                    mode='min'
                                    )

logger=pl.loggers.wandb.WandbLogger(name=f"{config.dataset_name}-timm-{config.model_name}", config=config)
filepath = wandb.run.dir + "/epoch={epoch:02d}-val_acc={val/acc/top1:.2f}.ckpt"

trainer = pl.Trainer(gpus=1,
                     logger=logger,
                     max_epochs=40,
                     weights_summary='top',#,
                     profiler="simple", #"advanced", #
                     callbacks=[per_class_metric_plots_cb,
                                image_pred_logger_cb,
                                monitor,
                                early_stop_callback])

[34m[1mwandb[0m: Currently logged in as: [33mjrose[0m (use `wandb login --relogin` to force relogin)
[34m[1mwandb[0m: wandb version 0.10.30 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


GPU available: True, used: True
TPU available: False, using: 0 TPU cores


In [8]:
# %%wandb
trainer.fit(model, datamodule)
results = trainer.test(model, datamodule=datamodule)
# trainer.callbacks[-1].best_model_path
print(trainer.callbacks[-1].best_model_path,
trainer.callbacks[-1].best_model_score.cpu().numpy())
# ckpt_path = "/media/data/jacob/GitHub/prj_fossils_contrastive/notebooks/wandb/run-20210508_041808-2dj05hbk/files/default/2dj05hbk/checkpoints/epoch=19-step=11259.ckpt"
wandb.finish()

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

   | Name          | Type              | Params | In sizes          | Out sizes        
---------------------------------------------------------------------------------------------
0  | relu          | ReLU              | 0      | [2, 64, 112, 112] | [2, 64, 112, 112]
1  | conv1         | Conv2d            | 9.4 K  | [2, 3, 224, 224]  | [2, 64, 112, 112]
2  | bn1           | BatchNorm2d       | 128    | [2, 64, 112, 112] | [2, 64, 112, 112]
3  | maxpool       | MaxPool2d         | 0      | [2, 64, 112, 112] | [2, 64, 56, 56]  
4  | layer1        | Sequential        | 215 K  | [2, 64, 56, 56]   | [2, 256, 56, 56] 
5  | layer2        | Sequential        | 1.2 M  | [2, 256, 56, 56]  | [2, 512, 28, 28] 
6  | layer3        | Sequential        | 7.1 M  | [2, 512, 28, 28]  | [2, 1024, 14, 14]
7  | layer4        | Sequential        | 15.0 M | [2, 1024, 14, 14] | [2, 2048, 7, 7]  
8  | stem          | Sequential        | 9.5 K  | [2, 3, 224, 224]  | [

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…






Profiler Report

Action                             	|  Mean duration (s)	|Num calls      	|  Total time (s) 	|  Percentage %   	|
------------------------------------------------------------------------------------------------------------------------------------
Total                              	|  -              	|_              	|  3497.6         	|  100 %          	|
------------------------------------------------------------------------------------------------------------------------------------
run_training_epoch                 	|  436.09         	|8              	|  3488.8         	|  99.746         	|
get_train_batch                    	|  0.70161        	|3376           	|  2368.6         	|  67.721         	|
run_training_batch                 	|  0.11024        	|3376           	|  372.17         	|  10.641         	|
model_forward                      	|  0.10999        	|3376           	|  371.34         	|  10.617         	|
training_step_end                  	|  0.

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Testing', layout=Layout(flex='2'), max=…


--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'test/acc/top1': 0.5878283977508545,
 'test/acc/top3': 0.7602733969688416,
 'test/precision/top1': 0.4063551723957062,
 'test/recall/top1': 0.3104079067707062,
 'test_loss': 2.1013941764831543}
--------------------------------------------------------------------------------


VBox(children=(Label(value=' 174.04MB of 174.04MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=…

0,1
val/batch_idx,105.0
val/acc/top1,0.85498
val/acc/top3,0.94632
val/precision/top1,0.78202
val/recall/top1,0.74961
_runtime,3814.0
_timestamp,1621041055.0
_step,5572.0
epoch,7.0
train/batch_idx,421.0


0,1
val/batch_idx,▁▂▄▆▇▁▃▄▆▇▁▃▄▆▇▁▃▄▆▇▁▃▅▆▇▂▃▄▆█▂▃▅▆█▂▃▅▆█
val/acc/top1,▁▂▂▁▃▃▄▃▃▂▆▅▅▆▃▆▅▆▄▅▃▆▆▆▆▆▆▆▆▆▆▅▆█▇▅▆▆▆█
val/acc/top3,▂▁▂▁▄▄▄▃▃▃▇▆▅▇▆▅▇▇▆▆▅▆▇█▇▆▆▇▆▆▆▇▆▇▇▆▅▇▇▇
val/precision/top1,▁▁▁▂▂▃▂▃▂▁▅▅▄▁▃▆▅▆▄▄▄▅▄█▆▆▄▄▅▆▅▅▇▆▇▅▆▅▇▇
val/recall/top1,▁▁▁▂▂▃▃▃▂▂▆▅▅▁▄▇▅▅▄▅▄▅▄█▆▆▅▄▅▆▆▆▇▆▇▅▇▅▇▇
_runtime,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
_timestamp,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
_step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
epoch,▁▁▁▁▁▁▂▂▂▂▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇█████
train/batch_idx,▁▂▄▆▇▁▃▄▆▇▁▃▄▆▇▁▃▄▆▇▁▃▄▆█▁▃▅▆█▂▃▅▆█▂▃▅▆█


## Extra: top k metrics

In [None]:
def select_k_highest_loss_images(batch, model, topk=5):
    
    x, y = batch
    model.eval()

    logits = model(x)
    preds = torch.argmax(logits, -1)

    loss_fn = nn.CrossEntropyLoss(reduction='none')
    loss = loss_fn(logits, y.long())

    highest_losses, highest_losses_idx = torch.sort(loss, descending=True)
    highest_losses = highest_losses.detach().numpy()
    highest_losses_idx = highest_losses_idx.numpy()

    
    topk_losses  = highest_losses[:topk]
    select_idx = highest_losses_idx[:topk]
    
    y_true = y[select_idx].numpy()
    y_prob = logits.softmax(dim=-1)
    topk_logits, topk_y_pred = logits.topk(topk)
#     topk_y_prob = y_prob[topk_y_pred].detach().numpy()

    topk_logits = topk_logits.detach().numpy()
    topk_y_pred = topk_y_pred.numpy()
    topk_y_prob = [y_prob[i,topk_y_pred[i,:]].detach().numpy() for i in range(len(topk_y_pred))]

    topk_pred_labels = [(classes[int(k)] for k in y_pred_i) \
                        for y_pred_i in topk_y_pred]
    true_labels = [classes[int(k)] for k in y_true]

    
    return (topk_losses,
            topk_pred_labels,
            topk_logits,
            topk_y_pred,
            topk_y_prob,
            true_labels,
            y_true,
            x,
            select_idx)
    
    
topk = 3
    
topk_losses, topk_pred_labels, topk_logits, topk_y_pred, topk_y_prob, true_labels, y_true, x, select_idx = \
                        select_k_highest_loss_images(batch=batch, model=loaded_model, topk=topk)

print(len(topk_pred_labels),
      len(topk_logits),
      topk_y_pred.shape,
      len(topk_y_prob),
      len(true_labels),
      y_true.shape,
      x.shape,
     len(select_idx))
    
    
# for i, batch in enumerate(iter(test_dataloader)):
#     print(i)
#     print(batch[0].shape)
#     if i >= num_batches:
#         break

# logit = logits[i,:]

def plot_example_topk_preds(x, logit, y_true, topk, classes, cmap: str='cividis', grayscale=True):
    y_true = int(y_true.item())
    y_prob = logit.softmax(dim=-1)
    topk_logits, topk_y_pred = logit.topk(topk)
    topk_y_prob = y_prob[topk_y_pred].detach().numpy()

    topk_logits = topk_logits.detach().numpy()
    topk_y_pred = topk_y_pred.numpy()

    topk_pred_labels = [classes[k] for k in topk_y_pred]
    true_label = classes[y_true]

    fig, ax = plt.subplots(1,2, figsize=(15,10))

    img = x.permute(1,2,0)
    img_min, img_max = img.min(), img.max()
    
    if grayscale and len(img.shape)==3:
        img = img[:,:,0]    
    
    img_ax = ax[0].imshow(img, cmap=cmap, vmin = img_min, vmax = img_max)
#     colorbar(img_ax)
    
    bar_x = [*list(topk_pred_labels), "other"][::-1]
    bar_y = [*list(topk_y_prob), (1 - sum(topk_y_prob))][::-1]
    ax[1].barh(bar_x, bar_y)
    ax[1].set_xlim([0,1])
    plt.suptitle(f'True label: {true_label}')
    plt.tight_layout()
    return fig, ax

topk_y_pred

i = 0
topk = 3
num_predictions = 5
seed = 387
batch_size = x.shape[0]


topk_losses, topk_pred_labels, topk_logits, topk_y_pred, topk_y_prob, true_labels, y_true, x, select_idx = select_k_highest_loss_images(batch=batch, model=loaded_model, topk=topk)

# plt.style.available
# plt.style.use('tableau-colorblind10')
plt.style.use('seaborn-pastel')

class_idx, class_counts = np.unique(datamodule.test_dataset.targets, return_counts=True)
sorted_class_counts = np.argsort(class_counts)#, reverse=True)
# sorted_class_counts[class_counts[[10,11]]]

plt.bar(class_idx, sorted(class_counts, reverse=True), log=True, color='blue')
plt.bar([23,56],sorted(class_counts[[23,56]], reverse=True), log=True, color='red')

for i in range(len(select_idx)):
    plot_example_topk_preds(x[i,...], logit=torch.Tensor(topk_logits[i,:]), y_true=torch.Tensor(y_true)[i], topk=topk, classes=classes)

# print(x[i,...].shape, topk_logits[i,:].shape, y_true[i].shape, topk,len(classes))
# print(type(x[i,...]), topk_logits[i,:], y_true[i], topk,len(classes))
i = 0
topk = 3
num_predictions = 3
seed = 387
batch_size = x.shape[0]
rng = np.random.default_rng(seed)
batch_idx = rng.choice(batch_size, num_predictions, replace=False)
print(f'Plotting top {topk} predictions for randomly chosen images at batch indices: {batch_idx}')
# for i in select_idx:
#     plot_example_topk_preds(x[i,...], logits[i,:], y_true=y[i], topk=topk, classes=classes)

In [None]:
# # _, argsorted_logit = torch.sort(logit, descending=True)
# # false_negative_rank = argsorted_logit[y[0]]
# print(topk_logits.shape, topk_y_pred.shape)
# print(topk_logits, topk_y_pred)

# print(f"True: {true_labels}")
# print(f"Misclassified by ranking true label {false_negative_rank} out of {num_classes} classes")
# print(f"Pred Top {topk}: {topk_pred_labels}")
# print(f"Prob Top {topk}: {topk_y_prob}")


# print(f"True: {true_labels}")
# print(f"Misclassified by ranking true label {false_negative_rank} out of {num_classes} classes")
# print(f"Pred Top {topk}: {topk_pred_labels}")
# print(f"Prob Top {topk}: {topk_y_prob}")

# num_predictions = 10
# print(d[0].shape, d[1].shape)
# print(d[0][:num_predictions,:], d[1][:num_predictions,:])

# y.view(-1,1).dtype#.shape

# preds.float()

# x, y = batch
# # x, y = x[:1], y[:1]
# print(x.shape, y.shape)
# logits = loaded_model(x)
# preds = torch.argmax(logits, -1)
# loss_fn = nn.CrossEntropyLoss(reduction='none')
# print(logits.shape, preds.shape, y.shape, loss.shape)
# loss = loss_fn(logits, y.long())#view(-1,1).float())
# print(logits.shape, preds.shape, y.shape, loss.shape)

## Plot confused images

In [None]:
x, y = datamodule.get_batch()
plt.imshow(x[0,...].permute(1,2,0))
plt.suptitle(f'True label: {datamodule.classes[y[0]]}')

y_pred_loaded = loaded(x)

num_predictions = 5
topk = 3
y_prob_topk, y_idx_topk = y_pred_loaded.softmax(dim=1).topk(topk)

fig, ax = plt.subplots(3,3, figsize=(15,15))
ax = ax.flatten()

for i in range(num_predictions):
    x_i = x[i,...].permute(1,2,0)
    x_i = (x_i-x_i.min())/(x_i.max()-x_i.min())
    ax[i].imshow(x_i)
    title = f"True: {classes[y[i]]}\n"
    title += f'Prob: {[classes[idx] for idx in y_idx_topk[i,:]]}'
    ax[i].set_title(title)
    
plt.tight_layout()

In [None]:
y_idx_topk.shape

### Scratch

In [None]:
# pl.optim.LightningOptimizer

%debug

# x, y = datamodule.get_batch()
# opt = model.configure_optimizers()

# def get_stats(y: torch.Tensor):
#     try:
#         y = y.detach().numpy()
#     except AttributeError:
#         pass
#     print('shape:', y.shape,
#           f'mean: {np.mean(y):.2f}',
#           f'std: {np.std(y):.2f}',
#           f'min: {np.min(y):.2f}',
#           f'max: {np.max(y):.2f}')
import json


class StatTensor:

    def __init__(self, data, decimals=3, *args, **kwargs):
#         super().__init__(**kwargs)
        self._data = getattr(data, 'data', data)
        self.decimals = decimals
    
    @property
    def data(self):
        return self._data

    @property
    def min(self):
        return torch.min(self.data).item()

    @property
    def max(self):
        return torch.max(self.data).item()

    
    def __add__(self, other):
        self.data += other.data
        self.decimals = max([self.decimals, other.decimals])
        return self
    
    def get_stats(self):
        data = self.data
        stats = {'shape':data.shape,
                 'mean':torch.mean(data).item(),
                 'std':torch.std(data).item(),
                 'min':torch.min(data).item(),
                 'max':torch.max(data).item()}

        return self._format_stats(stats)
    
    def _format_stats(self, stats):
        for k in list(stats.keys()):
            if k == 'shape':
                continue
            stats[k] = np.round(stats[k], decimals=self.decimals)
        return stats

    def __repr__(self):
        return json.dumps(self.get_stats())
    
import torchmetrics as metrics
    
class AverageStatTensor:
    
    def __init__(self):
        self._data = []
        self.mean = metrics.AverageMeter()
        self.std = metrics.AverageMeter()
        self.min = np.inf
        self.max = -np.inf
        self.decimals = 3

        
    @property
    def data(self):
        return self._data
        
    def _update_metrics(self, data):
        stats = data.get_stats()
        self.mean.update(stats['mean'])
        self.std.update(stats['std'])
        self.min = np.min([self.min, data.min])
        self.max = np.max([self.max, data.max])

        
    def update(self, data):
        self._data.append(StatTensor(data, 3))
        self._update_metrics(self.data[-1])
        return self.compute()
    
    def _format_stats(self, stats):
        for k in list(stats.keys()):
            if k == 'shape':
                continue
            stats[k] = np.round(stats[k], decimals=self.decimals)
        return stats
    
    def compute(self):
        return self._format_stats(
                {'mean':self.mean.compute().item(),
                 'std':self.std.compute().item(),
                 'min':self.min,
                 'max':self.max})

In [None]:
data = datamodule.train_dataset

In [None]:
# x = torch.rand((3,3))
# print(x)
# x_avg = AverageStatTensor()
# x, y = data[0]
# x_avg.update(x)
# print(x_avg.compute())
# print(x_avg.update(x+1))
# # print(x_avg.compute())
# print(x_avg.update(x-1-1))

In [None]:
x, y = datamodule.get_batch()

In [None]:
x_avg = AverageStatTensor()
x, y = data[0]
print(x_avg.update(x))


x_avg1 = AverageStatTensor()
x1, y1 = data[1]
print(x_avg1.update(x1))



In [None]:
def collect_running_stats_from_dataset(data: torch.utils.data.Dataset):
    x_avg = AverageStatTensor()

    history = {'batch_idx':[],
               'mean':[],
               'std':[],
               'max':[],
               'min':[]}

    for i, (x, y) in enumerate(data):
        stats = x_avg.update(x)

        history['batch_idx'].append(i)
        history['mean'].append(stats['mean'])
        history['std'].append(stats['std'])
        history['max'].append(stats['max'])
        history['min'].append(stats['min'])

    return history

import matplotlib.pyplot as plt


def plot_history(history, label: str=None, title: str=None, ax=None, **kwargs):
    if ax is None:
        fig, ax = plt.subplots(1,1)
    else:
        fig = plt.gcf()
        
#     if 'c' in kwargs:
    c = kwargs['c']
#     else:
#         c = 'm'

        
    idx = history['batch_idx']
    ax.plot(idx, history['mean'], c+'-', label=label)
    ax.plot(idx, [m+std for m, std in  zip(history['mean'], history['std'])], c+':')#, label=label)
    ax.plot(idx, [m-std for m, std in  zip(history['mean'], history['std'])], c+':')#, label=label)
#     ax.plot(idx, history['max'], 'k'+'.', label=label)
#     ax.plot(idx, history['min'], 'k'+'.', label=label)
    plt.suptitle(title)
    return fig, ax

In [None]:
%%time

train_data = datamodule.train_dataset
history = collect_running_stats_from_dataset(train_data)
plot_history(history, title='train images -- mean pixel intensities')

In [None]:
%%time

val_data = datamodule.val_dataset
val_history = collect_running_stats_from_dataset(val_data)
plot_history(val_history, title='val images -- mean pixel intensities')

In [None]:
%%time
datamodule.setup('test')
test_data = datamodule.test_dataset
test_history = collect_running_stats_from_dataset(test_data)
plot_history(test_history, title='test images -- mean pixel intensities')

In [None]:
plt.style.available

In [None]:
plt.style.use('seaborn-darkgrid')
# plt.style.use('seaborn-talk')

In [None]:
fig, ax = plt.subplots(1,1, figsize=(15,10))

plot_history(history, label='train images', ax=ax, c='r')#kwargs = {'c':'r'})
plot_history(val_history, label='val images', ax=ax, c='b')#kwargs = {'c':'b'})
plot_history(test_history, label='test images', ax=ax, c='g')#kwargs = {'c':'g'})
plt.legend()
plt.suptitle(f'Running mean of batch pixel intensities in PNAS', size='xx-large')

In [None]:
class MetadataTensor(object):
    def __init__(self, data, metadata=None, **kwargs):
        self._t = torch.as_tensor(data, **kwargs)
        self._metadata = metadata

    def __repr__(self):
        return "Metadata:\n{}\n\ndata:\n{}".format(self._metadata, self._t)

    def __torch_function__(self, func, types, args=(), kwargs=None):
        if kwargs is None:
            kwargs = {}
        args = [a._t if hasattr(a, '_t') else a for a in args]
        ret = func(*args, **kwargs)
        return MetadataTensor(ret, metadata=self._metadata)

In [None]:
def get_stats(y: torch.Tensor):
    try:
        y = y.detach().numpy()
    except AttributeError:
        pass
    return {'shape':y.shape,
            'mean':np.mean(y),
            'std':np.std(y),
            'min':np.min(y),
            'max':np.max(y)}


def fit_one_batch():
    y_hat = model(x)

    loss = model.loss(y_hat, y)
    y_prob = model.probs(y_hat)

    y_hat_int, y_pred = torch.max(y_prob, dim=1)
    loss.register_hook(get_stats)
    
    loss.backward()
    opt.step()
   
    # norms = model.grad_norm(2)

In [None]:
for name, p in model.named_parameters():
    print(name);print(p.requires_grad)
#     if isinstance(p, (nn.Conv2d, nn.BatchNorm2d, nn.ReLU, nn.MaxPool2d)):
#         continue
#     else:
#         print(name);print(p.requires_grad)

In [None]:
from torch import nn, Tensor

class VerboseExecution(nn.Module):
    def __init__(self, model: nn.Module):
        super().__init__()
        self.model = model

        # Register a hook for each layer
        for name, layer in self.model.named_children():
            layer.__name__ = name
            layer.register_forward_hook(self.module_hook)

    def forward(self, x: Tensor) -> Tensor:
        return self.model(x)

    @staticmethod
    def module_hook(module: nn.Module, input: Tensor, output: Tensor):
        print(f"{module.__name__}: {output.shape}")
        
        for p in module.parameters():
            if p.requires_grad:
                print(f"requires_grad: {p.requires_grad}")
                p.register_hook(VerboseExecution.tensor_hook)

        if not isinstance(input, (tuple, list)):
            input = (input,)
        if not isinstance(output, (tuple, list)):
            output = (output,)
            
        for i in input:
            get_stats(i)
        for o in output:
            get_stats(o)
        
    @staticmethod    
    def tensor_hook(grad: Tensor):
        # For Tensor objects only.
        # Only executed during the *backward* pass!
        print("Gradient stats:")
        get_stats(grad)
        print('='*20)
        


# import torch
# from torchvision.models import resnet50

verbose_resnet = VerboseExecution(model) #resnet50())
dummy_input = torch.ones(10, 3, 224, 224)
result = verbose_resnet(dummy_input)

print(result.shape)
# conv1: torch.Size([10, 64, 112, 112])
# bn1: torch.Size([10, 64, 112, 112])
# relu: torch.Size([10, 64, 112, 112])
# maxpool: torch.Size([10, 64, 56, 56])
# layer1: torch.Size([10, 256, 56, 56])
# layer2: torch.Size([10, 512, 28, 28])
# layer3: torch.Size([10, 1024, 14, 14])
# layer4: torch.Size([10, 2048, 7, 7])
# avgpool: torch.Size([10, 2048, 1, 1])

In [None]:
result.backward()

In [None]:

from typing import Dict, Iterable, Callable

class FeatureExtractor(nn.Module):
    def __init__(self, model: nn.Module, layers: Iterable[str]):
        super().__init__()
        self.model = model
        self.layers = layers
        self._features = {layer: torch.empty(0) for layer in layers}

        for layer_id in layers:
            layer = dict([*self.model.named_modules()])[layer_id]
            layer.register_forward_hook(self.save_outputs_hook(layer_id))

    def save_outputs_hook(self, layer_id: str) -> Callable:
        def fn(_, __, output):
            self._features[layer_id] = output
        return fn

    def forward(self, x: Tensor) -> Dict[str, Tensor]:
        _ = self.model(x)
        return self._features
    
    
    

resnet_features = FeatureExtractor(resnet50(), layers=["layer4", "avgpool"])
features = resnet_features(dummy_input)

print({name: output.shape for name, output in features.items()})
# {'layer4': torch.Size([10, 2048, 7, 7]), 'avgpool': torch.Size([10, 2048, 1, 1])}

In [None]:
def gradient_clipper(model: nn.Module, val: float) -> nn.Module:
    for parameter in model.parameters():
        parameter.register_hook(lambda grad: grad.clamp_(-val, val))
    
    return model

In [None]:

clipped_resnet = gradient_clipper(resnet50(), 0.01)
pred = clipped_resnet(dummy_input)
loss = pred.log().mean()
loss.backward()

print(clipped_resnet.fc.bias.grad[:25])

In [None]:
# for i, p in enumerate(model.parameters()):
#     print(i, type(p))


for l in model.forward_features:
    print(type(l), l.name)

In [None]:
model.forward_features[0]

In [None]:
dir(l)

In [None]:
for i, p in enumerate(model.children()):
    print(i, type(p))

for name, p in model.named_parameters():
    print(name, type(p))

In [None]:
dir(trainer)

dir(model)

%debug

## Inspect single backward pass gradients manually

In [None]:
opt = model.configure_optimizers()


x, y = datamodule.get_batch(stage='train', batch_idx=0)
y_hat = model(x)

loss = model.loss(y_hat, y)
y_prob = model.probs(y_hat)
y_hat_int, y_pred = torch.max(y_prob, dim=1)
loss.backward()
opt.step()
norms = model.grad_norm(2)
print(norms)

datamodule.show_batch()

### Model testing utils playground

In [None]:
def display_model_params_requires_grad(model, prefix=''):
    for name, p in model.named_parameters():
        print(prefix + name, type(p), p.shape, f'requires_grad=={p.requires_grad}')

from torch.nn.modules.container import Sequential
def recursive_print_model_children(model, prefix='', lw=15):
    for l in model.children():
        print(prefix + '='*lw)
        print(prefix + str(type(l)))
        if isinstance(l, Sequential):
            recursive_print_model_children(l, prefix='\t'+prefix)
        elif hasattr(l, 'named_parameters'):
            display_model_params_requires_grad(l, prefix='\t'+prefix)
#             for param in l.parameters():
#                 if hasattr(l, 'requires_grad'):
#                     print(f'requires_grad={l.requires_grad}')

# from torchsummary import summary
# summary(model.cuda(), input_size=(3, 224, 224))
                    
# display_model_params_requires_grad(model)
# recursive_print_model_children(model, prefix='\n')




backbone_name = 'resnet18'
model = timm.create_model(backbone_name, pretrained=True)

display_model_params_requires_grad(model)

from pl_bolts.utils import BatchGradientVerification

def perform_batch_gradient_verification(model, input_size=(3, 224, 224), input_array=None):
    """
    Checks if a model mixes data across the batch dimension.
    
    This can happen if reshape- and/or permutation operations are carried out in the wrong order or
    on the wrong tensor dimensions.
    
    Examples:
    ========
    perform_batch_gradient_verification(model)

    perform_batch_gradient_verification(model, input_size=(3, 224, 224))

    perform_batch_gradient_verification(model, input_array=torch.rand(2,3,224,224))
    
    """
    if input_array is None:
        input_array = model.example_input_array
    if input_array is None:
        input_array = torch.rand(2,*input_size)
    verification = BatchGradientVerification(model)
    valid = verification.check(input_array=input_array, sample_idx=1)
    
    if valid:
        print('Test: [PASSED]',
              '\n==============\n',
              'Model passed batch gradient verification test!\n',
              'Confirmed no data mixing occurs across batch dimension')
    
    return valid


valid = perform_batch_gradient_verification(model)

valid

## Miscellaneous -- End

In [None]:
x, y = next(iter(datamodule.train_dataloader()))
y = y.detach().numpy()

In [None]:
y_hat = model(x)

In [None]:
y_pred = y_hat.argmax(dim=1).detach().numpy()
y_hat_max = torch.max(y_hat, dim=1).values.detach().numpy()

# list(y_hat_max.detach().numpy())

In [None]:
datamodule.mean
datamodule.std

In [None]:
import numpy as np

def get_stats(y: torch.Tensor):
    try:
        y = y.detach().numpy()
    except AttributeError:
        pass
    print('shape:', y.shape,
          f'mean: {np.mean(y):.2f}',
          f'std: {np.std(y):.2f}',
          f'min: {np.min(y):.2f}',
          f'max: {np.max(y):.2f}')
    
def unnormalize(img, mean=None, std=None):
    if mean is None:
        img = img / 2 + 0.5 # unnormalise
    else:
        img = img * std + mean
    img = img.numpy()
    img = np.transpose(img, (1,2,0))
    return img

    
    
    
print('y_hat:')
get_stats(y_hat)

print('y_pred:')
get_stats(y_pred)

print('y_true:')
get_stats(y)

fig, ax = plt.subplots(1,1, figsize=(10,10))
grid_img = torchvision.utils.make_grid(x, nrow=7)
c, h, w = grid_img.shape
grid_img = unnormalize(grid_img, torch.Tensor(datamodule.mean).view(c,1,1), torch.Tensor(datamodule.std).view(c,1,1))
# plt.style('presentation')
# pos = ax.imshow(grid_img.permute(1,2,0), cmap='gray')
# fig.colorbar(pos, ax=ax)
get_stats(grid_img)
plt.imshow(grid_img, cmap='gray')
plt.colorbar()

# plt.style.use('seaborn-talk')
x.shape

plt.imshow(x[-1,...].permute(1,2,0))

In [None]:
import matplotlib.pyplot as plt

# plt.imshow?

fig, ax = plt.subplots(1,1)

plt.colorbar?

In [None]:
# project the digits into 2 dimensions using IsoMap
from sklearn.manifold import Isomap
iso = Isomap(n_components=2)
# x = x.numpy()
projection = iso.fit_transform([x[i,0,:,:].ravel() for i in range(x.shape[0])])


In [None]:
get_stats(projection[:,0])
get_stats(projection[:,1])

In [None]:
get_stats(projection[:,0])
get_stats(projection[:,1])

stats = np.min(projection[:,0]), np.max(projection[:,0])
projection[:,0] = (projection[:,0]-stats[0])/(stats[1]-stats[0])

stats = np.min(projection[:,1]), np.max(projection[:,1])
projection[:,1] = (projection[:,1]-stats[0])/(stats[1]-stats[0])

# plot the results
plt.scatter(projection[:, 0], projection[:, 1],# lw=0.1,
            c=y, cmap=plt.cm.get_cmap('cubehelix', 19))
plt.colorbar(ticks=range(19), label='leaf family')
# plt.clim(0.0, 1.0)
# plt.clim(-0.5, 5.5)

# plot the results
plt.scatter(projection[:, 0], projection[:, 1],# lw=0.1,
            c=y_pred, cmap=plt.cm.get_cmap('cubehelix', 19))
plt.colorbar(ticks=range(19), label='leaf family')
# plt.clim(0.0, 1.0)
# plt.clim(-0.5, 5.5)

projection

for i in range(48):
    get_stats(x[i,...])

In [None]:

print([('y_hat','y_pred','y_true'), *list(zip(list(y_hat_max), list(y_pred), list(y)))])



list(a.values.detach().numpy())

print(y_hat.shape)

%debug

child_counter = 0
for child in model.children():
    print(" child", child_counter, "is -")
    print(child)
    child_counter += 1

Log histogram of losses

fig = plt.figure()
losses = np.stack([x['val_loss'].numpy() for x in outputs])
plt.hist(losses)

import torchmetrics

torchmetrics.__version__

torchmetrics.__version__()

In [None]:
import numpy as np

np.log(1/19)



datamodule.setup('fit')
train_dataloader = datamodule.train_dataloader()    
    
classes = datamodule.classes
num_classes = len(classes)
# Get training images
dataiter = iter(train_dataloader)
images, labels = next(dataiter)#.next()


# def imshow(img):
#     img = img / 2 + 0.5 # unnormalise
#     npimg = img.numpy()
#     plt.imshow(np.transpose(npimg, (1,2,0)))
#     plt.show()


# imshow(torchvision.utils.make_grid(images))
# print(' '.join('%5s' % classes[labels[j]] for j in range(4)))

train_metrics = get_metrics(num_classes=num_classes, prefix='train')

train_metrics

def eval_one_batch(model, batch):

    x, y = batch
    y_hat = model(x)
    loss = model.loss(y_hat, y)
    y_prob = model.probs(y_hat)
    y_pred = y_prob.argmax(dim=1)


    results = {'loss':loss,
               'y_hat':y_hat,
               'y_prob':y_prob,
               'y_pred':y_pred,
               'y_true':y}
    return results



results = eval_one_batch(model, batch=(images, labels))

metric_result = train_metrics(results['y_prob'], results['y_true'])

def TP_topk(y_prob, y_true, k: int=1):
    _, topk = torch.topk(y_prob, k=k) #.numpy()
#     return [y_true[i].detach().numpy() in topk[i] for i in range(len(y_true))]
    return [y_true[i] in topk[i] for i in range(len(y_true))]
    

y_prob, y_true = results['y_prob'], results['y_true']


tp_top1 = TP_topk(y_prob, y_true, k=1)
tp_top3 = TP_topk(y_prob, y_true, k=3)
tp_top5 = TP_topk(y_prob, y_true, k=5)

print(list(zip(tp_top1, tp_top3, tp_top5)))

# %debug

# %debug
# torch.topk(results['y_prob'], k=5)[1]

# for k, v in results.items():
#     if len(v.size())==0:
#         print(k, v.size(), v.detach().numpy())
#     elif len(v.size())==1:
#         print(k, v.size(), v.detach().numpy())
#     else:
#         print(k, v.size(), f'Min: {v.min().detach():.2f}, Max: {v.max().detach():.2f}')
#         print("per-sample argmax:", v.argmax(dim=1).detach().numpy())
        
# print(list(zip(y_pred.numpy(), y.numpy())))

# plt.imshow(x.cpu()[0,...].permute(1,2,0))

backbone_name = 'resnet50' #'xception41'
# backbone_name = 'resnet101'
# backbone_name = 'resnet152'
# model = timm.create_model(backbone_name, pretrained=True)
model.reset_classifier(19,'avg')
# model = Classifier(backbone,
#            head_source=None,
#            head_target=None,
#            num_classes=19,
#            finetune=True)

# helper function to show an image
# (used in the `plot_classes_preds` function below)
def matplotlib_imshow(img, one_channel=False):
    if one_channel:
        img = img.mean(dim=0)
    img = img / 2 + 0.5     # unnormalize
    npimg = img.numpy()
    if one_channel:
        plt.imshow(npimg, cmap="Greys")
    else:
        plt.imshow(np.transpose(npimg, (1, 2, 0)))

img_grid = torchvision.utils.make_grid(images)

# show images
matplotlib_imshow(img_grid, one_channel=True)


def log_softmax(x):
    return x - x.exp().sum(-1).log().unsqueeze(-1)


modules = list(model.modules())
len(modules)    
#         def _initialize_weights(self) -> None:
#         for m in self.modules():
#             if isinstance(m, torch.nn.Conv2d):
#                 torch.nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
#                 if m.bias is not None:
#                     torch.nn.init.constant_(m.bias, 0)
#             elif isinstance(m, torch.nn.BatchNorm2d):
#                 torch.nn.init.constant_(m.weight, 1)
#                 torch.nn.init.constant_(m.bias, 0)
#             elif isinstance(m, torch.nn.Linear):
#                 torch.nn.init.normal_(m.weight, 0, 0.01)
#                 torch.nn.init.constant_(m.bias, 0)

In [None]:
import inspect
# print(inspect.getsource(model.reset_classifier))
# inspect.__dir__()
model.__dir__()

import inspect
# print(inspect.getsource(model.reset_classifier))
# inspect.__dir__()
model.num_features

state = model.bn1.state_dict()

state['bias'].size()

model.bn1.state_dict()

model.__dict__

import inspect
print(inspect.getsource(model.num_features))

## ResNet50
for i, child in enumerate(model.children()):
    if isinstance(child, torch.nn.modules.container.Sequential):
#         break
        print(f'Found conv block at module child index: {i}')
        print('='*20)
    else:
        print(i, type(child))

for i, (name,child) in enumerate(model.named_children()):
    if isinstance(child, torch.nn.modules.container.Sequential):
#         break
        print(f'Found conv block at module child index: {i}')
        print(f'name is: {name}')
        print('='*20)
    else:
        print(i, type(child))
        print(f'name is: {name}')

from collections import OrderedDict

named_layers = OrderedDict(model.named_children())
layers = list(named_layers.values())
layer_names = list(named_layers.keys())
layer_names

stem = torch.nn.modules.container.Sequential(
    OrderedDict({l:named_layers[l] for l in ['conv1','bn1','act1', 'maxpool']}
))
stem

In [None]:
residual_group_names = layer_names[4:8]
print(residual_group_names)

learner = torch.nn.modules.container.Sequential(
    OrderedDict({f"res_group{i}":named_layers[l] for i,l in enumerate(residual_group_names)}
))
learner

classifier_layer_names = layer_names[8:]
print(classifier_layer_names)

classifier = torch.nn.modules.container.Sequential(
    OrderedDict({l:named_layers[l] for l in classifier_layer_names}
))
classifier

from torch.nn.modules.container import Sequential

macro_model = Sequential(
    OrderedDict({'stem':stem,
     'learner':learner,
     'classifier':classifier})
)

for n, mod in macro_model.named_children():
    print(n, len(mod))

torch.nn.modules.container.Sequential*(layers[:4])

from rich import print
import inspect

In [None]:
print(inspect.getsource(backbone.reset_classifier))

# list(backbone.children())
# backbone.__dir__()

trainer = pl.Trainer(gpus=1, logger=pl.loggers.wandb.WandbLogger(name="default-timm-resnet50"))
trainer.fit(model, datamodule)





model_name = 'resnet50' #'xception41'
model = timm.create_model(model_name, pretrained=True)

model.forward_features

model.__dir__()

import sys
import os
os.chdir("/media/data/jacob/GitHub/lightning-hydra-classifiers/src/models") #/pnas_model.py")


from .pnas_model import Classifier
# from contrastive_learning.data.pytorch.

In [None]:
class TimmLightningModel(LightningModule):
    """
    Example of LightningModule for MNIST classification.

    A LightningModule organizes your PyTorch code into 5 sections:
        - Computations (init).
        - Train loop (training_step)
        - Validation loop (validation_step)
        - Test loop (test_step)
        - Optimizers (configure_optimizers)

    Read the docs:
        https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html
    """

    def __init__(
        self,
        backbone: Union[str,nn.Module] = "resnet50"
        input_size: int = 784,
        lin1_size: int = 256,
        lin2_size: int = 256,
        lin3_size: int = 256,
        output_size: int = 10,
        lr: float = 0.001,
        weight_decay: float = 0.0005,
        **kwargs
    ):
        super().__init__()

        # this line ensures params passed to LightningModule will be saved to ckpt
        # it also allows to access params with 'self.hparams' attribute
        self.save_hyperparameters()
        
        model_name = 'resnet50' #'xception41'
        self.backbone = timm.create_model(backbone, pretrained=True)

        self.model = SimpleDenseNet(hparams=self.hparams)

        # loss function
        self.criterion = torch.nn.CrossEntropyLoss()

        # use separate metric instance for train, val and test step
        # to ensure a proper reduction over the epoch
        self.train_accuracy = Accuracy()
        self.val_accuracy = Accuracy()
        self.test_accuracy = Accuracy()

        self.metric_hist = {
            "train/acc": [],
            "val/acc": [],
            "train/loss": [],
            "val/loss": [],
        }

    def forward(self, x: torch.Tensor):
        return self.model(x)

    def step(self, batch: Any):
        x, y = batch
        logits = self.forward(x)
        loss = self.criterion(logits, y)
        preds = torch.argmax(logits, dim=1)
        return loss, preds, y

    def training_step(self, batch: Any, batch_idx: int):
        loss, preds, targets = self.step(batch)

        # log train metrics
        acc = self.train_accuracy(preds, targets)
        self.log("train/loss", loss, on_step=False, on_epoch=True, prog_bar=False)
        self.log("train/acc", acc, on_step=False, on_epoch=True, prog_bar=True)

        # we can return here dict with any tensors
        # and then read it in some callback or in training_epoch_end() below
        # remember to always return loss from training_step, or else backpropagation will fail!
        return {"loss": loss, "preds": preds, "targets": targets}

    def training_epoch_end(self, outputs: List[Any]):
        # log best so far train acc and train loss
        self.metric_hist["train/acc"].append(self.trainer.callback_metrics["train/acc"])
        self.metric_hist["train/loss"].append(self.trainer.callback_metrics["train/loss"])
        self.log("train/acc_best", max(self.metric_hist["train/acc"]), prog_bar=False)
        self.log("train/loss_best", min(self.metric_hist["train/loss"]), prog_bar=False)

    def validation_step(self, batch: Any, batch_idx: int):
        loss, preds, targets = self.step(batch)

        # log val metrics
        acc = self.val_accuracy(preds, targets)
        self.log("val/loss", loss, on_step=False, on_epoch=True, prog_bar=False)
        self.log("val/acc", acc, on_step=False, on_epoch=True, prog_bar=True)

        return {"loss": loss, "preds": preds, "targets": targets}

    def validation_epoch_end(self, outputs: List[Any]):
        # log best so far val acc and val loss
        self.metric_hist["val/acc"].append(self.trainer.callback_metrics["val/acc"])
        self.metric_hist["val/loss"].append(self.trainer.callback_metrics["val/loss"])
        self.log("val/acc_best", max(self.metric_hist["val/acc"]), prog_bar=False)
        self.log("val/loss_best", min(self.metric_hist["val/loss"]), prog_bar=False)

    def test_step(self, batch: Any, batch_idx: int):
        loss, preds, targets = self.step(batch)

        # log test metrics
        acc = self.test_accuracy(preds, targets)
        self.log("test/loss", loss, on_step=False, on_epoch=True)
        self.log("test/acc", acc, on_step=False, on_epoch=True)

        return {"loss": loss, "preds": preds, "targets": targets}

    def test_epoch_end(self, outputs: List[Any]):
        pass

    def configure_optimizers(self):
        """Choose what optimizers and learning-rate schedulers to use in your optimization.
        Normally you'd need one. But in the case of GANs or similar you might have multiple.

        See examples here:
            https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html#configure-optimizers
        """
        return torch.optim.Adam(
            params=self.parameters(),
            lr=self.hparams.lr,
            weight_decay=self.hparams.weight_decay
        )


In [None]:
pretrained=True
model = models.resnet50(pretrained=pretrained)
num_classes = 19
model.fc = torch.nn.Linear(model.fc.in_features, num_classes)
print(model.fc)


model.conv1

model.layer1

for name, child in model.named_children():
    print(name)
# if "layer1" in ct:
#     for params in child.parameters():
#         params.requires_grad = True
# ct.append(name)

from rich import print


print(dir(datamodule))



# name=
#                  batch_size: int=32,
#                  val_split: float=0.2,
#                  num_workers=0,
#                  seed: int=None

In [None]:
import torch
from torchvision import models
from contrastive_learning.data.pytorch.pnas import PNASLightningDataModule
import inspect
from pprint import pprint as pp


data = PNASLightningDataModule()

data.setup(stage='fit')
train_dataset = ImageFolder(root=data.train_dir)

dir(val_data)

len(val_data)

inspect.getsource(val_data.__len__)
# getmembers(val_data.__len__)

# data.train_dataset

pp(inspect.getmembers(data, inspect.isfunction))

from torch.utils.data import Dataset, DataLoader, random_split
from torchvision.datasets import ImageFolder
from torch import Generator
import numpy as np
# from .common import SubsetImageDataset, seed_worker


self = data


self.train_dataset = ImageFolder(root=self.train_dir)#, transform=train_transform, target_transform=target_transform)
self.classes = self.train_dataset.classes

val_split = self.val_split
num_train = len(self.train_dataset)
split_idx = (int(np.floor((1-val_split) * num_train)), int(np.floor(val_split * num_train)))
if self.seed is None:
    generator = None
else:
    generator = Generator().manual_seed(self.seed)

train_data, val_data = random_split(self.train_dataset, 
                                    [split_idx[0], split_idx[1]], 
                                    generator=generator)







In [None]:
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision.datasets import ImageFolder
from torch import Generator
import numpy as np
from typing import List, Callable, Tuple
from torchvision.datasets import VisionDataset

# class TrainValSplitDataset(VisionDataset):
class TrainValSplitDataset(ImageFolder):
    
    all_params: List[str]= [
                            'class_to_idx',
                            'classes',
                            'extensions',
                            'imgs',
                            'loader',
                            'root',
                            'samples',
                            'target_transform',
                            'targets',
                            'transform',
                            'transforms'
                            ]
    sample_params: List[str] = ['imgs',
                                'targets',
                                'samples']
    
    @classmethod
    def train_val_split(cls, full_dataset, val_split: float=0.2, seed: float=None) -> Tuple[ImageFolder]:
        
        num_samples = len(full_dataset)
        split_idx = (int(np.floor((1-val_split) * num_samples)),
                     int(np.floor(val_split * num_samples)))
        if seed is None:
            generator = None
        else:
            generator = Generator().manual_seed(seed)

        train_data, val_data = random_split(full_dataset, 
                                            split_idx,
                                            generator=generator)
        
        train_dataset = cls.select_from_dataset(full_dataset, indices=train_data.indices)
        val_dataset = cls.select_from_dataset(full_dataset, indices=val_data.indices)
        
        return train_dataset, val_dataset
        
    @classmethod
    def from_dataset(cls, dataset):
        new_dataset = cls(root=dataset.root)
        
        for key in cls.all_params:
            if hasattr(dataset, key):
                setattr(new_dataset, key, getattr(dataset, key))
                         
        return new_dataset
    
    
    @classmethod
    def select_from_dataset(cls, dataset, indices=None):
        upgraded_dataset = cls.from_dataset(dataset)
        return upgraded_dataset.select(indices)
    
    
    def select(self, indices):
        new_subset = self.from_dataset(self) #, indices)
        
        for key in self.sample_params:
            old_attr = getattr(self, key)
            new_attr = [old_attr[idx] for idx in indices]
            setattr(new_subset, key, new_attr)
        return new_subset

                
    def __repr__(self) -> str:
        head = "Dataset " + self.__class__.__name__
        body = ["Number of datapoints: {}".format(self.__len__())]
        if self.root is not None:
            body.append("Root location: {}".format(self.root))
        body += self.extra_repr().splitlines()
        if hasattr(self, "transforms") and self.transforms is not None:
            body += [repr(self.transforms)]
        lines = [head] + [" " * self._repr_indent + line for line in body]
        return '\n'.join(lines)

    def _format_transform_repr(self, transform: Callable, head: str) -> List[str]:
        lines = transform.__repr__().splitlines()
        return (["{}{}".format(head, lines[0])] +
                ["{}{}".format(" " * len(head), line) for line in lines[1:]])
    

In [None]:
import torch
from torchvision import models
from contrastive_learning.data.pytorch.pnas import PNASLightningDataModule
import inspect
from pprint import pprint as pp


data = PNASLightningDataModule()

data.setup(stage='fit')
train_dataset = ImageFolder(root=data.train_dir)

In [None]:
train_data, val_data = TrainValSplitDataset.train_val_split(train_dataset, val_split=0.2, seed=0)

In [None]:
print(len(train_data), len(val_data))

In [None]:
train_dataset = TrainValSplitDataset.select_from_dataset(train_data.dataset, train_data.indices)
train_dataset

In [None]:
train_dataset[2]

In [None]:
val_dataset = TrainValSplitDataset.select_from_dataset(val_data.dataset, val_data.indices)
val_dataset

In [None]:
train_dataset = TrainValSplitDataset.from_dataset(train_data.dataset, train_data.indices)
train_dataset = train_dataset.select(train_data.indices)
train_dataset



# len(train_dataset)
len(val_dataset.samples)

val_dataset = TrainValSplitDataset.from_dataset(val_data.dataset, val_data.indices)
val_dataset

val_data.dataset

type(train_dataset)

getattr?

In [None]:
hasattr?

from torchvision import get_image_backend
get_image_backend()

inspect.getsource(train_data.dataset.loader)

(train_data.dataset.imgs[0])

dir(train_data.dataset)
#, val_data

self.train_dataset = SubsetImageDataset(train_data.dataset, train_data.indices)
self.val_dataset = SubsetImageDataset(val_data.dataset, val_data.indices)

class IndexDataset(Dataset):
    
    def __init__(self, dataset, indices):
        



In [None]:

class RAMDataset(Dataset):
    def __init__(image_fnames, targets):
        self.targets = targets
        self.images = []
        for fname in tqdm(image_fnames, desc="Loading files in RAM"):
            with open(fname, "rb") as f:
                self.images.append(f.read())

    def __len__(self):
        return len(self.targets)

    def __getitem__(self, index):
        target = self.targets[index]
        image, retval = cv2.imdecode(self.images[index], cv2.IMREAD_COLOR)
        return image, target

In [None]:
class MapDataset(torch.utils.data.Dataset):
    """
    Given a dataset, creates a dataset which applies a mapping function
    to its items (lazily, only when an item is called).

    Note that data is not cloned/copied from the initial dataset.
    """

    def __init__(self, dataset: torch.utils.data.Dataset, transform =None, target_transform=None):
        self.dataset = dataset
        self.transform = transform
        self.target_transform = target_transform

    def __getitem__(self, index):
        x, y = self.dataset[index]
        
        if self.transform:
            x = self.transform(x)
        if self.target_transform:
            y = self.target_transform(y)
        return x, y

    def __len__(self):
        return len(self.dataset)

In [None]:
dir(data.train_dataset)

type(data)

# data.val_dataset.class_to_idx
len(data.val_dataset.indices)

data.val_dataset.transforms

len(data.val_dataset.targets)

len(data.val_dataset.samples)

dir(data.val_dataset)

data.train_dataset

pretrained=True
model = models.resnet50(pretrained=pretrained)

# dir(model.fc)

in_features = model.fc.in_features
out_features = model.fc.out_features
print(in_features, out_features)

model.__dir__()

pretrained=True
model = models.resnet50(pretrained=pretrained)

num_classes = 10
model.fc = torch.nn.Linear(model.fc.in_features, num_classes)

print(model.fc)

dir(model.avgpool)

model.avgpool.T_destination

In [None]:
import torch
from pytorch_lightning.metrics.classification import StatScores
preds  = torch.tensor([1, 0, 2, 1])
target = torch.tensor([1, 1, 2, 0])
stat_scores = StatScores(reduce='macro', num_classes=3)
stat_scores(preds, target)
# tensor([[0, 1, 2, 1, 1],
#         [1, 1, 1, 1, 2],
#         [1, 0, 3, 0, 1]])

In [None]:
stat_scores = StatScores(reduce='micro')
stat_scores(preds, target)
# tensor([2, 2, 6, 2, 4])

stat_scores = StatScores(reduce='samples')
stat_scores(preds, target)
# tensor([2, 2, 6, 2, 4])

In [None]:
seed = 5
torch.manual_seed(seed)

print(torch.randint(0,9, size=(4,)))
print(torch.randint(0,9, size=(4,)))

seed = 6
torch.manual_seed(seed)

print(torch.randint(0,9, size=(4,)))
print(torch.randint(0,9, size=(4,)))

In [None]:
from pytorch_lightning.metrics.classification import StatScores
from pprint import pprint as pp


def generate_preds_target_pairs(num_classes=19, num_samples=100, min_percent_true=0.5):
    """
    Generate 2 random torch tensors of shape (num_samples,) each,
    where a fraction at least as large as min_percent_true are identical between the two.
    """
    target = torch.randint(0, num_classes, size=(num_samples,))
    preds = torch.clone(target.detach())

    lock_true_pairs = int(num_samples*min_percent_true)
    preds[lock_true_pairs:] = torch.randint(0, num_classes, size=(lock_true_pairs,))
    
    assert preds.size() == target.size()
    assert preds.size() == (num_samples,)
    assert (preds==target).sum() >= lock_true_pairs
    
    return preds, target

In [None]:
num_classes=19
num_samples=100
min_percent_true=0.5

preds, target = generate_preds_target_pairs(num_classes=num_classes, num_samples=num_samples, min_percent_true=min_percent_true)

In [None]:
macro_stats = StatScores(reduce='macro', num_classes=num_classes)
micro_stats = StatScores(reduce='micro', num_classes=num_classes)
sample_stats = StatScores(reduce='samples', num_classes=num_classes)

# print(macro_stats.compute(), micro_stats.compute(), sample_stats.compute())

In [None]:
from rich import print
import pdbr

In [None]:
from sklearn import metrics
import numpy as np
# dir(metrics)

In [None]:
np.unique([0,4, 3,2,2,2,3,5])#, return_counts=True)

macro_stats.update(preds, target)
micro_stats.update(preds, target)
sample_stats.update(preds, target)

print("macro: ", macro_stats.compute())
print("micro: ", micro_stats.compute())
print("Samples: ", sample_stats.compute())
# print(macro_stats, micro_stats, sample_stats)

print(f'preds.size = {preds.size()}', f'targets.size = {target.size()}')
print(f'# correct = {(preds==target).sum()}')
# target = torch.randint(0,9, size=(100,))

pp(list(zip(preds.tolist(), target.tolist())))