In [1]:
import numpy as np
import os
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from skimage.transform import resize
from tqdm import tqdm  # Import tqdm for progress tracking
from concurrent.futures import ThreadPoolExecutor
from tensorflow.keras.preprocessing.image import ImageDataGenerator

# Load all files in the dataset directory
dataset_dir = 'dataset/'

In [2]:

# Define the target shape for images
TARGET_SHAPE = (64, 64, 3)

def resize_image(image):
    return resize(image, TARGET_SHAPE, preserve_range=True, anti_aliasing=True)

def load_npz_file(file_path):
    data = np.load(file_path, mmap_mode='r')
    color_images = data['colorImages'] / 255.0  # Normalize images
    # Resize images one at a time to minimize memory usage
    color_images_resized = np.array([resize_image(img) for img in color_images.transpose(3, 0, 1, 2)])
    
    bounding_boxes = data['boundingBox']
    landmarks_2d = data['landmarks2D']
    landmarks_3d = data['landmarks3D']
    
    return color_images_resized, bounding_boxes, landmarks_2d, landmarks_3d


In [3]:
import random

sampling_fraction = 0.1  # 10% sampling for faster training
file_list = [os.path.join(dataset_dir, f) for f in os.listdir(dataset_dir) if f.endswith('.npz')]
sampled_file_list = random.sample(file_list, int(len(file_list) * sampling_fraction))


# Data augmentation
datagen = ImageDataGenerator(
    rotation_range=15,
    width_shift_range=0.1,
    height_shift_range=0.1,
    shear_range=0.2,
    zoom_range=0.2,
    horizontal_flip=True,
    fill_mode='nearest'
)

# Data generator with very small batch size
def data_generator(file_list, batch_size=2):
    while True:
        for i in range(0, len(file_list), batch_size):
            batch_files = file_list[i:i + batch_size]
            images = []
            
            with ThreadPoolExecutor(max_workers=1) as executor:  # Single-threaded for stability
                results = list(tqdm(executor.map(load_npz_file, batch_files), total=len(batch_files), desc="Loading batch"))
            
            for result in results:
                color_images, _, _, _ = result
                images.append(color_images)
            
            images = np.concatenate(images, axis=0)  # Free memory immediately
            augmented_images = np.array([datagen.random_transform(img) for img in images])  # Augment
            
            yield augmented_images, augmented_images  # Return augmented images and originals



In [4]:
# Split file list into training and validation sets
train_files, val_files = train_test_split(file_list, test_size=0.2, random_state=42)


In [5]:
import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Conv2D, UpSampling2D, MaxPooling2D,Dropout, concatenate


In [6]:
def build_unet_model(input_shape):
    inputs = Input(input_shape)
    
    # Encoder
    conv1 = Conv2D(16, (3, 3), activation='relu', padding='same')(inputs)  # Reduced filters from 32 to 16
    conv1 = Conv2D(16, (3, 3), activation='relu', padding='same')(conv1)
    pool1 = MaxPooling2D((2, 2))(conv1)
    
    conv2 = Conv2D(32, (3, 3), activation='relu', padding='same')(pool1)  # Reduced filters from 64 to 32
    conv2 = Conv2D(32, (3, 3), activation='relu', padding='same')(conv2)
    pool2 = MaxPooling2D((2, 2))(conv2)
    
    # Bottleneck
    conv3 = Conv2D(64, (3, 3), activation='relu', padding='same')(pool2)  # Reduced filters from 128 to 64
    conv3 = Conv2D(64, (3, 3), activation='relu', padding='same')(conv3)
    
    # Decoder
    up4 = UpSampling2D((2, 2))(conv3)
    concat4 = concatenate([up4, conv2])
    conv4 = Conv2D(32, (3, 3), activation='relu', padding='same')(concat4)
    conv4 = Conv2D(32, (3, 3), activation='relu', padding='same')(conv4)
    
    up5 = UpSampling2D((2, 2))(conv4)
    concat5 = concatenate([up5, conv1])
    conv5 = Conv2D(16, (3, 3), activation='relu', padding='same')(concat5)
    conv5 = Conv2D(16, (3, 3), activation='relu', padding='same')(conv5)
    
    # Output layer
    outputs = Conv2D(3, (1, 1), activation='sigmoid', padding='same')(conv5)
    
    model = Model(inputs, outputs)
    return model

input_shape = TARGET_SHAPE
model = build_unet_model(input_shape)
model.compile(optimizer='adam', loss='mean_squared_error', metrics=['accuracy'])
model.summary()


In [8]:
# Callbacks
callbacks = [
    tf.keras.callbacks.EarlyStopping(patience=2, monitor='val_loss'),  # Stop early if no improvement
    tf.keras.callbacks.ModelCheckpoint('models/unet_model_ltwt.keras', save_best_only=True, save_freq='epoch')  # Save only the best model
]

# Train model
train_gen = data_generator(train_files, batch_size=8)
val_gen = data_generator(val_files, batch_size=8)

history = model.fit(
    train_gen,
    epochs=20,  # Reduced epochs for initial training
    steps_per_epoch=len(train_files) // 8,  # Adjust to smaller batch sizes
    validation_data=val_gen,
    validation_steps=len(val_files) // 8,
    callbacks=callbacks
)

