In [1]:
import time
import torch
import torch.nn as nn
import numpy as np
import random
from torch import optim
import matplotlib.pyplot as plt
from typing import List
import math

In [62]:
class AttentionHead(nn.Module):
    def __init__(self, d_model=512, d_internal=64, dropout=0.1):
        """
        :param d_model: The dimension of the inputs and outputs of the layer (note that the inputs and outputs
        have to be the same size for the residual connection to work)
        :param d_internal: The "internal" dimension used in the self-attention computation. Your keys and queries
        should both be of this length.
        """
        super().__init__()
        self.d_model = d_model
        self.d_internal = d_internal
        self.query = nn.Linear(d_model, d_internal)
        self.key = nn.Linear(d_model, d_internal)
        self.value = nn.Linear(d_model, d_model)
        
        self.softmax = nn.Softmax(dim=-1)
#         self.linear = nn.Linear(d_model, d_model)
#         self.relu = nn.ReLU()
        
#         self.linear2 = nn.Linear(d_model, d_internal)
#         self.relu2 = nn.ReLU()
#         self.dropout2 = nn.Dropout(dropout)
#         self.linear3 = nn.Linear(d_internal, d_model)
#         self.dropout3 = nn.Dropout(dropout)
#         self.layernorm3 = nn.LayerNorm(d_model)
        

    def forward(self, query, key, value):
        """
        :param input_vecs: an input tensor of shape [seq len, d_model]
        :return: a tuple of two elements:
            - a tensor of shape [seq len, d_model] representing the log probabilities of each position in the input
            - a tensor of shape [seq len, seq len], representing the attention map for this layer
        """
        #n_pixels, batch, dim
        q = self.query(query).permute(1, 0, 2) # batch, n_pixels, dim 
        k = self.key(key).permute(1, 0, 2) # batch, m_pixels, dim 
        v = self.value(value).permute(1, 0, 2) # batch, m_pixels, dim 
        q_k = torch.matmul(q, k.transpose(1,2)) # batch, n_pixels, m_pixels
        q_k /= self.d_internal**0.5
        probs = self.softmax(q_k)
        probs /= (1e-9 + probs.sum(dim=1, keepdim=True))
        aten_scores = torch.matmul(probs, v).permute(1,0,2)
#         res_con = aten_scores + query
#         aten_weights = self.linear(res_con)

#         aten_weights = self.relu(aten_weights)
#         aten_weights2 = self.linear2(aten_weights)
#         aten_weights2 = self.relu2(aten_weights2)
#         aten_weights2 = self.dropout2(aten_weights2)
#         aten_weights2 = self.linear3(aten_weights2)
#         aten_weights2 = self.dropout3(aten_weights2)
#         aten_weights = aten_weights2 + aten_weights
#         aten_weights = self.layernorm3(aten_weights)

        return aten_scores

In [63]:
a = AttentionHead(512,64)

In [64]:
class FeedForward(nn.Module):
    def __init__(self, d_model=512, d_internal=64, dropout=0.1):
        super().__init__()
        self.linear = nn.Linear(d_model, d_internal)
        self.linear2 = nn.Linear(d_internal, d_model)
        self.dropout = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
        self.relu = nn.ReLU()
        self.instancenorm = nn.InstanceNorm1d(d_model)
        self.relu2 = nn.ReLU()
        
    def forward(self, input):
        out = self.linear(input)
        out = self.relu(out)
        out = self.dropout(out)
        out = self.linear2(out)
        out = self.dropout2(out)
        out = out + input
        
        out = self.instancenorm(out.permute(1, 2, 0))
        out = self.relu2(out.permute(2, 0, 1))
        return out
        


