# Task: `Vision-Language Model`

Given an image and a caption describing a target in that image, return a bounding box corresponding to the target’s location within the image.

Note that targets within a given image are not uniquely identified by their object class (e.g. ”airplane”, “helicopter”); multiple targets within an image may be members of the same object class. Instead, targets provided will correspond to a particular target description (e.g. “black and white drone”).

Not all possible target descriptions will be represented in the training dataset provided to participants. There will also be unseen targets and novel descriptions in the test data used in the hidden test cases of the Virtual Qualifiers, Semi-Finals / Finals. As such, Guardians will have to develop vision models capable of understanding **natural language** to identify the correct target from the scene.

For the **image datasets** provided to both Novice and Advanced Guardians, there will be no noise present. However, it is worth noting that your models will have to be adequately robust as the hidden test cases for the Virtual Qualifiers and the Semi-Finals/Finals will have increasing amounts of noise introduced. This is especially crucial for **Advanced Guardians**, due to the degradation of their robot sensors.

_Insert Code Here_

In [1]:
# !pip install -q torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu113
# !pip install -q -U torchinfo albumentations # Image Augmentation

In [2]:
import albumentations
from PIL import Image
import IPython.display as display
import torch
import requests
from transformers import CLIPProcessor, CLIPModel
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import json

import torchvision
from torchvision.transforms import functional as F
from torchvision import transforms
from torchinfo import summary
import urllib
import os

import numpy as np
from torch.utils.data import IterableDataset, DataLoader
import pytorch_lightning as pl
from pytorch_lightning.callbacks import EarlyStopping, ProgressBar
import torch.nn.functional as F

import torchvision.models.detection as detection
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor

# import multiprocessing as mp
# mp.set_start_method('spawn', force=True)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# Enable benchmark mode in cuDNN to find the best algorithm for your hardware
torch.backends.cudnn.benchmark = True
# Ensure deterministic behavior
torch.backends.cudnn.deterministic = True

In [3]:
cur_dir = os.getcwd()
vlm_dir = os.path.dirname(cur_dir)
til_dir = os.path.dirname(vlm_dir)
home_dir = os.path.dirname(til_dir)
test_dir = os.path.join(home_dir, 'novice')
img_dir = os.path.join(test_dir, 'images')
metadata_path = os.path.join(test_dir, 'vlm.jsonl')
data_dir = os.path.join(cur_dir, 'data')

train_dir = os.path.join(data_dir, "train")
test_dir = os.path.join(data_dir, "test")
val_dir = os.path.join(data_dir, "val")

img_dir

'/home/jupyter/novice/images'

# Models: Faster RNN, CLIP

In [4]:
device = "cuda" if torch.cuda.is_available() else "cpu"

# clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
# clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

In [5]:
class MemmapIterableDataset(IterableDataset):
    def __init__(self, data, shuffle=False):
        self.type_dir, self.num_batches = data
        self.shuffle = shuffle

    def __iter__(self):
        for batch_idx in range(self.num_batches):
            batch_path = os.path.join(self.type_dir, f"batch_{batch_idx}")


            # Load file paths to images from JSON file
            image_paths_file = os.path.join(batch_path, "img_paths.json")
            with open(image_paths_file, 'r') as f:
                image_paths_list = json.load(f)

            # Unpack the list of image paths
            image_paths = [image_path for image_path in image_paths_list]

            # Load other batch data
            bboxes_path = os.path.join(batch_path, "bboxes.npy")
            labels_path = os.path.join(batch_path, "labels.npy")
            text_features_path = os.path.join(batch_path, "text_features.npy")

            # Load other batch data as numpy arrays
            # try:
            bboxes_batch = np.load(bboxes_path)
            labels_batch = np.load(labels_path)
            # except ValueError:
            #     # If there's an error suggesting that you need to allow pickling, use allow_pickle=True
            #     bboxes_batch = np.load(bboxes_path, allow_pickle=True)
            #     labels_batch = np.load(labels_path, allow_pickle=True)
            text_features_batch = np.load(text_features_path, mmap_mode='r')

            # Convert numpy arrays to torch tensors
            bboxes_batch = torch.stack([torch.tensor(b).view(-1, 4) for b in bboxes_batch])
            labels_batch = torch.stack([torch.tensor([l]) for l in labels_batch])
            # print(bboxes_batch)
            # print(labels_batch)
            text_features_batch = torch.tensor(text_features_batch) # TO CHECK IF DIM IS TOO HIGH
            # print(text_features_batch)

            image_tensors = []
            for image_path in image_paths:
                # Load the image data as a memory-mapped array
                image_array = np.load(image_path, mmap_mode='r')

                if image_array is None or image_array.shape[0] == 0 or image_array.shape[1] == 0:
                    print(f"Skipping invalid image: {image_path}")
                    continue

                image_tensor = torch.from_numpy(image_array).type(torch.float32)
    
                # Transfer to GPU (if CUDA is available)
                image_tensor = image_tensor.to(device)

                # Append the tensor to the list
                image_tensors.append(image_tensor)

            # Now concatenate the list of tensors along dimension 0 to create a batch
            if not image_tensors:
                print("No images to process.")
                continue
                
            # Stack the list of tensors along dimension 0 to create a batch tensor
            image_tensors_batch = torch.stack(image_tensors, dim=0)
                
            # bboxes_batch = bboxes_batch.to('cpu')
            # labels_batch = labels_batch.to('cpu')
            # text_features_batch = text_features_batch.to('cpu')

            yield image_tensors_batch, bboxes_batch, labels_batch, text_features_batch