Loading batch: 100%|██████████| 8/8 [00:15<00:00,  1.89s/it]
Loading batch: 100%|██████████| 8/8 [00:10<00:00,  1.35s/it]


Epoch 1/20
[1m  1/219[0m [37m━━━━━━━━━━━━━━━━━━━━[0m [1m1:54:50[0m 32s/step - accuracy: 0.1745 - loss: 0.0827

Loading batch:  12%|█▎        | 1/8 [00:03<00:25,  3.63s/it]

[1m  2/219[0m [37m━━━━━━━━━━━━━━━━━━━━[0m [1m31:25[0m 9s/step - accuracy: 0.1863 - loss: 0.0809   

Loading batch: 100%|██████████| 8/8 [00:20<00:00,  2.57s/it]
Loading batch:  88%|████████▊ | 7/8 [00:24<00:04,  4.20s/it]

[1m  3/219[0m [37m━━━━━━━━━━━━━━━━━━━━[0m [1m1:31:17[0m 25s/step - accuracy: 0.1995 - loss: 0.0805

Loading batch: 100%|██████████| 8/8 [00:32<00:00,  4.01s/it]
Loading batch:  75%|███████▌  | 6/8 [00:20<00:07,  3.54s/it]

[1m  4/219[0m [37m━━━━━━━━━━━━━━━━━━━━[0m [1m1:42:00[0m 28s/step - accuracy: 0.2085 - loss: 0.0802

Loading batch: 100%|██████████| 8/8 [00:45<00:00,  5.72s/it]
Loading batch:  50%|█████     | 4/8 [00:08<00:06,  1.55s/it]

[1m  5/219[0m [37m━━━━━━━━━━━━━━━━━━━━[0m [1m1:49:56[0m 31s/step - accuracy: 0.2063 - loss: 0.0801

Loading batch: 100%|██████████| 8/8 [00:29<00:00,  3.74s/it]
Loading batch:  88%|████████▊ | 7/8 [00:13<00:01,  1.83s/it]

[1m  6/219[0m [37m━━━━━━━━━━━━━━━━━━━━[0m [1m1:50:09[0m 31s/step - accuracy: 0.2000 - loss: 0.0800

Loading batch: 100%|██████████| 8/8 [00:16<00:00,  2.10s/it]
Loading batch:  12%|█▎        | 1/8 [00:02<00:18,  2.63s/it]

[1m  7/219[0m [37m━━━━━━━━━━━━━━━━━━━━[0m [1m1:36:51[0m 27s/step - accuracy: 0.1945 - loss: 0.0800

Loading batch: 100%|██████████| 8/8 [00:14<00:00,  1.87s/it]
Loading batch:  50%|█████     | 4/8 [00:06<00:05,  1.47s/it]

[1m  8/219[0m [37m━━━━━━━━━━━━━━━━━━━━[0m [1m1:31:53[0m 26s/step - accuracy: 0.1922 - loss: 0.0797

Loading batch: 100%|██████████| 8/8 [00:13<00:00,  1.63s/it]
Loading batch:  12%|█▎        | 1/8 [00:03<00:23,  3.40s/it]

[1m  9/219[0m [37m━━━━━━━━━━━━━━━━━━━━[0m [1m1:24:49[0m 24s/step - accuracy: 0.1907 - loss: 0.0797

Loading batch: 100%|██████████| 8/8 [00:18<00:00,  2.34s/it]
Loading batch: 100%|██████████| 8/8 [00:44<00:00,  5.55s/it]


[1m 10/219[0m [37m━━━━━━━━━━━━━━━━━━━━[0m [1m1:39:43[0m 29s/step - accuracy: 0.1883 - loss: 0.0795

Loading batch: 100%|██████████| 8/8 [00:43<00:00,  5.39s/it]


[1m 11/219[0m [32m━[0m[37m━━━━━━━━━━━━━━━━━━━[0m [1m1:48:01[0m 31s/step - accuracy: 0.1859 - loss: 0.0797

In [None]:
# Load the best model
model.load_weights('models/unet_model_ltwt.keras')

# Evaluate the model
val_gen = data_generator(val_files, batch_size=32)
val_loss, val_accuracy = model.evaluate(val_gen, steps=len(val_files) // 32)
print(f'Validation Loss: {val_loss}')
print(f'Validation Accuracy: {val_accuracy}')


In [None]:
def display_images(original, enhanced, n=5):
    plt.figure(figsize=(20, 10))
    for i in range(n):
        # Original images
        plt.subplot(2, n, i + 1)
        plt.imshow(original[i])
        plt.title('Original')
        plt.axis('off')

        # Enhanced images
        plt.subplot(2, n, i + 1 + n)
        plt.imshow(enhanced[i])
        plt.title('Enhanced')
        plt.axis('off')
    plt.show()

# Predict enhanced images
val_gen = data_generator(val_files, batch_size=32)
original_images, _ = next(val_gen)
enhanced_images = model.predict(original_images)

# Display images
display_images(original_images, enhanced_images)
