In [7]:
import torch
import torch.nn as nn
from torchvision import models, transforms
from PIL import Image
from typing import List, Dict
from kornia.feature import DenseSIFTDescriptor
from kornia.filters import spatial_gradient


In [2]:
# VGG-16 Layer Names and Channels
vgg16_layers = {
    "conv1_1": 64,
    "relu1_1": 64,
    "conv1_2": 64,
    "relu1_2": 64,
    "pool1": 64,
    "conv2_1": 128,
    "relu2_1": 128,
    "conv2_2": 128,
    "relu2_2": 128,
    "pool2": 128,
    "conv3_1": 256,
    "relu3_1": 256,
    "conv3_2": 256,
    "relu3_2": 256,
    "conv3_3": 256,
    "relu3_3": 256,
    "pool3": 256,
    "conv4_1": 512,
    "relu4_1": 512,
    "conv4_2": 512,
    "relu4_2": 512,
    "conv4_3": 512,
    "relu4_3": 512,
    "pool4": 512,
    "conv5_1": 512,
    "relu5_1": 512,
    "conv5_2": 512,
    "relu5_2": 512,
    "conv5_3": 512,
    "relu5_3": 512,
    "pool5": 512,
}

In [29]:
class AdapLayers(nn.Module):
    """Small adaptation layers."""
    def __init__(self, hypercolumn_layers: List[str], output_dim: int = 128):
        super(AdapLayers, self).__init__()
        self.layers = []
        channel_sizes = [vgg16_layers[name] for name in hypercolumn_layers]
        for i, l in enumerate(channel_sizes):
            layer = nn.Sequential(
                nn.Conv2d(l, 64, kernel_size=1, stride=1, padding=0),
                nn.ReLU(),
                nn.Conv2d(64, output_dim, kernel_size=5, stride=1, padding=2),
                nn.BatchNorm2d(output_dim),
            )
            self.layers.append(layer)
            self.add_module("adap_layer_{}".format(i), layer)

    def forward(self, features: List[torch.tensor]):
        for i, _ in enumerate(features):
            features[i] = getattr(self, "adap_layer_{}".format(i))(features[i])
        return features

class S2DNet(nn.Module):
    def __init__(self, device: torch.device, hypercolumn_layers: List[str], checkpoint_path: str = None):
        super(S2DNet, self).__init__()
        self._device = device
        self._hypercolumn_layers = hypercolumn_layers
        vgg16 = models.vgg16(pretrained=False)
        self.encoder = nn.Sequential(*list(vgg16.features.children())[:-2]).to(device)
        self.adaptation_layers = AdapLayers(hypercolumn_layers).to(device)
        self.layer_indices = {name: idx for idx, name in enumerate(vgg16_layers) if name in hypercolumn_layers}
        
        if checkpoint_path:
            checkpoint = torch.load(checkpoint_path, map_location=device)
            self.load_state_dict(checkpoint['state_dict'])
        
        self.dense_sift_descriptor = DenseSIFTDescriptor(
            num_ang_bins=8,
            num_spatial_bins=4,
            spatial_bin_size=4,
            clipval=0.2,
            stride=1,
            padding=1
        ).to(device)

    def forward(self, image_tensor: torch.FloatTensor):
        feature_maps = []
        x = image_tensor.to(self._device)
        for idx, layer in enumerate(self.encoder):
            x = layer(x)
            layer_name = list(vgg16_layers.keys())[idx]
            if layer_name in self._hypercolumn_layers:
                feature_maps.append(x)
        adapted_features = self.adaptation_layers(feature_maps)


        dense_descriptors = self.dense_sift_descriptor(adapted_features)
        print(f"Dense descriptors type and shape: {type(dense_descriptors)}, {dense_descriptors.shape}")



        keypoints, descriptors = self.extract_keypoints_and_descriptors(dense_descriptors)
        print(f"Keypoints type and contents: {type(keypoints)}, {keypoints}")
        print(f"Descriptors type and shape: {type(descriptors)}, {descriptors.shape}")
        return keypoints, descriptors




    def extract_keypoints_and_descriptors(self, dense_descriptors):
    # This is a placeholder for whatever method you use to extract keypoints
    # For example, you might use a peak detection in the feature strength across the descriptor maps
    # Let's assume a simple thresholding for illustration
        strength = torch.norm(dense_descriptors, dim=1)  # Compute strength of features
        threshold = strength.mean() + strength.std()
        keypoints_mask = strength > threshold
        keypoints = keypoints_mask.nonzero(as_tuple=True)  # Get indices of keypoints
        descriptors = dense_descriptors[:, keypoints[0], keypoints[1]]  # Extract descriptors for keypoints
        return keypoints, descriptors

