In [1]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

# First make sure to install timm
!pip install timm

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
import torch
from tqdm import tqdm
import numpy as np
import timm
import torchvision
# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session



## Device Selection

We will opt for gpu, if it's available

In [2]:
if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

print(f'Using device: {device}')

Using device: cpu


## Download Dataset

We can use CIFAR10 or CIFAR100 as our dataset. Since both are really common datasets we will use the `torchvision.datasets` class to load them.

To select the dataset, modify the `use_CIFAR10` boolean variable accordingly.

We also apply some basic __preprocessing__:

1. Normalize the dataset by subtracting the _mean_ and dividing with the _std_

In [3]:
from torch.utils.data import DataLoader, random_split

use_CIFAR10 = True

# Get the precomputed mean and std
# Those are needed to normalize the dataset
# NOTE: To calculate the mean and std we have to 
# 1. calculate the sum for each channel
# 2. implement mean and variance formulas
if use_CIFAR10:
    mean = (0.4914, 0.4822, 0.4465)
    std = (0.2023, 0.1994, 0.2010)
else:
    raise NotImplementedError("Please compute mean, std for CIFAR100")

# dataset directory remains the same for both cases
dataset_directory = "/kaggle/working/"

# the transformation also remains the same
transform = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Resize((224, 224)),
    torchvision.transforms.Normalize(mean, std)
])

# download the dataset
if use_CIFAR10:
    cifar_dataset = torchvision.datasets.CIFAR10(root=dataset_directory, train=True, download=True, transform=transform)
    test_dataset = torchvision.datasets.CIFAR10(root=dataset_directory, train=False, download=True, transform=transform)
else:
    cifar_dataset = torchvision.datasets.CIFAR100(root=dataset_directory, train=True, download=True, transform=transform)
    test_dataset = torchvision.datasets.CIFAR100(root=dataset_directory, train=False, download=True, transform=transform)

print(f'Dataset downloaded. Total images: {len(cifar_dataset)}')

# Split the dataset into train / valildation sets
train_size = int(0.9 * len(cifar_dataset))
val_size = len(cifar_dataset) - train_size

train_dataset, val_dataset = random_split(cifar_dataset, [train_size, val_size])

# set the batch size to 64
batch_size = 64

# Create some dataloaders
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)

print(f'Train images: {train_size}')
print(f'Validation images: {val_size}')
print(f'Test images: {len(test_dataset)}')

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to /kaggle/working/cifar-10-python.tar.gz


100%|██████████| 170M/170M [00:03<00:00, 49.6MB/s] 


Extracting /kaggle/working/cifar-10-python.tar.gz to /kaggle/working/
Files already downloaded and verified
Dataset downloaded. Total images: 50000
Train images: 45000
Validation images: 5000
Test images: 10000


## Useful Methods

Bellow we have created some methods to make the code simpler.

### Most Common List Item

In [34]:
from collections import Counter

def most_common(lst):
    data = Counter(lst)
    return max(lst, key=data.get)

### Get CLS Token