In [10]:
class ObjectDetectionModule(pl.LightningModule):
    def __init__(self, num_classes, train_data, val_data, test_data, learning_rate=1e-3, num_workers=4):
        super().__init__()
        # Account for an additional dummy class for padding
        self.num_classes = num_classes + 1  # Increase number of classes to include a dummy class
        # Initialize the Faster R-CNN and CLIP models
        self.faster_rcnn = torchvision.models.detection.fasterrcnn_resnet50_fpn(weights=detection.FasterRCNN_ResNet50_FPN_Weights.DEFAULT)
        self.clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
        self.clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

        # Replace the classifier in Faster R-CNN
        in_features = self.faster_rcnn.roi_heads.box_predictor.cls_score.in_features
        self.faster_rcnn.roi_heads.box_predictor = FastRCNNPredictor(in_features, self.num_classes)
        
        clip_out_dim = 512  # Adjust according to your specific model output
        # self.projection = torch.nn.Linear(clip_out_dim, in_features)
        
        self.text_fc = nn.Linear(clip_out_dim, in_features)  # Assuming text features have 512 dimensions
        self.combined_fc = nn.Linear(in_features * 2, in_features)

        # Define the fusion layer
        # self.fusion_layer = FusionLayer(visual_feature_dim=in_features, text_feature_dim=512, output_dim=num_classes)
        self.learning_rate = learning_rate
        
        self.train_data = train_data
        self.val_data = val_data
        self.test_data = test_data
        
        self.num_workers = num_workers
        
    def setup(self, stage=None):
        self.train_dataset = MemmapIterableDataset(self.train_data, shuffle=True) # need transform?
        self.val_dataset = MemmapIterableDataset(self.val_data)
        self.test_dataset = MemmapIterableDataset(self.test_data)

    def forward(self, images, text_features):
        # Run images through Faster R-CNN
        detections = self.faster_rcnn(images)
        
        # Process text features
        text_features = self.text_fc(text_features)
        
        # Combine image and text features
        combined_features = torch.cat((detections, text_features), dim=1)
        combined_features = self.combined_fc(combined_features)
        
        return combined_features
    
    def training_step(self, batch, batch_idx):
        images, bboxes, labels, text_features = batch
        
        outputs = self(images, text_features)
        
        print(isinstance(images, torch.Tensor))
        # Filter out dummy data based on labels
        labels_1d = labels.squeeze()
        valid_indices = labels_1d.nonzero().squeeze()
        
        if valid_indices.numel() == 0:
            # Handle the case where all data might be dummy
            return torch.tensor(0.0, device=self.device)

        valid_images = images[valid_indices]
        valid_bboxes = bboxes[valid_indices]
        valid_labels = labels[valid_indices]
        valid_text_features = text_features[valid_indices]

        targets = [{'boxes': bbox, 'labels': label} for bbox, label in zip(valid_bboxes, valid_labels)]

        # classification_losses = [F.cross_entropy(output['logits'], target['labels']) for output, target in zip(outputs, targets)]
        # bbox_losses = [F.mse_loss(output['bbox'], target['boxes']) for output, target in zip(outputs, targets)]
        
        if self.training and valid_text_features is not None:
            outputs = self.augment_losses_with_text(outputs, targets, valid_text_features)

        total_loss = sum(outputs.values())
        average_loss = total_loss / len(valid_images)  # Normalize by the actual valid batch size
        self.log('train_loss', average_loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        return average_loss
    
    def validation_step(self, batch, batch_idx):
        images, bboxes, labels, text_features = batch
        # print(isinstance(images, torch.Tensor))
        # print(isinstance(labels, torch.Tensor))
        # print(isinstance(text_features, torch.Tensor))

        # Filter out dummy data based on labels
        labels_1d = labels.squeeze()
        valid_indices = labels_1d.nonzero().squeeze()
        
        print(valid_indices)
        if valid_indices.numel() == 0:
            # Handle the case where all data might be dummy
            return torch.tensor(0.0, device=self.device)

        valid_images = images[valid_indices]
        valid_bboxes = bboxes[valid_indices]
        valid_labels = labels[valid_indices]
        valid_text_features = text_features[valid_indices]

        targets = [{'boxes': bbox, 'labels': label} for bbox, label in zip(valid_bboxes, valid_labels)]
        outputs = self(valid_images, valid_text_features, targets=targets)

        # classification_losses = [F.cross_entropy(output['logits'], target['labels']) for output, target in zip(outputs, targets)]
        # bbox_losses = [F.mse_loss(output['bbox'], target['boxes']) for output, target in zip(outputs, targets)]

        predictions = self.faster_rcnn(valid_images)

        # Manually compute the losses from predictions and targets
        loss = self.compute_validation_loss(predictions, targets)
        
        self.log('val_loss', average_loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        return average_loss
    
    def compute_validation_loss(self, predictions, targets):
        classification_loss_fn = torch.nn.CrossEntropyLoss()
        bbox_loss_fn = torch.nn.MSELoss()  # Or whatever loss function you use for bbox regression

        all_losses = []
        for prediction, target in zip(predictions, targets):
            # Assuming prediction['scores'] contains the confidence scores
            scores = prediction['scores']
            pred_labels = prediction['labels']
            pred_boxes = prediction['boxes']

            # Get the index of the highest scoring prediction
            max_score_index = scores.argmax()

            # Extract the highest scoring prediction
            pred_label = pred_labels[max_score_index].unsqueeze(0)
            pred_box = pred_boxes[max_score_index].unsqueeze(0)

            true_label = target['labels'].unsqueeze(0)
            true_box = target['boxes'].unsqueeze(0)

            # Ensure logits (pred_label) are float and targets (true_label) are long
            pred_label = pred_label.float()
            true_label = true_label.long()

            # Calculate classification loss
            classification_loss = classification_loss_fn(pred_label, true_label)

            # Ensure pred_box and true_box are float
            pred_box = pred_box.float()
            true_box = true_box.float()

            # Calculate bounding box regression loss
            bbox_loss = bbox_loss_fn(pred_box, true_box)

            # Aggregate losses
            total_loss = classification_loss + bbox_loss
            all_losses.append(total_loss)

        # Return the mean loss across all predictions
        total_loss = torch.stack(all_losses).mean()
        return total_loss

    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
        return optimizer

    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=None, num_workers=self.num_workers)

    def val_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=None, num_workers=self.num_workers)
    
    def test_dataloader(self):
        return DataLoader(self.test_dataset, num_workers=self.num_workers)
    
    # def validation_epoch_end(self, validation_step_outputs):
    #     # Collect all batch losses from the validation_step outputs
    #     if validation_step_outputs:
    #         avg_loss = torch.stack([x for x in validation_step_outputs]).mean()
    #         self.log('avg_val_loss', avg_loss, on_epoch=True, prog_bar=True, logger=True)
    #         print(f"Average Validation Loss: {avg_loss.item()}")
    #     else:
    #         print("No validation data provided or validation_step did not return any outputs.")

