In [9]:
import timm
import torch
from torchvision.models.feature_extraction import create_feature_extractor
from torchvision.models.feature_extraction import get_graph_node_names

model = timm.create_model(
    'vit_small_patch14_dinov2',
    pretrained=True,
    dynamic_img_size=False,
)

#print(get_graph_node_names(model))
encoder = create_feature_extractor(model, {'norm': 'features'})
out = encoder(torch.rand(1, 3, 518, 518))
print([(k, v.shape) for k, v in out.items()])

[('features', torch.Size([1, 1370, 384]))]


In [13]:
import torchvision.transforms.v2 as v2
from dl_toolbox import datamodules
from pathlib import Path

tf = v2.Compose(
    [
        v2.CenterCrop(280),
    ]
)
  
cityscapes = datamodules.Cityscapes(
    data_path='/data',
    merge='all19',
    train_tf=tf,
    test_tf=tf,
    batch_size=2,
    num_workers=0,
    pin_memory=False
)

In [2]:
import timm
import torch

class ViTExtractor:
    """ This class facilitates extraction of features, descriptors, and saliency maps from a ViT.

    We use the following notation in the documentation of the module's methods:
    B - batch size
    h - number of heads. usually takes place of the channel dimension in pytorch's convention BxCxHxW
    p - patch size of the ViT. either 8 or 16.
    t - number of tokens. equals the number of patches + 1, e.g. HW / p**2 + 1. Where H and W are the height and width
    of the input image.
    d - the embedding dimension in the ViT.
    """

    def __init__(self, model_type='dino_vits8', device='cuda'):
        """
        :param model_type: A string specifying the type of model to extract from.
                          [dino_vits8 | dino_vits16 | dino_vitb8 | dino_vitb16 | vit_small_patch8_224 |
                          vit_small_patch16_224 | vit_base_patch8_224 | vit_base_patch16_224]
        :param stride: stride of first convolution layer. small stride -> higher resolution.
        :param model: Optional parameter. The nn.Module to extract from instead of creating a new one in ViTExtractor.
                      should be compatible with model_type.
        """
        self.model_type = model_type
        self.model = timm.create_model(
            model_type,
            pretrained=True,
            dynamic_img_size=True,
        )
        self.model_type = model_type
        self.device = device
        self.model.eval()
        self.model.to(self.device)
        self.p = self.model.patch_embed.patch_size[0]
        self.stride = self.model.patch_embed.proj.stride
        self._feats = []
        self.hook_handlers = []

    def _get_hook(self, facet):
        """
        generate a hook method for a specific block and facet.
        """
        if facet in ['attn', 'token']:
            def _hook(model, input, output):
                self._feats.append(output)
            return _hook

        if facet == 'query':
            facet_idx = 0
        elif facet == 'key':
            facet_idx = 1
        elif facet == 'value':
            facet_idx = 2
        else:
            raise TypeError(f"{facet} is not a supported facet.")

        def _inner_hook(module, input, output):
            input = input[0]
            B, N, C = input.shape
            qkv = module.qkv(input).reshape(B, N, 3, module.num_heads, C // module.num_heads).permute(2, 0, 3, 1, 4)
            self._feats.append(qkv[facet_idx]) #Bxhxtxd
        return _inner_hook

    def _register_hooks(self, layers, facet):
        """
        register hook to extract features.
        :param layers: layers from which to extract features.
        :param facet: facet to extract. One of the following options: ['key' | 'query' | 'value' | 'token' | 'attn']
        """
        for block_idx, block in enumerate(self.model.blocks):
            if block_idx in layers:
                if facet == 'token':
                    self.hook_handlers.append(block.register_forward_hook(self._get_hook(facet)))
                elif facet == 'attn':
                    self.hook_handlers.append(block.attn.attn_drop.register_forward_hook(self._get_hook(facet)))
                elif facet in ['key', 'query', 'value']:
                    self.hook_handlers.append(block.attn.register_forward_hook(self._get_hook(facet)))
                else:
                    raise TypeError(f"{facet} is not a supported facet.")

    def _unregister_hooks(self) -> None:
        """
        unregisters the hooks. should be called after feature extraction.
        """
        for handle in self.hook_handlers:
            handle.remove()
        self.hook_handlers = []

    def _extract_features(self, batch, layers, facet):
        """
        extract features from the model
        :param batch: batch to extract features for. Has shape BxCxHxW.
        :param layers: layer to extract. A number between 0 to 11.
        :param facet: facet to extract. One of the following options: ['key' | 'query' | 'value' | 'token' | 'attn']
        :return : tensor of features.
                  if facet is 'key' | 'query' | 'value' has shape Bxhxtxd
                  if facet is 'attn' has shape Bxhxtxt
                  if facet is 'token' has shape Bxtxd
        """
        B, C, H, W = batch.shape
        self._feats = []
        self._register_hooks(layers, facet)
        _ = self.model(batch)
        self._unregister_hooks()
        return self._feats

    def extract_descriptors(self, batch, layer, facet, bin, include_cls):
        """
        extract descriptors from the model
        :param batch: batch to extract descriptors for. Has shape BxCxHxW.
        :param layers: layer to extract. A number between 0 to 11.
        :param facet: facet to extract. One of the following options: ['key' | 'query' | 'value' | 'token']
        :param bin: apply log binning to the descriptor. default is False.
        :return: tensor of descriptors. Bx1xtxd' where d' is the dimension of the descriptors.
        """
        assert facet in ['key', 'query', 'value', 'token'], f"""{facet} is not a supported facet for descriptors. 
                                                             choose from ['key' | 'query' | 'value' | 'token'] """
        self._extract_features(batch, [layer], facet)
        x = self._feats[0]
        if facet == 'token':
            x.unsqueeze_(dim=1) #Bx1xtxd
        if not include_cls:
            x = x[:, :, 1:, :]  # remove cls token
        else:
            assert not bin, "bin = True and include_cls = True are not supported together, set one of them False."
        if not bin:
            desc = x.permute(0, 2, 3, 1).flatten(start_dim=-2, end_dim=-1).unsqueeze(dim=1)  # Bx1xtx(dxh)
        else:
            desc = self._log_bin(x)
        return desc

In [14]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
extractor = ViTExtractor('vit_small_patch14_dinov2', device=device)
cityscapes.setup(stage='fit')
for batch in cityscapes.val_dataloader():
    with torch.no_grad():
        #image_batch, image_pil = extractor.preprocess(args.image_path, args.load_size)
        #print(f"Image {args.image_path} is preprocessed to tensor of size {image_batch.shape}.")
        descriptors1 = extractor.extract_descriptors(batch['image'].to(device), 11, 'token', bin=False, include_cls=False)
        print(f"Descriptors are of size: {descriptors1.shape}")
        #torch.save(descriptors, args.output_path)
        #print(f"Descriptors saved to: {args.output_path}")
    break

Descriptors are of size: torch.Size([2, 1, 400, 384])


In [18]:
from dl_toolbox.modules import FeatureExtractor

device = 'cuda' if torch.cuda.is_available() else 'cpu'
feature_extractor = FeatureExtractor('vit_small_patch14_dinov2', fc_norm=False).to(device)
feature_extractor.encoder.prune_intermediate_layers(11,prune_norm=True,prune_head=False)
cityscapes.setup(stage='fit')
for batch in cityscapes.val_dataloader():
    with torch.no_grad():
        #image_batch, image_pil = extractor.preprocess(args.image_path, args.load_size)
        #print(f"Image {args.image_path} is preprocessed to tensor of size {image_batch.shape}.")
        descriptors2 = feature_extractor(batch['image'].to(device))[:,1:]
        print(f"Descriptors are of size: {descriptors2.shape}")
        #torch.save(descriptors, args.output_path)
        #print(f"Descriptors saved to: {args.output_path}")
    break

Descriptors are of size: torch.Size([2, 400, 384])


In [19]:
torch.allclose(descriptors1.squeeze(), descriptors2)

True

In [1]:
import torch
from functools import partial

import torchvision.transforms.v2 as v2
import pytorch_lightning as pl

from dl_toolbox import datamodules
from dl_toolbox import modules
from dl_toolbox import networks

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
tf = v2.Compose(
    [
        v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ]
)

datamodule = datamodules.Resisc(
    data_path='/data',
    train_tf=tf,
    test_tf=tf,
    merge='all45',
    batch_size=16,
    num_workers=5,
    pin_memory=False
)

In [3]:
import pytorch_lightning as pl
import torch
import torch.nn.functional as F
import torchmetrics as M
import matplotlib.pyplot as plt
from dl_toolbox.utils import plot_confusion_matrix
from pytorch_lightning.utilities import rank_zero_info
import torch.nn as nn
import math
from dl_toolbox.transforms import Mixup
import timm


class Classifier(pl.LightningModule):
    def __init__(
        self,
        encoder,
        in_channels,
        class_list,
        optimizer,
        scheduler,
        metric_ignore_index,
        one_hot,
        tta=None,
        sliding=None,
        *args,
        **kwargs
    ):
        super().__init__()
        self.class_list = class_list
        self.num_classes = len(class_list)
        
        self.network = timm.create_model(
            encoder,
            pretrained=True,
            num_classes=self.num_classes
        )
        self.encoder = self.network.blocks
        self.loss = torch.nn.CrossEntropyLoss(
            ignore_index=-1,
            reduction='mean',
            weight=None,
            label_smoothing=0.
        )
        self.optimizer = optimizer
        self.scheduler = scheduler
        
        self.tta = tta
        self.sliding = sliding
        self.one_hot = one_hot
        
        metric_args = {
            'task': 'multiclass',
            'num_classes': self.num_classes,
            'ignore_index': metric_ignore_index
        }
        self.val_accuracy = M.Accuracy(**metric_args)
        self.test_accuracy = M.Accuracy(**metric_args)
        self.val_cm = M.ConfusionMatrix(**metric_args, normalize='true')
        self.test_cm = M.ConfusionMatrix(**metric_args, normalize='true')
    
    def configure_optimizers(self):
        parameters = list(self.parameters())
        trainable_parameters = list(filter(lambda p: p.requires_grad, parameters))
        rank_zero_info(
            f"The model will start training with only {sum([int(torch.numel(p)) for p in trainable_parameters])} "
            f"trainable parameters out of {sum([int(torch.numel(p)) for p in parameters])}."
        )
        optimizer = self.optimizer(params=trainable_parameters)
        scheduler = self.scheduler(optimizer)
        return {
            "optimizer": optimizer,
            "lr_scheduler": {
                "scheduler": scheduler,
                "interval": "epoch"
            },
        }

    def forward(self, x, sliding=None, tta=None):
        if sliding is not None:
            auxs = [self.forward(aux, tta=tta) for aux in sliding(x)]
            return sliding.merge(auxs)
        elif tta is not None:
            auxs = [self.forward(aux) for aux in tta(x)]
            logits = self.forward(x)
            return torch.stack([logits] + self.tta.revert(auxs)).sum(dim=0)
        else:
            return self.network.forward(x)
    
    def to_one_hot(self, y):
        return torch.movedim(F.one_hot(y, self.num_classes),-1,1).float()
    
    def training_step(self, batch, batch_idx):
        batch = batch["sup"]
        x = batch["image"]
        y = batch["target"]
        if self.one_hot: y = self.to_one_hot(y)
        logits_x = self.forward(x, sliding=None, tta=None)
        loss = self.loss(logits_x, y)
        self.log(f"ce/train", loss)
        return loss
        
    def validation_step(self, batch, batch_idx):
        x = batch["image"]
        y = batch["target"]
        if self.one_hot: y = self.to_one_hot(y)
        logits_x = self.forward(x, sliding=self.sliding)                    
        loss = self.loss(logits_x, y)
        self.log(f"ce/val", loss)
        probs = logits_x.softmax(dim=1)
        pred_probs, preds = torch.max(probs, dim=1)
        self.val_accuracy.update(preds, y)
        self.val_cm.update(preds, y)
        
    def on_validation_epoch_end(self):
        val_accuracy = self.val_accuracy.compute()
        print(val_accuracy)
        self.log("accuracy/val", val_accuracy)
        confmat = self.val_cm.compute().detach().cpu()
        self.val_accuracy.reset()
        self.val_cm.reset()
        class_names = [l.name for l in self.class_list]
        logger = self.trainer.logger
        fs = 12 - 2*(self.num_classes//10)
        fig = plot_confusion_matrix(confmat, class_names, norm=None, fontsize=fs)
        logger.experiment.add_figure("confmat/val", fig, global_step=self.trainer.global_step)

    def test_step(self, batch, batch_idx):
        x = batch["image"]
        y = batch["target"]
        if self.one_hot: y = self.to_one_hot(y)
        logits_x = self.forward(x, sliding=self.sliding, tta=self.tta)
        loss = self.loss(logits_x, y)
        self.log(f"ce/test", loss)
        probs = logits_x.softmax(dim=1)
        pred_probs, preds = torch.max(probs, dim=1)
        self.test_accuracy.update(preds, y)
        self.test_cm.update(preds, y)
        
    def on_test_epoch_end(self):
        self.log("accuracy/test", self.test_accuracy.compute())
        confmat = self.test_cm.compute().detach().cpu()
        self.test_accuracy.reset()
        self.test_cm.reset()
        class_names = [l.name for l in self.class_list]
        logger = self.trainer.logger
        fs = 12 - 2*(self.num_classes//10)
        fig = plot_confusion_matrix(confmat, class_names, norm=None, fontsize=fs)
        logger.experiment.add_figure("confmat/test", fig, global_step=self.trainer.global_step)

    def predict_step(self, batch, batch_idx):
        x = batch["image"]
        logits_x = self.forward(x, sliding=self.sliding, tta=self.tta)
        return logits_x.softmax(dim=1)
    
module = Classifier(
    encoder='efficientnet_b0',
    class_list=datamodule.class_list,
    optimizer=partial(torch.optim.SGD, lr=0.01, momentum=0.9, weight_decay=0.0001),
    scheduler=partial(torch.optim.lr_scheduler.ConstantLR, factor=1),
    in_channels=3,
    metric_ignore_index=None,
    one_hot=False,
    tta=None,
    sliding=None,
)

In [None]:
from dl_toolbox.callbacks import ProgressBar, Finetuning, Lora, TiffPredsWriter, CalibrationLogger

lora = Lora('encoder', 4, True)

trainer = pl.Trainer(
    accelerator='gpu',
    devices=1,
    max_steps=100000,
    limit_train_batches=1.,
    limit_val_batches=1.,
    callbacks=[lora, ProgressBar()]
)

trainer.fit(
    module,
    datamodule=datamodule,
)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
`Trainer(limit_train_batches=1.0)` was configured so 100% of the batches per epoch will be used..
`Trainer(limit_val_batches=1.0)` was configured so 100% of the batches will be used..
Missing logger folder: /d/pfournie/dl_toolbox/dl_toolbox/modules/lightning_logs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
The model will start training with only 510125 trainable parameters out of 4065193.

  | Name          | Type                      | Params
------------------------------------------------------------
0 | network       | EfficientNet              | 4.1 M 
1 | encoder       | Sequential                | 3.6 M 
2 | loss          | CrossEntropyLoss          | 0     
3 | val_accuracy  | MulticlassAccuracy        | 0     
4 | test_accuracy | MulticlassAccuracy        | 0     
5 | val_cm        | MulticlassConfusionMatrix | 0     
6 | test_cm

Sanity Checking DataLoader 0:  50%|██████████████████████████████████████████████▌                                              | 1/2 [00:00<00:00,  2.52it/s]

/d/pfournie/dl_toolbox/venv/lib/python3.8/site-packages/pytorch_lightning/utilities/data.py:77: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 16. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.


Sanity Checking DataLoader 0: 100%|█████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  4.81it/s]tensor(0., device='cuda:0')




Epoch 0: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████| 1575/1575 [01:39<00:00, 15.84it/s, v_num=0]
Validation: |                                                                                                                           | 0/? [00:00<?, ?it/s][A

/d/pfournie/dl_toolbox/venv/lib/python3.8/site-packages/pytorch_lightning/utilities/data.py:77: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 12. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.


tensor(0.8890, device='cuda:0')
Epoch 1: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████| 1575/1575 [01:10<00:00, 22.43it/s, v_num=0]
Validation: |                                                                                                                           | 0/? [00:00<?, ?it/s][Atensor(0.9230, device='cuda:0')
Epoch 2: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████| 1575/1575 [01:10<00:00, 22.40it/s, v_num=0]
Validation: |                                                                                                                           | 0/? [00:00<?, ?it/s][Atensor(0.9259, device='cuda:0')
Epoch 3: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████| 1575/1575 [01:10<00:00, 22.38it/s, v_num=0]
Validation: |                                                                                            