# AML Exercise Modern Transformer Architecture

In the lecture it was mentioned that transformer architectures dominate in many domains. A recent architecture that uses transformers is ChatGPT. The system was roughly explained in the lecture. 
To further bridge the gap between the original transformer architecture shown in the lecture and modern architectures, we want you to look at another very new system: SAM (Segment Anything Model).

Read the according paper, published in April 2023: 
[Segment Anything](https://arxiv.org/pdf/2304.02643.pdf). 
There is a [website](https://segment-anything.com/) available where you can try out the system (being run on a web server). Please do so.

Use the paper to answer the following questions:

 
1) What is the goal of the model and what makes it special?

2) Explain how the architecture relates to a standard transformer architecture. Where can you identify components you know from the lecture and what is new?

3) What is a foundation model? Name another foundation model that was covered in the lecture. 

4) Foundation models require huge amounts of labelled data. How do foundation models (like the one in your previous answer) typically solve this issue? 

5) Would a similar approach have been possible for SAM? Explain the strategy the authors of SAM used to overcome the data problem.

6) Can SAM be used for Zero-Shot object classification? What if there are multiple objects in the image?

# AML Exercise Distance Metric Learning

In this exercise sheet, we will review pair-based losses and distance-based losses.
We will also look into a novel proxy-based loss for distance metric learning.
Lastly, you will use prototypical networks to perform few-shot learning on the
[Omniglot dataset](https://github.com/brendenlake/omniglot).


## Exercise 1 Pair-based vs. Proxy-based DML Methods

In the lecture, you have encountered pair-based DML, e.g. the triplet loss, and proxy-based methods, e.g. the Proxy-NCA loss.
Briefly describe the main difference the two DML methods and what the pros and cons are.
Feel free to use section 2 of the paper of the next exercise: https://arxiv.org/pdf/2003.13911.pdf

`TODO Enter answers`


## Exercise 2 Proxy-Anchor-Loss
The Proxy-NCA loss was one of the first proxy-based DML methods and was described in the lecture. In recent years, novel proxy-based losses have been introduced.
One example of such a loss is the Proxy-Anchor loss, introduced by Kim et al.
Read through section 3 of the paper [
Kim et al. (2020): "Proxy Anchor Loss for Deep Metric Learning"](https://arxiv.org/pdf/2003.13911.pdf) and answer:
 - How does the Proxy-Anchor-loss work?
 - What are its benefits in comparison to Proxy NCA?



`TODO Enter answers`

## Exercise 3 Few-Shot Learning using Prototypical Networks

We will now look at Prototypical networks you have encountered in the lecture.
In particular, we will try to replicate the experiments of section 3.1
of the [paper which introduced prototypical networks](https://arxiv.org/pdf/1703.05175.pdf) .
We will train on the
[OmniGlot data set](https://github.com/brendenlake/omniglot), consisting of images of characters of 50 alphabets.
30 alphabets belong to the "background" data set and should be used to train FSL models,
while the remaining 20 alphabets are used for evaluation.
The task of the Prototypical Network is to classify to which given character image samples query images belong to.
Let's first load necessary imports.

In [None]:
import sys
from datetime import datetime
from pathlib import Path
from typing import Dict, List, Tuple

import io
import json

import PIL
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from tqdm import tqdm

from torch.optim.adam import Adam
from torch.optim.lr_scheduler import StepLR
from torch.utils.tensorboard import SummaryWriter
from torchvision.transforms import ToTensor
from sklearn.manifold import TSNE

We will now load some data utilities we have set up. Make sure to load the corresponding files from StudIP.
In the `prepare_data_set` method, the Omniglot data set is downloaded to a `data` directory next to this notebook.
The images are also downscaled to 28x28 pixels and rotated four times (0, 90, 180, 270 degrees).
The rotations build their own distinct classes.

Note that this operation can take a while.

In [None]:
import sys
sys.path.insert(0, '.')

from prepare_data import OmniGlotDataSet, sample_from_torch_sample_dict
DATA_DIR = Path(".") / "data"
# Load example data
omniglot_dataset = OmniGlotDataSet(DATA_DIR)
omniglot_dataset.prepare_data_set()

We will now setup our neural network encoder as defined in the paper. It is a Conv-Net using four blocks, each consisting of a 2D Conv, Relu, BatchNorm and MaxPool layer. It maps our 28 x 28 grayscale images to a 64-dimensional embedding vector.

In [None]:
from encoder import OmniGlotEncoder

EMB_DIM = 64
encoder = OmniGlotEncoder(EMB_DIM)
print(encoder)


### Exercise 3.1 Implementing the ProtoNet Loss

We will now get to the core of this exercise. You will need to implement the ProtoNet loss as described in the original paper.
The `ProtoNet.calc_loss` methods has comments which will guide you through your implementation.
We have setup a JSON file with pre-defined labels and their support / query sets, on which you can test your implementation (see the following cells).


In [None]:
class ProtoNet(nn.Module):
    """
    Implements a ProtoNet as introduced by Snell etl al. (2017): https://arxiv.org/pdf/1703.05175.pdf
    """

    def __init__(self, encoder: nn.Module, emb_dim: int):
        """
        Initializes the Protonet.
        :encoder: A neural network which maps data samples to embedding vectors.
        :emb_dim: The dimensionality of the embedding networks of the encoder.
        """
        super().__init__()
        self.encoder = encoder
        self.emb_dim = emb_dim

    def calc_loss(self, support_set_batch: Tensor, query_set_batch: Tensor, device: torch.device) -> Tuple[Tensor, Tensor]:
        """
        Calculates the loss for one training epoch of a prototypical network
        :param support_set_batch: A tensor of shape: num_classes x num_support_samples_per_class  x embedding_dim
        :param query_set_batch: A tensor of shape: num_classes x num_query_samples_per_class x embedding_dim
        :param device: The device to which target values for the loss calculation should be put on.
        :return A tuple of the loss and the accuracy for accurately predicting the true class of the query samples.
        """
        # The dimensionality of the batches
        num_classes, num_support = support_set_batch.shape[0], support_set_batch.shape[1]
        num_query = query_set_batch.shape[1]
        data_shape = query_set_batch.shape[2:]

        # Step 1: Compute embeddings using the encoder
        ## YOUR CODE START
        all_data = None # Placeholder
        ## S1.1. Merge support and query samples to one large batch, such that we can compute the embeddings faster using parallelization

        ## S1.1.1 First "unravel" / reshape all support samples to one batch of shape: (num_classes * num_support, [data_shape])

        ## S1.1.2 Also "unravel" the query batch to shape (num_classes * num_query, [data_shape])

        ## S1.1.3 Concatenate the batches (support-set first) to the shape (num_classes * (num_support + num_query), data_shape)

        ## YOUR CODE END
        ## Step 1.2: Now compute the embeddings on the larger data batch
        ## Step 1.2. Use the encoder to get the embeddings of all data samples
        embeddings = self.encoder(all_data)
        assert embeddings.shape == (num_classes * (num_support + num_query), self.emb_dim)

        ## Step 1.3 Calculate the proto vectors (centroids)
        ### Step 1.3.1 Get the support set  from the embeddings batch and reshape it to (num_classes, num_support, emb_dim)
        ### YOUR CODE START
        ### Step 1.3.2 Calculate the mean of the embeddings. This should result in a tensor of shape (num_classes, emb_dim)

        ## Step 1.4 Calculate the distances between proto and query vectors
        ### Step S.1.4.1 Get the query embeddings from the large data batch of step 5.3

        ### Step S.1.4.2 Compute the pairwise euclidian distance between query embeddings and prototypical vectors
        ### Tip: Use pytorch cdist

        ## Step 1.5 Compute the loss of the protonet
        ### Step 1.5.1 Apply the log_softmax on negative (!) distances;
        ### Note: Be sure that the softmax is applied such that for each query embedding, there is a (log) prob distribution over the classes
        ## Step 1.5.2: Multiply the log_probs with -1, as we aim to minimize the loss (but would maximize the log probs)
        neg_log_probs = None # Placeholder

        ### Step 1.5.2: Reshape the negative log probs to the shape (num_classes, num_query, num_classes)
        ### I.e., we again have the same batch shape as the input (num_classes, num_query, ...rest),
        ### but now, the rest is not the shape of the original data, but the log probs to each of the 12 classes

        ### Step 1.5.3 We now setup target indicies representing the class labels
        # We want to setup a (num_classes, num_query, 1) matrix which we then can use to get the negative log prob value
        # of the actual target class of each sample
        # For example, if we would have num_classes=5 and num_query=4, then the targets would be (as a 2d matrix)
        # [[0 0 0 0]
        #  [1 1 1 1]
        #  [2 2 2 2]
        #  [3 3 3 3]
        #  [4 4 4 4]]
        targets = None # Placeholder

        ### Step 1.5.4 Use "torch.gather" along the dim=2 to get the index of the log probs of the target class of each sample

        ### Step 1.5.5: Take the mean of the negative log probs of target classes. This is our loss
        ### YOUR CODE END

        # We also predict the classes of thq query samples using the smallest negated log likelihood
        _, class_predictions = neg_log_probs.min(2)
        accuracy = torch.eq(class_predictions, targets.squeeze()).float().mean()

        return loss, accuracy

### Testing the ProtoNet Loss
We will first test the untrained ProtoNet on the example data defined in the `example_data.json` uploaded on StudIP.
Using an untrained ProtoNet, you should get an accuracy of approx 0.72 and a loss of approx 1.5.
Note that due to implementation details, your loss might differ slightly. 

In [None]:
# Load Example JSON
example_json_file = Path(".") / "example_data.json"
print_json_content = False  ## Change this if printed json contents are too large
if example_json_file.exists():
    print(f"Loading example data json from {example_json_file.absolute()}")
    with open(example_json_file) as fp:
        example_dict = json.load(fp)
        if print_json_content:
            print("Example JSON file content")
            print(json.dumps(example_dict, indent=4))

else:
    print(f"Please download the 'example_data.json' file from StudIP.")
    sys.exit()
example_support, example_query, example_classes = omniglot_dataset.load_example_label_to_img_dict(example_dict)


# Init Protonet and test on example data
DEVICE = "cpu"
proto_net = ProtoNet(encoder, EMB_DIM)
loss, acc = proto_net.calc_loss(example_support.to(DEVICE), example_query.to(DEVICE), DEVICE)
print(f"Example loss before loading: {loss}")
print(f"Example accuracy before loading: {acc}")

assert loss > 1.0
assert acc < 1.0

### Testing a Pre-trained ProtoNet

We will now test the supplied `proto-net-release.pth`, which has been trained for 1000 iterations.
You should see a significantly smaller loss (~0.016) and an accuracy of 100%. 
The exact loss might also differ because of implementation details. 


In [None]:
MODEL_PATH = Path(".") / "proto-net-release.pth"
state_dict = torch.load(MODEL_PATH)
print("Loading protonet and optimizer state dicts")
proto_net.load_state_dict(state_dict["proto_net"])

loss, acc = proto_net.calc_loss(example_support.to(DEVICE), example_query.to(DEVICE), DEVICE)
print(f"Example loss after loading: {loss}")
print(f"Example accuracy afer loading: {acc}")

assert np.isclose(acc, 1.0)
assert loss < 1.0

## Exercise 3.3 Training the ProtoNet
We will now train the ProtoNet for a small amount of epochs. You do not have to implement anything in this part. 
If your loss is correctly implemented, you should see a high accuracy (>95%) after training for one epoch of 2000 episodes.
If you want to have the same settings as in the paper, you would need to set the number of epochs to 5 and train for 2000 episodes per epoch.
Note, however, that one iteration is rather slow if you do not have a GPU.




In [None]:
# Initializing some training settings
MAX_NUM_EPOCHS = 1  # Use 5 epochs for a setup like in the paper
NUM_EPISODES_PER_EPOCH = 1000  # Use 2000 episodes for a setup like in the paper
RANDOM_SEED = 42

# Constants
LEARNING_RATE = 10 ** -3
NUM_CLASSES_PER_TRAIN_EPISODE = 60
NUM_SUPPORT_SAMPLES = 5
NUM_QUERY_SAMPLES = 5

from random import seed
# First we set some random seeds.
seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)

# Reset the network
def weight_reset(m):
    # https://discuss.pytorch.org/t/how-to-re-set-alll-parameters-in-a-network/20819/7
    reset_parameters = getattr(m, "reset_parameters", None)
    if callable(reset_parameters):
        m.reset_parameters()

print("Reseting weights of ProtoNet")
proto_net.apply(weight_reset)

# # Setup a TB Writer
LOG_DIR = Path(".")  / "logs"
TB_DIR = LOG_DIR / "tb" / str(datetime.now().strftime("%Y-%m-%d-%H-%M-%S"))
TB_DIR.mkdir(exist_ok=True, parents=True)
tb_writer = SummaryWriter(TB_DIR)

We will now load dictionaries which map the character classes to a list of tensors representing the pre-processed
images.
We load all images to memory such that training becomes faster. 
This operation also will take a while to finish.

In [None]:
USE_MULTI_PROC = False
bg_label_to_img = omniglot_dataset.init_torch_sample_dict(part="background", use_multi_proc=USE_MULTI_PROC)
eval_label_to_img = omniglot_dataset.init_torch_sample_dict(part="evaluation", use_multi_proc=USE_MULTI_PROC)

Next, we setup the optimizer and a learning rate scheduler which halfs the learning rate every epoch.



In [None]:
LEARNING_RATE = 10 ** -3
optim = Adam(params=proto_net.parameters(), lr=LEARNING_RATE)
# A scheduler which halfs the learning rate
scheduler = StepLR(optim, 1, gamma=0.5, last_epoch=-1)

We will now start training, which will take a while. You can see the number of iterations per second in the progress bar. 
If your loss is correctly implemented, the loss should  decline, while the accuracy increases.
Let the training script finish and post plots of  the training loss and accuracy (see tensorboard), as well as the eval results below.
Note: When you test on the eval set, then the model sees samples of labels which were not used during training.

In [None]:
current_epoch = 0
global_step = 0

bg_classes = list(sorted(bg_label_to_img.keys()))
eval_classes = list(sorted(eval_label_to_img.keys()))

print("Starting training...")
print(f"See the progress in tensorboard. Run:")
print(f"tensorboard --logdir {TB_DIR}")

while current_epoch < MAX_NUM_EPOCHS:
    print(f"Starting new train epoch: {current_epoch}/{MAX_NUM_EPOCHS}")

    print(f"Starting new training loop of {NUM_EPISODES_PER_EPOCH} episodes")
    epoch_train_losses = []
    epoch_train_accs = []
    # Train Loop
    proto_net.train()

    print("Starting train loop")
    train_p_bar = tqdm(range(NUM_EPISODES_PER_EPOCH), total=NUM_EPISODES_PER_EPOCH)
    for episode in train_p_bar:
        support_set, query_set, sampled_classes = sample_from_torch_sample_dict(
            bg_classes, bg_label_to_img, NUM_CLASSES_PER_TRAIN_EPISODE, NUM_SUPPORT_SAMPLES, NUM_QUERY_SAMPLES
        )
        optim.zero_grad()
        loss, acc = proto_net.calc_loss(support_set.to(DEVICE), query_set.to(DEVICE), DEVICE)
        loss.backward()
        optim.step()
        # Log the loss and accuracy
        loss_float = float(loss.item())
        acc_float = float(acc.item())
        tb_writer.add_scalar("train_step_loss", loss_float, global_step)
        tb_writer.add_scalar("train_step_acc", acc_float, global_step)
        epoch_train_losses.append(loss_float)
        epoch_train_accs.append(acc_float)
        global_step += 1

        if episode % 10 == 0:
            train_p_bar.set_postfix(accuracy=100. * np.mean(epoch_train_accs),
                                    loss=np.mean(epoch_train_losses))

    print(f"Avg Train Loss: {np.mean(epoch_train_losses)}")
    print(f"Train Acc. ({NUM_CLASSES_PER_TRAIN_EPISODE}-way {NUM_SUPPORT_SAMPLES}-Shot): : {np.mean(epoch_train_accs)}")

    tb_writer.add_scalar("epoch_avg_train_loss", np.mean(epoch_train_losses))
    tb_writer.add_scalar("epoch_avg_train_acc", np.mean(epoch_train_accs))

    # Half the learning rate
    print("Halfing learning rate...")
    scheduler.step()
    current_epoch += 1

print("Finished training!")
# Test loop
test_losses = []
test_accs = []
proto_net.eval()
print("Starting to calculate eval results")
NUM_TESTING_EPISODES = 1000
NUM_CLASSES_PER_TEST_EPISODE = 5

test_p_bar = tqdm(range(NUM_TESTING_EPISODES), total=NUM_TESTING_EPISODES)
for test_episode in test_p_bar:
    with torch.no_grad():
        support_set, query_set, sampled_classes = sample_from_torch_sample_dict(
            eval_classes, eval_label_to_img, NUM_CLASSES_PER_TEST_EPISODE, NUM_SUPPORT_SAMPLES, NUM_QUERY_SAMPLES
        )
        loss, acc = proto_net.calc_loss(support_set.to(DEVICE), query_set.to(DEVICE), DEVICE)
    # Log the loss and accuracy
    test_losses.append(float(loss.item()))
    test_accs.append(float(acc.item()))

    if test_episode % 100 == 0:
        test_p_bar.set_postfix(accuracy=100. * np.mean(test_accs), loss=np.mean(test_losses))


print(f"Test-Loss: {np.mean(test_losses)}")
print(f"Test-Acc ({NUM_CLASSES_PER_TEST_EPISODE}-way {NUM_SUPPORT_SAMPLES}-Shot): {np.mean(test_accs)}")


`TODO Report test accuracy, tensorboard logs`