In this tutorial we will learn to:
- Instantiate a DeepPrintExtractor
- Train a DeepPrintExtractor
- Extract DeepPrint features from fingerprint images
- Evaluate the performance of the extracted fixed-length representations

## Instantiate a DeepPrintExtractor

This package implements a number of variants of the DeepPrint architecture. The wrapper class for all these variants is called `DeepPrintExtractor`.
It has a `fit` method to train (and save) the model as well as an `extract` method to extract the DeepPrint features for fingerprint images. 

You can also try to implement your own models, but currently this is not directly supported by the package.

In [None]:
from flx.data.dataset import IdentifierSet, Identifier
from flx.extractor.fixed_length_extractor import get_DeepPrint_Tex, DeepPrintExtractor

# We will use the example dataset with 10 subjects and 10 impression per subject
training_ids: IdentifierSet = IdentifierSet([Identifier(i, j) for i in range(10) for j in range(10)])

# We choose a dimension of 128 for the fixed-length representation
extractor: DeepPrintExtractor = get_DeepPrint_Tex(num_training_subjects=training_ids.num_subjects, num_texture_dims=128)

## Training the model

Instantiating the model was easy. To train it, first we will load the training data (see the [data tutorial](./dataset_tutorial.ipynb) for how to implement your own dataset).

Besides the fingerprint images, we also need a mapping from subjects to integer labels (for pytorch). For some variants we also need minutiae data. To see how a more complex dataset can be loaded, have a look at `flx/setup/datasets.py`.

Finally, we call the `fit` method, which trains the model and saves it to the specified path.

There is also the option to add a validation set, which will be used to evaluate the embeddings during training. This is useful to monitor the training progress and to avoid overfitting.
In this example we will not use a validation set for simplicity.

In [None]:
import os

import torch 

from flx.data.dataset import *
from flx.data.image_loader import SFingeLoader
from flx.data.label_index import LabelIndex
from flx.data.transformed_image_loader import TransformedImageLoader
from flx.image_processing.binarization import LazilyAllocatedBinarizer
from flx.data.image_helpers import pad_and_resize_to_deepprint_input_size

# NOTE: If this does not work, enter the absolute path to the notebooks/example-dataset directory here! 
example_dataset_path = os.path.abspath("example-dataset")
outdir = os.path.join(os.path.dirname(example_dataset_path), "output")

# We will use the SFingeLoader to load the images from the dataset
image_loader = TransformedImageLoader(
        images=SFingeLoader(example_dataset_path),
        poses=None,
        transforms=[
            LazilyAllocatedBinarizer(5.0),
            pad_and_resize_to_deepprint_input_size,
        ],
    )

image_dataset = Dataset(image_loader, training_ids)

# For pytorch, we need to map the subjects to integer labels from [0 ... num_subjects-1]
label_dataset = Dataset(LabelIndex(training_ids), training_ids)

model_outdir = os.path.join(outdir, "training")
extractor.fit(
    fingerprints=image_dataset,
    minutia_maps=None,
    labels=label_dataset,
    validation_fingerprints=None,
    validation_benchmark=None,
    num_epochs=20,
    out_dir=model_outdir
)

## Embedding extraction

After training the model, we can extract the DeepPrint features for the fingerprint images. This is done by calling the `extract` method of the `DeepPrintExtractor` class.

In [None]:
# To load the best model, use the following line: But assuming you just trained it, it should already be loaded
# extractor.load_best_model(model_outdir)

# The second value is for the minutiae branch, which we do not have in this example
texture_embeddings, _ = extractor.extract(image_dataset)

## Benchmarking

To evaluate the embeddings, we want to run a benchmark on them. For this, we must first specify the type of benchmark, and which comparisons should be run.

In [None]:
from flx.scripts.generate_benchmarks import create_verification_benchmark

benchmark = create_verification_benchmark(subjects=list(range(10)), impressions_per_subject=list(range(10)))

Now we can run the benchmark. To do this, we must first specify the matcher (in our case cosine similarity of the embeddings)

In [None]:
from flx.benchmarks.matchers import CosineSimilarityMatcher

matcher = CosineSimilarityMatcher(texture_embeddings)

results = benchmark.run(matcher)

To visualize the results, we can plot a DET curve.

In [None]:
from flx.visualization.plot_DET_curve import plot_verification_results

figure_path = os.path.join(outdir, "DET_curve.png")

# Lists are used to allow for multiple models to be plotted in the same figure
plot_verification_results(figure_path, results=[results], model_labels=["DeepPrint_Tex"], plot_title="example-dataset - verification")