In [1]:
%cd ..

/home/shpotes/Projects/safe-regions


In [2]:
from typing import Tuple
import pickle
import torch
import torch.nn as nn
import torch.utils.data as data
from torchvision import transforms
from tqdm import tqdm
import pytorch_lightning as pl
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.loggers import WandbLogger

from pl_bolts.datamodules import CIFAR10DataModule
from pl_bolts.transforms.dataset_normalizations import cifar10_normalization
from safe_regions.pl_module import ResNet
from safe_regions.layers import ReLU
from safe_regions.region import MinMaxRegion, Region

In [3]:
model = ResNet.load_from_checkpoint('weights/colab.ckpt')

In [4]:
dm = CIFAR10DataModule(
    '~/.pytorch/cifar10/',
    num_workers=8,
    batch_size=512,
)

dm.train_transforms = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    cifar10_normalization(),
])

dm.val_transforms = transforms.Compose([
    transforms.ToTensor(),
    cifar10_normalization(),
])
dm.setup()

In [5]:
class RegionMembership:
    def __init__(self, model: nn.Module):
        self.model = model.eval()
        self.state = {}
        self.hooks = []
        
    def _membership_tracker(
        self, 
        module: nn.Module, 
        parent_name: Tuple[str] = ('',)
    ):
        for idx, (layer_name, layer) in enumerate(module.named_children()):
            if isinstance(layer, ReLU):
                layer_id = (idx, *parent_name)
                h = layer.register_forward_hook(
                    self.evaluate_membership(layer_id)
                )
                self.hooks.append(h)
            elif list(layer.children()):
                self._membership_tracker(layer, (*parent_name, layer_name))
    
    def reduction_ops(self):
        return torch.mean
    
    def evaluate_membership(self, layer_id):
        state = self.state
        reduced = self.reduction_ops()
        
        if not layer_id in state:
            state[layer_id] = []

        def _evaluate_membership(layer, input_tensor, _):
            (input_tensor,) = input_tensor
            membership = layer.region.evaluate_membership(
                input_tensor.detach().cpu()
            )
            state[layer_id].append(reduced(membership.float()))
            
        return _evaluate_membership
    
    def reset_state(self):
        for h in self.hooks:
            h.remove()
            
        self.state = {}
        self.hooks = []
        self._membership_tracker(self.model)

    def evaluate(self, dloader: data.DataLoader):
        self.reset_state()
        
        for input_tensor, _ in tqdm(dloader):
            self.model(input_tensor)
        
        self.state = {k: torch.tensor(v) for k, v in self.state.items()}
        return self.state        

In [6]:
class RandomData(data.Dataset):
    def __init__(self, length=1000, shape=(3, 32, 32)):
        self.shape = shape
        self.length = length

    def __getitem__(self, _):
        return torch.randn(*self.shape), 0

    def __len__(self):
        return self.length

In [7]:
tracker = RegionMembership(model.model)
tracker.evaluate(dm.train_dataloader())

100%|██████████| 79/79 [00:44<00:00,  1.79it/s]


{(2,
  ''): tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
         1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
         1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
         1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
         1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
         1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
         1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
         1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
         1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000]),
 (2,
  '',
  'layer1',
  '0'): tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
         1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
         1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
 

In [8]:
ood_loader = data.DataLoader(
    RandomData(10_240), batch_size=512, num_workers=8,
)

tracker = RegionMembership(model.model)
tracker.evaluate(ood_loader)

100%|██████████| 20/20 [00:12<00:00,  1.62it/s]


{(2,
  ''): tensor([0.9998, 0.9998, 0.9998, 0.9998, 0.9998, 0.9998, 0.9998, 0.9998, 0.9998,
         0.9998, 0.9998, 0.9998, 0.9998, 0.9998, 0.9998, 0.9998, 0.9998, 0.9998,
         0.9998, 0.9998]),
 (2,
  '',
  'layer1',
  '0'): tensor([1.0000, 0.9999, 1.0000, 0.9999, 1.0000, 0.9999, 1.0000, 0.9999, 1.0000,
         0.9999, 1.0000, 0.9999, 1.0000, 0.9999, 1.0000, 0.9999, 1.0000, 0.9999,
         1.0000, 0.9999, 1.0000, 0.9999, 1.0000, 0.9999, 1.0000, 0.9999, 1.0000,
         0.9999, 1.0000, 0.9999, 1.0000, 0.9999, 1.0000, 0.9999, 1.0000, 0.9999,
         1.0000, 0.9999, 1.0000, 1.0000]),
 (2,
  '',
  'layer1',
  '1'): tensor([1.0000, 0.9999, 1.0000, 0.9999, 1.0000, 0.9999, 1.0000, 0.9999, 1.0000,
         0.9999, 1.0000, 0.9999, 1.0000, 0.9999, 1.0000, 0.9999, 1.0000, 0.9999,
         1.0000, 0.9999, 1.0000, 0.9999, 1.0000, 0.9999, 1.0000, 0.9999, 1.0000,
         0.9999, 1.0000, 0.9999, 1.0000, 0.9999, 1.0000, 0.9999, 1.0000, 0.9999,
         1.0000, 0.9999, 1.0000, 0.9999]),
 (2,
 