In [1]:
import torch
import numpy as np
import sklearn

from utils import *
from architectures import *
import preprocess

import matplotlib.pyplot as plt

from CategoricalDiffusion import *
from denoiser import *

In [2]:
N_samples = 100000

M = 2
N_dim = 100

bias = torch.zeros((N_dim,1))
patterns = 2*torch.bernoulli(torch.ones((N_dim, M)) * 0.5) - 1

spins = generate_data(patterns, bias, N_samples, beta=4)

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [01:18<00:00,  1.28it/s]


In [3]:
index_spins = ((1+spins)/2).type(torch.int64)

In [4]:
one_hot_spin_vectors = torch.nn.functional.one_hot(index_spins)

In [5]:
class SpinsDataset(torch.utils.data.Dataset):
    """
    takes in one hot encoded spins X with one key - sequence, and Y with one, potentially two, keys - sequence and phenotype
    inputs:
        one_hot_spins: torch.Tensor representing spins (num_samples, dim, num_states)
    
    """
    
    def __init__(self,
                 spins,
                 include_mask = False,
                **kwargs):
        super().__init__(**kwargs)
        self.include_mask = include_mask
        self.spins = spins
        if self.include_mask:

            self.spins = torch.cat((self.spins, torch.zeros((self.spins.shape[0], self.spins.shape[1], 1))), axis=-1)

    def __len__(self):
        return self.spins.shape[0]

    def __getitem__(self, index):
        X = dict()
        Y = dict()

        Y['spins'] = self.spins[index]
        X['spins'] = self.spins[index]
        return X, Y

In [6]:
spins_dataset = SpinsDataset(one_hot_spin_vectors, include_mask=True)
spins_loader = torch.utils.data.DataLoader(spins_dataset, batch_size=32, shuffle=True)

In [7]:
for batch in spins_loader:
    X, Y = batch
    X_spins = X['spins']
    Y_spins = Y['spins']
    

In [8]:
noise_matrix = Noiser(noiser = 'BERT-LIKE', beta_t = 0.01, k=2).noise_matrix
ts, noised_samples = noiser(X_spins, noise_matrix, 100, X_spins.shape[-1])

In [9]:
denoiser = Denoiser(8,1,128,X_spins.shape[2])

In [10]:
cat_diff = CategoricalDiffusion(denoiser, noise_matrix)

In [11]:
cat_diff.decode(noised_samples, ts)

In [12]:
cat_diff.calc_forward_conditionals(noised_samples)

In [13]:
cat_diff.y_pred

tensor([[[[0.3041, 0.3294, 0.3664],
          [0.3103, 0.3205, 0.3692],
          [0.2981, 0.3307, 0.3712],
          ...,
          [0.3137, 0.3137, 0.3727],
          [0.3179, 0.3179, 0.3641],
          [0.3186, 0.3186, 0.3628]],

         [[0.3010, 0.3186, 0.3804],
          [0.3062, 0.3062, 0.3877],
          [0.3171, 0.3171, 0.3659],
          ...,
          [0.2894, 0.3355, 0.3751],
          [0.3027, 0.3165, 0.3808],
          [0.3069, 0.3335, 0.3596]],

         [[0.3033, 0.3033, 0.3933],
          [0.3234, 0.3234, 0.3531],
          [0.3128, 0.3129, 0.3743],
          ...,
          [0.3272, 0.3272, 0.3456],
          [0.3179, 0.3209, 0.3612],
          [0.3172, 0.3172, 0.3656]],

         ...,

         [[0.3198, 0.3399, 0.3402],
          [0.3090, 0.3090, 0.3821],
          [0.3282, 0.3249, 0.3469],
          ...,
          [0.3007, 0.3140, 0.3853],
          [0.3058, 0.3103, 0.3838],
          [0.2952, 0.2971, 0.4077]],

         [[0.3068, 0.3068, 0.3865],
          [0.3109

In [14]:
cat_diff.L_T(noised_samples)

tensor(1.6887)

In [15]:
cat_diff.L_t0t1(Y_spins)

tensor(1.1711, grad_fn=<NegBackward0>)

In [16]:
cat_diff.L_tminus1(Y_spins, noised_samples)

tensor(0.7824, grad_fn=<MeanBackward0>)

In [None]:
from tqdm import tqdm
epochs = 100

cat_diff.train()
optim = torch.optim.Adam(cat_diff.parameters(), lr = 1e-3)

for epoch in range(epochs):
    overall_loss = 0
    spins_loader = torch.utils.data.DataLoader(spins_dataset, batch_size=32, shuffle=True)
    for batch in tqdm(spins_loader):
        X, Y = batch
        
        X_spins = X['spins']
        Y_spins = Y['spins']
        
        ts, noised_samples = noiser(X_spins, noise_matrix, 100, X_spins.shape[-1])

        cat_diff.decode(noised_samples, ts)
        cat_diff.calc_forward_conditionals(noised_samples)
        
        LT_loss     = cat_diff.L_T(noised_samples)
        
        Lt0_t1_loss = cat_diff.L_t0t1(Y_spins)
        print(Lt0_t1_loss.item())

        Ltminus1    = cat_diff.L_tminus1(Y_spins, noised_samples)
        print(Ltminus1.item())

        loss = Lt0_t1_loss + Ltminus1
        print(loss.item())
        
        optim.zero_grad()
        loss.backward()
        optim.step()
        
        overall_loss += loss.item()
    print('overall loss at epoch {} is '.format(epoch) + str(overall_loss/X_train.shape[0]))

  0%|                                                                                                                                                                                                                                                  | 0/3125 [00:00<?, ?it/s]

1.1699272394180298
0.784461498260498
1.9543887376785278


  0%|                                                                                                                                                                                                                                       | 1/3125 [00:43<37:19:23, 43.01s/it]

1.064200758934021
0.725055456161499
1.78925621509552


  0%|▏                                                                                                                                                                                                                                      | 2/3125 [01:20<34:29:34, 39.76s/it]

1.0468299388885498
0.6229413151741028
1.6697711944580078


In [None]:
noised_samples

In [19]:
cat_diff.y_pred

tensor([[[[0.4912, 0.5029, 0.0060],
          [0.4970, 0.4975, 0.0055],
          [0.4976, 0.4967, 0.0057],
          ...,
          [0.4794, 0.5148, 0.0057],
          [0.4925, 0.5019, 0.0056],
          [0.4943, 0.4988, 0.0069]],

         [[0.5143, 0.4802, 0.0055],
          [0.5202, 0.4739, 0.0060],
          [0.5060, 0.4885, 0.0055],
          ...,
          [0.5052, 0.4891, 0.0056],
          [0.4812, 0.5138, 0.0051],
          [0.5073, 0.4862, 0.0065]],

         [[0.4856, 0.5086, 0.0059],
          [0.4917, 0.5032, 0.0051],
          [0.5107, 0.4837, 0.0056],
          ...,
          [0.5127, 0.4810, 0.0062],
          [0.4988, 0.4958, 0.0054],
          [0.4810, 0.5132, 0.0058]],

         ...,

         [[0.4987, 0.4961, 0.0053],
          [0.5291, 0.4650, 0.0059],
          [0.4945, 0.4991, 0.0064],
          ...,
          [0.4904, 0.5043, 0.0053],
          [0.5148, 0.4792, 0.0060],
          [0.4913, 0.5025, 0.0062]],

         [[0.5160, 0.4786, 0.0054],
          [0.5015