In [31]:
import torch
import pandas as pd
import numpy as np
import os
from torch.utils.data import Dataset
import random
from PIL import Image
from transformers import CLIPImageProcessor, CLIPTokenizer
import torchvision.transforms as transforms
import open_clip
import pandas as pd
from torch.utils.data import DataLoader
from dataset import AlignedModalityDataset
from open_clip import image_transform


In [32]:
dataset_path = "Data/ShapeNetSem/Datasets/subset_template_200.csv"
image_dir = "Data/ShapeNetSem/Images/subset_200"
pc_dir = "Data/ProcessedData/PointClouds"

dataset = AlignedModalityDataset(dataset_path, image_dir, pc_dir)
dataloader = DataLoader(dataset, batch_size=1, shuffle=True)

for idx, token_text, image_tensor, point_cloud in dataloader:
    print(idx)
    print(f"Tokenized Text: {token_text.shape}, {type(token_text)}")
    print(f"Image Tensor: {image_tensor.shape}, {type(image_tensor)}")
    print(f"Point Cloud: {point_cloud.shape}, {type(point_cloud)}")
    break

tensor([55])
Tokenized Text: torch.Size([1, 77]), <class 'torch.Tensor'>
Image Tensor: torch.Size([1, 3, 518, 518]), <class 'torch.Tensor'>
Point Cloud: torch.Size([1, 1024, 3]), <class 'torch.Tensor'>


In [34]:
import open3d as o3d

def point_cloud_to_depth_map(pc_data):
    """Converts a 3D point cloud to a 2D depth image."""
    pcd = o3d.geometry.PointCloud()
    pcd.points = o3d.utility.Vector3dVector(pc_data.numpy())

    vis = o3d.visualization.Visualizer()
    vis.create_window(visible=False)
    vis.add_geometry(pcd)

    ctr = vis.get_view_control()
    ctr.set_zoom(0.8)
    vis.poll_events()
    vis.update_renderer()

    # Capture depth image
    depth = vis.capture_depth_float_buffer(True)
    vis.destroy_window()

    # Convert depth to PIL Image
    depth_np = np.asarray(depth)
    depth_image = Image.fromarray((depth_np * 255).astype(np.uint8))

    return depth_image

In [None]:
import torch
from PIL import Image
import timm

def load_dinov2():
    # Load Pretrained DINOv2 Model (Use 'vit_small_patch14_dinov2' for smaller models)
    model_path = "PretrainedModels/dinov2_vits14_pretrain.pth"  # Change to your local path
    model = timm.create_model("vit_small_patch14_dinov2", pretrained=False)

    # Load the state dictionary and remove "mask_token"
    checkpoint = torch.load(model_path, map_location="cpu")
    checkpoint = {k: v for k, v in checkpoint.items() if k != "mask_token"}  # Remove unexpected key

    # Load the modified state dict into the model
    model.load_state_dict(checkpoint, strict=False)  # strict=False allows minor mismatches
    model.eval()  # Set model to evaluation mode
    print('Dinov2 Loaded Successfully!')

    device = "cuda" if torch.cuda.is_available() else "cpu"
    model.to(device)
    return model

def load_clip():
    model_name = "ViT-L-14"  # Change this to ViT-L/14 if needed
    save_path = f"PretrainedModels/clip_vitl14_pretrain.pth"

    clip_model = open_clip.create_model(model_name, pretrained=False)

    # Load saved state dict
    checkpoint = torch.load(save_path, map_location="cpu")
    clip_model.load_state_dict(checkpoint)
    clip_model.eval()

    # Move model to GPU if available
    device = "cuda" if torch.cuda.is_available() else "cpu"
    clip_model.to(device)
    print("CLIP Model Loaded Successfully!")
    return clip_model

def load_point_clip():
    save_path = "PretrainedModels/clip_vitl14_pretrain.pth"
    model_name = "ViT-L-14"

    # Load Model Without Downloading
    clip_model = open_clip.create_model(model_name, pretrained=False)
    checkpoint = torch.load(save_path, map_location="cpu")
    clip_model.load_state_dict(checkpoint)
    clip_model.eval()

    # Move Model to GPU if Available
    device = "cuda" if torch.cuda.is_available() else "cpu"
    clip_model.to(device)

    print("Point CLIP Model Loaded Successfully!")
    return clip_model

In [28]:
try:
    dinov2_encoder = load_dinov2()
    clip_encoder = load_clip()
    pclip_encoder = load_point_clip()
    dinov2_encoder.eval()
    clip_encoder.eval()
    pclip_encoder.eval()
    print('All Models loaded succesfully and set to eval mode')
except:
    print('Error in Loading Models')

Dinov2 Loaded Successfully
CLIP Model Loaded Successfully!
Point CLIP Model Loaded Successfully!
All Models loaded succesfully and set to eval mode


In [61]:
import torch
import torch.nn.functional as F
import torch.nn as nn

