In [1]:
import time
import torch
import torch.nn as nn
import numpy as np
import random
from torch import optim
import matplotlib.pyplot as plt
from typing import List
import math

In [31]:
class AttentionHead(nn.Module):
    def __init__(self, d_model=512, d_internal=64, dropout=0.1):
        """
        :param d_model: The dimension of the inputs and outputs of the layer (note that the inputs and outputs
        have to be the same size for the residual connection to work)
        :param d_internal: The "internal" dimension used in the self-attention computation. Your keys and queries
        should both be of this length.
        """
        super().__init__()
        self.d_model = d_model
        self.d_internal = d_internal
        self.query = nn.Linear(d_model, d_internal)
        self.key = nn.Linear(d_model, d_internal)
        self.value = nn.Linear(d_model, d_model)
        
        self.softmax = nn.Softmax(dim=-1)
        self.linear = nn.Linear(d_model, d_model)
        self.relu = nn.ReLU()
        
        self.linear2 = nn.Linear(d_model, d_internal)
        self.relu2 = nn.ReLU()
        self.dropout2 = nn.Dropout(dropout)
        self.linear3 = nn.Linear(d_internal, d_model)
        self.dropout3 = nn.Dropout(dropout)
        self.layernorm3 = nn.LayerNorm(d_model)
        

    def forward(self, query, key, value):
        """
        :param input_vecs: an input tensor of shape [seq len, d_model]
        :return: a tuple of two elements:
            - a tensor of shape [seq len, d_model] representing the log probabilities of each position in the input
            - a tensor of shape [seq len, seq len], representing the attention map for this layer
        """

        q = self.query(query).permute(1, 0, 2) # batch, n_pixels, dim 
        k = self.key(key).permute(1, 0, 2) # batch, m_pixels, dim 
        v = self.value(value).permute(1, 0, 2) # batch, m_pixels, dim 
        q_k = torch.matmul(q, k.transpose(1,2)) # batch, n_pixels, m_pixels
        q_k /= self.d_internal**0.5
        probs = self.softmax(q_k)
        probs /= (1e-9 + probs.sum(dim=1, keepdim=True))
        aten_scores = torch.matmul(probs, v).permute(1,0,2)
        res_con = aten_scores + query
        aten_weights = self.linear(res_con)

        aten_weights = self.relu(aten_weights)
        aten_weights2 = self.linear2(aten_weights)
        aten_weights2 = self.relu2(aten_weights2)
        aten_weights2 = self.dropout2(aten_weights2)
        aten_weights2 = self.linear3(aten_weights2)
        aten_weights2 = self.dropout3(aten_weights2)
        aten_weights = aten_weights2 + aten_weights
        aten_weights = self.layernorm3(aten_weights)

        return aten_weights

In [32]:
a = AttentionHead()

In [None]:
class MultiheadAttention(nn.Module):
    def __init__(self, num_positions=3, num_layers=8, d_model=512, d_internal=64):
        """
        :param vocab_size: vocabulary size of the embedding layer
        :param num_positions: max sequence length that will be fed to the model; should be 20
        :param d_model: see TransformerLayer
        :param d_internal: see TransformerLayer
        :param num_classes: number of classes predicted at the output layer; should be 3
        :param num_layers: number of TransformerLayers to use; can be whatever you want
        """
        super().__init__()
        self.num_layers = num_layers
        self.positional_encoding = PositionalEncoding(d_model)
        self.attention_heads = nn.Sequential(*[nn.Sequential(AttentionHead(d_model, d_internal))
                     for i in range(num_layers)])
        # self.linear = nn.Linear(d_model, 3)
        # self.softmax = nn.LogSoftmax(dim=-1)

    def forward(self, inp):
        """

        :param indices: list of input indices
        :return: A tuple of the softmax log probabilities (should be a 20x3 matrix) and a list of the attention
        maps you use in your layers (can be variable length, but each should be a 20x20 matrix)
        """
        # inp = self.emb(indices)
        input_vecs = self.positional_encoding(inp)
        aten_probs = input_vecs
        for attention_head in self.attention_heads:
            aten_probs = attention_head(aten_probs, aten_probs, aten_probs)

        # out = self.linear(aten_probs)
        # out = self.softmax(out)

        return aten_probs

In [33]:
a(torch.ones(2,3,512),torch.ones(2,3,512),torch.ones(2,3,512))

tensor([[[0.8626, 0.5521, 0.0000,  ..., 0.0000, 0.0000, 0.1414],
         [0.8626, 0.5521, 0.0000,  ..., 0.0000, 0.0000, 0.1414],
         [0.8626, 0.5521, 0.0000,  ..., 0.0000, 0.0000, 0.1414]],

        [[0.8626, 0.5521, 0.0000,  ..., 0.0000, 0.0000, 0.1414],
         [0.8626, 0.5521, 0.0000,  ..., 0.0000, 0.0000, 0.1414],
         [0.8626, 0.5521, 0.0000,  ..., 0.0000, 0.0000, 0.1414]]],
       grad_fn=<ReluBackward0>)

In [14]:
x = torch.zeros((2,3,4))
y = torch.zeros((2,3,4))

In [15]:
y=y.transpose(1,2)

In [19]:
y.shape

torch.Size([2, 4, 3])

In [17]:
z=torch.matmul(x, y)

In [18]:
z.shape

torch.Size([2, 3, 3])