<a href="https://colab.research.google.com/github/zhuzihan728/Image-Restore/blob/main/image_gen.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import os
import numpy as np
from PIL import Image
import random
import json

def listdir(path, list_name):
    list_dirs = os.listdir(path)
    list_dirs.sort()
    for file in list_dirs:
        file_path = os.path.join(path, file)
        if os.path.isdir(file_path):
            listdir(file_path, list_name)
        else:
            list_name.append(file_path)
    list_name.sort()


class ImageFolderPreprocess:
    def __init__(self, root, root_bg, output_dir, mask_scale_range=(1.0, 1.2), alpha_ranges=[(0.5, 0.8), (0.8, 1.0)]):
        """
        :param root: path to original images
        :param root_bg: path to background/mask images
        :param output_dir: where to save preprocessed images
        :param mask_scale_range: (min, max) scale factor for mask relative to image size
        """
        self.img_root = root
        self.bg_root = root_bg
        self.output_dir = output_dir
        self.mask_scale_range = mask_scale_range
        self.frame, self.bg_frame = self._parse_frame()
        print(f"Total images: {len(self.frame)}")
        print(f"Total background images: {len(self.bg_frame)}")
        self.metadata = []

        self.alpha_ranges = alpha_ranges
        os.makedirs(os.path.join(output_dir, 'corrupted'), exist_ok=True)

    def _parse_frame(self):
        img_names = []
        bg_names = []

        for fname in os.listdir(self.img_root):
            if fname.lower().endswith(('.jpg', '.jpeg', '.png')):
                img_names.append(fname)

        for fname in os.listdir(self.bg_root):
            if fname.lower().endswith(('.jpg', '.jpeg', '.png')):
                bg_names.append(fname)

        img_names.sort()
        bg_names.sort()

        return img_names, bg_names

    def __len__(self):
        return len(self.frame)


    def create_mask(self, img_forbg, w, h, scale_factor, crop_x, crop_y, strength):
        mask_w = int(w * scale_factor)
        mask_h = int(h * scale_factor)
        # Resize mask to scaled size
        img_forbg_resized = img_forbg.resize((mask_w, mask_h), Image.Resampling.LANCZOS)  # Fixed: use new variable
        if scale_factor > 1.0:
            # Crop position (if mask is larger)
            max_x = mask_w - w
            max_y = mask_h - h
            crop_x = int(max_x * crop_x) if max_x > 0 else 0
            crop_y = int(max_y * crop_y) if max_y > 0 else 0
            img_forbg_resized = img_forbg_resized.crop((crop_x, crop_y, crop_x + w, crop_y + h))
        alpha_channel = np.array(img_forbg_resized).astype(np.float32)  # Fixed: use resized version
        alpha_channel = alpha_channel * strength + 255 * (1 - strength)
        alpha_channel = alpha_channel.astype(np.uint8)
        return alpha_channel


    def __getitem__(self, idx):
        img_name = self.frame[idx]
        file = os.path.join(self.img_root, img_name)
        img = Image.open(file).convert('RGB')

        random_idx = random.randint(0, len(self.bg_frame) - 1)
        mask_name = self.bg_frame[random_idx]
        file_bg = os.path.join(self.bg_root, mask_name)
        img_forbg = Image.open(file_bg).convert('L')

        w, h = img.size

        # Random scale factor for mask
        scale_factor = random.uniform(self.mask_scale_range[0], self.mask_scale_range[1])

        # Crop position (if mask is larger)
        crop_x, crop_y = 0, 0
        if scale_factor > 1.0:
            crop_x = random.random()  # Fixed: random() not random.uniform()
            crop_y = random.random()  # Fixed: random() not random.uniform()

        # Apply blending at ORIGINAL size
        original_image_rgba = img.convert("RGBA")

        for alpha_range in self.alpha_ranges:
            # Random strength for blending
            strength = random.uniform(alpha_range[0], alpha_range[1])
            alpha_channel = self.create_mask(img_forbg, w, h, scale_factor, crop_x, crop_y, strength)

            original_array_rgba = np.array(original_image_rgba)
            original_array_rgba[..., 3] = alpha_channel

            img_rgb = self.rgba_to_rgb(original_array_rgba)
            img_rgb = Image.fromarray(img_rgb)

            # Save only corrupted image
            print(f"Created corrupted image {idx}, alpha: {alpha_range}")
            save_name = f"{idx:05d}_alpha{str(alpha_range[0]).replace('.', '')}-{str(alpha_range[1]).replace('.', '')}.png"  # Fixed: string formatting
            img_rgb.save(os.path.join(self.output_dir, 'corrupted', save_name))

            # Store metadata with original and mask image names
            meta = {
                "id": idx,
                "corrupted_image": save_name,
                "original_image": img_name,
                "mask_image": mask_name,
                "alpha_range": alpha_range,
                "strength": float(strength),
                "mask_scale": float(scale_factor),
                "crop_position": [float(crop_x), float(crop_y)],
                "image_size": [int(w), int(h)]
            }
            self.metadata.append(meta)

    def rgba_to_rgb(self, rgba_image):
        """Convert RGBA to RGB with white background"""
        rgb = np.zeros((rgba_image.shape[0], rgba_image.shape[1], 3), dtype=np.uint8)
        rgb[:, :] = 255  # White background

        alpha = rgba_image[:, :, 3:4] / 255.0
        rgb = (rgba_image[:, :, :3] * alpha + rgb * (1 - alpha)).astype(np.uint8)
        return rgb

    def save_metadata(self):
        """Save metadata to JSON file"""
        json_path = os.path.join(self.output_dir, 'metadata.json')
        with open(json_path, 'w') as f:
            json.dump(self.metadata, f, indent=2)
        print(f"Metadata saved to {json_path}")