In [30]:
def load_image(image_path, device, size=(256, 256), maintain_aspect_ratio=False):
    transform = transforms.Compose([
        transforms.Resize(size, interpolation=transforms.InterpolationMode.BICUBIC),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
    
    image = Image.open(image_path).convert('RGB')
    return transform(image).unsqueeze(0).to(device)

In [31]:
import matplotlib.pyplot as plt

# Function to visualize keypoints on an image
def visualize_keypoints(image_tensor, keypoints):
    # Convert CHW to HWC and normalize
    image_np = image_tensor.squeeze().permute(1, 2, 0).cpu().numpy()
    image_np = (image_np * torch.tensor([0.229, 0.224, 0.225]) + torch.tensor([0.485, 0.456, 0.406])).clip(0, 1)
    plt.imshow(image_np)
    # Ensure keypoints are in (x, y) format for plotting
    if keypoints.dim() > 2:
        keypoints = keypoints.squeeze()
    plt.scatter(keypoints[:, 1], keypoints[:, 0], s=10, color='red', marker='.')
    plt.title("Visualized Keypoints")
    plt.show()


# Load model and image
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = S2DNet(device, ["conv1_2", "conv3_3", "conv5_3"], 's2dnet_weights.pth')
model.eval()

# try:
image_path1 = 'buddha1.jpeg'
image1 = load_image(image_path1, device)
keypoints1, descriptors1 = model(image1)  # Adjust based on actual model output
print("Output type:", type(keypoints1))
print("Output contents:", keypoints1)

#     print("Keypoints from Image 1:", keypoints1.size())
#     visualize_keypoints(image1, keypoints1)

# except Exception as e:
#     print(f"An error occurred: {e}")



AttributeError: 'list' object has no attribute 'shape'

In [12]:
import torch.nn.functional as F

def match_features(features1, features2):
    # Flatten the C, H, and W dimensions so each feature map is a single vector
    f1_flat = features1.view(features1.shape[0], -1)
    f2_flat = features2.view(features2.shape[0], -1)
    
    # Normalize these vectors to have a unit norm
    f1_norm = F.normalize(f1_flat, p=2, dim=1)
    f2_norm = F.normalize(f2_flat, p=2, dim=1)
    
    # Use PyTorch's cosine similarity function across the batch dimension
    similarity = F.cosine_similarity(f1_norm.unsqueeze(1), f2_norm.unsqueeze(0), dim=2)
    
    # Find the best matches for each feature in image1
    max_similarity, best_matches = torch.max(similarity, dim=1)

    print("Similarity scores:", similarity)  # Debug: print similarity scores to check behavior
    return best_matches

# Assuming features1 and features2 are the feature tensors from your model
best_matches = match_features(features1[0], features2[0])
print("Best matches:", best_matches)


Similarity scores: tensor([[0.4755]], device='cuda:0', grad_fn=<SumBackward1>)
Best matches: tensor([0], device='cuda:0')


In [7]:
# import cv2
# def flatten_features(feature_maps):
#     batch_size, channels, height, width = feature_maps.size()
#     return feature_maps.view(batch_size, channels, height * width).permute(0, 2, 1).reshape(-1, channels)

# # Function to compute cosine similarity
# def cosine_similarity(feat1, feat2, batch_size=128):
#     feat1_norm = torch.nn.functional.normalize(feat1, p=2, dim=1)
#     feat2_norm = torch.nn.functional.normalize(feat2, p=2, dim=1)
#     num_rows = feat1_norm.shape[0]
#     similarities = []
    
#     for start in range(0, num_rows, batch_size):
#         end = min(start + batch_size, num_rows)
#         similarities.append(torch.mm(feat1_norm[start:end], feat2_norm.t()))
    
#     return torch.cat(similarities, dim=0)


# # Flatten the features from both images
# features1_flat = flatten_features(features1[0])  # Assuming features1 is a list of tensors
# features2_flat = flatten_features(features2[0])

# # Calculate similarities
# similarities = cosine_similarity(features1_flat, features2_flat)

# # Find the best matches
# top_matches = torch.topk(similarities, k=1, dim=1)[1].squeeze()

# # Load images for visualization
# img1 = cv2.imread(image_path1)
# img2 = cv2.imread(image_path2)

# # Assuming the feature maps are of size (1, C, H, W)
# h1, w1 = features1[0].size(2), features1[0].size(3)
# h2, w2 = features2[0].size(2), features2[0].size(3)

# # Create keypoints for each feature in both images
# keypoints1 = [cv2.KeyPoint(x=float(x % w1), y=float(x // w1), _size=1) for x in range(h1 * w1)]
# keypoints2 = [cv2.KeyPoint(x=float(x % w2), y=float(x // w2), _size=1) for x in range(h2 * w2)]

# # Create matches using the indices found in top_matches
# matches = [cv2.DMatch(_queryIdx=int(i), _trainIdx=int(top_matches[i]), _distance=1.0 - float(similarities[i, top_matches[i]])) for i in range(h1 * w1)]

# # Draw matches
# matched_image = cv2.drawMatches(img1, keypoints1, img2, keypoints2, matches, None, flags=cv2.DrawMatchesFlags_NOT_DRAW_SINGLE_POINTS)

# # Display the result
# cv2.imshow('Feature Matches', matched_image)
# cv2.waitKey(0)
# cv2.destroyAllWindows()

In [8]:
# import torchvision.transforms.functional as TF
# import matplotlib.pyplot as plt

# # Define a common size, for example, the smallest dimensions in your list
# common_size = (64, 64)  # Adjust as necessary

# # Resize tensors
# resized_tensors = [TF.resize(t, common_size) if t.size()[2:] != common_size else t for t in features]

# # Stack the resized tensors
# features_tensor = torch.stack(resized_tensors)

# # Move the tensor to CPU, detach it from the graph, and convert to numpy
# output_numpy = features_tensor.cpu().detach().numpy()

# # Print shapes to debug
# print("Shape of features_tensor:", features_tensor.shape)
# print("Shape of the selected image slice:", output_numpy[0, 0, 0, :, :].shape)

# # Select the first tensor, first channel, first feature map for visualization
# image = output_numpy[0, 0, 0, :, :]  # Ensure this index points to a single (64, 64) feature map

# plt.imshow(image, cmap='gray')  # Ensure image is a 2D array
# plt.colorbar()
# plt.title('Visualization of the Resized Output Tensor')
# plt.show()