In [1]:
import torch
import numpy as np
import sklearn

from utils import *
from architectures import *
import preprocess

import matplotlib.pyplot as plt

from CategoricalDiffusion import *
from denoiser import *

In [12]:
class MLPDenoiser(torch.nn.Module):
    def __init__(
        self,
        d_time,
        d_seq,
        d_aas,
        d_model = 128,
        p = 0.1,
        activation = torch.nn.ReLU(),
        **kwargs
    ):
        super().__init__(**kwargs)
        self.activation = activation
        self.p = p


        self.forward_time_seq = FeedForward(d_seq*d_aas + d_time, d_model)

        self.forward_time_seq_1 = FeedForward(d_model, d_model)
        self.dropout_1 = torch.nn.Dropout(self.p)
        
        self.forward_time_seq_2 = FeedForward(d_model, d_model)
        self.dropout_2 = torch.nn.Dropout(self.p)
        
        self.forward_time_seq_3 = FeedForward(d_model, d_model)
        self.dropout_3 = torch.nn.Dropout(self.p)

        self.feedforward_final = FeedForward(d_model, d_seq*d_aas)
        
    def forward(self, X, t):

        time_points, batch_size, seq_length, aas = X.shape[0], X.shape[1], X.shape[2], X.shape[3]
        
        t = t.reshape(t.shape[0] * t.shape[1], 1)
        X = X.view(time_points*batch_size, seq_length, aas)
        
        X_flattened = torch.nn.Flatten(start_dim=1)(X)
        
        seq_time_encoding = torch.concat([t, X_flattened], dim=-1)

        input_encoding = self.forward_time_seq(seq_time_encoding)

        X_1 = self.forward_time_seq_1(input_encoding)

        X_1 = self.dropout_1(X_1)

        X_2 = self.forward_time_seq_2(X_1)

        X_2 = self.dropout_2(X_2)

        X_3 = self.forward_time_seq_3(X_2)

        X_3 = self.dropout_3(X_3)

        X_final = self.feedforward_final(X_3)

        X_final = X_final.view(time_points, batch_size, seq_length, aas)
        
        Y_pred = torch.nn.Softmax(dim=-1)(X_final)
        return Y_pred

In [3]:
N_samples = 10000

M = 2
N_dim = 100

bias = torch.zeros((N_dim,1))
patterns = 2*torch.bernoulli(torch.ones((N_dim, M)) * 0.5) - 1

spins = generate_data(patterns, bias, N_samples, beta=4)

100%|█████████████████████████████████████████| 100/100 [00:21<00:00,  4.56it/s]


In [4]:
index_spins = ((1+spins)/2).type(torch.int64)

In [5]:
one_hot_spin_vectors = torch.nn.functional.one_hot(index_spins)

In [6]:
class SpinsDataset(torch.utils.data.Dataset):
    """
    takes in one hot encoded spins X with one key - sequence, and Y with one, potentially two, keys - sequence and phenotype
    inputs:
        one_hot_spins: torch.Tensor representing spins (num_samples, dim, num_states)
    
    """
    
    def __init__(self,
                 spins,
                 include_mask = False,
                **kwargs):
        super().__init__(**kwargs)
        self.include_mask = include_mask
        self.spins = spins
        if self.include_mask:

            self.spins = torch.cat((self.spins, torch.zeros((self.spins.shape[0], self.spins.shape[1], 1))), axis=-1)

    def __len__(self):
        return self.spins.shape[0]

    def __getitem__(self, index):
        X = dict()
        Y = dict()

        Y['spins'] = self.spins[index]
        X['spins'] = self.spins[index]
        return X, Y

In [7]:
spins_dataset = SpinsDataset(one_hot_spin_vectors, include_mask=True)
spins_loader = torch.utils.data.DataLoader(spins_dataset, batch_size=32, shuffle=True)

In [8]:
for batch in spins_loader:
    X, Y = batch
    X_spins = X['spins']
    Y_spins = Y['spins']
    

