In [None]:
import sqlite3
from sklearn.datasets import fetch_lfw_people
import cv2
import os
import pickle
import imutils
import random
import numpy as np
import torch
import matplotlib.pyplot as plt
from PIL import Image
from torchvision import transforms
import torch.nn as nn
import dlib

# Load the data using sklearn
lfw_dataset = fetch_lfw_people(data_home='./LFW/', min_faces_per_person=100, download_if_missing=True, color=True)

# Get the path to the raw images
raw_images_path = os.path.join('./LFW/', "lfw_home/lfw_funneled")

# Load raw images
raw_images = []
raw_names = os.listdir(raw_images_path)

# Connecting to our SQLite database
conn = sqlite3.connect('lfw_dataset.db')
cursor = conn.cursor()

# Drop the existing 'faces' table if it exists
cursor.execute('DROP TABLE IF EXISTS faces')

# creating a table
cursor.execute('''
    CREATE TABLE IF NOT EXISTS faces (
        id INTEGER PRIMARY KEY,
        target INTEGER,
        name TEXT NOT NULL,
        image BLOB NOT NULL,
        raw_image BLOB NOT NULL
    )
''')

image_index = -1
# Inserting the targets, names, and images into the table
for images in lfw_dataset.images:
    # Convert the image data to bytes
    image_bytes = pickle.dumps(images)
    image_index += 1
    # Get the target index for the specified image
    target_index = lfw_dataset.target[image_index]
    
    # Get the corresponding name from target_names
    #name = lfw_dataset.target_names[target_index]
    # Insert the record into the database with target as id
    #cursor.execute("INSERT INTO faces (target, name, image) VALUES (?, ?, ?)", (int(target_index), name, image_bytes))

    # Get the corresponding name from target_names
    name = lfw_dataset.target_names[target_index]
    
    # Replace spaces with underscores in the name for directory lookup
    raw_image_directory_name = name.replace(" ", "_")
    
    # Find the matching raw image
    raw_image_path = os.path.join(raw_images_path, raw_image_directory_name)
    
    # Check if the directory exists
    if os.path.exists(raw_image_path):
        raw_image_names = os.listdir(raw_image_path)
        
        # Assuming you want to use the first raw image for simplicity
        if raw_image_names:
            raw_image_file = os.path.join(raw_image_path, raw_image_names[0])
            raw_image = cv2.imread(raw_image_file)
            raw_image_bytes = pickle.dumps(raw_image, protocol=pickle.HIGHEST_PROTOCOL)
        else:
            # Handle the case where no raw images are found
            print(f"No raw images found for {name}")
            raw_image_bytes = None
    else:
        # Handle the case where the directory does not exist
        print(f"No raw image directory found for {name}")
        raw_image_bytes = None
    
    # Insert the record into the database with target as id
    cursor.execute("INSERT INTO faces (target, name, image, raw_image) VALUES (?, ?, ?, ?)", (int(target_index), name, image_bytes, raw_image_bytes))

# Assuming you want to use all raw images for visualization
"""
    for raw_image_name in raw_image_names:
        # Exclude images containing "original"
       if "original" not in raw_image_name:
            raw_image_file = os.path.join(raw_image_path, raw_image_name)
            raw_image = cv2.imread(raw_image_file)

            # Print information about the raw image
            print(f"Shape of raw_image: {raw_image.shape}")

            # Display the image in RGB
            plt.imshow(cv2.cvtColor(raw_image, cv2.COLOR_BGR2RGB))
            plt.title(f"Raw Image: {raw_image_name}")
            plt.show()

            raw_image_bytes = pickle.dumps(raw_image)
"""
    
# Commit the changes
conn.commit()

# Retrieve unique names from the database
cursor.execute('SELECT DISTINCT name FROM faces')
unique_names = [row[0] for row in cursor.fetchall()]

# Create a folder for augmented images
augmented_images_folder = './Augmented_images'
os.makedirs(augmented_images_folder, exist_ok=True)

# Create subfolders for all names
for name in unique_names:
    person_folder = os.path.join(augmented_images_folder, name)
    os.makedirs(person_folder, exist_ok=True)

# Retrieve data from the database
cursor.execute('SELECT target, name, image FROM faces')
rows = cursor.fetchall()


class ResidualBlock(nn.Module):
    def __init__(self, in_features):
        super(ResidualBlock, self).__init__()

        conv_block = [nn.ReflectionPad2d(1),
                      nn.Conv2d(in_features, in_features, 3),
                      nn.BatchNorm2d(in_features),
                      nn.ReLU(),
                      nn.ReflectionPad2d(1),
                      nn.Conv2d(in_features, in_features, 3),
                      nn.BatchNorm2d(in_features)]

        self.conv_block = nn.Sequential(*conv_block)

    def forward(self, x):
        return x + self.conv_block(x)
    
