In [None]:
import torch
import matplotlib.pyplot as plt
%matplotlib inline
plt.style.use('dark_background')

# Reweighing based on proximity

We get some simple timeseries

In [None]:
t = torch.linspace(0, 4 * torch.pi, 100)


def f(t):
    noise = torch.randn(100) * 0.1
    return torch.sin(t) + noise


x = f(t)

And let's create a filter that smooths the signal by taking in the neighbouring signals

In [None]:
# initialize a convolution
conv = torch.nn.Conv1d(1, 1, 8)

# build a gaussian kernel
n = torch.distributions.normal.Normal(0, 1)
v = torch.arange(-4, 4)
gaussian = torch.exp(n.log_prob(v))[None, None, :]

# replace the random weights
d = conv.state_dict()
d["weight"] = gaussian
conv.load_state_dict(d)
x_ = conv(x[None, None, :])[0][0]


We can plot everything like this:

In [None]:
fig, ax = plt.subplots(1, 3, figsize=(12, 4))

ax[0].plot(t, x)
ax[0].set_title("original")

ax[1].scatter([*range(8)], gaussian[0][0])
ax[1].set_title("gaussian filter")

ax[2].plot(x_.detach())
ax[2].set_title("smoothed")


So, you see a noisy signal. The gaussian filter takes adds more of the close values, and less of the values further away. The filter works on the direct neighborhood.

# Text

Now, we consider text.

In [None]:
from mltrainer import tokenizer

corpus = ["ik zit op de bank", "ik werk bij de bank", "bij de bank is het heel druk"]
v = tokenizer.build_vocab(corpus, max=11)


Now, the problem with text is: it is not necessarily the words that are close, that have the most impact.

In [None]:
v["bank"], v["test"], len(v)


We see three sentences, max seven words, so dimensions are (3,7)

In [None]:
x = tokenizer.tokenize(corpus, v)
x, x.shape


Our sentences are being encoded, and the word "bank" gets the integer 6 assigned. However, the meaning of this word is not the same because of the context... If we make an embedding:

In [None]:
import torch.nn as nn

emb = nn.Embedding(num_embeddings=len(v), embedding_dim=4, padding_idx=0)

embeddings = emb(x)

embeddings, embeddings.shape


We have added a dimensionality of 4 to every word. So now we have (3, 7, 4).
You can see that the word "bank" gets exactly the same vector, as expected...

In [None]:
bank1 = embeddings[0][4]
bank2 = embeddings[1][4]
bank1, bank2


# Attention
Now we will start with the attention mechanism.
We need a key, query and value. Because we use self attention, these are just clones.

In [None]:
key = embeddings.detach().clone()
query = embeddings.detach().clone()
values = embeddings.detach().clone()
key.shape


We have 4 features

In [None]:
d_features = torch.tensor(query.shape[-1])
d_features

And with this, we can calculate $$\frac{(QK^T)}{\sqrt{d}}$$

In [None]:
dots = torch.bmm(query, key.transpose(1, 2)) / torch.sqrt(d_features)
dots.shape


This gives us a shape of (3, 7, 7):
for every sentence, we have for every word, weights how we want to mix in every other word. So this last part always has a shape (sequence, sequence)

We obtain the weights with a softmax:

In [None]:
weights = nn.Softmax(dim=-1)(dots)

weights[0]


and finally we can do a matrix-multiplication with the values:

$$attention = softmax\left(\frac{(QK^T)}{\sqrt{d}}\right)V$$

In [None]:
activations = torch.bmm(weights, values)
activations.shape, embeddings.shape


note how we end up with exactly the same size: 3 sentences, max 7 words, but now every word has 4 dimensions that are reweighted by the other words in the sentence, regardless of the distance, but mainly depending on the semantics (meaning) of every word as encoded in de embedding.

In [None]:
bank1 = activations[0][4]
bank2 = activations[1][4]
bank1, bank2


Now, the vector for the word bank has been "mixed" with all the other words in the sentence, and they are different!

torch has a multihead attention implemented. With that, we can add a mask to cover the padding.

In [None]:
mask = x == 0
mask.shape


In [None]:
multihead = nn.MultiheadAttention(embed_dim=4, num_heads=2, batch_first=True)
attn, attn_w = multihead(query, key, values, key_padding_mask=mask)


In [None]:
attn.shape


It is possible to visualize the weights. In this case, this is untrained.
What you expect is that after training the vector for the word "bank" should be mostly mixed with the word "zit" (sit) to make more sense.

In [None]:
import seaborn as sns

plt.figure(figsize=(10, 10))
labels = corpus[0].split()
labels = labels + ["PAD", "PAD"]

plot = sns.heatmap(attn_w[0].detach().numpy())

plot.set_xticklabels(labels);
plot.set_yticklabels(labels);