# 3. Attention メカニズムのコーディング

## 3. 3 Sefl-Attention を使って入力の異なる部分に注意を払う

In [None]:
import torch

In [None]:
inputs = torch.tensor(
  [[0.43, 0.15, 0.89], # Your x^1
   [0.55, 0.87, 0.66], # journey x^2
   [0.57, 0.85, 0.64], # starts x^3
   [0.22, 0.85, 0.33], # with x^4
   [0.77, 0.25, 0.10], # one x^5
   [0.05, 0.80, 0.55]  # step x^6
  ]
)

In [None]:
query = inputs[1]
attn_scores_2 = torch.empty(inputs.shape[0])
for i, x_i in enumerate(inputs):
  attn_scores_2[i] = torch.dot(x_i, query)

print(attn_scores_2)

In [None]:
attn_weights_2_tmp = attn_scores_2 / attn_scores_2.sum()
print("Attention weights:", attn_weights_2_tmp)
print("Sum:", attn_weights_2_tmp.sum())

In [None]:
def softmax_naive(x):
  return torch.exp(x) / torch.exp(x).sum(dim=0)

In [None]:
attn_weights_2_naive = softmax_naive(attn_scores_2)
print("Attention weights:", attn_weights_2_naive)
print("Sum:", attn_weights_2_naive.sum()) 

In [None]:
attn_weight_2 = torch.softmax(attn_scores_2, dim=0)
print("Attention weights:", attn_weight_2)
print("Sum:", attn_weight_2.sum())

In [None]:
query = inputs[1]

context_vec_2 = torch.zeros(query.shape)
for i, x_i in enumerate(inputs):
  context_vec_2 += attn_weight_2[i] * x_i

print(context_vec_2)

In [None]:
attn_scores = torch.empty(6, 6)
for i, x_i in enumerate(inputs):
  for j, x_j in enumerate(inputs):
    attn_scores[i, j] = torch.dot(x_i, x_j)

print(attn_scores)

In [None]:
attn_scores = inputs @ inputs.T

print(attn_scores)

In [None]:
attn_weights = torch.softmax(attn_scores, dim=-1)
print(attn_weights)

In [None]:
row_2_sum = sum([0.1341, 0.2303, 0.2259, 0.1518, 0.1047, 0.1531])
print("Row 2 sum:", row_2_sum)
print("All row sums:", attn_weights.sum(dim=-1))

In [None]:
inputs

In [None]:
all_context_vecs = attn_weights @ inputs
print(all_context_vecs)

In [None]:
print("Previous 2nd context vector:", context_vec_2)

## 3.4 訓練可能な重みをもつ Self-Attention を実装する

In [None]:
x_2 = inputs[1]
d_in = inputs.shape[1] # = 3
d_out = 2

In [None]:
torch.manual_seed(123)
W_query = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
W_key   = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
W_value = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)

In [None]:
query_2 = x_2 @ W_query
key_2 = x_2 @ W_key 
value_2 = x_2 @ W_value
print(query_2)

In [None]:
keys = inputs @ W_key 
values = inputs @ W_value
print("keys.shape:", keys.shape)
print("values.shape:", values.shape)

In [None]:
keys_2 = keys[1]
attn_score_22 = query_2.dot(keys_2)
print(attn_score_22)

In [None]:
attn_scores_2 = query_2 @ keys.T
print(attn_scores_2)

In [None]:
d_k = keys.shape[-1]
attn_weights_2 = torch.softmax(attn_scores_2 / d_k ** 0.5, dim=-1)
print(attn_weights_2)

In [None]:
context_vec_2 = attn_weights_2 @ values
print(context_vec_2)

In [None]:
import torch.nn as nn 

class SelfAttention_v1(nn.Module):

  def __init__(self, d_in, d_out):
    super().__init__()
    self.W_query  = nn.Parameter(torch.rand(d_in, d_out))
    self.W_key    = nn.Parameter(torch.rand(d_in, d_out))
    self.W_value  = nn.Parameter(torch.rand(d_in, d_out))

  def forward(self, x):
    keys    = x @ self.W_key
    queries = x @ self.W_query
    values  = x @ self.W_value

    attn_scores = queries @ keys.T
    attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
    context_vec = attn_weights @ values
    return context_vec



In [None]:
torch.manual_seed(123)
sa_v1 = SelfAttention_v1(d_in, d_out)
print(sa_v1(inputs))

In [None]:
print(context_vec_2)

In [None]:
class SelfAttention_v2(nn.Module):
  
  def __init__(self, d_in, d_out, qkv_bias=False):
    super().__init__()
    self.W_query  = nn.Linear(d_in, d_out, bias=qkv_bias)
    self.W_key    = nn.Linear(d_in, d_out, bias=qkv_bias)
    self.W_value  = nn.Linear(d_in, d_out, bias=qkv_bias)

  def forward(self, x):
    keys    = self.W_key(x)
    queries = self.W_query(x)
    values  = self.W_value(x)

    attn_scores = queries @ keys.T
    attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
    context_vec = attn_weights @ values
    return context_vec
    

In [None]:
torch.manual_seed(789)
sa_v2 = SelfAttention_v2(d_in, d_out)
print(sa_v2(inputs))

## 3. 5 Causal Attention で未来の単語を隠す

In [None]:
queries = sa_v2.W_query(inputs)
keys    = sa_v2.W_key(inputs)

