In [1]:
import torch
import numpy as np
import sklearn

from utils import *
from architectures import *
import preprocess

import matplotlib.pyplot as plt

from CategoricalDiffusion import *
from denoiser import *

In [2]:
N_samples = 100000

M = 2
N_dim = 100

bias = torch.zeros((N_dim,1))
patterns = 2*torch.bernoulli(torch.ones((N_dim, M)) * 0.5) - 1

spins = generate_data(patterns, bias, N_samples, beta=4)

100%|█████████████████████████████████████████| 100/100 [03:14<00:00,  1.94s/it]


In [3]:
index_spins = ((1+spins)/2).type(torch.int64)

In [4]:
one_hot_spin_vectors = torch.nn.functional.one_hot(index_spins)

In [5]:
class SpinsDataset(torch.utils.data.Dataset):
    """
    takes in one hot encoded spins X with one key - sequence, and Y with one, potentially two, keys - sequence and phenotype
    inputs:
        one_hot_spins: torch.Tensor representing spins (num_samples, dim, num_states)
    
    """
    
    def __init__(self,
                 spins,
                 include_mask = False,
                **kwargs):
        super().__init__(**kwargs)
        self.include_mask = include_mask
        self.spins = spins
        if self.include_mask:

            self.spins = torch.cat((self.spins, torch.zeros((self.spins.shape[0], self.spins.shape[1], 1))), axis=-1)

    def __len__(self):
        return self.spins.shape[0]

    def __getitem__(self, index):
        X = dict()
        Y = dict()

        Y['spins'] = self.spins[index]
        X['spins'] = self.spins[index]
        return X, Y

In [6]:
spins_dataset = SpinsDataset(one_hot_spin_vectors, include_mask=True)
spins_loader = torch.utils.data.DataLoader(spins_dataset, batch_size=32, shuffle=True)

In [7]:
for batch in spins_loader:
    X, Y = batch
    X_spins = X['spins']
    Y_spins = Y['spins']
    

In [8]:
noise_matrix = Noiser(noiser = 'BERT-LIKE', beta_t = 0.01, k=2).noise_matrix
ts, noised_samples = noiser(X_spins, noise_matrix, 100, X_spins.shape[-1])

In [9]:
ts.shape

torch.Size([100, 32])

In [10]:
noised_samples.shape

torch.Size([100, 32, 100, 3])

In [11]:
noised_samples.shape

torch.Size([100, 32, 100, 3])

In [22]:
denoiser = AttentionDenoiser(8,1,128,X_spins.shape[2])

In [23]:
Y_pred = denoiser(noised_samples, ts)

In [24]:
Y_pred.shape

torch.Size([100, 32, 100, 3])

In [25]:
denoiser

