In [1]:
import torch
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import collections
from typing import Any
from typing import Dict
from typing import List
from typing import Optional
from typing import Tuple
import pandas as pd

2023-07-31 10:51:30.341150: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [2]:
msa = pd.read_csv('SH3_Full_Dataset_8_9_22.csv')
msa['Type'].unique()
naturals_msa = msa[msa['Type']=='Naturals']
seqs = np.asarray([list(seq) for seq in naturals_msa['Sequences']])
norm_re = np.asarray([re for re in naturals_msa['Norm_RE']])

In [3]:
default_aa_keys='-GALMFWKQESPVICYHRNDT'
def fasta_to_df(fasta_file, aa_keys = default_aa_keys):
    """
    creates one hot encoding of a fasta file using biopython's alignio.read process. 
    fasta_file : filepath leading to msa file in fasta format at hand
    """
    column_names = []
    column_names.extend(aa_keys)
    msa=AlignIO.read(fasta_file, "fasta")
    num_columns = len(msa[0].seq)
    column_names = column_names*num_columns
    column_names.append('sequence')
    column_names.append('id')
    init = np.zeros((len(msa), len(column_names)))
    df = pd.DataFrame(init, columns = column_names)
    df.sequence = df.sequence.astype(str)
    df.id=df.id.astype(str)
    
    for row_num, alignment in tqdm(enumerate(msa)):
        sequence = str(alignment.seq)
        for index, char in enumerate(sequence):
            place = aa_keys.find(char)
            df.iloc[row_num, index*len(aa_keys) + place] = 1
        
        df.iloc[row_num,-2]=str(alignment.seq)
        df.iloc[row_num,-1]=str(alignment.id)
    
    return df

In [4]:
def create_frequency_matrix(df, aa_keys = default_aa_keys):
    """takes one hot encoded msa and returns the frequency of each amino acid at each site
    df : pandas dataframe whose columns are the one hot encoding of an msa
    """
    num_columns=len(df['sequence'][0])
    
    frequency_matrix = np.zeros( (len(aa_keys) , num_columns) )
    print('calcing sum')
    freq=df.sum()
    print('sum calced')
    
    num_entries=len(df)
    len_aa_keys = len(aa_keys)
    
    for i in tqdm(range(len(aa_keys))):
        for j in range(num_columns):
            frequency_matrix[i, j] = freq[ i + len_aa_keys * j] / num_entries
    
    return frequency_matrix

In [5]:
from Bio import AlignIO
from Bio.Seq import Seq
from Bio.SeqRecord import SeqRecord
from tqdm import tqdm
vae_alignment = []
phenotypes = []

vae_data = msa[msa['Type']=='VAE'].reset_index()

for r in range(len(vae_data)):
    alignment = vae_data.loc[r]
    if len(alignment['Sequences'])==62:
        record = SeqRecord(seq = Seq(alignment['Sequences']), id = alignment['Header'])
    
    vae_alignment.append(record)
    phenotypes.append(alignment['Norm_RE'])

vae_alignment = AlignIO.MultipleSeqAlignment(vae_alignment)

AlignIO.write(vae_alignment, 'vae_alignment.fasta', 'fasta')

vae_df = fasta_to_df('vae_alignment.fasta')

freq_matrix = create_frequency_matrix(vae_df)

trim_positions = []

for i in range(freq_matrix.shape[1]):
    if 1 in freq_matrix[:,i]:
        trim_positions.append(i)

print(trim_positions)


vae_alignment_trimmed = []


for alignment in vae_alignment:
    new_seq = ''
    for i in range(62):
        if i not in trim_positions:
            new_seq+=alignment.seq[i]
    re_alignment = SeqRecord(seq=Seq(new_seq), id = alignment.id)
    vae_alignment_trimmed.append(re_alignment)

vae_alignment_trimmed = AlignIO.MultipleSeqAlignment(vae_alignment_trimmed)

AlignIO.write(vae_alignment_trimmed, 'vae_alignment_trimmed.fasta', 'fasta')

test_seqs = np.asarray([list(str(alignment.seq)) for alignment in vae_alignment_trimmed])

phenotypes = np.asarray(phenotypes)

11608it [01:53, 102.48it/s]


calcing sum
sum calced


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 21/21 [00:00<00:00, 8724.29it/s]


[16, 17, 44]


In [6]:
AMINO_ACIDS = "ARNDCQEGHILKMFPSTWYV-"
IDX_TO_AA = list(AMINO_ACIDS)
AA_TO_IDX = {aa: i for i, aa in enumerate(IDX_TO_AA)}

In [7]:
CNPRegressionDescription = collections.namedtuple(
    "CNPRegressionDescription",
    ("query", "target_y", "num_total_points", "num_context_points"))