By taking a look at the [documentation](https://huggingface.co/docs/timm/en/feature_extraction) for the timm library and specifically the __Feature Extraction__ section we can see that in order to get the __last hidden state__ of the model we have to use the `forward_features` method.

Specifically, this method returns the patch embeddings at the last hidden state, __before pooling is applied__. The return vector is of shape

```
(batch_size, num_patches + 1, hidden_size)
```

The __CLS Token__ is by design the __first of the patch embeddings__

For example to get the CLS Token of the first image in the batch we would have to do:

```py
model_output[0, 0, :]
```

In [4]:
def get_cls_token(model: timm.models.vision_transformer.VisionTransformer, images: torch.Tensor) -> np.array:
    # get the last hidden state
    output = model.forward_features(images)

    # for each image get the cls token
    # make sure to convert each tensor to numpy
    cls_tokens = output[:, 0, :].cpu().numpy()
    
    return cls_tokens

### Extract CLS Tokens for all images in a DataLoader

This method uses the `get_cls_token` method above to extract all cls tokens from a given dataloader.

For each image we will also need:
    
    1. it's original position (?)
    2. the label

The method returns __a dictionary__ with:
1. __key__: The original position of the image
2. __value__: A dictionary with `cls_token` and `label` keys

In [5]:
def get_dataset_cls_tokens(model: timm.models.vision_transformer.VisionTransformer, loader: torch.utils.data.dataloader.DataLoader
):

    cls_tokens = []
    cls_labels = []

    for idx, (images, labels) in tqdm(enumerate(loader), desc="Calculating CLS Tokens", total=len(loader)):

        # this returns a numpy array with shape
        # (batch_size, hidden_size)
        tokens = get_cls_token(model=model, images=images)

        # For each image in the batch
        for idx in range(tokens.shape[0]):
            cls_tokens.append(tokens[idx, :])
            cls_labels.append(labels[idx])

    return cls_tokens, cls_labels

## Load the VIT model

To load the model we will use the `timm` library.

We could also use the `transformers` library and more specific the `ViTForImageClassification`.

In [6]:
model = timm.create_model(
    "vit_tiny_patch16_224",  # Pre-trained ViT-Tiny on ImageNet-1k
    pretrained=True,        # Load pre-trained weights
    num_classes=10          # Adapt classifier head to CIFAR-10 (10 classes)
)

# We dont want to train here so we can freeze all the layers
for param in model.parameters():
    param.requires_grad = False

model.safetensors:   0%|          | 0.00/22.9M [00:00<?, ?B/s]

## Generate CLS Tokens for the Train Set

In [7]:
train_cls_tokens, train_labels = get_dataset_cls_tokens(model, train_loader)

Calculating CLS Tokens: 100%|██████████| 704/704 [26:40<00:00,  2.27s/it]


In [25]:
# convert the labels to numpy better handling
train_labels = np.array(train_labels)

## Generate CLS Tokens for the Test Set

In [12]:
test_cls_tokens, test_labels = get_dataset_cls_tokens(model, test_loader)

Calculating CLS Tokens: 100%|██████████| 157/157 [05:56<00:00,  2.27s/it]


In [26]:
# convert the labels to numpy better handling
test_labels = np.array(test_labels)

## Brute Force KNN - WITHOUT Test Prediction

In the following cells we will try to find the labels from the test images by doing the following:

1) Use the `ViT` model to extract the cls_token for the test image
2) Find the `topK` similar cls_tokens from the train dataset
3) Assign the `y_pred` to the majority class from the KNN

__This method will act as our baseline__ as it does not require a finetuned Visual Transformer.



In [50]:
from sklearn.neighbors import NearestNeighbors

def brute_force_knn(distance: str, train_data: np.array, train_labels: np.array, test_data: np.array, top_k: int = 5):
    # Initialize the knn
    knn = NearestNeighbors(n_neighbors=top_k, algorithm="brute", metric=distance)
    
    # Create a numpy array from the list
    print(f'Total train images: {len(train_data)}')
    train_data = np.stack(train_data, axis=0)
    print(f'Train shape: {train_data.shape}')
    
    # Fit with the train set
    knn.fit(train_data)
    
    # Create a numpy array for the test images
    print(f'Total test images: {len(test_data)}')
    test_data = np.stack(test_data, axis=0)
    print(f'Test shape: {test_data.shape}')
    
    # Apply the knn
    distances, indexes = knn.kneighbors(test_data, return_distance=True)
    
    print(f'Distances: {distances.shape}')
    print(f'Indexes: {distances.shape}')
    
    y_pred = []
    # Gather the final y_pred
    # For each test image
    for i in tqdm(range(indexes.shape[0]), desc='Gathering results'):
        # Get the classes of the top_k
        classes = train_labels[indexes[i]]
        # Select the majority as y_pred
        y_pred.append(most_common(classes.tolist()))
    
    # finally convert the y_pred to array
    y_pred = np.stack(y_pred, axis=0)

    return y_pred
        

### Using Cosine Distance

In [53]:
for k in [3, 5, 7, 10, 11, 13, 15, 17, 19]:
    y_pred = brute_force_knn(distance="cosine", train_data=train_cls_tokens, train_labels=train_labels, test_data=test_cls_tokens, top_k=k)
    print('----------------------------')
    print(f'COSINE - TOP_K: {k}')
    print(classification_report(test_labels, y_pred))

Total train images: 45000
Train shape: (45000, 192)
Total test images: 10000
Test shape: (10000, 192)
Distances: (10000, 3)
Indexes: (10000, 3)


Gathering results: 100%|██████████| 10000/10000 [00:00<00:00, 291172.03it/s]

