In [None]:
# Here the data from lfw dataset set is augmented using techniques such as flipping, dynamic rotation, color jittering, 
# and edge enhancement. The augmentation techniques are applied to people with min 20 images.

import sqlite3
from sklearn.datasets import fetch_lfw_people
import cv2
import os
import pickle
import imutils
import random
import numpy as np
import torch
import matplotlib.pyplot as plt
import traceback
from PIL import Image
from torchvision import transforms
import torch.nn as nn
from tqdm import tqdm
from mtcnn.mtcnn import MTCNN

# Load the data using sklearn
lfw_dataset = fetch_lfw_people(data_home='./LFW/', min_faces_per_person=20, download_if_missing=True, color=True)

# Get the path to the raw images
raw_images_path = os.path.join('./LFW/', "lfw_home/lfw_funneled")

# Load raw images
raw_images = []
raw_names = os.listdir(raw_images_path)

# Connecting to the original SQLite database
conn_original = sqlite3.connect('lfw_dataset.db')
cursor_original = conn_original.cursor()

# Connecting to the new SQLite database for augmented images
conn_augmented = sqlite3.connect('lfw_augmented_dataset.db')
cursor_augmented = conn_augmented.cursor()

# Connecting to the new SQLite database for augmented images
conn_raw = sqlite3.connect('lfw_raw_dataset.db')
cursor_raw = conn_raw.cursor()

# Drop the existing 'faces' table if it exists in the augmented database
cursor_original.execute('DROP TABLE IF EXISTS faces')

# Drop the existing 'faces' table if it exists in the augmented database
cursor_augmented.execute('DROP TABLE IF EXISTS faces')

# Drop the existing 'faces' table if it exists in the augmented database
cursor_raw.execute('DROP TABLE IF EXISTS faces')

# Creating a new 'faces' table in the augmented database
cursor_augmented.execute('''
    CREATE TABLE IF NOT EXISTS faces (
        id INTEGER PRIMARY KEY,
        target INTEGER,
        name TEXT NOT NULL,
        image BLOB NOT NULL
    )
''')

# creating a table
cursor_original.execute('''
    CREATE TABLE IF NOT EXISTS faces (
        id INTEGER PRIMARY KEY,
        target INTEGER,
        name TEXT NOT NULL,
        image BLOB NOT NULL
    )
''')

# creating a table
cursor_raw.execute('''
    CREATE TABLE IF NOT EXISTS faces (
        id INTEGER PRIMARY KEY,
        target INTEGER,
        name TEXT NOT NULL,
        raw_image BLOB NOT NULL
    )
''')


image_index = -1
total_images = 0
total_raw_images = 0


# Inserting the targets, names, and images into the table
for images in lfw_dataset.images:
    # Convert the image data to bytes
    image_bytes = pickle.dumps(images)
    image_index += 1
    total_images += 1
    # Get the target index for the specified image
    target_index = lfw_dataset.target[image_index]
    # Get the corresponding name from target_names
    name = lfw_dataset.target_names[target_index]
    # Insert the record into the database with target as id
    cursor_original.execute("INSERT INTO faces (target, name, image) VALUES (?, ?, ?)", (int(target_index), name, image_bytes))

# Commit the changes to the augmented database
conn_original.commit()

# Iterate through all people in the dataset
for target_index, name in enumerate(lfw_dataset.target_names):
    # Replace spaces with underscores in the name for directory lookup
    raw_image_directory_name = name.replace(" ", "_")

    # Find the matching raw image directory
    raw_image_path = os.path.join(raw_images_path, raw_image_directory_name)

    # Check if the directory exists
    if os.path.exists(raw_image_path):
        # Iterate through all raw images in the directory
        for raw_image_name in os.listdir(raw_image_path):
            # Exclude images containing "original"
            if "original" not in raw_image_name:
                raw_image_file = os.path.join(raw_image_path, raw_image_name)
                raw_image = cv2.imread(raw_image_file)

                # Convert the raw image to bytes
                raw_image_bytes = pickle.dumps(raw_image, protocol=pickle.HIGHEST_PROTOCOL)
                total_raw_images += 1

                # Insert the record into the database with target as id
                cursor_raw.execute("INSERT INTO faces (target, name, raw_image) VALUES (?, ?, ?)",
                                        (int(target_index), name, raw_image_bytes))
    else:
        # Handle the case where the directory does not exist
        print(f"No raw image directory found for {name}")

