In [None]:
from google.colab import drive
drive.mount('/content/drive')

import os
import cv2
import random
import numpy as np
from PIL import Image, ImageEnhance
from tqdm import tqdm
from pathlib import Path
import albumentations as A
from imblearn.over_sampling import SMOTE
import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from scipy import ndimage
import matplotlib.pyplot as plt
from sklearn.utils import shuffle
import re

input_dir = Path("path")
output_dir = Path("path")
output_dir.mkdir(parents=True, exist_ok=True)

def flip_image(image):
    return cv2.flip(image, 1)

def rotate_image(image):
    angle = random.randint(-45, 45)
    h, w = image.shape[:2]
    matrix = cv2.getRotationMatrix2D((w / 2, h / 2), angle, 1)
    return cv2.warpAffine(image, matrix, (w, h))

def scale_image(image):
    scale_factor = random.uniform(0.7, 1.3)
    return cv2.resize(image, None, fx=scale_factor, fy=scale_factor)

def translate_image(image):
    tx, ty = random.randint(-20, 20), random.randint(-20, 20)
    matrix = np.float32([[1, 0, tx], [0, 1, ty]])
    return cv2.warpAffine(image, matrix, (image.shape[1], image.shape[0]))

def shear_image(image):
    shear_factor = random.uniform(-0.2, 0.2)
    h, w = image.shape[:2]
    matrix = np.float32([[1, shear_factor, 0], [0, 1, 0]])
    return cv2.warpAffine(image, matrix, (w, h))