----------------------------
COSINE - TOP_K: 3
              precision    recall  f1-score   support

           0       0.75      0.74      0.74      1000
           1       0.70      0.72      0.71      1000
           2       0.78      0.55      0.65      1000
           3       0.56      0.49      0.52      1000
           4       0.69      0.65      0.67      1000
           5       0.61      0.60      0.61      1000
           6       0.67      0.84      0.74      1000
           7       0.70      0.77      0.73      1000
           8       0.80      0.77      0.78      1000
           9       0.66      0.76      0.71      1000

    accuracy                           0.69     10000
   macro avg       0.69      0.69      0.69     10000
weighted avg       0.69      0.69      0.69     10000

Total train images: 45000
Train shape: (45000, 192)
Total test images: 10000
Test shape: (10000, 192)





Distances: (10000, 5)
Indexes: (10000, 5)


Gathering results: 100%|██████████| 10000/10000 [00:00<00:00, 277417.57it/s]

----------------------------
COSINE - TOP_K: 5
              precision    recall  f1-score   support

           0       0.78      0.74      0.76      1000
           1       0.72      0.75      0.74      1000
           2       0.84      0.55      0.66      1000
           3       0.59      0.53      0.56      1000
           4       0.71      0.67      0.69      1000
           5       0.63      0.63      0.63      1000
           6       0.66      0.86      0.75      1000
           7       0.73      0.78      0.75      1000
           8       0.80      0.79      0.79      1000
           9       0.66      0.79      0.72      1000

    accuracy                           0.71     10000
   macro avg       0.71      0.71      0.71     10000
weighted avg       0.71      0.71      0.71     10000

Total train images: 45000
Train shape: (45000, 192)
Total test images: 10000
Test shape: (10000, 192)





Distances: (10000, 7)
Indexes: (10000, 7)


Gathering results: 100%|██████████| 10000/10000 [00:00<00:00, 260966.39it/s]

----------------------------
COSINE - TOP_K: 7
              precision    recall  f1-score   support

           0       0.78      0.75      0.76      1000
           1       0.74      0.75      0.75      1000
           2       0.85      0.53      0.65      1000
           3       0.59      0.53      0.56      1000
           4       0.73      0.68      0.70      1000
           5       0.64      0.63      0.64      1000
           6       0.66      0.88      0.75      1000
           7       0.73      0.79      0.76      1000
           8       0.82      0.79      0.81      1000
           9       0.66      0.81      0.73      1000

    accuracy                           0.71     10000
   macro avg       0.72      0.71      0.71     10000
weighted avg       0.72      0.71      0.71     10000

Total train images: 45000
Train shape: (45000, 192)
Total test images: 10000
Test shape: (10000, 192)





Distances: (10000, 10)
Indexes: (10000, 10)


Gathering results: 100%|██████████| 10000/10000 [00:00<00:00, 235759.56it/s]

----------------------------
COSINE - TOP_K: 10
              precision    recall  f1-score   support

           0       0.78      0.75      0.76      1000
           1       0.75      0.76      0.75      1000
           2       0.86      0.51      0.64      1000
           3       0.60      0.53      0.56      1000
           4       0.73      0.68      0.70      1000
           5       0.64      0.64      0.64      1000
           6       0.64      0.88      0.74      1000
           7       0.74      0.78      0.76      1000
           8       0.82      0.78      0.80      1000
           9       0.66      0.82      0.73      1000

    accuracy                           0.71     10000
   macro avg       0.72      0.71      0.71     10000
weighted avg       0.72      0.71      0.71     10000

Total train images: 45000
Train shape: (45000, 192)
Total test images: 10000
Test shape: (10000, 192)





Distances: (10000, 11)
Indexes: (10000, 11)


Gathering results: 100%|██████████| 10000/10000 [00:00<00:00, 236574.71it/s]

----------------------------
COSINE - TOP_K: 11
              precision    recall  f1-score   support

           0       0.79      0.74      0.76      1000
           1       0.75      0.76      0.75      1000
           2       0.87      0.51      0.64      1000
           3       0.60      0.53      0.56      1000
           4       0.74      0.68      0.71      1000
           5       0.64      0.65      0.64      1000
           6       0.64      0.88      0.74      1000
           7       0.73      0.78      0.76      1000
           8       0.81      0.79      0.80      1000
           9       0.66      0.81      0.73      1000

    accuracy                           0.71     10000
   macro avg       0.72      0.71      0.71     10000
