In [1]:
import torch
import numpy as np
import sklearn


In [2]:
class Noiser():
    def __init__(self,
                 noiser = 'Uniform',
                 beta_t = 0.001,
                 k = 21
    ):
        if noiser == 'Uniform':
            self.noise_matrix = (1.-beta_t) * torch.eye(k)
            self.noise_matrix = self.noise_matrix + beta_t/k * (torch.ones((k,k)) - torch.eye(k))
        elif noiser == 'BERT-LIKE':
            self.noise_matrix = (1.-beta_t) * torch.eye(k+1)
            self.noise_matrix[:,k] = torch.ones(k+1)*beta_t
            self.noise_matrix[k,:] = torch.zeros(k+1)
            self.noise_matrix[k,k] = 1.
        elif noiser == 'Gaussian':
            one_way = torch.arange(0,k,1).unsqueeze(0)
            other_way = torch.arange(0,k,1).unsqueeze(1)

            self.noise_matrix = torch.exp(-4 * (one_way-other_way).pow(2)/((k-1)**2 * beta_t))/torch.sum(torch.exp(-4 * torch.arange(-k+1,k,1)**2/((k-1)**2 * beta_t)))
            diagonal = 1-torch.sum(self.noise_matrix - self.noise_matrix[0,0]*torch.eye(k), dim=1)
            
            mask = torch.diag(torch.ones_like(self.noise_matrix))

            self.noise_matrix = mask * torch.diag(diagonal) + (1.-torch.diag(mask))*self.noise_matrix

In [3]:
default_aa_keys='-GALMFWKQESPVICYHRNDT'
def fasta_to_df(fasta_file, aa_keys = default_aa_keys):
    """
    creates one hot encoding of a fasta file using biopython's alignio.read process. 
    fasta_file : filepath leading to msa file in fasta format at hand
    """
    column_names = []
    column_names.extend(aa_keys)
    msa=AlignIO.read(fasta_file, "fasta")
    num_columns = len(msa[0].seq)
    column_names = column_names*num_columns
    column_names.append('sequence')
    column_names.append('id')
    init = np.zeros((len(msa), len(column_names)))
    df = pd.DataFrame(init, columns = column_names)
    df.sequence = df.sequence.astype(str)
    df.id=df.id.astype(str)
    
    for row_num, alignment in tqdm(enumerate(msa)):
        sequence = str(alignment.seq)
        for index, char in enumerate(sequence):
            place = aa_keys.find(char)
            df.iloc[row_num, index*len(aa_keys) + place] = 1
        
        df.iloc[row_num,-2]=str(alignment.seq)
        df.iloc[row_num,-1]=str(alignment.id)
    
    return df

In [4]:
def create_frequency_matrix(df, aa_keys = default_aa_keys):
    """takes one hot encoded msa and returns the frequency of each amino acid at each site
    df : pandas dataframe whose columns are the one hot encoding of an msa
    """
    num_columns=len(df['sequence'][0])
    
    frequency_matrix = np.zeros( (len(aa_keys) , num_columns) )
    print('calcing sum')
    freq=df.sum()
    print('sum calced')
    
    num_entries=len(df)
    len_aa_keys = len(aa_keys)
    
    for i in tqdm(range(len(aa_keys))):
        for j in range(num_columns):
            frequency_matrix[i, j] = freq[ i + len_aa_keys * j] / num_entries
    
    return frequency_matrix

In [5]:
import pandas as pd

In [6]:
msa = pd.read_csv('SH3_Full_Dataset_8_9_22.csv')
msa['Type'].unique()
naturals_msa = msa[msa['Type']=='Naturals']
seqs = np.asarray([list(seq) for seq in naturals_msa['Sequences']])
norm_re = np.asarray([re for re in naturals_msa['Norm_RE']])

In [7]:
phyla = np.asarray([domain for domain in naturals_msa['Phylum']])

In [8]:
from Bio import AlignIO
from Bio.Seq import Seq
from Bio.SeqRecord import SeqRecord
from tqdm import tqdm
vae_alignment = []
phenotypes = []

vae_data = msa[msa['Type']=='VAE'].reset_index()