# Usage:
# dataset = ImageFolderPreprocess(
#     root='images/',
#     root_bg='backgrounds/',
#     output_dir='eval_dataset/',
#     mask_scale_range=(1.0, 1.5)
# )

# for i in range(len(dataset)):
#     dataset[i]
#     if i % 100 == 0:
#         print(f"Processed {i}/{len(dataset)} images")

# dataset.save_metadata()

In [None]:
class EvalDataset:
    def __init__(self, corrupted_dir, original_dir, mask_dir, metadata_path,
                 im_size=None, transform=None):
        """
        :param corrupted_dir: path to corrupted images folder
        :param original_dir: path to original images folder
        :param mask_dir: path to mask images folder
        :param metadata_path: path to metadata.json
        :param im_size: target size (h, w) or None to keep original
        """
        self.corrupted_dir = corrupted_dir
        self.original_dir = original_dir
        self.mask_dir = mask_dir
        self.im_size = im_size
        self.transform = transform

        with open(metadata_path, 'r') as f:
            self.metadata = json.load(f)

    def __len__(self):
        return len(self.metadata)

    def __getitem__(self, idx):
        meta = self.metadata[idx]

        # Load images using names from metadata
        corrupted = Image.open(os.path.join(self.corrupted_dir, meta['corrupted_image']))
        original = Image.open(os.path.join(self.original_dir, meta['original_image']))

        # Resize if needed
        if self.im_size:
            corrupted = corrupted.resize(self.im_size, Image.Resampling.LANCZOS)
            original = original.resize(self.im_size, Image.Resampling.LANCZOS)

        if self.transform:
            corrupted = self.transform(corrupted)
            original = self.transform(original)

        return corrupted, original, meta

# Usage in evaluation:
# eval_dataset = EvalDataset(
#     corrupted_dir='eval_dataset/corrupted/',
#     original_dir='images/',  # Original folder
#     mask_dir='backgrounds/',  # Mask folder
#     metadata_path='eval_dataset/metadata.json',
#     im_size=(256, 256)
# )

In [None]:
import shutil

