In [None]:
from transformers import CLIPProcessor, CLIPModel
from PIL import Image
import torch
import os
import torch.nn.functional as F
import gc
import logging

In [None]:
# Configure logging
logging.basicConfig(filename='processing.log', level=20)

In [None]:
# Check GPU availability
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

In [None]:
# Load pre-trained CLIP model and processor
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
model = model.to(device)

In [None]:
# Define directories
directories = {
    "gta": {
        "images": "/home/jovyan/shared/dataset/data_seg/gta/images",
        "captions": "/home/jovyan/shared/dataset/data_seg/gta/captions",
        "output": "/home/jovyan/shared/dataset/data_seg/gta/clip_captions"
    },
    "cityscapes_train": {
        "images": "/home/jovyan/shared/dataset/data_seg/cityscapes/leftImg8bit/train",
        "captions": "/home/jovyan/shared/dataset/data_seg/cityscapes/captions/train",
        "output": "/home/jovyan/shared/dataset/data_seg/cityscapes/clip_captions/train"
    },
    "cityscapes_test": {
        "images": "/home/jovyan/shared/dataset/data_seg/cityscapes/leftImg8bit/test",
        "captions": "/home/jovyan/shared/dataset/data_seg/cityscapes/captions/test",
        "output": "/home/jovyan/shared/dataset/data_seg/cityscapes/clip_captions/test"
    },
    "cityscapes_val": {
        "images": "/home/jovyan/shared/dataset/data_seg/cityscapes/leftImg8bit/val",
        "captions": "/home/jovyan/shared/dataset/data_seg/cityscapes/captions/val",
        "output": "/home/jovyan/shared/dataset/data_seg/cityscapes/clip_captions/val"
    }
}

In [None]:
# Helper function to recursively collect file paths
def collect_files(folder, extension):
    file_paths = []
    for root, dirs, files in os.walk(folder):
        if '.ipynb_checkpoints' in root:
            continue
        for file in files:
            if file.endswith(extension):
                file_paths.append(os.path.join(root, file))
    return sorted(file_paths)

In [None]:
# Collect image and caption files
def process_folder(images_folder, captions_folder):
    image_files = collect_files(images_folder, '.png')
    caption_files = collect_files(captions_folder, '.txt')
    return image_files, caption_files

In [None]:
# Process each directory and get image and caption files
dataset_files = {key: process_folder(value['images'], value['captions']) for key, value in directories.items()}

In [None]:
# Function to find mismatched files
def find_mismatched_files(image_files, caption_files):
    image_basenames = {os.path.splitext(os.path.basename(f))[0] for f in image_files}
    caption_basenames = {os.path.splitext(os.path.basename(f))[0] for f in caption_files}
    
    missing_captions = image_basenames - caption_basenames
    missing_images = caption_basenames - image_basenames
    
    missing_caption_files = [f for f in image_files if os.path.splitext(os.path.basename(f))[0] in missing_captions]
    missing_image_files = [f for f in caption_files if os.path.splitext(os.path.basename(f))[0] in missing_images]
    
    return missing_caption_files, missing_image_files

In [None]:
# Check for mismatched files in each dataset
for key, value in directories.items():
    image_files, caption_files = dataset_files[key]
    missing_caption_files, missing_image_files = find_mismatched_files(image_files, caption_files)
    
    if missing_caption_files:
        print(f"Missing captions for {len(missing_caption_files)} images in {key}:")
        for f in missing_caption_files:
            print(f"  {f}")
    
    if missing_image_files:
        print(f"Missing images for {len(missing_image_files)} captions in {key}:")
        for f in missing_image_files:
            print(f"  {f}")

In [None]:
# Now, only process files that have matching pairs
def get_matching_files(image_files, caption_files):
    image_basenames = {os.path.splitext(os.path.basename(f))[0]: f for f in image_files}
    caption_basenames = {os.path.splitext(os.path.basename(f))[0]: f for f in caption_files}
    
    common_basenames = image_basenames.keys() & caption_basenames.keys()
    
    matched_image_files = [image_basenames[bn] for bn in common_basenames]
    matched_caption_files = [caption_basenames[bn] for bn in common_basenames]
    
    return matched_image_files, matched_caption_files

In [None]:
# Create directories to store the best captions
def save_best_captions(best_captions, image_files, save_folder):
    os.makedirs(save_folder, exist_ok=True)
    for caption, img_file in zip(best_captions, image_files):
        img_filename = os.path.splitext(os.path.basename(img_file))[0]
        caption_file = os.path.join(save_folder, f"{img_filename}.txt")
        with open(caption_file, 'w') as file:
            file.write(caption)

