# Attention intuition

The lecture notes on Deep Learning from [Francois Fleuret](https://fleuret.org/dlc/materials/dlc-handout-13-2-attention-mechanisms.pdf) contain a very nice intuition on why and where the attention mechanism works better than conv nets.

In the example, he considers a toy sequence-to-sequence problem with triangular and rectangular shapes with random heights as input.
The expected target contains the same shapes but with their heights averaged, as in the figure below.


![](images/data_example.png)

Since there was no source code available in his lecture (as far as I know), I have tried to reproduce the same intuition in this notebook.
As we can see, with the exact training procedure, the attention mechanism is able to learn the task much faster than the conv net.
The conv net model's poor performance is expected due to its inability to look far away the input signal to learn the task.
There are plenty of mechanisms we can equip the conv net with to make it work better (more layers, fully connected layers, ...), but the attention mechanism is a very simple and elegant solution to this problem.



In [None]:
import torch
import torch.nn as nn
from shape_dataset import ShapeDataset
import matplotlib.pyplot as plt

# Implementation of the self attention layer
class SelfAttentionLayer(nn.Module):
    def __init__(self, in_dim, out_dim, key_dim):
        super().__init__()
        self.conv_Q = nn.Conv1d(in_dim, key_dim, kernel_size=1, bias=False)
        self.conv_K = nn.Conv1d(in_dim, key_dim, kernel_size=1, bias=False)
        self.conv_V = nn.Conv1d(in_dim, out_dim, kernel_size=1, bias=False)

    def forward(self, x):
        Q = self.conv_Q(x)
        K = self.conv_K(x)
        V = self.conv_V(x)
        A = Q.transpose(1, 2).matmul(K).softmax(2)
        y = A.matmul(V.transpose(1, 2)).transpose(1, 2)
        return y
    
def train_model(model, epochs=20, device="cuda"):
    train_loader = torch.utils.data.DataLoader(
        dataset=ShapeDataset(size=100, max_height=50, noise_std=0.3, max_samples=10000),
        batch_size=32,
        shuffle=True,
        num_workers=4,
        pin_memory=True,
    )

    model.train()
    model.to(device)

    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    criterion = torch.nn.MSELoss()
    loss_per_epoch = []    
    for epoch in range(epochs):
        for i, (input, target) in enumerate(train_loader):
            input = input.to(device)
            target = target.to(device)
            optimizer.zero_grad()
            output = model(input)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()
        if epoch%10==0:
            print(f"Epoch: {epoch}, Loss: {loss.item()}")
        loss_per_epoch.append(loss.item())

    return loss_per_epoch


# Training the regular model with no self attention
print("Training with conv model")
conv_model = nn.Sequential(
    nn.Conv1d(1, 64, kernel_size=(5,), padding="same"),
    nn.ReLU(),
    nn.Conv1d(64, 64, kernel_size=(5,), padding="same"),
    nn.ReLU(),
    nn.Conv1d(64, 64, kernel_size=(5,), padding="same"),
    nn.ReLU(),
    nn.Conv1d(64, 64, kernel_size=(5,), padding="same"),
    nn.ReLU(),
    nn.Conv1d(64, 1, kernel_size=(5,), padding="same"),
)
loss_per_epoch_conv = train_model(conv_model, epochs=50)

# Training with self attention
print("Training with attention model")
attention_model = nn.Sequential(
    nn.Conv1d(1, 64, kernel_size=(5,), padding="same"),
    nn.ReLU(),
    nn.Conv1d(64, 64, kernel_size=(5,), padding="same"),
    nn.ReLU(),
    SelfAttentionLayer(in_dim=64, out_dim=64, key_dim=64),
    nn.Conv1d(64, 64, kernel_size=(5,), padding="same"),
    nn.ReLU(),
    nn.Conv1d(64, 1, kernel_size=(5,), padding="same"),
)
loss_per_epoch_attention = train_model(attention_model, epochs=50)

plt.figure()
plt.plot(loss_per_epoch_conv, label="Without Attention")
plt.plot(loss_per_epoch_attention, label="With Attention")
plt.legend()

pass    

Training with conv model
Epoch: 0, Loss: 0.0066954898647964
Epoch: 10, Loss: 0.006643436849117279
Epoch: 20, Loss: 0.003941401373594999
Epoch: 30, Loss: 0.005726605653762817
Epoch: 40, Loss: 0.007211305201053619
Training with attention model
Epoch: 0, Loss: 0.008388693444430828
Epoch: 10, Loss: 0.002573534846305847
Epoch: 20, Loss: 0.0019854912534356117
Epoch: 30, Loss: 0.0011634223628789186