In [11]:
# class CustomProgressBar(ProgressBar):
#     def __init__(self):
#         super().__init__()  # Initialize the ProgressBar base class

#     def init_train_tqdm(self):
#         """
#         This method initializes the tqdm progress bar for training.
#         """
#         bar = super().init_train_tqdm()
#         # Adding custom set of metrics to display in the progress bar, e.g., 'train_loss'
#         bar.set_description('Training')
#         return bar

#     def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
#         """
#         Called when the train batch ends. Updates the progress bar with current loss.
#         """
#         super().on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx)  # Call the base method
#         # Here we assume `train_loss` is logged using `self.log('train_loss', loss)` in your training step
#         loss = trainer.logged_metrics.get('train_loss', None)
#         if loss is not None:
#             self.main_progress_bar.set_postfix({'train_loss': f'{loss:.3f}'})

In [12]:
import multiprocessing as mp
mp.set_start_method('spawn', force=True)

data_module = ObjectDetectionModule(
    num_classes=8,
    train_data=(train_dir, 1709),
    val_data=(val_dir, 214),
    test_data=(test_dir, 214),
    learning_rate=1e-3,
    num_workers=0,
)

early_stopping_callback = EarlyStopping(
    monitor='val_loss',  # metric to monitor
    patience=3,          # no of epochs with no improvement to wait before stopping
    verbose=True,        # logging
    mode='min'           # minimize or maximize the monitored metric
)

