In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [181]:
class RefactorConfig:
    hidden_size = 768
    head_size = 64
    num_heads = 12
    
config = RefactorConfig

In [182]:
class MultiHeadAttention(nn.Module):
    
    def __init__(self, config):
        super().__init__()
        self.hidden_size = config.hidden_size
        self.head_size = config.head_size
        self.num_heads = config.num_heads
        
        self.qkv = nn.Linear(config.hidden_size, config.hidden_size * 3)
    
    def split_heads(self, x, k=False):
        # new_x_shape = x.size()[:-1] + (self.n_head, x.size(-1) // self.n_head)
        x = x.view(x.shape[0], -1, self.num_heads, self.head_size)
        if k:
            # (batch, head, head_features, seq_length)
            return x.permute(0, 2, 3, 1)
        else:
            # (batch, head, seq_length, head_features)
            return x.permute(0, 2, 1, 3)
    
    def forward(self, hidden_states, mask):
        # compute mixed qkv
        mixed_qkv = self.qkv(hidden_states)
        query, key, value = mixed_qkv.split(self.hidden_size, dim=2)
        # q, k, v
        query = self.split_heads(query)
        key = self.split_heads(key, k=True)
        value = self.split_heads(value) 
    
        return query, key, value


In [177]:
class BertSelfAttention(nn.Module):
    def __init__(self):
        super(BertSelfAttention, self).__init__()

        self.num_attention_heads = 12
        self.attention_head_size = 64
        self.all_head_size = self.num_attention_heads * self.attention_head_size

        self.query = nn.Linear(768, self.all_head_size)
        self.key = nn.Linear(768, self.all_head_size)
        self.value = nn.Linear(768, self.all_head_size)

    def transpose_for_scores(self, x):
        new_x_shape = x.size(
        )[:-1] + (self.num_attention_heads, self.attention_head_size)
        x = x.view(*new_x_shape)
        return x.permute(0, 2, 1, 3)

    def forward(self, hidden_states, attention_mask=None, head_mask=None):
        mixed_query_layer = self.query(hidden_states)
        mixed_key_layer = self.key(hidden_states)
        mixed_value_layer = self.value(hidden_states)

        query_layer = self.transpose_for_scores(mixed_query_layer)
        key_layer = self.transpose_for_scores(mixed_key_layer)
        value_layer = self.transpose_for_scores(mixed_value_layer)

        return query_layer, key_layer, value_layer
#         # Take the dot product between "query" and "key" to get the raw attention scores.
#         attention_scores = torch.matmul(
#             query_layer, key_layer.transpose(-1, -2)
#         )
#         attention_scores = attention_scores / math.sqrt(
#             self.attention_head_size
#         )
#         if attention_mask is not None:
#             # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
#             attention_scores = attention_scores + attention_mask

#         # Normalize the attention scores to probabilities.
#         attention_probs = nn.Softmax(dim=-1)(attention_scores)

#         # This is actually dropping out entire tokens to attend to, which might
#         # seem a bit unusual, but is taken from the original Bert paper.
#         attention_probs = self.dropout(attention_probs)

#         context_layer = torch.matmul(attention_probs, value_layer)

#         context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
#         new_context_layer_shape = context_layer.size(
#         )[:-2] + (self.all_head_size, )
#         context_layer = context_layer.view(*new_context_layer_shape)

#         outputs = (context_layer, attention_probs
#                   ) if self.output_attentions else (context_layer, )
#         return outputs


In [178]:
class Conv1D(nn.Module):
    def __init__(self, nf, nx):
        super(Conv1D, self).__init__()
        self.nf = nf
        w = torch.empty(nx, nf)
        nn.init.normal_(w, std=0.02)
        self.weight = nn.Parameter(w)
        self.bias = nn.Parameter(torch.zeros(nf))

    def forward(self, x):
        size_out = x.size()[:-1] + (self.nf, )
        x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight)
        x = x.view(*size_out)
        return x