class Generator(nn.Module):
    def __init__(self, ngf, n_residual_blocks=9):
        super(Generator, self).__init__()

        # Initial convolution block
        model = [nn.ReflectionPad2d(3),
                 nn.Conv2d(3, ngf, 7),
                 nn.BatchNorm2d(ngf),
                 nn.ReLU()]

        # Downsampling
        in_features = ngf
        out_features = in_features * 2
        for _ in range(2):
            model += [nn.Conv2d(in_features, out_features, 3, stride=2, padding=1),
                      nn.BatchNorm2d(out_features),
                      nn.ReLU()]
            in_features = out_features
            out_features = in_features * 2

        # Residual blocks
        for _ in range(n_residual_blocks):
            model += [ResidualBlock(in_features)]

        # Upsampling
        out_features = in_features // 2
        for _ in range(2):
            model += [nn.ConvTranspose2d(in_features, out_features, 3, stride=2, padding=1, output_padding=1),
                      nn.BatchNorm2d(out_features),
                      nn.ReLU()]
            in_features = out_features
            out_features = in_features // 2

        # Output layer
        model += [nn.ReflectionPad2d(3),
                  nn.Conv2d(ngf, 3, 7),
                  nn.Tanh()]

        self.model = nn.Sequential(*model)

    def forward(self, x):
        return self.model(x)

def apply_pretrained_model(image_dir, output_dir):
    # Create output directory if it doesn't exist
    os.makedirs(output_dir, exist_ok=True)

# Load the age progression model
model = Generator(ngf=32, n_residual_blocks=9)
ckpt = torch.load('/Users/sadhanaanandan/Fast-AgingGAN/pretrained_model/state_dict.pth', map_location='cpu')
model.load_state_dict(ckpt)
model.eval()


# Define image transformations
trans = transforms.Compose([
    transforms.Resize((512, 512)),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
])

def save_original_image(name, idx, original_image_bgr, person_folder):
    """
    Convert the image bytes back to a NumPy array, save the original image in BGR format,
    and print information for debugging.

    Parameters:
    - name: The name of the person
    - target: The target index
    - original_image_bytes: Bytes representing the original image
    - person_folder: Folder to save the images
    - idx: Index of the image
    """

    # Print information for debugging
    print(f"Processing image {idx + 1} for {name}, target: {target}")

    # Save the original image
    original_image_path = os.path.join(person_folder, f'{name}_original_{idx + 1}.jpg')
    cv2.imwrite(original_image_path, original_image_bgr)
    print(f"Original image saved: {original_image_path}")
    
# Example usage:
# Assuming you have the necessary variables (name, target, original_image_bytes, person_folder, idx)
# save_original_image(name, target, original_image_bytes, person_folder, idx)

# Initialize a face detector from dlib
face_detector = dlib.get_frontal_face_detector()

def get_face_bbox(image_bgr):
    # Convert BGR image to grayscale
    gray = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2GRAY)

    # Detect faces in the image
    faces = face_detector(gray)

    if len(faces) > 0:
        # Get the bounding box of the first detected face
        bbox = (faces[0].left(), faces[0].top(), faces[0].right(), faces[0].bottom())
        return bbox
    else:
        # Return None if no faces are detected
        return None



def save_aged_image(name, idx, raw_image_bgr, model, trans, person_folder):
    # Convert BGR image to RGB
    raw_image_rgb = cv2.cvtColor(raw_image_bgr, cv2.COLOR_BGR2RGB)

    # Normalize pixel values to [0, 1]
    raw_image_rgb_normalized = raw_image_rgb / 255.0

    # Convert the image to tensor and apply transformations
    raw_image_tensor = trans(Image.fromarray((raw_image_rgb_normalized * 255).astype('uint8')).convert('RGB')).unsqueeze(0)

    # Forward pass through the model
    aged_face = model(raw_image_tensor)

    # Detach and normalize aged image
    aged_face_normalized = (aged_face.squeeze().detach().numpy() + 1.0) / 2.0

    # Denormalize aged image to [0, 255]
    aged_face_save = (aged_face_normalized * 255).clip(0, 255).astype('uint8')

    # Transpose the array to have shape (512, 512, 3)
    #aged_face_save_rgb = np.transpose(aged_face_save, (1, 2, 0))
    
    # Transpose the array to have shape (250, 250, 3)
    if raw_image_rgb.shape[1] == 250:  # Check the actual shape
        aged_face_save_rgb = np.transpose(aged_face_save, (1, 2, 0))
    else:
        aged_face_save_rgb = np.transpose(aged_face_save, (2, 0, 1))

    # Convert the aged image back to BGR
    aged_face_save_bgr = cv2.cvtColor(aged_face_save_rgb, cv2.COLOR_RGB2BGR)

    # Get the bounding box of the face
    face_bbox = get_face_bbox(aged_face_save_bgr)

    if face_bbox is not None:
        # Crop the image based on the face bounding box
        cropped_face = aged_face_save_bgr[face_bbox[1]:face_bbox[3], face_bbox[0]:face_bbox[2]]

        # Save age-progressed and cropped image in BGR format
        aged_image_path = os.path.join(person_folder, f'{name}_aged_{idx + 1}.jpg')
        cv2.imwrite(aged_image_path, cropped_face)
        print(f"Aged and cropped image saved: {aged_image_path}")

        # Visualize both original and aged images side by side
        plt.figure(figsize=(10, 5))

        plt.subplot(1, 2, 1)
        plt.imshow(raw_image_rgb)
        plt.title('Original Image')

        plt.subplot(1, 2, 2)
        plt.imshow(cv2.cvtColor(cropped_face, cv2.COLOR_BGR2RGB))
        plt.title('Aged and Cropped Image')

        plt.show()

    else:
        print(f"No face detected in the aged image. Skipping...")


    # Save age-progressed image in BGR format
    #aged_image_path = os.path.join(person_folder, f'{name}_aged_{idx + 1}.jpg')
    #cv2.imwrite(aged_image_path, aged_face_save_bgr)
    #print(f"Aged image saved: {aged_image_path}")

    
    # Visualize intermediate results
    #plt.subplot(1, 3, 1)
    #plt.imshow(raw_image_rgb)
    #plt.title('Original Image')

    #plt.subplot(1, 3, 2)
    #plt.imshow(aged_face_save_rgb)
    #plt.title('Aged Face (BGR)')

    #plt.show()



