<a href="https://colab.research.google.com/github/wizardoftrap/Referring-Video-Object-Segmentation/blob/main/RVOS9.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#***RVOS-9:***
###***Shiv Prakash Verma(2021eeb1030)***

In [None]:
# Install required packages
!pip install torch torchvision
!pip install opencv-python
!pip install transformers
!pip install matplotlib
!pip install tqdm
!pip install pycocotools
!pip install einops
!pip install albumentations

In [None]:
import torch

print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA device count: {torch.cuda.device_count()}")
    print(f"CUDA device name: {torch.cuda.get_device_name(0)}")
    device = torch.device('cuda')
else:
    print("CUDA not available, using CPU")
    device = torch.device('cpu')

print(f"Using device: {device}")
!nvidia-smi

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
!mkdir /content/drive/MyDrive/RVOS_VR_Project
!mkdir /content/drive/MyDrive/RVOS_VR_Project/data
!mkdir /content/drive/MyDrive/RVOS_VR_Project/data/ref-davis

%cd /content/drive/MyDrive/RVOS_VR_Project/data/ref-davis

#Download DAVIS dataset files
!wget -c https://data.vision.ee.ethz.ch/csergi/share/davis/DAVIS-2017-Unsupervised-trainval-480p.zip
!wget -c https://data.vision.ee.ethz.ch/csergi/share/davis/DAVIS-2017_semantics-480p.zip

#Download text annotations
!wget -c https://www.mpi-inf.mpg.de/fileadmin/inf/d2/khoreva/davis_text_annotations.zip

!unzip -o davis_text_annotations.zip

#%cd /content/drive/MyDrive/RVOS_VR_Project

In [None]:
!unzip -o DAVIS-2017_semantics-480p.zip
#%cd /content/drive/MyDrive/RVOS_VR_Project

In [None]:
!unzip -o DAVIS-2017-Unsupervised-trainval-480p.zip

In [None]:
#%cd /content/drive/MyDrive/RVOS_VR_Project

In [None]:
!ls /content/drive/MyDrive/RVOS_VR_Project/data/ref-davis/DAVIS

In [None]:
import os
import json
import random
import matplotlib.pyplot as plt
import cv2
import numpy as np

def parse_davis_text_annotations(annotation_file):
    """Parse DAVIS text annotations from the given file."""
    annotations = {}

    with open(annotation_file, 'r') as f:
        lines = f.readlines()

    for line in lines:
        line = line.strip()
        if not line or line.startswith('%%%'):  #skip empty lines and format description
            continue

        #parse the line: video_name object_id "referring_expression"
        parts = line.split('"')
        if len(parts) < 2:
            continue

        prefix = parts[0].strip().split()
        if len(prefix) < 2:
            continue

        video_name = prefix[0]
        obj_id = prefix[1]
        expression = parts[1]

        #initialize video entry if not exists
        if video_name not in annotations:
            annotations[video_name] = []

        #add the annotation
        annotations[video_name].append({
            'obj_id': obj_id,
            'expression': expression
        })

    return annotations

#Parse the annotations
davis17_annot1 = parse_davis_text_annotations('/content/drive/MyDrive/RVOS_VR_Project/data/ref-davis/davis_text_annotations/Davis17_annot1.txt')
davis17_annot2 = parse_davis_text_annotations('/content/drive/MyDrive/RVOS_VR_Project/data/ref-davis/davis_text_annotations/Davis17_annot2.txt')

#Combine annotations
davis17_annotations = {}
for video_name, annotations in davis17_annot1.items():
    davis17_annotations[video_name] = annotations

for video_name, annotations in davis17_annot2.items():
    if video_name in davis17_annotations:
        davis17_annotations[video_name].extend(annotations)
    else:
        davis17_annotations[video_name] = annotations

#print statistics
print(f"Total videos: {len(davis17_annotations)}")
total_expressions = sum(len(annotations) for annotations in davis17_annotations.values())
print(f"Total expressions: {total_expressions}")

#save the combined annotations for easier access
with open('/content/drive/MyDrive/RVOS_VR_Project/data/ref-davis/combined_annotations.json', 'w') as f:
    json.dump(davis17_annotations, f)

