In [1]:
import os
import glob
import shutil
import random
import numpy as np
from PIL import Image, ImageEnhance, ImageOps
import cv2
from collections import Counter
import matplotlib.pyplot as plt

In [2]:
dataset_root = r"c:\Users\Pojesh\Documents\OfficialWorks\MV_Project\Dataset\affectnet\YOLO_format"

In [3]:
train_folder = os.path.join(dataset_root, "train")
augmented_folder = os.path.join(dataset_root, "train_augmented")

# Define class names
class_names = [
    "Anger",
    "Contempt",
    "Disgust",
    "Fear",
    "Happy",
    "Neutral",
    "Sad",
    "Surprise",
]

In [4]:
os.makedirs(os.path.join(augmented_folder, "images"), exist_ok=True)
os.makedirs(os.path.join(augmented_folder, "labels"), exist_ok=True)

In [5]:
def count_classes():
    """Count the number of images per class in the train folder"""
    class_counts = Counter()
    
    # Get all label files
    label_files = glob.glob(os.path.join(train_folder, "labels", "*.txt"))
    
    for label_file in label_files:
        try:
            with open(label_file, 'r') as f:
                first_line = f.readline().strip()
                if first_line:
                    # The class ID is the first number in the line
                    class_id = int(first_line.split()[0])
                    class_counts[class_id] += 1
        except Exception as e:
            print(f"Error reading {label_file}: {e}")
    
    return class_counts

In [6]:
def apply_augmentation(image, augmentation_type):
    """Apply different augmentation techniques to an image"""
    if augmentation_type == 0:
        # Horizontal flip
        return ImageOps.mirror(image)
    elif augmentation_type == 1:
        # Brightness adjustment
        enhancer = ImageEnhance.Brightness(image)
        return enhancer.enhance(random.uniform(0.8, 1.2))
    elif augmentation_type == 2:
        # Contrast adjustment
        enhancer = ImageEnhance.Contrast(image)
        return enhancer.enhance(random.uniform(0.8, 1.2))
    elif augmentation_type == 3:
        # Rotation (slight)
        return image.rotate(random.uniform(-15, 15), expand=False)
    elif augmentation_type == 4:
        # Combination: flip + brightness
        img = ImageOps.mirror(image)
        enhancer = ImageEnhance.Brightness(img)
        return enhancer.enhance(random.uniform(0.8, 1.2))
    elif augmentation_type == 5:
        # Combination: rotation + contrast
        img = image.rotate(random.uniform(-15, 15), expand=False)
        enhancer = ImageEnhance.Contrast(img)
        return enhancer.enhance(random.uniform(0.8, 1.2))
    else:
        return image