weighted avg       0.72      0.71      0.71     10000

Total train images: 45000
Train shape: (45000, 192)
Total test images: 10000
Test shape: (10000, 192)





Distances: (10000, 13)
Indexes: (10000, 13)


Gathering results: 100%|██████████| 10000/10000 [00:00<00:00, 222072.42it/s]

----------------------------
COSINE - TOP_K: 13
              precision    recall  f1-score   support

           0       0.79      0.75      0.77      1000
           1       0.76      0.77      0.76      1000
           2       0.88      0.51      0.64      1000
           3       0.61      0.54      0.57      1000
           4       0.74      0.68      0.71      1000
           5       0.64      0.65      0.64      1000
           6       0.64      0.89      0.75      1000
           7       0.75      0.79      0.77      1000
           8       0.83      0.80      0.81      1000
           9       0.66      0.82      0.73      1000

    accuracy                           0.72     10000
   macro avg       0.73      0.72      0.72     10000
weighted avg       0.73      0.72      0.72     10000

Total train images: 45000
Train shape: (45000, 192)
Total test images: 10000
Test shape: (10000, 192)





Distances: (10000, 15)
Indexes: (10000, 15)


Gathering results: 100%|██████████| 10000/10000 [00:00<00:00, 210068.11it/s]

----------------------------
COSINE - TOP_K: 15
              precision    recall  f1-score   support

           0       0.79      0.75      0.77      1000
           1       0.75      0.76      0.75      1000
           2       0.88      0.50      0.64      1000
           3       0.60      0.54      0.56      1000
           4       0.73      0.68      0.70      1000
           5       0.64      0.64      0.64      1000
           6       0.64      0.89      0.75      1000
           7       0.74      0.78      0.76      1000
           8       0.83      0.79      0.81      1000
           9       0.66      0.83      0.73      1000

    accuracy                           0.71     10000
   macro avg       0.73      0.71      0.71     10000
weighted avg       0.73      0.71      0.71     10000

Total train images: 45000
Train shape: (45000, 192)
Total test images: 10000
Test shape: (10000, 192)





Distances: (10000, 17)
Indexes: (10000, 17)


Gathering results: 100%|██████████| 10000/10000 [00:00<00:00, 195759.50it/s]

----------------------------
COSINE - TOP_K: 17
              precision    recall  f1-score   support

           0       0.79      0.75      0.77      1000
           1       0.75      0.76      0.76      1000
           2       0.90      0.49      0.64      1000
           3       0.60      0.54      0.56      1000
           4       0.74      0.67      0.70      1000
           5       0.64      0.65      0.64      1000
           6       0.64      0.89      0.74      1000
           7       0.74      0.77      0.75      1000
           8       0.82      0.79      0.80      1000
           9       0.65      0.83      0.73      1000

    accuracy                           0.71     10000
   macro avg       0.73      0.71      0.71     10000
weighted avg       0.73      0.71      0.71     10000

Total train images: 45000
Train shape: (45000, 192)
Total test images: 10000





Test shape: (10000, 192)
Distances: (10000, 19)
Indexes: (10000, 19)


Gathering results: 100%|██████████| 10000/10000 [00:00<00:00, 201624.03it/s]

----------------------------
COSINE - TOP_K: 19
              precision    recall  f1-score   support

           0       0.78      0.74      0.76      1000
           1       0.75      0.76      0.75      1000
           2       0.91      0.49      0.64      1000
           3       0.61      0.54      0.57      1000
           4       0.74      0.67      0.70      1000
           5       0.65      0.66      0.65      1000
           6       0.64      0.89      0.74      1000
           7       0.74      0.78      0.76      1000
           8       0.82      0.79      0.81      1000
           9       0.65      0.83      0.73      1000

    accuracy                           0.72     10000
   macro avg       0.73      0.72      0.71     10000
weighted avg       0.73      0.72      0.71     10000






### Using Euclidian Distance

In [55]:
for k in [3, 5, 7, 10, 11, 13, 15, 17, 19]:
    y_pred = brute_force_knn(distance="euclidean", train_data=train_cls_tokens, train_labels=train_labels, test_data=test_cls_tokens, top_k=k)
    print('----------------------------')
    print(f'COSINE - TOP_K: {k}')
    print(classification_report(test_labels, y_pred))