In [12]:
seqs

array([['V', 'F', 'L', ..., 'A', 'P', 'V'],
       ['P', 'V', 'I', ..., 'C', 'G', 'Q'],
       ['K', 'A', 'R', ..., 'S', 'K', 'E'],
       ...,
       ['R', 'Q', 'S', ..., 'V', 'G', 'K'],
       ['P', 'A', 'G', ..., 'R', 'D', 'M'],
       ['L', 'D', 'P', ..., 'E', 'T', 'F']], dtype='<U1')

In [74]:
class ProteinDataset(torch.utils.data.Dataset):
    def __init__(self, data, **kwargs):
        super().__init__(**kwargs)
        self.data = data

        self.AMINO_ACIDS = "ARNDCQEGHILKMFPSTWYV-"
        self.IDX_TO_AA = list(AMINO_ACIDS)
        self.AA_TO_IDX = {aa: i for i, aa in enumerate(IDX_TO_AA)}

    def __len__(self):
        return self.data.shape[0]

    def __getitem__(self, index):

        X = torch.unsqueeze(torch.arange(start = 0, end = self.data.shape[1]),-1)

        Y = self.data[index]

        one_hot_Y = torch.tensor(self._to_one_hot(Y))
        return X, one_hot_Y

    def _to_one_hot(self, seq):
        one_hot_encoded = np.zeros((seq.shape[0],len(self.IDX_TO_AA)))

        for index, char in enumerate(seq):
            one_hot_encoded[index, self.AA_TO_IDX[char]]=1
        return one_hot_encoded

In [237]:
Y = batch[1]

In [244]:
min_num_context = 5
max_num_context = 55

num_context = np.random.randint(min_num_context, max_num_context)
print(num_context)
num_target = np.random.randint(Y.shape[1]-max_num_context, max_num_context - num_context)
print(num_target)

20
23


In [258]:
def collate_fn(batch):
    X = batch[1][0]
    Y = batch[1][1]

    max_num_context = int(0.9 * Y.shape[1])
    min_num_context = int(0.1 * Y.shape[1])

    num_context = np.random.randint(min_num_context, max_num_context)

    context_x = torch.zeros((X.shape[0], num_context, 1))
    context_y = torch.zeros((Y.shape[0], num_context, Y.shape[2]))
    target_x  = torch.zeros((X.shape[0], X.shape[1], 1))
    target_y  = torch.zeros((X.shape[0], Y.shape[1], Y.shape[2]))

    for idx in range(Y.shape[0]):
        total_idx = np.random.choice(range(Y.shape[1]), Y.shape[1], replace=False)
        c_idx = total_idx[:num_context]
        t_idx = total_idx

        context_x[idx] = torch.tensor(c_idx).unsqueeze(-1)
        target_x[idx] = torch.tensor(t_idx).unsqueeze(-1)

        context_y[idx] = Y[idx, c_idx,:]
        target_y[idx] = Y[idx, t_idx,:]

    return context_x, context_y, target_x, target_y

In [259]:
def context_target_splitter(batch, min_context, max_context, len_seq):

    num_context = torch.randint(low=min_context, high=max_context, size=[1])

    X, Y = batch

    X_context = torch.zeros(size=(X.shape[0], num_context, X.shape[-1]))
    Y_context = torch.zeros(size=(Y.shape[0], num_context, Y.shape[-1]))

    
    X_target = torch.zeros(size=(X.shape[0], len_seq - num_context, X.shape[-1]))
    Y_target = torch.zeros(size=(Y.shape[0], len_seq - num_context, Y.shape[-1]))
    
    

    for index in range(Y.shape[0]):
        seq = Y[index]

        
        shuffled_indices = torch.randperm(len_seq)
    
        context_indices = shuffled_indices[:num_context]
        target_indices = shuffled_indices[num_context:]
        
        X_context[index] = context_indices.unsqueeze(-1)
        X_target[index]  = target_indices.unsqueeze(-1)
        
        Y_context[index] = seq[context_indices]
        Y_target[index] = seq[target_indices]

    return (((X_context, Y_context), X_target), Y_target)

In [260]:
(((X_context, Y_context), X_target), Y_target) = context_target_splitter(batch, 3,55,59)

In [263]:
loader = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=True)

In [264]:
for batch in loader:
    X, Y

In [267]:
Y.shape

torch.Size([25, 59, 21])

In [99]:
min_context = int(0.1 * seqs.shape[1])
max_context = int(0.9 * seqs.shape[1])
len_aa = seqs.shape[1]


for batch in loader:
    (((X_context, Y_context), X_target), Y_target) = context_target_splitter(batch, min_context, max_context, len_aa)

