In [1]:
import numpy as np
import torch
import torch.nn as nn

In [2]:
def softmax(Z):
    Z = np.exp(Z - Z.max(axis=-1, keepdims=True))
    return Z / Z.sum(axis=-1, keepdims=True)

def self_attention(X, mask, W_KQV, W_out):
    K, Q, V = np.split(X@W_KQV, 3, axis=1)
    print(K.shape, Q.shape, V.shape)
    attn = softmax(K@Q.T / np.sqrt(K.shape[1]) + mask)
    print(attn.shape)
    return attn@V@W_out, attn

In [3]:
T, d = 100, 64
attn = nn.MultiheadAttention(d, 1, bias=False, batch_first=True)
M = torch.triu(-float("inf")*torch.ones(T, T),1)
X = torch.rand(1, T, d)
Y_, A_ = attn(X, X, X, attn_mask=M)

In [4]:
attn.in_proj_weight.shape

torch.Size([192, 64])

In [5]:
attn.out_proj.weight.shape

torch.Size([64, 64])

In [6]:
Y, A = self_attention(X[0].numpy(), M.numpy(), attn.in_proj_weight.detach().numpy().T,
                      attn.out_proj.weight.detach().numpy().T)

(100, 64) (100, 64) (100, 64)
(100, 100)


In [7]:
Y.shape, A.shape

((100, 64), (100, 100))

In [8]:
np.linalg.norm(Y - Y_[0].detach().numpy())

np.float64(2.6030663793543387e-06)

In [9]:
C = np.random.randn(5, 4, 10, 3)
D = np.random.randn(5, 4, 3, 6)
(C@D).shape

(5, 4, 10, 6)

In [10]:
def self_attention(X, mask, W_KQV, W_out):
    K, Q, V = np.split(X@W_KQV, 3, axis=-1)
    attn = softmax(K@Q.swapaxes(-1, -2) / np.sqrt(X.shape[-1]) + mask)
    return attn@V@W_out, attn

In [11]:
B, T, d = 50, 100, 64
X = torch.randn(B, T, d)
M = torch.triu(-float("inf")*torch.ones(T, T), 1)
Y_, A_ = attn(X, X, X, attn_mask=M)

In [12]:
Y, A = self_attention(X.numpy(), M.numpy(),
                      attn.in_proj_weight.detach().numpy().T,
                      attn.out_proj.weight.detach().numpy().T)

In [13]:
np.linalg.norm(A - A_.detach().numpy())

np.float64(8.928570338039558e-07)

In [14]:
X.shape

torch.Size([50, 100, 64])

In [15]:
def multihead_attention(X, mask, heads, W_KQV, W_out):
    B, T, d = X.shape
    K, Q, V= np.split(X@W_KQV, 3, axis=-1)
    # B x T x d => B x heads x T x d / heads
    K, Q, V = [a.reshape(B, T, heads, d // heads).swapaxes(1, 2) for a in [K, Q, V]]
    attn = softmax(K@Q.swapaxes(-1, -2) / np.sqrt(d // heads) + mask)
    return (attn@V).swapaxes(1, 2).reshape(B, T, d)@W_out, attn

In [16]:
heads = 4
attn = nn.MultiheadAttention(d, heads, bias=False, batch_first=True)
Y_, A_ = attn(X, X, X, attn_mask = M)

In [17]:
A.shape

(50, 100, 100)

In [18]:
A_.shape

torch.Size([50, 100, 100])

In [19]:
X.shape

torch.Size([50, 100, 64])

In [20]:
attn.in_proj_weight.shape

torch.Size([192, 64])

In [21]:
Y, A = multihead_attention(X.numpy(), M.numpy(), heads,
                           attn.in_proj_weight.detach().numpy().T,
                           attn.out_proj.weight.detach().numpy().T)

In [22]:
np.linalg.norm(Y - Y_.detach().numpy())

np.float64(1.0437048537286735e-05)

In [23]:
np.linalg.norm(A.mean(axis=1) - A_.detach().numpy())

np.float64(7.301352137936004e-07)

In [24]:
def layer_norm(Z, eps):
    return (Z - Z.mean(axis = -1, keepdims=True)) / np.sqrt(Z.var(axis=-1,keepdims=True) + eps)

def relu(Z):
    return np.maximum(Z, 0)

def transformer(X, mask, heads, W_KQV, W_out, W_ff1, W_ff2, eps):
    Z = multihead_attention(X, mask, heads, W_KQV, W_out)[0]
    Z = layer_norm(X + Z, eps)
    return layer_norm(Z + relu(Z@W_ff1)@W_ff2, eps)

In [25]:
trans = nn.TransformerEncoderLayer(d, heads, dim_feedforward=128, dropout=0.0, batch_first=True)
trans.linear1.bias.data.zero_()
trans.linear2.bias.data.zero_()
Y_ = trans(X, M)

In [26]:
Y = transformer(X.numpy(), M.numpy(), heads,
                trans.self_attn.in_proj_weight.detach().numpy().T,
                trans.self_attn.out_proj.weight.detach().numpy().T,
                trans.linear1.weight.detach().numpy().T,
                trans.linear2.weight.detach().numpy().T,
                trans.norm1.eps)

In [27]:
%set_env PYTHONPATH ./python
%set_env NEEDLE_BACKEND nd

env: PYTHONPATH=./python
env: NEEDLE_BACKEND=nd


In [28]:
import sys
sys.path.append('./python')

In [34]:
from typing import List
from needle.autograd import Tensor
import needle.backend_ndarray.ndarray as ndarray
from needle import ops
import needle.init as init
import numpy as np
from needle.nn.nn_sequence import Embedding
from needle.nn.nn_basic import (
    Parameter, 
    Module, 
    ReLU,
    Dropout,
    LayerNorm1d,
    Linear,
    Sequential
)
from needle.nn.nn_transformer import (
    MultiHeadAttention,
    TransformerLayer
)

In [None]:
my_transformer_layer = TransformerLayer(
        d, heads, 128, hidden_size,
        dropout=0.0, causal=True, device=needle.cuda())

In [None]:
trans.linear2.weight.detach().numpy().T.shape

(128, 64)

In [None]:
X.shape

torch.Size([50, 100, 64])

In [None]:
np.linalg.norm(Y - Y_.detach().numpy())

np.float64(5.206937086569848e-05)

In [None]:
a = np.array([1, 2, 3, 4, 5, 6])
b = a.reshape(2, 3)
b

array([[1, 2, 3],
       [4, 5, 6]])

In [None]:
c = a.reshape(6)
c

array([1, 2, 3, 4, 5, 6])