## 기초 연습

In [73]:
import torch

dim = 512
n_heads = 8

dim_per_head = dim // n_heads

print(dim_per_head)

64


In [74]:
A = torch.Tensor([
    [1, 2, 3],
    [4, 5, 6]
])
B = torch.Tensor([
    [1, 4], 
    [2, 5],
    [3, 6]
])
C = torch.matmul(A, B)
print(C)

tensor([[14., 32.],
        [32., 77.]])


In [75]:
bs = 1
seq_length = 5

def shape(x: torch.Tensor) -> torch.Tensor:
    """separate heads"""
    return x.view(bs, -1, n_heads, dim_per_head).transpose(1, 2)

def unshape(x: torch.Tensor) -> torch.Tensor:
    """group heads"""
    return x.transpose(1, 2).contiguous().view(bs, -1, n_heads * dim_per_head)

query = torch.rand(bs, seq_length, dim)
q = shape(query)
print(q)
print(q.size())

q_ = unshape(q)
print(q_.size())

tensor([[[[0.0697, 0.7858, 0.5453,  ..., 0.6775, 0.4480, 0.6910],
          [0.2766, 0.0698, 0.4162,  ..., 0.1777, 0.8022, 0.3807],
          [0.5000, 0.7420, 0.9270,  ..., 0.9319, 0.3514, 0.9430],
          [0.3707, 0.3878, 0.5769,  ..., 0.4808, 0.9184, 0.0537],
          [0.0121, 0.0971, 0.2607,  ..., 0.8670, 0.4894, 0.2651]],

         [[0.0725, 0.6900, 0.1731,  ..., 0.4663, 0.8738, 0.1976],
          [0.0947, 0.8805, 0.0407,  ..., 0.6522, 0.6374, 0.1722],
          [0.3488, 0.7712, 0.9439,  ..., 0.7995, 0.6770, 0.7722],
          [0.1152, 0.0839, 0.3617,  ..., 0.3688, 0.2944, 0.2753],
          [0.5751, 0.7502, 0.0508,  ..., 0.3478, 0.1272, 0.2526]],

         [[0.6463, 0.3965, 0.9159,  ..., 0.3100, 0.7720, 0.0145],
          [0.2610, 0.3123, 0.7381,  ..., 0.2167, 0.9840, 0.7408],
          [0.9107, 0.2784, 0.8956,  ..., 0.6077, 0.8119, 0.0838],
          [0.6687, 0.6608, 0.5911,  ..., 0.7178, 0.1452, 0.9359],
          [0.4970, 0.0737, 0.5581,  ..., 0.9836, 0.5183, 0.9118]],

    

## 멀티 헤드 어텐션(Multi-head Attention)

In [76]:
class PretrainedConfig:
    def __init__(self, **kwargs):
        self.n_heads = kwargs.pop("n_heads", 8)
        self.dim = kwargs.pop("dim", 512)
        self.attention_dropout = kwargs.pop("attention_dropout", 0.2)

In [77]:
import math
from typing import Dict, List, Optional, Set, Tuple, Union

from torch import nn

class MultiHeadSelfAttention(nn.Module):
    def __init__(self, config: PretrainedConfig):
        super().__init__()
        self.config = config

        self.n_heads = config.n_heads
        self.dim = config.dim
        self.dropout = nn.Dropout(p=config.attention_dropout)
        self.is_causal = False

        # Have an even number of multi heads that divide the dimensions
        if self.dim % self.n_heads != 0:
            # Raise value errors for even multi-head attention nodes
            raise ValueError(f"self.n_heads: {self.n_heads} must divide self.dim: {self.dim} evenly")

        self.q_lin = nn.Linear(in_features=config.dim, out_features=config.dim)
        self.k_lin = nn.Linear(in_features=config.dim, out_features=config.dim)
        self.v_lin = nn.Linear(in_features=config.dim, out_features=config.dim)
        self.out_lin = nn.Linear(in_features=config.dim, out_features=config.dim)

        self.pruned_heads: Set[int] = set()
        self.attention_head_size = self.dim // self.n_heads

    def forward(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        mask: torch.Tensor,
        head_mask: Optional[torch.Tensor] = None,
        output_attentions: bool = False,
    ) -> Tuple[torch.Tensor, ...]:
        """
        Parameters:
            query: torch.tensor(bs, seq_length, dim)
            key: torch.tensor(bs, seq_length, dim)
            value: torch.tensor(bs, seq_length, dim)
            mask: torch.tensor(bs, seq_length)

        Returns:
            weights: torch.tensor(bs, n_heads, seq_length, seq_length) Attention weights context: torch.tensor(bs,
            seq_length, dim) Contextualized layer. Optional: only if `output_attentions=True`
        """
        bs, q_length, dim = query.size()
        print(f"bs: {bs}, q_length: {q_length}, dim: {dim}")
        k_length = key.size(1)
        print(f"k_length: {k_length}")
        # assert dim == self.dim, f'Dimensions do not match: {dim} input vs {self.dim} configured'
        # assert key.size() == value.size()

        dim_per_head = self.dim // self.n_heads
        print(f"dim_per_head: {dim_per_head}")

        mask_reshp = (bs, 1, 1, k_length)
        print(f"mask_reshp: {mask_reshp}")

        def shape(x: torch.Tensor) -> torch.Tensor:
            """separate heads"""
            return x.view(bs, -1, self.n_heads, dim_per_head).transpose(1, 2)

        def unshape(x: torch.Tensor) -> torch.Tensor:
            """group heads"""
            return x.transpose(1, 2).contiguous().view(bs, -1, self.n_heads * dim_per_head)

        q = shape(self.q_lin(query))  # (bs, n_heads, q_length, dim_per_head)
        k = shape(self.k_lin(key))  # (bs, n_heads, k_length, dim_per_head)
        v = shape(self.v_lin(value))  # (bs, n_heads, k_length, dim_per_head)

        q = q / math.sqrt(dim_per_head)  # (bs, n_heads, q_length, dim_per_head)
        scores = torch.matmul(q, k.transpose(2, 3))  # (bs, n_heads, q_length, k_length)
        mask = (mask == 0).view(mask_reshp).expand_as(scores)  # (bs, n_heads, q_length, k_length)
        print("=" * 80)
        print(mask)
        print("-" * 80)
        scores = scores.masked_fill(
            mask, torch.tensor(torch.finfo(scores.dtype).min)
        )  # (bs, n_heads, q_length, k_length)
        print(">" * 80)
        print(scores)
        print("<" * 80)

        weights = nn.functional.softmax(scores, dim=-1)  # (bs, n_heads, q_length, k_length)
        print("#" * 80)
        print(weights)
        print("@" * 80)
        weights = self.dropout(weights)  # (bs, n_heads, q_length, k_length)

        # Mask heads if we want to
        if head_mask is not None:
            weights = weights * head_mask

        context = torch.matmul(weights, v)  # (bs, n_heads, q_length, dim_per_head)
        context = unshape(context)  # (bs, q_length, dim)
        context = self.out_lin(context)  # (bs, q_length, dim)

        if output_attentions:
            return (context, weights)
        else:
            return (context,)

In [78]:
config = PretrainedConfig(n_heads=8, dim=512, attention_dropout=0.2)
mh_attention = MultiHeadSelfAttention(config=config)

In [79]:
bs = 1
seq_length = 5
dim = 512

torch.manual_seed(42)

query = torch.rand(bs, seq_length, dim)
key = torch.rand(bs, seq_length, dim)
value = torch.rand(bs, seq_length, dim)
#mask = torch.where(torch.randn(bs, seq_length) > 0, 1, 0)
mask = torch.Tensor([1, 1, 1, 1, 0])
print(query.size())
print(query)
print(mask)

torch.Size([1, 5, 512])
tensor([[[0.8823, 0.9150, 0.3829,  ..., 0.4078, 0.5411, 0.0410],
         [0.6556, 0.1186, 0.1836,  ..., 0.3092, 0.0702, 0.1836],
         [0.7785, 0.4253, 0.7124,  ..., 0.4593, 0.4520, 0.1866],
         [0.5729, 0.3465, 0.2419,  ..., 0.5122, 0.5909, 0.9712],
         [0.7322, 0.6075, 0.3988,  ..., 0.7021, 0.7154, 0.4832]]])
tensor([1., 1., 1., 1., 0.])


In [80]:
context = mh_attention.forward(query, key, value, mask, None)
print(context)

bs: 1, q_length: 5, dim: 512
k_length: 5
dim_per_head: 64
mask_reshp: (1, 1, 1, 5)
tensor([[[[False, False, False, False,  True],
          [False, False, False, False,  True],
          [False, False, False, False,  True],
          [False, False, False, False,  True],
          [False, False, False, False,  True]],

         [[False, False, False, False,  True],
          [False, False, False, False,  True],
          [False, False, False, False,  True],
          [False, False, False, False,  True],
          [False, False, False, False,  True]],

         [[False, False, False, False,  True],
          [False, False, False, False,  True],
          [False, False, False, False,  True],
          [False, False, False, False,  True],
          [False, False, False, False,  True]],

         [[False, False, False, False,  True],
          [False, False, False, False,  True],
          [False, False, False, False,  True],
          [False, False, False, False,  True],
          [False, 