In [100]:
class MLP(torch.nn.Module):
    def __init__(self, 
        input_size,
        output_sizes,
        **kwargs
    ):
        super().__init__(**kwargs)
        self.input_size = input_size
        self.output_sizes = output_sizes
        
        self.mlp = torch.nn.Sequential()

        self.mlp.add_module('input_layer', torch.nn.Linear(input_size, output_sizes[0]))
        self.mlp.add_module('relu', torch.nn.ReLU())

        for index in range(1, len(output_sizes)):
            self.mlp.add_module('hidden_layer_{}'.format(index), torch.nn.Linear(output_sizes[index-1], output_sizes[index]))
            self.mlp.add_module('relu_{}'.format(index+1), torch.nn.ReLU())


    def forward(self, x):
        assert x.shape[-1] == self.input_size, "Input to MLP not the correct dimension"

        return self.mlp(x)

In [307]:
mlp_layer = MLP(22,[128])

In [308]:
mlp_layer

MLP(
  (mlp): Sequential(
    (input_layer): Linear(in_features=22, out_features=128, bias=True)
    (relu): ReLU()
  )
)

In [101]:
class DotProductAttention(torch.nn.Module):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    def forward(self, queries, keys, values, d_k, mask=None):
        scores = torch.matmul(queries, keys.transpose(-1,-2))/torch.sqrt(d_k)
        if mask is not None:
            scores += -1e9*mask
        attention = torch.nn.Softmax(dim=-1)(scores)
        return torch.matmul(attention, values), attention

In [160]:
12//8

1

