In [1]:
import torch
import numpy as np
import sklearn

from utils import *
from architectures import *
import preprocess

import matplotlib.pyplot as plt

from CategoricalDiffusion import *
from denoiser import *

In [2]:
class MLPDenoiser(torch.nn.Module):
    def __init__(
        self,
        d_time,
        d_time_hidden,
        d_seq,
        d_aas,
        d_hidden = 128,
        d_model = 128,
        p = 0.1,
        activation = torch.nn.ReLU(),
        **kwargs
    ):
        super().__init__(**kwargs)
        self.activation = activation
        self.p = p


        self.forward_time_seq = FeedForward(d_seq*d_aas + d_time, d_model)

        self.forward_time_seq_1 = FeedForward(d_model, d_model)
        self.dropout_1 = torch.nn.Dropout(self.p)
        
        self.forward_time_seq_2 = FeedForward(d_model, d_model)
        self.dropout_2 = torch.nn.Dropout(self.p)
        
        self.forward_time_seq_3 = FeedForward(d_model, d_model)
        self.dropout_3 = torch.nn.Dropout(self.p)

        self.feedforward_final = FeedForward(d_model, d_seq*d_aas)
        
    def forward(self, X, t):

        time_points, batch_size, seq_length, aas = X.shape[0], X.shape[1], X.shape[2], X.shape[3]
        
        t = t.reshape(t.shape[0] * t.shape[1], 1)
        X = X.view(time_points*batch_size, seq_length, aas)
        
        X_flattened = torch.nn.Flatten(start_dim=1)(X)
        
        seq_time_encoding = torch.concat([t, X_flattened], dim=-1)

        input_encoding = self.forward_time_seq(seq_time_encoding)

        X_1 = self.forward_time_seq_1(input_encoding)
        X_1 = self.activation(X_1)
        X_1 = self.dropout_1(X_1)

        X_2 = self.forward_time_seq_2(X_1)
        X_2 = self.activation(X_2)
        X_2 = self.dropout_2(X_2)

        X_3 = self.forward_time_seq_3(X_2)
        X_3 = self.activation(X_3)
        X_3 = self.dropout_3(X_3)

        X_final = self.feedforward_final(X_3)

        X_final = X_final.view(time_points, batch_size, seq_length, aas)
        
        Y_pred = torch.nn.Softmax(dim=-1)(X_final)
        return Y_pred

In [3]:
N_samples = 10000

M = 2
N_dim = 100

bias = torch.zeros((N_dim,1))
patterns = 2*torch.bernoulli(torch.ones((N_dim, M)) * 0.5) - 1

spins = generate_data(patterns, bias, N_samples, beta=2)

100%|█████████████████████████████████████████| 100/100 [00:19<00:00,  5.15it/s]


In [4]:
index_spins = ((1+spins)/2).type(torch.int64)

In [5]:
one_hot_spin_vectors = torch.nn.functional.one_hot(index_spins)

In [6]:
class SpinsDataset(torch.utils.data.Dataset):
    """
    takes in one hot encoded spins X with one key - sequence, and Y with one, potentially two, keys - sequence and phenotype
    inputs:
        one_hot_spins: torch.Tensor representing spins (num_samples, dim, num_states)
    
    """
    
    def __init__(self,
                 spins,
                 include_mask = False,
                **kwargs):
        super().__init__(**kwargs)
        self.include_mask = include_mask
        self.spins = spins
        if self.include_mask:

            self.spins = torch.cat((self.spins, torch.zeros((self.spins.shape[0], self.spins.shape[1], 1))), axis=-1)

    def __len__(self):
        return self.spins.shape[0]

    def __getitem__(self, index):
        X = dict()
        Y = dict()

        Y['spins'] = self.spins[index]
        X['spins'] = self.spins[index]
        return X, Y

In [7]:
spins_dataset = SpinsDataset(one_hot_spin_vectors, include_mask=True)
spins_loader = torch.utils.data.DataLoader(spins_dataset, batch_size=32, shuffle=True)

