In [1]:
import torch
import numpy as np
import sklearn

from utils import *
from architectures import *
import preprocess

import matplotlib.pyplot as plt

from CategoricalDiffusion import *
from denoiser import *

In [2]:
class MLPDenoiser(torch.nn.Module):
    def __init__(
        self,
        d_time,
        d_seq,
        d_aas,
        d_model = 128,
        p = 0.1,
        activation = torch.nn.ReLU(),
        **kwargs
    ):
        super().__init__(**kwargs)
        self.activation = activation
        self.p = p


        self.forward_time_seq = FeedForward(d_seq*d_aas + d_time, d_model)

        self.forward_time_seq_1 = FeedForward(d_model, d_model)
        self.dropout_1 = torch.nn.Dropout(self.p)
        
        self.forward_time_seq_2 = FeedForward(d_model, d_model)
        self.dropout_2 = torch.nn.Dropout(self.p)
        
        self.forward_time_seq_3 = FeedForward(d_model, d_model)
        self.dropout_3 = torch.nn.Dropout(self.p)

        self.feedforward_final = FeedForward(d_model, d_seq*d_aas)
        
    def forward(self, X, t):

        time_points, batch_size, seq_length, aas = X.shape[0], X.shape[1], X.shape[2], X.shape[3]
        
        t = t.reshape(t.shape[0] * t.shape[1], 1)
        X = X.view(time_points*batch_size, seq_length, aas)
        
        X_flattened = torch.nn.Flatten(start_dim=1)(X)
        
        seq_time_encoding = torch.concat([t, X_flattened], dim=-1)

        input_encoding = self.forward_time_seq(seq_time_encoding)

        X_1 = self.forward_time_seq_1(input_encoding)
        X_1 = self.activation(X_1)
        X_1 = self.dropout_1(X_1)

        X_2 = self.forward_time_seq_2(X_1)
        X_2 = self.activation(X_2)
        X_2 = self.dropout_2(X_2)

        X_3 = self.forward_time_seq_3(X_2)
        X_3 = self.activation(X_3)
        X_3 = self.dropout_3(X_3)

        X_final = self.feedforward_final(X_3)

        X_final = X_final.view(time_points, batch_size, seq_length, aas)
        
        Y_pred = torch.nn.Softmax(dim=-1)(X_final)
        return Y_pred

In [3]:
N_samples = 10000

M = 2
N_dim = 100

bias = torch.zeros((N_dim,1))
patterns = 2*torch.bernoulli(torch.ones((N_dim, M)) * 0.5) - 1

spins = generate_data(patterns, bias, N_samples, beta=4)

100%|█████████████████████████████████████████| 100/100 [00:21<00:00,  4.56it/s]


In [4]:
index_spins = ((1+spins)/2).type(torch.int64)

In [5]:
one_hot_spin_vectors = torch.nn.functional.one_hot(index_spins)

In [6]:
class SpinsDataset(torch.utils.data.Dataset):
    """
    takes in one hot encoded spins X with one key - sequence, and Y with one, potentially two, keys - sequence and phenotype
    inputs:
        one_hot_spins: torch.Tensor representing spins (num_samples, dim, num_states)
    
    """
    
    def __init__(self,
                 spins,
                 include_mask = False,
                **kwargs):
        super().__init__(**kwargs)
        self.include_mask = include_mask
        self.spins = spins
        if self.include_mask:

            self.spins = torch.cat((self.spins, torch.zeros((self.spins.shape[0], self.spins.shape[1], 1))), axis=-1)

    def __len__(self):
        return self.spins.shape[0]

    def __getitem__(self, index):
        X = dict()
        Y = dict()

        Y['spins'] = self.spins[index]
        X['spins'] = self.spins[index]
        return X, Y

In [7]:
spins_dataset = SpinsDataset(one_hot_spin_vectors, include_mask=True)
spins_loader = torch.utils.data.DataLoader(spins_dataset, batch_size=32, shuffle=True)

In [8]:
for batch in spins_loader:
    X, Y = batch
    X_spins = X['spins']
    Y_spins = Y['spins']
    

In [9]:
noise_matrix = Noiser(noiser = 'BERT-LIKE', beta_t = 0.01, k=2).noise_matrix
ts, noised_samples = noiser(X_spins, noise_matrix, 200, X_spins.shape[-1])

In [10]:
noised_samples.shape

torch.Size([200, 16, 100, 3])

In [None]:
from tqdm import tqdm
epochs = 100

denoiser = MLPDenoiser(d_time = 1, d_seq = X_spins.shape[1], d_aas = X_spins.shape[2], d_model = 1024, p=0.9, activation = torch.nn.Tanh())
cat_diff = CategoricalDiffusion(denoiser, noise_matrix)

cat_diff.train()
optim = torch.optim.Adam(cat_diff.parameters(), lr = 1e-3)

