In [14]:
from pathlib import Path

dataset_path = Path("./lfw")

matchpairs_train_csv = dataset_path / "matchpairsDevTrain.csv"
assert matchpairs_train_csv.exists(), f"Expected {matchpairs_train_csv} to exist."


mismatches_train_csv = dataset_path / "mismatchpairsDevTrain.csv"
assert mismatches_train_csv.exists(), f"Expected {mismatches_train_csv} to exist."

matchpairs_test_csv = dataset_path / "matchpairsDevTest.csv"
assert matchpairs_test_csv.exists(), f"Expected {matchpairs_test_csv} to exist."

mismatches_test_csv = dataset_path / "mismatchpairsDevTest.csv"
assert mismatches_test_csv.exists(), f"Expected {mismatches_test_csv} to exist."

In [18]:
import pandas as pd
import numpy as np
import tensorflow as tf
from tensorflow.keras.preprocessing.image import load_img, img_to_array
import matplotlib.pyplot as plt
import os
from sklearn.metrics import accuracy_score, classification_report
import random

In [20]:
# Configuration
IMG_SIZE = (224, 224)
RESIZED_IMG_SIZE = (224, 224)
BATCH_SIZE = 32


def get_image_path(person_name, image_num):
    """Get the path to an image given person name and image number"""
    img_name = f"{person_name}_{image_num:04d}.jpg"
    return (
        dataset_path / "lfw-deepfunneled" / "lfw-deepfunneled" / person_name / img_name
    )


def load_and_preprocess_image(image_path):
    """Load and preprocess an image"""
    try:
        # Load image
        img = load_img(image_path, target_size=IMG_SIZE)
        img_array = img_to_array(img)
        # Normalize to [0,1]
        img_array = img_array / 255.0
        # Resize to RESIZED_IMG_SIZE
        img_array = tf.image.resize(img_array, RESIZED_IMG_SIZE)
        return img_array
    except Exception as e:
        print(f"Error loading image {image_path}: {e}")
        # Return a black image if loading fails
        return np.zeros((*IMG_SIZE, 3))


def load_pairs_data(csv_file, is_match=True):
    """Load pairs data from CSV file"""
    df = pd.read_csv(csv_file)
    pairs = []
    labels = []

    for _, row in df.iterrows():
        if is_match:
            # Match pairs: same person, different images
            name = row["name"]
            img1_path = get_image_path(name, row["imagenum1"])
            img2_path = get_image_path(name, row["imagenum2"])
            label = 1  # Same person
        else:
            # Mismatch pairs: different persons
            name1 = row.iloc[0]  # First name column
            name2 = row.iloc[2]  # Second name column
            img1_path = get_image_path(name1, row.iloc[1])  # First image num
            img2_path = get_image_path(name2, row.iloc[3])  # Second image num
            label = 0  # Different persons

        pairs.append((img1_path, img2_path))
        labels.append(label)

    return pairs, labels

In [21]:
# Load training data
print("Loading training data...")
match_pairs_train, match_labels_train = load_pairs_data(
    matchpairs_train_csv, is_match=True
)
mismatch_pairs_train, mismatch_labels_train = load_pairs_data(
    mismatches_train_csv, is_match=False
)

# Combine training data
train_pairs = match_pairs_train + mismatch_pairs_train
train_labels = match_labels_train + mismatch_labels_train

print(f"Training data: {len(train_pairs)} pairs")
print(f"Positive pairs (same person): {sum(train_labels)}")
print(f"Negative pairs (different person): {len(train_labels) - sum(train_labels)}")

# Load test data
print("\nLoading test data...")
match_pairs_test, match_labels_test = load_pairs_data(
    matchpairs_test_csv, is_match=True
)
mismatch_pairs_test, mismatch_labels_test = load_pairs_data(
    mismatches_test_csv, is_match=False
)

# Combine test data
test_pairs = match_pairs_test + mismatch_pairs_test
test_labels = match_labels_test + mismatch_labels_test

print(f"Test data: {len(test_pairs)} pairs")
print(f"Positive pairs (same person): {sum(test_labels)}")
print(f"Negative pairs (different person): {len(test_labels) - sum(test_labels)}")

Loading training data...
Training data: 2200 pairs
Positive pairs (same person): 1100
Negative pairs (different person): 1100

Loading test data...
Test data: 1000 pairs
Positive pairs (same person): 500
Negative pairs (different person): 500


