In [None]:
from src.data import build_data_loader_train_detection, denormalize_transform

import torch
import numpy as np
from matplotlib import pyplot as plt
from torchvision.utils import make_grid

In [None]:
data_loader = build_data_loader_train_detection( 
    'dataset_detection_pre', 
    batch_size = 4, 
    max_objects = 200, 
    max_poly_points = 64, 
    crop_size = 640, 
    mode = 'seg' )

In [None]:
idata = iter( data_loader )

In [None]:
from torchvision.transforms import ToPILImage, ToTensor
from PIL import ImageDraw
from torchvision.utils import make_grid

def draww(data, grid_cols=4):
    """
    Randomly samples num_samples items from the dataset, draws bounding boxes and polygons,
    and returns (and optionally saves) a grid image of the results.
    
    Args:
        dataset (Dataset): Your DeficiencyDataset instance.
        num_samples (int): Total number of samples to draw (should equal grid_rows * grid_cols).
        grid_cols (int): Number of columns in the grid.
        output_path (str, optional): If provided, the grid image will be saved to this path.
        
    Returns:
        grid_img (PIL.Image): The resulting grid image with drawn annotations.
    """
    # We'll use ToPILImage to convert tensor images (if needed)
    to_pil = ToPILImage()
    to_tensor = ToTensor()

    drawn_images = []
    # Randomly select indices from the dataset.
    poligons_shapes = []
    sample_localizations1 = []
    sample_localizations2 = []
    sample_localizations3 = []
    size = len( data[0] )
    for i in range( size ):

        # img = denormalize_transform( data[0][i] )
        img = data[0][i]
        labels = data[1][i]
        bounding_boxes = data[2][i]
        polygons = data[3][i]

        sample_localizations1.append( data[4][0][i] )
        sample_localizations2.append( data[4][1][i] )
        sample_localizations3.append( data[4][2][i] )
        
        # If the image is a tensor, convert it to a PIL image.
        if isinstance(img, torch.Tensor):
            # Assume image tensor is in [C,H,W] format.
            img = to_pil( img )
        
        # Create a drawing context.
        draw = ImageDraw.Draw( img )

        # The crop size is assumed to be the image size.
        crop_w, crop_h = img.size
        for label, bbox, polygon in zip( labels, bounding_boxes, polygons ):

            if label == 0:
                continue
            
            unormalized_coords = np.array( [ ( x * crop_w, y * crop_h ) for x, y in polygon ] )

            poligons_shapes.append( unormalized_coords.shape )
            
            # Convert normalized bbox to absolute coordinates.
            x_min = bbox[0] * crop_w
            y_min = bbox[1] * crop_h
            x_max = ( bbox[0] + bbox[2] ) * crop_w
            y_max = ( bbox[1] + bbox[3] ) * crop_h
            
            # Draw bounding box.
            draw.rectangle( [ x_min, y_min, x_max, y_max ], outline = "red", width = 2 )
                        
            # Draw polygon if available.
            if unormalized_coords is not None and unormalized_coords.shape[0] >= 3:
                # polygon_np is expected in OpenCV format: shape [1, num_points, 2].
                poly_points = [ tuple(pt) for pt in unormalized_coords ]
                draw.line( poly_points + [ poly_points[0] ], fill = "blue", width = 2 )

        # Append the drawn image.
        drawn_images.append(to_tensor(img))
    
    # Create a grid of images.
    grid = make_grid(drawn_images, nrow=grid_cols, padding=4)
    grid_img = to_pil(grid)
    
    return grid_img, np.unique( poligons_shapes ), ( sample_localizations1, sample_localizations2, sample_localizations3 )

In [None]:
import torch.nn.functional as F

def resize_image(image_tensor, size):
    return F.interpolate( image_tensor.unsqueeze(0), size = ( size, size ),
                          mode = 'bilinear', align_corners = False ).squeeze(0)

In [None]:
data = next( idata )
drawn_img, poligons_shapes, locs = draww(data, grid_cols=2)
print( f"Unique polygon shapes: {poligons_shapes}" )
plt.figure(figsize=(10, 10))
plt.imshow(drawn_img)
plt.axis("off")
plt.show()

In [None]:
images = []
for sample in range(len(data[0])):    
    img = []
    for m in locs:
        mask = m[sample]
        image = data[0][sample]
        im_size = image.shape[1]
        mask_size = mask.shape[1]
        random_mask_color = torch.zeros_like( image )
        random_mask_color += torch.rand( 3 )[:,None,None]
        mask = resize_image( mask[None], im_size )
        image = ( mask * denormalize_transform( image ) ) + ( ( 1 - mask ) *  random_mask_color)
        img.append( image )
    images.append( torch.cat( img, dim = 2 ) )
grid = torch.cat( images, dim = 1 )
plt.figure(figsize=(16, 16))
plt.imshow(grid.permute(1, 2, 0))
plt.axis("off")
plt.show()

In [None]:
nc: int = 25
ch: tuple = (384, 384, 384)
hd: int = 256  # hidden dim
nq: int = 300  # num queries
ndp: int = 4  # num decoder points
nh: int = 8  # num head
ndl: int = 6  # num decoder layers
d_ffn: int =   # dim of feedforward
dropout: float = 0.0
act: nn.Module = nn.ReLU()
eval_idx: int = -1
# Training args
learnt_init_query: bool = False