In [7]:
def augment_dataset():
    """Augment the dataset to balance classes"""
    # First, count the current class distribution
    class_counts = count_classes()
    print("Original class distribution:")
    for class_id, count in sorted(class_counts.items()):
        print(f"  {class_names[class_id]}: {count} images")
    
    # Find the target count (the highest class count)
    target_count = max(class_counts.values())
    print(f"\nTarget count per class: {target_count}")
    
    # Copy all original files to the augmented folder first
    print("\nCopying original files...")
    for file_type in ["images", "labels"]:
        original_files = glob.glob(os.path.join(train_folder, file_type, "*.*"))
        for file_path in original_files:
            file_name = os.path.basename(file_path)
            shutil.copy2(file_path, os.path.join(augmented_folder, file_type, file_name))
    
    # Get all label files and their corresponding class IDs
    label_files = glob.glob(os.path.join(train_folder, "labels", "*.txt"))
    file_class_map = {}
    
    for label_path in label_files:
        try:
            with open(label_path, 'r') as f:
                first_line = f.readline().strip()
                if first_line:
                    class_id = int(first_line.split()[0])
                    file_name = os.path.basename(label_path)
                    file_class_map[file_name] = class_id
        except Exception as e:
            print(f"Error reading {label_path}: {e}")
    
    # Augment underrepresented classes
    print("\nAugmenting underrepresented classes...")
    for class_id in range(len(class_names)):
        if class_id not in class_counts or class_counts[class_id] >= target_count:
            continue
        
        # Calculate how many more images we need
        needed = target_count - class_counts[class_id]
        print(f"  {class_names[class_id]}: need to generate {needed} more images")
        
        # Get all files of this class
        class_files = [file_name for file_name, file_class in file_class_map.items() if file_class == class_id]
        
        # Generate augmented images
        augmented_count = 0
        while augmented_count < needed:
            # Select a random file to augment
            label_file = random.choice(class_files)
            image_file = os.path.splitext(label_file)[0] + ".jpg"  # Try jpg first
            
            # Check if jpg exists, if not try other extensions
            image_path = os.path.join(train_folder, "images", image_file)
            if not os.path.exists(image_path):
                image_file = os.path.splitext(label_file)[0] + ".jpeg"
                image_path = os.path.join(train_folder, "images", image_file)
            if not os.path.exists(image_path):
                image_file = os.path.splitext(label_file)[0] + ".png"
                image_path = os.path.join(train_folder, "images", image_file)
            
            if not os.path.exists(image_path):
                print(f"    Warning: Could not find image for {label_file}")
                continue
            
            # Load the image
            try:
                image = Image.open(image_path)
                
                # Apply a random augmentation
                aug_type = random.randint(0, 5)  # 6 different augmentation types
                augmented_image = apply_augmentation(image, aug_type)
                
                # Save the augmented image and label
                aug_suffix = f"_aug_{augmented_count}"
                new_image_name = os.path.splitext(image_file)[0] + aug_suffix + os.path.splitext(image_file)[1]
                new_label_name = os.path.splitext(label_file)[0] + aug_suffix + ".txt"
                
                augmented_image.save(os.path.join(augmented_folder, "images", new_image_name))
                shutil.copy2(
                    os.path.join(train_folder, "labels", label_file),
                    os.path.join(augmented_folder, "labels", new_label_name)
                )
                
                augmented_count += 1
                if augmented_count % 100 == 0:
                    print(f"    Generated {augmented_count}/{needed} augmented images for {class_names[class_id]}")
            
            except Exception as e:
                print(f"    Error processing {image_path}: {e}")
    
    # Count the new class distribution
    print("\nVerifying augmented dataset...")
    augmented_class_counts = Counter()
    augmented_label_files = glob.glob(os.path.join(augmented_folder, "labels", "*.txt"))
    
    for label_file in augmented_label_files:
        try:
            with open(label_file, 'r') as f:
                first_line = f.readline().strip()
                if first_line:
                    class_id = int(first_line.split()[0])
                    augmented_class_counts[class_id] += 1
        except Exception as e:
            print(f"Error reading {label_file}: {e}")
    
    print("\nFinal class distribution:")
    for class_id in range(len(class_names)):
        print(f"  {class_names[class_id]}: {augmented_class_counts.get(class_id, 0)} images")
    
    # Plot the class distribution before and after augmentation
    plt.figure(figsize=(12, 6))
    
    plt.subplot(1, 2, 1)
    plt.bar(range(len(class_names)), [class_counts.get(i, 0) for i in range(len(class_names))])
    plt.xticks(range(len(class_names)), class_names, rotation=45)
    plt.title("Original Class Distribution")
    plt.ylabel("Number of Images")
    
    plt.subplot(1, 2, 2)
    plt.bar(range(len(class_names)), [augmented_class_counts.get(i, 0) for i in range(len(class_names))])
    plt.xticks(range(len(class_names)), class_names, rotation=45)
    plt.title("Augmented Class Distribution")
    
    plt.tight_layout()
    plt.savefig(os.path.join(dataset_root, "class_distribution.png"))
    plt.close()
    
    print(f"\nAugmentation complete! Augmented dataset saved to {augmented_folder}")
    print(f"Class distribution visualization saved to {os.path.join(dataset_root, 'class_distribution.png')}")
    
    # Update data.yaml to include the augmented dataset
    update_data_yaml()

In [8]:
def update_data_yaml():
    """Update the data.yaml file to use the augmented dataset"""
    yaml_path = os.path.join(dataset_root, "data.yaml")
    
    if os.path.exists(yaml_path):
        with open(yaml_path, 'r') as f:
            lines = f.readlines()
        
        # Update the train path
        for i, line in enumerate(lines):
            if line.strip().startswith("train:"):
                # Get the base path
                base_path = line.split(":")[1].strip().rsplit("/", 3)[0]
                lines[i] = f'train: "{base_path}/train_augmented/images"\n'
        
        # Write the updated yaml file
        with open(yaml_path, 'w') as f:
            f.writelines(lines)
        
        print(f"\nUpdated {yaml_path} to use the augmented dataset")

In [9]:
if __name__ == "__main__":
    augment_dataset()

Original class distribution:
  Anger: 2339 images
  Contempt: 1996 images
  Disgust: 2242 images
  Fear: 2021 images
  Happy: 2154 images
  Neutral: 1616 images
  Sad: 1914 images
  Surprise: 2819 images

Target count per class: 2819

Copying original files...

Augmenting underrepresented classes...
  Anger: need to generate 480 more images
    Generated 100/480 augmented images for Anger
    Generated 200/480 augmented images for Anger
    Generated 300/480 augmented images for Anger
    Generated 400/480 augmented images for Anger
  Contempt: need to generate 823 more images
    Generated 100/823 augmented images for Contempt
    Generated 200/823 augmented images for Contempt
    Generated 300/823 augmented images for Contempt
    Generated 400/823 augmented images for Contempt
    Generated 500/823 augmented images for Contempt
    Generated 600/823 augmented images for Contempt
    Generated 700/823 augmented images for Contempt
    Generated 800/823 augmented images for Contempt
