In [None]:
import torch
import torchvision.transforms.v2 as v2
import matplotlib.pyplot as plt
from PIL import Image
import torch.nn.functional as F
import numpy as np

class UnNormalize(object):
    def __init__(self, mean, std):
        self.mean = mean
        self.std = std

    def __call__(self, image):
        image2 = torch.clone(image)
        if len(image2.shape) == 4:
            # batched
            image2 = image2.permute(1, 0, 2, 3)
        for t, m, s in zip(image2, self.mean, self.std):
            t.mul_(s).add_(m)
        return image2.permute(1, 0, 2, 3)
    
norm = v2.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
unnorm = UnNormalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])

class TorchPCA(object):

    def __init__(self, n_components):
        self.n_components = n_components

    def fit(self, X):
        self.mean_ = X.mean(dim=0)
        unbiased = X - self.mean_.unsqueeze(0)
        U, S, V = torch.pca_lowrank(unbiased, q=self.n_components, center=False, niter=4)
        self.components_ = V.T
        self.singular_values_ = S
        return self

    def transform(self, X):
        t0 = X - self.mean_.unsqueeze(0)
        projected = t0 @ self.components_.T
        return projected


def pca(image_feats_list, dim=3, fit_pca=None, max_samples=None):
    device = image_feats_list[0].device

    def flatten(tensor, target_size=None):
        if target_size is not None and fit_pca is None:
            tensor = F.interpolate(tensor, (target_size, target_size), mode="bilinear")
        B, C, H, W = tensor.shape
        return tensor.permute(1, 0, 2, 3).reshape(C, B * H * W).permute(1, 0).detach().cpu()

    if len(image_feats_list) > 1 and fit_pca is None:
        target_size = image_feats_list[0].shape[2]
    else:
        target_size = None

    flattened_feats = []
    for feats in image_feats_list:
        flattened_feats.append(flatten(feats, target_size))
    x = torch.cat(flattened_feats, dim=0)

    # Subsample the data if max_samples is set and the number of samples exceeds max_samples
    if max_samples is not None and x.shape[0] > max_samples:
        indices = torch.randperm(x.shape[0])[:max_samples]
        x = x[indices]

    if fit_pca is None:
        fit_pca = TorchPCA(n_components=dim).fit(x)

    reduced_feats = []
    for feats in image_feats_list:
        x_red = fit_pca.transform(flatten(feats))
        if isinstance(x_red, np.ndarray):
            x_red = torch.from_numpy(x_red)
        x_red -= x_red.min(dim=0, keepdim=True).values
        x_red /= x_red.max(dim=0, keepdim=True).values
        B, C, H, W = feats.shape
        reduced_feats.append(x_red.reshape(B, H, W, dim).permute(0, 3, 1, 2).to(device))

    return reduced_feats, fit_pca

from pytorch_lightning import seed_everything

def _remove_axes(ax):
    ax.xaxis.set_major_formatter(plt.NullFormatter())
    ax.yaxis.set_major_formatter(plt.NullFormatter())
    ax.set_xticks([])
    ax.set_yticks([])


def remove_axes(axes):
    if len(axes.shape) == 2:
        for ax1 in axes:
            for ax in ax1:
                _remove_axes(ax)
    else:
        for ax in axes:
            _remove_axes(ax)
            
#from dl_toolbox.datasets import Rellis3d
#from torchvision import tv_tensors

#tf = v2.Compose([
#    v2.RandomCrop(size=(672, 672)),
#    v2.ToDtype(
#        dtype={tv_tensors.Image: torch.float32, tv_tensors.Mask: torch.int64, "others":None}, 
#        scale=True
#    ),
#    v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
#])
#device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#
#rellis = '/data/Rellis-3D'
#imgs = [rellis+'/00000/pylon_camera_node/frame000000-1581624652_750.jpg']
#msks = [rellis+'/00000/pylon_camera_node_label_id/frame000000-1581624652_750.png']
#dataset = Rellis3d(
#    imgs=imgs,
#    msks=msks,
#    merge='all19',
#    transforms=tf
#)
#elem = dataset[0]
#image, mask = elem['image'].to(device).unsqueeze(0), elem['label']
#h = 672 // 14
#w = 672 // 14
#encoder.to(device)
#lr_feats = encoder.forward_features(image)
#lr_feats = lr_feats[:,encoder.num_prefix_tokens:,...]
#lr_feats = lr_feats.reshape(-1, h, w, 384).permute(0,3,1,2).detach().cpu()
#hr_feats_bili = v2.functional.resize(lr_feats, (672, 672), Image.BILINEAR)

