Before running this notebook, download and extract train images into `./charteye/train` from: https://www.kaggle.com/datasets/pranithchowdary/icpr-2022?resource=download-directory

In [9]:
import numpy as np
import os
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from datasets import load_dataset
from sklearn.model_selection import train_test_split
from torchvision.datasets import ImageFolder
from PIL import Image

train_path = "./charteye/train"
save_folder = "./dataset"
os.makedirs(save_folder, exist_ok=True)

In [4]:
# Constants
GENERAL_CLASS_SAMPLE_SIZE = 30000
CHART_CLASS = 1
GENERAL_CLASS = 0

# 1. Load PACS, DomainNet, and Charteye datasets
print("Loading datasets...")
pacs_dataset = load_dataset("flwrlabs/pacs", split="train")['image']
domainnet_dataset = load_dataset("wltjr1007/DomainNet", split="train")
charteye_dataset = ImageFolder("./charteye/train")

# Determine sizes for sampling
domainnet_sample_size = GENERAL_CLASS_SAMPLE_SIZE - len(pacs_dataset)

# 2. Stratified sampling from DomainNet to get balanced final classes
print(f"Sampling {domainnet_sample_size} examples from DomainNet...")
sampled_indices, _ = train_test_split(
    range(len(domainnet_dataset)),
    train_size=domainnet_sample_size,
    stratify=domainnet_dataset['domain'],
    random_state=42
)
domainnet_sampled = domainnet_dataset[sorted(sampled_indices)]['image']

Loading datasets...
Sampling 20009 examples from DomainNet...


In [12]:
print("Saving ChartEye dataset...")
save_dir = os.path.join(save_folder, str(CHART_CLASS))
os.makedirs(save_dir, exist_ok=True)
for img, _ in charteye_dataset:
    img.save(os.path.join(save_dir, str(len(os.listdir(save_dir)))+'.jpg'))

print("Saving DomainNet dataset...")
save_dir = os.path.join(save_folder, str(GENERAL_CLASS))
os.makedirs(save_dir, exist_ok=True)
for img in domainnet_sampled:
    img.save(os.path.join(save_dir, str(len(os.listdir(save_dir)))+'.jpg'))

print("Saving PACS dataset...")
for img in pacs_dataset:
    img.save(os.path.join(save_dir, str(len(os.listdir(save_dir)))+'.jpg'))

Saving ChartEye dataset...
Saving DomainNet dataset...
Saving PACS dataset...