In [None]:
def visualize_sample(data_root, annotations):
    """Visualize a random sample from the dataset."""
    video_names = list(annotations.keys())
    random_video = random.choice(video_names)

    #get expressions for this video
    expressions = annotations[random_video]
    random_exp = random.choice(expressions)

    #get the object ID and expression
    obj_id = random_exp['obj_id']
    expression_text = random_exp['expression']

    print(f"Selected video: {random_video}")
    print(f"Object ID from annotation: {obj_id}")
    print(f"Expression: {expression_text}")

    #get frame paths
    frames_dir = os.path.join(data_root, 'DAVIS', 'JPEGImages', '480p', random_video)
    masks_dir = os.path.join(data_root, 'DAVIS', 'Annotations_unsupervised', '480p', random_video)

    if not os.path.exists(frames_dir):
        print(f"Frames directory not found: {frames_dir}")
        return None, None, None

    if not os.path.exists(masks_dir):
        print(f"Masks directory not found: {masks_dir}")
        return None, None, None

    print(f"Using frames directory: {frames_dir}")
    print(f"Using masks directory: {masks_dir}")

    #get all frames
    frame_files = sorted([f for f in os.listdir(frames_dir) if f.endswith('.jpg')])

    if not frame_files:
        print(f"No frame files found in {frames_dir}")
        return None, None, None

    #sample a few frames
    sample_indices = [0, len(frame_files)//2, len(frame_files)-1]  # First, middle, last

    #first, let's analyze the masks to find all available object IDs
    mask_file = frame_files[0].replace('.jpg', '.png')
    mask_path = os.path.join(masks_dir, mask_file)

    if os.path.exists(mask_path):
        mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
        unique_values = np.unique(mask)
        print(f"Available object IDs in mask: {unique_values}")

        #filter out background (0)
        object_ids = [id for id in unique_values if id > 0]

        if not object_ids:
            print("No object IDs found in mask!")
            return None, None, None

        #for now, let's just use the first non-zero ID
        #in a real application, you'd need to map between text annotation IDs and mask IDs
        actual_obj_id = object_ids[0]
        print(f"Using object ID {actual_obj_id} from mask (instead of {obj_id} from annotation)")
    else:
        print(f"Mask file not found: {mask_path}")
        return None, None, None

    fig, axes = plt.subplots(len(sample_indices), 3, figsize=(15, 4*len(sample_indices)))

    for i, idx in enumerate(sample_indices):
        frame_file = frame_files[idx]

        #load frame
        frame_path = os.path.join(frames_dir, frame_file)
        frame = cv2.imread(frame_path)
        if frame is None:
            print(f"Failed to load frame: {frame_path}")
            continue
        frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)

        #load mask
        mask_file = frame_file.replace('.jpg', '.png')
        mask_path = os.path.join(masks_dir, mask_file)

        if not os.path.exists(mask_path):
            print(f"Mask file not found: {mask_path}")
            obj_mask = np.zeros((frame.shape[0], frame.shape[1]), dtype=np.float32)
        else:
            mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
            if mask is None:
                print(f"Failed to load mask: {mask_path}")
                obj_mask = np.zeros((frame.shape[0], frame.shape[1]), dtype=np.float32)
            else:
                obj_mask = (mask == actual_obj_id).astype(np.float32)

        #original frame
        axes[i, 0].imshow(frame)
        axes[i, 0].set_title(f"Frame {idx}")
        axes[i, 0].axis('off')

        #mask only
        axes[i, 1].imshow(obj_mask, cmap='gray')
        axes[i, 1].set_title(f"Mask (Object ID: {actual_obj_id})")
        axes[i, 1].axis('off')

        #overlay
        masked_img = frame.copy()
        mask_colored = np.zeros_like(frame)
        mask_colored[:,:,0] = obj_mask * 255  # Red channel
        masked_img = cv2.addWeighted(masked_img, 1, mask_colored, 0.5, 0)

        axes[i, 2].imshow(masked_img)
        axes[i, 2].set_title(f"Overlay: {expression_text}")
        axes[i, 2].axis('off')

    plt.tight_layout()
    plt.show()

    #second object ID if available
    if len(object_ids) > 1:
        actual_obj_id = object_ids[1]
        print(f"\nTrying second object ID: {actual_obj_id}")

        fig, axes = plt.subplots(len(sample_indices), 3, figsize=(15, 4*len(sample_indices)))

        for i, idx in enumerate(sample_indices):
            frame_file = frame_files[idx]

            #load frame
            frame_path = os.path.join(frames_dir, frame_file)
            frame = cv2.imread(frame_path)
            frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)

            #load mask
            mask_file = frame_file.replace('.jpg', '.png')
            mask_path = os.path.join(masks_dir, mask_file)
            mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
            obj_mask = (mask == actual_obj_id).astype(np.float32)

            #original frame
            axes[i, 0].imshow(frame)
            axes[i, 0].set_title(f"Frame {idx}")
            axes[i, 0].axis('off')

            #mask only
            axes[i, 1].imshow(obj_mask, cmap='gray')
            axes[i, 1].set_title(f"Mask (Object ID: {actual_obj_id})")
            axes[i, 1].axis('off')

            #overlay
            masked_img = frame.copy()
            mask_colored = np.zeros_like(frame)
            mask_colored[:,:,0] = obj_mask * 255
            masked_img = cv2.addWeighted(masked_img, 1, mask_colored, 0.5, 0)

            axes[i, 2].imshow(masked_img)
            axes[i, 2].set_title(f"Overlay: {expression_text}")
            axes[i, 2].axis('off')

        plt.tight_layout()
        plt.show()

    return random_video, expression_text, obj_id

try:
    sample_video, sample_expr, sample_obj_id = visualize_sample(
        '/content/drive/MyDrive/RVOS_VR_Project/data/ref-davis',
        davis17_annotations
    )
    if sample_video:
        print(f"Video: {sample_video}, Expression: '{sample_expr}', Object ID: {sample_obj_id}")
except Exception as e:
    print(f"Error visualizing sample: {e}")
    import traceback
    traceback.print_exc()

In [None]:
#mapping between annotation object IDs and mask object IDs
def create_object_id_mapping(annotations):
    """Create a mapping between annotation object IDs and actual mask object IDs."""
    davis_root = '/content/drive/MyDrive/RVOS_VR_Project/data/ref-davis'
    frames_dir = os.path.join(davis_root, 'DAVIS', 'JPEGImages', '480p')
    masks_dir = os.path.join(davis_root, 'DAVIS', 'Annotations_unsupervised', '480p')

    id_mapping = {}

    for video_name, expressions in annotations.items():
        if not os.path.exists(os.path.join(frames_dir, video_name)) or not os.path.exists(os.path.join(masks_dir, video_name)):
            continue

        #frame files
        frame_files = sorted([f for f in os.listdir(os.path.join(frames_dir, video_name)) if f.endswith('.jpg')])
        if not frame_files:
            continue

        #load first mask to get object IDs
        mask_file = frame_files[0].replace('.jpg', '.png')
        mask_path = os.path.join(masks_dir, video_name, mask_file)

        if not os.path.exists(mask_path):
            continue

        mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
        if mask is None:
            continue

        #unique object IDs in mask - convert to standard Python int
        unique_ids = [int(id) for id in np.unique(mask) if id > 0]

        #mapping for this video
        id_mapping[video_name] = {}

        #for each expression, assign the first available object ID
        for i, exp_data in enumerate(expressions):
            anno_obj_id = exp_data['obj_id']
            if i < len(unique_ids):
                id_mapping[video_name][anno_obj_id] = unique_ids[i]
            else:
                # If we run out of object IDs, use the first one
                id_mapping[video_name][anno_obj_id] = unique_ids[0] if unique_ids else 0

    return id_mapping