# Initialize Trainer with model checkpointing
checkpoint_callback = pl.callbacks.ModelCheckpoint(
    monitor='val_loss',
    dirpath='model_checkpoints',
    filename='asr_model-{epoch:02d}-{val_loss:.2f}',
    save_top_k=3,
    mode='min',
)

trainer = pl.Trainer(
    max_steps=1709*10,  # Maximum number of steps (batches) to train for
    callbacks=[checkpoint_callback, early_stopping_callback], # CustomProgressBar()
    val_check_interval=1709,
    limit_val_batches=214,  # Limit the number of validation batches
)

INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs


In [13]:
# Train the model
trainer.fit(data_module) # pl.LightningDataModule can be 2nd parameter

# Test the model
trainer.test(data_module)

INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:pytorch_lightning.callbacks.model_summary:
  | Name        | Type       | Params
-------------------------------------------
0 | faster_rcnn | FasterRCNN | 41.3 M
1 | clip_model  | CLIPModel  | 151 M 
2 | projection  | Linear     | 525 K 
-------------------------------------------
192 M     Trainable params
222 K     Non-trainable params
193 M     Total params
772.551   Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

tensor([0, 1, 2, 3, 4, 5, 6], device='cuda:0')
[{'boxes': tensor([[708, 252, 852, 316]], device='cuda:0'), 'labels': tensor([4], device='cuda:0')}, {'boxes': tensor([[888,  88, 932, 112]], device='cuda:0'), 'labels': tensor([5], device='cuda:0')}, {'boxes': tensor([[400, 236, 436, 272]], device='cuda:0'), 'labels': tensor([4], device='cuda:0')}, {'boxes': tensor([[580, 104, 680, 148]], device='cuda:0'), 'labels': tensor([6], device='cuda:0')}, {'boxes': tensor([[1016,  420, 1084,  452]], device='cuda:0'), 'labels': tensor([2], device='cuda:0')}, {'boxes': tensor([[1052,   76, 1100,  104]], device='cuda:0'), 'labels': tensor([6], device='cuda:0')}, {'boxes': tensor([[428, 376, 548, 436]], device='cuda:0'), 'labels': tensor([2], device='cuda:0')}]


RuntimeError: size mismatch (got input: [100], target: [1])

# No of batches (batch_size=8)

- Train: 1496
- Val: 187
- Test: 187

In [None]:
import numpy as np

# file_path = "/home/jupyter/til-24-base/vlm/src/data/train/batch_0/labels.npy"
file_path = "/home/jupyter/til-24-base/vlm/src/data/val/batch_46/bboxes.npy"

data = np.load(file_path)

# Print the length of the array
print("Length of the array:", len(data))