In [29]:
import random
import tensorflow as tf


def create_tf_dataset(
    pairs,
    labels,
    batch_size=32,
    shuffle=True,
    buffer_size=None,  # if None, will default to len(pairs) when shuffling
):
    """
    Create a tf.data.Dataset of ((img1, img2), label) where:
      - img1/img2: float32 tensors of shape (H, W, 3)
      - label: float32 tensor of shape (1,)
    """

    def generator():
        indices = list(range(len(pairs)))
        if shuffle:
            random.shuffle(indices)

        for i in indices:
            img1_path, img2_path = pairs[i]
            y = labels[i]

            # Must return float32 of shape (H, W, 3)
            img1 = load_and_preprocess_image(
                img1_path
            )  # ensure it outputs (H,W,3) float32
            img2 = load_and_preprocess_image(img2_path)

            # Ensure label is float32 and shape (1,)
            y = tf.cast(y, tf.float32)
            y = tf.reshape(y, (1,))  # or keep scalar if your loss expects that

            yield (img1, img2), y

    output_signature = (
        (
            tf.TensorSpec(shape=(*RESIZED_IMG_SIZE, 3), dtype=tf.float32),
            tf.TensorSpec(shape=(*RESIZED_IMG_SIZE, 3), dtype=tf.float32),
        ),
        tf.TensorSpec(shape=(1,), dtype=tf.float32),
    )

    dataset = tf.data.Dataset.from_generator(
        generator, output_signature=output_signature
    )

    if shuffle:
        if buffer_size is None:
            buffer_size = len(pairs)
        dataset = dataset.shuffle(
            buffer_size=buffer_size, reshuffle_each_iteration=True
        )

    dataset = dataset.batch(batch_size, drop_remainder=False)
    dataset = dataset.prefetch(tf.data.AUTOTUNE)
    return dataset


# Example usage:
# If your model resizes 256->96 internally, set img_size=(256,256)
train_dataset = create_tf_dataset(
    train_pairs, train_labels, batch_size=BATCH_SIZE, shuffle=True
)
test_dataset = create_tf_dataset(
    test_pairs, test_labels, batch_size=BATCH_SIZE, shuffle=False
)

In [6]:
import tensorflow as tf
from tensorflow.keras.applications import ResNet50
from tensorflow.keras.models import Model
from tensorflow.keras.layers import (
    Input,
    Dense,
    Lambda,
    BatchNormalization,
    Dropout,
    GlobalAveragePooling2D,
)
import tensorflow.keras.backend as K


# Recreate the exact model architecture from your training code
def cosine_similarity(vectors):
    x, y = vectors
    x_norm = K.l2_normalize(x, axis=1)
    y_norm = K.l2_normalize(y, axis=1)
    similarity = K.sum(x_norm * y_norm, axis=1, keepdims=True)
    return 1 - similarity


def create_pretrained_base_network(
    input_shape, backbone="resnet50", trainable_layers=10
):
    input_layer = Input(shape=input_shape)
    base_model = ResNet50(
        weights="imagenet", include_top=False, input_tensor=input_layer
    )

    for layer in base_model.layers[:-trainable_layers]:
        layer.trainable = False
    for layer in base_model.layers[-trainable_layers:]:
        layer.trainable = True

    x = base_model.output
    x = GlobalAveragePooling2D()(x)
    x = BatchNormalization()(x)
    x = Dropout(0.5)(x)
    x = Dense(512, activation="relu")(x)
    x = BatchNormalization()(x)
    x = Dropout(0.3)(x)
    x = Dense(128, activation="relu", name="embedding")(x)

    return Model(input_layer, x, name=f"{backbone}_base_network")


def create_cosine_siamese_network(input_shape, backbone="resnet50"):
    base_network = create_pretrained_base_network(input_shape, backbone)

    input_a = Input(shape=input_shape, name="left_input")
    input_b = Input(shape=input_shape, name="right_input")

    processed_a = base_network(input_a)
    processed_b = base_network(input_b)

    distance = Lambda(
        cosine_similarity,
        output_shape=lambda shapes: (shapes[0][0], 1),
        name="cosine_distance",
    )([processed_a, processed_b])

    return Model([input_a, input_b], distance, name="cosine_siamese_network")


# Recreate the model
input_shape = (224, 224, 3)  # Adjust based on your training
cosine_siamese_model = create_cosine_siamese_network(input_shape, backbone="resnet50")

