# Task: `Vision-Language Model`

Given an image and a caption describing a target in that image, return a bounding box corresponding to the target’s location within the image.

Note that targets within a given image are not uniquely identified by their object class (e.g. ”airplane”, “helicopter”); multiple targets within an image may be members of the same object class. Instead, targets provided will correspond to a particular target description (e.g. “black and white drone”).

Not all possible target descriptions will be represented in the training dataset provided to participants. There will also be unseen targets and novel descriptions in the test data used in the hidden test cases of the Virtual Qualifiers, Semi-Finals / Finals. As such, Guardians will have to develop vision models capable of understanding **natural language** to identify the correct target from the scene.

For the **image datasets** provided to both Novice and Advanced Guardians, there will be no noise present. However, it is worth noting that your models will have to be adequately robust as the hidden test cases for the Virtual Qualifiers and the Semi-Finals/Finals will have increasing amounts of noise introduced. This is especially crucial for **Advanced Guardians**, due to the degradation of their robot sensors.

_Insert Code Here_

In [1]:
# !pip install -q torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu113
# !pip install -q -U torchinfo albumentations # Image Augmentation

In [1]:
import albumentations
from PIL import Image
import IPython.display as display
import torch
import requests
from transformers import CLIPProcessor, CLIPModel
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import json

import torchvision
from torchvision.transforms import functional as F
from torchvision import transforms
from torchinfo import summary
import urllib
import os

import numpy as np
from torch.utils.data import IterableDataset, DataLoader
import pytorch_lightning as pl
from pytorch_lightning.callbacks import EarlyStopping, ProgressBar
import torch.nn.functional as F

import torchvision.models.detection as detection
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection.image_list import ImageList

# import multiprocessing as mp
# mp.set_start_method('spawn', force=True)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# Enable benchmark mode in cuDNN to find the best algorithm for your hardware
torch.backends.cudnn.benchmark = True
# Ensure deterministic behavior
torch.backends.cudnn.deterministic = True

In [3]:
# clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
# for name, param in clip_model.named_parameters():
#     print(name)

In [2]:
cur_dir = os.getcwd()
vlm_dir = os.path.dirname(cur_dir)
til_dir = os.path.dirname(vlm_dir)
home_dir = os.path.dirname(til_dir)
test_dir = os.path.join(home_dir, 'novice')
img_dir = os.path.join(test_dir, 'images')
metadata_path = os.path.join(test_dir, 'vlm.jsonl')
data_dir = os.path.join(cur_dir, 'data')

train_dir = os.path.join(data_dir, "train")
test_dir = os.path.join(data_dir, "test")
val_dir = os.path.join(data_dir, "val")

img_dir

'/home/jupyter/novice/images'

# Models: Faster RNN, CLIP

In [3]:
from torchvision.models.detection import fasterrcnn_resnet50_fpn_v2, FasterRCNN_ResNet50_FPN_V2_Weights

device = torch.device("cpu")
weights = FasterRCNN_ResNet50_FPN_V2_Weights.DEFAULT
preprocess = weights.transforms()

