DATACILLECTION AND  CREATION 

In [None]:
import os
import cv2
import numpy as np
import matplotlib.pyplot as plt

# Paths for datasets
BACKGROUND_DIR = r"C:\IMMAGE ANALYSIS PROJECT DATASET\weed\weeds\input\backgrounds"
WEED_DIR = r"C:\IMMAGE ANALYSIS PROJECT DATASET\weed\weeds\input\foregrounds\allweeds\weed"
TOMATO_PATH = r"C:\Users\BRIJESH KUMAR GHADEI\Downloads\Screenshot 2024-11-01 150908_processed.png"
OUTPUT_DIR = r"C:\IMMAGE ANALYSIS PROJECT DATASET\weed\weeds\output"

# Create output directories
os.makedirs(os.path.join(OUTPUT_DIR, "images"), exist_ok=True)
os.makedirs(os.path.join(OUTPUT_DIR, "masks"), exist_ok=True)
os.makedirs(os.path.join(OUTPUT_DIR, "visualization"), exist_ok=True)

class_colors = {
    'background': [0, 0, 0],
    'weed': [0, 255, 0],
    'tomato': [0, 0, 255]
}

def load_images(directory, size=(256, 256), alpha=False, target_size_ratio=0.5):
    """Load and resize images from a directory."""
    images = []
    for filename in os.listdir(directory):
        if filename.endswith((".jpg", ".png", ".jpeg")):
            img_path = os.path.join(directory, filename)
            img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED if alpha else cv2.IMREAD_COLOR)
            if img is not None:
                if alpha:
                    h, w = img.shape[:2]
                    target_size = (int(size[0] * target_size_ratio), int(size[1] * target_size_ratio))
                    ratio = min(target_size[0] / w, target_size[1] / h)
                    new_size = (int(w * ratio), int(h * ratio))
                    img = cv2.resize(img, new_size)
                else:
                    img = cv2.resize(img, size)
                images.append(img)
    return images

def load_and_process_tomato(path, size=(256, 256)):
    """Load and process single tomato image with transparency."""
    img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
    if img is None:
        raise ValueError(f"Could not load tomato image from {path}")
    
    # If image doesn't have alpha channel, add it
    if img.shape[-1] != 4:
        # Convert to RGBA
        tmp = cv2.cvtColor(img, cv2.COLOR_BGR2BGRA)
        # Create mask from non-black pixels
        mask = np.any(img > 20, axis=-1).astype(np.uint8) * 255
        tmp[:, :, 3] = mask
        img = tmp
    
    # Resize while maintaining aspect ratio
    h, w = img.shape[:2]
    target_size = (int(size[0] * 0.5), int(size[1] * 0.5))
    ratio = min(target_size[0] / w, target_size[1] / h)
    new_size = (int(w * ratio), int(h * ratio))
    img = cv2.resize(img, new_size)
    
    return img

def apply_random_transform(image):
    """Apply random transformations to create variations of the same image."""
    # Get image dimensions
    h, w = image.shape[:2]
    
    # Random scale (0.8 to 1.2)
    scale = np.random.uniform(0.8, 1.2)
    new_h, new_w = int(h * scale), int(w * scale)
    transformed = cv2.resize(image, (new_w, new_h))
    
    # Random rotation (-30 to 30 degrees)
    angle = np.random.uniform(-30, 30)
    matrix = cv2.getRotationMatrix2D((new_w/2, new_h/2), angle, 1)
    transformed = cv2.warpAffine(transformed, matrix, (new_w, new_h))
    
    # Random brightness variation
    if np.random.random() > 0.5:
        alpha = np.random.uniform(0.8, 1.2)  # Contrast
        beta = np.random.uniform(-20, 20)    # Brightness
        transformed = cv2.convertScaleAbs(transformed, alpha=alpha, beta=beta)
    
    return transformed

def create_multi_class_mask(shape, num_classes=3):
    """Create a mask array for multiple classes."""
    return np.zeros((shape[0], shape[1], num_classes), dtype=np.uint8)

def save_batch(images, masks, start_idx):
    """Save generated images and multi-class masks."""
    for idx, (image, mask) in enumerate(zip(images, masks)):
        # Save synthetic image
        image_path = os.path.join(OUTPUT_DIR, "images", f"synthetic_{start_idx + idx}.png")
        cv2.imwrite(image_path, image)
        
        # Save combined mask as a single PNG with different values for each class
        combined_mask = np.zeros(mask.shape[:2], dtype=np.uint8)
        for i in range(1, mask.shape[-1]):  # Skip background
            combined_mask[mask[:, :, i] > 0] = i
        
        mask_path = os.path.join(OUTPUT_DIR, "masks", f"mask_{start_idx + idx}.png")
        cv2.imwrite(mask_path, combined_mask)
        
        # Create visualization
        vis_img = image.copy()
        for class_idx, color in enumerate(class_colors.values()):
            if class_idx > 0:  # Skip background
                vis_img[mask[:, :, class_idx] > 0] = color
        
        vis_path = os.path.join(OUTPUT_DIR, "visualization", f"vis_{start_idx + idx}.png")
        cv2.imwrite(vis_path, vis_img)