Total train images: 45000
Train shape: (45000, 192)
Total test images: 10000
Test shape: (10000, 192)
Distances: (10000, 3)
Indexes: (10000, 3)


Gathering results: 100%|██████████| 10000/10000 [00:00<00:00, 296849.41it/s]

----------------------------
COSINE - TOP_K: 3
              precision    recall  f1-score   support

           0       0.74      0.71      0.72      1000
           1       0.71      0.70      0.71      1000
           2       0.75      0.53      0.62      1000
           3       0.53      0.49      0.51      1000
           4       0.67      0.66      0.66      1000
           5       0.60      0.59      0.59      1000
           6       0.68      0.82      0.74      1000
           7       0.71      0.77      0.74      1000
           8       0.75      0.77      0.76      1000
           9       0.66      0.74      0.70      1000

    accuracy                           0.68     10000
   macro avg       0.68      0.68      0.67     10000
weighted avg       0.68      0.68      0.67     10000

Total train images: 45000
Train shape: (45000, 192)
Total test images: 10000
Test shape: (10000, 192)





Distances: (10000, 5)
Indexes: (10000, 5)


Gathering results: 100%|██████████| 10000/10000 [00:00<00:00, 276519.56it/s]

----------------------------
COSINE - TOP_K: 5
              precision    recall  f1-score   support

           0       0.75      0.73      0.74      1000
           1       0.73      0.72      0.73      1000
           2       0.79      0.52      0.63      1000
           3       0.56      0.50      0.53      1000
           4       0.67      0.68      0.67      1000
           5       0.62      0.61      0.61      1000
           6       0.68      0.85      0.76      1000
           7       0.73      0.75      0.74      1000
           8       0.76      0.78      0.77      1000
           9       0.66      0.77      0.71      1000

    accuracy                           0.69     10000
   macro avg       0.69      0.69      0.69     10000
weighted avg       0.69      0.69      0.69     10000

Total train images: 45000
Train shape: (45000, 192)
Total test images: 10000
Test shape: (10000, 192)





Distances: (10000, 7)
Indexes: (10000, 7)


Gathering results: 100%|██████████| 10000/10000 [00:00<00:00, 255840.72it/s]

----------------------------
COSINE - TOP_K: 7
              precision    recall  f1-score   support

           0       0.77      0.72      0.75      1000
           1       0.75      0.73      0.74      1000
           2       0.82      0.51      0.63      1000
           3       0.57      0.52      0.54      1000
           4       0.69      0.68      0.69      1000
           5       0.63      0.62      0.62      1000
           6       0.66      0.86      0.75      1000
           7       0.73      0.76      0.75      1000
           8       0.77      0.80      0.78      1000
           9       0.66      0.80      0.73      1000

    accuracy                           0.70     10000
   macro avg       0.71      0.70      0.70     10000
weighted avg       0.71      0.70      0.70     10000

Total train images: 45000
Train shape: (45000, 192)
Total test images: 10000
Test shape: (10000, 192)





Distances: (10000, 10)
Indexes: (10000, 10)


Gathering results: 100%|██████████| 10000/10000 [00:00<00:00, 233521.55it/s]

----------------------------
COSINE - TOP_K: 10
              precision    recall  f1-score   support

           0       0.77      0.73      0.75      1000
           1       0.76      0.74      0.75      1000
           2       0.84      0.51      0.63      1000
           3       0.59      0.53      0.56      1000
           4       0.68      0.69      0.69      1000
           5       0.64      0.63      0.64      1000
           6       0.67      0.87      0.76      1000
           7       0.74      0.77      0.76      1000
           8       0.78      0.79      0.79      1000
           9       0.67      0.81      0.73      1000

    accuracy                           0.71     10000
   macro avg       0.71      0.71      0.70     10000
weighted avg       0.71      0.71      0.70     10000

Total train images: 45000
Train shape: (45000, 192)
Total test images: 10000
Test shape: (10000, 192)





Distances: (10000, 11)
Indexes: (10000, 11)


Gathering results: 100%|██████████| 10000/10000 [00:00<00:00, 213484.26it/s]