all_images_ls = []
train_images_ls = []
test_images_ls = []

listdir('/content/drive/MyDrive/image/', all_images_ls)

print(f"Total images: {len(all_images_ls)}")
ratio = 1 - 100 / len(all_images_ls)
train_indices = random.sample(range(len(all_images_ls)), int(ratio * len(all_images_ls)))
test_indices = [i for i in range(len(all_images_ls)) if i not in train_indices]
print(f"Train images: {len(train_indices)}")
print(f"Test images: {len(test_indices)}")

assert len(train_indices) + len(test_indices) == len(all_images_ls)



Total images: 750
Train images: 650
Test images: 100


In [None]:
for i in train_indices:
    train_images_ls.append(all_images_ls[i])

for i in test_indices:
    test_images_ls.append(all_images_ls[i])

train_image_folder = '/content/drive/MyDrive/image_train/'
test_image_folder = '/content/drive/MyDrive/image_test/'
# save train and test images to two folders

if not os.path.exists(train_image_folder):
    os.makedirs(train_image_folder)

if not os.path.exists(test_image_folder):
    os.makedirs(test_image_folder)
for i in train_images_ls:
    shutil.copy(i, train_image_folder)

for i in test_images_ls:
    shutil.copy(i, test_image_folder)


In [None]:
from IPython.testing import test
train_copy_ls = []
test_copy_ls = []

listdir('/content/drive/MyDrive/image_train/', train_copy_ls)
listdir('/content/drive/MyDrive/image_test/', test_copy_ls)

all_image_file_name = [os.path.basename(i) for i in all_images_ls]
train_image_file_name = [os.path.basename(i) for i in train_copy_ls]
test_image_file_name = [os.path.basename(i) for i in test_copy_ls]
# verify train and test are exclusive, and are in all_images_ls
for i in train_image_file_name:
    assert i in all_image_file_name
    assert i not in test_image_file_name

for i in test_image_file_name:
    assert i in all_image_file_name
    assert i not in train_image_file_name

len(train_copy_ls), len(test_copy_ls)

(650, 100)

In [None]:
# Step 1: Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

# Step 2: Save directly to Drive
dataset = ImageFolderPreprocess(
    root='/content/drive/MyDrive/image_test/',
    root_bg='/content/drive/MyDrive/mask/',
    output_dir='/content/drive/MyDrive/eval_dataset/',  # Direct to Drive
    mask_scale_range=(1.0, 1.2),
    alpha_ranges=[(0.1, 0.5), (0.5, 0.8), (0.8, 1.0)]
)

# Process all images
for i in range(len(dataset)):
    dataset[i]
    if i % 100 == 0:
        print(f"Processed {i}/{len(dataset)} images")

# Save metadata
dataset.save_metadata()
print("✓ All files saved to Google Drive!")

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
Total images: 100
Total background images: 750
Created corrupted image 0, alpha: (0.1, 0.5)
Created corrupted image 0, alpha: (0.5, 0.8)
Created corrupted image 0, alpha: (0.8, 1.0)
Processed 0/100 images
Created corrupted image 1, alpha: (0.1, 0.5)
Created corrupted image 1, alpha: (0.5, 0.8)
Created corrupted image 1, alpha: (0.8, 1.0)
Created corrupted image 2, alpha: (0.1, 0.5)
Created corrupted image 2, alpha: (0.5, 0.8)
Created corrupted image 2, alpha: (0.8, 1.0)
Created corrupted image 3, alpha: (0.1, 0.5)
Created corrupted image 3, alpha: (0.5, 0.8)
Created corrupted image 3, alpha: (0.8, 1.0)
Created corrupted image 4, alpha: (0.1, 0.5)
Created corrupted image 4, alpha: (0.5, 0.8)
Created corrupted image 4, alpha: (0.8, 1.0)
Created corrupted image 5, alpha: (0.1, 0.5)
Created corrupted image 5, alpha: (0.5, 0.8)
Created corrupted image 5, alpha: (0