for epoch in tqdm(range(epochs)):
    overall_loss = 0
    spins_loader = torch.utils.data.DataLoader(spins_dataset, batch_size=128, shuffle=True)
    for batch in spins_loader:
        X, Y = batch
        
        X_spins = X['spins']
        Y_spins = Y['spins']
        
        ts, noised_samples = noiser(X_spins, noise_matrix, 200, X_spins.shape[-1])

        cat_diff.decode(noised_samples, ts)
        cat_diff.calc_forward_conditionals(noised_samples)
        
        LT_loss     = cat_diff.L_T(noised_samples)
        
        Lt0_t1_loss = cat_diff.L_t0t1(Y_spins)

        Ltminus1    = cat_diff.L_tminus1(Y_spins, noised_samples)


        loss = Lt0_t1_loss + Ltminus1
        
        optim.zero_grad()
        loss.backward()

        #torch.nn.utils.clip_grad_value_(cat_diff.parameters(), 0.1)

        optim.step()
        if np.isnan(loss.item()):
            break
            
        overall_loss += loss.item()
    if np.isnan(loss.item()):
        break
    print('overall loss at epoch {} is '.format(epoch) + str(overall_loss))

  0%|                                                   | 0/100 [00:00<?, ?it/s]

In [None]:
noised_samples[99][0]

In [None]:
cat_diff.y_pred[99][0]

In [208]:
denoised = cat_diff.y_pred[2:]

px_tminus1_giv_xt = q_xtminus1_xt_giv_x0 * denoised

In [209]:
px_tminus1_giv_xt