def split_dataset(annotations, train_ratio=0.8, seed=42):
    """Split the dataset into training and validation sets."""
    random.seed(seed)

    #list of videos that actually exist in the dataset
    davis_root = '/content/drive/MyDrive/RVOS_VR_Project/data/ref-davis'
    frames_dir = os.path.join(davis_root, 'DAVIS', 'JPEGImages', '480p')
    masks_dir = os.path.join(davis_root, 'DAVIS', 'Annotations_unsupervised', '480p')

    #check which videos from annotations exist in the dataset
    available_videos = []
    for video_name in annotations.keys():
        if os.path.exists(os.path.join(frames_dir, video_name)) and os.path.exists(os.path.join(masks_dir, video_name)):
            available_videos.append(video_name)

    print(f"Total videos in annotations: {len(annotations)}")
    print(f"Available videos in dataset: {len(available_videos)}")

    #shuffle and split
    random.shuffle(available_videos)
    split_idx = int(len(available_videos) * train_ratio)
    train_videos = available_videos[:split_idx]
    val_videos = available_videos[split_idx:]

    train_annotations = {video: annotations[video] for video in train_videos}
    val_annotations = {video: annotations[video] for video in val_videos}

    return train_annotations, val_annotations

#split the dataset
train_annotations, val_annotations = split_dataset(davis17_annotations)

print(f"Training videos: {len(train_annotations)}")
print(f"Validation videos: {len(val_annotations)}")

#save the splits
with open('/content/drive/MyDrive/RVOS_VR_Project/data/ref-davis/train_annotations.json', 'w') as f:
    json.dump(train_annotations, f)

with open('/content/drive/MyDrive/RVOS_VR_Project/data/ref-davis/val_annotations.json', 'w') as f:
    json.dump(val_annotations, f)

In [None]:
def visualize_with_mapping(data_root, annotations, id_mapping, video_name=None):
    """Visualize a sample using the object ID mapping."""
    if video_name is None:
        #random video that has a mapping
        mapped_videos = list(id_mapping.keys())
        if not mapped_videos:
            print("No videos with ID mapping available")
            return None, None, None
        video_name = random.choice(mapped_videos)

    if video_name not in annotations:
        print(f"Video {video_name} not in annotations")
        return None, None, None

    #expressions for this video
    expressions = annotations[video_name]
    random_exp = random.choice(expressions)

    #the object ID and expression
    anno_obj_id = random_exp['obj_id']
    expression_text = random_exp['expression']

    #the mapped object ID
    if video_name in id_mapping and anno_obj_id in id_mapping[video_name]:
        actual_obj_id = id_mapping[video_name][anno_obj_id]
    else:
        print(f"No mapping found for video {video_name}, object ID {anno_obj_id}")
        return None, None, None

    print(f"Selected video: {video_name}")
    print(f"Annotation object ID: {anno_obj_id}")
    print(f"Mapped to actual object ID: {actual_obj_id}")
    print(f"Expression: {expression_text}")

    #frame paths
    frames_dir = os.path.join(data_root, 'DAVIS', 'JPEGImages', '480p', video_name)
    masks_dir = os.path.join(data_root, 'DAVIS', 'Annotations_unsupervised', '480p', video_name)

    if not os.path.exists(frames_dir) or not os.path.exists(masks_dir):
        print(f"Directories not found for video {video_name}")
        return None, None, None

    #all frames
    frame_files = sorted([f for f in os.listdir(frames_dir) if f.endswith('.jpg')])

    if not frame_files:
        print(f"No frame files found for video {video_name}")
        return None, None, None

    #sample a few frames
    sample_indices = [0, len(frame_files)//2, len(frame_files)-1]  # First, middle, last

    fig, axes = plt.subplots(len(sample_indices), 3, figsize=(15, 4*len(sample_indices)))

    for i, idx in enumerate(sample_indices):
        frame_file = frame_files[idx]

        #frame
        frame_path = os.path.join(frames_dir, frame_file)
        frame = cv2.imread(frame_path)
        if frame is None:
            print(f"Failed to load frame: {frame_path}")
            continue
        frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)

        #mask
        mask_file = frame_file.replace('.jpg', '.png')
        mask_path = os.path.join(masks_dir, mask_file)

        if not os.path.exists(mask_path):
            print(f"Mask file not found: {mask_path}")
            obj_mask = np.zeros((frame.shape[0], frame.shape[1]), dtype=np.float32)
        else:
            mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
            if mask is None:
                print(f"Failed to load mask: {mask_path}")
                obj_mask = np.zeros((frame.shape[0], frame.shape[1]), dtype=np.float32)
            else:
                #use the mapped object ID
                obj_mask = (mask == actual_obj_id).astype(np.float32)

        #original frame
        axes[i, 0].imshow(frame)
        axes[i, 0].set_title(f"Frame {idx}")
        axes[i, 0].axis('off')

        #mask only
        axes[i, 1].imshow(obj_mask, cmap='gray')
        axes[i, 1].set_title(f"Mask (Object ID: {actual_obj_id})")
        axes[i, 1].axis('off')

        #overlay
        masked_img = frame.copy()
        mask_colored = np.zeros_like(frame)
        mask_colored[:,:,0] = obj_mask * 255
        masked_img = cv2.addWeighted(masked_img, 1, mask_colored, 0.5, 0)

        axes[i, 2].imshow(masked_img)
        axes[i, 2].set_title(f"Overlay: {expression_text}")
        axes[i, 2].axis('off')

    plt.tight_layout()
    plt.show()

    return video_name, expression_text, anno_obj_id
