In [None]:
## Q2.1.2/3
def attention(query, key, value, mask=None, dropout=None):
    # Your code here
    # get dimension of keys/queries
    d_k = query.size(-1)
    
    # dot product btwn query & key
    scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
    
    # apply mask if needed
    if mask is not None:
        scores = scores.masked_fill(mask == 0, float('-inf'))
        
    # attention weights w/ softmax
    attn = torch.softmax(scores, dim=-1)
    
    # apply dropout if needed
    if dropout is not None:
        attn = dropout(attn)
        
    # attention weights to get weighted sum of vals
    out = torch.matmul(attn, value)
    
    return out, attn

In [None]:
## Q2.1.4
class MultiHeadedAttention(nn.Module):
    def __init__(self, h, d_model, dropout=0.1):
        # Your code here
        super(MultiHeadedAttention, self).__init__()
        # check if d_model is divisible by h
        assert d_model % h == 0 
        
        # num attenton heads
        self.h = h
        # dimension of each head
        self.d_k = d_model // h  
        
        # linear layers for projecting query, key, and value to multi-head 
        self.linear_q = nn.Linear(d_model, d_model)  # query
        self.linear_k = nn.Linear(d_model, d_model)  # key
        self.linear_v = nn.Linear(d_model, d_model)  # val
        
        # linear layer to combine all heads
        self.linear_out = nn.Linear(d_model, d_model)
        
        # dropout layer
        self.dropout = nn.Dropout(dropout)
        
        # attention weights
        self.attn = None

    def forward(self, query, key, value, mask=None):
        # Your code here
        batch_size = query.size(0)
        
        # linear projections
        query = self.linear_q(query).view(batch_size, -1, self.h, self.d_k).transpose(1, 2)
        key = self.linear_k(key).view(batch_size, -1, self.h, self.d_k).transpose(1, 2)
        value = self.linear_v(value).view(batch_size, -1, self.h, self.d_k).transpose(1, 2)
        
        # apply attention
        out, self.attn = attention(query, key, value, mask, self.dropout)
        
        # concat attention heads
        out = out.transpose(1, 2).contiguous().view(batch_size, -1, self.h * self.d_k)
        
        # last linear transformation to combine heads
        out = self.linear_out(out)
        
        return out