AttentionDenoiser(
  (activation): ReLU()
  (forward_time_1): FeedForward(
    (layer): Linear(in_features=1, out_features=128, bias=True)
    (activation): ReLU()
  )
  (forward_time_2): FeedForward(
    (layer): Linear(in_features=128, out_features=128, bias=True)
    (activation): ReLU()
  )
  (forward_time_3): FeedForward(
    (layer): Linear(in_features=128, out_features=128, bias=True)
    (activation): ReLU()
  )
  (forward_seq_1): FeedForward(
    (layer): Linear(in_features=3, out_features=128, bias=True)
    (activation): ReLU()
  )
  (forward_seq_2): FeedForward(
    (layer): Linear(in_features=128, out_features=128, bias=True)
    (activation): ReLU()
  )
  (forward_seq_3): FeedForward(
    (layer): Linear(in_features=128, out_features=128, bias=True)
    (activation): ReLU()
  )
  (forward_time_seq): FeedForward(
    (layer): Linear(in_features=256, out_features=128, bias=True)
    (activation): ReLU()
  )
  (mha_1): MultiHeadedAttention(
    (attention): Attention()
    (

In [26]:
cat_diff = CategoricalDiffusion(denoiser, noise_matrix)

In [27]:
cat_diff.decode(noised_samples, ts)

In [28]:
cat_diff.calc_forward_conditionals(noised_samples)

In [29]:
cat_diff.y_pred

tensor([[[[0.3666, 0.3440, 0.2894],
          [0.3653, 0.3328, 0.3019],
          [0.3655, 0.3346, 0.2999],
          ...,
          [0.3537, 0.3535, 0.2928],
          [0.3781, 0.3257, 0.2962],
          [0.3929, 0.3203, 0.2868]],

         [[0.4009, 0.3083, 0.2908],
          [0.3802, 0.3233, 0.2965],
          [0.3628, 0.3294, 0.3078],
          ...,
          [0.3805, 0.3129, 0.3066],
          [0.3924, 0.3038, 0.3038],
          [0.3858, 0.3237, 0.2905]],

         [[0.3619, 0.3428, 0.2953],
          [0.4102, 0.3021, 0.2878],
          [0.4272, 0.2864, 0.2864],
          ...,
          [0.3633, 0.3398, 0.2969],
          [0.3867, 0.3083, 0.3050],
          [0.3800, 0.3354, 0.2847]],

         ...,

         [[0.3722, 0.3339, 0.2939],
          [0.3838, 0.3165, 0.2997],
          [0.3825, 0.3208, 0.2967],
          ...,
          [0.4035, 0.2933, 0.3032],
          [0.3527, 0.3408, 0.3065],
          [0.3703, 0.3310, 0.2987]],

         [[0.3749, 0.3123, 0.3128],
          [0.3859

In [30]:
cat_diff.L_T(noised_samples)

tensor(1.7190)

In [31]:
cat_diff.L_t0t1(Y_spins)

tensor(1.0542, grad_fn=<NegBackward0>)

In [32]:
cat_diff.L_tminus1(Y_spins, noised_samples)

tensor(0.8860, grad_fn=<MeanBackward0>)

In [63]:
from tqdm import tqdm
epochs = 100

cat_diff.train()
optim = torch.optim.Adam(cat_diff.parameters(), lr = 1e-3)

for epoch in range(epochs):
    overall_loss = 0
    spins_loader = torch.utils.data.DataLoader(spins_dataset, batch_size=32, shuffle=True)
    for batch in tqdm(spins_loader):
        X, Y = batch
        
        X_spins = X['spins']
        Y_spins = Y['spins']
        
        ts, noised_samples = noiser(X_spins, noise_matrix, 100, X_spins.shape[-1])

        cat_diff.decode(noised_samples, ts)
        cat_diff.calc_forward_conditionals(noised_samples)
        
        LT_loss     = cat_diff.L_T(noised_samples)
        
        Lt0_t1_loss = cat_diff.L_t0t1(Y_spins)
        print(Lt0_t1_loss.item())

        Ltminus1    = cat_diff.L_tminus1(Y_spins, noised_samples)
        print(Ltminus1.item())

        loss = Lt0_t1_loss + Ltminus1
        print(loss.item())
        
        optim.zero_grad()
        loss.backward()
        optim.step()
        
        overall_loss += loss.item()
    print('overall loss at epoch {} is '.format(epoch) + str(overall_loss/X_train.shape[0]))

  0%|                                                  | 0/3125 [00:00<?, ?it/s]

0.16421779990196228
0.8125371932983398
0.9767550230026245


  0%|                                     | 1/3125 [02:28<129:03:56, 148.73s/it]

0.30432596802711487
0.7709723114967346
1.0752983093261719


  0%|                                     | 2/3125 [04:28<114:19:30, 131.79s/it]

0.6773074269294739
0.8158389329910278
1.4931464195251465


  0%|                                     | 3/3125 [06:11<102:55:50, 118.69s/it]

0.11997884511947632
0.7584633231163025
0.8784421682357788


  0%|                                      | 4/3125 [07:54<97:35:22, 112.57s/it]

0.1489763706922531
0.7436121106147766
0.8925884962081909


  0%|                                      | 5/3125 [09:36<94:07:11, 108.60s/it]

0.1511092633008957
0.6768866777420044
0.8279959559440613


  0%|                                      | 6/3125 [11:17<91:54:59, 106.09s/it]

0.1121843159198761
0.6301479339599609
0.7423322200775146


  0%|                                      | 7/3125 [13:08<93:16:04, 107.69s/it]

0.11779102683067322
0.5986711978912354
0.716462254524231


  0%|                                      | 8/3125 [14:50<91:29:10, 105.66s/it]

0.10749813169240952
0.5632871389389038
0.6707852482795715


  0%|                                      | 9/3125 [16:34<91:08:00, 105.29s/it]

0.10868397355079651
0.5163050293922424
0.6249890327453613


  0%|                                     | 10/3125 [18:16<90:10:33, 104.22s/it]

0.08883217722177505
0.48652133345603943
0.5753535032272339


  0%|                                    | 10/3125 [20:05<104:21:02, 120.60s/it]


KeyboardInterrupt: 

In [68]:
cat_diff.y_pred[-20]

tensor([[[0.1477, 0.1477, 0.7045],
         [0.1340, 0.1340, 0.7320],
         [0.1431, 0.1431, 0.7139],
         ...,
         [0.1390, 0.1390, 0.7220],
         [0.1402, 0.1402, 0.7196],
         [0.1352, 0.1352, 0.7295]],

        [[0.1360, 0.1360, 0.7279],
         [0.1444, 0.1444, 0.7113],
         [0.1399, 0.1399, 0.7203],
         ...,
         [0.1459, 0.1459, 0.7083],
         [0.1341, 0.1341, 0.7318],
         [0.1393, 0.1393, 0.7215]],

        [[0.1431, 0.1431, 0.7138],
         [0.1357, 0.1357, 0.7286],
         [0.1397, 0.1397, 0.7206],
         ...,
         [0.1368, 0.1368, 0.7265],
         [0.1412, 0.1412, 0.7176],
         [0.1506, 0.1506, 0.6988]],

        ...,

        [[0.1430, 0.1430, 0.7140],
         [0.1346, 0.1346, 0.7308],
         [0.1407, 0.1407, 0.7186],
         ...,
         [0.1325, 0.1325, 0.7350],
         [0.1360, 0.1360, 0.7279],
         [0.1437, 0.1437, 0.7126]],

        [[0.1378, 0.1378, 0.7244],
         [0.1362, 0.1362, 0.7276],
         [0.

In [69]:
noised_samples[-20]

tensor([[[0., 0., 1.],
         [0., 0., 1.],
         [0., 0., 1.],
         ...,
         [0., 0., 1.],
         [0., 0., 1.],
         [0., 0., 1.]],

        [[0., 0., 1.],
         [0., 0., 1.],
         [0., 0., 1.],
         ...,
         [0., 0., 1.],
         [0., 0., 1.],
         [0., 0., 1.]],

        [[0., 0., 1.],
         [0., 0., 1.],
         [0., 0., 1.],
         ...,
         [0., 0., 1.],
         [0., 0., 1.],
         [0., 0., 1.]],

        ...,

        [[0., 0., 1.],
         [0., 0., 1.],
         [0., 0., 1.],
         ...,
         [0., 0., 1.],
         [0., 0., 1.],
         [0., 0., 1.]],

        [[0., 0., 1.],
         [0., 0., 1.],
         [0., 0., 1.],
         ...,
         [0., 0., 1.],
         [0., 0., 1.],
         [0., 0., 1.]],

        [[0., 0., 1.],
         [0., 0., 1.],
         [0., 0., 1.],
         ...,
         [0., 0., 1.],
         [0., 0., 1.],
         [0., 0., 1.]]])

In [58]:
noised_samples.shape

torch.Size([100, 32, 100, 3])

In [62]:
cat_diff.y_pred[-80][0]

tensor([[0.3333, 0.3333, 0.3333],
        [0.3333, 0.3333, 0.3333],
        [0.3333, 0.3333, 0.3333],
        [0.3333, 0.3333, 0.3333],
        [0.3333, 0.3333, 0.3333],
        [0.3333, 0.3333, 0.3333],
        [0.3333, 0.3333, 0.3333],
        [0.3333, 0.3333, 0.3333],
        [0.3333, 0.3333, 0.3333],
        [0.3333, 0.3333, 0.3333],
        [0.3333, 0.3333, 0.3333],
        [0.3333, 0.3333, 0.3333],
        [0.3333, 0.3333, 0.3333],
        [0.3333, 0.3333, 0.3333],
        [0.3333, 0.3333, 0.3333],
        [0.3333, 0.3333, 0.3333],
        [0.3333, 0.3333, 0.3333],
        [0.3333, 0.3333, 0.3333],
        [0.3333, 0.3333, 0.3333],
        [0.3333, 0.3333, 0.3333],
        [0.3333, 0.3333, 0.3333],
        [0.3333, 0.3333, 0.3333],
        [0.3333, 0.3333, 0.3333],
        [0.3333, 0.3333, 0.3333],
        [0.3333, 0.3333, 0.3333],
        [0.3333, 0.3333, 0.3333],
        [0.3333, 0.3333, 0.3333],
        [0.3333, 0.3333, 0.3333],
        [0.3333, 0.3333, 0.3333],
        [0.333

In [34]:
cat_diff.y_pred

tensor([[[[0.1709, 0.7877, 0.0414],
          [0.8030, 0.1634, 0.0336],
          [0.1539, 0.8111, 0.0349],
          ...,
          [0.1511, 0.8132, 0.0357],
          [0.8139, 0.1538, 0.0323],
          [0.1841, 0.7795, 0.0364]],

         [[0.1337, 0.8304, 0.0358],
          [0.8365, 0.1304, 0.0331],
          [0.1624, 0.8012, 0.0364],
          ...,
          [0.1556, 0.8074, 0.0371],
          [0.7740, 0.1878, 0.0382],
          [0.1822, 0.7812, 0.0366]],

         [[0.1527, 0.8109, 0.0364],
          [0.8175, 0.1504, 0.0321],
          [0.1919, 0.7714, 0.0368],
          ...,
          [0.1483, 0.8171, 0.0346],
          [0.7637, 0.1947, 0.0417],
          [0.1757, 0.7848, 0.0394]],

         ...,

         [[0.1291, 0.8332, 0.0377],
          [0.8154, 0.1508, 0.0338],
          [0.8095, 0.1545, 0.0361],
          ...,
          [0.1802, 0.7799, 0.0398],
          [0.1443, 0.8168, 0.0389],
          [0.8178, 0.1449, 0.0372]],

         [[0.8469, 0.1207, 0.0325],
          [0.1528

In [51]:
noised_samples[0][0][4]

tensor([0., 1., 0.])

In [None]:
y_

In [54]:
cat_diff.y_pred[20][0]

tensor([[0.2546, 0.0875, 0.6578],
        [0.2047, 0.1238, 0.6715],
        [0.2216, 0.0783, 0.7002],
        [0.2528, 0.0928, 0.6544],
        [0.2295, 0.0936, 0.6769],
        [0.1915, 0.0878, 0.7207],
        [0.2639, 0.0820, 0.6541],
        [0.1980, 0.0884, 0.7136],
        [0.2126, 0.0955, 0.6918],
        [0.2177, 0.0831, 0.6992],
        [0.2414, 0.0764, 0.6823],
        [0.1962, 0.0821, 0.7217],
        [0.2059, 0.0703, 0.7239],
        [0.2776, 0.0904, 0.6320],
        [0.2183, 0.0807, 0.7010],
        [0.0770, 0.7789, 0.1442],
        [0.1740, 0.0795, 0.7464],
        [0.1991, 0.0740, 0.7268],
        [0.0728, 0.8356, 0.0917],
        [0.2441, 0.1066, 0.6493],
        [0.0836, 0.7678, 0.1487],
        [0.1881, 0.0797, 0.7322],
        [0.1819, 0.0731, 0.7450],
        [0.8415, 0.0379, 0.1206],
        [0.8280, 0.0375, 0.1345],
        [0.2189, 0.0915, 0.6896],
        [0.2233, 0.0904, 0.6864],
        [0.0873, 0.7565, 0.1562],
        [0.2481, 0.0756, 0.6762],
        [0.219

In [32]:
torch.sum(noised_samples[20][0], axis=0)

tensor([ 3.,  3., 94.])

In [36]:
noised_samples[0][0]

tensor([[1., 0., 0.],
        [0., 1., 0.],
        [0., 1., 0.],
        [1., 0., 0.],
        [0., 1., 0.],
        [1., 0., 0.],
        [0., 1., 0.],
        [0., 1., 0.],
        [1., 0., 0.],
        [0., 1., 0.],
        [1., 0., 0.],
        [1., 0., 0.],
        [1., 0., 0.],
        [1., 0., 0.],
        [0., 1., 0.],
        [1., 0., 0.],
        [1., 0., 0.],
        [1., 0., 0.],
        [1., 0., 0.],
        [1., 0., 0.],
        [1., 0., 0.],
        [1., 0., 0.],
        [1., 0., 0.],
        [1., 0., 0.],
        [0., 1., 0.],
        [1., 0., 0.],
        [0., 1., 0.],
        [1., 0., 0.],
        [1., 0., 0.],
        [0., 1., 0.],
        [1., 0., 0.],
        [0., 1., 0.],
        [0., 1., 0.],
        [1., 0., 0.],
        [1., 0., 0.],
        [0., 1., 0.],
        [1., 0., 0.],
        [0., 1., 0.],
        [0., 1., 0.],
        [1., 0., 0.],
        [1., 0., 0.],
        [0., 1., 0.],
        [0., 1., 0.],
        [0., 1., 0.],
        [1., 0., 0.],
        [0

In [38]:
noised_samples[20][0]

tensor([[0., 0., 1.],
        [0., 0., 1.],
        [0., 0., 1.],
        [0., 0., 1.],
        [0., 0., 1.],
        [0., 0., 1.],
        [0., 0., 1.],
        [0., 0., 1.],
        [0., 0., 1.],
        [0., 0., 1.],
        [0., 0., 1.],
        [0., 0., 1.],
        [0., 0., 1.],
        [0., 0., 1.],
        [0., 0., 1.],
        [0., 0., 1.],
        [0., 0., 1.],
        [0., 0., 1.],
        [0., 0., 1.],
        [0., 0., 1.],
        [0., 0., 1.],
        [0., 0., 1.],
        [0., 0., 1.],
        [0., 0., 1.],
        [0., 0., 1.],
        [0., 0., 1.],
        [0., 0., 1.],
        [0., 0., 1.],
        [0., 0., 1.],
        [0., 0., 1.],
        [0., 0., 1.],
        [0., 0., 1.],
        [0., 0., 1.],
        [0., 0., 1.],
        [0., 0., 1.],
        [0., 0., 1.],
        [0., 0., 1.],
        [0., 1., 0.],
        [0., 1., 0.],
        [0., 0., 1.],
        [0., 0., 1.],
        [0., 0., 1.],
        [0., 0., 1.],
        [0., 0., 1.],
        [0., 0., 1.],
        [0

In [37]:
cat_diff.y_pred[20][0]

tensor([[0.0759, 0.1358, 0.7883],
        [0.0824, 0.1348, 0.7828],
        [0.0750, 0.0975, 0.8275],
        [0.0765, 0.1111, 0.8124],
        [0.0663, 0.1048, 0.8290],
        [0.0616, 0.1049, 0.8335],
        [0.0601, 0.1525, 0.7874],
        [0.0730, 0.0903, 0.8367],
        [0.0723, 0.0850, 0.8427],
        [0.0797, 0.0934, 0.8268],
        [0.0668, 0.1035, 0.8298],
        [0.0774, 0.1301, 0.7925],
        [0.0748, 0.1057, 0.8195],
        [0.0677, 0.1093, 0.8230],
        [0.0919, 0.1118, 0.7963],
        [0.0778, 0.1123, 0.8099],
        [0.0902, 0.1442, 0.7656],
        [0.0881, 0.1147, 0.7972],
        [0.0806, 0.0900, 0.8294],
        [0.0793, 0.0919, 0.8288],
        [0.1042, 0.1194, 0.7764],
        [0.0615, 0.1023, 0.8362],
        [0.0635, 0.0857, 0.8508],
        [0.1065, 0.1346, 0.7589],
        [0.0830, 0.1108, 0.8062],
        [0.0927, 0.1189, 0.7884],
        [0.0675, 0.1347, 0.7979],
        [0.1007, 0.1121, 0.7872],
        [0.0790, 0.1328, 0.7883],
        [0.065