In [25]:
from transformers import CLIPProcessor, CLIPModel
from PIL import Image
import torch
import os
import matplotlib.pyplot as plt
import torch.nn as nn
import torch.nn.functional as F

In [26]:
# Check GPU availability
if torch.cuda.is_available():
    device = torch.device("cuda")
    print("Using GPU:", torch.cuda.get_device_name(0))
else:
    device = torch.device("cpu")
    print("Using CPU")

Using GPU: NVIDIA L40


In [27]:
from collections import namedtuple

# Define the Label class
Label = namedtuple('Label', ['name', 'id', 'trainId', 'category', 'catId', 'hasInstances', 'ignoreInEval', 'color'])

# List of Cityscapes labels
labels = [
    Label('unlabeled', 0, 255, 'void', 0, False, True, (0, 0, 0)),
    Label('ego vehicle', 1, 255, 'void', 0, False, True, (0, 0, 0)),
    Label('rectification border', 2, 255, 'void', 0, False, True, (0, 0, 0)),
    Label('out of roi', 3, 255, 'void', 0, False, True, (0, 0, 0)),
    Label('static', 4, 255, 'void', 0, False, True, (111, 74, 0)),
    Label('dynamic', 5, 255, 'void', 0, False, True, (81, 0, 81)),
    Label('ground', 6, 255, 'void', 0, False, True, (81, 0, 81)),
    Label('road', 7, 0, 'flat', 1, False, False, (128, 64, 128)),
    Label('sidewalk', 8, 1, 'flat', 1, False, False, (244, 35, 232)),
    Label('parking', 9, 255, 'flat', 1, False, True, (250, 170, 160)),
    Label('rail track', 10, 255, 'flat', 1, False, True, (230, 150, 140)),
    Label('building', 11, 2, 'construction', 2, False, False, (70, 70, 70)),
    Label('wall', 12, 3, 'construction', 2, False, False, (102, 102, 156)),
    Label('fence', 13, 4, 'construction', 2, False, False, (190, 153, 153)),
    Label('guard rail', 14, 255, 'construction', 2, False, True, (180, 165, 180)),
    Label('bridge', 15, 255, 'construction', 2, False, True, (150, 100, 100)),
    Label('tunnel', 16, 255, 'construction', 2, False, True, (150, 120, 90)),
    Label('pole', 17, 5, 'object', 3, False, False, (153, 153, 153)),
    Label('polegroup', 18, 255, 'object', 3, False, True, (153, 153, 153)),
    Label('traffic light', 19, 6, 'object', 3, False, False, (250, 170, 30)),
    Label('traffic sign', 20, 7, 'object', 3, False, False, (220, 220, 0)),
    Label('vegetation', 21, 8, 'nature', 4, False, False, (107, 142, 35)),
    Label('terrain', 22, 9, 'nature', 4, False, False, (152, 251, 152)),
    Label('sky', 23, 10, 'sky', 5, False, False, (70, 130, 180)),
    Label('person', 24, 11, 'human', 6, True, False, (220, 20, 60)),
    Label('rider', 25, 12, 'human', 6, True, False, (255, 0, 0)),
    Label('car', 26, 13, 'vehicle', 7, True, False, (0, 0, 142)),
    Label('truck', 27, 14, 'vehicle', 7, True, False, (0, 0, 70)),
    Label('bus', 28, 15, 'vehicle', 7, True, False, (0, 60, 100)),
    Label('caravan', 29, 255, 'vehicle', 7, True, True, (0, 0, 90)),
    Label('trailer', 30, 255, 'vehicle', 7, True, True, (0, 0, 110)),
    Label('train', 31, 16, 'vehicle', 7, True, False, (0, 80, 100)),
    Label('motorcycle', 32, 17, 'vehicle', 7, True, False, (0, 0, 230)),
    Label('bicycle', 33, 18, 'vehicle', 7, True, False, (119, 11, 32)),
    Label('license plate', -1, -1, 'vehicle', 7, False, True, (0, 0, 142)),
]


In [28]:
# Create a mapping from class name to trainId
name_to_trainId = {label.name: label.trainId for label in labels}


In [29]:
# Load pre-trained CLIP model and processor
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
model = model.to(device)

In [30]:
current_directory = os.getcwd()
print("Current working directory:", current_directory)

Current working directory: /home/jovyan/shared/siglip_task


In [31]:
# Specify your directories
images_folder = "/home/jovyan/shared/dataset/data_seg/cityscapes/leftImg8bit/train"
captions_folder = "/home/jovyan/shared/dataset/data_seg/cityscapes/captions/train"

In [32]:
# Helper function to recursively collect file paths
def collect_files(folder, extension):
    file_paths = []
    for root, dirs, files in os.walk(folder):
        for file in files:
            if file.endswith(extension):
                file_paths.append(os.path.join(root, file))
    return sorted(file_paths)