attn_scores = queries @ keys.T
attn_weights = torch.softmax(attn_scores / keys.shape[-1] ** 0.5, dim=-1)
print(attn_weights)

In [None]:
context_length = attn_scores.shape[0]
mask_simple = torch.tril(torch.ones(context_length, context_length))
print(mask_simple)

In [None]:
masked_simple = attn_weights * mask_simple
print(masked_simple)

In [None]:
row_sums = masked_simple.sum(dim=-1, keepdim=True)
masked_simple_norm = masked_simple / row_sums
print(masked_simple_norm)

In [None]:
mask = torch.triu(torch.ones(context_length, context_length), diagonal=1)
masked = attn_scores.masked_fill(mask.bool(), -torch.inf)
print(masked)

In [None]:
attn_weights = torch.softmax(masked / keys.shape[-1]**0.5, dim=1)
print(attn_weights)

In [None]:
# dropout
dropout = torch.nn.Dropout(p=0.5)
example = torch.ones(6, 6)
print(dropout(example))

In [None]:
torch.manual_seed(789)
print(dropout(attn_weights))

In [None]:
batch = torch.stack((inputs, inputs), dim=0)
print(batch.shape)

In [None]:
class CausalAttention(nn.Module):
  def __init__(self, d_in, d_out, context_length, dropout, qkv_bias=False):
    super().__init__()
    self.d_out = d_out
    self.W_query  = nn.Linear(d_in, d_out, bias=qkv_bias)
    self.W_key    = nn.Linear(d_in, d_out, bias=qkv_bias)
    self.W_value  = nn.Linear(d_in, d_out, bias=qkv_bias)

    self.dropout  = nn.Dropout(dropout)
    self.register_buffer('mask', torch.triu(torch.ones(context_length, context_length), diagonal=1))

  def forward(self, x):
    b, num_tokens, d_in = x.shape
    keys    = self.W_key(x)
    queries = self.W_query(x)
    values  = self.W_value(x)

    attn_scores = queries @ keys.transpose(1, 2)
    attn_scores.masked_fill_(self.mask.bool()[:num_tokens, :num_tokens], -torch.inf)
    attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
    attn_weights = self.dropout(attn_weights)

    context_vec = attn_weights @ values
    return context_vec

In [None]:
torch.manual_seed(123)
context_length = batch.shape[1]
ca = CausalAttention(d_in, d_out, context_length, 0.0)
context_vecs = ca(batch)
print("context_vecs.shape:", context_vecs.shape)

In [None]:
class MultiHeadAttentionWrapper(nn.Module):
  def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
    super().__init__()
    self.heads = nn.ModuleList(
      [CausalAttention(d_in, d_out, context_length, dropout, qkv_bias)
       for _ in range(num_heads)]
    )
  
  def forward(self, x):
    return torch.cat([head(x) for head in self.heads], dim=-1)
  

In [None]:
torch.manual_seed(123)
context_length = batch.shape[1]
d_in, d_out = 3, 2
mha = MultiHeadAttentionWrapper(d_in, d_out, context_length, 0.0, num_heads=2)
context_vecs = mha(batch)

print(context_vecs)
print("context_vecs.shape:", context_vecs.shape)

In [None]:
class MultiHeadAttention(nn.Module):
  def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
    super().__init__()
    assert (d_out % num_heads == 0), "d_out must be divisible by num_heads"

    self.d_out = d_out
    self.num_heads = num_heads
    self.head_dim = d_out // num_heads

    self.W_query  = nn.Linear(d_in, d_out, bias=qkv_bias)
    self.W_key    = nn.Linear(d_in, d_out, bias=qkv_bias)
    self.W_value  = nn.Linear(d_in, d_out, bias=qkv_bias)
    
    self.out_proj = nn.Linear(d_out, d_out)
    self.dropout = nn.Dropout(dropout)
    self.register_buffer(
      "mask",
      torch.triu(torch.ones(context_length, context_length), diagonal=1)
    )

  def forward(self, x):
    b, num_tokens, d_in = x.shape

    keys    = self.W_key(x)
    queries = self.W_query(x)
    values  = self.W_value(x)

    keys    = keys.view(    b, num_tokens, self.num_heads, self.head_dim)
    values  = values.view(  b, num_tokens, self.num_heads, self.head_dim)
    queries = queries.view( b, num_tokens, self.num_heads, self.head_dim)

    keys    = keys.transpose(1, 2)
    queries = queries.transpose(1, 2)
    values  = values.transpose(1, 2)

    attn_scores = queries @ keys.transpose(2, 3)

    mask_bool = self.mask.bool()[:num_tokens, :num_tokens]

    attn_scores.masked_fill_(mask_bool, -torch.inf)

    attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
    attn_weights = self.dropout(attn_weights)

    context_vec = (attn_weights @ values).transpose(1, 2)
    context_vec = context_vec.contiguous().view(
      b, num_tokens, self.d_out
    )

    context_vec = self.out_proj(context_vec)

    return context_vec
    


In [None]:
torch.manual_seed(123)
batch_size, context_length, d_in = batch.shape
d_out = 2
mha = MultiHeadAttention(d_in, d_out, context_length, 0.0, num_heads=2)
context_vecs = mha(batch)
print(context_vecs)
print("context_vecs.shape:", context_vecs.shape)