# Try to load weights
try:
    cosine_siamese_model.load_weights("models/best_siamese_model.keras")
    print("Weights loaded successfully!")
except Exception as e:
    print(f"Error loading weights: {e}")
    print("You may need to retrain the model or check for a weights-only file (.h5)")

Weights loaded successfully!


In [7]:
# Extract the pre-trained encoder as teacher
def extract_teacher_encoder(siamese_model):
    """Extract the encoder part from trained Siamese model"""
    # The base network is the encoder we want to distill
    for layer in siamese_model.layers:
        if "base_network" in layer.name:
            teacher_encoder = layer
            break
    else:
        # If not found, extract from the model structure
        teacher_encoder = siamese_model.get_layer("resnet50_base_network")

    print(f"Teacher encoder: {teacher_encoder.name}")
    print(f"Teacher parameters: {teacher_encoder.count_params():,}")
    return teacher_encoder


# Extract teacher from your trained model
teacher_encoder = extract_teacher_encoder(cosine_siamese_model)

Teacher encoder: resnet50_base_network
Teacher parameters: 24,712,704


In [25]:
teacher_encoder.trainable = False
def l2n(x):
    return tf.math.l2_normalize(x, axis=-1)


teacher = tf.keras.Sequential([teacher_encoder, tf.keras.layers.Lambda(l2n)], name="teacher_norm")

embedding_dim = teacher.output_shape[-1]
input_shape = teacher.input_shape[1:]

In [26]:
def build_student(input_shape, embedding_dim):
    inp = tf.keras.Input(shape=input_shape)
    x = tf.keras.layers.Rescaling(1.0 / 255)(inp)
    # Example lightweight CNN; swap for MobileNetV2/TinyViT/etc. if you like
    for f in [32, 64, 128]:
        x = tf.keras.layers.SeparableConv2D(f, 3, padding="same", activation="relu")(x)
        x = tf.keras.layers.MaxPooling2D()(x)
    x = tf.keras.layers.GlobalAveragePooling2D()(x)
    x = tf.keras.layers.Dense(embedding_dim)(x)
    out = tf.keras.layers.Lambda(lambda t: tf.math.l2_normalize(t, axis=-1))(x)
    return tf.keras.Model(inp, out, name="student_encoder")


student = build_student(input_shape, embedding_dim)