def display_samples(images, masks, num_samples=5):
    """Display sample synthetic images with multi-class masks."""
    plt.figure(figsize=(15, 8))
    indices = np.random.choice(len(images), num_samples, replace=False)
    
    for idx, i in enumerate(indices):
        # Original image
        plt.subplot(4, num_samples, idx + 1)
        plt.imshow(cv2.cvtColor(images[i], cv2.COLOR_BGR2RGB))
        plt.axis('off')
        if idx == 0:
            plt.title('Synthetic Image')
        
        # Weed mask
        plt.subplot(4, num_samples, num_samples + idx + 1)
        plt.imshow(masks[i][:, :, 1], cmap='gray')
        plt.axis('off')
        if idx == 0:
            plt.title('Weed Mask')
        
        # Tomato mask
        plt.subplot(4, num_samples, 2*num_samples + idx + 1)
        plt.imshow(masks[i][:, :, 2], cmap='gray')
        plt.axis('off')
        if idx == 0:
            plt.title('Tomato Mask')
        
        # Combined visualization
        vis_img = images[i].copy()
        for class_idx, color in enumerate(class_colors.values()):
            if class_idx > 0:  # Skip background
                vis_img[masks[i][:, :, class_idx] > 0] = color
        plt.subplot(4, num_samples, 3*num_samples + idx + 1)
        plt.imshow(cv2.cvtColor(vis_img, cv2.COLOR_BGR2RGB))
        plt.axis('off')
        if idx == 0:
            plt.title('Combined Visualization')
    
    plt.tight_layout()
    plt.show()

def generate_synthetic_data(backgrounds, weeds, tomato_img, num_samples=1000, batch_size=100):
    """Generate synthetic images with both weeds and transformed tomato plants."""
    for batch_start in range(0, num_samples, batch_size):
        batch_samples = min(batch_size, num_samples - batch_start)
        synthetic_images, masks = [], []
        
        for _ in range(batch_samples):
            background = backgrounds[np.random.randint(len(backgrounds))].copy()
            mask = create_multi_class_mask(background.shape)
            
            # Set background class
            mask[:, :, 0] = 255
            
            # Add tomatoes (1-2 plants with variations)
            for _ in range(np.random.randint(1, 3)):
                # Create variation of tomato image
                tomato = apply_random_transform(tomato_img.copy())
                
                max_x = background.shape[1] - tomato.shape[1]
                max_y = background.shape[0] - tomato.shape[0]
                
                if max_x > 0 and max_y > 0:
                    x = np.random.randint(0, max_x)
                    y = np.random.randint(0, max_y)
                    
                    alpha = tomato[:, :, 3] / 255.0 if tomato.shape[-1] == 4 else np.ones(tomato.shape[:2])
                    
                    for c in range(3):
                        background[y:y+tomato.shape[0], x:x+tomato.shape[1], c] = (
                            alpha * tomato[:, :, c] + 
                            (1 - alpha) * background[y:y+tomato.shape[0], x:x+tomato.shape[1], c]
                        )
                    
                    # Update tomato mask (class index 2)
                    mask[y:y+tomato.shape[0], x:x+tomato.shape[1], 2] = (alpha > 0.5).astype(np.uint8) * 255
                    mask[y:y+tomato.shape[0], x:x+tomato.shape[1], 0] = 0  # Remove background where tomato exists
            
            # Add weeds (1-3 plants)
            for _ in range(np.random.randint(1, 4)):
                weed = weeds[np.random.randint(len(weeds))]
                max_x = background.shape[1] - weed.shape[1]
                max_y = background.shape[0] - weed.shape[0]
                
                if max_x > 0 and max_y > 0:
                    x = np.random.randint(0, max_x)
                    y = np.random.randint(0, max_y)
                    
                    alpha = weed[:, :, 3] / 255.0
                    
                    for c in range(3):
                        background[y:y+weed.shape[0], x:x+weed.shape[1], c] = (
                            alpha * weed[:, :, c] + 
                            (1 - alpha) * background[y:y+weed.shape[0], x:x+weed.shape[1], c]
                        )
                    
                    # Update weed mask (class index 1)
                    weed_mask = (alpha > 0.5).astype(np.uint8) * 255
                    mask[y:y+weed.shape[0], x:x+weed.shape[1], 1] = weed_mask
                    # Remove background where weed exists
                    mask[y:y+weed.shape[0], x:x+weed.shape[1], 0] = 0
            
            synthetic_images.append(background)
            masks.append(mask)
        
        synthetic_images = np.array(synthetic_images)
        masks = np.array(masks)
        
        save_batch(synthetic_images, masks, batch_start)
        
        if batch_start == 0:
            display_samples(synthetic_images, masks)
        
        print(f"Generated and saved images {batch_start} to {batch_start + batch_samples - 1}")
    
    return

def main():
    print("Loading images...")
    backgrounds = load_images(BACKGROUND_DIR)
    weeds = load_images(WEED_DIR, alpha=True)
    
    try:
        tomato_img = load_and_process_tomato(TOMATO_PATH)
        print("Successfully loaded tomato image")
    except Exception as e:
        print(f"Error loading tomato image: {e}")
        return
    
    print(f"Loaded {len(backgrounds)} backgrounds and {len(weeds)} weeds")
    
    if not all([backgrounds, weeds, tomato_img is not None]):
        print("Error: No images found in one or more directories")
        return
        
    print("Generating synthetic images...")
    generate_synthetic_data(backgrounds, weeds, tomato_img, num_samples=1000)
    print("Done!")

if __name__ == "__main__":
    main()