----------------------------
COSINE - TOP_K: 11
              precision    recall  f1-score   support

           0       0.77      0.74      0.76      1000
           1       0.76      0.74      0.75      1000
           2       0.84      0.51      0.63      1000
           3       0.58      0.52      0.55      1000
           4       0.69      0.69      0.69      1000
           5       0.63      0.62      0.63      1000
           6       0.66      0.87      0.75      1000
           7       0.74      0.77      0.75      1000
           8       0.78      0.80      0.79      1000
           9       0.67      0.82      0.74      1000

    accuracy                           0.71     10000
   macro avg       0.71      0.71      0.70     10000
weighted avg       0.71      0.71      0.70     10000

Total train images: 45000
Train shape: (45000, 192)
Total test images: 10000
Test shape: (10000, 192)





Distances: (10000, 13)
Indexes: (10000, 13)


Gathering results: 100%|██████████| 10000/10000 [00:00<00:00, 209054.59it/s]

----------------------------
COSINE - TOP_K: 13
              precision    recall  f1-score   support

           0       0.77      0.74      0.76      1000
           1       0.76      0.74      0.75      1000
           2       0.85      0.49      0.62      1000
           3       0.58      0.52      0.55      1000
           4       0.70      0.69      0.70      1000
           5       0.64      0.63      0.64      1000
           6       0.65      0.88      0.75      1000
           7       0.74      0.76      0.75      1000
           8       0.77      0.79      0.78      1000
           9       0.66      0.81      0.73      1000

    accuracy                           0.71     10000
   macro avg       0.71      0.71      0.70     10000
weighted avg       0.71      0.71      0.70     10000

Total train images: 45000
Train shape: (45000, 192)
Total test images: 10000
Test shape: (10000, 192)





Distances: (10000, 15)
Indexes: (10000, 15)


Gathering results: 100%|██████████| 10000/10000 [00:00<00:00, 207114.87it/s]

----------------------------
COSINE - TOP_K: 15
              precision    recall  f1-score   support

           0       0.78      0.74      0.76      1000
           1       0.76      0.74      0.75      1000
           2       0.86      0.48      0.62      1000
           3       0.58      0.51      0.55      1000
           4       0.70      0.70      0.70      1000
           5       0.63      0.64      0.63      1000
           6       0.65      0.88      0.75      1000
           7       0.75      0.76      0.75      1000
           8       0.77      0.81      0.79      1000
           9       0.67      0.82      0.74      1000

    accuracy                           0.71     10000
   macro avg       0.72      0.71      0.70     10000
weighted avg       0.72      0.71      0.70     10000

Total train images: 45000
Train shape: (45000, 192)
Total test images: 10000





Test shape: (10000, 192)
Distances: (10000, 17)
Indexes: (10000, 17)


Gathering results: 100%|██████████| 10000/10000 [00:00<00:00, 193190.64it/s]

----------------------------
COSINE - TOP_K: 17
              precision    recall  f1-score   support

           0       0.77      0.74      0.75      1000
           1       0.75      0.73      0.74      1000
           2       0.87      0.48      0.61      1000
           3       0.59      0.51      0.55      1000
           4       0.69      0.68      0.69      1000
           5       0.62      0.65      0.64      1000
           6       0.65      0.88      0.75      1000
           7       0.74      0.75      0.75      1000
           8       0.77      0.81      0.79      1000
           9       0.66      0.82      0.73      1000

    accuracy                           0.70     10000
   macro avg       0.71      0.70      0.70     10000
weighted avg       0.71      0.70      0.70     10000

Total train images: 45000
Train shape: (45000, 192)
Total test images: 10000





Test shape: (10000, 192)
Distances: (10000, 19)
Indexes: (10000, 19)


Gathering results: 100%|██████████| 10000/10000 [00:00<00:00, 191541.72it/s]

----------------------------
COSINE - TOP_K: 19
              precision    recall  f1-score   support

           0       0.78      0.74      0.76      1000
           1       0.76      0.72      0.74      1000
           2       0.87      0.48      0.62      1000
           3       0.61      0.52      0.56      1000
           4       0.70      0.68      0.69      1000
           5       0.63      0.65      0.64      1000
           6       0.64      0.88      0.74      1000
           7       0.75      0.75      0.75      1000
           8       0.77      0.81      0.79      1000
           9       0.66      0.82      0.73      1000

    accuracy                           0.71     10000
   macro avg       0.72      0.71      0.70     10000
weighted avg       0.72      0.71      0.70     10000