In [9]:
noise_matrix = Noiser(noiser = 'BERT-LIKE', beta_t = 0.01, k=2).noise_matrix
ts, noised_samples = noiser(X_spins, noise_matrix, 200, X_spins.shape[-1])

In [10]:
noised_samples.shape

torch.Size([200, 16, 100, 3])

In [13]:
from tqdm import tqdm
epochs = 100

denoiser = MLPDenoiser(d_time = 1, d_seq = X_spins.shape[1], d_aas = X_spins.shape[2], d_model = 128, p=0.1, activation = torch.nn.ReLU())
cat_diff = CategoricalDiffusion(denoiser, noise_matrix)

cat_diff.train()
optim = torch.optim.Adam(cat_diff.parameters(), lr = 1e-3)

for epoch in tqdm(range(epochs)):
    overall_loss = 0
    spins_loader = torch.utils.data.DataLoader(spins_dataset, batch_size=128, shuffle=True)
    for batch in spins_loader:
        X, Y = batch
        
        X_spins = X['spins']
        Y_spins = Y['spins']
        
        ts, noised_samples = noiser(X_spins, noise_matrix, 200, X_spins.shape[-1])

        cat_diff.decode(noised_samples, ts)
        cat_diff.calc_forward_conditionals(noised_samples)
        
        LT_loss     = cat_diff.L_T(noised_samples)
        
        Lt0_t1_loss = cat_diff.L_t0t1(Y_spins)

        Ltminus1    = cat_diff.L_tminus1(Y_spins, noised_samples)


        loss = Lt0_t1_loss + Ltminus1
        
        optim.zero_grad()
        loss.backward()

        #torch.nn.utils.clip_grad_value_(cat_diff.parameters(), 0.1)

        optim.step()
        if np.isnan(loss.item()):
            break
            
        overall_loss += loss.item()
    if np.isnan(loss.item()):
        break
    print('overall loss at epoch {} is '.format(epoch) + str(overall_loss))

  1%|▍                                        | 1/100 [01:23<2:17:58, 83.62s/it]

overall loss at epoch 0 is 100.06969994306564


  2%|▊                                        | 2/100 [03:00<2:29:03, 91.26s/it]

overall loss at epoch 1 is 64.72410082817078


  3%|█▏                                       | 3/100 [04:37<2:31:58, 94.00s/it]

overall loss at epoch 2 is 62.15826719999313


  4%|█▋                                       | 4/100 [06:17<2:33:55, 96.20s/it]

overall loss at epoch 3 is 60.835922598838806


  5%|██                                       | 5/100 [07:59<2:35:55, 98.48s/it]

overall loss at epoch 4 is 59.70040690898895


  6%|██▍                                      | 6/100 [09:41<2:36:01, 99.59s/it]

overall loss at epoch 5 is 59.211028814315796


  7%|██▊                                     | 7/100 [11:31<2:39:50, 103.12s/it]

overall loss at epoch 6 is 58.22878170013428


  8%|███▏                                    | 8/100 [13:18<2:39:41, 104.14s/it]

overall loss at epoch 7 is 57.48782426118851


  9%|███▌                                    | 9/100 [15:10<2:41:40, 106.60s/it]

overall loss at epoch 8 is 57.26657783985138


 10%|███▉                                   | 10/100 [16:51<2:37:44, 105.16s/it]

overall loss at epoch 9 is 57.30648738145828


 11%|████▎                                  | 11/100 [18:29<2:32:36, 102.88s/it]

overall loss at epoch 10 is 57.21241170167923


 12%|████▋                                  | 12/100 [20:06<2:28:06, 100.99s/it]

overall loss at epoch 11 is 57.10530799627304


 13%|█████                                  | 13/100 [21:44<2:25:10, 100.12s/it]

overall loss at epoch 12 is 57.073893547058105


 14%|█████▍                                 | 14/100 [23:26<2:24:30, 100.82s/it]

overall loss at epoch 13 is 56.110202729701996


 15%|█████▊                                 | 15/100 [25:05<2:21:41, 100.02s/it]

overall loss at epoch 14 is 55.54197037220001


 16%|██████▏                                | 16/100 [26:46<2:20:25, 100.31s/it]