# Function to match unique names with names from raw images
def match_names(unique_names, raw_names):
    matched_names = set(unique_names) & set(raw_names)
    return list(matched_names)

# Get the matched names
matched_names = match_names(unique_names, raw_names)

# Organize images by name
for name in unique_names:
    
    # Find the existing folder for the person (using name as subfolder name)
    person_folder = os.path.join(augmented_images_folder, name)

    # Check if the folder exists, and create it if it doesn't
    if not os.path.exists(person_folder):
        os.makedirs(person_folder)

    # Retrieve data for the current person
    cursor.execute('SELECT target, name, image FROM faces WHERE name = ?', (name,))
    person_images = cursor.fetchall()
    
    # Determine the number of images to process for this person
    max_images_to_process = min(2, len(person_images))

    # Process each image for the person
    # Process each image for the person
    for idx, row in enumerate(person_images[:max_images_to_process]):
        target, _, original_image_bytes = row
        
        # Convert the image bytes back to a NumPy array
        original_image = pickle.loads(original_image_bytes)
        
        # Convert the image bytes back to a NumPy array
        #raw_image = pickle.loads(raw_image_bytes)

        # Convert RGB to BGR for saving with cv2.imwrite
        original_image_bgr = cv2.cvtColor((original_image * 255).astype('uint8'), cv2.COLOR_RGB2BGR)
        
        # Convert RGB to BGR for saving with cv2.imwrite
        #raw_image_bgr = cv2.cvtColor((raw_image * 255).astype('uint8'), cv2.COLOR_RGB2BGR)
        
        save_original_image(name, idx, original_image_bgr, person_folder)
        
        # Retrieve the corresponding raw image data for the current person and image
        raw_image_row = cursor.execute('SELECT raw_image FROM faces WHERE target = ? AND name = ?', (target, name)).fetchone()
    
        if raw_image_row:
            raw_image_bytes = raw_image_row[0]
        
            # Convert the image bytes back to a NumPy array
            raw_image = pickle.loads(raw_image_bytes)

            # Convert RGB to BGR for saving with cv2.imwrite
            raw_image_bgr = cv2.cvtColor((raw_image * 255).astype('uint8'), cv2.COLOR_RGB2BGR)
        else:
            # Handle the case where no raw image is found for the current person and image
            print(f"No raw image found for {name} - target: {target}")
            continue
            
            """
    for raw_image_name in raw_image_names:
        # Exclude images containing "original"
        if "original" not in raw_image_name:
            raw_image_file = os.path.join(raw_image_path, raw_image_name)
            raw_image = cv2.imread(raw_image_file)

            # Print information about the raw image
            print(f"Shape of raw_image: {raw_image.shape}")

            # Display the image in RGB
            plt.imshow(cv2.cvtColor(raw_image, cv2.COLOR_BGR2RGB))
            plt.title(f"Raw Image: {raw_image_name}")
            plt.show()

            raw_image_bytes = pickle.dumps(raw_image)
"""
        
        save_aged_image(name, idx, raw_image_bgr, model, trans, person_folder)
"""       
# Save the remaining images (from count number 20 and onwards) as original images
    for idx, row in enumerate(person_images[max_images_to_process:], start=max_images_to_process):
        target, _, original_image_bytes = row
    
    # Convert the image bytes back to a NumPy array
        original_image = pickle.loads(original_image_bytes)

    # Convert RGB to BGR for saving with cv2.imwrite
        original_image_bgr = cv2.cvtColor((original_image * 255).astype('uint8'), cv2.COLOR_RGB2BGR)

    # Print information for debugging
        #print(f"Processing remaining original image {idx + 1} for {name}, target: {target}")

    # Save original image
        original_image_path = os.path.join(person_folder, f'{name}_original_{idx + 1}.jpg')
        cv2.imwrite(original_image_path, original_image_bgr)
        #print(f"Remaining original image saved: {original_image_path}")
"""
# Close the connection
conn.close()