# If the array is multidimensional and you want to check the size of each dimension
print("Shape of the array:", data.shape)

In [None]:
# import matplotlib.pyplot as plt
# # Extract bounding boxes, labels, and scores
# boxes = prediction[0]['boxes']
# labels = prediction[0]['labels']
# scores = prediction[0]['scores']

# # Visualize the results
# plt.imshow(image)
# for box, label, score in zip(boxes, labels, scores):
#   if score > 0.1:
#     print(id_2_label[label.item()], score.item())
#     plt.gca().add_patch(plt.Rectangle((box[0], box[1]), box[2] - box[0], box[3] - box[1],
#                                       fill=False, edgecolor='red', linewidth=2))
#     plt.text(box[0], box[1], f"Class {label.item()} ({score:.2f})", color='red', fontsize=10,
#              bbox=dict(facecolor='white', alpha=0.7))

# plt.axis('off')
# plt.show()
# print(f"{id_2_label[labels.item()] =}")

In [None]:
# def preprocess_and_save_batches(dataset, augmentations, data_dir='data', batch_size=32):
#     clip_model, preprocess_clip = clip.load("ViT-B/32", device=device)
#     images = dataset['image']
#     annotations = dataset['annotations']
#     num_batches = (len(images) + batch_size - 1) // batch_size

#     for batch_idx in tqdm(range(num_batches), desc="Processing Batches"):
#         batch_images = images[batch_idx * batch_size:(batch_idx + 1) * batch_size]
#         batch_annotations = annotations[batch_idx * batch_size:(batch_idx + 1) * batch_size]
#         batch_data = list(zip(batch_images, batch_annotations))
#         image_tensors = []
#         all_bboxes = []
#         all_labels = []
#         image_features = []
#         text_features = []

#         for image_path, image_annotations in batch_data:
#             # Load image
#             image = cv2.imread(image_path)
#             image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

#             # Apply augmentations
#             augmented = augmentations(image=image)
#             augmented_image = augmented['image'].permute(1, 2, 0).numpy()
#             augmented_image = (augmented_image * 255).astype(np.uint8)
            
#             bboxes = []
#             labels = []

#             for annotation in image_annotations:
#                 caption = annotation['caption']
#                 bbox = annotation['bbox']
#                 bboxes.append(bbox)
#                 labels.append(caption)

#                 # Extract the cropped image
#                 cropped_image = augmented_image[bbox[1]:bbox[1] + bbox[3], bbox[0]:bbox[0] + bbox[2]]
#                 cropped_pil_image = Image.fromarray(cropped_image.astype('uint8'))

#                 # Preprocess the image for CLIP
#                 cropped_preprocessed = preprocess_clip(cropped_pil_image).unsqueeze(0).to(device)
                
#                 # Encode features using CLIP
#                 with torch.no_grad():
#                     if cropped_preprocessed.shape[1] == 3:
#                         image_feature = clip_model.encode_image(cropped_preprocessed).cpu().numpy()
#                         text_feature = clip_model.encode_text(clip.tokenize([caption]).to(device)).cpu().numpy()
#                         image_features.append(image_feature)
#                         text_features.append(text_feature)
#                     else:
#                         print(f"Skipping encoding due to incorrect shape: {cropped_preprocessed.shape}")

#             all_bboxes.append(bboxes)
#             all_labels.append(labels)
        
#         # Save batch to memmap files
#         image_batch_memmap_path = os.path.join(data_dir, f"image_batch_{batch_idx}.npy")
#         np.save(image_batch_memmap_path, np.array(image_tensors))
        
#         # Save each list of bounding boxes separately
#         for i, bboxes in enumerate(all_bboxes):
#             bboxes_path = os.path.join(data_dir, f"bboxes_batch_{batch_idx}_image_{i}.npy")
#             np.save(bboxes_path, np.array(bboxes))
        
#         labels_path = os.path.join(data_dir, f"labels_batch_{batch_idx}.npy")
#         np.save(labels_path, np.array(all_labels, dtype=object))
        
#         image_features_path = os.path.join(data_dir, f"image_features_batch_{batch_idx}.npy")
#         np.save(image_features_path, np.array(image_features))
        
#         text_features_path = os.path.join(data_dir, f"text_features_batch_{batch_idx}.npy")
#         np.save(text_features_path, np.array(text_features))
        
