In [1]:
import torch
import numpy as np
import sklearn

from utils import *
from architectures import *
import preprocess

In [2]:
X_train, y_train, X_test, y_test = preprocess.main()

11608it [00:50, 231.99it/s]
100%|████████████████████████████████████████| 21/21 [00:00<00:00, 26379.27it/s]


In [3]:
class Denoiser(torch.nn.Module):
    def __init__(
        self,
        heads,
        d_time,
        d_time_hidden,
        d_seq,
        d_query = 128,
        d_key = 128,
        d_values = 128,
        d_hidden = 128,
        d_model = 128,
        p = 0.1,
        activation = torch.nn.ReLU(),
        **kwargs
    ):
        super().__init__(**kwargs)
        self.heads = heads
        self.activation = activation
        self.forward_time_1 = FeedForward(d_time, d_time_hidden)
        self.forward_time_2 = FeedForward(d_time_hidden, d_time_hidden)
        self.forward_time_3 = FeedForward(d_time_hidden, d_time_hidden)
 
        self.forward_seq_1 = FeedForward(d_seq, d_model)
        self.forward_seq_2 = FeedForward(d_model, d_model)
        self.forward_seq_3 = FeedForward(d_model, d_model)

        self.forward_time_seq = FeedForward(d_model + d_time_hidden, d_model)
        
        self.mha_1 = MultiHeadedAttention(self.heads, d_model, d_model, d_model, d_hidden, d_model)
        self.dropout_1 = torch.nn.Dropout(p)
        self.addnorm_1 = AddNorm(d_model)
        self.feedforward_1 = FeedForward(d_model, d_model)

        self.mha_2 = MultiHeadedAttention(self.heads, d_model, d_model, d_model, d_hidden, d_model)
        self.dropout_2 = torch.nn.Dropout(p)
        self.addnorm_2 = AddNorm(d_model)
        self.feedforward_2 = FeedForward(d_model, d_model)

        self.feedforward_3 = FeedForward(d_model, d_seq)
        
    def forward(self, X, t):

        time_points, batch_size, seq_length, aas = X.shape[0], X.shape[1], X.shape[2], X.shape[3]
        
        t = t.reshape(t.shape[0] * t.shape[1], 1)
        X = X.view(time_points*batch_size, seq_length, aas)


        time_encoding = self.forward_time_1(t)
        time_encoding = self.forward_time_2(time_encoding)
        time_encoding = self.forward_time_3(time_encoding)
        time_encoding = time_encoding.view(time_encoding.shape[0], 1, time_encoding.shape[1])

        time_encoding = time_encoding.tile((1, seq_length, 1))
        seq_encoding = self.forward_seq_1(X)
        seq_encoding = self.forward_seq_2(seq_encoding)
        seq_encoding = self.forward_seq_3(seq_encoding)

        
        seq_time_encoding = torch.concat([time_encoding, seq_encoding], dim=-1)

        input_encoding = self.forward_time_seq(seq_time_encoding)

        X_1, _ = self.mha_1(input_encoding, input_encoding, input_encoding)
        X_1    = self.dropout_1(X_1)
        X_1    = self.addnorm_1(X_1, input_encoding)
        X_1    = self.feedforward_1(X_1)

        X_2, _ = self.mha_2(X_1, X_1, X_1)
        X_2    = self.dropout_2(X_2)
        X_2    = self.addnorm_2(X_2,X_1)
        X_2    = self.feedforward_2(X_2)

        X_3    = self.feedforward_3(X_2)
        Y_pred = torch.nn.Softmax(dim=-1)(X_3)
        return Y_pred.view(time_points, batch_size, seq_length, aas)

In [4]:
protein_dataset = ProteinDataset(seq_data=X_train, include_mask=True)
protein_loader = torch.utils.data.DataLoader(protein_dataset, batch_size=32, shuffle=True)
for batch in protein_loader:
    X, Y = batch
    X_seq = X['seq']
    Y_seq = Y['seq']
    break

In [5]:
noise_matrix = Noiser(noiser = 'BERT-LIKE', beta_t = 0.01).noise_matrix
ts, noised_samples = noiser(X_seq, noise_matrix, 100, X_seq.shape[-1])