In [1]:
import timm

  from .autonotebook import tqdm as notebook_tqdm


In [16]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange

from timm.models.layers import trunc_normal_

def init_weights(m):
    if isinstance(m, nn.Linear):
        trunc_normal_(m.weight, std=0.02)
        if isinstance(m, nn.Linear) and m.bias is not None:
            nn.init.constant_(m.bias, 0)
    elif isinstance(m, nn.LayerNorm):
        nn.init.constant_(m.bias, 0)
        nn.init.constant_(m.weight, 1.0)

class DecoderLinear(nn.Module):
    def __init__(self, n_cls, patch_size, d_encoder):
        super().__init__()

        self.d_encoder = d_encoder
        self.patch_size = patch_size
        self.n_cls = n_cls

        self.head = nn.Linear(self.d_encoder, n_cls)
        self.apply(init_weights)

    @torch.jit.ignore
    def no_weight_decay(self):
        return set()

    def forward(self, x, im_size):
        H, W = im_size
        GS = H // self.patch_size
        x = self.head(x)
        x = rearrange(x, "b (h w) c -> b c h w", h=GS)

        return x
    
import torch
import torch.nn as nn
import torch.nn.functional as F

class Segmenter(nn.Module):
    def __init__(
        self,
        encoder,
        decoder,
        n_cls,
    ):
        super().__init__()
        self.n_cls = n_cls
        self.patch_size = encoder.patch_size
        self.encoder = encoder
        self.decoder = decoder

    @torch.jit.ignore
    def no_weight_decay(self):
        def append_prefix_no_weight_decay(prefix, module):
            return set(map(lambda x: prefix + x, module.no_weight_decay()))

        nwd_params = append_prefix_no_weight_decay("encoder.", self.encoder).union(
            append_prefix_no_weight_decay("decoder.", self.decoder)
        )
        return nwd_params

    def forward(self, im):
        H_ori, W_ori = im.size(2), im.size(3)
        #im = padding(im, self.patch_size)
        H, W = im.size(2), im.size(3)

        #x = self.encoder(im, return_features=True)
        x = self.encoder.forward_features(im)
        

        # remove CLS/DIST tokens for decoding
        #num_extra_tokens = 1 + self.encoder.distilled
        #x = x[:, num_extra_tokens:]
        x = x[:,self.encoder.num_prefix_tokens:,...]

        masks = self.decoder(x, (H, W))

        masks = F.interpolate(masks, size=(H, W), mode="bilinear")
        #masks = unpadding(masks, (H_ori, W_ori))

        return masks

#    def get_attention_map_enc(self, im, layer_id):
#        return self.encoder.get_attention_map(im, layer_id)
#
#    def get_attention_map_dec(self, im, layer_id):
#        x = self.encoder(im, return_features=True)
#
#        # remove CLS/DIST tokens for decoding
#        num_extra_tokens = 1 + self.encoder.distilled
#        x = x[:, num_extra_tokens:]
#
#        return self.decoder.get_attention_map(x, layer_id)
    
encoder = timm.create_model('vit_base_patch8_224_dino', pretrained=True)
encoder.patch_size = 8
decoder = DecoderLinear(n_cls=10, d_encoder=encoder.embed_dim, patch_size=8)
model = Segmenter(encoder,decoder,n_cls=10)

In [17]:
x = torch.randn(1,3,224,224)
model(x).shape

torch.Size([1, 10, 224, 224])