class NTXentLoss(torch.nn.Module):
    def __init__(self, temperature=0.07):
        super().__init__()
        self.temperature = temperature

    def forward(self, z1, z2):
        """
        Computes NT-Xent loss for a batch of paired embeddings.
        - z1: First modality (e.g., text)
        - z2: Second modality (e.g., pointcloud)
        """
        batch_size = z1.shape[0]

        # Normalize embeddings
        z1 = F.normalize(z1, dim=-1)
        z2 = F.normalize(z2, dim=-1)
        #print(z1.shape, z2.shape)
        # Compute cosine similarity matrix
        similarity = torch.mm(z1, z2.T) / self.temperature  # Shape: [batch_size, batch_size]
        #print(similarity.shape)
        # Labels should be in range [0, batch_size-1]
        labels = torch.arange(batch_size, device=z1.device)  # Correct labels
        #print(labels)
        # Compute contrastive loss using cross-entropy
        loss = F.cross_entropy(similarity, labels)

        # Debugging output
        """
        with torch.no_grad():
            print(f"Mean Similarity: {similarity.mean().item():.4f}, Loss: {loss.item():.4f}")
        """

        return loss



class MLP(nn.Module):
    def __init__(self, in_dim, out_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, 2*in_dim),
            nn.ReLU(),
            nn.Linear(2*in_dim, out_dim)
        )
    
    def forward(self, x):
        out = self.net(x)
        return out

class AlignEncoder(nn.Module):
    def __init__(self, embed_dim):
        super().__init__()
        self.text_proj_head = MLP(768, embed_dim)
        self.img_proj_head = MLP(384, embed_dim)
        self.pc_proj_head = MLP(768, embed_dim)
    
    def forward(self, text, img, pc):
        text_proj = self.text_proj_head(text)
        img_proj = self.img_proj_head(img)
        pc_proj = self.pc_proj_head(pc)

        return text_proj, img_proj, pc_proj


In [None]:
epochs = 200

dataset_path = "Data/ShapeNetSem/Datasets/subset_template_200.csv"
image_dir = "Data/ShapeNetSem/Images/subset_200"
pc_dir = "Data/ProcessedData/PointClouds"

# Set up CLIP preprocessing
preprocess = image_transform(
    clip_encoder.visual.image_size,  # Correct image size for CLIP
    is_train=False  # Ensures we use inference preprocessing
)

dataset = AlignedModalityDataset(dataset_path, image_dir, pc_dir)
dataloader = DataLoader(dataset, batch_size=8, shuffle=True)
device = "cuda" if torch.cuda.is_available() else "cpu"
align_model = AlignEncoder(400)
align_model.to(device)

loss_fn = NTXentLoss(temperature=0.07)
optimizer = torch.optim.Adam(align_model.parameters(), lr=1e-4)

all_losses = []
for epoch in range(epochs):
    epoch_losses = []
    for i, batch in enumerate(dataloader):
        idx, tokenized_text, image_tensor, point_cloud = batch
        tokenized_text = tokenized_text.to(device) # (B, 77)
        image_tensor = image_tensor.to(device) # (B, 3, 518, 518)
        
        point_cloud = point_cloud.to(device) # (B, 1024, 3)

        #print(tokenized_text.shape, image_tensor.shape, point_cloud.shape)

        # Assuming point_cloud is a batch of point clouds (shape: [batch_size, N, 3])
        batch_size = point_cloud.shape[0]

        # Convert each point cloud to a depth map and preprocess it
        depth_maps = [preprocess(point_cloud_to_depth_map(point_cloud[i])).unsqueeze(0) for i in range(batch_size)]

        # Stack depth maps into a single batch tensor
        depth_maps = torch.cat(depth_maps, dim=0).to(device)  # Shape: [batch_size, 3, H, W]
        #print(depth_maps.shape)

        with torch.no_grad():
            text_emb = clip_encoder.encode_text(tokenized_text) # (B, 768)
            img_emb = dinov2_encoder(image_tensor) # (B, 384)
            pc_emb = pclip_encoder.encode_image(depth_maps) # (B, 768)

        #print(text_emb.shape, img_emb.shape, pc_emb.shape)
        text_proj, img_proj, pc_proj = align_model(text_emb, img_emb, pc_emb)
        #print(text_proj.shape, img_proj.shape, pc_proj.shape)
        loss_text_point = loss_fn(text_proj, pc_proj)
        loss_text_image = loss_fn(text_proj, img_proj)
        loss_image_point = loss_fn(img_proj, pc_proj)

        #print('Loss: ', loss_text_point, loss_text_image, loss_image_point)
        avg_loss = (loss_text_point + loss_text_image + loss_image_point) / 3
        optimizer.zero_grad()
        avg_loss.backward()
        optimizer.step()
        epoch_losses.append(avg_loss.item())

    avg_epoch_loss = sum(epoch_losses)/len(epoch_losses)
    all_losses.append(avg_epoch_loss)
    print(f'Epoch {epoch} loss: {sum(avg_epoch_loss)}')

tensor(2.0735, grad_fn=<DivBackward0>)
tensor(2.0871, grad_fn=<DivBackward0>)
tensor(2.0751, grad_fn=<DivBackward0>)


KeyboardInterrupt: 