### Adding domain field in both ai2d and science qa

In [None]:
import json
import os

# File paths
ai2d_in = "ai2d/ai2d_educational_captions_filtered.json"  # .json format
scienceqa_in = "scienceqa/llava_image_captions_only.jsonl"     # .jsonl format

ai2d_out = "ai2d_labeled.json"
scienceqa_out = "scienceqa_labeled.json"

# Processing AI2D JSON
ai2d_result = []
with open(ai2d_in, "r") as f_in:
    ai2d_data = json.load(f_in)
    for item in ai2d_data:
        if item["caption"].strip():
            ai2d_result.append({
                "image_path": os.path.join("ai2d", item["image"]),
                "caption": item["caption"].strip(),
                "domain": "AI2D"
            })

with open(ai2d_out, "w") as f_out:
    json.dump(ai2d_result, f_out, indent=2)

# Processing ScienceQA JSONL
scienceqa_result = []
with open(scienceqa_in, "r") as f_in:
    for line in f_in:
        item = json.loads(line)
        if "llava_caption" in item and item["llava_caption"].strip():
            scienceqa_result.append({
                "image_path": os.path.join("scienceqa/images", f"{item['image_id']}.png"),
                "caption": item["llava_caption"].strip(),
                "domain": "ScienceQA"
            })

with open(scienceqa_out, "w") as f_out:
    json.dump(scienceqa_result, f_out, indent=2)

print("Step 1 complete: JSON files saved.")

Step 1 complete: JSON files saved.


### Combining ai2d and science qa

In [None]:
import json

# Input files (from Step 1)
ai2d_file = "ai2d_labeled.json"
scienceqa_file = "scienceqa_labeled.json"
output_file = "combined_dataset.jsonl"

# Loading both the JSON lists
with open(ai2d_file, "r") as f:
    ai2d_data = json.load(f)

with open(scienceqa_file, "r") as f:
    scienceqa_data = json.load(f)

# Combining the two lists
combined_data = ai2d_data + scienceqa_data

# Writing to JSONL file
with open(output_file, "w") as f_out:
    for entry in combined_data:
        f_out.write(json.dumps(entry) + "\n")

print(f"Combined dataset saved to: {output_file}")

Combined dataset saved to: combined_dataset.jsonl


###  Defining the Domain Map


In [None]:
# Mapping domains to numerical labels
domain_dict = {
    "AI2D": 0,
    "ScienceQA": 1
}

### Defining Image Transformations (PyTorch)

This sets a standard transformation for all your images:
	•	Resize all to 256x256 (matches VAE input)
	•	Convert image to PyTorch Tensor (scales pixel values from 0–255 → 0–1)


In [None]:
from torchvision import transforms

# Standard image preprocessing
image_transform = transforms.Compose([
    transforms.Resize((256, 256)),  # Resizing to 256x256
    transforms.ToTensor(),          # Converting to tensor and normalizing to [0, 1]
])

  from .autonotebook import tqdm as notebook_tqdm


###  Implementing Text Tokenizer and Embedder

In [None]:
import torch

	•	Loads the pretrained tokenizer and text encoder.
	•	tokenizer turns your caption into token IDs.
	•	text_encoder converts token IDs into embeddings.

In [None]:
#Tokenizer Initialization

from transformers import CLIPTokenizer, CLIPTextModel

tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")

In [None]:
#Text Encoder Initialization
text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-base-patch32")

  return self.fget.__get__(instance, owner)()


	•	Input: text – a single string (caption like “Photosynthesis diagram”).
	•	Step 1: Tokenizes it and turns it into tensors (input_ids, attention_mask).
	•	Step 2: Passes it through the CLIP text encoder.
	•	Step 3: Takes the mean over the token embeddings to get a single [768]-dim vector.
	•	Returns: A vector that represents the text meaning in CLIP’s latent space.

In [None]:
## Embedding Function
def embed_text(text):
    inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True)
    with torch.no_grad():
        embedding = text_encoder(**inputs).last_hidden_state.mean(dim=1)
    return embedding.squeeze()