overall loss at epoch 15 is 55.242895007133484


 17%|██████▋                                | 17/100 [28:36<2:23:03, 103.41s/it]

overall loss at epoch 16 is 54.28800082206726


 18%|███████                                | 18/100 [30:16<2:19:49, 102.31s/it]

overall loss at epoch 17 is 53.56894689798355


 19%|███████▍                               | 19/100 [31:58<2:18:07, 102.32s/it]

overall loss at epoch 18 is 53.33845406770706


 20%|███████▊                               | 20/100 [33:51<2:20:45, 105.56s/it]

overall loss at epoch 19 is 53.37296104431152


 21%|████████▏                              | 21/100 [35:35<2:18:13, 104.98s/it]

overall loss at epoch 20 is 53.28182291984558


 22%|████████▌                              | 22/100 [37:13<2:13:41, 102.84s/it]

overall loss at epoch 21 is 53.26785868406296


 23%|████████▉                              | 23/100 [38:56<2:12:14, 103.05s/it]

overall loss at epoch 22 is 52.71163988113403


 24%|█████████▎                             | 24/100 [40:37<2:09:44, 102.42s/it]

overall loss at epoch 23 is 51.93733125925064


 25%|██████████                              | 25/100 [41:57<1:59:18, 95.44s/it]

overall loss at epoch 24 is 51.699801206588745


 26%|██████████▍                             | 26/100 [43:17<1:52:13, 90.99s/it]

overall loss at epoch 25 is 51.45796573162079


 27%|██████████▊                             | 27/100 [44:40<1:47:38, 88.47s/it]

overall loss at epoch 26 is 51.36398386955261


 28%|███████████▏                            | 28/100 [46:37<1:56:33, 97.13s/it]

overall loss at epoch 27 is 51.33103656768799


 29%|███████████▌                            | 29/100 [48:21<1:57:13, 99.07s/it]

overall loss at epoch 28 is 51.15240693092346


 30%|████████████                            | 30/100 [50:02<1:56:19, 99.71s/it]

overall loss at epoch 29 is 51.094566106796265


 31%|████████████                           | 31/100 [51:45<1:56:00, 100.87s/it]

overall loss at epoch 30 is 50.81892514228821


 32%|████████████▍                          | 32/100 [53:33<1:56:27, 102.76s/it]

overall loss at epoch 31 is 50.66588360071182


 33%|████████████▊                          | 33/100 [55:18<1:55:44, 103.64s/it]

overall loss at epoch 32 is 50.23755240440369


 34%|█████████████▎                         | 34/100 [56:58<1:52:45, 102.51s/it]

overall loss at epoch 33 is 49.900306046009064


 35%|█████████████▋                         | 35/100 [58:35<1:49:06, 100.71s/it]

overall loss at epoch 34 is 49.43292784690857


 36%|█████████████▎                       | 36/100 [1:00:16<1:47:36, 100.88s/it]

overall loss at epoch 35 is 49.06233096122742


 37%|██████████████                        | 37/100 [1:01:53<1:44:39, 99.68s/it]

overall loss at epoch 36 is 48.983771204948425


 38%|██████████████                       | 38/100 [1:03:39<1:45:05, 101.70s/it]

overall loss at epoch 37 is 48.90850472450256


 39%|██████████████▍                      | 39/100 [1:05:15<1:41:41, 100.02s/it]

overall loss at epoch 38 is 48.70907241106033


 40%|███████████████▏                      | 40/100 [1:06:51<1:38:40, 98.67s/it]

overall loss at epoch 39 is 48.59527415037155


 41%|███████████████▌                      | 41/100 [1:08:30<1:37:15, 98.92s/it]

overall loss at epoch 40 is 48.53797972202301


 42%|███████████████▉                      | 42/100 [1:10:08<1:35:16, 98.56s/it]

overall loss at epoch 41 is 48.35039621591568


 43%|████████████████▎                     | 43/100 [1:11:45<1:33:05, 97.99s/it]

overall loss at epoch 42 is 48.26460725069046


 44%|████████████████▎                    | 44/100 [1:13:40<1:36:12, 103.09s/it]