#create object ID mapping
object_id_mapping = create_object_id_mapping(davis17_annotations)
#visualize a sample from the training set
print("\nVisualizing a sample from the training set:")
train_video, train_expr, train_obj_id = visualize_with_mapping(
    '/content/drive/MyDrive/RVOS_VR_Project/data/ref-davis',
    train_annotations,
    object_id_mapping
)

#visualize a sample from the validation set
print("\nVisualizing a sample from the validation set:")
val_video, val_expr, val_obj_id = visualize_with_mapping(
    '/content/drive/MyDrive/RVOS_VR_Project/data/ref-davis',
    val_annotations,
    object_id_mapping
)

In [None]:
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from transformers import BertTokenizer
import albumentations as A
from albumentations.pytorch import ToTensorV2

class RefDAVISDataset(Dataset):
    def __init__(self, data_root, annotations, id_mapping, transform=None, max_seq_len=5):
        self.data_root = data_root
        self.annotations = annotations
        self.id_mapping = id_mapping
        self.transform = transform
        self.max_seq_len = max_seq_len
        self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

        #samples list
        self.samples = []
        for video_id, expressions in self.annotations.items():
            #skip videos without ID mapping
            if video_id not in self.id_mapping:
                continue

            for exp_data in expressions:
                obj_id = exp_data['obj_id']
                #skip objects without ID mapping
                if obj_id not in self.id_mapping[video_id]:
                    continue

                expression = exp_data['expression']
                self.samples.append({
                    'video_id': video_id,
                    'expression': expression,
                    'obj_id': obj_id
                })

        print(f"Created dataset with {len(self.samples)} samples")

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        sample = self.samples[idx]
        video_id = sample['video_id']
        anno_obj_id = sample['obj_id']
        expression = sample['expression']

        #the mapped object ID
        actual_obj_id = self.id_mapping[video_id][anno_obj_id]

        #frame paths
        frames_dir = os.path.join(self.data_root, 'DAVIS', 'JPEGImages', '480p', video_id)
        masks_dir = os.path.join(self.data_root, 'DAVIS', 'Annotations_unsupervised', '480p', video_id)

        #check if directories exist
        if not os.path.exists(frames_dir) or not os.path.exists(masks_dir):
            #return dummy data if directories don't exist
            dummy_frame = torch.zeros(3, 384, 384)
            dummy_mask = torch.zeros(384, 384)
            dummy_text_ids = torch.zeros(20, dtype=torch.long)
            dummy_text_mask = torch.zeros(20, dtype=torch.long)

            return {
                'frames': torch.stack([dummy_frame] * self.max_seq_len),
                'masks': torch.stack([dummy_mask] * self.max_seq_len),
                'text_ids': dummy_text_ids,
                'text_mask': dummy_text_mask,
                'expression': expression,
                'video_id': video_id,
                'obj_id': anno_obj_id
            }

        frame_files = sorted([f for f in os.listdir(frames_dir) if f.endswith('.jpg')])

        #sample frames (for simplicity, take evenly spaced frames)
        if len(frame_files) > self.max_seq_len:
            indices = np.linspace(0, len(frame_files)-1, self.max_seq_len, dtype=int)
            frame_files = [frame_files[i] for i in indices]
        else:
            # If fewer frames than max_seq_len, use all and pad later
            frame_files = frame_files[:self.max_seq_len]

        #load frames and masks
        frames = []
        masks = []

        for frame_file in frame_files:
            #frame
            frame_path = os.path.join(frames_dir, frame_file)
            frame = cv2.imread(frame_path)
            if frame is None:
                #use a blank frame if loading fails
                frame = np.zeros((384, 384, 3), dtype=np.uint8)
            else:
                frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)

            #load mask
            mask_file = frame_file.replace('.jpg', '.png')
            mask_path = os.path.join(masks_dir, mask_file)

            if not os.path.exists(mask_path):
                #use a blank mask if file doesn't exist
                obj_mask = np.zeros((frame.shape[0], frame.shape[1]), dtype=np.float32)
            else:
                mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
                if mask is None:
                    #use a blank mask if loading fails
                    obj_mask = np.zeros((frame.shape[0], frame.shape[1]), dtype=np.float32)
                else:
                    #extract mask for the specific object using the mapped ID
                    obj_mask = (mask == actual_obj_id).astype(np.float32)

            if self.transform:
                #apply the same transform to both frame and mask
                augmented = self.transform(image=frame, mask=obj_mask)
                frame = augmented['image']
                obj_mask = augmented['mask']
            else:
                #default transformation
                frame = transforms.ToTensor()(frame)
                obj_mask = torch.from_numpy(obj_mask)

            frames.append(frame)
            masks.append(obj_mask)

        #pad sequences if needed
        while len(frames) < self.max_seq_len:
            frames.append(torch.zeros_like(frames[0]) if frames else torch.zeros(3, 384, 384))
            masks.append(torch.zeros_like(masks[0]) if masks else torch.zeros(384, 384))

        #tokenize expression
        encoded_text = self.tokenizer(
            expression,
            padding='max_length',
            max_length=20,
            truncation=True,
            return_tensors='pt'
        )

        return {
            'frames': torch.stack(frames),
            'masks': torch.stack(masks),
            'text_ids': encoded_text['input_ids'].squeeze(0),
            'text_mask': encoded_text['attention_mask'].squeeze(0),
            'expression': expression,
            'video_id': video_id,
            'obj_id': anno_obj_id,
            'actual_obj_id': actual_obj_id
        }

#create data loaders with augmentation
def get_data_loaders(data_root, train_annotations, val_annotations, id_mapping, batch_size=4):
    #training transforms with augmentation
    train_transform = A.Compose([
        A.Resize(height=384, width=384),
        A.HorizontalFlip(p=0.5),
        A.ShiftScaleRotate(p=0.2),
        A.RandomBrightnessContrast(p=0.2),
        A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ToTensorV2(),
    ])

    #validation transforms without augmentation
    val_transform = A.Compose([
        A.Resize(height=384, width=384),
        A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ToTensorV2(),
    ])

    train_dataset = RefDAVISDataset(data_root, train_annotations, id_mapping, transform=train_transform)
    val_dataset = RefDAVISDataset(data_root, val_annotations, id_mapping, transform=val_transform)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=2)

    return train_loader, val_loader