class GPTSelfAttention(nn.Module):
    def __init__(self, nx, n_ctx, scale=False):
        """
        This version is modified to support batch training
        mask needs to be precomputed
        """
        super().__init__()
        n_state = nx  

        self.n_head = 12
        self.split_size = n_state
        self.scale = scale
        self.c_attn = Conv1D(n_state * 3, nx)
        self.c_proj = Conv1D(n_state, nx)

    def _attn(self, q, k, v, mask):
        w = torch.matmul(q, k)
        w = w / math.sqrt(v.size(-1))
        # w = w * mask - 1e4 * (1 - mask)
        w.masked_fill_(~mask, -1e4)
        w = F.softmax(w, dim=-1)
        w = self.attn_dropout(w)
        return torch.matmul(w, v)

    def merge_heads(self, x):
        x = x.permute(0, 2, 1, 3).contiguous()
        new_x_shape = x.size()[:-2] + (x.size(-2) * x.size(-1), )
        return x.view(*new_x_shape)  # in Tensorflow implem: fct merge_states

    def split_heads(self, x, k=False):
        new_x_shape = x.size()[:-1] + (self.n_head, x.size(-1) // self.n_head)
        x = x.view(*new_x_shape)  # in Tensorflow implem: fct split_states
        if k:
            # (batch, head, head_features, seq_length)
            return x.permute(0, 2, 3, 1)
        else:
            # (batch, head, seq_length, head_features)
            return x.permute(0, 2, 1, 3)

    def forward(self, x, layer_past=None, mask=None):
        x = self.c_attn(x)
        
        
        
        query, key, value = x.split(self.split_size, dim=2)
        query = self.split_heads(query)
        key = self.split_heads(key, k=True)
        value = self.split_heads(value)
        
        return query, key, value
        
        if layer_past is not None:
            # transpose back cf below
            past_key, past_value = layer_past[0].transpose(-2,
                                                           -1), layer_past[1]
            key = torch.cat((past_key, key), dim=-1)
            value = torch.cat((past_value, value), dim=-2)
        # transpose to have same shapes for stacking
        present = torch.stack((key.transpose(-2, -1), value))

        a = self._attn(query, key, value, mask)
        a = self.merge_heads(a)
        a = self.c_proj(a)
        a = self.resid_dropout(a)
        return a, present


In [179]:
x = torch.randn(1, 10, 768)

In [183]:
bert_attention = BertSelfAttention()
gpt_attention = GPTSelfAttention(768, 12)
attention = MultiHeadAttention(config)

In [184]:
q, k, v = bert_attention(x)

In [185]:
gpt_attention.c_attn.weight.data = torch.cat([bert_attention.query.weight.data, 
                                              bert_attention.key.weight.data, 
                                              bert_attention.value.weight.data], dim=0).t()

In [186]:
gpt_attention.c_attn.bias.data = torch.cat([bert_attention.query.bias.data, 
                                            bert_attention.key.bias.data, 
                                            bert_attention.value.bias.data], dim=0).t()

In [187]:
attention.qkv.weight.data = torch.cat([bert_attention.query.weight.data, 
                                              bert_attention.key.weight.data, 
                                              bert_attention.value.weight.data], dim=0)

attention.qkv.bias.data = torch.cat([bert_attention.query.bias.data, 
                                            bert_attention.key.bias.data, 
                                            bert_attention.value.bias.data], dim=0)

In [188]:
q_g, k_g, v_g = gpt_attention(x)

In [195]:
k_g.sum()

tensor(-49.4289, grad_fn=<SumBackward0>)

In [197]:
k.sum()

tensor(-49.4289, grad_fn=<SumBackward0>)

In [198]:
q_a, k_a, v_a = attention(x, mask=None)

In [199]:
k_a.sum()

tensor(-49.4289, grad_fn=<SumBackward0>)