In [4]:
class BatchDataLoader:
    def __init__(self, batch_path):
        self.batch_path = batch_path

    def load_data(self):
        # Load image paths
        with open(os.path.join(self.batch_path, "rcnn_img_paths.json"), 'r') as f:
            rcnn_image_paths = json.load(f)
        with open(os.path.join(self.batch_path, "clip_img_paths.json"), 'r') as f:
            clip_image_paths = json.load(f)

        rcnn_image_tensors = self.load_and_stack_images(rcnn_image_paths)
        clip_pixel_values = self.load_and_stack_images(clip_image_paths, img_type='clip')

        # Load text data
        with open(os.path.join(self.batch_path, "text_data.json"), 'r') as f:
            text_data = json.load(f)
        text_data_tensors = [self.convert_to_tensors(item) for item in text_data]

        # Load bounding boxes and labels
        bboxes_batch = self.load_bboxes(os.path.join(self.batch_path, "bboxes.npy"))
        labels_batch = self.load_labels(os.path.join(self.batch_path, "labels.npy"))

        return rcnn_image_tensors, clip_pixel_values, text_data_tensors, bboxes_batch, labels_batch
    
    def load_and_stack_images(self, image_paths, img_type='rcnn'):
        image_tensors = [self.load_image_to_tensor(image_path) for image_path in image_paths]
        # Filter out None values in case of invalid images
        image_tensors = [tensor for tensor in image_tensors if tensor is not None]
        if not image_tensors:
            return None
        if img_type == 'rcnn':
            return torch.stack(image_tensors)
        elif img_type == 'clip':
            # Convert all image tensors to a list of dictionaries
            clip_inputs = [{"pixel_values": tensor.unsqueeze(0)} for tensor in image_tensors]
            return clip_inputs
        else:
            raise ValueError("Invalid image type specified. Use 'rcnn' or 'clip'.")
    
    @staticmethod
    def load_image_to_tensor(image_path):
        # Load the image data as a memory-mapped array
        image_array = np.load(image_path, mmap_mode='r')
        if image_array is None or image_array.size == 0:
            print(f"Skipping invalid image: {image_path}")
            return None
        # Convert to a PyTorch tensor
        image_tensor = torch.from_numpy(image_array).type(torch.float32).to('cuda')
        return image_tensor

    @staticmethod
    def convert_to_tensors(data):
        converted_data = {key: torch.tensor(value).to('cuda') for key, value in data.items()}
        return converted_data

    @staticmethod
    def load_bboxes(bboxes_path):
        bboxes_batch = np.load(bboxes_path, mmap_mode='r')
        # Convert numpy arrays to torch tensors and move to GPU
        bboxes_batch = torch.stack([torch.tensor(b).view(-1, 4).to('cuda') for b in bboxes_batch])
        return bboxes_batch

    @staticmethod
    def load_labels(labels_path):
        labels_batch = np.load(labels_path, mmap_mode='r')
        # Convert numpy arrays to torch tensors and move to GPU
        labels_batch = torch.stack([torch.tensor([l]).to('cuda') for l in labels_batch])
        return labels_batch

In [5]:
class MemmapIterableDataset(IterableDataset):
    def __init__(self, data, shuffle=False):
        self.type_dir, self.num_batches = data
        self.shuffle = shuffle

    def __iter__(self):
        for batch_idx in range(self.num_batches):
            batch_path = os.path.join(self.type_dir, f"batch_{batch_idx}")
            
            dataloader = BatchDataLoader(batch_path)
            rcnn_image_tensors, clip_pixel_values, text_data_tensors, bboxes_batch, labels_batch = dataloader.load_data()

            # Concatenate the list of tensors along dimension 0 to create a batch
            if rcnn_image_tensors.size(0) == 0 or len(clip_pixel_values) == 0:
                print("No images to process.")
                continue

            yield rcnn_image_tensors, clip_pixel_values, text_data_tensors, bboxes_batch, labels_batch