# def load_image(image_path):
#     with Image.open(image_path) as img:
#         img = img.convert('RGB')  # Ensure image is in RGB format
#         image_array = np.array(img)
#     return image_array

In [None]:
import os
directory = "/home/jupyter/til-24-base/vlm/src/data/imgs"

count = 0
for filename in os.listdir(directory):
    if filename.endswith('.jpg.npy'):
        # Construct the full path of the old file
        old_file = os.path.join(directory, filename)
        
        # Create the new file name by replacing '.jpg.npy' with '.npy'
        new_file = os.path.join(directory, filename.replace('.jpg.npy', '.npy'))
        
        # Rename the file
        os.rename(old_file, new_file)
        count += 1

print(f'Renamed {count} files')

In [None]:
# def forward(self, images, bboxes=None, labels=None, text_features=None):
#         # Check if we are in training or inference mode based on if bboxes and labels are provided
#         if bboxes is not None and labels is not None:
#             targets = []
#             for bbox, label in zip(bboxes, labels):
#                 targets.append({
#                     'boxes': bbox,   # Tensor of shape [num_objs, 4]
#                     'labels': label  # Tensor of shape [num_objs]
#                 })
#             outputs = self.faster_rcnn(images, targets)
#         else:
#             outputs = self.faster_rcnn(images)  # Inference mode: No targets are passed

#         # Assuming we have a function to extract features from outputs
#         if self.training:
#             box_features = self.extract_features(outputs, targets)  # Define this method based on your model architecture
#             fused_logits = self.fusion_layer(box_features, text_features)
#             for output, logits in zip(outputs, fused_logits):
#                 output['logits'] = logits

#         return outputs

In [None]:
# from torchvision.models.detection import FasterRCNN
# from torchvision.models.detection.roi_heads import RoIHeads

# class FusionLayer(nn.Module):
#     def __init__(self, visual_feature_dim, text_feature_dim, output_dim):
#         super(FusionLayer, self).__init__()
#         self.fc = nn.Linear(visual_feature_dim + text_feature_dim, output_dim)

#     def forward(self, visual_features, text_features):
#         combined_features = torch.cat((visual_features, text_features), dim=1)
#         output = self.fc(combined_features)
#         return output
    
# class ObjectDetectionModule(pl.LightningModule):
#     def __init__(self, num_classes, train_data, val_data, test_data, learning_rate=1e-3, num_workers=4):
#         super().__init__()
#         # Account for an additional dummy class for padding
#         self.num_classes = num_classes + 1  # Increase number of classes to include a dummy class
#         # Initialize the Faster R-CNN and CLIP models
#         self.faster_rcnn = torchvision.models.detection.fasterrcnn_resnet50_fpn(weights=detection.FasterRCNN_ResNet50_FPN_Weights.DEFAULT)
#         self.clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
#         self.clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

#         # Replace the classifier in Faster R-CNN
#         in_features = self.faster_rcnn.roi_heads.box_predictor.cls_score.in_features
#         self.faster_rcnn.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)

#         # Define the fusion layer
#         self.fusion_layer = FusionLayer(visual_feature_dim=in_features, text_feature_dim=512, output_dim=num_classes)
#         self.learning_rate = learning_rate
        
#         self.train_data = train_data
#         self.val_data = val_data
#         self.test_data = test_data
        
#         self.num_workers = num_workers
        
#     def setup(self, stage=None):
#         self.train_dataset = MemmapIterableDataset(self.train_data, shuffle=True) # need transform?
#         self.val_dataset = MemmapIterableDataset(self.val_data)
#         self.test_dataset = MemmapIterableDataset(self.test_data)

#     def forward(self, images, bboxes=None, labels=None, text_features=None):
#         if bboxes is not None and labels is not None:
#             # Targets are provided, so we are in training mode
#             targets = [{'boxes': bbox, 'labels': label} for bbox, label in zip(bboxes, labels)]
#             loss_dict = self.faster_rcnn(images, targets)  # This will return a dictionary of losses
#             predictions = {'loss_dict': loss_dict}  # Encapsulate loss dict in predictions for compatibility
#         else:
#             # Inference mode: No targets are passed
#             predictions = self.faster_rcnn(images)

