<a href="https://colab.research.google.com/github/neuromatch/course-content-template/blob/main/tutorials/W1D2_Template/W1D2_Tutorial1.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a> &nbsp; <a href="https://kaggle.com/kernels/welcome?src=https://raw.githubusercontent.com/neuromatch/course-content-template/main/tutorials/W1D2_Template/W1D2_Tutorial1.ipynb" target="_parent"><img src="https://kaggle.com/static/images/open-in-kaggle.svg" alt="Open in Kaggle"/></a>

# Tutorial 4.(last part): Inductive Biases in Self-Attention Mechanisms

**Week [1], Day [4]: [Micro-Circuits]**

**By Sciencematch Academy** # update to the correct academy

__Content creators:__ Names & Surnames

__Content reviewers:__ Names & Surnames

__Production editors:__ Names & Surnames

<br>

Acknowledgments: [ACKNOWLEDGMENT_INFORMATION]


___

Use a line (---) separator from title block to objectives. 

# Tutorial Objectives

*Estimated timing of tutorial: [insert estimated duration of whole tutorial in minutes]*

Transformers have been surprisingly successful in almost every modality in ML, surpassing every other deep-learning model in performance and generalization especially in sequence generation and self-supervised representation learning.

But unlike MultiLayer fully connected Perceptrons (MLPs), Recurrent Neural Networks (RNNs), and Convolution Neural Nets (CNNs), transformer architecture seems to have no explicit inductive bias. This goes against the common understanding in machine learning, that to achieve good generalization, one must exploit the geometry of data and problem and build the necessary symmetries into the architecture.