# Commit the changes to the database
conn_raw.commit()

# Print the total number of images and raw images extracted
print(f"Total number of images: {total_images}")
print(f"Total number of raw images: {total_raw_images}")

# Retrieve unique names from the original database
cursor_original.execute('SELECT DISTINCT name FROM faces')
unique_names = [row[0] for row in cursor_original.fetchall()]

# Retrieve data from the original database
cursor_original.execute('SELECT target, name, image FROM faces')
rows_original = cursor_original.fetchall()

# Function for edge enhancement using Laplacian filter
def enhance_edges(image):
    # Convert the image to BGR if it's in RGB format
    if image.shape[-1] == 3:
        image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)

    # Convert the image to grayscale
    gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)

    # Apply Laplacian filter
    laplacian = cv2.Laplacian(gray, cv2.CV_64F)

    # Convert back to RGB or BGR
    if image.shape[-1] == 3:
        sharp_image = np.clip(image - 0.7 * laplacian[:, :, np.newaxis], 0, 255).astype('uint8')
    else:
        sharp_image = np.clip(gray - 0.7 * laplacian, 0, 255).astype('uint8')

    return sharp_image

class ResidualBlock(nn.Module):
    def __init__(self, in_features):
        super(ResidualBlock, self).__init__()

        conv_block = [nn.ReflectionPad2d(1),
                      nn.Conv2d(in_features, in_features, 3),
                      nn.BatchNorm2d(in_features),
                      nn.ReLU(),
                      nn.ReflectionPad2d(1),
                      nn.Conv2d(in_features, in_features, 3),
                      nn.BatchNorm2d(in_features)]

        self.conv_block = nn.Sequential(*conv_block)

    def forward(self, x):
        return x + self.conv_block(x)

class Generator(nn.Module):
    def __init__(self, ngf, n_residual_blocks=9):
        super(Generator, self).__init__()

        # Initial convolution block
        model = [nn.ReflectionPad2d(3),
                 nn.Conv2d(3, ngf, 7),
                 nn.BatchNorm2d(ngf),
                 nn.ReLU()]

        # Downsampling
        in_features = ngf
        out_features = in_features * 2
        for _ in range(2):
            model += [nn.Conv2d(in_features, out_features, 3, stride=2, padding=1),
                      nn.BatchNorm2d(out_features),
                      nn.ReLU()]
            in_features = out_features
            out_features = in_features * 2

        # Residual blocks
        for _ in range(n_residual_blocks):
            model += [ResidualBlock(in_features)]

        # Upsampling
        out_features = in_features // 2
        for _ in range(2):
            model += [nn.ConvTranspose2d(in_features, out_features, 3, stride=2, padding=1, output_padding=1),
                      nn.BatchNorm2d(out_features),
                      nn.ReLU()]
            in_features = out_features
            out_features = in_features // 2

        # Output layer
        model += [nn.ReflectionPad2d(3),
                  nn.Conv2d(ngf, 3, 7),
                  nn.Tanh()]

        self.model = nn.Sequential(*model)

    def forward(self, x):
        return self.model(x)

def apply_pretrained_model(image):

    # Load the age progression model
    model = Generator(ngf=32, n_residual_blocks=9)
    ckpt = torch.load('/Users/sadhanaanandan/Fast-AgingGAN/pretrained_model/state_dict.pth', map_location='cpu')
    model.load_state_dict(ckpt)
    model.eval()


    # Define image transformations
    trans = transforms.Compose([
        transforms.Resize((512, 512)),
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
        ])

    image = trans(Image.fromarray((image * 255).astype('uint8')).convert('RGB')).unsqueeze(0)

    # Forward pass
    with torch.no_grad():
        result = model(image)
    
    # Convert the result back to a NumPy array
    result = result.squeeze().cpu().numpy().transpose((1, 2, 0))
   
    # Postprocess if necessary (e.g., denormalization)
    result = ((result + 1) / 2.0 * 255).clip(0, 255).astype('uint8')

    return result

