In [6]:
from datasets import load_from_disk

dataset = load_from_disk("/home/allanz/data/datasets/v3.1_spatial_grid_multimodal/")

In [119]:
response = ""
for i in range(len(dataset["train"][0]["conversations"])):
    for j in range(len(dataset["train"][0]["conversations"][i])):
        if j % 2 == 1:
            response += dataset["train"][0]["conversations"][i][j] + "\n"
        else:
            response += dataset["train"][0]["conversations"][i][j]
print(temp)

What object is in row 2, column 2?A: bird
What object is in row 1, column 2?A: cat
What object is in row 0, column 0?A: deer
What object is in row 0, column 1?A: cat
What object is in row 1, column 0?A: deer
What object is in row 1, column 1?A: cat
What object is in row 0, column 2?A: bird
What object is in row 2, column 1?A: dog
What object is in row 2, column 0?A: bird



In [120]:
temp

'What object is in row 2, column 2?A: bird\nWhat object is in row 1, column 2?A: cat\nWhat object is in row 0, column 0?A: deer\nWhat object is in row 0, column 1?A: cat\nWhat object is in row 1, column 0?A: deer\nWhat object is in row 1, column 1?A: cat\nWhat object is in row 0, column 2?A: bird\nWhat object is in row 2, column 1?A: dog\nWhat object is in row 2, column 0?A: bird\n'

In [121]:
from typing import Dict, List

from datasets import load_from_disk
from loguru import logger
from PIL import Image
from torch.utils.data import Dataset
import torchvision.transforms as transforms
import random

from lmm_synthetic.data.convert_to_multimodal import parse_grid_from_text

def find_text(text, char, index):
    count = 0
    for i in range(len(text)):
        if text[i] == char:
            count += 1
            if count == index:
                return i 