class MultiheadAttention(nn.Module):
    def __init__(self, num_layers=1, d_model=512, d_internal=64):
        """
        :param vocab_size: vocabulary size of the embedding layer
        :param num_positions: max sequence length that will be fed to the model; should be 20
        :param d_model: see TransformerLayer
        :param d_internal: see TransformerLayer
        :param num_classes: number of classes predicted at the output layer; should be 3
        :param num_layers: number of TransformerLayers to use; can be whatever you want
        """
        super().__init__()
        self.num_layers = num_layers
        self.attention_heads = nn.ModuleList()
        for i in range(num_layers):
            self.attention_heads.append(AttentionHead(d_model, d_internal))
                     
        self.linear = nn.Linear(d_model, d_model)
        self.relu = nn.ReLU()
        self.layernorm = nn.LayerNorm(d_model)
        
        self.feedforward = FeedForward(d_model, d_internal)

    def forward(self, q, k, v):
        """

        :param indices: list of input indices
        :return: A tuple of the softmax log probabilities (should be a 20x3 matrix) and a list of the attention
        maps you use in your layers (can be variable length, but each should be a 20x20 matrix)
        """
        # inp = self.emb(indices)
        aten_scores = None
        first_layer = True
        for attention_head in self.attention_heads:
            if first_layer:
                aten_scores = attention_head(q, k, v)
                concat = aten_scores
                first_layer = False
            else:
                aten_scores = attention_head(q,k,v)
                concat = torch.cat((concat, aten_scores), -1)

        out = self.linear(concat)
        out = self.relu(out)
        out = out + q
        out = self.layernorm(out)
        
        out = self.feedforward(out)

        return out

In [65]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model: int=256, num_positions: int=3, batched=False):
        """
        :param d_model: dimensionality of the embedding layer to your model; since the position encodings are being
        added to character encodings, these need to match (and will match the dimension of the subsequent Transformer
        layer inputs/outputs)
        :param num_positions: the number of positions that need to be encoded; the maximum sequence length this
        module will see
        :param batched: True if you are using batching, False otherwise
        """
        super().__init__()
        self.conv = nn.Conv1d(num_positions, d_model, kernel_size=1)
        self.batchnorm = nn.BatchNorm1d(d_model)
        self.relu = nn.ReLU()
        self.conv2 = nn.Conv1d(d_model, d_model, kernel_size=1)

    def forward(self, x):
        """
        :param x: If using batching, should be [batch size, seq len, embedding dim]. Otherwise, [seq len, embedding dim]
        :return: a tensor of the same size with positional embeddings added in
        """
        # B, n_pixels, 3
        x = x.transpose(1,2).contiguous()
        out = self.conv(x)
        out = self.batchnorm(out)
        out = self.relu(out)
        out = self.conv2(out)
        return out

In [66]:
p = PositionalEncoding()

In [67]:
p(torch.ones(1,10,3))

tensor([[[ 0.0410,  0.0410,  0.0410,  ...,  0.0410,  0.0410,  0.0410],
         [-0.0038, -0.0038, -0.0038,  ..., -0.0038, -0.0038, -0.0038],
         [ 0.0067,  0.0067,  0.0067,  ...,  0.0067,  0.0067,  0.0067],
         ...,
         [-0.0259, -0.0259, -0.0259,  ..., -0.0259, -0.0259, -0.0259],
         [-0.0021, -0.0021, -0.0021,  ..., -0.0021, -0.0021, -0.0021],
         [ 0.0542,  0.0542,  0.0542,  ...,  0.0542,  0.0542,  0.0542]]],
       grad_fn=<SqueezeBackward1>)

In [68]:
b = MultiheadAttention(4)

In [69]:
b(torch.ones(1,10,512))

TypeError: forward() missing 2 required positional arguments: 'k' and 'v'

In [70]:
z=torch.ones(2,3,512)
a(z,z,z)

tensor([[[ 0.2936,  1.0847,  0.2158,  ..., -0.0325,  0.5631,  0.3757],
         [ 0.2936,  1.0847,  0.2158,  ..., -0.0325,  0.5631,  0.3757],
         [ 0.2936,  1.0847,  0.2158,  ..., -0.0325,  0.5631,  0.3757]],

        [[ 0.2936,  1.0847,  0.2158,  ..., -0.0325,  0.5631,  0.3757],
         [ 0.2936,  1.0847,  0.2158,  ..., -0.0325,  0.5631,  0.3757],
         [ 0.2936,  1.0847,  0.2158,  ..., -0.0325,  0.5631,  0.3757]]],
       grad_fn=<PermuteBackward>)

In [None]:
x = torch.zeros((2,3,4))
y = torch.zeros((2,3,4))

In [None]:
y=y.transpose(1,2)

In [None]:
y.shape

In [None]:
z=torch.matmul(x, y)

In [None]:
z.shape

In [79]:
class TransformerEncoder(nn.Module):
    def __init__(self, d_model=256, d_internal=128, n_heads=1, n_layers=1):
        super().__init__()
        encoder_layer = TransformerEncoderLayer(n_heads=n_heads, d_model=d_model,d_internal=d_internal)
        self.encoder_layers = nn.ModuleList([encoder_layer for i in range(n_layers)])
        
    def forward(self, inp):
        intermediate = inp
        for encoder in self.encoder_layers:
            intermediate = encoder(intermediate)
        return intermediate
        
