In [1]:
import torch
import numpy as np
import sklearn

from utils import *
from architectures import *
import preprocess

import matplotlib.pyplot as plt

from CategoricalDiffusion import *
from denoiser import *

In [2]:
N_samples = 100000

M = 2
N_dim = 100

bias = torch.zeros((N_dim,1))
patterns = 2*torch.bernoulli(torch.ones((N_dim, M)) * 0.5) - 1

spins = generate_data(patterns, bias, N_samples, beta=4)

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [01:18<00:00,  1.28it/s]


In [3]:
index_spins = ((1+spins)/2).type(torch.int64)

In [4]:
one_hot_spin_vectors = torch.nn.functional.one_hot(index_spins)

In [5]:
class SpinsDataset(torch.utils.data.Dataset):
    """
    takes in one hot encoded spins X with one key - sequence, and Y with one, potentially two, keys - sequence and phenotype
    inputs:
        one_hot_spins: torch.Tensor representing spins (num_samples, dim, num_states)
    
    """
    
    def __init__(self,
                 spins,
                 include_mask = False,
                **kwargs):
        super().__init__(**kwargs)
        self.include_mask = include_mask
        self.spins = spins
        if self.include_mask:

            self.spins = torch.cat((self.spins, torch.zeros((self.spins.shape[0], self.spins.shape[1], 1))), axis=-1)

    def __len__(self):
        return self.spins.shape[0]

    def __getitem__(self, index):
        X = dict()
        Y = dict()

        Y['spins'] = self.spins[index]
        X['spins'] = self.spins[index]
        return X, Y

In [6]:
spins_dataset = SpinsDataset(one_hot_spin_vectors, include_mask=True)
spins_loader = torch.utils.data.DataLoader(spins_dataset, batch_size=32, shuffle=True)

In [7]:
for batch in spins_loader:
    X, Y = batch
    X_spins = X['spins']
    Y_spins = Y['spins']
    

In [8]:
noise_matrix = Noiser(noiser = 'BERT-LIKE', beta_t = 0.01, k=2).noise_matrix
ts, noised_samples = noiser(X_spins, noise_matrix, 100, X_spins.shape[-1])

In [9]:
denoiser = Denoiser(8,1,128,X_spins.shape[2])

In [10]:
cat_diff = CategoricalDiffusion(denoiser, noise_matrix)

In [11]:
cat_diff.decode(noised_samples, ts)

In [12]:
cat_diff.calc_forward_conditionals(noised_samples)

In [13]:
cat_diff.y_pred