def crop_image(image):
    h, w = image.shape[:2]
    start_x = random.randint(0, w // 10)
    start_y = random.randint(0, h // 10)
    end_x = w - random.randint(0, w // 10)
    end_y = h - random.randint(0, h // 10)
    return image[start_y:end_y, start_x:end_x]

def add_noise(image):
    noise = np.random.normal(0, 25, image.shape).astype(np.uint8)
    return cv2.add(image, noise)

def adjust_brightness(image):
    enhancer = ImageEnhance.Brightness(Image.fromarray(image))
    return np.array(enhancer.enhance(random.uniform(0.7, 1.3)))

def adjust_contrast(image):
    enhancer = ImageEnhance.Contrast(Image.fromarray(image))
    return np.array(enhancer.enhance(random.uniform(0.7, 1.3)))

def adjust_hue(image):
    hsv = cv2.cvtColor(image, cv2.COLOR_RGB2HSV)
    hsv[:, :, 0] = (hsv[:, :, 0] + random.randint(-10, 10)) % 180
    return cv2.cvtColor(hsv, cv2.COLOR_HSV2RGB)

def adjust_saturation(image):
    enhancer = ImageEnhance.Color(Image.fromarray(image))
    return np.array(enhancer.enhance(random.uniform(0.7, 1.3)))

def blur_image(image):
    kernel_size = random.choice([3, 5])
    return cv2.GaussianBlur(image, (kernel_size, kernel_size), 0)

def sharpen_image(image):
    kernel = np.array([[-1,-1,-1], [-1,9,-1], [-1,-1,-1]])
    return cv2.filter2D(image, -1, kernel)

def cutout(image, mask_size=50):
    h, w = image.shape[:2]
    y, x = np.random.randint(h), np.random.randint(w)
    y1 = np.clip(y - mask_size // 2, 0, h)
    y2 = np.clip(y + mask_size // 2, 0, h)
    x1 = np.clip(x - mask_size // 2, 0, w)
    x2 = np.clip(x + mask_size // 2, 0, w)
    image[y1:y2, x1:x2] = 0
    return image

def gridmask(image, grid_size=100, mask_ratio=0.5):
    h, w = image.shape[:2]
    for y in range(0, h, grid_size):
        for x in range(0, w, grid_size):
            if random.random() < mask_ratio:
                image[y:y+grid_size//2, x:x+grid_size//2] = 0
    return image

def mixup(image1, image2):
    lam = np.random.beta(0.4, 0.4)
    image1 = cv2.resize(image1, (image2.shape[1], image2.shape[0]))
    return (lam * image1 + (1 - lam) * image2).astype(np.uint8)

def albumentations_aug(image):
    transform = A.Compose([
        A.Rotate(limit=45, p=0.7),
        A.Affine(shear=(-25, 25), p=0.5),
        A.RandomScale(scale_limit=(-0.5, 1.0), p=0.5),
        A.HorizontalFlip(p=0.5),
        A.ShiftScaleRotate(
            shift_limit=0.2,
            scale_limit=0.2,
            rotate_limit=45,
            p=0.7
        ),
        A.Blur(blur_limit=3, p=0.3),
        A.RandomBrightnessContrast(p=0.5),
        A.CLAHE(p=0.3),
        A.GaussNoise(p=0.3),
    ])
    augmented = transform(image=image)
    return augmented['image']

def keras_aug(image):
    datagen = ImageDataGenerator(
        rotation_range=45,
        width_shift_range=0.2,
        height_shift_range=0.2,
        shear_range=0.2,
        zoom_range=0.3,
        horizontal_flip=True,
        fill_mode='reflect'
    )
    image = np.expand_dims(image, axis=0)
    aug_iter = datagen.flow(image, batch_size=1)
    aug_image = next(aug_iter)[0]
    return aug_image.astype(np.uint8)

def geometric_augmentations(image):
    angle = np.random.uniform(-45, 45)
    rotated = ndimage.rotate(image, angle, reshape=False, mode='reflect')
    tx, ty = np.random.randint(-20, 20, 2)
    translation_matrix = np.float32([[1, 0, tx], [0, 1, ty]])
    translated = cv2.warpAffine(rotated, translation_matrix,
                               (rotated.shape[1], rotated.shape[0]))
    scale = np.random.uniform(0.7, 1.3)
    center = (translated.shape[1] // 2, translated.shape[0] // 2)
    zoom_matrix = cv2.getRotationMatrix2D(center, 0, scale)
    zoomed = cv2.warpAffine(translated, zoom_matrix,
                           (translated.shape[1], translated.shape[0]))
    if np.random.random() > 0.5:
        flipped = cv2.flip(zoomed, 1)
    else:
        flipped = zoomed
    return flipped

def photometric_augmentations(image):
    brightness = np.random.uniform(0.7, 1.3)
    bright_image = image * brightness
    bright_image = np.clip(bright_image, 0, 255).astype(np.uint8)
    contrast = np.random.uniform(0.7, 1.5)
    contrast_image = image * contrast
    contrast_image = np.clip(contrast_image, 0, 255).astype(np.uint8)
    if np.random.random() > 0.5:
        kernel_size = np.random.choice([3, 5])
        blur_image = cv2.GaussianBlur(contrast_image, (kernel_size, kernel_size), 0)
    else:
        blur_image = contrast_image
    return blur_image

def fuzzy_augmentation(image):
    hsv = cv2.cvtColor(image, cv2.COLOR_RGB2HSV)
    hsv[:, :, 1] = hsv[:, :, 1] * np.random.uniform(0.7, 1.3)
    hsv[:, :, 2] = hsv[:, :, 2] * np.random.uniform(0.7, 1.3)
    hsv[:, :, 1] = np.clip(hsv[:, :, 1], 0, 255)
    hsv[:, :, 2] = np.clip(hsv[:, :, 2], 0, 255)
    fuzzy_image = cv2.cvtColor(hsv, cv2.COLOR_HSV2RGB)
    return fuzzy_image

image_files = list(input_dir.glob("*.*"))
target_count = 5000
current_count = len(image_files)
augmentation_needed = target_count - current_count

augmentations = [
    flip_image, rotate_image, scale_image, translate_image, shear_image,
    crop_image, add_noise, adjust_brightness, adjust_contrast, adjust_hue,
    adjust_saturation, blur_image, sharpen_image, cutout, gridmask,
    albumentations_aug, keras_aug, geometric_augmentations,
    photometric_augmentations, fuzzy_augmentation
]

print(f"Starting augmentation. Need to create {augmentation_needed} additional images.")

for i in tqdm(range(augmentation_needed)):
    img_path = random.choice(image_files)
    img = cv2.imread(str(img_path))
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

    if random.random() < 0.1:
        img2_path = random.choice(image_files)
        img2 = cv2.imread(str(img2_path))
        img2 = cv2.cvtColor(img2, cv2.COLOR_BGR2RGB)
        aug_img = mixup(img, img2)
    else:
        aug_img = img.copy()
        for aug_fn in augmentations:
            if random.random() < 0.7:
                try:
                    aug_img = aug_fn(aug_img)
                except:
                    pass

    aug_img = cv2.resize(aug_img, (224, 224))
    pil_img = Image.fromarray(aug_img)
    original_filename = img_path.name
    cleaned_filename = re.sub(r'[^\w\s.-]', '', original_filename)
    save_path = output_dir / f"aug_{i}_{cleaned_filename}"

    quality = 95
    while True:
        try:
            pil_img.save(save_path, "JPEG", quality=quality, optimize=True)
            if os.path.getsize(save_path) <= 200 * 1024 or quality <= 60:
                break
            quality -= 5
        except Exception as e:
            print(f"Error saving file {save_path}: {e}")
            break

print(f"Augmentation complete. Created {augmentation_needed} new images.")
print(f"Total images in output directory: {len(list(output_dir.glob('*.*')))}")