This raises the question of whether we can find a class of functions (i.e. set of problems) that transformers excel at in representing. To this end, here we try to present the results from [Inductive Biases and Variable Creation in Self-Attention Mechanisms](https://arxiv.org/abs/2110.10090) which shows that a single self-attention head can successfully learn to represent a sparse-function with an extremely efficient sample-complexity. Although the original paper is rather focused on the rigorous proof of the claims, in this tutorial we will show-case their findings and general message. *Throughout this tutorial, we will refer to the aforementioned paper as "the paper".*

**Tutorial Learning Structure**
* We first introduce class of *s-sparce binary AND functions*
* You then train an MLP model on an s-sparse AND dataset for classification task
* We present a basic self-attention architecture for classification (similar to those used in BERT and ViT)
* Next you should train the transformer model on the s-sparse AND dataset for classification task
* Finally, we compare the results from both architecture.

**Tutorial Learning Objectives**
By the end you should:
* feel familiar with the self-attention architecture
* 

**References**:
- The transformer code is from "https://github.com/MathInf/toroidal" by Thomas Viehmann
- The content and results are inspired by "Edelman et al. (2022), [Inductive Biases and Variable Creation in Self-Attention Mechanisms](https://proceedings.mlr.press/v162/edelman22a.html)"

Tutorial Slides "link_id"s will be added in below by the curriculum or production team. You do not need to do anything but leave the block of code below here.

In [None]:
# @title Tutorial slides

# @markdown These are the slides for the videos in all tutorials today


## Uncomment the code below to test your function

#from IPython.display import IFrame
#link_id = "<YOUR_LINK_ID_HERE>"

print("If you want to download the slides: 'Link to the slides'")
      # Example: https://osf.io/download/{link_id}/

#IFrame(src=f"https://mfr.ca-1.osf.io/render?url=https://osf.io/download/{link_id}/?direct%26mode=render", width=854, height=480)

---
# Setup



In [None]:
# @title Install and import feedback gadget

# note this is not relevant for climatematch at the moment

!pip3 install vibecheck datatops --quiet

from vibecheck import DatatopsContentReviewContainer
def content_review(notebook_section: str):
    return DatatopsContentReviewContainer(
        "",  # No text prompt - leave this as is
        notebook_section,
        {
            "url": "https://pmyvdlilci.execute-api.us-east-1.amazonaws.com/klab",
            "name": "sciencematch_sm", # change the name of the course : neuromatch_dl, climatematch_ct, etc
            "user_key": "y1x3mpx5",
        },
    ).render()

# Define the feedback prefix: Replace 'weeknumber' and 'daynumber' with appropriate values, underscore followed by T(stands for tutorial) and the number of the tutorial
# e.g., W1D1_T1
feedback_prefix = "W*weeknumber*D*daynumber*_T*tutorialNumber*"

In [None]:
# Imports
import random as pyrandom  # to avoid confusion with np.random
import torch
import matplotlib.pyplot as plt


In [None]:
# @title Figure settings

# logging.getLogger('matplotlib.font_manager').disabled = True

%matplotlib inline
%config InlineBackend.figure_format = 'retina' # perfrom high definition rendering for images and plots
plt.style.use("https://raw.githubusercontent.com/NeuromatchAcademy/course-content/main/nma.mplstyle") # update this to match your course

In [None]:
# @title Plotting functions
def plot_loss_accuracy(t_loss, t_acc, v_loss = None, v_acc = None):
    plt.figure(figsize=(15, 4))
    plt.suptitle("Training and Validation for the Transformer Model")
    plt.subplot(1, 2, 1)
    plt.plot(t_loss, label="Training loss", color="red")
    if v_loss is not None:
        # plt.plot(v_loss, label="Valididation loss", color="blue")
        plt.scatter(len(t_loss)-1, v_loss, label="Validation loss", color="blue", marker="*")
        # plt.text(len(t_loss)-1, v_loss, f"{v_loss:.3f}", va="bottom", ha="right")
    plt.yscale("log")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.xticks([])
    plt.legend(loc="lower right")
    plt.subplot(1, 2, 2)
    plt.plot(t_acc, label="Training accuracy", color="red", linestyle="dotted")
    if v_acc is not None:
        # plt.plot(v_acc, label="Validation accuracy", color="blue", linestyle="--")
        plt.scatter(len(t_acc)-1, v_acc, label="Validation accuracy", color="blue", marker="*")
        # plt.text(len(t_acc)-1, v_acc, f"{v_acc:.3f}", va="bottom", ha="right")
    plt.xticks([])
    plt.ylim(0, 1)
    plt.xlabel("Epoch")
    plt.ylabel("Accuracy")
    plt.legend(loc="lower right")
    plt.show()

def plot_samples(X_plot, y_plot, correct_ids):
    n_samples, seq_length = X_plot.shape
    fig, axs = plt.subplots(1, 2, figsize=(16, 2.5), sharey=True)
    rects = []
    for ri in correct_ids:
        rects.append(plt.Rectangle((ri-0.5, -0.5), 1, n_samples, edgecolor="red", alpha=1.0, fill=False, linewidth=2))
    axs[0].imshow(X_plot, cmap="binary")
    for rect in rects:
        axs[0].add_patch(rect)
    # axs[0].axis("off")
    axs[0].set_yticks([])
    axs[0].set_xticks([])
    axs[0].set_ylabel("Context")
    axs[0].set_xlabel("Samples")
    axs[1].imshow(y_plot, cmap="binary")
    axs[1].add_patch(plt.Rectangle((-0.5, -0.5), 1, n_samples, edgecolor="black", alpha=1.0, fill=False, linewidth=2))
    axs[1].yaxis.set_label_position("right")
    axs[1].set_ylabel("Labels")
    axs[1].set_yticks([])
    axs[1].set_xticks([])
    plt.subplots_adjust(wspace=-1)
    plt.tight_layout()
    plt.show()

def plot_attention_weights(att_weights, correct_ids, context_length):
    aw_flatten = att_weights.view(-1, context_length+1)
    n_weights = aw_flatten.size(0)
    fig, ax = plt.subplots(figsize=(9, 5))
    for i in range(context_length+1):
        ax.scatter(torch.full((n_weights, ), i) , aw_flatten[:, i], alpha=0.1, c='blue')
    rects = []
    for ri in correct_ids:
        rects.append(plt.Rectangle((ri-0.5, 1e-6), 1.0, 2.0, edgecolor="red", alpha=1.0, fill=False, linewidth=2))
    for rect in rects:
        ax.add_patch(rect)
    plt.yscale("log")
    plt.ylim(1e-6, 2)
    plt.title("Attention weights for the whole batch")
    plt.xlabel("Boolean input index t")
    plt.ylabel("Attention weight")
    plt.show()


In [None]:
#@title Data retrieval
class s_Sparse_Boolean:  # 1-Dimensional AND
    def __init__(self, T: int, s: int):
        self.T = T # context length
        self.s = s # sparsity
        self.p = 0.5**(1.0/3.0)  # probability chosen for balanced data
        self.f_i = None

    def pick_an_f(self):
        self.f_i = sorted(pyrandom.sample(range(self.T), 3))
        self.others = list(i for i in range(self.T) if i not in self.f_i)

    def generate(self, m: int, verbose: bool = False):
        if self.f_i is None:
            self.pick_an_f()
        max_try = 100
        i_try = 0
        while i_try < max_try:
            i_try += 1
            X, y = torch.zeros(m, self.T), torch.zeros(m, 1)
            X[torch.rand(m, self.T) < self.p] = 1
            y[X[:, self.f_i].sum(dim=1) == self.s] = 1
            if y.sum()/m < 0.4 or y.sum()/m > 0.6:
                verbose and print(f"Large imbalance in the training set {y.sum()/m}, retrying...")
                continue
            else:
                verbose and print(f"Data-label balance: {y.sum()/m}")
            bad_batch = False
            for i in self.f_i:
                for o in self.others:
                    if (X[:, i] == X[:, o]).all():
                        verbose and print(f"Found at least another compatible hypothesis {i} and {o}")
                        bad_batch = True
                        break
            if bad_batch:
                continue
            else:
                break
        else:
            print("Could not find a compatible hypothesis")
        return X.long(), y.float()


In [None]:
# @title Helper functions

class BinaryMLP(torch.nn.Module):
    def __init__(self, in_dims, h_dims, out_dims, dropout=0.1):
        super().__init__()
        self.in_dims = in_dims
        self.h_dims = h_dims
        self.out_dims = out_dims

        self.layers = torch.nn.ModuleList()
        self.layers.append(torch.nn.Linear(in_dims, h_dims[0]))
        torch.nn.init.normal_(self.layers[-1].weight, std=0.02)
        torch.nn.init.zeros_(self.layers[-1].bias)
        self.layers.append(torch.nn.GELU())
        self.layers.append(torch.nn.Dropout(dropout))
        for i in range(len(h_dims) - 1):
            self.layers.append(torch.nn.Linear(h_dims[i], h_dims[i+1]))
            torch.nn.init.normal_(self.layers[-1].weight, std=0.02)
            torch.nn.init.zeros_(self.layers[-1].bias)
            self.layers.append(torch.nn.GELU())
            self.layers.append(torch.nn.Dropout(dropout))
        self.layers.append(torch.nn.Linear(h_dims[-1], out_dims))
        self.layers.append(torch.nn.Sigmoid())

    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return x


def bin_acc(y_hat, y):
    """
    Compute the binary accuracy
    """
    y_ = y_hat.round()
    TP_TN = (y_ == y).float().sum().item()
    FP_FN = (y_ != y).float().sum().item()
    assert TP_TN + FP_FN == y.numel(), f"{TP_TN + FP_FN} != {y.numel()}"
    return TP_TN / y.numel()


def get_n_parameters(model: torch.nn.Module):
    """
    Get the number of learnable parameters in a model
    """
    i = 0
    for par in model.parameters():
        i += par.numel()
    return i


def save_model(model):
    torch.save(model.state_dict(), 'model_states.pt')


def load_model(model):
    model_states = torch.load('model_states.pt')
    model.load_state_dict(model_states)

def evaluator(model, criterion, X_v, y_v, device="cpu"):
    model.to(device)
    model.eval()
    X_v, y_v = X_v.to(device), y_v.to(device)
    with torch.no_grad():
        y_hat = model(X_v)
        loss = criterion(y_hat.squeeze(), y_v.squeeze())
        acc = bin_acc(y_hat, y_v)
    return loss.item(), acc


def trainer(model, optimizer, criterion, n_epochs, X_t, y_t, device="cpu", verbose=False):
    train_loss, train_acc = [], []
    model.to(device)
    model.train()
    X_t, y_t = X_t.to(device), y_t.to(device)
    for i in range(n_epochs):
        optimizer.zero_grad(set_to_none=True)
        y_hat = model(X_t)
        loss_t = criterion(y_hat.squeeze(), y_t.squeeze())
        loss_t.backward()
        optimizer.step()
        optimizer.zero_grad(set_to_none=True)
        if (i + 1) % 10 == 0 or i == 0:
            train_loss.append(loss_t.item())
            train_acc.append(bin_acc(y_hat, y_t))
    model.eval()
    return train_loss, train_acc


In [None]:
# @title Set random seed for `Python` and `PyTorch`

# @markdown Executing `set_seed(seed=seed)` you are setting the seed

def set_seed(seed=None):
    if seed is None:
        seed = pyrandom.choice([i for i in range(1, 128)])
    pyrandom.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True
    print(f'Random seed {seed} has been set.')

set_seed(seed=2014)  # Bahdanau et al. (2014)

In [None]:
# @title Set device (GPU, MPS or CPU). Execute `set_device()`

def get_device():
    if torch.cuda.is_available():
        device = torch.device("cuda:0")  # NVIDIA GPU
        print("GPU is enabled in this notebook. \n"
              "If you want to disable it, in the menu under `Runtime` -> \n"
              "`Hardware accelerator.` and select `None` from the dropdown menu")
    elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available() and torch.backends.mps.is_built():
        device = torch.device("mps")  # Apple Silicon (Metal)
        print("MPS (Apple Silicon Metal) is enabled in this notebook.")
    else:
        device = torch.device("cpu")
        print("GPU is not enabled in this notebook. \n"
              "If you want to enable it, in the menu under `Runtime` -> \n"
              "`Hardware accelerator.` and select `GPU` from the dropdown menu")
    return device

device = get_device()

---

# Section 1.1: Formulation of problem

Although the goal of AI researchers is to find a model architecture that could solve any problem regardless of its geometry and modality, the current state of machine learning is to rather build or tweak existing models to accomodate the necessary symmetries of the dataset. The authors of the paper identify that self-attention transformers have an inductive bias for sparse functions. They name this inductive bias as *sparse variable creation*.

The sparse function we will use in this tutorial is the 3-sparse AND function. For a given sequence of length $T$ (context length) and three pre-selected unique indices, the sequence is labeled as *True* if the value of $T$ at the three indices is $1$, otherwise *False*. This means that the sequence label only depends on 3 elements of the whole sequence.

$X := [x_1, x_2, ..., X_T]^{\top}~~~~\forall x_i \in \{0, 1 \} \\$
$f: \{0, 1 \}^T \rightarrow \{0, 1 \}$

Our goal is to find a DL architecture that could learn the underlying sparse-boolean function $f$ with least necessary number of training sample and best generalization error.


For this tutorial we have already defined a s-sparse AND dataset generator class `s_Sparse_Boolean`, that can generate $m$ sequence with context length $T$ and sparsity of $s$. Here we visualize few samples and their corresponding labels. The red rectangles show the relevant indicies for this dataset.

In [None]:
context_length = 30  # T: context length
s_sparse = 3  # s: sparsity (number of function-relevant indices)
n_sequences = 10  # m: number of samples (sequences)
data_gen = s_Sparse_Boolean(context_length, s_sparse)
X_, y_ = data_gen.generate(n_sequences, verbose=False)
correct_ids = data_gen.f_i
print(f"Target (function-relevant indices) indices: {correct_ids}")

plot_samples(X_, y_, correct_ids)

In [None]:
# @title Video 1: Video 1 Name  # put in the title of your video
# note the libraries are imported here on purpose

###@@@ for konstanine. a question, why isn't this above in the list of cells?

from ipywidgets import widgets
from IPython.display import YouTubeVideo
from IPython.display import IFrame
from IPython.display import display


class PlayVideo(IFrame):
  def __init__(self, id, source, page=1, width=400, height=300, **kwargs):
    self.id = id
    if source == 'Bilibili':
      src = f'https://player.bilibili.com/player.html?bvid={id}&page={page}'
    elif source == 'Osf':
      src = f'https://mfr.ca-1.osf.io/render?url=https://osf.io/download/{id}/?direct%26mode=render'
    super(PlayVideo, self).__init__(src, width, height, **kwargs)


def display_videos(video_ids, W=400, H=300, fs=1):
  tab_contents = []
  for i, video_id in enumerate(video_ids):
    out = widgets.Output()
    with out:
      if video_ids[i][0] == 'Youtube':
        video = YouTubeVideo(id=video_ids[i][1], width=W,
                             height=H, fs=fs, rel=0)
        print(f'Video available at https://youtube.com/watch?v={video.id}')
      else:
        video = PlayVideo(id=video_ids[i][1], source=video_ids[i][0], width=W,
                          height=H, fs=fs, autoplay=False)
        if video_ids[i][0] == 'Bilibili':
          print(f'Video available at https://www.bilibili.com/video/{video.id}')
        elif video_ids[i][0] == 'Osf':
          print(f'Video available at https://osf.io/{video.id}')
      display(video)
    tab_contents.append(out)
  return tab_contents

# curriculum or production team will provide these ids
video_ids = [('Youtube', '<video_id_1>'), ('Bilibili', '<video_id_2>'), ('Osf', '<video_id_3>')]
tab_contents = display_videos(video_ids, W=854, H=480)
tabs = widgets.Tab()
tabs.children = tab_contents
for i in range(len(tab_contents)):
  tabs.set_title(i, video_ids[i][0])
display(tabs)

In [None]:
# @title Submit your feedback
content_review(f"{feedback_prefix}_Video_1_Name")

### Exercise 1.1: Multi-Layer Perceptron

Multi-Layer Percepton (MLP) layers are a key part of most deep learning architectures. Theoretically, a neural network with one wide enough (i.e. with enough learnable parameters) hidden MLP layer, enough training samples, and training iterations, is a universal function approximator. But in practice, these models tend to be extremely limited in their representation strength. Nonetheless, they are often a good place to start from.

Below we will train an MLP with one hidden layer on the s-sparse AND dataset. You should be able to evaluate the performance and generalization of the model for the given task.

#### Discussion
* What are the symmetries in a fully connected MLP?
* Why do MLPs often fall behind other architectures in representation learning and generalization?

#### Task
First, evaluate and discuss the performance of the model! Then Change the setup (through hyperparameters) such that the model performs better on the validation set (empirical generalization).

In [None]:
# # Hyperparameters
context_length = 30  # T: context length
s_sparse = 3  # s: sparsity (number of function-relevant indices)
B_train = 50  # batch size for training (number of training samples)
B_valid = 500  # batch size for validation  (number of validation samples)
hidden_layers = [512, 128, 64]  # the number of hidden units in each layer [H1, H2, ...]
etta = 1e-3  # learning rate
n_epochs = 500  # number of epochs

# # Data generation
data_gen = s_Sparse_Boolean(context_length, s_sparse)
X_train, y_train = data_gen.generate(B_train, verbose=False)
X_valid, y_valid = data_gen.generate(B_valid, verbose=False)

# # model, optimizer, and criterion 
mlp_model = BinaryMLP(context_length, hidden_layers, 1)
optimizer = torch.optim.Adam(mlp_model.parameters(), lr=etta, weight_decay=1e-4)
criterion = torch.nn.BCELoss()
print(f"Number of model's learnable parameters: {get_n_parameters(mlp_model)}")

In [None]:
# # Training and evaluation
t_loss, t_acc = trainer(mlp_model, optimizer, criterion, n_epochs, X_train.float(), y_train, device=device, verbose=False)
v_loss, v_acc = evaluator(mlp_model, criterion, X_valid.float(), y_valid, device=device)
print(f"Training loss: {t_loss[-1]:.3f}, accuracy: {t_acc[-1]:.3f}")
print(f"Validation loss: {v_loss:.3f}, accuracy: {v_acc:.3f}")
plot_loss_accuracy(t_loss, t_acc, v_loss, v_acc)

## Section 1.2: Self-attention Transformers

Self-attention is the heart of all transformer models, including ViT, ChatGPT, and Dino. For this tutorial, we will limit ourselves to a very simple implementation of the self-attention transformer for classification. Although the model below  has only one self-attention block, it can be easily extended to more state-of-the-art architectures. We also use positional embedding (learnable embedding), as opposed to positional encoding.

> TODO the attention figure!

#### Task
Look through the model and the figure, and make sure that you know the purpose of each component, try to follow the information flow and appreciate the simplicity of underlying components that give rise to such a powerful architecture.

In [None]:
class BinarySAT(torch.nn.Module):
    """Binary Self-Attention Transformer
    """
    def __init__(self, T: int, d: int, n_heads: int, n: int):
        super().__init__()
        self.T = T  # context length
        self.E = T + 1  # effective length (including cls token)
        self.d = d  # embedding size
        self.n_heads = n_heads  # number of heads
        self.scale = (d // n_heads) ** -0.5  # scaling factor (1 / sqrt(d_k))
        self.n = n  # number of hidden units
        assert d % n_heads == 0, "embedding size must be divisible by number of heads"
        self.v = 2  # vocabulary size (binary input, 0 or 1)
        att_drop=0.1
        out_drop=0.1
        mlp_drop=0.1
        ln_eps=1e-6

        # embedding layers
        self.toke = torch.nn.Embedding(2, d)  # token embedding
        self.cls = torch.nn.Parameter(torch.randn(1, 1, d))  # "cls / class / global" learnable token
        self.pose = torch.nn.Parameter(torch.randn(1, T + 1, d))  # positional embedding
        self.norm1 = torch.nn.LayerNorm(d, eps=ln_eps)  # [https://arxiv.org/pdf/2002.04745.pdf]

        # self-attention layers
        self.qkv = torch.nn.Linear(d, 3 * d)  # query, key, value layers
        self.dropout_attn = torch.nn.Dropout(att_drop)
        self.proj = torch.nn.Linear(d, d)  # projection layer
        self.dropout_out = torch.nn.Dropout(out_drop)
        self.norm2 = torch.nn.LayerNorm(d, eps=ln_eps)

        # MLP layers
        self.mlp_l1 = torch.nn.Linear(d, n)
        self.mlp_l2 = torch.nn.Linear(n, d)
        self.dropout_mlp = torch.nn.Dropout(mlp_drop)
        self.norm3 = torch.nn.LayerNorm(d, eps=1e-6)

        # clasification layer
        self.head = torch.nn.Linear(d, 1)

        # initialize weights and biases (per description in the paper)
        torch.nn.init.normal_(self.toke.weight, std=0.02)
        torch.nn.init.normal_(self.pose, std=0.02)
        torch.nn.init.normal_(self.cls, std=0.02)
        torch.nn.init.ones_(self.norm1.weight)
        torch.nn.init.zeros_(self.norm1.bias)

        torch.nn.init.normal_(self.qkv.weight, std=0.02)
        torch.nn.init.zeros_(self.qkv.bias)
        torch.nn.init.normal_(self.proj.weight, std=0.02)
        torch.nn.init.zeros_(self.proj.bias)
        torch.nn.init.ones_(self.norm2.weight)
        torch.nn.init.zeros_(self.norm2.bias)

        torch.nn.init.normal_(self.mlp_l1.weight, std=0.02)
        torch.nn.init.zeros_(self.mlp_l1.bias)
        torch.nn.init.normal_(self.mlp_l2.weight, std=0.02)
        torch.nn.init.zeros_(self.mlp_l2.bias)
        torch.nn.init.ones_(self.norm3.weight)
        torch.nn.init.zeros_(self.norm3.bias)

        torch.nn.init.normal_(self.head.weight, std=0.02)
        torch.nn.init.zeros_(self.head.bias)

    def forward(self, x):
        # Embedding
        B = x.size(0)  # batch size
        x = self.toke(x)
        x = torch.cat([x, self.cls.expand(B, -1, -1)], dim=1)
        x = x + self.pose
    
        # Transformer Block
        # # (Scaled Dot-Product Attention)
        norm_x = self.norm1(x)  # [https://arxiv.org/pdf/2002.04745.pdf]
        q, k, v = self.qkv(norm_x).view(B, self.E, 3, self.n_heads, -1).unbind(dim=2)
        logits = torch.einsum("bthc,bshc->bhts", q, k)  # query key product
        logits *= self.scale  # normalize against staturation
        attn = torch.softmax(logits, dim=-1)
        attn = self.dropout_attn(attn)
        output = torch.einsum("bhts,bshc->bthc", attn, v)  # weighted attention
        # # concat and linear projection with residual connection
        output = output.reshape(B, self.E, self.d)  # recombine
        output = self.proj(output)  # linear layer projection
        output = self.dropout_out(output)
        x = self.norm2(x + output)  # normalization and residual connection

        # MLP with residual connection
        output = torch.relu(self.mlp_l1(x))  # nonlinear layer
        output = self.dropout_mlp(output)
        output = self.mlp_l2(output)  # linear layer
        x = self.norm3(x + output)  # normalization and residual connection

        # projection
        x = self.head(x[:, -1])
        x = torch.sigmoid(x)  # binary classification task
        return x


we can now train a self-attention model with 1-head and 1-block on the s-sparse AND dataset. The choice of hyper-parameters here is solely for the purpose of demonstration. The paper uses 16-heads and double the embedding and hidden dimensionality.

#### Task:
First, evaluate and discuss the performance of the attention model and compare the results and hyper-parameters with the MLP.

In [None]:
# # Hyperparameters
context_length = 30  # T: context length
s_sparse = 3  # s: sparsity (number of function-relevant indices)
B_train = 50  # batch size for training (number of training samples)
B_valid = 500  # batch size for validation  (number of validation samples)

embed_dim = 32  # embedding dimension
n_heads = 1  # number of heads
hidden_dim = 64  # number of hidden units

etta = 1e-3  # learning rate
n_epochs = 500  # number of epochs

# # Data generation
data_gen = s_Sparse_Boolean(context_length, s_sparse)
X_train, y_train = data_gen.generate(B_train, verbose=False)
X_valid, y_valid = data_gen.generate(B_valid, verbose=False)

# # model, optimizer, and criterion 
sat_model = BinarySAT(context_length, embed_dim, n_heads, hidden_dim)
optimizer = torch.optim.Adam(sat_model.parameters(), lr=etta, weight_decay=1e-4)
criterion = torch.nn.BCELoss()
print(f"Number of model's learnable parameters: {get_n_parameters(sat_model)}")

In [None]:
# # Training and evaluation
t_loss, t_acc = trainer(sat_model, optimizer, criterion, n_epochs, X_train, y_train, device=device, verbose=False)
v_loss, v_acc = evaluator(sat_model, criterion, X_valid, y_valid, device=device)
print(f"Training loss: {t_loss[-1]:.3f}, accuracy: {t_acc[-1]:.3f}")
print(f"Validation loss: {v_loss:.3f}, accuracy: {v_acc:.3f}")
plot_loss_accuracy(t_loss, t_acc, v_loss, v_acc)

### Coding Exercise 1: Weighted attention

A common figure in attention literature is the "Attention visualization" which shows how the model is attending to different parts of the sequence. Here, we will call this the weighted attention defined as follow:

$$W_{QK} = \text{softmax}(\frac{QK^T}{\sqrt{d_k}})$$

The implementation above already includes this, but we would like to have a function that takes the input sequence and outputs the $W_{QK}$. Note that since the $W_{QK}$ is to "re-weight" the input sequence, it should be of shape `[B, E, E]` where `B` is the batch_size, `E` is the extended context length ($E = T + 1$ since the `cls` token is appended to the sequence!).

In [None]:
def weighted_attention(self, x):
    """This function computes the weighted attention as a method for the BinarySAT class

    Args:
        x (Tensor): An array of shape (B:batch_size, T:context length) containing the input data

    Returns:
        Tensor: weighted attention
    """
    assert self.n_heads == 1, "This function is only implemented for a single head!"
    # Embedding
    B = x.size(0)  # batch size
    x = self.toke(x)  # token embedding
    x = torch.cat([x, self.cls.expand(B, -1, -1)], dim=1)  # concatenate cls token
    x = x + self.pose  # positional embedding
    norm_x = self.norm1(x)  # normalization

    # Scaled Dot-Product Attention (partially implemented)
    q, k, v = self.qkv(norm_x).view(B, self.E, 3, self.d).unbind(dim=2)
    #################################################
    ## TODO Implement the weighted attention 
    # Fill remove the following line of code one you have completed the exercise:
    raise NotImplementedError("Student exercise: say what they should have done")
    #################################################
    W_qk = ...
    return W_qk


In [None]:
# to_remove solution
def weighted_attention(self, x):
    """This function computes the weighted attention as a method for the BinarySAT class

    Args:
        x (Tensor): An array of shape (B:batch_size, T:context length) containing the input data

    Returns:
        Tensor: weighted attention
    """
    assert self.n_heads == 1, "This function is only implemented for a single head!"
    # Embedding
    B = x.size(0)  # batch size
    x = self.toke(x)  # token embedding
    x = torch.cat([x, self.cls.expand(B, -1, -1)], dim=1)  # concatenate cls token
    x = x + self.pose  # positional embedding
    norm_x = self.norm1(x)  # normalization
    
    # Scaled Dot-Product Attention (partially implemented)
    q, k, v = self.qkv(norm_x).view(B, self.E, 3, self.d).unbind(dim=2)
    W_qk = q @ k.transpose(-2, -1)
    W_qk = W_qk * self.scale
    W_qk = torch.softmax(W_qk, dim=-1)
    return W_qk


In [None]:
context_length = 30  # T: context length
s_sparse = 3  # s: sparsity (number of function-relevant indices)
n_sequences = 100  # m: number of samples (sequences)
X_, y_ = data_gen.generate(n_sequences, verbose=False)
correct_ids = data_gen.f_i
print(f"Target (function-relevant indices) indices: {correct_ids}")

with torch.no_grad():
    w_att = weighted_attention(sat_model, X_.to(device)).cpu().detach()

plot_attention_weights(w_att, correct_ids, context_length)

In [None]:
# @title Submit your feedback
content_review(f"{feedback_prefix}_name_of_Exercise")

## Section 1.3: Sparse Variable Creation

So far, we saw that a simple transformer model can effectively learn to represent an s-sparse boolean function. Next, we would like to demonstrate the sample efficiency of self-attention for such class of functions. The paper regorously shows that the number of training samples $m$ needed to achieve a good generalization on an s-sparse function, grows only logarithmically with respect to the context length $T$. They also put forward their empirical results for the s-sparse AND dataset. Given the time and computation limits of our tutorial, we will only show their results here, but you can find the implementation and detailed results on GitHub. The Figure below is from the paper [Inductive Biases and Variable Creation in Self-Attention Mechanisms](https://proceedings.mlr.press/v162/edelman22a.html)

<div>
<img src="./static/Fig_2_ref2.png" width="500"/>
</div>

In [None]:
# @title Submit your feedback
content_review(f"{feedback_prefix}_Name_of_Discussion_Topic")

---
# Summary

*Estimated timing of tutorial: [minutes]* [provide the estimated time for the completing of the entire tutorail]

Have a summary of what they learned with specific points.

1. Specific point A

2. Specific point B