In [27]:
class EmbeddingDistiller(tf.keras.Model):
    def __init__(
        self, student, teacher, alpha=1.0, beta=0.0, margin=1.0, use_cosine=False
    ):
        super().__init__()
        self.student = student
        self.teacher = teacher
        self.alpha = alpha  # weight for distillation loss
        self.beta = beta  # weight for task loss (if pairs provided)
        self.margin = margin
        self.use_cosine = use_cosine
        self.loss_tracker = tf.keras.metrics.Mean(name="loss")
        self.distill_tracker = tf.keras.metrics.Mean(name="distill_loss")
        self.task_tracker = tf.keras.metrics.Mean(name="task_loss")

    @property
    def metrics(self):
        return [self.loss_tracker, self.distill_tracker, self.task_tracker]

    def contrastive_loss(self, y, z1, z2):
        # y: 1 = positive, 0 = negative; embeddings already L2-normalized
        d = tf.norm(z1 - z2, axis=-1)
        pos = y * tf.square(d)
        neg = (1.0 - y) * tf.square(tf.nn.relu(self.margin - d))
        return tf.reduce_mean(pos + neg)

    def distill_loss(self, z_s, z_t):
        if self.use_cosine:
            # CosineSimilarity returns negative cosine; negate to make it a loss we minimize
            cos = tf.keras.losses.CosineSimilarity(axis=-1)(z_t, z_s)
            return tf.reduce_mean(-cos)
        else:
            # MSE between normalized embeddings
            return tf.reduce_mean(tf.reduce_sum(tf.square(z_s - z_t), axis=-1))

    def train_step(self, data):
        # Accept (x) or ((x1,x2), y)
        supervised_pairs = False
        if isinstance(data, tuple):
            x, y = data
            if isinstance(x, (tuple, list)):
                supervised_pairs = True
                x1, x2 = x
            else:
                # Labeled singles (rare here) – we’ll ignore labels and just distill
                x1 = x
        else:
            x1 = data

        with tf.GradientTape() as tape:
            if supervised_pairs:
                z_s1 = self.student(x1, training=True)
                z_s2 = self.student(x2, training=True)
                z_t1 = tf.stop_gradient(self.teacher(x1, training=False))
                z_t2 = tf.stop_gradient(self.teacher(x2, training=False))

                distill = 0.5 * (
                    self.distill_loss(z_s1, z_t1) + self.distill_loss(z_s2, z_t2)
                )
                task = self.contrastive_loss(tf.cast(y, z_s1.dtype), z_s1, z_s2)
                loss = self.alpha * distill + self.beta * task
            else:
                z_s = self.student(x1, training=True)
                z_t = tf.stop_gradient(self.teacher(x1, training=False))
                distill = self.distill_loss(z_s, z_t)
                task = 0.0
                loss = self.alpha * distill

        grads = tape.gradient(loss, self.student.trainable_variables)
        self.optimizer.apply_gradients(zip(grads, self.student.trainable_variables))

        self.loss_tracker.update_state(loss)
        self.distill_tracker.update_state(distill)
        self.task_tracker.update_state(task)
        return {
            "loss": self.loss_tracker.result(),
            "distill_loss": self.distill_tracker.result(),
            "task_loss": self.task_tracker.result(),
        }

    def test_step(self, data):
        # Mirror train_step for validation metrics
        supervised_pairs = False
        if isinstance(data, tuple):
            x, y = data
            if isinstance(x, (tuple, list)):
                supervised_pairs = True
                x1, x2 = x
            else:
                x1 = x
        else:
            x1 = data

        if supervised_pairs:
            z_s1 = self.student(x1, training=False)
            z_s2 = self.student(x2, training=False)
            z_t1 = self.teacher(x1, training=False)
            z_t2 = self.teacher(x2, training=False)
            distill = 0.5 * (
                self.distill_loss(z_s1, z_t1) + self.distill_loss(z_s2, z_t2)
            )
            task = self.contrastive_loss(tf.cast(y, z_s1.dtype), z_s1, z_s2)
            loss = self.alpha * distill + self.beta * task
        else:
            z_s = self.student(x1, training=False)
            z_t = self.teacher(x1, training=False)
            distill = self.distill_loss(z_s, z_t)
            task = 0.0
            loss = self.alpha * distill

        return {"loss": loss, "distill_loss": distill, "task_loss": task}

In [30]:
distiller = EmbeddingDistiller(
    student, teacher, alpha=1.0, beta=0.5, margin=1.0, use_cosine=False
)
distiller.compile(optimizer=tf.keras.optimizers.Adam(3e-4))

# `train_pairs` yields ((x1, x2), y)  with y in {0,1}
# `val_pairs`   yields ((x1, x2), y)
history = distiller.fit(
    train_dataset,
    validation_data=test_dataset,
    epochs=50,
    callbacks=[
        tf.keras.callbacks.EarlyStopping(
            patience=5, restore_best_weights=True, monitor="val_loss"
        )
    ],
)

Epoch 1/50


2025-09-08 17:48:55.440062: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:117] Plugin optimizer for device_type GPU is enabled.


Error loading image lfw/lfw-deepfunneled/lfw-deepfunneled/Harriet_Lessy/Harriet_Lessy_0001.jpg: Could not import PIL.Image. The use of `load_img` requires PIL.
Error loading image lfw/lfw-deepfunneled/lfw-deepfunneled/Petria_Thomas/Petria_Thomas_0003.jpg: Could not import PIL.Image. The use of `load_img` requires PIL.
Error loading image lfw/lfw-deepfunneled/lfw-deepfunneled/Adolfo_Aguilar_Zinser/Adolfo_Aguilar_Zinser_0003.jpg: Could not import PIL.Image. The use of `load_img` requires PIL.
Error loading image lfw/lfw-deepfunneled/lfw-deepfunneled/Soon_Yi/Soon_Yi_0001.jpg: Could not import PIL.Image. The use of `load_img` requires PIL.
Error loading image lfw/lfw-deepfunneled/lfw-deepfunneled/Valentino_Rossi/Valentino_Rossi_0004.jpg: Could not import PIL.Image. The use of `load_img` requires PIL.
Error loading image lfw/lfw-deepfunneled/lfw-deepfunneled/Valentino_Rossi/Valentino_Rossi_0005.jpg: Could not import PIL.Image. The use of `load_img` requires PIL.
Error loading image lfw/lfw-

KeyboardInterrupt: 