In [6]:
class ImageDetectionModel(pl.LightningModule):
    def __init__(self, train_data, val_data, test_data, num_classes, num_workers):
        super().__init__()
        self.clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
        self.clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
        self.rcnn = fasterrcnn_resnet50_fpn_v2(weights=FasterRCNN_ResNet50_FPN_V2_Weights.DEFAULT)
        in_features = self.rcnn.roi_heads.box_predictor.cls_score.in_features
        self.rcnn.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes + 1)
        
        self.embedding_transform = nn.Linear(512, 256)
        
        # Allow CLIP model parameters to be trainable (fine-tuning)
        for name, param in self.clip_model.named_parameters():
        # Freeze all parameters first
            param.requires_grad = False

        # Unfreeze parameters in the last layers of the text model
        if 'text_model.encoder.layers.11' in name:
            param.requires_grad = True

        # Unfreeze parameters in the last layers of the vision model
        if 'vision_model.encoder.layers.11' in name:
            param.requires_grad = True

        # Optionally, adjust parameters related to the output projections if fine-tuning the head is desired
        if 'visual_projection.weight' in name or 'text_projection.weight' in name:
            param.requires_grad = True
            
        self.train_data = train_data
        self.val_data = val_data
        self.test_data = test_data
        self.num_workers = num_workers
        
    def setup(self, stage=None):
        self.train_dataset = MemmapIterableDataset(self.train_data)
        self.val_dataset = MemmapIterableDataset(self.val_data)
        self.test_dataset = MemmapIterableDataset(self.test_data)

    def forward(self, rcnn_imgs_preprocessed, clip_imgs_preprocessed, clip_texts_preprocessed, targets=None):
        batch_feature_maps = []
        # losses =[]

        for rcnn_img, clip_img, clip_text in zip(rcnn_imgs_preprocessed, clip_imgs_preprocessed, clip_texts_preprocessed):
            # Generate embeddings for both image and text from CLIP
            image_embeddings, text_embeddings = self.generate_embeddings(clip_img, clip_text)

            # Ensure image is in [C, H, W] format and transfer to device
            image_tensor = rcnn_img.permute(2, 0, 1).to(self.device)

            # Extract feature maps using the RCNN backbone
            feature_maps = self.get_feature_maps(image_tensor.unsqueeze(0))['0']

            # Modulate feature maps using both image and text embeddings
            modulated_feature_maps = self.modulate_features_with_embeddings(feature_maps, image_embeddings, text_embeddings)
    
            #to remove the batch number
            modulated_feature_maps = modulated_feature_maps.squeeze(0)
            
            # Resize modulated_feature_maps to have size [3, H, W]
            modulated_feature_maps = modulated_feature_maps[:3]  # Take the first 3 channels
            
            # Store processed feature maps
            batch_feature_maps.append(modulated_feature_maps)
            
            # # Compute loss
            # loss = compute_loss(modulated_feature_maps, targeted_feature_maps)
            # losses.append(loss.item())
            
        # print("Dimensions of batch_image_tensors:", [t.shape for t in batch_feature_maps])

        # Since all operations should be on feature maps post backbone processing
        #integrated_features = torch.stack(batch_feature_maps,dim=0)
        
        integrated_features = batch_feature_maps
        
        # print(f"training status: {self.training}")
        
        return self.rcnn(integrated_features, targets)
    
    def generate_embeddings(self, clip_img_preprocessed, clip_text_preprocessed):
        with torch.no_grad():
            image_embeddings = self.clip_model.get_image_features(**clip_img_preprocessed).to(self.device)
            text_embeddings = self.clip_model.get_text_features(**clip_text_preprocessed).to(self.device)
        return image_embeddings, text_embeddings
    
    def get_feature_maps(self, image_tensor):
        backbone = self.rcnn.backbone
        backbone.eval()

        with torch.no_grad():
            feature_maps = backbone(image_tensor)

        return feature_maps  # This now returns a dictionary of feature maps
    
    def modulate_features_with_embeddings(self, feature_maps, image_embeddings, text_embeddings):
        # Assuming feature_maps is a batch of feature maps with shape [batch_size, channels, height, width]
        # Both image_embeddings and text_embeddings are [batch_size, 512]
        
        # print(feature_maps.shape)
        
        image_embeddings_transformed = self.embedding_transform(image_embeddings)  # [batch_size, 256]
        text_embeddings_transformed = self.embedding_transform(text_embeddings)    # [batch_size, 256]

        # Expand embeddings to match the spatial dimensions of the feature maps
        image_embeddings_expanded = image_embeddings_transformed.unsqueeze(-1).unsqueeze(-1)  # [batch_size, 256, 1, 1]
        text_embeddings_expanded = text_embeddings_transformed.unsqueeze(-1).unsqueeze(-1)    # [batch_size, 256, 1, 1]
        
        # print(image_embeddings_expanded.shape)

        # Broadcast the embeddings across the spatial dimensions
        image_embeddings_expanded = image_embeddings_expanded.expand_as(feature_maps)  # [batch_size, 256, height, width]
        text_embeddings_expanded = text_embeddings_expanded.expand_as(feature_maps)    # [batch_size, 256, height, width]

        # Concatenate or add embeddings to the feature maps
        # Here we choose concatenation for demonstration; dimension 1 is the channel dimension
        modulated_feature_maps = torch.cat([feature_maps, image_embeddings_expanded, text_embeddings_expanded], dim=1)

        return modulated_feature_maps


    def training_step(self, batch, batch_idx):
        rcnn_image_tensors, clip_image_tensors, clip_text_data, bboxes_batch, labels_batch = batch

        clip_texts_preprocessed = [{key: torch.tensor(value).to(self.device) for key, value in item.items()} for item in clip_text_data]

        targets = []
        for bboxes, labels in zip(bboxes_batch, labels_batch):
            mask = labels != 0
            if mask.any():
                filtered_bboxes = bboxes[mask]
                filtered_labels = labels[mask]

                # Construct the target dictionary
                target = {
                    'boxes': filtered_bboxes.to(self.device),
                    'labels': filtered_labels.to(self.device)
                }
            else:
                # Create an empty target dictionary with correct shape and on the correct device
                target = {
                    'boxes': torch.zeros((0, 4), dtype=torch.float, device=self.device),
                    'labels': torch.zeros(0, dtype=torch.int64, device=self.device)
                }

            targets.append(target)

        outputs = self(rcnn_image_tensors, clip_image_tensors, clip_texts_preprocessed, targets)

        # print("Keys in training")
        # print(outputs.keys())  # Assuming outputs is a dictionary, not a list

        # Calculate total loss from various components
        if any(t['labels'].numel() > 0 for t in targets):
            total_loss = sum(outputs[key] for key in outputs.keys() if 'loss' in key)
            self.log('train_loss', total_loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
            return total_loss
        else:
            self.log('train_loss', 0, on_step=True, on_epoch=True, prog_bar=True, logger=True)
            return torch.tensor(0.0, requires_grad=True).to(self.device)
    
    def validation_step(self, batch, batch_idx):
        rcnn_image_tensors, clip_image_tensors, clip_text_data, bboxes_batch, labels_batch = batch

        clip_texts_preprocessed = [{key: torch.tensor(value).to(self.device) for key, value in item.items()} for item in clip_text_data]

        targets = []
        for bboxes, labels in zip(bboxes_batch, labels_batch):
            # Filter out entries where labels are 0 (masking background or padded elements)
            mask = labels != 0
            if mask.any():
                filtered_bboxes = bboxes[mask]
                filtered_labels = labels[mask]

                # Construct the target dictionary
                target = {
                    'boxes': filtered_bboxes.to(self.device),  # Ensure tensors are on the correct device
                    'labels': filtered_labels.to(self.device)
                }
            else:
                # Create an empty target dictionary with correct shape and on the correct device
                target = {
                    'boxes': torch.zeros((0, 4), dtype=torch.float, device=self.device),
                    'labels': torch.zeros(0, dtype=torch.int64, device=self.device)
                }

            targets.append(target)

        self.rcnn.train()

        with torch.no_grad():
            outputs = self(rcnn_image_tensors, clip_image_tensors, clip_texts_preprocessed, targets)

        self.rcnn.eval()

#             print("Keys in validation")
#             print(outputs.keys())

#             print(outputs['loss_classifier'])

        if any(t['labels'].numel() > 0 for t in targets):
            # Calculate the total validation loss by summing individual components
            total_loss = sum(outputs[key] for key in outputs.keys())
            self.log('val_loss', total_loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
            return total_loss
        else:
            self.log('val_loss', 0, on_step=True, on_epoch=True, prog_bar=True, logger=True)
            return torch.tensor(0.0, device=self.device)
            
    def test_step(self, batch, batch_idx):
        rcnn_image_tensors, clip_image_tensors, clip_text_data, _, _ = batch
        # Assuming test data might not always have labels available

        clip_texts_preprocessed = [{key: torch.tensor(value).to(self.device) for key, value in item.items()} for item in clip_text_data]

        # Put model in evaluation mode
        self.rcnn.eval()

        # Disable gradient computation explicitly for safety
        with torch.no_grad():
            outputs = self(rcnn_image_tensors, clip_image_tensors, clip_texts_preprocessed)

        # Extract relevant output details, e.g., predicted boxes, labels, and scores
        predictions = {
            'boxes': outputs['boxes'],
            'labels': outputs['labels'],
            'scores': outputs['scores']
        }
        
        return predictions
                
    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=None, num_workers=self.num_workers)

    def val_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=None, num_workers=self.num_workers)
    
    def test_dataloader(self):
        return DataLoader(self.test_dataset, num_workers=self.num_workers)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-4)
        return optimizer

