In [None]:
import torch
from torchvision.utils import make_grid
import torch.nn.functional as F
import matplotlib.pyplot as plt
from dataclasses import dataclass

from src.models import PerceptionModel
from src.data import build_data_loader

from sklearn.mixture import GaussianMixture
from sklearn.cluster import HDBSCAN, KMeans
from sklearn.preprocessing import normalize

import numpy as np
np.set_printoptions(suppress=True)

In [None]:
def resize_image(image_tensor, size):
    return F.interpolate( image_tensor.unsqueeze(0), size = ( size, size ),
                          mode = 'bilinear', align_corners = False ).squeeze(0)

def pad_image(image_tensor, target_size=64):
    """
    Pads the input image tensor with a black border so that the final image
    has dimensions (target_size, target_size). The image is centered.
    
    Assumes image_tensor is of shape [C, H, W].
    
    Parameters:
    - image_tensor: torch.Tensor of shape [C, H, W]
    - target_size: Desired output size (both height and width)
    
    Returns:
    - padded_image: torch.Tensor of shape [C, target_size, target_size]
    """
    # Get original image dimensions
    _, H, W = image_tensor.shape
    
    # Compute the padding sizes for height and width
    pad_height = target_size - H
    pad_width = target_size - W
    
    # Ensure that the image is smaller than the target size
    if pad_height < 0 or pad_width < 0:
        raise ValueError("The image dimensions are larger than the target size.")
    
    # Calculate padding for each side (left, right, top, bottom)
    pad_left = pad_width // 2
    pad_right = pad_width - pad_left
    pad_top = pad_height // 2
    pad_bottom = pad_height - pad_top
    
    # F.pad expects the padding tuple in the order: (pad_left, pad_right, pad_top, pad_bottom)
    padded_image = F.pad(image_tensor, (pad_left, pad_right, pad_top, pad_bottom), mode='constant', value=0)
    return padded_image

In [None]:
@dataclass
class BackboneConfig:
    in_channels: float = 3
    embed_dim: float = 384
    num_heads: float = 8
    depth: float = 4
    num_tokens: float = 4096
    model: str = "linear"

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 
device

In [None]:
model = PerceptionModel( vars( BackboneConfig() ), "./models/", 'encoder_v1', device ).to(device)
model.load()

In [None]:
model.eval()

In [None]:
import cv2

image = cv2.imread( 'front_1706545937.png' )
image = cv2.cvtColor( image, cv2.COLOR_BGR2RGB )
torch_image = torch.from_numpy( image ).permute( 2, 0, 1 ).float() / 255.0
torch_image = torch_image.unsqueeze( 0 ).to( device )

In [None]:
feature_maps = None
def hook_fn(module, input, output):
    global feature_maps
    # feature_maps = input[0].detach().cpu()
    feature_maps = output.detach().cpu()

In [None]:
model.backbone.net.model.layers[15].register_forward_hook( hook_fn )

In [None]:
r = model( torch_image )

In [None]:
feature_maps.shape

In [None]:
layer_index = 3
t_ = feature_maps[0].permute( 1, 2, 0 )
plt.figure( figsize = ( 8, 8 ) )
plt.imshow( t_[:,:,layer_index], cmap = 'magma' )
plt.show()

In [None]:
# data_loader = build_data_loader( 'dataset', batch_size = 4, global_crops_size = 640 )

In [None]:
# data_iter = iter(data_loader)

In [None]:
# denormalize = lambda x: 0.5 * x + 0.5

In [None]:
# with torch.no_grad():
#     with torch.amp.autocast( enabled = True, device_type = "cuda", dtype = torch.bfloat16 ):
#         data = next(data_iter)
#         images = data['collated_global_crops'].cuda( non_blocking = True )
#         tokens = model( images )

# tokens = [ 
#     tokens[0].view( 8, 64, 64, -1 ).to(torch.float32),
#     tokens[1].view( 8, 32, 32, -1 ).to(torch.float32),
#     tokens[2].view( 8, 16, 16, -1 ).to(torch.float32),
# ]

# # plot the imgae
# grid = make_grid( images, nrow = 4, normalize = True ).cpu()
# # grid = make_grid( denormalize( images ), nrow = 4, normalize = True ).cpu()

# plt.figure( figsize = ( 10, 10 ) )
# plt.imshow( grid.permute(1, 2, 0) )
# plt.show()

In [None]:
# index = 2
# layer_index = 70
# t_ = [ resize_image( x[index][:,:,layer_index][None], 80 ) for x in tokens ]
# t_ = [ ( x - x.mean() ) / ( x.std() + 1e-6 ) for x in t_ ]
# grid = make_grid( t_, nrow = 3 )
# plt.figure( figsize = ( 12, 12 ) )
# plt.imshow( grid.cpu().permute(1, 2, 0)[:,:,0], cmap = 'magma' )
# plt.show()

In [None]:
# image = tokens[0][index].permute(2, 0, 1).cpu().numpy()

# channels, s, _ = image.shape

# # Reshape the image to have one pixel per row (each pixel is a vector of length 'channels')
# pixels = image.reshape(channels, s * s).T  # shape: [s*s, channels]

# # Define the number of clusters (for example, 3)
# n_clusters = 20
# random_colors = np.random.rand(n_clusters, 3)

# # Initialize and fit the GMM
# # cluster = GaussianMixture(n_components=n_clusters, random_state=0)
# # cluster = HDBSCAN( min_cluster_size = 16, min_samples = 16 )
# cluster = KMeans( n_clusters = n_clusters )
# labels = cluster.fit_predict( pixels )

# # Reshape the labels back to the original spatial dimensions
# clustered_image = labels.reshape(s, s)
# clustered_image = random_colors[clustered_image].transpose(2, 0, 1)
# clustered_image = torch.tensor( clustered_image ).to(torch.float32).to(device)

# image = make_grid( 
#     [ 
#         resize_image( clustered_image, 640 ), 
#         images[index], 
#         resize_image( resize_image( images[index], s ), 640 ) 
#     ] 
# ).cpu().numpy()

# # Plot the original image
# plt.figure(figsize=(18, 12))
# plt.imshow(image.transpose(1, 2, 0))
# plt.axis('off')
# plt.show()