## augmente for non-mask people

In [None]:
import os
import random
from PIL import Image, ImageEnhance

def random_scale(img, base_size, scale_range=(0.5, 1.5)):
    """
    随机缩放人物，保证不超出原始画布大小。
    """
    base_w, base_h = base_size
    w, h = img.size
    # 最大缩放因子，避免缩放后宽高超出原图
    max_scale = min(scale_range[1], base_w / w, base_h / h)
    if max_scale < scale_range[0]:
        return img  # 无法缩放，则跳过
    scale = random.uniform(scale_range[0], max_scale)
    new_w, new_h = int(w * scale), int(h * scale)
    return img.resize((new_w, new_h), Image.BICUBIC)

def random_rotate(img, base_size, angle_range=(-60, 60)):
    """
    随机旋转人物，并保证旋转后整个人物都能放入原始画布内。

    Args:
        img (PIL.Image): 当前已做完其它变换的人物图，模式为 "RGB"。
        base_size (tuple): 画布尺寸 (width, height)。
        angle_range (tuple): 随机旋转角度范围（度数）。

    Returns:
        PIL.Image: 旋转并（如有必要）等比例缩放后的图像。
    """
    base_w, base_h = base_size
    angle = random.uniform(*angle_range)
    # 1) 完整旋转，expand=True 能保留所有像素
    rotated = img.rotate(
        angle,
        resample=Image.BICUBIC,
        expand=True,
        fillcolor=(0, 0, 0)
    )

    rw, rh = rotated.size
    # 2) 如超出画布，按最小比例缩放
    if rw > base_w or rh > base_h:
        scale = min(base_w / rw, base_h / rh)
        new_w, new_h = int(rw * scale), int(rh * scale)
        rotated = rotated.resize((new_w, new_h), Image.BICUBIC)

    return rotated

def random_brightness(img, brightness_range=(0.5, 1.5)):
    """随机调整亮度。"""
    enhancer = ImageEnhance.Brightness(img)
    factor = random.uniform(*brightness_range)
    return enhancer.enhance(factor)

def random_flip(img):
    """50% 概率水平翻转。"""
    if random.random() < 0.5:
        return img.transpose(Image.FLIP_LEFT_RIGHT)
    return img

def random_translate(img, base_size, translate_range=(-50, 50)):
    """
    在黑色画布上随机平移人物，保证不超出边界。
    
    Arguments:
        img (PIL.Image): 当前已变换的人物图。
        base_size (tuple): 原始画布大小 (width, height)。
        translate_range (tuple): 在 x 和 y 方向上允许的相对平移范围（可正可负）。
    """
    base_w, base_h = base_size
    w, h = img.size
    # 绝对偏移区间：确保人物不会跑出画布
    dx_min = max(-w, translate_range[0])
    dx_max = min(base_w - w, translate_range[1])
    dy_min = max(-h, translate_range[0])
    dy_max = min(base_h - h, translate_range[1])
    
    dx = random.randint(dx_min, dx_max)
    dy = random.randint(dy_min, dy_max)
    
    # # 新画布填充黑色
    # canvas = Image.new("RGB", base_size, (0, 0, 0))
    # 贴到新位置
    canvas.paste(img, (dx, dy))
    return canvas

def augment_image(img_path, output_dir, num_aug=5):
    """
    对单张图片执行所有增强，并保存。
    """
    original = Image.open(img_path)
    base_size = original.size
    base_name = os.path.splitext(os.path.basename(img_path))[0]
    
    os.makedirs(output_dir, exist_ok=True)
    for i in range(num_aug):
        aug = original.copy()
        
        aug = random_scale(aug, base_size)
        aug = random_rotate(aug, base_size)
        aug = random_brightness(aug)
        aug = random_flip(aug)
        aug = random_translate(aug, base_size)
        
        out_path = os.path.join(output_dir, f"{base_name}_aug_{i}.png")
        aug.save(out_path, format="PNG")

def main(input_dir, output_dir, num_aug=5):
    """
    批量处理文件夹下所有 PNG 图像。
    """
    files = os.listdir(input_dir)
    for fname in files:
        if fname.lower().endswith('.png'):
            augment_image(
                img_path=os.path.join(input_dir, fname),
                output_dir=output_dir,
                num_aug=num_aug
            )

if __name__ == "__main__":
    # 请修改为你的实际路径和增强次数
    input_directory = '../autodl-tmp/train_data_360_540_with_mask/people'
    output_directory = '..//autodl-tmp/train_data_360_540_with_mask/people_aug'
    augmentations_per_image = 2
    
    main(input_directory, output_directory, augmentations_per_image)


In [None]:
print(len(os.listdir('autodl-tmp/train_data_360_540_augmented/people')))

In [None]:
for name in os.listdir('autodl-tmp/train_data_360_540_selected/people'):
    if name.lower().endswith('.png'):
        pass
    else:
        print(name)

In [None]:
import os
import random
from PIL import Image, ImageEnhance
from tqdm import tqdm