for r in range(len(vae_data)):
    alignment = vae_data.loc[r]
    if len(alignment['Sequences'])==62:
        record = SeqRecord(seq = Seq(alignment['Sequences']), id = alignment['Header'])
    
    vae_alignment.append(record)
    phenotypes.append(alignment['Norm_RE'])

vae_alignment = AlignIO.MultipleSeqAlignment(vae_alignment)

AlignIO.write(vae_alignment, 'vae_alignment.fasta', 'fasta')

vae_df = fasta_to_df('vae_alignment.fasta')

freq_matrix = create_frequency_matrix(vae_df)

trim_positions = []

for i in range(freq_matrix.shape[1]):
    if 1 in freq_matrix[:,i]:
        trim_positions.append(i)

print(trim_positions)


vae_alignment_trimmed = []


for alignment in vae_alignment:
    new_seq = ''
    for i in range(62):
        if i not in trim_positions:
            new_seq+=alignment.seq[i]
    re_alignment = SeqRecord(seq=Seq(new_seq), id = alignment.id)
    vae_alignment_trimmed.append(re_alignment)

vae_alignment_trimmed = AlignIO.MultipleSeqAlignment(vae_alignment_trimmed)

AlignIO.write(vae_alignment_trimmed, 'vae_alignment_trimmed.fasta', 'fasta')

test_seqs = np.asarray([list(str(alignment.seq)) for alignment in vae_alignment_trimmed])

phenotypes = np.asarray(phenotypes)

11608it [02:10, 89.02it/s] 


calcing sum
sum calced


100%|█████████████████████████████████████████| 21/21 [00:00<00:00, 6318.99it/s]

[16, 17, 44]





In [9]:
test_seqs

array([['-', '-', '-', ..., 'S', 'C', 'C'],
       ['T', 'A', 'K', ..., 'I', 'V', 'S'],
       ['P', 'A', 'K', ..., 'I', 'E', 'V'],
       ...,
       ['S', 'A', 'Q', ..., 'V', '-', '-'],
       ['K', 'K', 'H', ..., 'T', 'G', 'S'],
       ['K', 'K', 'H', ..., 'T', 'G', 'S']], dtype='<U1')

In [10]:
AMINO_ACIDS = "ARNDCQEGHILKMFPSTWYV-"
IDX_TO_AA = list(AMINO_ACIDS)
AA_TO_IDX = {aa: i for i, aa in enumerate(IDX_TO_AA)}

In [11]:
class ProteinDataset(torch.utils.data.Dataset):
    """
    takes in sequence data and phenotype data and spits back two dictionaries, X with one key - sequence, and Y with one, potentially two, keys - sequence and phenotype
    inputs:
        seq_data: np.array representing sequences
        phenotype_data np.array representing phenotypes
    
    """
    
    def __init__(self,
                 seq_data,
                 phenotype_data = None,
                 include_mask = False,
                **kwargs):
        super().__init__(**kwargs)
        self.include_mask = include_mask
        self.seqs = seq_data
        self.phenotype_data = phenotype_data
        self.AMINO_ACIDS = "ARNDCQEGHILKMFPSTWYV-"
        self.IDX_TO_AA = list(self.AMINO_ACIDS)
        self.AA_TO_IDX = {aa: i for i, aa in enumerate(self.IDX_TO_AA)}

    def __len__(self):
        return self.seqs.shape[0]

    def __getitem__(self, index):
        X = dict()
        Y = dict()
        if self.phenotype_data is not None:
            Y['pheno'] = self.phenotype_data[index]

        one_hot_seq = self._to_one_hot(self.seqs[index])
        Y['seq'] = one_hot_seq
        X['seq'] = one_hot_seq
        return X, Y

    def _to_one_hot(self, seq):
        if self.include_mask:
            one_hot_encoded = np.zeros((seq.shape[0],len(self.IDX_TO_AA)+1))
            for index, char in enumerate(seq):
                one_hot_encoded[index, self.AA_TO_IDX[char]]=1
            return torch.tensor(one_hot_encoded, dtype=torch.float32)
        else:
            one_hot_encoded = np.zeros((seq.shape[0],len(self.IDX_TO_AA)))
            for index, char in enumerate(seq):
                one_hot_encoded[index, self.AA_TO_IDX[char]]=1
            return torch.tensor(one_hot_encoded, dtype=torch.float32)