In [323]:
class MultiHeadedAttention(torch.nn.Module):
    def __init__(self, 
                 h,
                 d_query, 
                 d_key,
                 d_values,
                 d_hidden,
                 d_model,
                 activation = torch.nn.ReLU,
                 **kwargs
    ):
        super().__init__(**kwargs)
        self.attention = DotProductAttention()
        self.heads = h
        self.d_query = d_query
        self.d_key = d_key
        self.d_values = d_values
        self.d_hidden = d_hidden
        self.W_q = torch.nn.Linear(d_query, d_hidden*h)
        self.W_k = torch.nn.Linear(d_key, d_hidden*h)
        self.W_v = torch.nn.Linear(d_values, d_hidden*h)
        self.W_o = torch.nn.Linear(d_hidden, d_model)    

    def reshape_tensor(self, x, heads, flag):
        if flag:
            x = torch.reshape(x, shape = (x.shape[0], x.shape[1], heads, x.shape[2]//heads))
            x = x.permute(0,2,1,3)
        else:
            x = x.transpose(-1,-2)
            x = torch.reshape(x, shape = (x.shape[0], x.shape[1], x.shape[2]//heads))
        return x

    def forward(self, queries, keys, values, mask = None):
        n_context = queries.shape[1]
        n_target = keys.shape[1]

        q_reshaped = self.reshape_tensor(self.W_q(queries), self.heads, True)
        k_reshaped = self.reshape_tensor(self.W_k(keys), self.heads, True)
        v_reshaped = self.reshape_tensor(self.W_v(values), self.heads, True)
        
        output, attention = self.attention(q_reshaped, k_reshaped, v_reshaped, torch.tensor(self.d_hidden), mask)

        return self.W_o(output), output    

In [406]:
mlp_encoder = [128]

heads = 8
input_dim = 22
attention_input_dim = mlp_encoder[-1]
d_hidden = 128

output_mlp_decoder = [128,128,128,21]



encoder = [MLP(input_dim, mlp_encoder),
           MultiHeadedAttention(heads, attention_input_dim, attention_input_dim, attention_input_dim, d_hidden, d_hidden), 
           MultiHeadedAttention(heads, heads*d_hidden, heads*d_hidden, heads*d_hidden, d_hidden, d_hidden)
          ]

decoder = [MultiHeadedAttention(heads, d_hidden, d_hidden, heads*d_hidden, d_hidden, d_hidden),
           MultiHeadedAttention(heads, heads*d_hidden, d_hidden, heads*d_hidden, d_hidden, d_hidden),
           MLP(heads*d_hidden  + 1, output_mlp_decoder)
]



In [407]:
full_ANP = AttentiveNeuralProcess(encoder, decoder)

In [408]:
full_ANP.decoder

[MultiHeadedAttention(
   (attention): DotProductAttention()
   (W_q): Linear(in_features=128, out_features=1024, bias=True)
   (W_k): Linear(in_features=128, out_features=1024, bias=True)
   (W_v): Linear(in_features=1024, out_features=1024, bias=True)
   (W_o): Linear(in_features=128, out_features=128, bias=True)
 ),
 MultiHeadedAttention(
   (attention): DotProductAttention()
   (W_q): Linear(in_features=1024, out_features=1024, bias=True)
   (W_k): Linear(in_features=128, out_features=1024, bias=True)
   (W_v): Linear(in_features=1024, out_features=1024, bias=True)
   (W_o): Linear(in_features=128, out_features=128, bias=True)
 ),
 MLP(
   (mlp): Sequential(
     (input_layer): Linear(in_features=1025, out_features=128, bias=True)
     (relu): ReLU()
     (hidden_layer_1): Linear(in_features=128, out_features=128, bias=True)
     (relu_2): ReLU()
     (hidden_layer_2): Linear(in_features=128, out_features=128, bias=True)
     (relu_3): ReLU()
     (hidden_layer_3): Linear(in_featur

In [409]:
stuff = full_ANP(X_context, Y_context, X_target, Y_target)

got here
torch.Size([25, 7, 128])
torch.Size([25, 52, 128])
torch.Size([25, 52, 1024])
got here
torch.Size([25, 7, 1024])
torch.Size([25, 52, 128])
torch.Size([25, 52, 1024])


In [410]:
stuff[0].shape

torch.Size([25, 7, 21])

In [411]:
stuff[1]

tensor(1.3440, grad_fn=<MeanBackward0>)

In [405]:
class AttentiveNeuralProcess(torch.nn.Module):
    def __init__(self,
                 encoder,
                 decoder,
                 **kwargs
    ):
        super().__init__(**kwargs)
        self.encoder = encoder
        self.decoder = decoder
        self.target_projection = MLP(1, [128,128,128,128])
        self.context_projection = MLP(1, [128,128,128,128])

    def forward(self, context_x, context_y, target_x, target_y = None):
        concat_input = torch.concat([context_x, context_y], dim=-1)
        encoder_input = self.encoder[0](concat_input)
        for layer in self.encoder[1:]:
            encoder_input, _ = layer(encoder_input, encoder_input, encoder_input)
            encoder_input = encoder_input.permute((0,2,1,3)).reshape((encoder_input.shape[0], encoder_input.shape[2], encoder_input.shape[1]*encoder_input.shape[3]))

        query = self.target_projection(target_x)
        keys = self.context_projection(context_x)

        for layer in self.decoder[:-1]:
            print('got here')
            print(query.shape)
            print(keys.shape)
            print(encoder_input.shape)
            query , _ = layer(query, keys, encoder_input)
            query = query.permute((0,2,1,3)).reshape((query.shape[0], query.shape[2], query.shape[1]*query.shape[3]))

        concatenated_final_entry = torch.concat([query, target_x], dim=-1)
        output = self.decoder[-1](concatenated_final_entry)

        output = torch.nn.Softmax(dim=-1)(output)

        if target_y is not None:
            loss = self.cross_entropy_loss(output, target_y)
            return output, loss
        else:
            return output
        

        

    def cross_entropy_loss(self, output, target_y):
        assert output.shape == target_y.shape

        cross_entropy = torch.mean(torch.sum(-target_y * torch.log(output + 1e-6) - (1-target_y) * torch.log(1 - output + 1e-6), dim=1))
        return cross_entropy
        
        

In [329]:
(((x_context, y_context), x_target), y_target) = context_target_splitter(batch, min_context, max_context, len_aa)

In [330]:
full_context = torch.concat([x_context, y_context], dim=2)

In [331]:
full_context.shape

torch.Size([25, 44, 22])

In [332]:
full_context.shape

torch.Size([25, 44, 22])

In [333]:
Q, K, V = full_context, full_context, full_context

In [334]:
full_context.shape

torch.Size([25, 44, 22])

In [335]:
d_x = 1
d_y = 21
heads = 8
d_hidden = 128
d_model = 256

mha = MultiHeadedAttention(heads, 22,22,22, d_hidden, d_model)
mlp = MLP(d_x+d_y, [d_hidden,d_hidden,heads*d_model])

In [336]:
output, attention = mha(Q, K, V)

In [337]:
output_reshaped = output.permute((0,2,1,3)).reshape((output.shape[0], output.shape[2], output.shape[1]*output.shape[3]))

In [338]:
output.shape

torch.Size([25, 8, 44, 256])

In [339]:
output_reshaped.shape

torch.Size([25, 44, 2048])

In [340]:
mha_2 = MultiHeadedAttention(heads, output_reshaped.shape[-1], output_reshaped.shape[-1], output_reshaped.shape[-1], d_hidden, d_model)
output_second_layer, attention_2 = mha_2(output_reshaped, output_reshaped, output_reshaped)

In [341]:
output_second_layer_reshaped = output_second_layer.permute((0,2,1,3)).reshape((output_second_layer.shape[0], output_second_layer.shape[2],  output_second_layer.shape[1]* output_second_layer.shape[3])) 

In [342]:
output_second_layer_reshaped.shape

torch.Size([25, 44, 2048])

In [343]:
new_query = x_target

In [344]:
new_keys = x_context
new_values = output_second_layer_reshaped

In [345]:
new_query.shape

torch.Size([25, 15, 1])

In [346]:
mha_cross_1 = MultiHeadedAttention(8, 1, 1, 2048, d_hidden, d_model)
cross_out_1, _ = mha_cross_1(new_query, new_keys, new_values)

In [347]:
cross_out_1_reshaped = cross_out_1.permute((0,2,1,3)).reshape((cross_out_1.shape[0], cross_out_1.shape[2], cross_out_1.shape[1]*cross_out_1.shape[3]))

In [348]:
cross_out_1_reshaped.shape

torch.Size([25, 15, 2048])

In [349]:
concatenated_cross_out_1 = torch.concat([cross_out_1_reshaped, x_target], dim=-1)

In [350]:
concatenated_cross_out_1.shape

torch.Size([25, 15, 2049])

In [351]:
output_reshaped = output.permute((0,2,1,3)).reshape((output.shape[0], output.shape[2], output.shape[1]*output.shape[3]))

In [352]:
output_reshaped.shape

torch.Size([25, 44, 2048])

In [353]:
mha = MultiHeadedAttention(heads, 1,1,2048, d_hidden, d_model)


In [354]:
new_keys.shape

torch.Size([25, 44, 1])

In [355]:
final_output, _ = mha(new_query, new_keys, output_reshaped)

torch.Size([25, 8, 44, 256])

In [209]:
x_target.shape

torch.Size([25, 44, 1])

In [221]:
x_target.shape

torch.Size([25, 44, 1])

NameError: name 'd_values' is not defined

In [123]:
x_target.shape

torch.Size([25, 42, 1])

In [220]:
new_keys = x_context

In [126]:
new_keys.shape

torch.Size([25, 17, 1])

In [149]:
stuff = torch.matmul(new_query, new_keys.transpose(-1,-2))

In [229]:
randn_a = torch.randn((8,10,25))
randn_b = torch.randn((8,25,50))

In [230]:
torch.matmul(randn_a, randn_b) == randn_a.bmm(randn_b)

tensor([[[True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True],
         ...,
         [True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True]],

        [[True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True],
         ...,
         [True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True]],

        [[True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True],
         ...,
         [True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True]],

In [144]:
torch.matmul(torch.matmul(new_query, new_keys.transpose(-1,-2)).unsqueeze(1) * output)

RuntimeError: The size of tensor a (17) must match the size of tensor b (256) at non-singleton dimension 3

In [143]:
output.shape

torch.Size([25, 8, 17, 256])

In [168]:
reshaped_output = output.permute((0,2,1,3))

In [169]:
reshaped_output = reshaped_output.reshape((reshaped_output.shape[0], reshaped_output.shape[1], reshaped_output.shape[2]*reshaped_output.shape[3]))

In [170]:
mlp_output = mlp(full_context)

In [171]:
mlp_output.shape

torch.Size([100, 7, 2048])

In [172]:
reshaped_output.shape

torch.Size([100, 7, 2048])

In [178]:
torch.sum(reshaped_output * mlp_output, dim=1).shape

torch.Size([100, 2048])

In [84]:
output.unsqueeze(-2)

torch.Size([100, 8, 34, 1, 256])

In [92]:
torch.tile(mlped_context.unsqueeze(1).unsqueeze(2).shape

torch.Size([100, 1, 1, 34, 256])

In [93]:
torch.matmul(mlped_context.unsqueeze(1).unsqueeze(2), output.unsqueeze(-2))

RuntimeError: Expected size for first two dimensions of batch2 tensor to be: [27200, 256] but got: [27200, 1].

In [61]:
output.shape

torch.Size([100, 8, 34, 256])

In [190]:
test = full_context.unsqueeze(-2).tile((1,1,full_context.shape[1],1))

In [191]:
mha = MultiHeadedAttention(8, 22, 128, 256)

In [192]:
full_context.shape

torch.Size([100, 46, 22])

In [129]:
tiled_x_dim_1 = full_context.permute((0,2,1)).tile((1, full_context.shape[1],1))

In [131]:
tiled_x_dim_1[0].shape

torch.Size([1012, 46])

In [128]:
tensor_x_context[0]

tensor([[51.],
        [19.],
        [44.],
        [12.],
        [10.],
        [31.],
        [ 2.],
        [48.],
        [52.],
        [28.],
        [14.],
        [54.],
        [39.],
        [ 9.],
        [56.],
        [26.],
        [50.],
        [ 4.],
        [34.],
        [55.],
        [40.],
        [46.],
        [36.],
        [38.],
        [58.],
        [17.],
        [ 0.],
        [ 8.],
        [43.],
        [ 3.],
        [ 1.],
        [32.],
        [11.],
        [ 5.],
        [13.],
        [18.],
        [57.],
        [ 7.],
        [20.],
        [45.],
        [35.],
        [33.],
        [47.],
        [23.],
        [22.],
        [42.]])

In [115]:
tiled_x_dim_2 = tensor_x_context.tile((1, 1, tensor_x_context.shape[1]))

In [None]:
mha = MultiHeadedAttention()

In [118]:
tiled_x_dim_1.unsqueeze(-1)

torch.Size([100, 46, 46, 1])

In [120]:
tiled_x_dim_2.unsqueeze(-1).shape

torch.Size([100, 46, 46, 1])

In [88]:
x_context.shape

(100, 46, 1)

In [100]:
x_context

array([[[51.],
        [19.],
        [44.],
        ...,
        [23.],
        [22.],
        [42.]],

       [[44.],
        [50.],
        [26.],
        ...,
        [33.],
        [49.],
        [58.]],

       [[13.],
        [52.],
        [14.],
        ...,
        [19.],
        [53.],
        [17.]],

       ...,

       [[27.],
        [47.],
        [46.],
        ...,
        [34.],
        [20.],
        [19.]],

       [[10.],
        [ 1.],
        [47.],
        ...,
        [51.],
        [30.],
        [37.]],

       [[15.],
        [24.],
        [11.],
        ...,
        [41.],
        [ 6.],
        [32.]]], dtype=float32)

In [89]:
full_context = torch.concat((torch.tensor(x_context), torch.tensor(y_context)), dim=-1)

In [90]:
Q, K, V = full_context, full_context, full_context

In [91]:
Q, K, V

(tensor([[[51.,  0.,  0.,  ...,  0.,  0.,  0.],
          [19.,  0.,  0.,  ...,  0.,  0.,  0.],
          [44.,  0.,  1.,  ...,  0.,  0.,  0.],
          ...,
          [23.,  0.,  0.,  ...,  0.,  1.,  0.],
          [22.,  0.,  0.,  ...,  0.,  0.,  0.],
          [42.,  1.,  0.,  ...,  0.,  0.,  0.]],
 
         [[44.,  0.,  1.,  ...,  0.,  0.,  0.],
          [50.,  1.,  0.,  ...,  0.,  0.,  0.],
          [26.,  0.,  0.,  ...,  0.,  1.,  0.],
          ...,
          [33.,  0.,  0.,  ...,  0.,  0.,  0.],
          [49.,  0.,  0.,  ...,  0.,  0.,  0.],
          [58.,  0.,  0.,  ...,  0.,  0.,  1.]],
 
         [[13.,  1.,  0.,  ...,  0.,  0.,  0.],
          [52.,  0.,  0.,  ...,  1.,  0.,  0.],
          [14.,  0.,  0.,  ...,  0.,  0.,  0.],
          ...,
          [19.,  0.,  0.,  ...,  0.,  0.,  0.],
          [53.,  0.,  0.,  ...,  0.,  1.,  0.],
          [17.,  0.,  0.,  ...,  0.,  0.,  0.]],
 
         ...,
 
         [[27.,  0.,  0.,  ...,  0.,  0.,  0.],
          [47.,  0

In [92]:
mha = MultiHeadedAttention(8, 22, 128, 256)

In [94]:
representation, _ =mha(Q, K, V)

In [95]:
representation.shape

torch.Size([100, 8, 256])

In [96]:
representation_reshaped = torch.reshape(representation, shape = (representation.shape[0], representation.shape[1]*representation.shape[2]))

In [99]:
representation_reshaped.unsqueeze(1).shape

torch.Size([100, 1, 2048])

In [60]:
Q = torch.randn((10,5,4))
K = torch.randn((10,5,4))
V = torch.randn((10,5,4))


In [61]:
mha = MultiHeadedAttention(8, 4, 16, 256)

In [62]:
blah, other_blah = mha(Q, K, V)

torch.Size([10, 8, 5, 16])


In [64]:
blah.shape

torch.Size([10, 8, 256])

In [290]:
Q_unsqueezed_1 = Q.unsqueeze(2)
Q_unsqueezed_2 = Q.unsqueeze(3)


In [312]:
difference = Q_unsqueezed_1 - Q_unsqueezed_2

In [313]:
difference

tensor([[[[ 0.0000, -1.1733, -0.2043, -3.6985],
          [ 1.1733,  0.0000,  0.9690, -2.5252],
          [ 0.2043, -0.9690,  0.0000, -3.4942],
          [ 3.6985,  2.5252,  3.4942,  0.0000]],

         [[ 0.0000,  0.2828, -1.9960,  0.5130],
          [-0.2828,  0.0000, -2.2788,  0.2302],
          [ 1.9960,  2.2788,  0.0000,  2.5090],
          [-0.5130, -0.2302, -2.5090,  0.0000]],

         [[ 0.0000,  0.5370, -0.3218, -2.2304],
          [-0.5370,  0.0000, -0.8589, -2.7674],
          [ 0.3218,  0.8589,  0.0000, -1.9086],
          [ 2.2304,  2.7674,  1.9086,  0.0000]],

         [[ 0.0000, -1.5438, -2.6875,  1.5698],
          [ 1.5438,  0.0000, -1.1437,  3.1136],
          [ 2.6875,  1.1437,  0.0000,  4.2573],
          [-1.5698, -3.1136, -4.2573,  0.0000]],

         [[ 0.0000, -0.3065,  2.0844, -0.5693],
          [ 0.3065,  0.0000,  2.3909, -0.2628],
          [-2.0844, -2.3909,  0.0000, -2.6537],
          [ 0.5693,  0.2628,  2.6537,  0.0000]]],


        [[[ 0.0000, -1.0647,

In [318]:
Q_unsqueezed_1[0,1]

tensor([[ 0.9603,  1.2431, -1.0357,  1.4733]])

In [317]:
Q_unsqueezed_2[0,1]

tensor([[ 0.9603],
        [ 1.2431],
        [-1.0357],
        [ 1.4733]])

In [316]:
difference[0][1]

tensor([[ 0.0000,  0.2828, -1.9960,  0.5130],
        [-0.2828,  0.0000, -2.2788,  0.2302],
        [ 1.9960,  2.2788,  0.0000,  2.5090],
        [-0.5130, -0.2302, -2.5090,  0.0000]])

torch.Size([10, 5, 4, 1])

In [304]:
Q[0][0] - Q[0][1]

tensor([ 0.7462, -0.7099,  2.5379, -3.4653])

In [300]:
(Q_unsqueezed_1-Q_unsqueezed_2)[0][0]

tensor([[ 0.0000, -1.1733, -0.2043, -3.6985],
        [ 1.1733,  0.0000,  0.9690, -2.5252],
        [ 0.2043, -0.9690,  0.0000, -3.4942],
        [ 3.6985,  2.5252,  3.4942,  0.0000]])

In [204]:
norm_Q = torch.sum((Q*Q), dim=-1)

In [218]:
pairwise_Q = Q.matmul(Q.transpose(-1,-2))

In [221]:
norm_Q_A = torch.unsqueeze(norm_Q,2).tile(1,1,5)
norm_Q_B = torch.unsqueeze(norm_Q,1).tile(1,5,1)

In [229]:
pairwise_dists = (norm_Q_A + norm_Q_B - 2 * pairwise_Q)

In [230]:
pairwise_dists[0]

tensor([[ 0.0000, 10.0702,  1.6816,  9.2205,  4.1929],
        [10.0702,  0.0000,  4.4489, 16.1165, 12.6896],
        [ 1.6816,  4.4489,  0.0000,  8.8964,  7.0335],
        [ 9.2205, 16.1165,  8.8964,  0.0000,  8.0963],
        [ 4.1929, 12.6896,  7.0335,  8.0963,  0.0000]])

In [236]:
np.linalg.norm(Q[0][0] - Q[0][4])**2

4.192918770641882

In [216]:
np.linalg.norm(Q[0].detach().numpy()[1])**2

7.8102502911578995

In [200]:
norm_unsqueezed_1 = torch.tile(torch.unsqueeze((torch.sum((Q*Q), dim=-1)),2),(1,1,5))
norm_unsqueezed_2 = torch.tile(torch.unsqueeze((torch.sum((Q*Q), dim=-1)),2),(1,1,5))


torch.Size([10, 5, 1])

In [201]:
torch.unsqueeze((torch.sum((Q*Q), dim=-1)),1).shape

torch.Size([10, 1, 5])

In [None]:
torch.

In [196]:
torch.tile(torch.unsqueeze((torch.sum((Q*Q), dim=-1)),2),(1,1,5)).shape

torch.Size([10, 5, 5])

In [198]:
torch.tile(torch.unsqueeze((torch.sum((Q*Q), dim=-1)),1),(1,5,1)).shape

torch.Size([10, 5, 5])

In [90]:
stuff = torch.randn((10,5,4))
other_stuff = torch.randn((10,5,4))


torch.matmul(stuff, other_stuff.transpose(-1,-2)).shape

torch.Size([10, 5, 5])

In [93]:
attention = DotProductAttention()

In [51]:
Q = torch.tensor([[0.3367, 0.1288],[0.2345,0.2303],[-1.1229,-0.1863]])
K = torch.tensor([[ 2.2082, -0.6380],
        [ 0.4617,  0.2674],
        [ 0.5349,  0.8094]])

V = torch.tensor([[ 1.1103, -1.6898],
        [-0.9890,  0.9580],
        [ 1.3221,  0.8172]])

In [52]:
seq_len = 3
d_k = torch.tensor(2)

attention(Q, K, V, d_k)

(tensor([[ 0.5697, -0.1520],
         [ 0.5379, -0.0265],
         [ 0.2246,  0.5556]]),
 tensor([[0.4028, 0.2886, 0.3086],
         [0.3538, 0.3069, 0.3393],
         [0.1303, 0.4630, 0.4067]]))

In [69]:
mha = MultiHeadedAttention(8, 2, 2, 128)

mha(Q,K,V)

torch.Size([3, 16])


RuntimeError: shape '[3, 16, 8, -1]' is invalid for input of size 48

In [62]:
torch.Size(Q)

TypeError: torch.Size() takes an iterable of 'int' (item 0 is 'Tensor')

In [14]:
class MLP(torch.nn.Module):
    def __init__(
        input_size,
        output_sizes,
        is_bias = True,
        activation = torch.nn.ReLU,
        dropout = 0,
        is_res = False,
        **kwargs
    ):
        super().__init__(**kwargs)
        self._output_sizes = output_sizes
        self.activation = activation
        self.dropout = torch.nn.Dropout(dropout)
        self.res = is_res

        self.to_hidden = torch.nn.Linear(input_size, self.output_sizes[0], bias=is_bias)
        self.linears = torch.nn.ModuleList(
            [
                torch.nn.Linear(self.output_sizes[i-1], self.output_sizes[i], bias=is_bias)
                for _ in range(1, len(self.output_sizes))
            ]
        )
    def forward(self, x):

        for linear in linears:
            output_1 = self.linear(output)
            output_1 = self.activation(output_1)
            output_1 = self.dropout(output_1)
            if self.is_res:
                output = output_1 + output
            else:
                output = output_1
        return output

        

In [None]:
class MLP_RBF(torch.nn.Module):
    def __init__(
        x_dim,
        output_sizes = [128.128,128,128],
        is_bias = True,
        activation = torch.nn.ReLU,
        dropout = 0
        is_res = False
    ):
        self.mlp = MLP(
            x_dim,
            output_sizes,
            is_bias,
            activation,
            dropout,
            is_res)

    def forward(self, x):
        """ takes x : (batch_size, num_context, x_dim + y_dim)
        passes through mlp : (batch_size, num_context, mlp_output_dim)
        passes through reshaping: (batch_size, mlp_output_dim*num_context)
        Computes radial basis function, returning (batch_size, mlp_output^2)
        """
        x_mlp = self.mlp(x)

        x_reshaped_mlp = torch.reshape(x, (x_mlp.shape[0], x_mlp.shape[1]*x_mlp.shape[2]))

        
        

        

In [88]:
stuff = torch.randn((10,5,4))

In [86]:
stuff

tensor([[[ 1.1892, -0.1593,  0.0150, -0.0488],
         [ 0.2855,  1.1376, -1.3892, -0.5053],
         [-0.0613,  1.3183, -1.3418,  0.8088],
         [-1.0177, -0.9790, -0.4899, -2.2784],
         [ 1.0995,  1.6432,  1.1753, -0.5345]],

        [[ 0.7801, -1.0657,  0.2304,  0.2352],
         [ 1.0238,  0.0708, -0.4076, -0.8022],
         [-0.3727, -0.8060, -0.6630,  1.7373],
         [-1.0747, -0.4488,  0.4065, -0.6092],
         [ 1.0579,  2.1286,  0.0697,  2.7280]],

        [[ 1.6939, -0.8498, -0.4778,  0.7137],
         [ 0.2977, -0.7736,  0.2704, -0.2792],
         [-1.3080,  2.2400, -0.7018,  0.2845],
         [-0.2582,  0.2023, -1.8169,  1.6254],
         [-0.9776,  0.1462, -1.2581,  1.1714]],

        [[-0.1130,  1.6967,  0.5609, -1.6138],
         [ 0.2231, -0.8121,  0.5768, -0.5012],
         [ 1.1451,  0.7488,  2.2243, -0.6175],
         [-0.8851, -0.4628, -0.4371,  0.7586],
         [-0.1597, -0.1398,  0.0112, -1.4179]],

        [[ 0.0439,  0.1806,  0.3941,  2.7421],
     

In [79]:
Q.view(6,2)

RuntimeError: shape '[6, 2]' is invalid for input of size 6

In [None]:
class SetConv(nn.Module):
    def __init__(

        x_dim, 
        in_channels,
        out_channels,
        RadialBasisFunc = 

    ):

        