In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import os
import random
from PIL import Image
from torch.utils.data import Dataset
import torchvision.transforms as transforms

import os
import torch
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import transforms
import argparse
from tqdm import tqdm

In [3]:
class FingerprintDataset(Dataset):
    def __init__(self, data_dir, input_size=(105, 105)):
        self.data_dir = data_dir
        self.image_paths = []
        self.input_size = input_size

        # Collect all image paths
        for class_dir in os.listdir(data_dir):
            class_path = os.path.join(data_dir, class_dir)
            if os.path.isdir(class_path):
                for image_name in os.listdir(class_path):
                    self.image_paths.append(os.path.join(class_path, image_name))

        # Define two different transform pipelines
        self.transform1 = transforms.Compose([
            transforms.Grayscale(num_output_channels=1),
            transforms.Resize(self.input_size),
            transforms.RandomRotation(10),
            transforms.RandomAffine(degrees=0, translate=(0.05, 0.05), scale=(0.95, 1.05)),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5], std=[0.5])
        ])

        self.transform2 = transforms.Compose([
            transforms.Grayscale(num_output_channels=1),
            transforms.Resize(self.input_size),
            transforms.RandomRotation(10),
            transforms.RandomAffine(degrees=0, translate=(0.05, 0.05), scale=(0.95, 1.05)),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5], std=[0.5])
        ])

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img1_path = self.image_paths[idx]
        img1 = Image.open(img1_path).convert('RGB')

        is_positive = random.choice([True, False])

        if is_positive:
            img2_path = img1_path
            img2 = Image.open(img2_path).convert('RGB')

            # Apply different augmentations for each image in the same pair
            img1 = self.transform1(img1)
            img2 = self.transform2(img2)
        else:
            img2_path = random.choice(self.image_paths)
            while img2_path == img1_path:
                img2_path = random.choice(self.image_paths)
            img2 = Image.open(img2_path).convert('RGB')

            img1 = self.transform1(img1)
            img2 = self.transform2(img2)

        return img1, img2, int(is_positive)

In [4]:
class SiameseNetwork(nn.Module):
    def __init__(self, embedding_dim=256):
        super(SiameseNetwork, self).__init__()

        self.conv = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=7),  # Output: [64, 99, 99]
            nn.ReLU(),
            nn.MaxPool2d(2),                 # -> [64, 49, 49]

            nn.Conv2d(64, 128, kernel_size=5),  # -> [128, 45, 45]
            nn.ReLU(),
            nn.MaxPool2d(2),                   # -> [128, 22, 22]

            nn.Conv2d(128, 128, kernel_size=3),  # -> [128, 20, 20]
            nn.ReLU(),
            nn.MaxPool2d(2)                     # -> [128, 10, 10]
        )

        self.fc = nn.Sequential(
            nn.Linear(128 * 10 * 10, 512),
            nn.ReLU(),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Linear(256, embedding_dim)  # Final output is the embedding
        )

    def forward_one(self, x):
        x = self.conv(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)
        return x

    def forward(self, input1, input2):
        output1 = self.forward_one(input1)
        output2 = self.forward_one(input2)
        return output1, output2

In [5]:
class ContrastiveLoss(nn.Module):
    def __init__(self, margin=50.0, reduction='mean'):
        super(ContrastiveLoss, self).__init__()
        self.margin = margin
        self.reduction = reduction

    def forward(self, output1, output2, label):
        # Ensure label is float for proper math
        label = label.float()

        euclidean_distance = F.pairwise_distance(output1, output2)
        positive_loss = label * torch.pow(euclidean_distance, 2)
        negative_loss = (1 - label) * torch.pow(torch.clamp(self.margin - euclidean_distance, min=0.0), 2)
        loss = 0.5 * (positive_loss + negative_loss)

        if self.reduction == 'mean':
            return loss.mean()
        elif self.reduction == 'sum':
            return loss.sum()
        else:
            return loss  # no reduction


In [14]:
def train(input_directory, model_file, embedding_size, num_epochs, batch_size, learning_rate):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


    train_dataset = FingerprintDataset(input_directory)
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    model = SiameseNetwork(embedding_dim=embedding_size).to(device)
    loss_function = ContrastiveLoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)

    min_loss = float('inf')  # Set to infinity initially

# Training loop
    for epoch in range(num_epochs):
        model.train()
        total_loss = 0.0

        for batch_idx, (img1, img2, label) in enumerate(train_loader, start=1):  # Use enumerate to get batch index
            img1, img2, label = img1.to(device), img2.to(device), label.to(device)
            optimizer.zero_grad()
            output1, output2 = model(img1, img2)
            loss = loss_function(output1, output2, label)
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            print(f"Batch [{batch_idx}/{len(train_loader)}] Loss: {loss.item()}")  # Include batch number

        epoch_loss = total_loss / len(train_loader)
        print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss}')

        # Save the model checkpoint for the current epoch
        # torch.save(model.state_dict(), f'siamese_model_epoch_{epoch+1}.pth')

        # Check if this is the minimum loss so far
        if epoch_loss < min_loss:
            min_loss = epoch_loss
            torch.save(model.state_dict(), model_file)  # Save the model with minimum loss
            print(f"New minimum loss achieved: {min_loss}. Model saved as 'siamese_model_min_loss.pth'.")

    print("Training finished!")


In [11]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [15]:
input_directory = '/content/drive/MyDrive/Facultate Informatica/Profesor/2024-2025/Sisteme expert si metode biometrice in securitatea informatiei/Curs/ColabMount/DATA/NISTDB4_RAW/train_set'
model_file = '/content/drive/MyDrive/Facultate Informatica/Profesor/2024-2025/Sisteme expert si metode biometrice in securitatea informatiei/Curs/ColabMount/Models/model_fingerprints.pth'
embedding_size = 128
epochs = 50
batch_size = 16
learning_rate = 0.0001

In [16]:
train(input_directory, model_file, embedding_size, epochs, batch_size, learning_rate)

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
Batch [1/105] Loss: 44.66823959350586
Batch [2/105] Loss: 14.511571884155273
Batch [3/105] Loss: 51.18972396850586
Batch [4/105] Loss: 31.339542388916016
Batch [5/105] Loss: 33.11433792114258
Batch [6/105] Loss: 29.36878204345703
Batch [7/105] Loss: 17.933698654174805
Batch [8/105] Loss: 55.77229309082031
Batch [9/105] Loss: 22.054370880126953
Batch [10/105] Loss: 70.9734878540039
Batch [11/105] Loss: 100.42759704589844
Batch [12/105] Loss: 27.30337142944336
Batch [13/105] Loss: 36.618385314941406
Batch [14/105] Loss: 29.474803924560547
Batch [15/105] Loss: 51.8868408203125
Batch [16/105] Loss: 33.62506866455078
Batch [17/105] Loss: 36.52508544921875
Batch [18/105] Loss: 75.40165710449219
Batch [19/105] Loss: 28.508760452270508
Batch [20/105] Loss: 31.573711395263672
Batch [21/105] Loss: 30.55251693725586
Batch [22/105] Loss: 12.240350723266602
Batch [23/105] Loss: 21.898818969726562
Batch [24/105] Loss: 41.61586761474609