### Unified Dataset Loader

You are creating a custom PyTorch Dataset class to:
	1.	Load and transform educational diagram images.
	2.	Convert captions into CLIP embeddings.
	3.	Encode the data source (AI2D or ScienceQA) for multi-domain handling.

In [None]:
from torch.utils.data import Dataset
from PIL import Image
import json
import torch
import os

class EducationalDiagramDataset(Dataset):
    def __init__(self, metadata_path, text_embedder, domain_dict, transform):
        with open(metadata_path, "r") as f:
            samples = [json.loads(line) for line in f]

        self.samples = []
        self.skipped_count = 0
        self.text_embedder = text_embedder
        self.domain_dict = domain_dict
        self.transform = transform

        # Filtering out missing image files during initialization
        for sample in samples:
            if os.path.exists(sample["image_path"]):
                self.samples.append(sample)
            else:
                self.skipped_count += 1

        print(f"Loaded {len(self.samples)} valid samples")
        print(f"Skipped {self.skipped_count} missing or invalid image files")

    def __getitem__(self, idx):
        sample = self.samples[idx]

        image = Image.open(sample["image_path"]).convert("RGB")
        image = self.transform(image)

        caption = sample["caption"]
        text_embedding = self.text_embedder(caption)

        domain_idx = self.domain_dict[sample["domain"]]

        return image, text_embedding, domain_idx

    def __len__(self):
        return len(self.samples)

In [None]:
from torch.utils.data import DataLoader

dataset = EducationalDiagramDataset(
    metadata_path="combined_dataset.jsonl",
    text_embedder=embed_text,
    domain_dict={"AI2D": 0, "ScienceQA": 1},
    transform=image_transform
)

dataloader = DataLoader(dataset, batch_size=8, shuffle=True)

for images, text_embeddings, domain_labels in dataloader:
    print(images.shape)          # [8, 3, 256, 256]
    print(text_embeddings.shape) # [8, 768]
    print(domain_labels)         # tensor([0, 1, 0, ...])
    break

Loaded 9573 valid samples
Skipped 468 missing or invalid image files
torch.Size([8, 3, 256, 256])
torch.Size([8, 512])
tensor([1, 1, 0, 1, 0, 1, 1, 0])


### Saving the file

You are converting your full dataset into 3 tensors:
	1.	Images [N, 3, 256, 256]
	2.	Text Embeddings [N, 768] (from CLIP)
	3.	Labels [N] (0 for AI2D, 1 for ScienceQA)

Then, you’re saving it as a .pt file (educational_diagram_data.pt) — this makes it quick to load later during VAE or LDM training without recomputing embeddings.

In [None]:
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm

dataset = EducationalDiagramDataset(
    metadata_path="combined_dataset.jsonl",
    text_embedder=embed_text,
    domain_dict={"AI2D": 0, "ScienceQA": 1},
    transform=image_transform
)

dataloader = DataLoader(dataset, batch_size=1, shuffle=False)

# Collecting all into lists
all_images = []
all_text_embeddings = []
all_domain_labels = []

print("Saving embeddings...")

for image, text_emb, domain in tqdm(dataloader):
    all_images.append(image.squeeze(0))           # shape: [3, 256, 256]
    all_text_embeddings.append(text_emb.squeeze(0))  # shape: [512]
    all_domain_labels.append(domain.item())       # single int

# Converting to tensors
all_images = torch.stack(all_images)
all_text_embeddings = torch.stack(all_text_embeddings)
all_domain_labels = torch.tensor(all_domain_labels)

# Saving to .pt file
torch.save({
    "images": all_images,
    "text_embeddings": all_text_embeddings,
    "labels": all_domain_labels
}, "educational_diagram_data.pt")

print("Saved all tensors to 'educational_diagram_data.pt'")

Loaded 9573 valid samples
Skipped 468 missing or invalid image files
Saving embeddings...


100%|██████████| 9573/9573 [05:31<00:00, 28.86it/s]


Saved all tensors to 'educational_diagram_data.pt'