#create the data loaders
train_loader, val_loader = get_data_loaders(
    '/content/drive/MyDrive/RVOS_VR_Project/data/ref-davis',
    train_annotations,
    val_annotations,
    object_id_mapping,
    batch_size=2
)

#check a batch from the data loader
batch = next(iter(train_loader))
print(f"Batch shapes:")
print(f"  Frames: {batch['frames'].shape}")
print(f"  Masks: {batch['masks'].shape}")
print(f"  Text IDs: {batch['text_ids'].shape}")
print(f"  Text Mask: {batch['text_mask'].shape}")

#visualize a sample from the batch
def visualize_batch_sample(batch, sample_idx=0):
    """Visualize a sample from a batch."""
    frames = batch['frames'][sample_idx].cpu()  # [T, C, H, W]
    masks = batch['masks'][sample_idx].cpu()    # [T, H, W]
    expression = batch['expression'][sample_idx]
    video_id = batch['video_id'][sample_idx]
    obj_id = batch['obj_id'][sample_idx]
    actual_obj_id = batch['actual_obj_id'][sample_idx]

    print(f"Video: {video_id}")
    print(f"Expression: {expression}")
    print(f"Annotation Object ID: {obj_id}")
    print(f"Actual Object ID: {actual_obj_id}")

    #denormalize frames
    mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)
    std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)
    frames = frames * std + mean
    frames = frames.permute(0, 2, 3, 1).numpy()  # [T, H, W, C]
    frames = np.clip(frames, 0, 1)

    masks = masks.numpy()  # [T, H, W]

    #display frames and masks
    fig, axes = plt.subplots(frames.shape[0], 2, figsize=(10, 4*frames.shape[0]))

    for i in range(frames.shape[0]):
        #frame
        axes[i, 0].imshow(frames[i])
        axes[i, 0].set_title(f"Frame {i}")
        axes[i, 0].axis('off')

        #mask overlay
        axes[i, 1].imshow(frames[i])
        axes[i, 1].imshow(masks[i], alpha=0.5, cmap='cool')
        axes[i, 1].set_title(f"Mask: {expression}")
        axes[i, 1].axis('off')

    plt.tight_layout()
    plt.show()

#visualize a sample from the training batch
visualize_batch_sample(batch)

#check a batch from the validation loader
val_batch = next(iter(val_loader))
print(f"\nValidation batch shapes:")
print(f"  Frames: {val_batch['frames'].shape}")
print(f"  Masks: {val_batch['masks'].shape}")
print(f"  Text IDs: {val_batch['text_ids'].shape}")
print(f"  Text Mask: {val_batch['text_mask'].shape}")

#visualize a sample from the validation batch
visualize_batch_sample(val_batch)

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
from transformers import BertModel
from einops import rearrange

class VisualBackbone(nn.Module):
    def __init__(self, backbone_type='resnet50', pretrained=True):
        super(VisualBackbone, self).__init__()
        if backbone_type == 'resnet50':
            resnet = models.resnet50(pretrained=pretrained)
            self.out_channels = 2048
        else:
            resnet = models.resnet18(pretrained=pretrained)
            self.out_channels = 512

        #use feature layers instead of just removing last two layers
        self.backbone = nn.Sequential(
            resnet.conv1,
            resnet.bn1,
            resnet.relu,
            resnet.maxpool,
            resnet.layer1,
            resnet.layer2,
            resnet.layer3,
            resnet.layer4
        )

        #additional feature refinement
        self.feature_refine = nn.Conv2d(self.out_channels, self.out_channels // 2, kernel_size=1)

    def forward(self, x):
        b, t, c, h, w = x.shape
        x = rearrange(x, 'b t c h w -> (b t) c h w')
        features = self.backbone(x)
        features = self.feature_refine(features)
        features = rearrange(features, '(b t) c h w -> b t c h w', b=b, t=t)
        return features

class TextEncoder(nn.Module):
    def __init__(self, pooling='cls'):
        super(TextEncoder, self).__init__()
        self.bert = BertModel.from_pretrained('bert-base-uncased')
        self.out_channels = 768
        self.pooling = pooling

        #optional additional projection
        self.text_projection = nn.Sequential(
            nn.Linear(self.out_channels, self.out_channels),
            nn.LayerNorm(self.out_channels),
            nn.GELU()
        )

    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)

        if self.pooling == 'cls':
            text_features = outputs.last_hidden_state[:, 0]
        elif self.pooling == 'mean':
            text_features = (outputs.last_hidden_state * attention_mask.unsqueeze(-1)).sum(1) / attention_mask.sum(1).unsqueeze(1)

        text_features = self.text_projection(text_features)
        return text_features