# Assuming each row in 'rows' is a tuple (target, name, image_bytes)
def convert_to_rgb(image_data):
    # Convert BGR to RGB
    rgb_image = cv2.cvtColor(image_data, cv2.COLOR_BGR2RGB)
    return rgb_image

def save_image_to_database(cursor, target, name, image_bytes):

    cursor.execute("INSERT INTO faces (target, name, image) VALUES (?, ?, ?)", (int(target), name, image_bytes))

# Create an instance of the MTCNN detector
detector = MTCNN()

# Initialize counters
original_images_count = 0
flipped_images_count = 0
rotated_images_count = 0
jittered_images_count = 0
enhanced_images_count = 0
aged_images_count = 0


"""
# Function for edge enhancement using Laplacian filter
def enhance_edges(image):
    # Apply Laplacian filter
    laplacian = cv2.Laplacian(image, cv2.CV_64F)
    sharp_image = np.clip(image - 0.7 * laplacian, 0, 255).astype('uint8')
    return sharp_image
"""

# Process each image for the person
for name in tqdm(unique_names, desc='Processing images'):
    # Retrieve data for the current person from the original database
    cursor_original.execute('SELECT target, name, image FROM faces WHERE name = ?', (name,))
    person_images_original = cursor_original.fetchall()

    # Determine the number of images to process for this person
    max_images_to_process = min(50, len(person_images_original))
    
    # Process images if there are 70 or fewer images
    if len(person_images_original) <= 30:
        for idx, row_original in enumerate(person_images_original):
            target, _, original_image_bytes = row_original

            # Convert the image bytes back to a NumPy array
            original_image = pickle.loads(original_image_bytes)

            # Save original image to the augmented database in RGB format
            original_image_rgb = (original_image * 255).astype('uint8')
            original_image_bytes = pickle.dumps(original_image_rgb)
            save_image_to_database(cursor_augmented, target, name, original_image_bytes)

            # Increment original images count
            original_images_count += 1

            # Save flipped image to the augmented database in RGB format
            flipped_image = cv2.flip(original_image_rgb, 1)
            flipped_image_bytes = pickle.dumps(flipped_image)
            save_image_to_database(cursor_augmented, target, name, flipped_image_bytes)

            flipped_images_count += 1
            
            # Apply dynamic rotation to the original image
            rotation_angle = random.uniform(-30, 30)  # Random rotation angle between -30 and 30 degrees
            rotated_image = imutils.rotate(original_image_rgb, angle=rotation_angle)

            # Save rotated image to the augmented database in RGB format
            rotated_image_bytes = pickle.dumps(rotated_image)
            save_image_to_database(cursor_augmented, target, name, rotated_image_bytes)

            rotated_images_count += 1

            # Apply color jittering to the original image
            color_jittered_image = cv2.cvtColor(original_image_rgb, cv2.COLOR_RGB2HSV)
            
            # Adjust brightness
            brightness_factor = random.uniform(0.5, 1.5)
            color_jittered_image[..., 2] = cv2.multiply(color_jittered_image[..., 2], brightness_factor)

            # Adjust contrast
            contrast_factor = random.uniform(0.5, 1.5)
            color_jittered_image[..., 1] = cv2.multiply(color_jittered_image[..., 1], contrast_factor)

            # Adjust hue
            hue_factor = random.uniform(-10, 10)
            color_jittered_image[..., 0] = (color_jittered_image[..., 0] + hue_factor) % 180

            # Adjust saturation
            saturation_factor = random.uniform(0.5, 1.5)
            color_jittered_image[..., 1] = cv2.multiply(color_jittered_image[..., 1], saturation_factor)

            # Convert back to RGB
            color_jittered_image_rgb = cv2.cvtColor(color_jittered_image, cv2.COLOR_HSV2RGB)

            # Save color-jittered image to the augmented database in RGB format
            color_jittered_image_bytes = pickle.dumps(color_jittered_image_rgb)
            save_image_to_database(cursor_augmented, target, name, color_jittered_image_bytes)

            jittered_images_count += 1
            
            # Enhance edges using Laplacian filter
            enhanced_image = enhance_edges(original_image_rgb)

            # Save enhanced image to the augmented database in RGB format
            enhanced_image_bytes = pickle.dumps(enhanced_image)
            save_image_to_database(cursor_augmented, target, name, enhanced_image_bytes)

            enhanced_images_count += 1
            
    # Process each image for the person (if the person has more than 50 images)
    else:
        # Process each image for the person (up to 50 images if the person has more than 50 images)
        for idx, row_original in enumerate(person_images_original[:max_images_to_process]):
            target, _, original_image_bytes = row_original

            # Convert the image bytes back to a NumPy array
            original_image = pickle.loads(original_image_bytes)

            # Save original image to the augmented database in RGB format
            original_image_rgb = (original_image * 255).astype('uint8')
            original_image_bytes = pickle.dumps(original_image_rgb)
            save_image_to_database(cursor_augmented, target, name, original_image_bytes)
            # Increment original images count
            original_images_count += 1