ts_reshaped = ts.reshape(ts.shape[0] * ts.shape[1], 1)
noised_samples_reshaped = noised_samples.view(noised_samples.shape[0]*noised_samples.shape[1], noised_samples.shape[2], noised_samples.shape[3])

In [7]:
denoiser = Denoiser(8,1,128,X_seq.shape[2])

In [8]:
denoiser.eval()

Denoiser(
  (activation): ReLU()
  (forward_time_1): FeedForward(
    (layer): Linear(in_features=1, out_features=128, bias=True)
    (activation): ReLU()
  )
  (forward_time_2): FeedForward(
    (layer): Linear(in_features=128, out_features=128, bias=True)
    (activation): ReLU()
  )
  (forward_time_3): FeedForward(
    (layer): Linear(in_features=128, out_features=128, bias=True)
    (activation): ReLU()
  )
  (forward_seq_1): FeedForward(
    (layer): Linear(in_features=22, out_features=128, bias=True)
    (activation): ReLU()
  )
  (forward_seq_2): FeedForward(
    (layer): Linear(in_features=128, out_features=128, bias=True)
    (activation): ReLU()
  )
  (forward_seq_3): FeedForward(
    (layer): Linear(in_features=128, out_features=128, bias=True)
    (activation): ReLU()
  )
  (forward_time_seq): FeedForward(
    (layer): Linear(in_features=256, out_features=128, bias=True)
    (activation): ReLU()
  )
  (mha_1): MultiHeadedAttention(
    (attention): Attention()
    (W_q): Li

In [9]:
noised_samples.shape

torch.Size([100, 32, 59, 22])

In [10]:
y_pred = denoiser(noised_samples, ts)

In [11]:
y_pred.shape

torch.Size([100, 32, 59, 22])

In [12]:
y_pred[1][0:2]

tensor([[[0.0496, 0.0415, 0.0415,  ..., 0.0426, 0.0496, 0.0442],
         [0.0496, 0.0415, 0.0415,  ..., 0.0426, 0.0496, 0.0442],
         [0.0521, 0.0417, 0.0417,  ..., 0.0433, 0.0491, 0.0428],
         ...,
         [0.0508, 0.0415, 0.0415,  ..., 0.0430, 0.0490, 0.0434],
         [0.0508, 0.0416, 0.0416,  ..., 0.0449, 0.0484, 0.0417],
         [0.0521, 0.0417, 0.0417,  ..., 0.0433, 0.0491, 0.0428]],

        [[0.0517, 0.0416, 0.0416,  ..., 0.0431, 0.0488, 0.0429],
         [0.0509, 0.0415, 0.0415,  ..., 0.0442, 0.0486, 0.0437],
         [0.0496, 0.0415, 0.0415,  ..., 0.0426, 0.0495, 0.0441],
         ...,
         [0.0508, 0.0416, 0.0416,  ..., 0.0449, 0.0483, 0.0416],
         [0.0508, 0.0415, 0.0415,  ..., 0.0440, 0.0483, 0.0440],
         [0.0521, 0.0417, 0.0417,  ..., 0.0438, 0.0483, 0.0423]]],
       grad_fn=<SliceBackward0>)

In [45]:
class CategoricalDiffusion(torch.nn.Module):
    def __init__(
        self,
        denoiser,
        noise_matrix,
        **kwargs
    ):
        super().__init__(**kwargs)
        self.denoiser = denoiser
        self.noise_matrix = noise_matrix
    def L_T(self, noised_sample, noise_matrix):
        vals, vecs = torch.linalg.eig(noise_matrix.t())
        vals = torch.real(vals)
        vecs = torch.real(vecs)
    
        PxT = vecs[:, torch.argmax(vals)].unsqueeze(0).unsqueeze(1)
        
        dkl_steady_state = torch.sum(
            (noised_sample+1e-6) * torch.log((noised_sample+1e-6)/(PxT+1e-6))
        )
    
        return dkl_steady_state
    
    def L_t0t1(self, real, ts, noised_sample):
        t_ones = ts[1].unsqueeze(0)
        noised_ones = noised_sample[1].unsqueeze(0)
        y_pred = self.denoiser(noised_ones, t_ones)
        return self.cross_entropy(real, y_pred)

    def L_tminus1(self, real, ts, noised_sample):
        
        reverse_marginals = self.one_step_reverse_marginal(real, noised_sample)

        denoised = self.denoiser(noised_sample, ts)
        print(denoised.shape)
        px_tminus1_giv_xt = self.px_tminus1_giv_xt(noised_samples, denoised, noise_matrix) 

        return torch.sum(reverse_marginals * torch.log((reverse_marginals+1e-6)/(px_tminus1_giv_xt+1e-6)))

    def cross_entropy(self, real, y_pred):
        return -torch.sum(real * torch.log(y_pred+1e-6))

    def one_step_reverse_marginal(self, real, noised_sample):
        reverse_marginals = torch.zeros(noised_samples.shape[0]-2, noised_samples.shape[1], noised_samples.shape[2], noised_samples.shape[3])
        x0=real
        for t in range(2, noised_samples.shape[0]):
            
            xt = noised_samples[t]
            numer = torch.matmul(xt, noise_matrix.t()) * torch.matmul(x0, self.noise_matrix.matrix_power(t-1))
            denom = torch.matmul(torch.matmul(x0, self.noise_matrix.matrix_power(t)), xt.permute(0,2,1))
            denom = torch.diagonal(denom, dim1=-2, dim2=-1).unsqueeze(-1)
            reverse_marginals[t-2] = numer/denom
            
        return reverse_marginals

    def px_tminus1_giv_xt(self, noised_samples, denoised, noise_matrix):
        
        px_onestepback = torch.zeros(denoised.shape[0]-2, denoised.shape[1], denoised.shape[2], denoised.shape[3])
        
        for t in range(2, denoised.shape[0]):
            
            denoised_estimate = denoised[t]
            real = noised_samples[0]
            weighted_expectation = real*denoised_estimate

            px_onestepback[t-2] = torch.matmul(weighted_expectation, noise_matrix.matrix_power(t-1))
        return px_onestepback
    

In [46]:
cat_diff = CategoricalDiffusion(denoiser, noise_matrix)

In [47]:
cat_diff.L_T(noised_samples, noise_matrix)

tensor(325148.0625)

In [48]:
cat_diff.L_t0t1(X_seq[0:2],ts[:,0:2], noised_samples[:,0:2,:,:])

tensor(364.4633, grad_fn=<NegBackward0>)

In [49]:
cat_diff.L_tminus1(X_seq, ts, noised_samples)

torch.Size([100, 32, 59, 22])


tensor(720177.4375, grad_fn=<SumBackward0>)

In [53]:
y_pred[1][0:2]

tensor([[[0.0681, 0.0589, 0.0496,  ..., 0.0395, 0.0395, 0.0431],
         [0.0698, 0.0512, 0.0459,  ..., 0.0393, 0.0393, 0.0403],
         [0.0639, 0.0480, 0.0537,  ..., 0.0403, 0.0403, 0.0413],
         ...,
         [0.0711, 0.0543, 0.0466,  ..., 0.0394, 0.0394, 0.0394],
         [0.0685, 0.0519, 0.0478,  ..., 0.0406, 0.0391, 0.0396],
         [0.0639, 0.0603, 0.0415,  ..., 0.0401, 0.0401, 0.0423]],

        [[0.0693, 0.0582, 0.0470,  ..., 0.0406, 0.0391, 0.0428],
         [0.0682, 0.0500, 0.0521,  ..., 0.0396, 0.0396, 0.0418],
         [0.0679, 0.0548, 0.0482,  ..., 0.0398, 0.0398, 0.0433],
         ...,
         [0.0657, 0.0554, 0.0494,  ..., 0.0403, 0.0403, 0.0403],
         [0.0653, 0.0551, 0.0487,  ..., 0.0394, 0.0394, 0.0404],
         [0.0755, 0.0506, 0.0526,  ..., 0.0398, 0.0398, 0.0400]]],
       grad_fn=<SliceBackward0>)

In [98]:
noised_samples.shape

torch.Size([100, 32, 59, 22])

In [102]:
y_pred_reshaped = y_pred.view(noised_samples.shape)

In [120]:
(y_pred_reshaped[2][0] == y_pred[64]).all()

tensor(True)

In [123]:
(noised_samples[2][0]==noised_samples_reshaped[64]).all()

tensor(True)

In [125]:
new_test = y_pred_reshaped[2]

In [129]:
(new_test * noised_samples[0])

tensor([[[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
         ...,
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000]],

        [[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
         ...,
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000]],

        [[0.0000, 0.0000, 0.0000,  ..., 0.0409, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000,  ..., 0.0415, 0.0000, 0.

In [94]:
test[test>0]

tensor([1.0826, 1.2775, 0.1294, 1.0052, 0.5819, 0.8000, 0.4890, 0.0455])

In [58]:
torch.matmul(x0,noise_matrix.matrix_power(3))

tensor([[[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0297],
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0297],
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0297],
         ...,
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0297],
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0297],
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0297]],

        [[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0297],
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0297],
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0297],
         ...,
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0297],
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0297],
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0297]],

        [[0.0000, 0.0000, 0.0000,  ..., 0.9703, 0.0000, 0.0297],
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0297],
         [0.0000, 0.0000, 0.0000,  ..., 0.9703, 0.0000, 0.

In [39]:
reverse_marginals = torch.zeros(noised_samples.shape[0]-2, noised_samples.shape[1], noised_samples.shape[2], noised_samples.shape[3])

In [41]:
for t in range(2, noised_samples.shape[0]):
    xt = noised_samples[t]
    x0 = noised_samples[0]

    numer = torch.matmul(xt, noise_matrix.t()) * torch.matmul(x0, noise_matrix.matrix_power(t-1))
    denom = torch.matmul(torch.matmul(x0, noise_matrix.matrix_power(t)), xt.permute(0,2,1))
    denom = torch.diagonal(denom, dim1=-2, dim2=-1).unsqueeze(-1)
    reverse_marginals[t-2] = numer/denom

In [63]:
noised_samples[3][0][0]

tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 1.])