class CrossModalFusion(nn.Module):
    def __init__(self, visual_dim, text_dim, hidden_dim=256, num_heads=8):
        super(CrossModalFusion, self).__init__()
        self.visual_proj = nn.Sequential(
            nn.Conv2d(visual_dim, hidden_dim, kernel_size=1),
            nn.BatchNorm2d(hidden_dim),
            nn.GELU()
        )
        self.text_proj = nn.Sequential(
            nn.Linear(text_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.GELU()
        )
        self.multihead_attn = nn.MultiheadAttention(embed_dim=hidden_dim, num_heads=num_heads, dropout=0.1)
        self.norm = nn.LayerNorm(hidden_dim)
        self.gamma = nn.Parameter(torch.ones(1))

        #residual connection
        self.residual_proj = nn.Conv2d(hidden_dim, hidden_dim, kernel_size=1)

    def forward(self, visual_feat, text_feat):
        b, t, _, h, w = visual_feat.shape
        visual_feat = rearrange(visual_feat, 'b t c h w -> (b t) c h w')
        visual_proj = self.visual_proj(visual_feat)
        text_proj = self.text_proj(text_feat)

        text_proj = text_proj.unsqueeze(1).expand(-1, t, -1)
        text_proj = rearrange(text_proj, 'b t d -> (b t) d').unsqueeze(0)

        visual_proj_flat = rearrange(visual_proj, '(b t) d h w -> (h w) (b t) d', b=b, t=t)
        attn_output, _ = self.multihead_attn(visual_proj_flat, text_proj, text_proj)

        attn_output = rearrange(attn_output, '(h w) (b t) d -> b t d h w', b=b, t=t, h=h, w=w)
        visual_proj_reshaped = rearrange(visual_proj, '(b t) d h w -> b t d h w', b=b, t=t)

        residual = self.residual_proj(rearrange(visual_proj_reshaped, 'b t d h w -> (b t) d h w'))
        residual = rearrange(residual, '(b t) d h w -> b t d h w', b=b, t=t)

        fused_feat = self.gamma * attn_output + visual_proj_reshaped + residual
        fused_feat = rearrange(fused_feat, 'b t d h w -> (b t h w) d')
        fused_feat = self.norm(fused_feat)
        fused_feat = rearrange(fused_feat, '(b t h w) d -> b t d h w', b=b, t=t, h=h, w=w)

        return fused_feat

class TemporalModule(nn.Module):
    def __init__(self, in_channels, hidden_dim=256):
        super(TemporalModule, self).__init__()
        #use 3D Convolutions with more layers and deeper architecture
        self.temporal_conv = nn.Sequential(
            nn.Conv3d(in_channels, hidden_dim, kernel_size=(3, 3, 3), padding=(1, 1, 1)),
            nn.BatchNorm3d(hidden_dim),
            nn.GELU(),
            nn.Conv3d(hidden_dim, hidden_dim * 2, kernel_size=(3, 3, 3), padding=(1, 1, 1)),
            nn.BatchNorm3d(hidden_dim * 2),
            nn.GELU(),
            nn.Conv3d(hidden_dim * 2, hidden_dim, kernel_size=(3, 3, 3), padding=(1, 1, 1)),
            nn.BatchNorm3d(hidden_dim),
            nn.GELU()
        )

        #optional residual connection
        self.residual = nn.Conv3d(in_channels, hidden_dim, kernel_size=1)

    def forward(self, x):
        x = rearrange(x, 'b t c h w -> b c t h w')
        residual = self.residual(rearrange(x, 'b c t h w -> b c t h w'))
        x = self.temporal_conv(x) + residual
        x = rearrange(x, 'b c t h w -> b t c h w')
        return x

class SegmentationHead(nn.Module):
    def __init__(self, in_channels, mid_channels=256, out_channels=1):
        super(SegmentationHead, self).__init__()
        self.segmentation_head = nn.Sequential(
            nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(mid_channels),
            nn.GELU(),
            nn.Dropout(0.3),
            nn.Conv2d(mid_channels, mid_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(mid_channels),
            nn.GELU(),
            nn.Dropout(0.3),
            nn.Conv2d(mid_channels, out_channels, kernel_size=1)
        )

    def forward(self, x):
        b, t, c, h, w = x.shape
        x = rearrange(x, 'b t c h w -> (b t) c h w')
        x = self.segmentation_head(x)
        x = rearrange(x, '(b t) c h w -> b t c h w', b=b, t=t)
        return x

class RVOSModel(nn.Module):
    def __init__(self, hidden_dim=256, backbone_type='resnet50'):
        super(RVOSModel, self).__init__()
        self.visual_backbone = VisualBackbone(backbone_type=backbone_type)
        visual_dim = self.visual_backbone.out_channels // 2

        self.text_encoder = TextEncoder(pooling='mean')
        text_dim = self.text_encoder.out_channels

        self.fusion = CrossModalFusion(visual_dim, text_dim, hidden_dim)
        self.temporal = TemporalModule(hidden_dim, hidden_dim)
        self.segmentation_head = SegmentationHead(hidden_dim, hidden_dim, out_channels=1)

        #adjustable upsampling
        self.upsample = nn.Upsample(scale_factor=32, mode='bilinear', align_corners=False)

    def forward(self, frames, text_ids, text_mask):
        visual_features = self.visual_backbone(frames)
        text_features = self.text_encoder(text_ids, text_mask)

        fused_features = self.fusion(visual_features, text_features)
        temporal_features = self.temporal(fused_features)

        masks = self.segmentation_head(temporal_features)

        b, t, c, h, w = masks.shape
        masks = rearrange(masks, 'b t c h w -> (b t) c h w')
        masks = self.upsample(masks)
        masks = rearrange(masks, '(b t) c h w -> b t c h w', b=b, t=t)

        return masks

#initialize the model
model = RVOSModel(backbone_type='resnet50')
print(model)

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
from torch.amp import GradScaler, autocast
from tqdm import tqdm
import numpy as np
import os

#to avoid memory fragmentation
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

#define device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

class RVOSLoss(nn.Module):
    def __init__(self, bce_weight=0.3, dice_weight=0.6, focal_weight=0.1):
        super(RVOSLoss, self).__init__()
        self.bce_weight = bce_weight
        self.dice_weight = dice_weight
        self.focal_weight = focal_weight
        self.bce_loss = nn.BCEWithLogitsLoss()

    def forward(self, pred, target):
        b, t, c, h, w = pred.shape
        if target.dim() == 4:
            target = target.unsqueeze(2)
        if target.shape[3] != h or target.shape[4] != w:
            target = F.interpolate(
                target.reshape(b*t, 1, target.shape[3], target.shape[4]),
                size=(h, w),
                mode='nearest'
            ).reshape(b, t, 1, h, w)

        pred_flat = pred.reshape(b * t, c, h, w)
        target_flat = target.reshape(b * t, 1, h, w)

        #BCE Loss
        bce = self.bce_loss(pred_flat, target_flat)

        #Dice Loss with smoother calculation
        pred_sigmoid = torch.sigmoid(pred_flat)
        smooth = 1e-6
        intersection = (pred_sigmoid * target_flat).sum(dim=(2, 3))
        union = pred_sigmoid.sum(dim=(2, 3)) + target_flat.sum(dim=(2, 3))
        dice = 1 - ((2.0 * intersection + smooth) / (union + smooth))
        dice = dice.mean()

        #Focal Loss
        focal = self.focal_loss(pred_flat, target_flat)

        #combine losses
        loss = (self.bce_weight * bce +
                self.dice_weight * dice +
                self.focal_weight * focal)

        return loss

    def focal_loss(self, pred, target, alpha=0.25, gamma=2):
        bce = F.binary_cross_entropy_with_logits(pred, target, reduction='none')
        pred_sigmoid = torch.sigmoid(pred)
        p_t = pred_sigmoid * target + (1 - pred_sigmoid) * (1 - target)

        loss = bce * ((1 - p_t) ** gamma)

        if alpha >= 0:
            alpha_t = alpha * target + (1 - alpha) * (1 - target)
            loss = alpha_t * loss

        return loss.mean()

def calculate_iou(pred, target, threshold=0.5):
    b, t, c, h, w = pred.shape
    if target.dim() == 4:
        target = target.unsqueeze(2)
    if target.shape[3] != h or target.shape[4] != w:
        target = F.interpolate(
            target.reshape(b*t, 1, target.shape[3], target.shape[4]),
            size=(h, w),
            mode='nearest'
        ).reshape(b, t, 1, h, w)

    pred = (pred > threshold).float()

    #prevent division by zero
    epsilon = 1e-7

    intersection = (pred * target).sum((2, 3, 4))
    union = pred.sum((2, 3, 4)) + target.sum((2, 3, 4)) - intersection
    union = torch.clamp(union, min=epsilon)

    iou = intersection / union
    return iou.mean()

def train_one_epoch(model, train_loader, criterion, optimizer, device, scaler,
                    accumulation_steps=4, clip_grad_norm=1.0):
    model.train()
    epoch_loss = 0
    epoch_iou = 0
    num_batches = len(train_loader)
    optimizer.zero_grad()

    #enhanced logging and tracking
    total_gradient_norm = 0

    with tqdm(total=num_batches, desc="Training") as pbar:
        for i, batch in enumerate(train_loader):
            frames = batch['frames'].to(device)
            masks = batch['masks'].to(device)
            text_ids = batch['text_ids'].to(device)
            text_mask = batch['text_mask'].to(device)

            #mixed precision training
            with autocast('cuda', dtype=torch.float16):
                outputs = model(frames, text_ids, text_mask)
                loss = criterion(outputs, masks)

            #gradient accumulation with scaled loss
            scaler.scale(loss / accumulation_steps).backward()

            #check if it's time to step
            if (i + 1) % accumulation_steps == 0:
                #gradient norm clipping
                scaler.unscale_(optimizer)
                if clip_grad_norm > 0:
                    grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), clip_grad_norm)
                    total_gradient_norm += grad_norm.item()

                #optimizer step
                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad()

            #IoU calculation
            with torch.no_grad():
                iou = calculate_iou(torch.sigmoid(outputs), masks)

            #tracking metrics
            epoch_loss += loss.item()
            epoch_iou += iou.item()

            #Progress bar update
            pbar.update(1)
            pbar.set_postfix({
                "Loss": f"{loss.item():.4f}",
                "IoU": f"{iou.item():.4f}",
                "Avg Grad Norm": f"{total_gradient_norm / max(1, (i+1)/accumulation_steps):.4f}"
            })

            #Memory management
            del outputs, loss
            torch.cuda.empty_cache()

        if (num_batches % accumulation_steps) != 0:
            scaler.unscale_(optimizer)
            if clip_grad_norm > 0:
                torch.nn.utils.clip_grad_norm_(model.parameters(), clip_grad_norm)
            scaler.step(optimizer)
            scaler.update()

    avg_loss = epoch_loss / num_batches
    avg_iou = epoch_iou / num_batches
    avg_grad_norm = total_gradient_norm / (num_batches / accumulation_steps)

    return avg_loss, avg_iou, avg_grad_norm