class TransformerEncoderLayer(nn.Module):
    def __init__(self, d_model=256, d_internal=128, n_heads=1):
        super().__init__()
        self.self_attention = MultiheadAttention(num_layers=n_heads, d_internal=d_internal,d_model=d_model)
        
        
    def forward(self,inp):
        return self.self_attention(inp, inp, inp)
        
        
class TransformerDecoder(nn.Module):
    def __init__(self, d_model=256, d_internal=128, n_heads=1, n_layers=1):
        super().__init__()
        decoder_layer = TransformerDecoderLayer(n_heads=n_heads, d_model=d_model,d_internal=d_internal)
        self.decoder_layers = nn.ModuleList([decoder_layer for i in range(n_layers)])
        
    def forward(self, inp_decoder, kv_encoder):
        intermediate = inp_decoder
        for decoder in self.decoder_layers:
            intermediate = decoder(intermediate, kv_encoder)
        return intermediate
    
class TransformerDecoderLayer(nn.Module):
    def __init__(self, d_model=256, d_internal=128, n_heads=1):
        super().__init__()
        self.self_attention = MultiheadAttention(num_layers=n_heads, d_internal=d_internal,d_model=d_model)
        self.cross_attention = MultiheadAttention(num_layers=n_heads, d_internal=d_internal,d_model=d_model)
        
        
    def forward(self, inp, kv_encoder):
        out = self.self_attention(inp, inp, inp)
        res = self.cross_attention(out, kv_encoder, kv_encoder)
        return res
        

In [92]:
class TransformerFusion(nn.Module):
    def __init__(self, num_layers_encoder = 1, num_layers_decoder=1, d_model=256, d_internal=128, n_heads=1):
        super().__init__()
        self.transformer_encoder = TransformerEncoder(d_model=d_model, d_internal=d_internal, n_heads=n_heads, n_layers=num_layers_encoder)
        self.transformer_decoder = TransformerDecoder(d_model=d_model, d_internal=d_internal, n_heads=n_heads, n_layers=num_layers_decoder)
        self.pos_emb_enc = PositionalEncoding(d_model=d_model, num_positions=3)
        self.pos_emb_dec = PositionalEncoding(d_model=d_model, num_positions=3)
        # self.feature_convs = (pt_utils.Seq(256)
        #         .conv1d(256, bn=True)
        #         .conv1d(256, activation=None))
        
    def forward(self, search_feature, search_xyz, template_feature, template_xyz):
        search_feature = search_feature.permute(2,0,1) + self.pos_emb_dec(search_xyz).permute(2, 0, 1)
        template_feature = template_feature.permute(2,0,1) + self.pos_emb_enc(template_xyz).permute(2,0,1)
        kv_encoder = self.transformer_encoder(template_feature)
        out = self.transformer_decoder(search_feature, kv_encoder)
        out = out.permute(1,2,0)
        # out = self.feature_convs(out)
        return out

In [93]:
t = TransformerFusion()

In [95]:
# bcn
t(torch.ones(2,256,10),torch.ones(2,10,3), torch.ones(2,256,10), torch.ones(2,10,3))

tensor([[[1.0554, 0.4846, 0.0000,  ..., 0.0000, 0.0000, 1.2612],
         [0.0000, 0.0000, 0.0000,  ..., 0.0653, 0.0000, 0.0000],
         [0.4004, 2.3250, 0.0000,  ..., 0.0000, 0.2690, 0.0000],
         ...,
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 1.7656],
         [0.1326, 0.0000, 1.5889,  ..., 0.0000, 1.5916, 0.0000],
         [2.0896, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000]],

        [[0.0000, 1.2883, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
         [0.6922, 0.0000, 0.0000,  ..., 1.8916, 1.3286, 0.6990],
         [0.0000, 1.4544, 0.9785,  ..., 0.3826, 0.0000, 1.0939],
         ...,
         [0.6329, 0.9107, 0.3306,  ..., 0.0000, 0.0000, 0.0000],
         [2.5605, 0.0000, 0.7686,  ..., 0.0000, 0.0000, 0.3391],
         [0.0000, 0.1047, 0.0000,  ..., 0.0000, 0.8144, 0.0000]]],
       grad_fn=<PermuteBackward>)