overall loss at epoch 43 is 48.18594175577164


 45%|████████████████▋                    | 45/100 [1:15:19<1:33:31, 102.03s/it]

overall loss at epoch 44 is 48.066030979156494


 46%|█████████████████                    | 46/100 [1:17:06<1:33:01, 103.35s/it]

overall loss at epoch 45 is 47.954599261283875


 47%|█████████████████▍                   | 47/100 [1:18:48<1:31:04, 103.10s/it]

overall loss at epoch 46 is 47.848890244960785


 48%|█████████████████▊                   | 48/100 [1:20:36<1:30:32, 104.47s/it]

overall loss at epoch 47 is 47.69820964336395


 49%|██████████████████▏                  | 49/100 [1:22:25<1:30:04, 105.98s/it]

overall loss at epoch 48 is 47.62423038482666


 50%|██████████████████▌                  | 50/100 [1:24:07<1:27:05, 104.51s/it]

overall loss at epoch 49 is 47.586633920669556


 51%|██████████████████▊                  | 51/100 [1:25:46<1:24:04, 102.95s/it]

overall loss at epoch 50 is 47.51938647031784


 52%|███████████████████▏                 | 52/100 [1:27:25<1:21:32, 101.93s/it]

overall loss at epoch 51 is 47.44450807571411


 53%|███████████████████▌                 | 53/100 [1:29:04<1:19:08, 101.03s/it]

overall loss at epoch 52 is 47.307036101818085


 54%|███████████████████▉                 | 54/100 [1:30:43<1:16:56, 100.37s/it]

overall loss at epoch 53 is 47.28451108932495


 55%|████████████████████▉                 | 55/100 [1:32:21<1:14:40, 99.56s/it]

overall loss at epoch 54 is 47.24653619527817


 56%|████████████████████▋                | 56/100 [1:34:03<1:13:29, 100.21s/it]

overall loss at epoch 55 is 47.11367845535278


 57%|█████████████████████▋                | 57/100 [1:35:42<1:11:37, 99.95s/it]

overall loss at epoch 56 is 47.07690840959549


 58%|██████████████████████                | 58/100 [1:37:13<1:08:10, 97.40s/it]

overall loss at epoch 57 is 46.9900666475296


 59%|█████████████████████▊               | 59/100 [1:39:07<1:09:48, 102.15s/it]

overall loss at epoch 58 is 46.93121302127838


 60%|██████████████████████▏              | 60/100 [1:40:52<1:08:43, 103.09s/it]

overall loss at epoch 59 is 46.861046731472015


 61%|██████████████████████▌              | 61/100 [1:42:41<1:08:08, 104.84s/it]

overall loss at epoch 60 is 46.787689208984375


 62%|██████████████████████▉              | 62/100 [1:45:01<1:13:07, 115.46s/it]

overall loss at epoch 61 is 46.72508245706558


 63%|███████████████████████▎             | 63/100 [1:46:59<1:11:45, 116.37s/it]

overall loss at epoch 62 is 46.71050316095352


 64%|███████████████████████▋             | 64/100 [1:48:45<1:07:52, 113.13s/it]

overall loss at epoch 63 is 46.61284452676773


 65%|████████████████████████             | 65/100 [1:50:28<1:04:14, 110.12s/it]

overall loss at epoch 64 is 46.561629474163055


 66%|████████████████████████▍            | 66/100 [1:52:27<1:03:49, 112.62s/it]

overall loss at epoch 65 is 46.54560285806656


 67%|████████████████████████▊            | 67/100 [1:54:19<1:01:53, 112.53s/it]

overall loss at epoch 66 is 46.52586871385574


 68%|██████████████████████████▌            | 68/100 [1:56:01<58:24, 109.53s/it]

overall loss at epoch 67 is 46.4762402176857


 69%|██████████████████████████▉            | 69/100 [1:57:48<56:09, 108.69s/it]

overall loss at epoch 68 is 46.280232667922974


 70%|███████████████████████████▎           | 70/100 [1:59:37<54:21, 108.72s/it]