def validate(model, val_loader, criterion, device):
    model.eval()
    val_loss = 0
    val_iou = 0
    num_batches = len(val_loader)

    with torch.no_grad():
        with tqdm(total=num_batches, desc="Validation") as pbar:
            for batch in val_loader:
                frames = batch['frames'].to(device)
                masks = batch['masks'].to(device)
                text_ids = batch['text_ids'].to(device)
                text_mask = batch['text_mask'].to(device)

                with autocast('cuda', dtype=torch.float16):
                    outputs = model(frames, text_ids, text_mask)
                    loss = criterion(outputs, masks)

                iou = calculate_iou(torch.sigmoid(outputs), masks)
                val_loss += loss.item()
                val_iou += iou.item()

                pbar.update(1)
                pbar.set_postfix({
                    "Loss": f"{loss.item():.4f}",
                    "IoU": f"{iou.item():.4f}"
                })

                del outputs, loss
                torch.cuda.empty_cache()

    avg_loss = val_loss / num_batches
    avg_iou = val_iou / num_batches
    return avg_loss, avg_iou

def train_model(model, train_loader, val_loader, num_epochs=30, lr=5e-4, device=device):
    model = model.to(device)

    criterion = RVOSLoss()

    #optimizer with weight decay and adaptive learning
    optimizer = optim.AdamW(
        model.parameters(),
        lr=lr,
        weight_decay=1e-4,
        betas=(0.9, 0.999)
    )

    #cosine Annealing with Warm Restarts
    scheduler = CosineAnnealingWarmRestarts(
        optimizer,
        T_0=5,  #initial restart period
        T_mult=2,  #multiplicative factor for subsequent restart periods
        eta_min=1e-6  #minimum learning rate
    )

    #Mixed precision training
    scaler = GradScaler('cuda')

    #model checkpoint management
    best_val_iou = 0
    best_model_path = '/content/drive/MyDrive/RVOS_VR_Project/checkpoints/best_model.pth'

    for epoch in range(num_epochs):
        print(f"\nEpoch {epoch+1}/{num_epochs}")

        #learning rate logging
        current_lr = optimizer.param_groups[0]['lr']
        print(f"Current Learning Rate: {current_lr:.6f}")

        #training phase
        train_loss, train_iou, train_grad_norm = train_one_epoch(
            model, train_loader, criterion, optimizer, device, scaler
        )

        #validation phase
        val_loss, val_iou = validate(model, val_loader, criterion, device)

        #learning rate scheduling
        scheduler.step()

        #epoch metrics
        print(f"Train Loss: {train_loss:.4f}, Train IoU: {train_iou:.4f}, Train Grad Norm: {train_grad_norm:.4f}")
        print(f"Val Loss: {val_loss:.4f}, Val IoU: {val_iou:.4f}")

        #model checkpointing
        if val_iou > best_val_iou:
            best_val_iou = val_iou
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'val_iou': val_iou,
                'val_loss': val_loss,
            }, best_model_path)
            print(f"Saved best model with Val IoU: {val_iou:.4f}")

    #load and return the best model
    checkpoint = torch.load(best_model_path)
    model.load_state_dict(checkpoint['model_state_dict'])
    print(f"Loaded best model from epoch {checkpoint['epoch']+1} with Val IoU: {checkpoint['val_iou']:.4f}")

    return model