In [86]:
test = reverse_marginals[1]

In [87]:
test

tensor([[[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.6700],
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
         ...,
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000]],

        [[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
         ...,
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000]],

        [[0.0000, 0.0000, 0.0000,  ..., 1.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000,  ..., 1.0000, 0.0000, 0.

In [69]:
test.shape

torch.Size([32, 59, 22])

In [73]:
other_test = test * torch.matmul(noised_samples[0], noise_matrix.matrix_power(3))

In [78]:
other_test[0]

tensor([[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0239],
        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        ...,
        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000]])

In [28]:
ts[1].unsqueeze(1).shape

torch.Size([32, 1])

In [30]:
noised_samples[1].shape

torch.Size([32, 59, 22])

In [38]:
torch.matmul(noised_samples, noise_matrix.t()) * torch.matmul(

tensor([[[[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          ...,
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000]],

         [[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          ...,
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000]],

         [[0.0000, 0.0000, 0.0000,  ..., 0.9900, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0

In [21]:
L_T(noised_samples_reshaped, noise_matrix)

tensor(325811.2188)

In [None]:
class CategoricalDiffusion(torch.nn.Module):
    def __init__(
        denoiser
        
    ):
        super().__init__(**kwargs)
        self.denoiser = denoiser

    def denoising_process(self, noised_sample, ts):
        y_pred = self.denoiser(noised_sample)
        return y_pred

    def L_T(self, noised_sample, noise_matrix):
        vals, vecs = torch.linalg.eig(noise_matrix.t())
        vals = torch.real(vals)
        vecs = torch.real(vecs)

        PxT = vecs[:, torch.argmax(vals)].unsqueeze(0).unsqueeze(1)

        no_zeros_PxT = PxT + 1e-6
        no_zeros_X   = noised_sample + 1e-6
        
        dkl_steady_state = torch.sum(
            no_zeros_X * torch.log(no_zeros_X/no_zeros_PxT)
        )

        return dkl_steady_state

    def L_tminus1(self, real, noise, forward_step):
        fake = self.denoising_process(noise)
        output = real*fake
        qxtminus1_giv_xt_xtilde = reverse_marginal(noise_matrix, forward_step, reverse_step, output, noise)
        qxtminus1_xt_giv_xtilde = qxtminus1_giv_xt_xtilde * noise_matrix.matrix_power(forward_step)


    def L_t1(self, real, fake):
        """
        
        """
        return torch.sum(real * torch.log(fake+1e-6))
        