In [9]:
NUM_CLASSES=8
TRAIN_NUM_BATCHES = 2991
TEST_NUM_BATCHES = 374
VAL_NUM_BATCHES = 374

early_stopping_callback = EarlyStopping(
    monitor='val_loss',  # metric to monitor
    patience=3,          # no of epochs with no improvement to wait before stopping
    verbose=True,        # logging
    mode='min'           # minimize or maximize the monitored metric
)

# Initialize Trainer with model checkpointing
checkpoint_callback = pl.callbacks.ModelCheckpoint(
    monitor='val_loss',
    dirpath='model_checkpoints',
    filename='nlp_model-{epoch:02d}-{val_loss:.2f}',
    save_top_k=3,
    mode='min',
)

vlm_model = ImageDetectionModel(
    train_data=(train_dir, TRAIN_NUM_BATCHES), 
    val_data=(val_dir, TEST_NUM_BATCHES), 
    test_data=(test_dir, VAL_NUM_BATCHES), 
    num_classes=NUM_CLASSES,
    num_workers=0
)

trainer = pl.Trainer(
    max_steps=TRAIN_NUM_BATCHES*10,  # Maximum number of steps (batches) to train for
    callbacks=[checkpoint_callback, early_stopping_callback], # CustomProgressBar()
    val_check_interval=TRAIN_NUM_BATCHES,
    limit_val_batches=VAL_NUM_BATCHES,  # Limit the number of validation batches
    accelerator="gpu",
    devices=1
)

INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs


In [8]:
# Train the model
trainer.fit(vlm_model)

# Test the model
trainer.test(vlm_model)

INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:pytorch_lightning.callbacks.model_summary:
  | Name                | Type       | Params
---------------------------------------------------
0 | clip_model          | CLIPModel  | 151 M 
1 | rcnn                | FasterRCNN | 43.3 M
2 | embedding_transform | Linear     | 131 K 
---------------------------------------------------
43.5 M    Trainable params
151 M     Non-trainable params
194 M     Total params
778.803   Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

  rank_zero_warn(
  image_tensor = torch.from_numpy(image_array).type(torch.float32).to('cuda')
  clip_texts_preprocessed = [{key: torch.tensor(value).to(self.device) for key, value in item.items()} for item in clip_text_data]
  rank_zero_warn(


Training: 0it [00:00, ?it/s]

  clip_texts_preprocessed = [{key: torch.tensor(value).to(self.device) for key, value in item.items()} for item in clip_text_data]
  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
  rank_zero_warn(


Testing: 0it [00:00, ?it/s]

ValueError: too many values to unpack (expected 3)

# No of batches (batch_size=8)

- Train: 1496
- Val: 187
- Test: 187

In [None]:
import numpy as np

# file_path = "/home/jupyter/til-24-base/vlm/src/data/train/batch_0/labels.npy"
file_path = "/home/jupyter/til-24-base/vlm/src/data/val/batch_46/bboxes.npy"

data = np.load(file_path)

# Print the length of the array
print("Length of the array:", len(data))

# If the array is multidimensional and you want to check the size of each dimension
print("Shape of the array:", data.shape)

In [None]:
# import matplotlib.pyplot as plt
# # Extract bounding boxes, labels, and scores
# boxes = prediction[0]['boxes']
# labels = prediction[0]['labels']
# scores = prediction[0]['scores']

# # Visualize the results
# plt.imshow(image)
# for box, label, score in zip(boxes, labels, scores):
#   if score > 0.1:
#     print(id_2_label[label.item()], score.item())
#     plt.gca().add_patch(plt.Rectangle((box[0], box[1]), box[2] - box[0], box[3] - box[1],
#                                       fill=False, edgecolor='red', linewidth=2))
#     plt.text(box[0], box[1], f"Class {label.item()} ({score:.2f})", color='red', fontsize=10,
#              bbox=dict(facecolor='white', alpha=0.7))

# plt.axis('off')
# plt.show()
# print(f"{id_2_label[labels.item()] =}")

In [None]:
# import os
# directory = "/home/jupyter/til-24-base/vlm/src/data/imgs"

# count = 0
# for filename in os.listdir(directory):
#     if filename.endswith('.jpg.npy'):
#         # Construct the full path of the old file
#         old_file = os.path.join(directory, filename)
        
#         # Create the new file name by replacing '.jpg.npy' with '.npy'
#         new_file = os.path.join(directory, filename.replace('.jpg.npy', '.npy'))
        
#         # Rename the file
#         os.rename(old_file, new_file)
#         count += 1

# print(f'Renamed {count} files')

In [None]:
# class CustomDataModule(pl.LightningModule):
#     def __init__(self, train_data, val_data, test_data, num_classes, num_workers):
#         super(CustomDataModule, self).__init__()
#         self.model = MultimodalFasterRCNN(num_classes)
#         self.clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
#         self.clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
        
#         # Allow CLIP model parameters to be trainable (fine-tuning)
#         for param in self.clip_model.parameters():
#             param.requires_grad = True
            
#         self.train_data = train_data
#         self.val_data = val_data
#         self.test_data = test_data
#         self.num_workers = num_workers

#     def setup(self, stage=None):
#         self.train_dataset = MemmapIterableDataset(self.train_data)
#         self.val_dataset = MemmapIterableDataset(self.val_data)
#         self.test_dataset = MemmapIterableDataset(self.test_data)

#     def training_step(self, batch, batch_idx):
#         print(self.device)
#         images, bboxes, labels, captions = batch
#         text_features = self.extract_text_features(captions)
        
#         outputs = self.model(images, text_features)

#         valid_indices = labels.view(-1) != 0  # Flatten labels for masking

#         masked_labels = outputs['labels'].view(-1, outputs['labels'].size(-1))[valid_indices]
#         masked_boxes = outputs['boxes'].view(-1, outputs['boxes'].size(-1))[valid_indices]
        
#         targets_labels = labels.view(-1)[valid_indices]
#         targets_boxes = bboxes.view(-1, bboxes.size(-1))[valid_indices]

#         # Compute loss based on masked outputs and targets
#         loss = self.compute_loss(masked_labels, masked_boxes, targets_labels, targets_boxes)
#         self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
#         return loss

#     def validation_step(self, batch, batch_idx):
#         images, bboxes, labels, captions = batch
#         text_features = self.extract_text_features(captions)
        
#         outputs = self.model(images, text_features)

#         valid_indices = labels.view(-1) != 0  # Flatten labels for masking

#         masked_labels = outputs['labels'].view(-1, outputs['labels'].size(-1))[valid_indices]
#         masked_boxes = outputs['boxes'].view(-1, outputs['boxes'].size(-1))[valid_indices]
        
#         targets_labels = labels.view(-1)[valid_indices]
#         targets_boxes = bboxes.view(-1, bboxes.size(-1))[valid_indices]

#         # Compute loss based on masked outputs and targets
#         loss = self.compute_loss(masked_labels, masked_boxes, targets_labels, targets_boxes)
#         self.log('val_loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
#         return loss
    
#     def compute_loss(self, masked_labels, masked_boxes, targets_labels, targets_boxes):
#         # Custom loss computation using masked outputs and targets
#         classification_loss = torch.nn.CrossEntropyLoss()(masked_labels, targets_labels)
#         bbox_loss = torch.nn.MSELoss()(masked_boxes, targets_boxes)
#         total_loss = classification_loss + bbox_loss
#         return total_loss

#     def configure_optimizers(self):
#         optimizer = torch.optim.Adam(
#             list(self.model.parameters()) + list(self.clip_model.parameters()),
#             lr=1e-4
#         )
#         return optimizer
    
#     def train_dataloader(self):
#         return DataLoader(self.train_dataset, batch_size=None, num_workers=self.num_workers)

#     def val_dataloader(self):
#         return DataLoader(self.val_dataset, batch_size=None, num_workers=self.num_workers)
    
#     def test_dataloader(self):
#         return DataLoader(self.test_dataset, num_workers=self.num_workers)
    
#     def extract_text_features(self, captions):
#         inputs = self.clip_processor(text=captions, return_tensors="pt", padding=True)
#         for key, value in inputs.items():
#             print(f"{key} is on device: {value.device}")
#         text_features = self.clip_model.get_text_features(**inputs)
#         print(f"Text features are on device: {text_features.device}")
#         return text_features

In [None]:
# clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")

# # Get the dimension of the text features
# text_feature_dim = clip_model.config.text_config.hidden_size
# print(f"Text feature dimension: {text_feature_dim}")

# print(pl.__version__)

In [None]:
#  def modulate_features_with_embeddings(self, feature_maps, image_embeddings, text_embeddings):
#         # Assuming feature_maps is a batch of feature maps with shape [batch_size, channels, height, width]
#         # Both image_embeddings and text_embeddings are [batch_size, 512]
        
#         image_embeddings_transformed = self.embedding_transform(image_embeddings)  # [batch_size, 256]
#         text_embeddings_transformed = self.embedding_transform(text_embeddings)    # [batch_size, 256]

#         # Assume image_embeddings and text_embeddings are already transformed to match the channel dimension, i.e., [batch_size, 256]
#         image_embeddings_expanded = image_embeddings.unsqueeze(-1).unsqueeze(-1)  # [batch_size, 256, 1, 1]
#         text_embeddings_expanded = text_embeddings.unsqueeze(-1).unsqueeze(-1)    # [batch_size, 256, 1, 1]

#         # Calculate the repeat factors for height and width
#         height_repeat = feature_maps.size(2)
#         width_repeat = feature_maps.size(3)

#         # Use repeat to manually expand to the required dimensions
#         image_embeddings_expanded = image_embeddings_expanded.repeat(1, 1, height_repeat, width_repeat)
#         text_embeddings_expanded = text_embeddings_expanded.repeat(1, 1, height_repeat, width_repeat)

#         # Now you can concatenate them with feature maps
#         modulated_feature_maps = torch.cat([feature_maps, image_embeddings_expanded, text_embeddings_expanded], dim=1)
        
#         return modulated_feature_maps

In [4]:
import torch

# Load the checkpoint file
checkpoint = torch.load('./model_checkpoints/vlm_model-epoch=00-val_loss=32.70.ckpt', map_location=torch.device('cpu'))

# Print the keys and explore the structure
print(checkpoint.keys())

print(checkpoint['state_dict'])

dict_keys(['epoch', 'global_step', 'pytorch-lightning_version', 'state_dict', 'loops', 'callbacks', 'optimizer_states', 'lr_schedulers'])
OrderedDict([('clip_model.logit_scale', tensor(4.6052)), ('clip_model.text_model.embeddings.token_embedding.weight', tensor([[-3.9053e-03, -6.3254e-03,  7.3507e-03,  ..., -1.0660e-02,
         -2.2764e-02, -1.0908e-02],
        [-2.6081e-02,  8.7953e-03, -1.1737e-02,  ..., -1.2019e-02,
         -2.4059e-02, -2.1929e-02],
        [-1.9648e-02, -6.6711e-03, -9.0593e-03,  ...,  4.5782e-03,
         -2.0692e-02, -8.7150e-03],
        ...,
        [ 8.5028e-03,  1.0219e-03,  2.0366e-02,  ...,  1.4868e-02,
          1.7627e-02, -1.4752e-03],
        [-1.6741e-03,  7.3048e-05, -4.1996e-03,  ..., -3.4096e-03,
         -3.9295e-03, -5.5289e-05],
        [-6.0260e-03,  2.0210e-03,  4.9674e-04,  ..., -3.3459e-03,
         -9.8587e-03, -2.3390e-04]])), ('clip_model.text_model.embeddings.position_embedding.weight', tensor([[-1.4361e-03,  1.9820e-04, -4.1244e-03, 