class LazySupervisedDataset(Dataset):
    """Dataset for multimodal supervised fine-tuning

    Args:
        data_path (str): Path to the dataset.
        split (str): Dataset split (e.g., 'train', 'test').
        max_data_size (int, optional): Maximum number of data samples to load. Defaults to -1 (load all).
        vision_token_ablation (bool, optional): Whether to perform vision token ablation. Defaults to False.
        debug (bool, optional): Whether to enable debug mode. Defaults to False.
        alignment (bool, optional): Whether to concatenate everything into reponse for alignment training
        image_grid (bool, optional): Whether to have dataset only include image and text grid
        sub_sampling (bool, optional): Whether to subsample the conversations
        num_samples (int, optional): Number of conversations to subsample
    """

    def __init__(
        self, 
        data_path: str, 
        split: str,
        max_data_size: int = -1,
        vision_token_ablation: bool = False,
        debug: bool = False,
        alignment: bool = False,
        image_grid: bool = False,
        sub_sampling: bool = False,
        num_samples: int = 3
    ) -> None:
        super(LazySupervisedDataset, self).__init__()
        self.debug = debug
        self.vision_token_ablation = vision_token_ablation

        # Load the dataset from disk
        hf_dataset = load_from_disk(data_path)[split]
        self.list_data_dict = []

        # Image grid is already suited for alignment training
        if image_grid == True:
            for sample in hf_dataset:
                prompt = sample.get("prompt", "")
                grid_index = find_text(sample.get("text", ""), "\n", 3)
                grid = sample.get("text", "")[0:grid_index]
                conversations = [["", grid]]
                data_dict = {
                    "image": sample.get("image", None),
                    "prompt": prompt,
                    "conversations": conversations
                }
                if self.debug:
                    data_dict["text"] = sample.get("text", "")
                if self.vision_token_ablation:
                    data_dict["grid"] = sample.get('grid', parse_grid_from_text(sample['text']))
                self.list_data_dict.append(data_dict)

        else:
            for sample in hf_dataset:
                prompt = sample.get("prompt", "")
                if alignment == True:
                    if sub_sampling == True:
                        conversations = []
                        response = ""
                        shuffled = random.sample(sample["conversations"], num_samples)
                        for entry in random.sample(sample["conversations"], num_samples):
                            for subentry in entry:
                                response += "" + subentry
                        response = ""
                        for i in range(len(shuffled)):
                            for j in range(len(shuffled[i])):
                                if j % 2 == 1:
                                    response += shuffled[i][j] + "\n"
                                else:
                                    response += shuffled[i][j]
                        conversations.append(["", response])
                    else:
                        conversations = []
                        response = ""
                        for i in range(len(sample["conversations"])):
                            for j in range(len(sample["conversations"][i])):
                                if j % 2 == 1:
                                    response += sample["conversations"][i][j] + "\n"
                                else:
                                    response += sample["conversations"][i][j]
                        conversations.append(["", response])
                else:
                    if sub_sampling == True:
                        conversations = random.sample(sample["conversations"], num_samples)
                    else:
                        conversations = sample.get("conversations", [])
                data_dict = {
                    "image": sample.get("image", None),
                    "prompt": prompt,
                    "conversations": conversations
                    }
                if self.debug:
                    data_dict["text"] = sample.get("text", "")
                if self.vision_token_ablation:
                    data_dict["grid"] = sample.get('grid', parse_grid_from_text(sample['text']))
                self.list_data_dict.append(data_dict)



        # Limit the dataset size if max_data_size is specified
        if max_data_size > 0:
            self.list_data_dict = self.list_data_dict[:max_data_size]

        logger.info(f"Dataset size: {len(self.list_data_dict)}")

        # Determine whether each sample is text-only
        self.is_text_only = [
            "image" not in source for source in self.list_data_dict
        ]

    def __len__(self) -> int:
        """Returns the total number of samples in the dataset."""
        return len(self.list_data_dict)

    def __getitem__(self, i) -> Dict[str, List]:
        """Retrieves the sample at index `i`.

        Args:
            i (int): Index of the sample to retrieve.

        Returns:
            Dict[str, List]: A dictionary containing the sample data.
        """
        sample = self.list_data_dict[i]
        item_dict = {
            "image": Image.open(sample["image"]).convert("RGB"),
            "prompt": sample["prompt"],
            "conversations": sample["conversations"]
        }
        if self.debug:
            item_dict["text"] = sample["text"]
        if self.vision_token_ablation:
            item_dict["grid"] = sample["grid"]
        return item_dict

In [122]:
original = LazySupervisedDataset(r"/home/allanz/data/datasets/v3.1_spatial_grid_multimodal/", "train", 10, False, False)
image_grid = LazySupervisedDataset(r"/home/allanz/data/datasets/v3.1_spatial_grid_multimodal/", "train", 10, False, False, False, True)
sub_sample = LazySupervisedDataset(r"/home/allanz/data/datasets/v3.1_spatial_grid_multimodal/", "train", 10, False, False, False, False, True, 3)