tensor([[[[0.0000e+00, 3.1462e-01, 0.0000e+00],
          [3.5677e-01, 0.0000e+00, 0.0000e+00],
          [3.3245e-01, 0.0000e+00, 0.0000e+00],
          ...,
          [0.0000e+00, 3.2670e-01, 0.0000e+00],
          [0.0000e+00, 3.0302e-01, 0.0000e+00],
          [0.0000e+00, 2.9244e-01, 0.0000e+00]],

         [[0.0000e+00, 3.3716e-01, 0.0000e+00],
          [3.2219e-01, 0.0000e+00, 0.0000e+00],
          [2.9704e-01, 0.0000e+00, 0.0000e+00],
          ...,
          [0.0000e+00, 3.2619e-01, 0.0000e+00],
          [3.6161e-01, 0.0000e+00, 0.0000e+00],
          [2.8870e-01, 0.0000e+00, 0.0000e+00]],

         [[0.0000e+00, 3.0515e-01, 0.0000e+00],
          [0.0000e+00, 3.1967e-01, 0.0000e+00],
          [3.1448e-01, 0.0000e+00, 0.0000e+00],
          ...,
          [2.8615e-01, 0.0000e+00, 0.0000e+00],
          [0.0000e+00, 3.0698e-01, 0.0000e+00],
          [3.2339e-01, 0.0000e+00, 0.0000e+00]],

         ...,

         [[0.0000e+00, 1.5856e-01, 3.2519e-03],
          [3.3326e-01,

In [212]:
torch.mean(
            torch.sum(
                cat_diff.forward_conditionals[2:] * (
                    stuff * torch.log(
                        (stuff + 1e-6)/(px_tminus1_giv_xt+1e-6)
                    )
                ),
                axis=-1
            )
        )

tensor(nan, grad_fn=<MeanBackward0>)

In [226]:
np.where(np.isnan(cat_diff.forward_conditionals[2:] * (
        stuff * torch.log(
            (stuff + 1e-6)/(px_tminus1_giv_xt+1e-6)
        )
    ).detach().numpy()))

(array([94, 94, 94]),
 array([24, 24, 24]),
 array([15, 15, 15]),
 array([0, 1, 2]))

In [234]:
stuff[94,24,15,2]

tensor(nan)

In [230]:
cat_diff.forward_conditionals[2:] * (
        stuff * torch.log(
            (stuff + 1e-6)/(px_tminus1_giv_xt+1e-6)
        )
)[94,24,15,0]
    

tensor([[[[nan, nan, nan],
          [nan, nan, nan],
          [nan, nan, nan],
          ...,
          [nan, nan, nan],
          [nan, nan, nan],
          [nan, nan, nan]],

         [[nan, nan, nan],
          [nan, nan, nan],
          [nan, nan, nan],
          ...,
          [nan, nan, nan],
          [nan, nan, nan],
          [nan, nan, nan]],

         [[nan, nan, nan],
          [nan, nan, nan],
          [nan, nan, nan],
          ...,
          [nan, nan, nan],
          [nan, nan, nan],
          [nan, nan, nan]],

         ...,

         [[nan, nan, nan],
          [nan, nan, nan],
          [nan, nan, nan],
          ...,
          [nan, nan, nan],
          [nan, nan, nan],
          [nan, nan, nan]],

         [[nan, nan, nan],
          [nan, nan, nan],
          [nan, nan, nan],
          ...,
          [nan, nan, nan],
          [nan, nan, nan],
          [nan, nan, nan]],

         [[nan, nan, nan],
          [nan, nan, nan],
          [nan, nan, nan],
         

In [215]:
np.isnan(torch.sum(
    cat_diff.forward_conditionals[2:] * (
        stuff * torch.log(
            (stuff + 1e-6)/(px_tminus1_giv_xt+1e-6)
        )
    ),
    axis=-1
).detach().numpy()).any()

True

In [216]:
np.where(np.isnan(torch.sum(
    cat_diff.forward_conditionals[2:] * (
        stuff * torch.log(
            (stuff + 1e-6)/(px_tminus1_giv_xt+1e-6)
        )
    ),
    axis=-1
).detach().numpy()))

(array([94]), array([24]), array([15]))

In [220]:
torch.sum(
    cat_diff.forward_conditionals[2:] * (
        stuff * torch.log(
            (stuff + 1e-6)/(px_tminus1_giv_xt+1e-6)
        )
    ),
    axis=-1
)[94][24]

tensor([1.0332, 0.6706, 0.9716, 0.9757, 0.9757, 0.9757, 1.0741, 0.7729, 1.1290,
        0.9757, 0.9757, 1.4398, 0.9554, 0.7639, 0.8608,    nan, 0.7205, 0.7551,
        1.3197, 1.0787, 1.0720, 1.4066, 0.6798, 1.0329, 1.2583, 1.5059, 1.4623,
        0.8676, 0.9757, 0.8594, 1.1232, 0.9757, 0.9757, 0.9757, 1.1720, 1.2588,
        0.7667, 0.7437, 1.1948, 0.9847, 0.7845, 0.5300, 0.6469, 1.4376, 1.0879,
        0.9078, 0.7355, 1.7366, 1.0102, 1.1947, 0.6780, 1.0162, 1.2185, 1.2742,
        1.0918, 0.9757, 1.2834, 0.9914, 0.9757, 1.0105, 1.1534, 1.3210, 1.2183,
        0.6016, 1.0976, 0.8873, 1.2643, 1.6309, 0.6532, 1.2003, 1.3970, 1.0385,
        0.8668, 0.9967, 1.5562, 0.8305, 1.2290, 0.5932, 0.9200, 1.2263, 1.3277,
        0.9840, 1.3741, 0.9757, 0.9794, 1.1538, 0.9106, 1.0421, 1.3159, 0.8671,
        0.7996, 1.5560, 1.2296, 1.1848, 1.4853, 0.9757, 1.4080, 0.9757, 1.1681,
        0.9757], grad_fn=<SelectBackward0>)

In [222]:
cat_diff.forward_conditionals[2:].shape

torch.Size([98, 128, 100, 3])

In [180]:
Lt0_t1_loss = cat_diff.L_t0t1(Y_spins)

In [181]:
Lt0_t1_loss

tensor(1.1005, grad_fn=<NegBackward0>)

In [182]:
Ltminus1    = cat_diff.L_tminus1(Y_spins, noised_samples)

In [91]:
cat_diff.eval()

t = ts.reshape(ts.shape[0] * ts.shape[1], 1)

time_encoding = cat_diff.denoiser.forward_time_1(t)
cat_diff.denoiser.forward_time_2(time_encoding).shape

torch.Size([12800, 128])

In [93]:
print(time_encoding[:,0])

tensor([ 0.0000,  0.0000,  0.0000,  ..., 37.5615, 37.5615, 37.5615],
       grad_fn=<SelectBackward0>)


In [183]:
Ltminus1

tensor(nan, grad_fn=<MeanBackward0>)

In [201]:
noised_samples[1][3][:,2]

tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])

tensor([[[1., 0., 0.],
         [0., 1., 0.],
         [1., 0., 0.],
         ...,
         [1., 0., 0.],
         [0., 1., 0.],
         [1., 0., 0.]],

        [[0., 1., 0.],
         [0., 1., 0.],
         [1., 0., 0.],
         ...,
         [0., 0., 1.],
         [0., 1., 0.],
         [1., 0., 0.]],

        [[1., 0., 0.],
         [0., 1., 0.],
         [1., 0., 0.],
         ...,
         [1., 0., 0.],
         [0., 1., 0.],
         [0., 1., 0.]],

        ...,

        [[0., 1., 0.],
         [1., 0., 0.],
         [1., 0., 0.],
         ...,
         [0., 1., 0.],
         [0., 1., 0.],
         [1., 0., 0.]],

        [[0., 1., 0.],
         [1., 0., 0.],
         [1., 0., 0.],
         ...,
         [0., 1., 0.],
         [0., 1., 0.],
         [1., 0., 0.]],

        [[0., 1., 0.],
         [1., 0., 0.],
         [0., 1., 0.],
         ...,
         [0., 1., 0.],
         [0., 1., 0.],
         [1., 0., 0.]]])