In [1]:
import os
from pathlib import Path
from collections import defaultdict, Counter
from sklearn.model_selection import train_test_split
import shutil

# Paths
base_dir = Path("/home/rishabh.mondal/Brick-Kilns-project/ijcai_2025_kilns/data/processed_data/sentinel/sentinel_sr")

labels_dir = base_dir / "labels"
images_dir = base_dir / "thera_delhi_images"

# Output split folders
split_dir = base_dir / "split_data_swinir"
splits = ["train", "val", "test"]
for split in splits:
    (split_dir / split / "labels").mkdir(parents=True, exist_ok=True)
    (split_dir / split / "images").mkdir(parents=True, exist_ok=True)

# Collect image-label pairs and dominant class
data = []
class_counts = Counter()

for label_file in labels_dir.glob("*.txt"):
    with open(label_file, "r") as f:
        lines = f.readlines()
    
    if not lines:
        continue  # skip empty files

    class_ids = [int(line.split()[0]) for line in lines]
    dominant_class = Counter(class_ids).most_common(1)[0][0]

    image_file = images_dir / label_file.with_suffix(".jpg").name
    if not image_file.exists():
        image_file = image_file.with_suffix(".png")
    if not image_file.exists():
        print(f"Missing image for label: {label_file.name}")
        continue

    data.append((label_file, image_file, dominant_class))
    class_counts.update(class_ids)

print("Overall class counts:", dict(class_counts))

# Stratified split
X = [(str(lf), str(imgf)) for lf, imgf, _ in data]
y = [cls for _, _, cls in data]

X_train, X_temp, y_train, y_temp = train_test_split(X, y, test_size=0.4, stratify=y, random_state=42)
X_val, X_test, y_val, y_test = train_test_split(X_temp, y_temp, test_size=0.5, stratify=y_temp, random_state=42)

splits_data = {"train": X_train, "val": X_val, "test": X_test}

# Copy files and count classes per split
split_class_counts = defaultdict(Counter)

for split, items in splits_data.items():
    for label_path, img_path in items:
        label_file = Path(label_path)
        image_file = Path(img_path)

        # Copy label and image
        shutil.copy(label_file, split_dir / split / "labels" / label_file.name)
        shutil.copy(image_file, split_dir / split / "images" / image_file.name)

        # Count class instances
        with open(label_file, "r") as f:
            for line in f:
                class_id = int(line.split()[0])
                split_class_counts[split][class_id] += 1

# Display class distribution
for split in splits:
    print(f"\nClass distribution in {split}:")
    for class_id, count in split_class_counts[split].items():
        print(f"  Class {class_id}: {count}")


Overall class counts: {2: 5894, 1: 709, 0: 56}

Class distribution in train:
  Class 2: 3533
  Class 1: 419
  Class 0: 37

Class distribution in val:
  Class 2: 1183
  Class 1: 150
  Class 0: 9

Class distribution in test:
  Class 2: 1178
  Class 1: 140
  Class 0: 10
