In [None]:
from torch import nn, optim
from tqdm import tqdm
from easyfsl.samplers import TaskSampler
from easyfsl.utils import  sliding_average
import numpy as np

In [2]:
from transformers import AutoAdapterModel,AutoTokenizer

encoder = AutoAdapterModel.from_pretrained("distilbert-base-uncased")
adapter_name = encoder.load_adapter("Elise-hf/distilbert-base-uncased_reddit_categories_unipelft", source="hf", set_active=True)

tokenizer = AutoTokenizer.from_pretrained('distilbert-base-uncased')


Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertAdapterModel: ['vocab_layer_norm.bias', 'vocab_projector.bias', 'vocab_layer_norm.weight', 'vocab_transform.weight', 'vocab_transform.bias']
- This IS expected if you are initializing DistilBertAdapterModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertAdapterModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Fetching 6 files:   0%|          | 0/6 [00:00<?, ?it/s]

In [3]:
import datasets
dataset = datasets.load_dataset("Elise-hf/reddit_categories_clean")

In [6]:
dataset['test'][[1,2]]

{'id': ['7uzqs5', '5komf8'],
 'subreddit': ['nanocurrency', 'TalesFromThePizzaGuy'],
 'title': ['Almost 24 hrs withdrawal from Kucoin to Binance ...',
  'Confused as hell about wages.'],
 'raw_text': ['Is there anyone else with the same problem, I read a lot people got their nano deposited at binance within 10-16 hrs ... I am starting to get worried! <lb>At RaiBlocks.net where shows the transaction, it says “Received, Pending deposit to balance.(Make sure wallet is open and all blocks are downloaded) “<lb>Could it be an issue with Binance wallet? Since they suspended the deposit and withdrawals of Nank<lb><lb>Thank you guys',
  "So I'm new to this board and with tax season coming up I was looking to get some input about filing and some misc shit about my wages.<lb><lb>I'm a driver for a semi-local chain of calzone restaurants in New York (Sorry I'm not actually a pizza guy). NYS Minimum wage is currently $9/hr and I'm making $7.50 with no reimbursement from my employer for mileage. 100

In [3]:
from sklearn.preprocessing import LabelEncoder
# Fit the LabelEncoder on the 'category' column
subcategory_encoder = LabelEncoder()
category_encoder = LabelEncoder()
category_encoder.fit(dataset['train']['category'])
subcategory_encoder.fit(dataset['train']['subcategory'])

def encode_batch(batch):
    """Encodes a batch of input data using the model tokenizer."""
    # Concatenate title and selftext with [SEP] token in between
    combined_text = [title + ' [SEP] ' + selftext for title, selftext in zip(batch['title'], batch['text'])]

    # Encode the text using the tokenizer and add the attention mask
    encoding = tokenizer(combined_text, max_length=512, truncation=True, padding="longest")

    # Encode the labels and add them to the encoding
    encoding['subcategory'] = subcategory_encoder.transform(batch['subcategory'])
    encoding['category'] = category_encoder.transform(batch['category'])

    return encoding

tokenized_dataset = dataset['validation'].map(encode_batch, batched=True,num_proc=6,remove_columns=["text","category","subcategory"])
tokenized_dataset.set_format(type="torch", columns=["input_ids", "attention_mask",  "subcategory"])


In [None]:
from clustering import Clustering
clust = Clustering(encoder,tokenizer)
tokenized_dataset.set_format(type="torch", columns=["input_ids", "attention_mask"])
embeddings = clust.get_features(tokenized_dataset)

In [4]:
import numpy as np
embeddings = np.load('embeddings_fine_tuned/val/val_embeddings.npy')

In [5]:
from clustering import Clustering
clust = Clustering(encoder,tokenizer)
cluster_labels = clust.run_clustering(embeddings)

In [10]:
tokenized_dataset = tokenized_dataset.add_column('cluster', cluster_labels)

ValueError: The table can't have duplicated columns but columns ['cluster'] are duplicated.

In [7]:
tokenized_dataset = tokenized_dataset.filter(lambda example: example['cluster'] != -1)

Filter:   0%|          | 0/101300 [00:00<?, ? examples/s]

In [1]:
tokenized_dataset

NameError: name 'tokenized_dataset' is not defined

In [None]:
from prototypical_net import PrototypicalNetworks
import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = PrototypicalNetworks(encoder).to(device)


In [None]:
from fs_dataset import FewShotTaskDatasetWithWeights
n_way = 2
n_support = 1
n_query = 1
dataset_fw = FewShotTaskDatasetWithWeights(tokenized_dataset, n_way, n_support, n_query)



In [None]:
from torch.utils.data import DataLoader
data_loader = DataLoader(dataset_fw, batch_size=1, shuffle=True)



In [26]:
from scipy.spatial.distance import cdist
import numpy as np
import pickle
from collections import defaultdict
from tqdm import tqdm
import faiss
import torch
import torch.nn as nn
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def get_representative(embeddings, representative_type="mean"):
    """
    Compute the representative point of a set of embeddings.

    Args:
    embeddings (np.array): The embeddings of the instances in a cluster.
    representative_type (str): The type of representative to compute. Options are 'mean', 'centroid', and 'medoid'.

    Returns:
    representative (np.array): The representative point of the embeddings.
    """
    if representative_type == "mean":
        representative = np.mean(embeddings, axis=0)
    elif representative_type == "centroid":
        representative = embeddings[np.argmin(cdist(embeddings, embeddings).sum(axis=1))]
    elif representative_type == "medoid":
        representative = embeddings[np.argmin(cdist(embeddings, embeddings).mean(axis=1))]
    else:
        raise ValueError(f"Unknown representative type: {representative_type}. Choices are 'mean', 'centroid', and 'medoid'.")

    return representative

#second commit
def compute_and_store_prototypes_centroids(train_loader, model, embedding_path, prototype_path, centroid_path, representative_type="mean"):
    model.eval()

    prototypes = defaultdict(list)
    embeddings_per_cluster = defaultdict(list)
    
    # Step 1: Computing the embeddings for the entire train dataset
    with torch.no_grad():
        for batch in tqdm(train_loader):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            cluster_labels = batch['cluster']
            inputs = {'input_ids': input_ids, 'attention_mask': attention_mask}
            
            outputs = model.base_model(**inputs)
            print(outputs.last_hidden_state.shape)
            embeddings = outputs.last_hidden_state[:, 0, :].cpu().detach().numpy() 
            
            for emb, cluster in zip(embeddings, cluster_labels):
                embeddings_per_cluster[cluster].append(emb)

    # Step 2: Computing the prototype (mean embedding) for each class within each cluster
    for cluster, embeddings in embeddings_per_cluster.items():
        embeddings = np.array(embeddings)
        class_labels = cluster_labels[cluster_labels == cluster]
        unique_classes = np.unique(class_labels)
        
        for class_ in unique_classes:
            class_embeddings = embeddings[class_labels == class_]
            prototype = get_representative(class_embeddings, representative_type) 
            prototypes[cluster].append((class_, prototype))
    
    # Step 3: Computing the centroid for each cluster
    centroids = {}
    for cluster, embeddings in embeddings_per_cluster.items():
        centroid = get_representative(embeddings, representative_type)
        centroids[cluster] = centroid

    # Save the prototypes and centroids for each cluster
    with open(embedding_path, "wb") as f:
        pickle.dump(embeddings_per_cluster, f)
    with open(prototype_path, "wb") as f:
        pickle.dump(prototypes, f)
    with open(centroid_path, "wb") as f:
        pickle.dump(centroids, f)


def create_faiss_indexes(prototype_path, centroid_path, index_path):
    # Load the prototypes and centroids
    with open(prototype_path, "rb") as f:
        prototypes = pickle.load(f)
    with open(centroid_path, "rb") as f:
        centroids = pickle.load(f)

    prototype_indexes = {}
    centroid_indexes = {}

    for cluster in prototypes.keys():
        # Create a Faiss index for the prototypes of this cluster
        prototype_index = faiss.IndexFlatL2(prototypes[cluster][0][1].shape[0])
        prototypes_matrix = np.vstack([proto[1] for proto in prototypes[cluster]])
        prototype_index.add(prototypes_matrix)
        prototype_indexes[cluster] = prototype_index

    for cluster in centroids.keys():
        # Create a Faiss index for the centroid of this cluster
        centroid_index = faiss.IndexFlatL2(centroids[cluster].shape[0])
        centroid_index.add(np.expand_dims(centroids[cluster], axis=0))
        centroid_indexes[cluster] = centroid_index

    # Save the indexes using Faiss's built-in functions
    for cluster, index in prototype_indexes.items():
        faiss.write_index(index, f"{index_path}/prototype_{cluster}.index")
    
    for cluster, index in centroid_indexes.items():
        faiss.write_index(index, f"{index_path}/centroid_{cluster}.index")

In [23]:
tokenized_dataset[0]

{'subcategory': tensor(969),
 'input_ids': tensor([  101,  2047, 24978,  2546,  2860, 27830,  2208,  1024,  6396,  8737,
          6806,  1011, 10365,   102,  4931,  3071,   999,  2057,  2024, 27830,
         12260, 11022,  2094,  1998,  2057,  2024,  2551,  2006,  1037,  2047,
          2208,  2170,  6396,  8737,  6806,  1011, 10365,  1012, 27830, 12260,
         11022,  2094,  2003,  1037,  2208,  1011,  2458,  2194,  2437, 27830,
         22555,  6322,  1012,  2057,  2024,  1037,  2176,  1011,  2158,  2136,
          1998,  2057,  2024, 13459,  2055,  2437, 10047, 16862,  3512,  4639,
          7484,  4507,  4180,  1012,  1996,  2208,  2993,  2003,  2422,  1038,
          5104,  2213, 11773,  1010,  1998,  2838,  8750,  4427,  2396,  1012,
          2057,  1005,  2128,  2220,  1999,  2458,  1010,  1998,  2057,  2024,
          5186,  2330,  2000, 12247,  1012,  2256,  2783,  2933,  2003,  2000,
          2031,  2048,  6177,  1997,  1996,  2208,  1010,  2019,  9123,  4664,
          

In [25]:
embeddings_path = "embeddings.pt"
prototype_path = "prototype.pt"
centroids_path = "centroids.pt"
compute_and_store_prototypes_centroids (tokenized_dataset,encoder,embeddings_path,prototype_path,centroids_path)

  0%|          | 0/66488 [00:00<?, ?it/s]


IndexError: Dimension out of range (expected to be in range of [-1, 0], but got 1)

In [None]:
next (iter(data_loader))[0].keys()

In [None]:
import random
from typing import Dict, Iterator, List, Tuple, Union

class TransformerTaskSampler(TaskSampler):
    """
    This sampler extends the TaskSampler but overrides the episodic_collate_fn to work with transformer-based models.
    """
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.sampled_examples = set()

    def __iter__(self) -> Iterator[List[int]]:
        """
        Sample n_way labels uniformly at random,
        and then sample n_shot + n_query items for each label, also uniformly at random.
        Yields:
            a list of indices of length (n_way * (n_shot + n_query))
        """
        total_examples = sum(len(examples) for examples in self.items_per_label.values())

        for task_id in range(self.n_tasks):
            # Sample items for each label
            sampled_items = []
            for label in random.sample(sorted(self.items_per_label.keys()), self.n_way):
                sampled_examples_for_label = random.sample(self.items_per_label[label], self.n_shot + self.n_query)
                self.sampled_examples.update(sampled_examples_for_label)
                sampled_items.append(torch.tensor(sampled_examples_for_label))

       
            yield torch.cat(sampled_items).tolist()
    def episodic_collate_fn(self, input_data):
        """
        Collate function for episodic data loaders.

        Args:
            input_data: each element is a dict containing:
                - 'input_ids': Tensor
                - 'attention_mask': Tensor
                - 'labels': Tensor

        Returns:
            tuple(Tensor, Tensor, Tensor, Tensor, list[int]): respectively:
                - support 'input_ids' of shape (n_way * n_shot, ...)
                - their labels of shape (n_way * n_shot)
                - query 'input_ids' of shape (n_way * n_query, ...)
                - their labels of shape (n_way * n_query)
                - the dataset class ids of the class sampled in the episode
        """

        # Ensure input data is of correct format
        for data in input_data:
            if not isinstance(data, dict):
                raise TypeError(f"Expected dict, got {type(data)}.")
            if not all(isinstance(value, torch.Tensor) for value in data.values()):
                raise TypeError("All dictionary values must be PyTorch Tensors.")

        # Gather true class ids
        true_class_ids = list({x['labels'].item() for x in input_data})

        # Create tensors of 'input_ids' and 'labels'
        all_input_ids = torch.stack([x['input_ids'] for x in input_data])
        all_input_ids = all_input_ids.reshape((self.n_way, self.n_shot + self.n_query, *all_input_ids.shape[1:]))

        all_attention_mask = torch.stack([x['attention_mask'] for x in input_data])
        all_attention_mask = all_attention_mask.reshape((self.n_way, self.n_shot + self.n_query, *all_attention_mask.shape[1:]))

        all_labels = torch.tensor([true_class_ids.index(x['labels'].item()) for x in input_data])
        all_labels = all_labels.reshape((self.n_way, self.n_shot + self.n_query))

        # Separate into support and query sets
        support_input_ids = all_input_ids[:, : self.n_shot].reshape((-1, *all_input_ids.shape[2:]))
        query_input_ids = all_input_ids[:, self.n_shot :].reshape((-1, *all_input_ids.shape[2:]))

        support_attention_mask = all_attention_mask[:, : self.n_shot].reshape((-1, *all_attention_mask.shape[2:]))
        query_attention_mask = all_attention_mask[:, self.n_shot :].reshape((-1, *all_attention_mask.shape[2:]))

        support_labels = all_labels[:, : self.n_shot].flatten()
        query_labels = all_labels[:, self.n_shot :].flatten()

        return (
            {'input_ids': support_input_ids, 'attention_mask': support_attention_mask},
            support_labels,
            {'input_ids': query_input_ids, 'attention_mask': query_attention_mask},
            query_labels,
            true_class_ids,
        )


In [None]:
N_WAY = 5 # Number of classes in a task
N_SHOT = 5  # Number of images per class in the support set
N_QUERY = 5  # Number of images per class in the query set
N_EVALUATION_TASKS = 100

# Your dataset still needs a "get_labels" method
dataset_val.get_labels = lambda: dataset_val['labels'].tolist()

test_sampler = TransformerTaskSampler(
    dataset_val, n_way=N_WAY, n_shot=N_SHOT, n_query=N_QUERY, n_tasks=N_EVALUATION_TASKS
)

test_loader = DataLoader(
    dataset_val,
    batch_sampler=test_sampler,
    num_workers=4,
    pin_memory=True,
    collate_fn=test_sampler.episodic_collate_fn
)


In [None]:
d

We created a dataloader that will feed us with 5-way 5-shot tasks (the most common setting in the litterature).
Now, as every data scientist should do before launching opaque training scripts,
let's take a look at our dataset.

In [None]:
(
    example_support_inputs,
    example_support_labels,
    example_query_inputs,
    example_query_labels,
    example_class_ids,
) = next(iter(test_loader))


In [None]:
model.eval()
example_scores = model(
    example_support_inputs,
    example_support_labels.cuda(),
    example_query_inputs,
).detach()

_, example_predicted_labels = torch.max(example_scores.data, 1)
example_predicted_labels


In [None]:
import gc

# del example_support_inputs
# del example_query_inputs
del model
gc.collect()

torch.cuda.empty_cache()

In [None]:
example_support_inputs

This doesn't look bad: keep in mind that the model was trained on very different images, and has only seen 5 examples for each class!

Now that we have a first idea, let's see more precisely how good our model is.

In [None]:
def evaluate_on_one_task(
    support_images: torch.Tensor,
    support_labels: torch.Tensor,
    query_images: torch.Tensor,
    query_labels: torch.Tensor,
) :
    """
    Returns the number of correct predictions of query labels, and the total number of predictions.
    """
    return (
        torch.max(
            model(support_images, support_labels.cuda(), query_images)
            .detach()
            .data,
            1,
        )[1]
        == query_labels.cuda()
    ).sum().item(), len(query_labels)

In [None]:

def evaluate(data_loader: DataLoader, sampler: TransformerTaskSampler):
    total_predictions = 0
    correct_predictions = 0
    total_classes =set()
    model.eval()
    with torch.no_grad():
        progress_bar = tqdm(enumerate(data_loader), total=len(data_loader))
        for episode_index, (
            support_images,
            support_labels,
            query_images,
            query_labels,
            class_ids,
        ) in progress_bar:

            correct, total = evaluate_on_one_task(
                support_images, support_labels, query_images, query_labels
            )

            total_predictions += total
            correct_predictions += correct

            # Compute and display the percentage of examples sampled so far
            total_examples = sum(len(examples) for examples in sampler.items_per_label.values())
            percentage_sampled = len(sampler.sampled_examples) / total_examples * 100
            progress_bar.set_description(f'Sampled {percentage_sampled:.2f}% of examples')

            # store used classes
            total_classes.update(class_ids)
            progress_bar.set_postfix( # display some useful information
                accuracy=f"{100 * correct_predictions/total_predictions:.2f}%",
                sampled_classes=f"{len(total_classes)}/{len(le.classes_)}",
            )



    print(
        f"Model tested on {len(data_loader)} tasks. Accuracy: {(100 * correct_predictions/total_predictions):.2f}%"
    )

evaluate(test_loader, test_sampler)


With absolutely zero training on Omniglot images, and only 5 examples per class, we achieve around 86% accuracy! Isn't this a great start?

Now that you know how to make Prototypical Networks work, you can see what happens if you tweak it
a little bit (change the backbone, use other distances than euclidean...) or if you change the problem
(more classes in each task, less or more examples in the support set, maybe even one example only,
but keep in mind that in that case Prototypical Networks are just standard nearest neighbour).

When you're done, you can scroll further down and learn how to **meta-train this model**, to get even better results.

## Training a meta-learning algorithm

Let's use the "background" images of Omniglot as training set. Here we prepare a data loader of 40 000 few-shot classification
tasks on which we will train our model. The alphabets used in the training set are entirely separated from those used in the testing set.
This guarantees that at test time, the model will have to classify characters that were not seen during training.

Note that we don't set a validation set here to keep this notebook concise,
but keep in mind that **this is not good practice** and you should always use validation when training a model for production.

In [None]:
N_WAY = 5 # Number of classes in a task
N_SHOT = 5  # Number of images per class in the support set
N_QUERY = 5  # Number of images per class in the query set
N_TRAINING_EPISODES = 1
N_VALIDATION_TASKS = 100
dataset_train = dataset['train']
dataset_train.get_labels = lambda: dataset_train['labels'].tolist()
train_sampler = TransformerTaskSampler(
   dataset_train, n_way=N_WAY, n_shot=N_SHOT, n_query=N_QUERY, n_tasks=N_TRAINING_EPISODES
)
train_loader = DataLoader(
    dataset_train,
    batch_sampler=train_sampler,
    num_workers=12,
    pin_memory=True,
    collate_fn=train_sampler.episodic_collate_fn,
)

We will keep the same model. So our weights will be pre-trained on ImageNet. If you want to start a training from scratch,
feel free to set `pretrained=False` in the definition of the ResNet.

Here we define our loss and our optimizer (cross entropy and Adam, pretty standard), and a `fit` method.
This method takes a classification task as input (support set and query set). It predicts the labels of the query set
based on the information from the support set; then it compares the predicted labels to ground truth query labels,
and this gives us a loss value. Then it uses this loss to update the parameters of the model. This is a *meta-training loop*.

In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)


def fit(
    support_images: torch.Tensor,
    support_labels: torch.Tensor,
    query_images: torch.Tensor,
    query_labels: torch.Tensor,
) -> float:
    optimizer.zero_grad()
    classification_scores = model(
        support_images, support_labels.cuda(), query_images
    )

    loss = criterion(classification_scores, query_labels.cuda())
    loss.backward()
    optimizer.step()

    return loss.item()

To train the model, we are just going to iterate over a large number of randomly generated few-shot classification tasks,
and let the `fit` method update our model after each task. This is called **episodic training**.

This took me 20mn on an RTX 2080 and I promised you that this whole tutorial would take 15mn.
So if you don't want to run the training yourself, you can just skip the training and load the model that I trained
using the exact same code.

In [None]:
# Train the model yourself with this cell

log_update_frequency = 10

all_loss = []
# model.train()
with tqdm(enumerate(train_loader), total=len(train_loader)) as tqdm_train:
    for episode_index, (
        support_images,
        support_labels,
        query_images,
        query_labels,
        _,
    ) in tqdm_train:
        loss_value = fit(support_images, support_labels, query_images, query_labels)
        all_loss.append(loss_value)

        if episode_index % log_update_frequency == 0:
            tqdm_train.set_postfix(loss=sliding_average(all_loss, log_update_frequency))

        if episode_index % N_VALIDATION_TASKS == 0:
          evaluate(test_loader, test_sampler)
evaluate(test_loader, test_sampler)



In [None]:
# Train the model yourself with this cell

log_update_frequency = 10

all_loss = []
model.train()
with tqdm(enumerate(train_loader), total=len(train_loader)) as tqdm_train:
    for episode_index, (
       batch
    ) in tqdm_train:
        support_inputs,query_inputs,weights = batch
        
        
        loss_value = fit(support_images, support_labels, query_images, query_labels)
        all_loss.append(loss_value)

        if episode_index % log_update_frequency == 0:
            tqdm_train.set_postfix(loss=sliding_average(all_loss, log_update_frequency))

        if episode_index % N_VALIDATION_TASKS == 0:
          evaluate(test_loader, test_sampler)
evaluate(test_loader, test_sampler)



In [None]:
# Or just load mine

!wget https://public-sicara.s3.eu-central-1.amazonaws.com/easy-fsl/resnet18_with_pretraining.tar
model.load_state_dict(torch.load("resnet18_with_pretraining.tar", map_location="cuda"))

Now let's see if our model got better!

In [None]:
evaluate(test_loader)

Around 98%!

It's not surprising that the model performs better after being further trained on Omniglot images than it was with its
ImageNet-based parameters. However, we have to keep in mind that the classes on which we just evaluated our model were still
**not seen during training**, so 99% (with a 12% improvement over the model trained on ImageNet) seems like a decent performance.

## What have we learned?

- What a Prototypical Network is and how to implement one in 15 lines of code.
- How to use Omniglot to evaluate few-shot models
- How to use custom PyTorch objets to sample batches in the shape of a few-shot classification tasks.
- How to use meta-learning to train a few-shot algorithm.

## What's next?

- Take this notebook in your own hands, tweak everything that there is to tweak. It's the best way to understand what does what.
- Implement other few-shot learning methods, such as Matching Networks, Relation Networks, MAML...
- Try other ways of training. Episodic training is not the only way to train a model to generalize to new classes!
- Experiment on other, more challenging few-shot learning benchmarks, such as [CUB](http://www.vision.caltech.edu/visipedia/CUB-200.html)
or [Meta-Dataset](https://github.com/google-research/meta-dataset).
- If you liked this tutorial, feel free to ⭐ [give us a star on Github](https://github.com/sicara/easy-few-shot-learning) ⭐
- **Contribute!** The companion repository of this notebook is meant to become a boilerplate, a source of useful code that
that newcomers can use to start their few-shot learning projects.
