In [1]:
import torch as t
from torch import einsum
from einops import rearrange, reduce, repeat
import bert_tests


In [58]:
def raw_attention_pattern(token_activations, num_heads, project_query, project_key):
    K = rearrange(project_key(token_activations), "b s (n h) -> b s n h", n = num_heads)
    Q = rearrange(project_query(token_activations), "b s (n h) -> b s n h", n = num_heads)

    KbyQ = t.einsum("bsnh,btnh -> bnst", K, Q)

    d_k = token_activations.shape[2]/num_heads

    out = KbyQ/t.sqrt(t.tensor([d_k]))
    return out

bert_tests.test_attention_pattern_fn(raw_attention_pattern)


attention pattern raw MATCH!!!!!!!!
 SHAPE (2, 12, 3, 3) MEAN: 0.007011 STD: 0.115 VALS [0.07875 0.1818 0.1318 0.1061 -0.176 0.1207 -0.04416 -0.08958 -0.008028 -0.03693...]


In [59]:
def bert_attention(token_activations, num_heads, attention_pattern, project_value, project_output):
    s = t.nn.Softmax(dim=2)

    out = s(attention_pattern) # batch_size, head_num, key_token, query_token

    out = rearrange(out, "b n k q -> b n q k 1")

    V = rearrange(project_value(token_activations), "b k (n h) -> b n 1 k h", n = num_heads) # batch_size, num_heads, 1, key, head_size

    out = einsum("bnqkh,bnqkh -> bnqh", out, V)

    out = rearrange(out, "b n q h -> b q (n h)")

    return project_output(out)

bert_tests.test_attention_fn(bert_attention)



attention MATCH!!!!!!!!
 SHAPE (2, 3, 768) MEAN: -0.002339 STD: 0.1166 VALS [0.106 -0.1369 -0.0374 -0.0001321 -0.07958 0.09887 0.02255 0.05939 0.086 0.1509...]


In [60]:
class MultiHeadedSelfAttention(t.nn.Module):
    def __init__(self, num_heads, hidden_size):
        super(MultiHeadedSelfAttention, self).__init__()

        hidden_dim = num_heads * hidden_size

        self.query = t.nn.Linear(hidden_size, hidden_size)
        self.key = t.nn.Linear(hidden_size, hidden_size)
        self.value = t.nn.Linear(hidden_size, hidden_size)
        self.output = t.nn.Linear(hidden_size, hidden_size)

        self.num_heads = num_heads
        self.hidden_size = hidden_size

    def forward(self, input):
        attention_scores = raw_attention_pattern(input, self.num_heads, self.query, self.key)

        attention = bert_attention(input, self.num_heads, attention_scores, self.value, self.output)

        return attention

bert_tests.test_bert_attention(MultiHeadedSelfAttention)

        

bert MATCH!!!!!!!!
 SHAPE (2, 3, 768) MEAN: -0.001554 STD: 0.1736 VALS [-0.08316 -0.09165 -0.03188 -0.03013 0.1001 0.09549 -0.1046 0.07742 0.0424 0.05553...]


In [65]:
def bert_mlp(token_activations, linear_1, linear_2):
    out = linear_1(token_activations)
    out = t.nn.GELU()(out)
    return linear_2(out)

bert_tests.test_bert_mlp(bert_mlp)


bert mlp MATCH!!!!!!!!
 SHAPE (2, 3, 768) MEAN: -0.003054 STD: 0.1041 VALS [0.1262 0.01134 0.06912 0.05845 0.06832 0.06498 -0.07017 -0.1155 -0.004871 0.2145...]


In [67]:
class BertMLP(t.nn.Module):
    def __init__(self, input_size, intermediate_size):
        super(BertMLP, self).__init__()
        self.linear_1 = t.nn.Linear(input,intermediate_size)
        self.linear_2 = t.nn.Linear(intermediate,input)
    
    def forward(self,input):
        return bert_mlp(input,self.linear_1,self.linear_2)



EinopsError:  Error while processing sum-reduction pattern "...d -> ...".
 Input tensor shape: torch.Size([3, 5, 6]). Additional info: {}.
 Invalid axis identifier: …d
not a valid python identifier

In [74]:
class LayerNorm(t.nn.Module):
    def __init__(self, normalized_dim):
        super().__init__()
        self.weight = t.nn.Parameter(t.ones((normalized_dim,)))
        self.bias = t.nn.Parameter(t.zeros((normalized_dim,)))
    
    def forward(self, input):
        mean = input.mean(-1).unsqueeze(-1)
        mean.detach()
        stdev = input.std(-1,unbiased = False).unsqueeze(-1)
        stdev.detach()
        out = (input - mean)/stdev
        return out*self.weight + self.bias
bert_tests.test_layer_norm(LayerNorm)



layer norm MATCH!!!!!!!!
 SHAPE (20, 10) MEAN: 4.768e-09 STD: 1.003 VALS [-0.94 -1.641 -0.1301 -0.3103 1.493 -0.2086 -0.1952 -0.2518 1.973 0.2104...]


In [None]:
class BertBlock(t.nn.Module):
    def __init__(self, hidden_size, intermediate_size, num_heads, dropout):
        