tensor([[[[0.3041, 0.3294, 0.3664],
          [0.3103, 0.3205, 0.3692],
          [0.2981, 0.3307, 0.3712],
          ...,
          [0.3137, 0.3137, 0.3727],
          [0.3179, 0.3179, 0.3641],
          [0.3186, 0.3186, 0.3628]],

         [[0.3010, 0.3186, 0.3804],
          [0.3062, 0.3062, 0.3877],
          [0.3171, 0.3171, 0.3659],
          ...,
          [0.2894, 0.3355, 0.3751],
          [0.3027, 0.3165, 0.3808],
          [0.3069, 0.3335, 0.3596]],

         [[0.3033, 0.3033, 0.3933],
          [0.3234, 0.3234, 0.3531],
          [0.3128, 0.3129, 0.3743],
          ...,
          [0.3272, 0.3272, 0.3456],
          [0.3179, 0.3209, 0.3612],
          [0.3172, 0.3172, 0.3656]],

         ...,

         [[0.3198, 0.3399, 0.3402],
          [0.3090, 0.3090, 0.3821],
          [0.3282, 0.3249, 0.3469],
          ...,
          [0.3007, 0.3140, 0.3853],
          [0.3058, 0.3103, 0.3838],
          [0.2952, 0.2971, 0.4077]],

         [[0.3068, 0.3068, 0.3865],
          [0.3109

In [14]:
cat_diff.L_T(noised_samples)

tensor(1.6887)

In [15]:
cat_diff.L_t0t1(Y_spins)

tensor(1.1711, grad_fn=<NegBackward0>)

In [16]:
cat_diff.L_tminus1(Y_spins, noised_samples)

tensor(0.7824, grad_fn=<MeanBackward0>)

In [39]:
from tqdm import tqdm
epochs = 100

cat_diff.train()
optim = torch.optim.Adam(cat_diff.parameters(), lr = 1e-3)

for epoch in range(epochs):
    overall_loss = 0
    spins_loader = torch.utils.data.DataLoader(spins_dataset, batch_size=32, shuffle=True)
    for batch in tqdm(spins_loader):
        X, Y = batch
        
        X_spins = X['spins']
        Y_spins = Y['spins']
        
        ts, noised_samples = noiser(X_spins, noise_matrix, 100, X_spins.shape[-1])

        cat_diff.decode(noised_samples, ts)
        cat_diff.calc_forward_conditionals(noised_samples)
        
        LT_loss     = cat_diff.L_T(noised_samples)
        
        Lt0_t1_loss = cat_diff.L_t0t1(Y_spins)
        print(Lt0_t1_loss.item())

        Ltminus1    = cat_diff.L_tminus1(Y_spins, noised_samples)
        print(Ltminus1.item())

        loss = Lt0_t1_loss + Ltminus1
        print(loss.item())
        
        optim.zero_grad()
        loss.backward()
        optim.step()
        
        overall_loss += loss.item()
    print('overall loss at epoch {} is '.format(epoch) + str(overall_loss/X_train.shape[0]))

  0%|                                                                                                                                                                                                                                                  | 0/3125 [00:00<?, ?it/s]

0.07430996000766754
0.37103477120399475
0.4453447461128235


  0%|                                                                                                                                                                                                                                       | 1/3125 [00:36<31:29:28, 36.29s/it]

0.40446680784225464
1.9333828687667847
2.3378496170043945


  0%|▏                                                                                                                                                                                                                                      | 2/3125 [01:11<31:03:29, 35.80s/it]

0.07624992728233337
1.2241350412368774
1.3003849983215332


  0%|▏                                                                                                                                                                                                                                      | 3/3125 [01:45<30:11:39, 34.82s/it]

0.09118326008319855
0.8738067746162415
0.9649900197982788


  0%|▎                                                                                                                                                                                                                                      | 4/3125 [02:18<29:45:10, 34.32s/it]

0.10097696632146835
0.7012823820114136
0.8022593259811401


  0%|▎                                                                                                                                                                                                                                      | 5/3125 [02:49<28:42:55, 33.13s/it]

0.1322679966688156
0.4703066647052765
0.6025746464729309


  0%|▍                                                                                                                                                                                                                                      | 6/3125 [03:26<29:36:48, 34.18s/it]

0.12514491379261017
0.4814998209476471
0.6066447496414185


  0%|▌                                                                                                                                                                                                                                      | 7/3125 [04:00<29:45:40, 34.36s/it]

0.10772933065891266
0.5380478501319885
0.64577716588974


  0%|▌                                                                                                                                                                                                                                      | 8/3125 [04:35<29:50:06, 34.46s/it]

0.09788627922534943
0.5315903425216675
0.6294766068458557


  0%|▋                                                                                                                                                                                                                                      | 9/3125 [05:11<30:14:04, 34.93s/it]

0.08516058325767517
0.49674561619758606
0.5819061994552612


  0%|▋                                                                                                                                                                                                                                     | 10/3125 [05:46<30:11:45, 34.90s/it]

0.09699426591396332
0.4506804943084717
0.5476747751235962


  0%|▊                                                                                                                                                                                                                                     | 11/3125 [06:21<30:09:54, 34.87s/it]

0.09590338915586472
0.41700008511543274
0.5129034519195557


  0%|▉                                                                                                                                                                                                                                     | 12/3125 [06:56<30:14:02, 34.96s/it]

0.08559533953666687
0.4061691462993622
0.49176448583602905


  0%|▉                                                                                                                                                                                                                                     | 13/3125 [07:30<30:07:19, 34.85s/it]

0.09758633375167847
0.4148550033569336
0.5124413371086121


  0%|█                                                                                                                                                                                                                                     | 14/3125 [08:05<30:05:06, 34.81s/it]

0.09096696972846985
0.4159482419490814
0.5069152116775513


  0%|█                                                                                                                                                                                                                                     | 15/3125 [08:38<29:39:29, 34.33s/it]

0.09014927595853806
0.40489521622657776
0.4950444996356964


  1%|█▏                                                                                                                                                                                                                                    | 16/3125 [09:13<29:42:01, 34.39s/it]

0.08645707368850708
0.3999840021133423
0.48644107580184937


  1%|█▎                                                                                                                                                                                                                                    | 17/3125 [09:45<29:09:25, 33.77s/it]

0.09553436189889908
0.39514976739883423
0.4906841218471527


  1%|█▎                                                                                                                                                                                                                                    | 18/3125 [10:19<29:00:32, 33.61s/it]

0.08652438223361969
0.3944658637046814
0.4809902310371399


  1%|█▍                                                                                                                                                                                                                                    | 19/3125 [10:51<28:39:48, 33.22s/it]

0.0887974202632904
0.392701119184494
0.4814985394477844


  1%|█▍                                                                                                                                                                                                                                    | 20/3125 [11:23<28:24:18, 32.93s/it]

0.08361733704805374
0.39234301447868347
0.4759603440761566


  1%|█▌                                                                                                                                                                                                                                    | 21/3125 [11:54<27:57:24, 32.42s/it]

0.0913640707731247
0.3919111490249634
0.48327523469924927


  1%|█▌                                                                                                                                                                                                                                    | 22/3125 [12:26<27:49:41, 32.29s/it]

0.09220419079065323
0.39292898774147034
0.48513317108154297


  1%|█▋                                                                                                                                                                                                                                    | 23/3125 [13:00<28:05:25, 32.60s/it]

0.0884355902671814
0.3911934792995453
0.4796290695667267


  1%|█▊                                                                                                                                                                                                                                    | 24/3125 [13:30<27:27:20, 31.87s/it]

0.08324029296636581
0.3863489329814911
0.4695892333984375


  1%|█▊                                                                                                                                                                                                                                    | 25/3125 [14:04<27:57:26, 32.47s/it]

0.09620465338230133
0.38294392824172974
0.4791485667228699


  1%|█▉                                                                                                                                                                                                                                    | 26/3125 [14:35<27:44:58, 32.24s/it]

0.08487823605537415
0.3784160614013672
0.46329429745674133


  1%|█▉                                                                                                                                                                                                                                    | 27/3125 [15:09<28:07:03, 32.67s/it]

0.08192338794469833
0.37581926584243774
0.4577426612377167


  1%|██                                                                                                                                                                                                                                    | 28/3125 [15:43<28:22:15, 32.98s/it]

0.08595511317253113
0.3744770288467407
0.46043214201927185


  1%|██▏                                                                                                                                                                                                                                   | 29/3125 [16:17<28:44:12, 33.41s/it]

0.0805070623755455
0.3729054033756256
0.4534124732017517


  1%|██▏                                                                                                                                                                                                                                   | 29/3125 [16:52<30:00:53, 34.90s/it]


KeyboardInterrupt: 

In [49]:
noised_samples[5][0][4]

tensor([0., 0., 1.])

In [51]:
noised_samples[0][0][4]

tensor([0., 1., 0.])

In [None]:
y_

In [54]:
cat_diff.y_pred[20][0]

tensor([[0.2546, 0.0875, 0.6578],
        [0.2047, 0.1238, 0.6715],
        [0.2216, 0.0783, 0.7002],
        [0.2528, 0.0928, 0.6544],
        [0.2295, 0.0936, 0.6769],
        [0.1915, 0.0878, 0.7207],
        [0.2639, 0.0820, 0.6541],
        [0.1980, 0.0884, 0.7136],
        [0.2126, 0.0955, 0.6918],
        [0.2177, 0.0831, 0.6992],
        [0.2414, 0.0764, 0.6823],
        [0.1962, 0.0821, 0.7217],
        [0.2059, 0.0703, 0.7239],
        [0.2776, 0.0904, 0.6320],
        [0.2183, 0.0807, 0.7010],
        [0.0770, 0.7789, 0.1442],
        [0.1740, 0.0795, 0.7464],
        [0.1991, 0.0740, 0.7268],
        [0.0728, 0.8356, 0.0917],
        [0.2441, 0.1066, 0.6493],
        [0.0836, 0.7678, 0.1487],
        [0.1881, 0.0797, 0.7322],
        [0.1819, 0.0731, 0.7450],
        [0.8415, 0.0379, 0.1206],
        [0.8280, 0.0375, 0.1345],
        [0.2189, 0.0915, 0.6896],
        [0.2233, 0.0904, 0.6864],
        [0.0873, 0.7565, 0.1562],
        [0.2481, 0.0756, 0.6762],
        [0.219

In [32]:
torch.sum(noised_samples[20][0], axis=0)

tensor([ 3.,  3., 94.])

In [36]:
noised_samples[0][0]

tensor([[1., 0., 0.],
        [0., 1., 0.],
        [0., 1., 0.],
        [1., 0., 0.],
        [0., 1., 0.],
        [1., 0., 0.],
        [0., 1., 0.],
        [0., 1., 0.],
        [1., 0., 0.],
        [0., 1., 0.],
        [1., 0., 0.],
        [1., 0., 0.],
        [1., 0., 0.],
        [1., 0., 0.],
        [0., 1., 0.],
        [1., 0., 0.],
        [1., 0., 0.],
        [1., 0., 0.],
        [1., 0., 0.],
        [1., 0., 0.],
        [1., 0., 0.],
        [1., 0., 0.],
        [1., 0., 0.],
        [1., 0., 0.],
        [0., 1., 0.],
        [1., 0., 0.],
        [0., 1., 0.],
        [1., 0., 0.],
        [1., 0., 0.],
        [0., 1., 0.],
        [1., 0., 0.],
        [0., 1., 0.],
        [0., 1., 0.],
        [1., 0., 0.],
        [1., 0., 0.],
        [0., 1., 0.],
        [1., 0., 0.],
        [0., 1., 0.],
        [0., 1., 0.],
        [1., 0., 0.],
        [1., 0., 0.],
        [0., 1., 0.],
        [0., 1., 0.],
        [0., 1., 0.],
        [1., 0., 0.],
        [0

In [38]:
noised_samples[20][0]

tensor([[0., 0., 1.],
        [0., 0., 1.],
        [0., 0., 1.],
        [0., 0., 1.],
        [0., 0., 1.],
        [0., 0., 1.],
        [0., 0., 1.],
        [0., 0., 1.],
        [0., 0., 1.],
        [0., 0., 1.],
        [0., 0., 1.],
        [0., 0., 1.],
        [0., 0., 1.],
        [0., 0., 1.],
        [0., 0., 1.],
        [0., 0., 1.],
        [0., 0., 1.],
        [0., 0., 1.],
        [0., 0., 1.],
        [0., 0., 1.],
        [0., 0., 1.],
        [0., 0., 1.],
        [0., 0., 1.],
        [0., 0., 1.],
        [0., 0., 1.],
        [0., 0., 1.],
        [0., 0., 1.],
        [0., 0., 1.],
        [0., 0., 1.],
        [0., 0., 1.],
        [0., 0., 1.],
        [0., 0., 1.],
        [0., 0., 1.],
        [0., 0., 1.],
        [0., 0., 1.],
        [0., 0., 1.],
        [0., 0., 1.],
        [0., 1., 0.],
        [0., 1., 0.],
        [0., 0., 1.],
        [0., 0., 1.],
        [0., 0., 1.],
        [0., 0., 1.],
        [0., 0., 1.],
        [0., 0., 1.],
        [0

In [37]:
cat_diff.y_pred[20][0]

tensor([[0.0759, 0.1358, 0.7883],
        [0.0824, 0.1348, 0.7828],
        [0.0750, 0.0975, 0.8275],
        [0.0765, 0.1111, 0.8124],
        [0.0663, 0.1048, 0.8290],
        [0.0616, 0.1049, 0.8335],
        [0.0601, 0.1525, 0.7874],
        [0.0730, 0.0903, 0.8367],
        [0.0723, 0.0850, 0.8427],
        [0.0797, 0.0934, 0.8268],
        [0.0668, 0.1035, 0.8298],
        [0.0774, 0.1301, 0.7925],
        [0.0748, 0.1057, 0.8195],
        [0.0677, 0.1093, 0.8230],
        [0.0919, 0.1118, 0.7963],
        [0.0778, 0.1123, 0.8099],
        [0.0902, 0.1442, 0.7656],
        [0.0881, 0.1147, 0.7972],
        [0.0806, 0.0900, 0.8294],
        [0.0793, 0.0919, 0.8288],
        [0.1042, 0.1194, 0.7764],
        [0.0615, 0.1023, 0.8362],
        [0.0635, 0.0857, 0.8508],
        [0.1065, 0.1346, 0.7589],
        [0.0830, 0.1108, 0.8062],
        [0.0927, 0.1189, 0.7884],
        [0.0675, 0.1347, 0.7979],
        [0.1007, 0.1121, 0.7872],
        [0.0790, 0.1328, 0.7883],
        [0.065