# Iterate through all raw images in the raw dataset
for name in tqdm(unique_names, desc='Processing raw images'):
    # Retrieve raw image data for the current person from the original database
    cursor_raw.execute('SELECT target, name, raw_image FROM faces WHERE name = ?', (name,))
    person_raw_images = cursor_raw.fetchall()

    # Determine the number of images to process for this person
    max_images_to_process = min(50, len(person_raw_images))

        # Process images if there are 70 or fewer images
    if len(person_raw_images) <= 30:
        for idx, raw_image_row in enumerate(person_raw_images):
            target, _, raw_image_bytes = raw_image_row

            # Convert the raw image bytes back to a NumPy array
            raw_image = pickle.loads(raw_image_bytes)
            raw_image_bgr = cv2.cvtColor(raw_image, cv2.COLOR_RGB2BGR)
     
            # Normalize pixel values to [0, 1]
            raw_image_normal = raw_image_bgr / 255.0

            # Convert the image to tensor and apply transformations
            #raw_image_tensor = trans(Image.fromarray((raw_image_rgb * 255).astype('uint8')).convert('RGB')).unsqueeze(0)
            aged_face = apply_pretrained_model(raw_image_normal)

            # Detect faces in the aged image
            faces = detector.detect_faces(aged_face)

            # Set a minimum face size threshold
            min_face_size = 100
            filtered_faces = [face for face in faces if face['box'][2] > min_face_size and face['box'][3] > min_face_size]

            # Use the largest detected face (if any)
            if filtered_faces:
               
                main_face = max(filtered_faces, key=lambda x: x['box'][2] * x['box'][3])
                x, y, w, h = main_face['box']

                # Check if the face coordinates are valid
                if w > 0 and h > 0:
                    # Crop the face from the aged image
                    cropped_face = aged_face[y:y+h, x:x+w]

                    # Save the cropped face to the database
                    aged_face_bytes = pickle.dumps(cropped_face)
                    save_image_to_database(cursor_augmented, target, name, aged_face_bytes)

                    # Increment original images count
                    aged_images_count += 1
                else:
                    print("Invalid face coordinates or empty face region.")
            else:
                print("No faces detected in the aged image.")


# Commit the changes to the augmented database
conn_augmented.commit()

# Retrieve data from the augmented database
cursor_augmented.execute('SELECT target, name, image FROM faces')
rows_augmented = cursor_augmented.fetchall()

# Print the total number of images in the augmented dataset
total_images_augmented = len(rows_augmented)
print(f'Total number of images in the augmented dataset: {total_images_augmented}')
print(f'Total number of original images in the augmented dataset: {original_images_count}')
print(f'Total number of flipped images in the augmented dataset: {flipped_images_count}')
print(f'Total number of rotated images in the augmented dataset: {rotated_images_count}')
print(f'Total number of color jittered images in the augmented dataset: {jittered_images_count}')
print(f'Total number of edge enhanced images in the augmented dataset: {enhanced_images_count}')
print(f'Total number of aged images in the augmented dataset: {aged_images_count}')

# Close the connections
conn_original.close()
conn_augmented.close()
conn_raw.close()