def load_image_and_mask(img_path):
    """Load image and corresponding mask (Alpha channel)"""
    img = Image.open(img_path)
    if img.mode == 'RGBA':
        # Split the image into RGB and Alpha components
        rgb = Image.new("RGB", img.size, (0, 0, 0))
        rgb.paste(img, mask=img.split()[3])  # Use alpha as mask for proper compositing
        return rgb, img.split()[3]  # Return (RGB, Alpha)
    else:
        raise ValueError("Input image must contain Alpha channel!")

def random_scale(img, base_size, scale_range=(0.5, 1.5)):
    """Random scaling with constraints"""
    base_w, base_h = base_size
    w, h = img.size
    
    max_scale = min(scale_range[1], base_w/w, base_h/h)
    if max_scale < scale_range[0]:
        return img
    
    scale = random.uniform(scale_range[0], max_scale)
    new_size = (int(w*scale), int(h*scale))
    return img.resize(new_size, Image.BICUBIC)

def random_rotate(img, angle_range=(-60, 60)):
    """Rotate the image"""
    angle = random.uniform(*angle_range)
    fillcolor = 0 if img.mode == 'L' else (0, 0, 0)
    return img.rotate(angle, resample=Image.BICUBIC, expand=True, fillcolor=fillcolor)

def random_brightness(img, brightness_range=(0.5, 1.5)):
    """Adjust RGB brightness only"""
    if img.mode == 'L':  # Skip for masks
        return img
    factor = random.uniform(*brightness_range)
    return ImageEnhance.Brightness(img).enhance(factor)

def random_flip(img):
    """Horizontal flip"""
    if random.random() < 0.5:
        return img.transpose(Image.FLIP_LEFT_RIGHT)
    return img