In [33]:
# Collect image and caption files
image_files = collect_files(images_folder, '.png')[:1000]
caption_files = collect_files(captions_folder, '.txt')[:1000]

In [34]:
batch_size = 32  # Adjust based on your available memory
num_batches = len(image_files) // batch_size

In [35]:
'''# Process and encode the first 10,000 images and captions
batch_size = 32  # Adjust based on your available memory
image_files = sorted([f for f in os.listdir(images_folder) if os.path.isfile(os.path.join(images_folder, f)) and f.endswith('.png')])[:1000]
caption_files = sorted([f for f in os.listdir(captions_folder) if os.path.isfile(os.path.join(captions_folder, f)) and f.endswith('.txt')])[:1000]'''

if len(image_files) != len(caption_files):
    raise ValueError("The number of images and captions do not match.")

num_batches = len(image_files) // batch_size

In [36]:
# Initialize global lists for captions and image tensors
all_captions = []
all_image_tensors = []
all_image_files = []

image_features_list = []
text_features_list = []

In [37]:
# Define binary cross-entropy loss with logits
criterion = nn.BCEWithLogitsLoss()

In [39]:
for i in range(num_batches):
    batch_image_files = image_files[i * batch_size: (i + 1) * batch_size]
    batch_caption_files = caption_files[i * batch_size: (i + 1) * batch_size]

    # Debug: Print the current batch files
    print(f"Processing batch {i + 1}/{num_batches}")
    print(f"Batch image files: {batch_image_files}")
    print(f"Batch caption files: {batch_caption_files}")

    # Load and process images
    image_tensors = []
    for img_file in batch_image_files:
        try:
            image = Image.open(os.path.join(images_folder, img_file)).convert("RGB")
            image_tensors.append(image)
            all_image_files.append(img_file)  # Store the image file name
        except Exception as e:
            print(f"Error loading image {img_file}: {e}")

    # Load and process captions
    captions = []
    for cap_file in batch_caption_files:
        try:
            caption = open(os.path.join(captions_folder, cap_file)).read().strip()
            captions.append(caption)
        except Exception as e:
            print(f"Error loading caption {cap_file}: {e}")

    # Append to global lists
    all_captions.extend(captions)
    all_image_tensors.extend(image_tensors)
    
    # Print the features before encoding
    print(f"Before encoding - Batch {i + 1}:")
    print(f"Image tensors: {len(image_tensors)}, Captions: {len(captions)}")

    # Process batch if all files were loaded successfully
    if len(image_tensors) == len(batch_image_files) and len(captions) == len(batch_caption_files):
        inputs = processor(text=captions, images=image_tensors, return_tensors="pt", padding=True)
        
        # Move inputs to the device (GPU or CPU)
        inputs = {k: v.to(device) for k, v in inputs.items()}

        # Forward pass through CLIP model
        with torch.no_grad():
            outputs = model(**inputs)

        # Print the features after encoding
        print(f"After encoding - Batch {i + 1}:")
        print(f"Image embeddings shape: {outputs.image_embeds.shape}")
        print(f"Text embeddings shape: {outputs.text_embeds.shape}")

        # Normalize the embeddings
        image_embeds = outputs.image_embeds / outputs.image_embeds.norm(dim=-1, keepdim=True)
        text_embeds = outputs.text_embeds / outputs.text_embeds.norm(dim=-1, keepdim=True)

        # Print the features after normalization
        print(f"After normalization - Batch {i + 1}:")
        print(f"Normalized image embeddings shape: {image_embeds.shape}")
        print(f"Normalized text embeddings shape: {text_embeds.shape}")

        # Extract features
        image_features = outputs.image_embeds
        text_features = outputs.text_embeds
    

        # Move outputs to CPU and store them
        image_features_list.append(image_embeds.cpu())
        text_features_list.append(text_embeds.cpu())

        # Assume binary labels based on the presence of specific keywords in captions
        target_labels = []
        for caption in captions:
            # Initialize an all-zero label vector for each caption
            label_vector = torch.zeros(len(name_to_trainId))
            for name, trainId in name_to_trainId.items():
                if name in caption and trainId != 255:
                    label_vector[trainId] = 1
            target_labels.append(label_vector)
    
        target_labels = torch.stack(target_labels).to(device)

        # Compute logits and loss
        logits_per_image = outputs.logits_per_image
        logits_per_text = outputs.logits_per_text

        # Resize target labels to match the output logits shape
        if target_labels.size(1) != logits_per_image.size(1):
            raise ValueError(f"Mismatch in target label size ({target_labels.size(1)}) and logits size ({logits_per_image.size(1)})")

        loss_image = criterion(logits_per_image, target_labels)
        loss_text = criterion(logits_per_text, target_labels)
    
        # Backward pass (if training)
        loss_image.backward()
        loss_text.backward()
    
        print(f"Processed batch {i + 1}/{num_batches}")
    else:
        print("Skipping batch due to errors in loading files.")