overall loss at epoch 69 is 46.25557631254196


 71%|███████████████████████████▋           | 71/100 [2:01:46<55:26, 114.70s/it]

overall loss at epoch 70 is 46.24967133998871


 72%|████████████████████████████           | 72/100 [2:03:16<50:07, 107.42s/it]

overall loss at epoch 71 is 46.144320011138916


 73%|█████████████████████████████▏          | 73/100 [2:04:35<44:31, 98.94s/it]

overall loss at epoch 72 is 46.010707437992096


 74%|█████████████████████████████▌          | 74/100 [2:05:53<40:10, 92.72s/it]

overall loss at epoch 73 is 46.083362340927124


 75%|██████████████████████████████          | 75/100 [2:07:13<36:57, 88.69s/it]

overall loss at epoch 74 is 46.009502589702606


 76%|██████████████████████████████▍         | 76/100 [2:08:33<34:29, 86.25s/it]

overall loss at epoch 75 is 45.898802638053894


 77%|██████████████████████████████▊         | 77/100 [2:10:01<33:12, 86.63s/it]

overall loss at epoch 76 is 45.94507896900177


 78%|███████████████████████████████▏        | 78/100 [2:11:20<30:57, 84.43s/it]

overall loss at epoch 77 is 45.86523634195328


 79%|███████████████████████████████▌        | 79/100 [2:12:39<28:58, 82.79s/it]

overall loss at epoch 78 is 45.7953125834465


 80%|████████████████████████████████        | 80/100 [2:13:58<27:13, 81.67s/it]

overall loss at epoch 79 is 45.89622259140015


 81%|████████████████████████████████▍       | 81/100 [2:15:17<25:35, 80.79s/it]

overall loss at epoch 80 is 45.756455302238464


 82%|████████████████████████████████▊       | 82/100 [2:16:35<23:59, 79.96s/it]

overall loss at epoch 81 is 45.715791404247284


 83%|█████████████████████████████████▏      | 83/100 [2:17:53<22:27, 79.28s/it]

overall loss at epoch 82 is 45.746273934841156


 84%|█████████████████████████████████▌      | 84/100 [2:19:11<21:06, 79.16s/it]

overall loss at epoch 83 is 45.700822710990906


 85%|██████████████████████████████████      | 85/100 [2:20:31<19:49, 79.28s/it]

overall loss at epoch 84 is 45.62009572982788


 86%|██████████████████████████████████▍     | 86/100 [2:21:49<18:23, 78.85s/it]

overall loss at epoch 85 is 45.55710655450821


 87%|██████████████████████████████████▊     | 87/100 [2:23:08<17:06, 78.93s/it]

overall loss at epoch 86 is 45.47724175453186


 88%|███████████████████████████████████▏    | 88/100 [2:24:27<15:47, 78.93s/it]

overall loss at epoch 87 is 45.39179515838623


 89%|███████████████████████████████████▌    | 89/100 [2:25:45<14:26, 78.78s/it]

overall loss at epoch 88 is 45.45664024353027


 90%|████████████████████████████████████    | 90/100 [2:27:04<13:08, 78.83s/it]

overall loss at epoch 89 is 45.41257292032242


 91%|████████████████████████████████████▍   | 91/100 [2:28:23<11:49, 78.83s/it]

overall loss at epoch 90 is 45.377503991127014


 92%|████████████████████████████████████▊   | 92/100 [2:29:43<10:32, 79.07s/it]

overall loss at epoch 91 is 45.26091003417969


 93%|█████████████████████████████████████▏  | 93/100 [2:31:03<09:16, 79.45s/it]

overall loss at epoch 92 is 45.37160402536392


 94%|█████████████████████████████████████▌  | 94/100 [2:32:26<08:02, 80.37s/it]

overall loss at epoch 93 is 45.136182963848114


 95%|██████████████████████████████████████  | 95/100 [2:33:45<06:40, 80.15s/it]

overall loss at epoch 94 is 44.84328246116638


 96%|██████████████████████████████████████▍ | 96/100 [2:35:12<05:28, 82.13s/it]