#         if not self.training and text_features is not None:
#             # Extract box features for inference mode; adjust this method for training if needed
#             box_features = self.extract_box_features(predictions)
#             fused_features = self.fusion_layer(box_features, text_features)
#             for pred, fused_feature in zip(predictions, fused_features):
#                 pred['scores'] = torch.sigmoid(fused_feature)  # Modifying scores based on fused features

#         return predictions
    
#     def training_step(self, batch, batch_idx):
#         images, bboxes, labels, text_features = batch
#         outputs = self(images, bboxes, labels, text_features)
#         loss = self.faster_rcnn.compute_alignment_loss(outputs, bboxes, labels, text_features)  # Adjust the loss computation accordingly
#         self.log('train_loss', loss)
#         return loss
    
#     def validation_step(self, batch, batch_idx):
#         images, bboxes, labels, text_features = batch
#         outputs = self(images, bboxes, labels, text_features)
#         loss = self.faster_rcnn.compute_alignment_loss(outputs, bboxes, labels, text_features)  # Adjust the loss computation accordingly
#         self.log('val_loss', loss)
#         return loss
    
#     def compute_alignment_loss(self, predictions, bboxes, labels, text_features, alpha=0.5):
#         # Assuming predictions['loss_dict'] contains Faster R-CNN's native loss components
#         ce_loss = predictions['loss_dict']['loss_classifier']
#         mse_loss = predictions['loss_dict']['loss_box_reg']

#         # Assuming you have a way to get 'fused_features' and it contains logits
#         logits = predictions['fused_logits']
#         valid_idx = labels > 0  # Assuming labels is a tensor indicating valid data
#         valid_logits = logits[valid_idx]
#         valid_labels = labels[valid_idx]

#         # Cross-entropy loss from logits
#         additional_ce_loss = F.cross_entropy(valid_logits, valid_labels)

#         # Calculate cosine similarity for alignment loss
#         if text_features is not None:
#             valid_text_features = text_features[valid_idx]
#             cosine_similarity = (valid_logits * valid_text_features).sum(1) / \
#                                 (valid_logits.norm(dim=1) * valid_text_features.norm(dim=1))
#             alignment_loss = 1 - cosine_similarity.mean()
#         else:
#             alignment_loss = 0

#         # Combine all losses
#         return alpha * (ce_loss + mse_loss + additional_ce_loss) + (1 - alpha) * alignment_loss

#     def configure_optimizers(self):
#         optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
#         return optimizer

#     def train_dataloader(self):
#         return DataLoader(self.train_dataset, batch_size=None, num_workers=self.num_workers)

#     def val_dataloader(self):
#         return DataLoader(self.val_dataset, batch_size=None, num_workers=self.num_workers)
    
#     def test_dataloader(self):
#         return DataLoader(self.test_dataset, num_workers=self.num_workers)

In [None]:
print(pl.__version__)

In [None]:
# def validation_step(self, batch, batch_idx):
#         images, bboxes, labels, text_features = batch
#         valid_indices = [i for i, lbl in enumerate(labels) if lbl != 0]
#         if not valid_indices:
#             # Handle the case where all data might be dummy
#             return 0

#         valid_images = images[valid_indices]
#         valid_bboxes = [bboxes[i] for i in valid_indices]
#         valid_labels = [labels[i] for i in valid_indices]
#         valid_text_features = text_features[valid_indices]

#         targets = [{'boxes': bbox, 'labels': label} for bbox, label in zip(valid_bboxes, valid_labels)]
#         outputs = self(valid_images, valid_text_features, targets=targets)

#         losses = []
#         for output, target in zip(outputs, targets):
#             if 'scores' in output:
#                 max_conf_index = output['scores'].argmax()
#                 pred_box = output['boxes'][max_conf_index].unsqueeze(0)
#                 true_box = target['boxes']
#                 loss = F.mse_loss(pred_box, true_box)
#                 losses.append(loss)

#         average_loss = torch.mean(torch.stack(losses)) if losses else 0
#         self.log('val_loss', total_loss, on_step=False, on_epoch=True, prog_bar=True, logger=True)
#         return average_loss