Processing batch 1/31
Batch image files: ['/home/jovyan/shared/dataset/data_seg/cityscapes/leftImg8bit/train/aachen/.ipynb_checkpoints/aachen_000000_000019_leftImg8bit-checkpoint.png', '/home/jovyan/shared/dataset/data_seg/cityscapes/leftImg8bit/train/aachen/.ipynb_checkpoints/aachen_000001_000019_leftImg8bit-checkpoint.png', '/home/jovyan/shared/dataset/data_seg/cityscapes/leftImg8bit/train/aachen/.ipynb_checkpoints/aachen_000003_000019_leftImg8bit-checkpoint.png', '/home/jovyan/shared/dataset/data_seg/cityscapes/leftImg8bit/train/aachen/.ipynb_checkpoints/aachen_000008_000019_leftImg8bit-checkpoint.png', '/home/jovyan/shared/dataset/data_seg/cityscapes/leftImg8bit/train/aachen/.ipynb_checkpoints/aachen_000015_000019_leftImg8bit-checkpoint.png', '/home/jovyan/shared/dataset/data_seg/cityscapes/leftImg8bit/train/aachen/aachen_000000_000019_leftImg8bit.png', '/home/jovyan/shared/dataset/data_seg/cityscapes/leftImg8bit/train/aachen/aachen_000001_000019_leftImg8bit.png', '/home/jovyan/sha

ValueError: Mismatch in target label size (35) and logits size (32)

In [None]:
# Concatenate features from all batches
image_features = torch.cat(image_features_list, dim=0)
text_features = torch.cat(text_features_list, dim=0)

In [None]:
# Print the concatenated features' dimensions
print(f"Concatenated image features shape: {image_features.shape}")
print(f"Concatenated text features shape: {text_features.shape}")

In [None]:
# Compute sigmoid logits and binary cross-entropy loss
logits = torch.matmul(image_features, text_features.T)
targets = torch.eye(len(image_features)).to(logits.device)  # Dummy targets for demonstration (identity matrix)
loss = criterion(logits, targets)

# Print the computed loss
print(f"Computed binary cross-entropy loss: {loss.item()}")

In [None]:
# Compute sigmoid probabilities
sigmoid_probs = torch.sigmoid(logits)

In [None]:
# Print shapes of logits and sigmoid probabilities
print(f"Logits shape: {logits.shape}")
print(f"Sigmoid probabilities shape: {sigmoid_probs.shape}")

In [None]:
# Compute cosine similarities for comparison
cosine_similarities = F.cosine_similarity(
    image_features.unsqueeze(1),  # (num_images, 1, feature_dim)
    text_features.unsqueeze(0),  # (1, num_captions, feature_dim)
    dim=-1  # Compute similarity along the feature dimension
)

# Print shape of cosine_similarities
print(f"Cosine similarities shape: {cosine_similarities.shape}")
print(f"Cosine similarities: {cosine_similarities}")

In [None]:
def show_image(image_path):
    image = Image.open(os.path.join(images_folder, image_path))
    plt.imshow(image)
    plt.axis('off')
    plt.show()


In [None]:
# Function to find the most similar caption for a given image index
def find_most_similar_caption(image_idx):
    similarity_scores = cosine_similarities[image_idx]
    most_similar_caption_idx = torch.argmax(similarity_scores).item()
    print(f"Image index: {image_idx}, Most similar caption index: {most_similar_caption_idx}, Total captions: {len(all_captions)}")
    return all_captions[most_similar_caption_idx]

# Example usage
image_idx = 450  # Index of the image you want to query

# Show the original image
print("Original image:")
show_image(all_image_files[image_idx])

# Print the original caption and the most similar caption
print("Original caption:", all_captions[image_idx])
print("len of all captions", len(all_captions))
print("Most similar caption:", find_most_similar_caption(image_idx))

In [None]:
def find_most_similar_image(caption_idx):
    similarity_scores = cosine_similarities[:, caption_idx]
    most_similar_image_idx = torch.argmax(similarity_scores).item()
    print(f"Caption index: {caption_idx}, Most similar image index: {most_similar_image_idx}, Total images: {len(all_image_files)}")
    return all_image_files[most_similar_image_idx], most_similar_image_idx

caption_idx = 450  # Index of the caption you want to query

# Print the caption first
print(f"Original caption: {all_captions[caption_idx]}")

# Find and show the most similar image for the given caption
most_similar_image, most_similar_image_idx = find_most_similar_image(caption_idx)
print("Most similar image for the caption:")
show_image(most_similar_image)