In [None]:
!mkdir '/content/drive/MyDrive/RVOS_VR_Project/checkpoints'

In [None]:
model = train_model(model, train_loader, val_loader, num_epochs=50, lr=1e-4, device=device)

In [None]:
#load the best model
best_model_path = '/content/drive/MyDrive/RVOS_VR_Project/checkpoints/best_model.pth'
checkpoint = torch.load(best_model_path)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
print(f"Loaded best model from epoch {checkpoint['epoch']+1} with Val IoU: {checkpoint['val_iou']:.4f}")

def visualize_specific_videos(model, dataloader, device, video_names):
    """Visualize model predictions on specific videos."""
    model.eval()
    video_batches = {}
    for batch in dataloader:
        batch_video_ids = batch['video_id']
        for i, vid in enumerate(batch_video_ids):
            if vid in video_names and vid not in video_batches:
                single_batch = {
                    'frames': batch['frames'][i:i+1],
                    'masks': batch['masks'][i:i+1],
                    'text_ids': batch['text_ids'][i:i+1],
                    'text_mask': batch['text_mask'][i:i+1],
                    'expression': [batch['expression'][i]],
                    'video_id': [batch['video_id'][i]]
                }
                video_batches[vid] = single_batch
                if len(video_batches) == len(video_names):
                    break
        if len(video_batches) == len(video_names):
            break

    for video_id, batch in video_batches.items():
        print(f"\nProcessing video: {video_id}")

        frames = batch['frames'].to(device)
        masks = batch['masks'].to(device)
        text_ids = batch['text_ids'].to(device)
        text_mask = batch['text_mask'].to(device)
        expressions = batch['expression']

        with torch.no_grad():
            outputs = model(frames, text_ids, text_mask)
            pred_masks = torch.sigmoid(outputs) > 0.5

        #visualize
        sample_frames = frames[0].cpu()  # [T, C, H, W]
        sample_masks = masks[0].cpu()    # [T, H, W]
        sample_preds = pred_masks[0, :, 0].cpu()  # [T, H, W]
        expression = expressions[0]

        #denormalize frames
        mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)
        std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)
        sample_frames = sample_frames * std + mean
        sample_frames = sample_frames.permute(0, 2, 3, 1).numpy()  # [T, H, W, C]
        sample_frames = np.clip(sample_frames, 0, 1)

        #convert masks to numpy
        sample_masks = sample_masks.numpy()
        sample_preds = sample_preds.numpy()

        #frames and masks
        fig, axes = plt.subplots(sample_frames.shape[0], 2, figsize=(15, 4*sample_frames.shape[0]))

        #handle case with only one frame
        if sample_frames.shape[0] == 1:
            axes = np.array([axes])

        for t in range(sample_frames.shape[0]):
            #frame
            axes[t, 0].imshow(sample_frames[t])
            axes[t, 0].set_title(f"Frame {t}")
            axes[t, 0].axis('off')

            #predicted mask
            axes[t, 1].imshow(sample_frames[t])
            axes[t, 1].imshow(sample_preds[t], alpha=0.5, cmap='cool')
            axes[t, 1].set_title(f"Prediction")
            axes[t, 1].axis('off')

        plt.suptitle(f"Video: {video_id}, Expression: '{expression}'", fontsize=16)
        plt.tight_layout()
        plt.subplots_adjust(top=0.95)
        plt.show()
available_videos = list(val_annotations.keys())
print("Available videos in validation set:")
for i, video in enumerate(available_videos):
    print(f"{i+1}. {video}")

print("\nSelect videos to visualize (comma-separated numbers, e.g., '1,3,5'):")
selection = input(f"Enter video numbers (1-{len(available_videos)}): ")
selected_indices = [int(idx.strip()) - 1 for idx in selection.split(',')]
selected_videos = [available_videos[idx] for idx in selected_indices if 0 <= idx < len(available_videos)]

print(f"\nSelected videos: {selected_videos}")
visualize_specific_videos(model, val_loader, device, selected_videos)