In [None]:
# Function to process batches of images and captions
def process_batches(image_files, caption_files, batch_size):
    if len(image_files) != len(caption_files):
        raise ValueError("The number of images and captions do not match.")
    
    num_batches = (len(image_files) + batch_size - 1) // batch_size
    image_features_list = []
    text_features_list = []

    for i in range(num_batches):
        batch_image_files = image_files[i * batch_size: (i + 1) * batch_size]
        batch_caption_files = caption_files[i * batch_size: (i + 1) * batch_size]

        print(f"Processing batch {i + 1}/{num_batches}")
        print(f"Batch image files: {batch_image_files}")
        print(f"Batch caption files: {batch_caption_files}")

        # Load and process images
        image_tensors = []
        for img_file in batch_image_files:
            try:
                image = Image.open(img_file).convert("RGB")
                image_tensors.append(image)
            except Exception as e:
                print(f"Error loading image {img_file}: {e}")

        # Load and process captions
        captions = []
        for cap_file in batch_caption_files:
            try:
                with open(cap_file, 'r') as file:
                    caption = file.read().strip()
                captions.append(caption)
            except Exception as e:
                print(f"Error loading caption {cap_file}: {e}")

        if len(image_tensors) == len(batch_image_files) and len(captions) == len(batch_caption_files):
            inputs = processor(text=captions, images=image_tensors, return_tensors="pt", padding=True)
            inputs = {k: v.to(device) for k, v in inputs.items()}

            with torch.no_grad():
                outputs = model(**inputs)

            image_embeds = outputs.image_embeds / outputs.image_embeds.norm(dim=-1, keepdim=True)
            text_embeds = outputs.text_embeds / outputs.text_embeds.norm(dim=-1, keepdim=True)

            image_features_list.append(image_embeds.cpu())
            text_features_list.append(text_embeds.cpu())
        else:
            print("Skipping batch due to errors in loading files.")
        
        # Collect garbage to free up memory
        gc.collect()

    image_features = torch.cat(image_features_list, dim=0)
    text_features = torch.cat(text_features_list, dim=0)
    return image_features, text_features

In [None]:
def normalize_features(features):
    return F.normalize(features, p=2, dim=-1)

def find_best_captions(image_features, text_features, caption_files, batch_size=128):
    # Normalize features
    image_features = normalize_features(image_features)
    text_features = normalize_features(text_features)

    num_images = image_features.size(0)
    num_texts = text_features.size(0)
    
    best_captions = []
    for start in range(0, num_images, batch_size):
        end = min(start + batch_size, num_images)
        image_batch = image_features[start:end]

        batch_best_caption_indices = None

        for text_start in range(0, num_texts, batch_size):
            text_end = min(text_start + batch_size, num_texts)
            text_batch = text_features[text_start:text_end]

            with torch.no_grad():
                cosine_similarities = F.cosine_similarity(
                    image_batch.unsqueeze(1),
                    text_batch.unsqueeze(0),
                    dim=-1
                )

            current_best_indices = cosine_similarities.argmax(dim=1)
            current_best_similarities = cosine_similarities[range(cosine_similarities.size(0)), current_best_indices]

            if batch_best_caption_indices is None:
                batch_best_caption_indices = current_best_indices + text_start
                previous_best_similarities = current_best_similarities
            else:
                previous_best_similarities = F.cosine_similarity(
                    image_batch.unsqueeze(1),
                    text_features[batch_best_caption_indices].unsqueeze(0),
                    dim=-1
                ).squeeze(1)

                better_indices = current_best_similarities > previous_best_similarities
                batch_best_caption_indices = torch.where(better_indices, current_best_indices + text_start, batch_best_caption_indices)

        batch_best_caption_indices_list = batch_best_caption_indices.view(-1).cpu().tolist()
        best_captions += [open(caption_files[idx], 'r').read().strip() for idx in batch_best_caption_indices_list]

    return best_captions

In [None]:
# Set batch size for processing
batch_size = 8

# Process each dataset and save the best captions
for key, value in directories.items():
    image_files, caption_files = dataset_files[key]
    print(f"Processing {key}: {len(image_files)} images, {len(caption_files)} captions")

    matched_image_files, matched_caption_files = get_matching_files(image_files, caption_files)
    print(f"Matched pairs: {len(matched_image_files)} images and {len(matched_caption_files)} captions")

    image_features, text_features = process_batches(matched_image_files, matched_caption_files, batch_size)
    best_captions = find_best_captions(image_features, text_features, matched_caption_files, batch_size)
    save_best_captions(best_captions, matched_image_files, value['output'])

print("Best captions saved successfully.")