# Few-shot classification
This notebook demonstrates how to use the few-shot classification baselines.

There are two baselines:
1. Centroid-based classifier
2. Nearest neighbour classifier

The centroid-based classifier computes the class prototype by averaging the features of the support set. The class prototype is then used to classify the query set.

The nearest neighbour classifier classifies the query set by finding the nearest neighbour in the support set.

The notebook consists of the following steps:
1. Precompute features
2. Evaluate the baselines

The few-shot baselines require pre-extracted features. The features can be extracted using the feature extraction script.

In [3]:
# Add the parent directory to Python path
import os
import sys
# Get the directory of the notebook
notebook_dir = os.path.dirname(os.path.abspath("__file__"))
# Move one level up
root_dir = os.path.abspath(os.path.join(notebook_dir, "../../.."))
# Add to sys.path
sys.path.insert(0, root_dir)


from scripts.baselines.few_shot.feature_generation import generate_embeddings

# ['clip', 'dinov2', 'bioclip']
model_name = 'bioclip'
# ['centroid', 'nn']
classifier_name = 'centroid'
split = 'val'

# General settings
data_path = ''
feature_path = ''
path_out = "out"

## 1. Precompute features

In [4]:
# precompute features for the fungi dataset
generate_embeddings(data_path=data_path, model_name=model_name, data_split='train', feature_path=feature_path)
generate_embeddings(data_path=data_path, model_name=model_name, data_split='val', feature_path=feature_path) # not needed
# generate_embeddings(data_path=data_path, model_name=model_name, data_split='test', feature_path=feature_path)

Skipping /mnt/datagrid/personal/janoukl1/out/FungiTastic_public/features_old/bioclip/224x224_train.h5 because it already exists
Skipping /mnt/datagrid/personal/janoukl1/out/FungiTastic_public/features_old/bioclip/224x224_val.h5 because it already exists


## 2. Load features and evaluate the baselines

In [5]:
from scripts.baselines.few_shot.eval import test_fungi

# test_fungi(path_out, data_path, feature_path, feature_model, classifier_name, split, debug=False)
test_fungi(path_out=path_out, data_path=data_path, feature_path=feature_path, feature_model=model_name, classifier_name=classifier_name, split=split, debug=False)

Evaluating eval_bioclip_val_centroid


9it [00:08,  1.06it/s]


{'Accuracy': 0.15929978118161925, 'Recall@3': 0.2485776805251641, 'F1': 0.04096092101065624}