def augment_with_mask(img_path, output_dir, num_aug=5):
    """Main function for synchronized mask augmentation"""
    # Load image and extract mask
    rgb, alpha = load_image_and_mask(img_path)
    base_name = os.path.splitext(os.path.basename(img_path))[0]
    
    os.makedirs(output_dir, exist_ok=True)
    
    for i in range(num_aug):
        # Start with copies to avoid modifying originals
        current_rgb = rgb.copy()
        current_alpha = alpha.copy()
        
        # First pass: apply identical geometric transformations to both
        
        # 1. Scale - apply to both with same parameters
        scale_range = (0.7, 1.3)
        scale_factor = random.uniform(*scale_range)
        new_size = (int(current_rgb.width * scale_factor), int(current_rgb.height * scale_factor))
        current_rgb = current_rgb.resize(new_size, Image.BICUBIC)
        current_alpha = current_alpha.resize(new_size, Image.BICUBIC)
        
        # 2. Rotation - use same angle for both
        angle = random.uniform(-30, 30)
        rgb_fill = (0, 0, 0)
        alpha_fill = 0
        current_rgb = current_rgb.rotate(angle, resample=Image.BICUBIC, expand=True, fillcolor=rgb_fill)
        current_alpha = current_alpha.rotate(angle, resample=Image.BICUBIC, expand=True, fillcolor=alpha_fill)
        
        # 3. Flipping - apply same flip to both
        if random.random() < 0.5:
            current_rgb = current_rgb.transpose(Image.FLIP_LEFT_RIGHT)
            current_alpha = current_alpha.transpose(Image.FLIP_LEFT_RIGHT)
        
        # 4. Apply brightness only to RGB image (not to mask)
        if random.random() < 0.7:  # 70% chance of brightness adjustment
            brightness_factor = random.uniform(0.7, 1.3)
            current_rgb = ImageEnhance.Brightness(current_rgb).enhance(brightness_factor)
        
        # 5. Translation - create same-sized canvases and paste at identical positions
        base_size = (360, 540)  # Assuming this is your target size, adjust as needed
        dx = random.randint(-50, 50)
        dy = random.randint(-50, 50)
        
        # Create new canvases
        rgb_canvas = Image.new("RGB", base_size, (0, 0, 0))
        alpha_canvas = Image.new("L", base_size, 0)
        
        # Calculate paste coordinates (centered + offset)
        paste_x = max(0, (base_size[0] - current_rgb.width) // 2 + dx)
        paste_y = max(0, (base_size[1] - current_rgb.height) // 2 + dy)
        
        # Ensure we don't paste outside canvas boundaries
        if paste_x + current_rgb.width > base_size[0]:
            paste_x = base_size[0] - current_rgb.width
        if paste_y + current_rgb.height > base_size[1]:
            paste_y = base_size[1] - current_rgb.height
        
        # Paste at exact same position for both
        rgb_canvas.paste(current_rgb, (paste_x, paste_y))
        alpha_canvas.paste(current_alpha, (paste_x, paste_y))
        
        # Merge into RGBA
        result = Image.merge("RGBA", (*rgb_canvas.split(), alpha_canvas))
        
        # Save output
        output_path = os.path.join(output_dir, f"{base_name}_aug_{i}.png")
        result.save(output_path)
        print(f"Saved augmented image: {output_path}")

def batch_augment(input_dir, output_dir, num_aug=5):
    """Process all PNG images in directory"""
    os.makedirs(output_dir, exist_ok=True)
    
    for fname in tqdm([f for f in os.listdir(input_dir) if f.endswith('.png')]):
        try:
            img_path = os.path.join(input_dir, fname)
            # print(f"Processing: {img_path}")
            augment_with_mask(img_path, output_dir, num_aug)
        except Exception as e:
            
            print(f"Error processing {fname}: {e}")
            pass

if __name__ == "__main__":
    input_dir = "../autodl-tmp/train_data_360_540_selected/people_2_with_mask"
    output_dir = "../autodl-tmp/train_data_360_540_aug_2/people"
    batch_augment(input_dir, output_dir, num_aug=10)

In [None]:
import os
import random
import math
from PIL import Image, ImageEnhance

def load_image_and_mask(img_path):
    """Load image and corresponding mask (Alpha channel)"""
    img = Image.open(img_path)
    if img.mode == 'RGBA':
        # Split the image into RGB and Alpha components
        rgb = Image.new("RGB", img.size, (0, 0, 0))
        rgb.paste(img, mask=img.split()[3])  # Use alpha as mask for proper compositing
        return rgb, img.split()[3]  # Return (RGB, Alpha)
    else:
        raise ValueError("Input image must contain Alpha channel!")

def get_content_bounding_box(alpha):
    """Get the bounding box of non-transparent content"""
    # Convert to binary mask if needed
    if alpha.mode != '1':
        binary_mask = alpha.point(lambda p: 255 if p > 128 else 0)
    else:
        binary_mask = alpha
    
    # Get bounding box (left, upper, right, lower)
    bbox = binary_mask.getbbox()
    return bbox if bbox else (0, 0, alpha.width, alpha.height)

def calculate_max_rotation(alpha, target_size):
    """Calculate maximum rotation angle that keeps content within bounds"""
    # Get content bounding box
    bbox = get_content_bounding_box(alpha)
    content_width = bbox[2] - bbox[0]
    content_height = bbox[3] - bbox[1]
    
    # Calculate diagonal length of content
    diagonal = math.sqrt(content_width**2 + content_height**2)
    
    # Calculate safe margins
    max_width = target_size[0] * 0.95  # 95% of target width
    max_height = target_size[1] * 0.95  # 95% of target height
    
    # If diagonal is already too large, limit rotation
    if diagonal > min(max_width, max_height):
        return (-10, 10)  # Very limited rotation
    
    # Calculate maximum angle based on how much space is available
    space_ratio = min(max_width / content_width, max_height / content_height)
    
    if space_ratio > 1.5:
        return (-45, 45)  # Lots of space, allow large rotation
    elif space_ratio > 1.3:
        return (-30, 30)  # Moderate space
    elif space_ratio > 1.1:
        return (-15, 15)  # Limited space
    else:
        return (-5, 5)    # Very limited space

def calculate_safe_scale_range(alpha, target_size):
    """Calculate safe scaling range that keeps content within bounds"""
    # Get content bounding box
    bbox = get_content_bounding_box(alpha)
    content_width = bbox[2] - bbox[0]
    content_height = bbox[3] - bbox[1]
    
    # Calculate maximum scale that fits in target
    max_scale = min(
        target_size[0] * 0.95 / content_width,
        target_size[1] * 0.95 / content_height
    )
    
    # Calculate minimum scale (don't make too small)
    min_scale = max(0.5, min(
        target_size[0] * 0.3 / content_width,
        target_size[1] * 0.3 / content_height
    ))
    
    # Ensure min_scale is less than max_scale
    min_scale = min(min_scale, max_scale * 0.7)
    
    return (min_scale, max_scale)

def augment_with_mask(img_path, output_dir, num_aug=5, target_size=(360, 540)):
    """Main function for synchronized mask augmentation with boundary awareness"""
    # Load image and extract mask
    rgb, alpha = load_image_and_mask(img_path)
    base_name = os.path.splitext(os.path.basename(img_path))[0]
    
    os.makedirs(output_dir, exist_ok=True)
    
    for i in range(num_aug):
        # Start with copies to avoid modifying originals
        current_rgb = rgb.copy()
        current_alpha = alpha.copy()
        
        # Calculate safe transformation parameters
        safe_scale_range = calculate_safe_scale_range(current_alpha, target_size)
        safe_rotation_range = calculate_max_rotation(current_alpha, target_size)
        
        # 1. Scale - apply to both with same parameters within safe range
        scale_factor = random.uniform(*safe_scale_range)
        new_size = (int(current_rgb.width * scale_factor), int(current_rgb.height * scale_factor))
        current_rgb = current_rgb.resize(new_size, Image.BICUBIC)
        current_alpha = current_alpha.resize(new_size, Image.BICUBIC)
        
        # 2. Rotation - use same angle for both within safe range
        angle = random.uniform(*safe_rotation_range)
        rgb_fill = (0, 0, 0)
        alpha_fill = 0
        current_rgb = current_rgb.rotate(angle, resample=Image.BICUBIC, expand=True, fillcolor=rgb_fill)
        current_alpha = current_alpha.rotate(angle, resample=Image.BICUBIC, expand=True, fillcolor=alpha_fill)
        
        # 3. Flipping - apply same flip to both
        if random.random() < 0.5:
            current_rgb = current_rgb.transpose(Image.FLIP_LEFT_RIGHT)
            current_alpha = current_alpha.transpose(Image.FLIP_LEFT_RIGHT)
        
        # 4. Apply brightness only to RGB image (not to mask)
        if random.random() < 0.7:  # 70% chance of brightness adjustment
            brightness_factor = random.uniform(0.7, 1.3)
            current_rgb = ImageEnhance.Brightness(current_rgb).enhance(brightness_factor)
        
        # Get new content bounding box after transformations
        content_bbox = get_content_bounding_box(current_alpha)
        content_width = content_bbox[2] - content_bbox[0]
        content_height = content_bbox[3] - content_bbox[1]
        
        # Verify content still fits within target size
        if content_width > target_size[0] or content_height > target_size[1]:
            # Scale down if too large
            scale_factor = min(
                target_size[0] / content_width,
                target_size[1] / content_height
            ) * 0.95  # Add a small safety margin
            
            new_size = (int(current_rgb.width * scale_factor), int(current_rgb.height * scale_factor))
            current_rgb = current_rgb.resize(new_size, Image.BICUBIC)
            current_alpha = current_alpha.resize(new_size, Image.BICUBIC)
            
            # Recalculate content bounding box
            content_bbox = get_content_bounding_box(current_alpha)
        
        # 5. Translation - create same-sized canvases and paste at identical positions
        # Calculate safe translation range to keep content within bounds
        safe_x_min = max(0, target_size[0] - current_rgb.width)
        safe_y_min = max(0, target_size[1] - current_rgb.height)
        
        # Generate random position within safe range
        paste_x = random.randint(0, safe_x_min) if safe_x_min > 0 else 0
        paste_y = random.randint(0, safe_y_min) if safe_y_min > 0 else 0
        
        # Create new canvases of target size
        rgb_canvas = Image.new("RGB", target_size, (0, 0, 0))
        alpha_canvas = Image.new("L", target_size, 0)
        
        # Paste at exact same position for both
        rgb_canvas.paste(current_rgb, (paste_x, paste_y))
        alpha_canvas.paste(current_alpha, (paste_x, paste_y))
        
        # Merge into RGBA
        result = Image.merge("RGBA", (*rgb_canvas.split(), alpha_canvas))
        
        # Save output
        output_path = os.path.join(output_dir, f"{base_name}_aug_{i}.png")
        result.save(output_path)
        # print(f"Saved augmented image: {output_path}")

def batch_augment(input_dir, output_dir, num_aug=5, target_size=(360, 540)):
    """Process all PNG images in directory"""
    os.makedirs(output_dir, exist_ok=True)
    
    for fname in [f for f in os.listdir(input_dir) if f.lower().endswith('.png')]:
        try:
            img_path = os.path.join(input_dir, fname)
            # print(f"Processing: {img_path}")
            augment_with_mask(img_path, output_dir, num_aug, target_size)
        except Exception as e:
            print(f"Error processing {fname}: {e}")

if __name__ == "__main__":
    input_dir = "../autodl-tmp/train_data_360_540_selected/people_2_with_mask"
    output_dir = "../autodl-tmp/train_data_360_540_aug_2/people"
    target_size = (360, 540)  # Target canvas size
    batch_augment(input_dir, output_dir, num_aug=4, target_size=target_size)

In [26]:
import os
print(len([f for f in os.listdir('../autodl-tmp/4_26_new_train_data_aug/masks') if f.endswith('.png')]))

1652


In [None]:
import os
import random
import math
from PIL import Image, ImageEnhance

def load_image_and_mask(people_path, mask_path):
    """Load people image and corresponding mask"""
    try:
        # Load people image
        people_img = Image.open(people_path)
        if people_img.mode != 'RGB':
            people_img = people_img.convert('RGB')
        
        # Load mask image
        mask_img = Image.open(mask_path)
        if mask_img.mode != 'L':
            mask_img = mask_img.convert('L')
            
        return people_img, mask_img
    except Exception as e:
        print(f"Error loading image pair: {e}")
        return None, None

def extract_person_from_background(people_img, mask_img):
    """Extract only the person from the image using the mask"""
    # Create a transparent image (RGBA) with same dimensions
    extracted = Image.new("RGBA", people_img.size, (0, 0, 0, 0))
    
    # Convert RGB people image to RGBA
    people_rgba = people_img.convert("RGBA")
    
    # Create an RGBA version of the mask (where mask becomes alpha)
    # This requires converting mask values (0-255) to RGBA where RGB is white and A is the mask value
    mask_rgba = Image.new("RGBA", mask_img.size)
    
    # For each pixel in the mask, create corresponding RGBA pixel
    for y in range(mask_img.height):
        for x in range(mask_img.width):
            mask_value = mask_img.getpixel((x, y))
            # Get the RGB value from original image
            r, g, b = people_img.getpixel((x, y))
            # Create RGBA with original RGB and mask as alpha
            mask_rgba.putpixel((x, y), (r, g, b, mask_value))
    
    # The result is the person extracted with transparency where mask was 0
    return mask_rgba

def get_content_bounds(mask):
    """Get the content bounds from mask"""
    # Convert mask to binary for bbox detection if needed
    if mask.mode != '1':
        binary = mask.point(lambda p: p > 0)
    else:
        binary = mask
        
    # Get bounding box (left, upper, right, lower)
    bbox = binary.getbbox()
    if not bbox:
        # Return full image bounds if no content detected
        return (0, 0, mask.width, mask.height)
    return bbox

def crop_to_content(image, mask):
    """Crop both image and mask to just the content area"""
    # Get bounding box of content
    bbox = get_content_bounds(mask)
    
    # Crop both image and mask to that bounding box
    cropped_image = image.crop(bbox)
    cropped_mask = mask.crop(bbox)
    
    return cropped_image, cropped_mask, bbox

def calculate_safe_transformations(width, height, target_size):
    """Calculate safe transformation parameters to keep content in bounds"""
    # Calculate safe margins (95% of target)
    safe_width = target_size[0] * 0.95
    safe_height = target_size[1] * 0.95
    
    # Calculate scale ranges
    max_scale = min(
        safe_width / width,
        safe_height / height
    )
    
    # Don't make images too small (min 30% of target)
    min_scale = max(0.5, min(
        target_size[0] * 0.3 / width,
        target_size[1] * 0.3 / height
    ))
    
    # Ensure min_scale doesn't exceed max_scale
    min_scale = min(min_scale, max_scale * 0.8)
    
    # Calculate rotation range based on available space
    space_ratio = min(safe_width / width, safe_height / height)
    
    if space_ratio > 1.5:
        rotation_range = (-45, 45)
    elif space_ratio > 1.3:
        rotation_range = (-30, 30)
    elif space_ratio > 1.1:
        rotation_range = (-15, 15)
    else:
        rotation_range = (-5, 5)
    
    return {
        'scale_range': (min_scale, max_scale),
        'rotation_range': rotation_range
    }

def augment_person(people_img, mask_img, output_people_dir, output_mask_dir, base_name, index, target_size):
    """Augment a person image and its mask, preserving only the masked area"""
    # Set random seed for consistent transformations
    seed = random.randint(0, 10000)
    random.seed(seed)
    
    # Step 1: Extract just the person using the mask
    extracted_person = extract_person_from_background(people_img, mask_img)
    
    # Step 2: Crop to content area
    content_width = extracted_person.width
    content_height = extracted_person.height
    
    # Calculate safe transformations
    transforms = calculate_safe_transformations(content_width, content_height, target_size)
    
    # Step 3: Apply transformations to the extracted person
    # 3.1 Scale
    scale = random.uniform(*transforms['scale_range'])
    new_width = int(extracted_person.width * scale)
    new_height = int(extracted_person.height * scale)
    transformed_person = extracted_person.resize((new_width, new_height), Image.BICUBIC)
    
    # Reset seed for consistency
    random.seed(seed)
    
    # 3.2 Rotate
    angle = random.uniform(*transforms['rotation_range'])
    transformed_person = transformed_person.rotate(angle, resample=Image.BICUBIC, expand=True, fillcolor=(0, 0, 0, 0))
    
    # Reset seed
    random.seed(seed)
    
    # 3.3 Flip (50% chance)
    if random.random() < 0.5:
        transformed_person = transformed_person.transpose(Image.FLIP_LEFT_RIGHT)
    
    # Step 4: Place on target canvas
    # Create new RGBA canvas
    output_canvas = Image.new("RGBA", target_size, (0, 0, 0, 0))
    
    # Calculate centering position
    paste_x = (target_size[0] - transformed_person.width) // 2
    paste_y = (target_size[1] - transformed_person.height) // 2
    
    # Random offset from center (within safe range)
    max_x_offset = min(paste_x, (target_size[0] - paste_x - transformed_person.width))
    max_y_offset = min(paste_y, (target_size[1] - paste_y - transformed_person.height))
    
    x_offset = random.randint(-max_x_offset if max_x_offset > 0 else 0, 
                             max_x_offset if max_x_offset > 0 else 0)
    y_offset = random.randint(-max_y_offset if max_y_offset > 0 else 0,
                             max_y_offset if max_y_offset > 0 else 0)
    
    # Adjust paste position with offset
    paste_x += x_offset
    paste_y += y_offset
    
    # Ensure we don't go out of bounds
    paste_x = max(0, min(paste_x, target_size[0] - transformed_person.width))
    paste_y = max(0, min(paste_y, target_size[1] - transformed_person.height))
    
    # Paste the transformed person
    output_canvas.paste(transformed_person, (paste_x, paste_y), transformed_person.split()[3])
    
    # Step 5: Separate RGB and Alpha for saving
    rgb_output = Image.new("RGB", target_size, (0, 0, 0))
    alpha_output = Image.new("L", target_size, 0)
    
    # Copy RGB channels from output canvas to RGB output
    rgb_output.paste(output_canvas.convert("RGB"), (0, 0))
    
    # Copy Alpha channel from output canvas to Alpha output
    alpha_output.paste(output_canvas.split()[3], (0, 0))
    
    # Save outputs
    people_output_path = os.path.join(output_people_dir, f"{base_name}_aug_{index}.png")
    mask_output_path = os.path.join(output_mask_dir, f"{base_name}_aug_{index}.png")
    # combined_output_path = os.path.join(output_people_dir, f"{base_name}_aug_{index}_preview.png")
    
    rgb_output.save(people_output_path)
    alpha_output.save(mask_output_path)
    # output_canvas.save(combined_output_path)  # Save combined RGBA for preview
    
    return people_output_path, mask_output_path

def batch_augment(people_dir, mask_dir, output_people_dir, output_mask_dir, num_augmentations=5, target_size=(360, 540)):
    """Process all image pairs in the input directories"""
    # Create output directories if they don't exist
    os.makedirs(output_people_dir, exist_ok=True)
    os.makedirs(output_mask_dir, exist_ok=True)
    
    # Find all image files in people directory
    valid_extensions = ('.png', '.jpg', '.jpeg')
    image_files = [f for f in os.listdir(people_dir) 
                  if os.path.isfile(os.path.join(people_dir, f)) 
                  and f.lower().endswith(valid_extensions)]
    print(image_files)
    processed_count = 0
    errors_count = 0
    
    # Process each image pair
    for img_file in image_files:
        try:
            # Check if corresponding mask exists
            img_file = os.path.splitext(img_file)[0] + '.JPG'
            mask_file = os.path.splitext(img_file)[0] + '.png'  # Assuming same filename in mask directory
            if not os.path.exists(os.path.join(mask_dir, mask_file)):
                # print(f"Warning: No matching mask found for {img_file}")
                continue
            
            base_name = os.path.splitext(img_file)[0]
            
            # Load the image pair
            people_path = os.path.join(people_dir, img_file)
            mask_path = os.path.join(mask_dir, mask_file)
            
            people_img, mask_img = load_image_and_mask(people_path, mask_path)
            if people_img is None or mask_img is None:
                print(f"Skipping {img_file} due to loading error")
                errors_count += 1
                continue
            
            # Generate augmentations
            for i in range(num_augmentations):
                try:
                    people_out, mask_out = augment_person(
                        people_img, mask_img, 
                        output_people_dir, output_mask_dir,
                        base_name, i, target_size
                    )
                    print(f"Created augmentation {i+1}/{num_augmentations} for {img_file}")
                except Exception as e:
                    print(f"Error creating augmentation {i+1} for {img_file}: {e}")
                    import traceback
                    traceback.print_exc()
            
            processed_count += 1
            print(f"Processed: {img_file} - created {num_augmentations} augmentations")
            
        except Exception as e:
            print(f"Error processing {img_file}: {e}")
            import traceback
            traceback.print_exc()
            errors_count += 1
    
    print(f"Augmentation complete. Processed {processed_count} images with {errors_count} errors.")
    print(f"Generated {processed_count * num_augmentations} augmented images.")

if __name__ == "__main__":
    # Configuration
    people_dir = "../autodl-tmp/4_27_分类数据/道扬24/1"  # Directory containing people images
    mask_dir = "../autodl-tmp/4_27_分类数据/道扬24/1_mask"     # Directory containing mask images
    output_people_dir = ".../autodl-tmp/4_27_aug/people"  # Output directory for augmented people images
    output_mask_dir = "../autodl-tmp/4_27_aug/masks"     # Output directory for augmented masks
    
    num_augmentations = 4
    target_size = (360, 360)  # Fixed size for all outputs
    
    # Run the augmentation
    batch_augment(people_dir, mask_dir, output_people_dir, output_mask_dir, num_augmentations, target_size)

In [21]:
import os
import random
import math
from PIL import Image, ImageEnhance

def load_image_and_mask(people_path, mask_path):
    """Load people image and corresponding mask"""
    try:
        # Load people image
        people_img = Image.open(people_path)
        if people_img.mode != 'RGB':
            people_img = people_img.convert('RGB')
        
        # Load mask image
        mask_img = Image.open(mask_path)
        if mask_img.mode != 'L':
            mask_img = mask_img.convert('L')
            
        return people_img, mask_img
    except Exception as e:
        print(f"Error loading image pair: {e}")
        return None, None

def extract_person_from_background(people_img, mask_img):
    """Extract only the person from the image using the mask"""
    # Create a transparent image (RGBA) with same dimensions
    extracted = Image.new("RGBA", people_img.size, (0, 0, 0, 0))
    
    # Convert RGB people image to RGBA
    people_rgba = people_img.convert("RGBA")
    
    # Create an RGBA version of the mask (where mask becomes alpha)
    # This requires converting mask values (0-255) to RGBA where RGB is white and A is the mask value
    mask_rgba = Image.new("RGBA", mask_img.size)
    
    # For each pixel in the mask, create corresponding RGBA pixel
    for y in range(mask_img.height):
        for x in range(mask_img.width):
            mask_value = mask_img.getpixel((x, y))
            # Get the RGB value from original image
            r, g, b = people_img.getpixel((x, y))
            # Create RGBA with original RGB and mask as alpha
            mask_rgba.putpixel((x, y), (r, g, b, mask_value))
    
    # The result is the person extracted with transparency where mask was 0
    return mask_rgba

def get_content_bounds(mask):
    """Get the content bounds from mask"""
    # Convert mask to binary for bbox detection if needed
    if mask.mode != '1':
        binary = mask.point(lambda p: p > 0)
    else:
        binary = mask
        
    # Get bounding box (left, upper, right, lower)
    bbox = binary.getbbox()
    if not bbox:
        # Return full image bounds if no content detected
        return (0, 0, mask.width, mask.height)
    return bbox

def crop_to_content(image, mask):
    """Crop both image and mask to just the content area"""
    # Get bounding box of content
    bbox = get_content_bounds(mask)
    
    # Crop both image and mask to that bounding box
    cropped_image = image.crop(bbox)
    cropped_mask = mask.crop(bbox)
    
    return cropped_image, cropped_mask, bbox

def calculate_safe_transformations(width, height, target_size):
    """Calculate safe transformation parameters to keep content in bounds"""
    # Calculate safe margins (95% of target)
    safe_width = target_size[0] * 0.95
    safe_height = target_size[1] * 0.95
    
    # Calculate scale ranges
    max_scale = min(
        safe_width / width,
        safe_height / height
    )
    
    # Don't make images too small (min 30% of target)
    min_scale = max(0.5, min(
        target_size[0] * 0.3 / width,
        target_size[1] * 0.3 / height
    ))
    
    # Ensure min_scale doesn't exceed max_scale
    min_scale = min(min_scale, max_scale * 0.8)
    
    # Calculate rotation range based on available space
    space_ratio = min(safe_width / width, safe_height / height)
    
    if space_ratio > 1.5:
        rotation_range = (-45, 45)
    elif space_ratio > 1.3:
        rotation_range = (-30, 30)
    elif space_ratio > 1.1:
        rotation_range = (-15, 15)
    else:
        rotation_range = (-5, 5)
    
    return {
        'scale_range': (min_scale, max_scale),
        'rotation_range': rotation_range
    }

def augment_person(people_img, mask_img, output_people_dir, output_mask_dir, base_name, index):
    """Augment a person image and its mask, preserving original size"""
    # Set random seed for consistent transformations
    seed = random.randint(0, 10000)
    random.seed(seed)
    
    # Step 1: Extract just the person using the mask
    extracted_person = extract_person_from_background(people_img, mask_img)
    
    # Step 2: Crop to content area (optional, keeps only the person)
    cropped_person, cropped_mask, bbox = crop_to_content(extracted_person, mask_img)
    
    # Step 3: Apply transformations to the extracted person
    # 3.1 Scale (randomly resize within a reasonable range)
    scale = random.uniform(0.8, 1.2)  # Example: 80% to 120% of original size
    new_width = int(cropped_person.width * scale)
    new_height = int(cropped_person.height * scale)
    transformed_person = cropped_person.resize((new_width, new_height), Image.BICUBIC)
    
    # Reset seed for consistency
    random.seed(seed)
    
    # 3.2 Rotate (smaller range to avoid excessive empty space)
    angle = random.uniform(-15, 15)  # Reduced rotation range
    transformed_person = transformed_person.rotate(angle, resample=Image.BICUBIC, expand=True, fillcolor=(0, 0, 0, 0))
    
    # Reset seed
    random.seed(seed)
    
    # 3.3 Flip (50% chance)
    if random.random() < 0.5:
        transformed_person = transformed_person.transpose(Image.FLIP_LEFT_RIGHT)
    
    # Step 4: Place on canvas with ORIGINAL DIMENSIONS
    original_width, original_height = people_img.size
    output_canvas = Image.new("RGBA", (original_width, original_height), (0, 0, 0, 0))
    
    # Center the transformed person
    paste_x = (original_width - transformed_person.width) // 2
    paste_y = (original_height - transformed_person.height) // 2
    
    # Paste the transformed person
    output_canvas.paste(transformed_person, (paste_x, paste_y), transformed_person.split()[3])
    
    # Step 5: Separate RGB and Alpha for saving
    rgb_output = output_canvas.convert("RGB")
    alpha_output = output_canvas.split()[3]  # Alpha channel
    
    # Save outputs
    people_output_path = os.path.join(output_people_dir, f"{base_name}_aug_{index}.png")
    mask_output_path = os.path.join(output_mask_dir, f"{base_name}_aug_{index}.png")
    
    rgb_output.save(people_output_path)
    alpha_output.save(mask_output_path)
    
    return people_output_path, mask_output_path


def batch_augment(people_dir, mask_dir, output_people_dir, output_mask_dir, num_augmentations=5):
    """Process all image pairs in the input directories"""
    # Create output directories if they don't exist
    os.makedirs(output_people_dir, exist_ok=True)
    os.makedirs(output_mask_dir, exist_ok=True)
    
    # Find all image files in people directory
    valid_extensions = ('.png', '.jpg', '.jpeg')
    image_files = [f for f in os.listdir(people_dir) 
                  if os.path.isfile(os.path.join(people_dir, f)) 
                  and f.lower().endswith(valid_extensions)]
    print(image_files)
    processed_count = 0
    errors_count = 0
    
    # Process each image pair
    for img_file in image_files:
        try:
            # Check if corresponding mask exists
            img_file = os.path.splitext(img_file)[0] + '.JPG'
            mask_file = os.path.splitext(img_file)[0] + '.png'  # Assuming same filename in mask directory
            if not os.path.exists(os.path.join(mask_dir, mask_file)):
                # print(f"Warning: No matching mask found for {img_file}")
                continue
            
            base_name = os.path.splitext(img_file)[0]
            
            # Load the image pair
            people_path = os.path.join(people_dir, img_file)
            mask_path = os.path.join(mask_dir, mask_file)
            
            people_img, mask_img = load_image_and_mask(people_path, mask_path)
            if people_img is None or mask_img is None:
                print(f"Skipping {img_file} due to loading error")
                errors_count += 1
                continue
            
            # Generate augmentations
            for i in range(num_augmentations):
                try:
                    people_out, mask_out = augment_person(
                        people_img, mask_img, 
                        output_people_dir, output_mask_dir,
                        base_name, i  # No target_size passed
                    )

                    print(f"Created augmentation {i+1}/{num_augmentations} for {img_file}")
                except Exception as e:
                    print(f"Error creating augmentation {i+1} for {img_file}: {e}")
                    import traceback
                    traceback.print_exc()
            
            processed_count += 1
            print(f"Processed: {img_file} - created {num_augmentations} augmentations")
            
        except Exception as e:
            print(f"Error processing {img_file}: {e}")
            import traceback
            traceback.print_exc()
            errors_count += 1
    
    print(f"Augmentation complete. Processed {processed_count} images with {errors_count} errors.")
    print(f"Generated {processed_count * num_augmentations} augmented images.")

if __name__ == "__main__":
    # Configuration
    people_dir = "../autodl-tmp/4_27_分类数据/道扬24/2"  # Directory containing people images
    mask_dir = "../autodl-tmp/4_27_分类数据/道扬24/2_mask"     # Directory containing mask images
    output_people_dir = "../autodl-tmp/4_27_aug/people"  # Output directory for augmented people images
    output_mask_dir = "../autodl-tmp/4_27_aug/masks"     # Output directory for augmented masks
    
    num_augmentations = 4
    # target_size = (360, 360)  # Fixed size for all outputs
    
    # Run the augmentation
    batch_augment(people_dir, mask_dir, output_people_dir, output_mask_dir, num_augmentations)

['KIT_9007.JPG', 'KIT_9034.JPG', 'KIT_9048.JPG', 'KIT_9051.JPG', 'KIT_9087.JPG', 'KIT_9060.JPG', 'KIT_9105.JPG', 'KIT_9098.JPG', 'KIT_9113.JPG', 'KIT_9128.JPG', 'KIT_9116.JPG', 'KIT_9141.JPG', 'KIT_9136.JPG', 'KIT_9147.JPG', 'KIT_9166.JPG', 'KIT_9178.JPG', 'KIT_9168.JPG', '_DSC4492.JPG', '_DSC4513.JPG', '_DSC4548.JPG', '_DSC4552.JPG', '_DSC4555.JPG', '_DSC4560.JPG', '_DSC4565.JPG', '_DSC4585.JPG', '_DSC4593.JPG', '_DSC4612.JPG', '_DSC4601.JPG', '_DSC4627.JPG', '_DSC4616.JPG', '_DSC4643.JPG', '_DSC4669.JPG', '_DSC4640.JPG', '_DSC4678.JPG', '_DSC4674.JPG', '_DSC4679.JPG', '_DSC4681.JPG']
Created augmentation 1/4 for KIT_9007.JPG
Created augmentation 2/4 for KIT_9007.JPG
Created augmentation 3/4 for KIT_9007.JPG
Created augmentation 4/4 for KIT_9007.JPG
Processed: KIT_9007.JPG - created 4 augmentations
Created augmentation 1/4 for KIT_9034.JPG
Created augmentation 2/4 for KIT_9034.JPG
Created augmentation 3/4 for KIT_9034.JPG
Created augmentation 4/4 for KIT_9034.JPG
Processed: KIT_9034.J