[32m2025-01-22 13:51:06.155[0m | [1mINFO    [0m | [36m__main__[0m:[36m__init__[0m:[36m125[0m - [1mDataset size: 10[0m
[32m2025-01-22 13:51:12.690[0m | [1mINFO    [0m | [36m__main__[0m:[36m__init__[0m:[36m125[0m - [1mDataset size: 10[0m
[32m2025-01-22 13:51:19.694[0m | [1mINFO    [0m | [36m__main__[0m:[36m__init__[0m:[36m125[0m - [1mDataset size: 10[0m


In [123]:
original[0]

{'image': <PIL.Image.Image image mode=RGB size=256x256>,
 'prompt': "The grid above is size 3 by 3. Each cell contains an object from ['deer', 'bird', 'dog', 'cat'].",
 'conversations': [['What object is in row 2, column 2?', 'A: bird'],
  ['What object is in row 1, column 2?', 'A: cat'],
  ['What object is in row 0, column 0?', 'A: deer'],
  ['What object is in row 0, column 1?', 'A: cat'],
  ['What object is in row 1, column 0?', 'A: deer'],
  ['What object is in row 1, column 1?', 'A: cat'],
  ['What object is in row 0, column 2?', 'A: bird'],
  ['What object is in row 2, column 1?', 'A: dog'],
  ['What object is in row 2, column 0?', 'A: bird']]}

In [124]:
image_grid[0]

{'image': <PIL.Image.Image image mode=RGB size=256x256>,
 'prompt': "The grid above is size 3 by 3. Each cell contains an object from ['deer', 'bird', 'dog', 'cat'].",
 'conversations': ['',
  '| deer | cat | bird |\n| deer | cat | cat |\n| bird | dog | bird |']}

In [127]:
sub_sample[0]

{'image': <PIL.Image.Image image mode=RGB size=256x256>,
 'prompt': "The grid above is size 3 by 3. Each cell contains an object from ['deer', 'bird', 'dog', 'cat'].",
 'conversations': [['What object is in row 1, column 0?', 'A: deer'],
  ['What object is in row 2, column 0?', 'A: bird'],
  ['What object is in row 0, column 1?', 'A: cat']]}

In [128]:
original_alignment = LazySupervisedDataset(r"/home/allanz/data/datasets/v3.1_spatial_grid_multimodal/", "train", 10, False, False, True)
image_grid_alignment = LazySupervisedDataset(r"/home/allanz/data/datasets/v3.1_spatial_grid_multimodal/", "train", 10, False, False, True, False, True)
sub_sample_alignment = LazySupervisedDataset(r"/home/allanz/data/datasets/v3.1_spatial_grid_multimodal/", "train", 10, False, False, True, False, False, 3)

[32m2025-01-22 13:52:22.686[0m | [1mINFO    [0m | [36m__main__[0m:[36m__init__[0m:[36m125[0m - [1mDataset size: 10[0m
[32m2025-01-22 13:52:30.153[0m | [1mINFO    [0m | [36m__main__[0m:[36m__init__[0m:[36m125[0m - [1mDataset size: 10[0m
[32m2025-01-22 13:52:36.896[0m | [1mINFO    [0m | [36m__main__[0m:[36m__init__[0m:[36m125[0m - [1mDataset size: 10[0m


In [129]:
original_alignment[0]

{'image': <PIL.Image.Image image mode=RGB size=256x256>,
 'prompt': "The grid above is size 3 by 3. Each cell contains an object from ['deer', 'bird', 'dog', 'cat'].",
 'conversations': [['',
   'What object is in row 2, column 2?A: bird\nWhat object is in row 1, column 2?A: cat\nWhat object is in row 0, column 0?A: deer\nWhat object is in row 0, column 1?A: cat\nWhat object is in row 1, column 0?A: deer\nWhat object is in row 1, column 1?A: cat\nWhat object is in row 0, column 2?A: bird\nWhat object is in row 2, column 1?A: dog\nWhat object is in row 2, column 0?A: bird\n']]}

In [130]:
image_grid_alignment[0]

{'image': <PIL.Image.Image image mode=RGB size=256x256>,
 'prompt': "The grid above is size 3 by 3. Each cell contains an object from ['deer', 'bird', 'dog', 'cat'].",
 'conversations': [['',
   'What object is in row 1, column 1?A: cat\nWhat object is in row 0, column 1?A: cat\nWhat object is in row 2, column 0?A: bird\n']]}

In [131]:
sub_sample_alignment[0]

{'image': <PIL.Image.Image image mode=RGB size=256x256>,
 'prompt': "The grid above is size 3 by 3. Each cell contains an object from ['deer', 'bird', 'dog', 'cat'].",
 'conversations': [['',
   'What object is in row 2, column 2?A: bird\nWhat object is in row 1, column 2?A: cat\nWhat object is in row 0, column 0?A: deer\nWhat object is in row 0, column 1?A: cat\nWhat object is in row 1, column 0?A: deer\nWhat object is in row 1, column 1?A: cat\nWhat object is in row 0, column 2?A: bird\nWhat object is in row 2, column 1?A: dog\nWhat object is in row 2, column 0?A: bird\n']]}