In [8]:
for batch in spins_loader:
    X, Y = batch
    X_spins = X['spins']
    Y_spins = Y['spins']
    

In [9]:
noise_matrix = Noiser(noiser = 'BERT-LIKE', beta_t = 0.01, k=2).noise_matrix
ts, noised_samples = noiser(X_spins, noise_matrix, 100, X_spins.shape[-1])

In [10]:
noised_samples.shape

torch.Size([100, 16, 100, 3])

In [11]:
denoiser = MLPDenoiser(1,128,X_spins.shape[1],X_spins.shape[2])

In [12]:
Y_pred = denoiser(noised_samples, ts)

In [13]:
Y_pred.shape

torch.Size([100, 16, 100, 3])

In [14]:
noised_samples.shape

torch.Size([100, 16, 100, 3])

In [15]:
cat_diff = CategoricalDiffusion(denoiser, noise_matrix)

In [16]:
cat_diff.denoiser(noised_samples,ts)

tensor([[[[0.3370, 0.3315, 0.3315],
          [0.3478, 0.3261, 0.3261],
          [0.3243, 0.3513, 0.3243],
          ...,
          [0.3267, 0.3267, 0.3466],
          [0.3190, 0.3315, 0.3495],
          [0.3420, 0.3355, 0.3226]],

         [[0.3403, 0.3299, 0.3299],
          [0.3494, 0.3253, 0.3253],
          [0.3217, 0.3539, 0.3244],
          ...,
          [0.3279, 0.3279, 0.3441],
          [0.3178, 0.3367, 0.3455],
          [0.3429, 0.3342, 0.3229]],

         [[0.3376, 0.3312, 0.3312],
          [0.3412, 0.3294, 0.3294],
          [0.3239, 0.3523, 0.3239],
          ...,
          [0.3274, 0.3274, 0.3452],
          [0.3275, 0.3239, 0.3486],
          [0.3471, 0.3301, 0.3228]],

         ...,

         [[0.3350, 0.3325, 0.3325],
          [0.3467, 0.3267, 0.3267],
          [0.3216, 0.3490, 0.3294],
          ...,
          [0.3270, 0.3270, 0.3460],
          [0.3281, 0.3317, 0.3402],
          [0.3436, 0.3315, 0.3249]],

         [[0.3414, 0.3293, 0.3293],
          [0.3444

In [17]:
cat_diff.decode(noised_samples, ts)

In [18]:
cat_diff.calc_forward_conditionals(noised_samples)

In [19]:
cat_diff.L_T(noised_samples)

tensor(1.7237)

In [20]:
cat_diff.L_t0t1(Y_spins)

tensor(1.0975, grad_fn=<NegBackward0>)

In [21]:
cat_diff.L_tminus1(Y_spins, noised_samples)

tensor(0.8218, grad_fn=<MeanBackward0>)

In [None]:
from tqdm import tqdm
epochs = 100

denoiser = MLPDenoiser(1,128,X_spins.shape[1],X_spins.shape[2], p=0.9, activation = torch.nn.Tanh())
cat_diff = CategoricalDiffusion(denoiser, noise_matrix)

cat_diff.train()
optim = torch.optim.Adam(cat_diff.parameters(), lr = 1e-3)

for epoch in tqdm(range(epochs)):
    overall_loss = 0
    spins_loader = torch.utils.data.DataLoader(spins_dataset, batch_size=128, shuffle=True)
    for batch in spins_loader:
        X, Y = batch
        
        X_spins = X['spins']
        Y_spins = Y['spins']
        
        ts, noised_samples = noiser(X_spins, noise_matrix, 100, X_spins.shape[-1])

        cat_diff.decode(noised_samples, ts)
        cat_diff.calc_forward_conditionals(noised_samples)
        
        LT_loss     = cat_diff.L_T(noised_samples)
        
        Lt0_t1_loss = cat_diff.L_t0t1(Y_spins)


        Ltminus1    = cat_diff.L_tminus1(Y_spins, noised_samples)


        loss = Lt0_t1_loss + Ltminus1
        
        optim.zero_grad()
        loss.backward()

        torch.nn.utils.clip_grad_value_(cat_diff.parameters(), 0.1)

        optim.step()
        #print(loss.item())
        overall_loss += loss.item()

    print('overall loss at epoch {} is '.format(epoch) + str(overall_loss))

  1%|▍                                        | 1/100 [00:45<1:15:17, 45.63s/it]

overall loss at epoch 0 is 135.84732270240784


  2%|▊                                        | 2/100 [01:30<1:13:32, 45.03s/it]

overall loss at epoch 1 is 121.05670738220215


  3%|█▏                                       | 3/100 [02:15<1:13:07, 45.24s/it]

overall loss at epoch 2 is 111.6009372472763


  4%|█▋                                       | 4/100 [03:00<1:11:47, 44.87s/it]

overall loss at epoch 3 is 105.38810896873474


  5%|██                                       | 5/100 [03:43<1:10:22, 44.45s/it]

overall loss at epoch 4 is 103.57412123680115


  6%|██▍                                      | 6/100 [04:27<1:09:31, 44.38s/it]

overall loss at epoch 5 is 102.41586709022522


  7%|██▊                                      | 7/100 [05:15<1:10:11, 45.29s/it]

overall loss at epoch 6 is 101.82431697845459


  8%|███▎                                     | 8/100 [06:12<1:15:13, 49.06s/it]

overall loss at epoch 7 is 101.42035746574402


  9%|███▋                                     | 9/100 [07:12<1:19:31, 52.43s/it]

overall loss at epoch 8 is 100.73042058944702


 10%|████                                    | 10/100 [08:13<1:22:36, 55.07s/it]

overall loss at epoch 9 is 100.37978506088257


 11%|████▍                                   | 11/100 [09:09<1:22:24, 55.55s/it]

overall loss at epoch 10 is 99.91097223758698


 12%|████▊                                   | 12/100 [10:05<1:21:24, 55.51s/it]

overall loss at epoch 11 is 99.74670684337616


 13%|█████▏                                  | 13/100 [11:02<1:21:18, 56.07s/it]

overall loss at epoch 12 is 99.48744630813599


 14%|█████▌                                  | 14/100 [11:59<1:20:42, 56.30s/it]

overall loss at epoch 13 is 99.3787248134613


 15%|██████                                  | 15/100 [12:55<1:19:53, 56.39s/it]

overall loss at epoch 14 is 99.16166853904724


 16%|██████▍                                 | 16/100 [13:52<1:19:10, 56.56s/it]

overall loss at epoch 15 is 98.85555231571198


 17%|██████▊                                 | 17/100 [14:49<1:18:10, 56.51s/it]

overall loss at epoch 16 is 98.83682644367218


 18%|███████▏                                | 18/100 [15:45<1:17:07, 56.43s/it]

overall loss at epoch 17 is 98.70484399795532


 19%|███████▌                                | 19/100 [16:40<1:15:28, 55.90s/it]

overall loss at epoch 18 is 98.6557366847992


 20%|████████                                | 20/100 [17:36<1:14:42, 56.03s/it]

overall loss at epoch 19 is 98.4665961265564


 21%|████████▍                               | 21/100 [18:34<1:14:35, 56.65s/it]

overall loss at epoch 20 is 98.3373372554779


 22%|████████▊                               | 22/100 [19:32<1:13:58, 56.90s/it]

overall loss at epoch 21 is 98.27106475830078


 23%|█████████▏                              | 23/100 [20:30<1:13:25, 57.22s/it]

overall loss at epoch 22 is 98.34451687335968


 24%|█████████▌                              | 24/100 [21:26<1:11:59, 56.83s/it]

overall loss at epoch 23 is 98.09732925891876


 25%|██████████                              | 25/100 [22:19<1:09:43, 55.79s/it]

overall loss at epoch 24 is 98.00088608264923


 26%|██████████▍                             | 26/100 [23:14<1:08:27, 55.51s/it]

overall loss at epoch 25 is 97.8824177980423


 27%|██████████▊                             | 27/100 [24:10<1:07:55, 55.83s/it]

overall loss at epoch 26 is 97.8516857624054


 28%|███████████▏                            | 28/100 [25:06<1:07:06, 55.93s/it]

overall loss at epoch 27 is 97.75183618068695


 29%|███████████▌                            | 29/100 [26:02<1:05:52, 55.67s/it]

overall loss at epoch 28 is 97.62625348567963


 30%|████████████                            | 30/100 [26:57<1:04:48, 55.56s/it]

overall loss at epoch 29 is 97.49396598339081


 31%|████████████▍                           | 31/100 [27:50<1:03:13, 54.98s/it]

overall loss at epoch 30 is 97.53072082996368


 32%|████████████▊                           | 32/100 [28:45<1:02:03, 54.76s/it]

overall loss at epoch 31 is 97.50134766101837


 33%|█████████████▏                          | 33/100 [29:38<1:00:49, 54.46s/it]

overall loss at epoch 32 is 97.42038226127625


 34%|█████████████▌                          | 34/100 [30:34<1:00:24, 54.92s/it]

overall loss at epoch 33 is 97.31864750385284


 35%|████████████▌                       | 35/100 [1:01:50<10:51:09, 601.06s/it]

overall loss at epoch 34 is 97.31011879444122


 36%|████████████▉                       | 36/100 [1:27:40<15:44:52, 885.81s/it]

overall loss at epoch 35 is 97.20506167411804


 37%|████████████▉                      | 37/100 [1:58:21<20:30:57, 1172.34s/it]

overall loss at epoch 36 is 97.08378505706787


 38%|█████████████▋                      | 38/100 [2:00:17<14:43:55, 855.41s/it]

overall loss at epoch 37 is 97.07422626018524


 39%|██████████████                      | 39/100 [2:18:39<15:44:48, 929.32s/it]

overall loss at epoch 38 is 96.98706912994385


 40%|██████████████                     | 40/100 [2:38:15<16:43:27, 1003.46s/it]

overall loss at epoch 39 is 97.00689280033112


 41%|██████████████▊                     | 41/100 [2:39:40<11:55:52, 728.01s/it]

overall loss at epoch 40 is 96.92144405841827


 42%|███████████████▌                     | 42/100 [2:40:53<8:33:41, 531.41s/it]

overall loss at epoch 41 is 96.88095653057098


 43%|███████████████▉                     | 43/100 [2:42:13<6:16:15, 396.05s/it]

overall loss at epoch 42 is 96.84663331508636


 44%|████████████████▎                    | 44/100 [3:00:04<9:18:32, 598.43s/it]

overall loss at epoch 43 is 96.79241728782654


 45%|████████████████▋                    | 45/100 [3:01:25<6:46:14, 443.18s/it]

overall loss at epoch 44 is 96.72847485542297


 46%|█████████████████                    | 46/100 [3:02:42<5:00:07, 333.48s/it]

overall loss at epoch 45 is 96.65947902202606


 47%|█████████████████▍                   | 47/100 [3:04:04<3:47:50, 257.93s/it]

overall loss at epoch 46 is 96.65366506576538


 48%|█████████████████▊                   | 48/100 [3:05:31<2:59:09, 206.73s/it]

overall loss at epoch 47 is 96.59584629535675


 49%|██████████████████▏                  | 49/100 [3:06:52<2:23:32, 168.87s/it]

overall loss at epoch 48 is 96.59914898872375


 50%|██████████████████▌                  | 50/100 [3:08:18<2:00:06, 144.13s/it]

overall loss at epoch 49 is 96.57844698429108


 51%|██████████████████▊                  | 51/100 [3:09:45<1:43:32, 126.79s/it]

overall loss at epoch 50 is 96.50876569747925


 52%|███████████████████▏                 | 52/100 [3:11:12<1:31:56, 114.93s/it]

overall loss at epoch 51 is 96.4804699420929


 53%|███████████████████▌                 | 53/100 [3:12:37<1:23:06, 106.09s/it]

overall loss at epoch 52 is 96.41776490211487


 54%|███████████████████▉                 | 54/100 [3:14:04<1:16:57, 100.37s/it]

overall loss at epoch 53 is 96.35691034793854


 55%|████████████████████▉                 | 55/100 [3:15:36<1:13:19, 97.76s/it]

overall loss at epoch 54 is 96.35047352313995


 56%|█████████████████████▎                | 56/100 [3:17:07<1:10:05, 95.59s/it]

overall loss at epoch 55 is 96.31885266304016


 57%|█████████████████████▋                | 57/100 [3:18:40<1:08:00, 94.89s/it]

overall loss at epoch 56 is 96.33619499206543


 58%|██████████████████████                | 58/100 [3:20:10<1:05:31, 93.61s/it]

overall loss at epoch 57 is 96.31549298763275


 59%|██████████████████████▍               | 59/100 [3:21:41<1:03:21, 92.72s/it]

overall loss at epoch 58 is 96.28060913085938


 60%|██████████████████████▊               | 60/100 [3:23:09<1:00:52, 91.32s/it]

overall loss at epoch 59 is 92.02845573425293


 61%|████████████████████████▍               | 61/100 [3:24:38<58:51, 90.55s/it]

overall loss at epoch 60 is 86.50207257270813


 62%|████████████████████████▊               | 62/100 [3:26:12<58:03, 91.67s/it]

overall loss at epoch 61 is 83.93247842788696


 63%|█████████████████████████▏              | 63/100 [3:27:41<55:59, 90.79s/it]

overall loss at epoch 62 is 83.21084862947464


 64%|█████████████████████████▌              | 64/100 [3:29:07<53:39, 89.43s/it]

overall loss at epoch 63 is 82.41088318824768


 65%|██████████████████████████              | 65/100 [3:30:32<51:25, 88.16s/it]

overall loss at epoch 64 is 82.29489922523499


 66%|██████████████████████████▍             | 66/100 [3:31:21<43:14, 76.31s/it]

overall loss at epoch 65 is 81.89134383201599


 67%|█████████████████████▍          | 67/100 [13:46:16<101:57:02, 11121.91s/it]

overall loss at epoch 66 is 81.2218416929245


 68%|███████████████████████           | 68/100 [14:26:10<75:35:09, 8503.43s/it]

overall loss at epoch 67 is 80.85507422685623


 69%|███████████████████████▍          | 69/100 [15:19:47<59:34:07, 6917.67s/it]

overall loss at epoch 68 is 80.58616369962692


 70%|███████████████████████▊          | 70/100 [15:20:33<40:28:03, 4856.11s/it]

overall loss at epoch 69 is 80.15747720003128


 71%|████████████████████████▏         | 71/100 [15:28:55<28:35:44, 3549.83s/it]

overall loss at epoch 70 is 79.81055492162704


 72%|████████████████████████▍         | 72/100 [15:30:35<19:33:37, 2514.91s/it]

overall loss at epoch 71 is 79.28730779886246


 73%|████████████████████████▊         | 73/100 [15:31:23<13:18:39, 1774.80s/it]

overall loss at epoch 72 is 79.12084889411926


 74%|█████████████████████████▉         | 74/100 [15:32:09<9:04:23, 1256.27s/it]

overall loss at epoch 73 is 78.83034461736679


 75%|███████████████████████████         | 75/100 [15:32:55<6:12:09, 893.16s/it]

overall loss at epoch 74 is 78.51904970407486


 76%|███████████████████████████▎        | 76/100 [15:33:37<4:15:09, 637.91s/it]

overall loss at epoch 75 is 78.29425460100174


 77%|███████████████████████████▋        | 77/100 [15:34:21<2:56:08, 459.52s/it]

overall loss at epoch 76 is 78.08866488933563


 78%|████████████████████████████        | 78/100 [15:35:06<2:02:55, 335.26s/it]

overall loss at epoch 77 is 78.01650094985962


 79%|████████████████████████████▍       | 79/100 [15:35:52<1:27:00, 248.59s/it]

overall loss at epoch 78 is 77.84435653686523


 80%|████████████████████████████▊       | 80/100 [15:36:35<1:02:19, 186.95s/it]

overall loss at epoch 79 is 77.63565111160278


 81%|██████████████████████████████▊       | 81/100 [15:37:20<45:38, 144.15s/it]

overall loss at epoch 80 is 77.67585772275925


 82%|███████████████████████████████▏      | 82/100 [15:38:04<34:17, 114.29s/it]

overall loss at epoch 81 is 77.34525293111801


 83%|████████████████████████████████▎      | 83/100 [15:38:49<26:29, 93.50s/it]

overall loss at epoch 82 is 77.02098935842514


 84%|████████████████████████████████▊      | 84/100 [15:39:33<20:59, 78.69s/it]

overall loss at epoch 83 is 77.31703680753708


 85%|█████████████████████████████████▏     | 85/100 [15:40:19<17:08, 68.60s/it]

overall loss at epoch 84 is 77.15096467733383


 86%|█████████████████████████████████▌     | 86/100 [15:41:03<14:17, 61.27s/it]

overall loss at epoch 85 is 77.17066425085068


 87%|█████████████████████████████████▉     | 87/100 [15:41:47<12:10, 56.21s/it]

overall loss at epoch 86 is 76.77328258752823


 88%|██████████████████████████████████▎    | 88/100 [15:42:33<10:36, 53.06s/it]

overall loss at epoch 87 is 76.7115079164505


 89%|██████████████████████████████████▋    | 89/100 [15:43:16<09:11, 50.17s/it]

overall loss at epoch 88 is 76.94848650693893


 90%|███████████████████████████████████    | 90/100 [15:44:00<08:02, 48.25s/it]

overall loss at epoch 89 is 76.72017842531204


 91%|███████████████████████████████████▍   | 91/100 [15:44:43<07:00, 46.72s/it]

overall loss at epoch 90 is 76.74418246746063


 92%|███████████████████████████████████▉   | 92/100 [15:45:27<06:05, 45.74s/it]

overall loss at epoch 91 is 76.71150887012482


 93%|████████████████████████████████████▎  | 93/100 [15:46:12<05:18, 45.53s/it]

overall loss at epoch 92 is 76.60798090696335


 94%|████████████████████████████████████▋  | 94/100 [15:46:56<04:31, 45.25s/it]

overall loss at epoch 93 is 76.48578530550003


In [204]:
stuff = cat_diff.one_step_reverse_conditional(Y_spins, noised_samples)

In [206]:
q_xtminus1_xt_giv_x0 = cat_diff.q_xtminus1_xt_giv_x0(noised_samples, stuff)


In [208]:
denoised = cat_diff.y_pred[2:]

px_tminus1_giv_xt = q_xtminus1_xt_giv_x0 * denoised

In [209]:
px_tminus1_giv_xt

tensor([[[[0.0000e+00, 3.1462e-01, 0.0000e+00],
          [3.5677e-01, 0.0000e+00, 0.0000e+00],
          [3.3245e-01, 0.0000e+00, 0.0000e+00],
          ...,
          [0.0000e+00, 3.2670e-01, 0.0000e+00],
          [0.0000e+00, 3.0302e-01, 0.0000e+00],
          [0.0000e+00, 2.9244e-01, 0.0000e+00]],

         [[0.0000e+00, 3.3716e-01, 0.0000e+00],
          [3.2219e-01, 0.0000e+00, 0.0000e+00],
          [2.9704e-01, 0.0000e+00, 0.0000e+00],
          ...,
          [0.0000e+00, 3.2619e-01, 0.0000e+00],
          [3.6161e-01, 0.0000e+00, 0.0000e+00],
          [2.8870e-01, 0.0000e+00, 0.0000e+00]],

         [[0.0000e+00, 3.0515e-01, 0.0000e+00],
          [0.0000e+00, 3.1967e-01, 0.0000e+00],
          [3.1448e-01, 0.0000e+00, 0.0000e+00],
          ...,
          [2.8615e-01, 0.0000e+00, 0.0000e+00],
          [0.0000e+00, 3.0698e-01, 0.0000e+00],
          [3.2339e-01, 0.0000e+00, 0.0000e+00]],

         ...,

         [[0.0000e+00, 1.5856e-01, 3.2519e-03],
          [3.3326e-01,

In [212]:
torch.mean(
            torch.sum(
                cat_diff.forward_conditionals[2:] * (
                    stuff * torch.log(
                        (stuff + 1e-6)/(px_tminus1_giv_xt+1e-6)
                    )
                ),
                axis=-1
            )
        )

tensor(nan, grad_fn=<MeanBackward0>)

In [226]:
np.where(np.isnan(cat_diff.forward_conditionals[2:] * (
        stuff * torch.log(
            (stuff + 1e-6)/(px_tminus1_giv_xt+1e-6)
        )
    ).detach().numpy()))

(array([94, 94, 94]),
 array([24, 24, 24]),
 array([15, 15, 15]),
 array([0, 1, 2]))

In [234]:
stuff[94,24,15,2]

tensor(nan)

In [230]:
cat_diff.forward_conditionals[2:] * (
        stuff * torch.log(
            (stuff + 1e-6)/(px_tminus1_giv_xt+1e-6)
        )
)[94,24,15,0]
    

tensor([[[[nan, nan, nan],
          [nan, nan, nan],
          [nan, nan, nan],
          ...,
          [nan, nan, nan],
          [nan, nan, nan],
          [nan, nan, nan]],

         [[nan, nan, nan],
          [nan, nan, nan],
          [nan, nan, nan],
          ...,
          [nan, nan, nan],
          [nan, nan, nan],
          [nan, nan, nan]],

         [[nan, nan, nan],
          [nan, nan, nan],
          [nan, nan, nan],
          ...,
          [nan, nan, nan],
          [nan, nan, nan],
          [nan, nan, nan]],

         ...,

         [[nan, nan, nan],
          [nan, nan, nan],
          [nan, nan, nan],
          ...,
          [nan, nan, nan],
          [nan, nan, nan],
          [nan, nan, nan]],

         [[nan, nan, nan],
          [nan, nan, nan],
          [nan, nan, nan],
          ...,
          [nan, nan, nan],
          [nan, nan, nan],
          [nan, nan, nan]],

         [[nan, nan, nan],
          [nan, nan, nan],
          [nan, nan, nan],
         

In [215]:
np.isnan(torch.sum(
    cat_diff.forward_conditionals[2:] * (
        stuff * torch.log(
            (stuff + 1e-6)/(px_tminus1_giv_xt+1e-6)
        )
    ),
    axis=-1
).detach().numpy()).any()

True

In [216]:
np.where(np.isnan(torch.sum(
    cat_diff.forward_conditionals[2:] * (
        stuff * torch.log(
            (stuff + 1e-6)/(px_tminus1_giv_xt+1e-6)
        )
    ),
    axis=-1
).detach().numpy()))

(array([94]), array([24]), array([15]))

In [220]:
torch.sum(
    cat_diff.forward_conditionals[2:] * (
        stuff * torch.log(
            (stuff + 1e-6)/(px_tminus1_giv_xt+1e-6)
        )
    ),
    axis=-1
)[94][24]

tensor([1.0332, 0.6706, 0.9716, 0.9757, 0.9757, 0.9757, 1.0741, 0.7729, 1.1290,
        0.9757, 0.9757, 1.4398, 0.9554, 0.7639, 0.8608,    nan, 0.7205, 0.7551,
        1.3197, 1.0787, 1.0720, 1.4066, 0.6798, 1.0329, 1.2583, 1.5059, 1.4623,
        0.8676, 0.9757, 0.8594, 1.1232, 0.9757, 0.9757, 0.9757, 1.1720, 1.2588,
        0.7667, 0.7437, 1.1948, 0.9847, 0.7845, 0.5300, 0.6469, 1.4376, 1.0879,
        0.9078, 0.7355, 1.7366, 1.0102, 1.1947, 0.6780, 1.0162, 1.2185, 1.2742,
        1.0918, 0.9757, 1.2834, 0.9914, 0.9757, 1.0105, 1.1534, 1.3210, 1.2183,
        0.6016, 1.0976, 0.8873, 1.2643, 1.6309, 0.6532, 1.2003, 1.3970, 1.0385,
        0.8668, 0.9967, 1.5562, 0.8305, 1.2290, 0.5932, 0.9200, 1.2263, 1.3277,
        0.9840, 1.3741, 0.9757, 0.9794, 1.1538, 0.9106, 1.0421, 1.3159, 0.8671,
        0.7996, 1.5560, 1.2296, 1.1848, 1.4853, 0.9757, 1.4080, 0.9757, 1.1681,
        0.9757], grad_fn=<SelectBackward0>)

In [222]:
cat_diff.forward_conditionals[2:].shape

torch.Size([98, 128, 100, 3])

In [180]:
Lt0_t1_loss = cat_diff.L_t0t1(Y_spins)

In [181]:
Lt0_t1_loss

tensor(1.1005, grad_fn=<NegBackward0>)

In [182]:
Ltminus1    = cat_diff.L_tminus1(Y_spins, noised_samples)

In [91]:
cat_diff.eval()

t = ts.reshape(ts.shape[0] * ts.shape[1], 1)

time_encoding = cat_diff.denoiser.forward_time_1(t)
cat_diff.denoiser.forward_time_2(time_encoding).shape

torch.Size([12800, 128])

In [93]:
print(time_encoding[:,0])

tensor([ 0.0000,  0.0000,  0.0000,  ..., 37.5615, 37.5615, 37.5615],
       grad_fn=<SelectBackward0>)


In [183]:
Ltminus1

tensor(nan, grad_fn=<MeanBackward0>)

In [201]:
noised_samples[1][3][:,2]

tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])

tensor([[[1., 0., 0.],
         [0., 1., 0.],
         [1., 0., 0.],
         ...,
         [1., 0., 0.],
         [0., 1., 0.],
         [1., 0., 0.]],

        [[0., 1., 0.],
         [0., 1., 0.],
         [1., 0., 0.],
         ...,
         [0., 0., 1.],
         [0., 1., 0.],
         [1., 0., 0.]],

        [[1., 0., 0.],
         [0., 1., 0.],
         [1., 0., 0.],
         ...,
         [1., 0., 0.],
         [0., 1., 0.],
         [0., 1., 0.]],

        ...,

        [[0., 1., 0.],
         [1., 0., 0.],
         [1., 0., 0.],
         ...,
         [0., 1., 0.],
         [0., 1., 0.],
         [1., 0., 0.]],

        [[0., 1., 0.],
         [1., 0., 0.],
         [1., 0., 0.],
         ...,
         [0., 1., 0.],
         [0., 1., 0.],
         [1., 0., 0.]],

        [[0., 1., 0.],
         [1., 0., 0.],
         [0., 1., 0.],
         ...,
         [0., 1., 0.],
         [0., 1., 0.],
         [1., 0., 0.]]])