In [12]:
def sampler(sequence_matrix, noise_matrix, num_states):
    
    probabilities = torch.matmul(sequence_matrix, noise_matrix.unsqueeze(0))
    results = torch.nn.functional.one_hot(torch.multinomial(probabilities.view(-1, results.shape[-1]), 1), num_classes=num_states)

    results = results.view(probabilities.shape[0], probabilities.shape[1], num_states)

    return results

In [13]:
delta_t = 0.001

times = torch.arange(0, 1 + delta_t, step=0.001)

In [14]:
class Attention(torch.nn.Module):
    def __init__(self,
                 **kwargs
                ):
        super().__init__()


    def forward(queries,
                keys,
                values,
                d_k,
                mask = None
               ):
        scores = torch.matmul(queries, keys.transpose(-1,-2))/torch.sqrt(d_k)
        if mask is not None:
            scores += -1e9*mask

        attention = torch.nn.Softmax(dim=-1)(scores)
        return torch.matmul(attention, values), attention
        
                 

In [15]:
class MultiHeadedAttention(torch.nn.Module):
    def __init__(self, 
                 heads,
                 d_query,
                 d_key,
                 d_value,
                 d_hidden,
                 d_model,
                 activation = torch.nn.ReLU,
                 **kwargs
                ):
        super().__init__(**kwargs)
        self.attention = DotProductAttention()
        self.heads = heads
        self.d_query = d_query
        self.d_key = d_key
        self.d_values = d_values
        self.d_hidden = d_hidden
        self.W_q = torch.nn.Linear(d_query, d_hidden*self.heads)
        self.W_k = torch.nn.Linear(d_key, d_hidden*self.heads)
        self.W_v = torch.nn.Linear(d_values, d_hidden*self.heads)
        self.W_o = torch.nn.Linear(self.heads*d_hidden, d_model)

    def reshape_tensor(self,
                       x,
                       heads,
                       flag
                      ):
        if flag:
            x = x.view(x.shape[0], x.shape[1], heads, x.shape[2]//heads)
            x = x.permute(0,2,1,3)
        else:
            x = x.permute(0,2,1,3)
            x = x.view(x.shape[0], x.shape[1], self.d_model*self.heads)

    def forward(self,
                query,
                keys,
                value
               ):
        q, k, v = self.W_q(query), self.W_k(query), self.W_v(value)
        if self.activation is not None:
            query_reshaped = self.reshape_tensor(self.activation(q), self.heads, True)
            key_reshaped   = self.reshape_tensor(self.activation(k), self.heads, True)
            value_reshaped = self.reshape_tensor(self.activation(v), self.heads, True)

        activations, attention = self.attention(query_reshaped, key_reshaped, value_reshaped)
        activations = self.reshape_tensor(activations, self.heads, False)
        if self.activation is not None:
            return self.W_o(output), output
        else:
            return output, output
        

In [16]:
class AddNorm(torch.nn.Module):
    def __init__(
        self,
        normalized_shape, 
        **kwargs
    ):
        super().__init__(**kwargs)
        self.layer_norm = torch.nn.LayerNorm(normalized_shape)
    def forward(
        self,
        x,
        sub_layer_x
    ):
        add = x + sub_layer_x
        return self.layer_norm(add)        

In [17]:
class FeedForward(torch.nn.Module):
    def __init__(
        self,
        input_dim,
        output_dim,
        activation = torch.nn.ReLU,
        **kwargs
    ):
        super().__init__(**kwargs)
        self.layer = torch.nn.Linear(input_dim, output_dim)
        self.activation = activation
    def forward(
        self,
        x
    ):
        if self.activation is not None:
            return self.activation(self.layer(x))
        else:
            return self.layer(x)

In [74]:
class D3PM(torch.nn.Module):
    def __init__(
        self,
        input_dim,
        d_feedforward,
        heads,
        d_model,
        p_dropout,
        noise_matrix
    ):
        super().__init__(**kwargs)
        self.heads = heads
        self.input_dim = input_dim
        self.d_feedforward = d_feedforward
        self.d_model = d_model
        self.p_dropout = p_dropout
        self.flatten = torch.nn.Flatten(start_dim=1)
        self.feedforward_1 = FeedForward(input_dim, d_feedforward)
        self.mha_1 = MultiHeadedAttention(heads, self.d_feedforward, self.d_feedforward, self.d_feedforward, self.d_feedforward, self.d_model)
        self.dropout_1 = torch.nn.Dropout(0.1)
        self.addnorm_1 = AddNorm()
        self.feedforward_2 = FeedForward(self.d_model, self.d_feedforward)
        self.mha_2 = MultiHeadedAttention(heads, self.d_feedforward, self.d_feedforward, self.d_feedforward, self.d_feedforward, self.d_model)
        self.dropout_2 = torch.nn.Dropout(0.1)
        self.addnorm_2 = AddNorm()
        self.feedforward_3 = FeedForward(self.d_model, self.input_dim, activation = None)
        self.noise_matrix = noise_matrix

    def forward(self, X, Y = None):
        X = self.flatten(X)
        X_1 = self.feedforward_1(X)
        X_2, _ = self.mha_1(X_1,X_1,X_1)
        X_2 = self.dropout_1(X_2)
        X_2 = self.addnorm_1(X_2, X_1)

        X_2 = self.feedforward_2(X_2)
        X_3, _ = self.mha_2(X_2, X_2, X_2)
        X_3 = self.dropout_2(X_3)
        X_3 = self.addnorm_2(X_3, X_2)
        Y_pred = self.feedforward_3(X_3)
        Y_pred = Y_pred.view(Y.shape[0], Y.shape[1], Y.shape[2])

    def L_T(self, X, pX):
        vals, vecs = torch.linalg.eig(noise_matrix.noise_matrix.t())
        vals = torch.real(vals)
        vecs = torch.real(vecs)

        PxT = vecs[:, torch.argmax(vals)].unsqueeze(0).unsqueeze(1)

        no_zeros_PxT = PxT + 1e-6
        no_zeros_X   = X + 1e-6
        
        dkl_steady_state = torch.sum(
            no_zeros_X * torch.log(no_zeros_X/no_zeros_PxT)
        )

        return dkl_steady_state

    def cross_entropy_loss(self, y, yhat):
        no_zeros_y = y+1e-6
        no_zeros_yhat = yhat + 1e-6

        return torch.sum(y * torch.log(y/yhat))  


In [52]:
an = AddNorm()

TypeError: LayerNorm.__init__() missing 1 required positional argument: 'normalized_shape'

In [21]:
proteins = ProteinDataset(seq_data=seqs, include_mask=True)

In [22]:
loader = torch.utils.data.DataLoader(proteins, batch_size=32, shuffle=True)


In [23]:
for batch in loader:
    print(1)

1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1


In [24]:
X, Y = batch

In [25]:
X_seq = X['seq']
Y_seq = Y['seq']


In [26]:
noise_matrix.noise_matrix

tensor([[0.6000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.4000],
        [0.0000, 0.6000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.4000],
        [0.0000, 0.0000, 0.6000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.4000],
        [0.0000, 0.0000, 0.0000, 0.6000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.4000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.6000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.000

In [77]:
noise_matrix_04 = Noiser(noiser = 'BERT-LIKE', beta_t = 0.4, k =21).noise_matrix
noise_matrix_03 = Noiser(noiser = 'BERT-LIKE', beta_t = 0.3, k =21).noise_matrix
noise_matrix_01 = Noiser(noiser = 'BERT-LIKE', beta_t = 0.1, k =21).noise_matrix


In [107]:
results_04 = torch.matmul(X_seq, noise_matrix_04.view(1,22,22))
results_03 = torch.matmul(X_seq, noise_matrix_03.view(1,22,22))
results_01 = torch.matmul(sampled[3], noise_matrix_01.view(1,22,22))

sampled_04 = torch.nn.functional.one_hot(torch.multinomial(results_04.view(-1, results_04.shape[-1]), 1), num_classes=22)
sampled_04 = sampled_04.view(results_04.shape[0], results_04.shape[1], 22)

sampled_03 = torch.nn.functional.one_hot(torch.multinomial(results_03.view(-1, results_03.shape[-1]), 1), num_classes=22)
sampled_03 = sampled_03.view(results_03.shape[0], results_03.shape[1], 22)

sampled_01 = torch.nn.functional.one_hot(torch.multinomial(results_01.view(-1, results_01.shape[-1]), 1), num_classes=22)
sampled_01 = sampled_01.view(results_01.shape[0], results_04.shape[1], 22)



In [125]:
numer = (torch.matmul(sampled_04.type(torch.FloatTensor), noise_matrix_01.view(1,22,22).transpose(-1,-2)) * torch.matmul(X_seq, noise_matrix_03.view(1,22,22)))

In [None]:
denom = torch.matmul(X_seq, noise_matrix_04.view(1,22,22))

In [129]:
sampled_04.shape

torch.Size([25, 59, 22])

In [132]:
denom = torch.matmul(torch.matmul(X_seq, noise_matrix_04.view(1,22,22)), sampled_04.type(torch.FloatTensor).transpose(-1,-2))

In [134]:
numer.shape

torch.Size([25, 59, 22])

In [135]:
denom.shape

torch.Size([25, 59, 59])

In [140]:
torch.diagonal(denom,dim1=-2, dim2=-1)

tensor([[0.4000, 0.4000, 0.6000,  ..., 0.6000, 0.6000, 0.6000],
        [0.6000, 0.4000, 0.4000,  ..., 0.6000, 0.4000, 0.6000],
        [0.4000, 0.4000, 0.6000,  ..., 0.6000, 0.6000, 0.6000],
        ...,
        [0.4000, 0.4000, 0.6000,  ..., 0.4000, 0.4000, 0.6000],
        [0.6000, 0.6000, 0.6000,  ..., 0.6000, 0.6000, 0.4000],
        [0.6000, 0.6000, 0.6000,  ..., 0.4000, 0.4000, 0.4000]])

In [133]:
numer/denom

RuntimeError: The size of tensor a (22) must match the size of tensor b (59) at non-singleton dimension 2

In [121]:
sampled_04.type(torch.FloatTensor)

tensor([[[0., 0., 0.,  ..., 0., 0., 1.],
         [0., 0., 0.,  ..., 0., 0., 1.],
         [0., 0., 1.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 1.,  ..., 0., 0., 0.]],

        [[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 1.],
         [0., 0., 0.,  ..., 0., 0., 1.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 1.],
         [1., 0., 0.,  ..., 0., 0., 0.]],

        [[0., 0., 0.,  ..., 0., 0., 1.],
         [0., 0., 0.,  ..., 0., 0., 1.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]],

        ...,

        [[0., 0., 0.,  ..., 0., 0., 1.],
         [0., 0., 0.,  ..., 0., 0., 1.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 1.],
         [0., 0., 0., 

In [116]:
noise_matrix_01.t().shape

torch.Size([22, 22])

In [109]:
((results_01+1e-6) * (results_03+1e-6)/(results_04+1e-6))[0][0]

tensor([1.0000e-06, 1.0000e-06, 1.0000e-06, 1.0000e-06, 1.0000e-06, 1.0000e-06,
        1.0000e-06, 1.0000e-06, 1.0000e-06, 1.0000e-06, 1.0000e-06, 1.0000e-06,
        1.0000e-06, 1.0000e-06, 1.0000e-06, 1.0500e+00, 1.0000e-06, 1.0000e-06,
        1.0000e-06, 1.0000e-06, 1.0000e-06, 7.5001e-02])

In [92]:
sampled = torch.zeros((5, X_seq.shape[0], X_seq.shape[1], X_seq.shape[2]))

In [93]:
sampled[0] = X_seq
for i in range(1,5):
    results = torch.matmul(sampled[i-1], noise_matrix_01.view(1,22,22))
    sampled_dummy = torch.nn.functional.one_hot(torch.multinomial(results.view(-1, results.shape[-1]), 1), num_classes=22)
    
    sampled[i] = sampled_dummy.view(results.shape[0], results.shape[1], 22)

In [95]:
sampled.shape

torch.Size([5, 25, 59, 22])

In [84]:
qxt1givxt = ((sampled[+1e-6)*(sampled_03+1e-6)/(sampled_04+1e-6))

In [85]:
qxt1givxt[0][0]

tensor([1.0000e-06, 1.0000e-06, 1.0000e-06, 1.0000e-06, 1.0000e-06, 1.0000e-06,
        1.0000e-06, 1.0000e-06, 1.0000e-06, 1.0000e-06, 1.0000e-06, 1.0000e-06,
        1.0000e-06, 1.0000e-06, 1.0000e-06, 1.0000e+06, 1.0000e-06, 1.0000e-06,
        1.0000e-06, 1.0000e-06, 1.0000e-06, 1.0000e-12])

In [29]:
results = torch.matmul(X_seq, noise_matrix.noise_matrix.view(1,22,22))

In [30]:
torch.nn.functional.one_hot(torch.multinomial(results[0,0], 1)).shape

torch.Size([1, 22])

In [31]:
results.shape

torch.Size([25, 59, 22])

In [32]:
results[0][0]

tensor([0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.6000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.4000])

In [33]:
X_seq[0][0]

tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0.,
        0., 0., 0., 0.])

In [34]:
sampled = torch.nn.functional.one_hot(torch.multinomial(results.view(-1, results.shape[-1]), 1), num_classes=22)

In [39]:
sampled = sampled.view(results.shape[0], results.shape[1], 22)

tensor([[0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 1,  ..., 0, 0, 0],
        ...,
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 1]])

In [36]:
results.shape

torch.Size([25, 59, 22])

In [42]:
noise_matrix = Noiser('BERT-LIKE',  beta_t = 0.4, k =21)

In [43]:
vals, vecs = torch.linalg.eig(noise_matrix.noise_matrix.t())
vals = torch.real(vals)
vecs = torch.real(vecs)

In [44]:
vecs[:,0]

tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 1.])

In [45]:
vals

tensor([1.0000, 0.6000, 0.6000, 0.6000, 0.6000, 0.6000, 0.6000, 0.6000, 0.6000,
        0.6000, 0.6000, 0.6000, 0.6000, 0.6000, 0.6000, 0.6000, 0.6000, 0.6000,
        0.6000, 0.6000, 0.6000, 0.6000])

In [47]:
PxT = vecs[:,torch.argmax(vals)]

In [50]:
PxT.shape

torch.Size([22])

In [52]:
perturbed_sampled = sampled + 1e-6

In [59]:
unsqueeze_PxT = PxT.unsqueeze(0).unsqueeze(1) + 1e-6

In [60]:
unsqueeze_PxT.shape

torch.Size([1, 1, 22])

In [63]:
(perturbed_sampled/unsqueeze_PxT).shape

torch.Size([25, 59, 22])

In [73]:
perturbed_sampled * torch.log(perturbed_sampled/unsqueeze_PxT)

tensor([[[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
           0.0000e+00, -1.3816e-05],
         [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
           0.0000e+00, -1.3816e-05],
         [ 0.0000e+00,  0.0000e+00,  1.3816e+01,  ...,  0.0000e+00,
           0.0000e+00, -1.3816e-05],
         ...,
         [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
           0.0000e+00, -1.3816e-05],
         [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
           0.0000e+00, -1.3816e-05],
         [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
           0.0000e+00,  0.0000e+00]],

        [[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
           0.0000e+00, -1.3816e-05],
         [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
           0.0000e+00,  0.0000e+00],
         [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
           0.0000e+00, -1.3816e-05],
         ...,
         [ 0.0000e+00,  0

In [67]:
perturbed_sampled[0][0]

tensor([1.0000e-06, 1.0000e-06, 1.0000e-06, 1.0000e-06, 1.0000e-06, 1.0000e-06,
        1.0000e-06, 1.0000e-06, 1.0000e-06, 1.0000e-06, 1.0000e-06, 1.0000e-06,
        1.0000e-06, 1.0000e-06, 1.0000e-06, 1.0000e+00, 1.0000e-06, 1.0000e-06,
        1.0000e-06, 1.0000e-06, 1.0000e-06, 1.0000e-06])

In [68]:
unsqueeze_PxT

tensor([[[1.0000e-06, 1.0000e-06, 1.0000e-06, 1.0000e-06, 1.0000e-06,
          1.0000e-06, 1.0000e-06, 1.0000e-06, 1.0000e-06, 1.0000e-06,
          1.0000e-06, 1.0000e-06, 1.0000e-06, 1.0000e-06, 1.0000e-06,
          1.0000e-06, 1.0000e-06, 1.0000e-06, 1.0000e-06, 1.0000e-06,
          1.0000e-06, 1.0000e+00]]])