overall loss at epoch 95 is 44.78454273939133


 97%|██████████████████████████████████████▊ | 97/100 [2:36:42<04:13, 84.46s/it]

overall loss at epoch 96 is 44.68179214000702


 98%|███████████████████████████████████████▏| 98/100 [2:38:01<02:45, 82.90s/it]

overall loss at epoch 97 is 44.59868311882019


 99%|███████████████████████████████████████▌| 99/100 [2:39:21<01:22, 82.09s/it]

overall loss at epoch 98 is 44.650520503520966


100%|███████████████████████████████████████| 100/100 [2:40:40<00:00, 96.41s/it]

overall loss at epoch 99 is 44.63087421655655





In [17]:
noised_samples[50][0]

tensor([[0., 1., 0.],
        [0., 0., 1.],
        [1., 0., 0.],
        [0., 0., 1.],
        [1., 0., 0.],
        [0., 1., 0.],
        [0., 1., 0.],
        [1., 0., 0.],
        [1., 0., 0.],
        [0., 1., 0.],
        [0., 0., 1.],
        [0., 0., 1.],
        [1., 0., 0.],
        [0., 0., 1.],
        [1., 0., 0.],
        [1., 0., 0.],
        [0., 0., 1.],
        [0., 1., 0.],
        [0., 1., 0.],
        [1., 0., 0.],
        [0., 0., 1.],
        [1., 0., 0.],
        [0., 0., 1.],
        [0., 0., 1.],
        [1., 0., 0.],
        [1., 0., 0.],
        [0., 0., 1.],
        [0., 0., 1.],
        [0., 1., 0.],
        [0., 0., 1.],
        [1., 0., 0.],
        [0., 0., 1.],
        [1., 0., 0.],
        [1., 0., 0.],
        [0., 0., 1.],
        [0., 0., 1.],
        [0., 0., 1.],
        [0., 0., 1.],
        [0., 0., 1.],
        [0., 0., 1.],
        [1., 0., 0.],
        [0., 0., 1.],
        [0., 0., 1.],
        [0., 1., 0.],
        [0., 0., 1.],
        [1

In [16]:
cat_diff.y_pred[50][0]

tensor([[4.5065e-08, 8.9144e-01, 1.0856e-01],
        [4.7650e-05, 5.7511e-02, 9.4244e-01],
        [5.6188e-01, 1.3088e-06, 4.3812e-01],
        [7.6338e-01, 5.6680e-07, 2.3662e-01],
        [9.3882e-01, 1.2777e-06, 6.1175e-02],
        [5.4274e-06, 9.7300e-01, 2.6997e-02],
        [7.8631e-08, 9.2914e-01, 7.0860e-02],
        [6.3727e-01, 8.9061e-08, 3.6273e-01],
        [7.6796e-01, 1.1465e-06, 2.3204e-01],
        [1.8923e-08, 9.7842e-01, 2.1581e-02],
        [5.6258e-01, 3.1737e-05, 4.3739e-01],
        [8.6632e-06, 4.8849e-01, 5.1150e-01],
        [8.2178e-01, 1.8486e-05, 1.7821e-01],
        [5.0460e-07, 6.9131e-01, 3.0869e-01],
        [8.7242e-01, 6.6569e-05, 1.2751e-01],
        [7.7559e-01, 3.3692e-07, 2.2441e-01],
        [3.1820e-06, 9.1739e-01, 8.2606e-02],
        [2.1653e-07, 8.8438e-01, 1.1562e-01],
        [8.1607e-08, 7.6258e-01, 2.3742e-01],
        [7.5530e-01, 2.0627e-05, 2.4468e-01],
        [2.5186e-07, 6.8006e-01, 3.1993e-01],
        [6.0897e-01, 4.5614e-08, 3

In [208]:
denoised = cat_diff.y_pred[2:]

px_tminus1_giv_xt = q_xtminus1_xt_giv_x0 * denoised

In [209]:
px_tminus1_giv_xt

tensor([[[[0.0000e+00, 3.1462e-01, 0.0000e+00],
          [3.5677e-01, 0.0000e+00, 0.0000e+00],
          [3.3245e-01, 0.0000e+00, 0.0000e+00],
          ...,
          [0.0000e+00, 3.2670e-01, 0.0000e+00],
          [0.0000e+00, 3.0302e-01, 0.0000e+00],
          [0.0000e+00, 2.9244e-01, 0.0000e+00]],

         [[0.0000e+00, 3.3716e-01, 0.0000e+00],
          [3.2219e-01, 0.0000e+00, 0.0000e+00],
          [2.9704e-01, 0.0000e+00, 0.0000e+00],
          ...,
          [0.0000e+00, 3.2619e-01, 0.0000e+00],
          [3.6161e-01, 0.0000e+00, 0.0000e+00],
          [2.8870e-01, 0.0000e+00, 0.0000e+00]],

         [[0.0000e+00, 3.0515e-01, 0.0000e+00],
          [0.0000e+00, 3.1967e-01, 0.0000e+00],
          [3.1448e-01, 0.0000e+00, 0.0000e+00],
          ...,
          [2.8615e-01, 0.0000e+00, 0.0000e+00],
          [0.0000e+00, 3.0698e-01, 0.0000e+00],
          [3.2339e-01, 0.0000e+00, 0.0000e+00]],

         ...,

         [[0.0000e+00, 1.5856e-01, 3.2519e-03],
          [3.3326e-01,

In [212]:
torch.mean(
            torch.sum(
                cat_diff.forward_conditionals[2:] * (
                    stuff * torch.log(
                        (stuff + 1e-6)/(px_tminus1_giv_xt+1e-6)
                    )
                ),
                axis=-1
            )
        )

tensor(nan, grad_fn=<MeanBackward0>)

In [226]:
np.where(np.isnan(cat_diff.forward_conditionals[2:] * (
        stuff * torch.log(
            (stuff + 1e-6)/(px_tminus1_giv_xt+1e-6)
        )
    ).detach().numpy()))

(array([94, 94, 94]),
 array([24, 24, 24]),
 array([15, 15, 15]),
 array([0, 1, 2]))

In [234]:
stuff[94,24,15,2]

tensor(nan)

In [230]:
cat_diff.forward_conditionals[2:] * (
        stuff * torch.log(
            (stuff + 1e-6)/(px_tminus1_giv_xt+1e-6)
        )
)[94,24,15,0]
    

tensor([[[[nan, nan, nan],
          [nan, nan, nan],
          [nan, nan, nan],
          ...,
          [nan, nan, nan],
          [nan, nan, nan],
          [nan, nan, nan]],

         [[nan, nan, nan],
          [nan, nan, nan],
          [nan, nan, nan],
          ...,
          [nan, nan, nan],
          [nan, nan, nan],
          [nan, nan, nan]],

         [[nan, nan, nan],
          [nan, nan, nan],
          [nan, nan, nan],
          ...,
          [nan, nan, nan],
          [nan, nan, nan],
          [nan, nan, nan]],

         ...,

         [[nan, nan, nan],
          [nan, nan, nan],
          [nan, nan, nan],
          ...,
          [nan, nan, nan],
          [nan, nan, nan],
          [nan, nan, nan]],

         [[nan, nan, nan],
          [nan, nan, nan],
          [nan, nan, nan],
          ...,
          [nan, nan, nan],
          [nan, nan, nan],
          [nan, nan, nan]],

         [[nan, nan, nan],
          [nan, nan, nan],
          [nan, nan, nan],
         

In [215]:
np.isnan(torch.sum(
    cat_diff.forward_conditionals[2:] * (
        stuff * torch.log(
            (stuff + 1e-6)/(px_tminus1_giv_xt+1e-6)
        )
    ),
    axis=-1
).detach().numpy()).any()

True

In [216]:
np.where(np.isnan(torch.sum(
    cat_diff.forward_conditionals[2:] * (
        stuff * torch.log(
            (stuff + 1e-6)/(px_tminus1_giv_xt+1e-6)
        )
    ),
    axis=-1
).detach().numpy()))

(array([94]), array([24]), array([15]))

In [220]:
torch.sum(
    cat_diff.forward_conditionals[2:] * (
        stuff * torch.log(
            (stuff + 1e-6)/(px_tminus1_giv_xt+1e-6)
        )
    ),
    axis=-1
)[94][24]

tensor([1.0332, 0.6706, 0.9716, 0.9757, 0.9757, 0.9757, 1.0741, 0.7729, 1.1290,
        0.9757, 0.9757, 1.4398, 0.9554, 0.7639, 0.8608,    nan, 0.7205, 0.7551,
        1.3197, 1.0787, 1.0720, 1.4066, 0.6798, 1.0329, 1.2583, 1.5059, 1.4623,
        0.8676, 0.9757, 0.8594, 1.1232, 0.9757, 0.9757, 0.9757, 1.1720, 1.2588,
        0.7667, 0.7437, 1.1948, 0.9847, 0.7845, 0.5300, 0.6469, 1.4376, 1.0879,
        0.9078, 0.7355, 1.7366, 1.0102, 1.1947, 0.6780, 1.0162, 1.2185, 1.2742,
        1.0918, 0.9757, 1.2834, 0.9914, 0.9757, 1.0105, 1.1534, 1.3210, 1.2183,
        0.6016, 1.0976, 0.8873, 1.2643, 1.6309, 0.6532, 1.2003, 1.3970, 1.0385,
        0.8668, 0.9967, 1.5562, 0.8305, 1.2290, 0.5932, 0.9200, 1.2263, 1.3277,
        0.9840, 1.3741, 0.9757, 0.9794, 1.1538, 0.9106, 1.0421, 1.3159, 0.8671,
        0.7996, 1.5560, 1.2296, 1.1848, 1.4853, 0.9757, 1.4080, 0.9757, 1.1681,
        0.9757], grad_fn=<SelectBackward0>)

In [222]:
cat_diff.forward_conditionals[2:].shape

torch.Size([98, 128, 100, 3])

In [180]:
Lt0_t1_loss = cat_diff.L_t0t1(Y_spins)

In [181]:
Lt0_t1_loss

tensor(1.1005, grad_fn=<NegBackward0>)

In [182]:
Ltminus1    = cat_diff.L_tminus1(Y_spins, noised_samples)

In [91]:
cat_diff.eval()

t = ts.reshape(ts.shape[0] * ts.shape[1], 1)

time_encoding = cat_diff.denoiser.forward_time_1(t)
cat_diff.denoiser.forward_time_2(time_encoding).shape

torch.Size([12800, 128])

In [93]:
print(time_encoding[:,0])

tensor([ 0.0000,  0.0000,  0.0000,  ..., 37.5615, 37.5615, 37.5615],
       grad_fn=<SelectBackward0>)


In [183]:
Ltminus1

tensor(nan, grad_fn=<MeanBackward0>)

In [201]:
noised_samples[1][3][:,2]

tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])

tensor([[[1., 0., 0.],
         [0., 1., 0.],
         [1., 0., 0.],
         ...,
         [1., 0., 0.],
         [0., 1., 0.],
         [1., 0., 0.]],

        [[0., 1., 0.],
         [0., 1., 0.],
         [1., 0., 0.],
         ...,
         [0., 0., 1.],
         [0., 1., 0.],
         [1., 0., 0.]],

        [[1., 0., 0.],
         [0., 1., 0.],
         [1., 0., 0.],
         ...,
         [1., 0., 0.],
         [0., 1., 0.],
         [0., 1., 0.]],

        ...,

        [[0., 1., 0.],
         [1., 0., 0.],
         [1., 0., 0.],
         ...,
         [0., 1., 0.],
         [0., 1., 0.],
         [1., 0., 0.]],

        [[0., 1., 0.],
         [1., 0., 0.],
         [1., 0., 0.],
         ...,
         [0., 1., 0.],
         [0., 1., 0.],
         [1., 0., 0.]],

        [[0., 1., 0.],
         [1., 0., 0.],
         [0., 1., 0.],
         ...,
         [0., 1., 0.],
         [0., 1., 0.],
         [1., 0., 0.]]])