In [1]:
import torch
import numpy as np
import sklearn

from utils import *
from architectures import *
import preprocess

import matplotlib.pyplot as plt

from CategoricalDiffusion import *
from denoiser import *

In [2]:
N_samples = 100000

M = 2
N_dim = 100

bias = torch.zeros((N_dim,1))
patterns = 2*torch.bernoulli(torch.ones((N_dim, M)) * 0.5) - 1

spins = generate_data(patterns, bias, N_samples, beta=4)

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [01:18<00:00,  1.27it/s]


In [3]:
index_spins = ((1+spins)/2).type(torch.int64)

In [4]:
one_hot_spin_vectors = torch.nn.functional.one_hot(index_spins)

In [5]:
class SpinsDataset(torch.utils.data.Dataset):
    """
    takes in one hot encoded spins X with one key - sequence, and Y with one, potentially two, keys - sequence and phenotype
    inputs:
        one_hot_spins: torch.Tensor representing spins (num_samples, dim, num_states)
    
    """
    
    def __init__(self,
                 spins,
                 include_mask = False,
                **kwargs):
        super().__init__(**kwargs)
        self.include_mask = include_mask
        self.spins = spins
        if self.include_mask:

            self.spins = torch.cat((self.spins, torch.zeros((self.spins.shape[0], self.spins.shape[1], 1))), axis=-1)

    def __len__(self):
        return self.spins.shape[0]

    def __getitem__(self, index):
        X = dict()
        Y = dict()

        Y['spins'] = self.spins[index]
        X['spins'] = self.spins[index]
        return X, Y

In [6]:
spins_dataset = SpinsDataset(one_hot_spin_vectors, include_mask=True)
spins_loader = torch.utils.data.DataLoader(spins_dataset, batch_size=32, shuffle=True)

In [7]:
for batch in spins_loader:
    X, Y = batch
    X_spins = X['spins']
    Y_spins = Y['spins']
    

In [8]:
noise_matrix = Noiser(noiser = 'BERT-LIKE', beta_t = 0.01, k=2).noise_matrix
ts, noised_samples = noiser(X_spins, noise_matrix, 100, X_spins.shape[-1])

In [9]:
class CategoricalDiffusion(torch.nn.Module):
    """
    Base class for Categorical diffusion.
    input:
        noise_matrix: tensor, K x K states
        denoiser: torch.nn.Module, outputs denoised samples
    """
    def __init__(
        self,
        denoiser,
        noise_matrix,
        **kwargs
    ):
        super().__init__(**kwargs)
        self.denoiser = denoiser
        self.noise_matrix = noise_matrix
        vals, vecs = torch.linalg.eig(self.noise_matrix.t())
        vals = torch.real(vals)
        vecs = torch.real(vecs)
        self.PxT = vecs[:, torch.argmax(vals)].unsqueeze(0).unsqueeze(1)


    def decode(self, noised_samples, ts):
        """
        Forward pass of denoiser
        input:
            noised_samples: tensor (num_time_steps, batch_size, seq_length, num_classes)
            ts: tensor (num_time_steps, batch_size)
        """
        self.y_pred = self.denoiser(noised_samples, ts)

    def calc_forward_conditionals(self, noised_samples):
        """
        computes forward conditionals of the noised samples
        input:
            noised_samples: tensor (num_time_steps, batch_size, seq_length, num_classes)
        """
        forward_conditionals = torch.zeros(noised_samples.shape)

        for t in range(forward_conditionals.shape[0]):
            forward_conditionals[t] = torch.matmul(noised_samples[0], self.noise_matrix.matrix_power(t))
            
        self.forward_conditionals = forward_conditionals

    def one_step_reverse_conditional(self, real, noised_sample):
        """
        computes the one step reverse conditional
        input:
            real: tensor, (batch_size, seq_length, num_states)
            noised_samples: tensor, (num_time_steps, batch_size, seq_length, num_states)
        output:
            reverse_conditionals: tensor, (num_time_steps-2, batch_size, seq_length, num_states)
        """
        reverse_conditionals = torch.zeros(
            noised_samples.shape[0]-2, noised_samples.shape[1], noised_samples.shape[2], noised_samples.shape[3]
        )
        x0=real
        for t in range(2, noised_samples.shape[0]):
            
            xt = noised_samples[t]
            numer = torch.matmul(xt, noise_matrix.t()) * torch.matmul(x0, self.noise_matrix.matrix_power(t-1))
            denom = torch.matmul(torch.matmul(x0, self.noise_matrix.matrix_power(t)), xt.permute(0,2,1))
            denom = torch.diagonal(denom, dim1=-2, dim2=-1).unsqueeze(-1)
            reverse_conditionals[t-2] = numer/denom
            
        return reverse_conditionals

    def q_xtminus1_xt_giv_x0(self, noised_samples, reverse_conditionals):
        
        q_xtminus1_xt_giv_x0 = torch.zeros(
            (noised_samples.shape[0]-2, noised_samples.shape[1], noised_samples.shape[2], noised_samples.shape[3])
        )

        for t in range(q_xtminus1_xt_giv_x0.shape[0]):
            q_xtminus1_xt_giv_x0[t] = reverse_conditionals[t] * torch.matmul(noised_samples[0], self.noise_matrix.matrix_power(t+2))

        return q_xtminus1_xt_giv_x0

    def L_T(self, noised_sample): 
        """
        computes D_KL of noised samples from the steady state distribution. Not relevant for gradient updates
        input:
            noised_samples: tensor (num_time_steps, batch_size, seq_length, num_classes)
        output:
            dkl_steady_state: tensor, (1, grad=None)
        """
        dkl_steady_state = torch.mean(
            torch.sum(
                (noised_sample+1e-6) * torch.log((noised_sample+1e-6)/(self.PxT+1e-6)),
                axis=-1
            )
        )
    
        return dkl_steady_state        
    
    def L_t0t1(self, real):
        """
        computes cross entropy of the one step decoder. requires self.decode and self.calc_forward_conditionals
        input:
            real: tensor, (batch_size, seq_length, states)
        output:
            one_step_cross_ent_loss: tensor, (1, grad=True)
        """
        y_pred = self.y_pred[1]
        return -torch.mean(torch.sum(self.forward_conditionals[1] * torch.log(y_pred),axis=-1))

        

    def L_tminus1(self, real, noised_sample):
        """
        computes D_KL of each one step backwards decoding
        input:
            real: tensor, (batch_size, seq_length, states)
            noised_sample: tensor, (num_time_steps, batch_size, seq_length, states)
        output:
            d_kl_per_time_step: tensor, (1, grad=True)
        """
        
        reverse_conditionals = self.one_step_reverse_conditional(real, noised_sample)

        
        q_xtminus1_xt_giv_x0 = self.q_xtminus1_xt_giv_x0(noised_sample, reverse_conditionals)

        denoised = self.y_pred[2:]

        px_tminus1_giv_xt = q_xtminus1_xt_giv_x0 * denoised
        
        return torch.mean(
            torch.sum(
                self.forward_conditionals[2:] * (
                    reverse_conditionals * torch.log(
                        (reverse_conditionals + 1e-6)/(px_tminus1_giv_xt+1e-6)
                    )
                ),
                axis=-1
            )
        )

In [10]:
denoiser = Denoiser(8,1,128,X_spins.shape[2])

In [11]:
cat_diff = CategoricalDiffusion(denoiser, noise_matrix)

In [12]:
cat_diff.decode(noised_samples, ts)

In [13]:
cat_diff.calc_forward_conditionals(noised_samples)

In [14]:
cat_diff.L_T(noised_samples)

tensor(1.7133)

In [15]:
cat_diff.L_t0t1(Y_spins)

tensor(1.2170, grad_fn=<NegBackward0>)

In [16]:
cat_diff.L_tminus1(Y_spins, noised_samples)

tensor(0.8247, grad_fn=<MeanBackward0>)

In [None]:
from tqdm import tqdm
epochs = 100

cat_diff.train()
optim = torch.optim.Adam(cat_diff.parameters(), lr = 1e-3)

for epoch in range(epochs):
    overall_loss = 0
    spins_loader = torch.utils.data.DataLoader(spins_dataset, batch_size=32, shuffle=True)
    for batch in tqdm(spins_loader):
        X, Y = batch
        
        X_spins = X['spins']
        Y_spins = Y['spins']
        
        ts, noised_samples = noiser(X_spins, noise_matrix, 100, X_spins.shape[-1])

        cat_diff.decode(noised_samples, ts)
        
        LT_loss     = cat_diff.L_T(noised_samples)
        
        Lt0_t1_loss = cat_diff.L_t0t1(Y_spins)
        print(Lt0_t1_loss.item())

        Ltminus1    = cat_diff.L_tminus1(Y_spins, noised_samples)
        print(Ltminus1.item())

        loss = Lt0_t1_loss + Ltminus1
        print(loss.item())
        
        optim.zero_grad()
        loss.backward()
        optim.step()
        
        overall_loss += loss.item()
    print('overall loss at epoch {} is '.format(epoch) + str(overall_loss/X_train.shape[0]))

  0%|                                                                                                                                                                                                                                                  | 0/3125 [00:00<?, ?it/s]

1.21830415725708
0.7575411796569824
1.9758453369140625


  0%|                                                                                                                                                                                                                                       | 1/3125 [00:43<37:25:04, 43.12s/it]

1.0813572406768799
0.7494692802429199
1.8308265209197998


  0%|▏                                                                                                                                                                                                                                      | 2/3125 [01:19<34:09:35, 39.38s/it]

0.9573370218276978
0.7444831132888794
1.7018201351165771


  0%|▏                                                                                                                                                                                                                                      | 3/3125 [01:55<32:32:48, 37.53s/it]

0.8924304246902466
0.742101788520813
1.6345322132110596


  0%|▎                                                                                                                                                                                                                                      | 4/3125 [02:29<31:22:26, 36.19s/it]

0.8543373346328735
0.739834189414978
1.5941715240478516


  0%|▎                                                                                                                                                                                                                                      | 5/3125 [03:04<31:00:06, 35.77s/it]

0.8293694853782654
0.7438334226608276
1.5732028484344482


  0%|▍                                                                                                                                                                                                                                      | 6/3125 [03:39<30:51:31, 35.62s/it]

0.8111544251441956
0.7444204688072205
1.555574893951416


  0%|▌                                                                                                                                                                                                                                      | 7/3125 [04:15<30:57:04, 35.74s/it]

0.7992609143257141
0.7437841296195984
1.5430450439453125


  0%|▌                                                                                                                                                                                                                                      | 8/3125 [04:50<30:34:29, 35.31s/it]

0.7883644104003906
0.7389649152755737
1.5273293256759644


  0%|▋                                                                                                                                                                                                                                      | 9/3125 [05:27<31:01:16, 35.84s/it]

0.7792285084724426
0.7418856620788574
1.5211141109466553


  0%|▋                                                                                                                                                                                                                                     | 10/3125 [06:01<30:42:54, 35.50s/it]

0.7719051241874695
0.7500696778297424
1.521974802017212


  0%|▊                                                                                                                                                                                                                                     | 11/3125 [06:35<30:14:41, 34.97s/it]

0.7654650211334229
0.7501637935638428
1.5156288146972656


  0%|▉                                                                                                                                                                                                                                     | 12/3125 [07:09<30:00:43, 34.71s/it]

0.7604207396507263
0.7371329069137573
1.4975535869598389


  0%|▉                                                                                                                                                                                                                                     | 13/3125 [07:45<30:21:17, 35.12s/it]

0.7549020648002625
0.7424541711807251
1.4973561763763428


  0%|█                                                                                                                                                                                                                                     | 14/3125 [08:24<31:16:39, 36.19s/it]

0.7514030933380127
0.7450118660926819
1.4964148998260498


  0%|█                                                                                                                                                                                                                                     | 15/3125 [09:00<31:15:33, 36.18s/it]

0.7493706345558167
0.7482198476791382
1.4975905418395996


  1%|█▏                                                                                                                                                                                                                                    | 16/3125 [09:35<31:01:04, 35.92s/it]

0.7472766041755676
0.7448191046714783
1.492095708847046


  1%|█▎                                                                                                                                                                                                                                    | 17/3125 [10:11<30:58:26, 35.88s/it]

0.746214747428894
0.7539440393447876
1.5001587867736816


  1%|█▎                                                                                                                                                                                                                                    | 18/3125 [10:48<31:08:30, 36.08s/it]

0.7434281706809998
0.7501891255378723
1.493617296218872


  1%|█▍                                                                                                                                                                                                                                    | 19/3125 [11:23<30:52:52, 35.79s/it]

0.7432730197906494
0.7322185039520264
1.4754915237426758


  1%|█▍                                                                                                                                                                                                                                    | 20/3125 [11:58<30:42:51, 35.61s/it]

0.742911159992218
0.7474761009216309
1.490387201309204


  1%|█▌                                                                                                                                                                                                                                    | 21/3125 [12:32<30:14:18, 35.07s/it]

0.7425565123558044
0.7397345900535583
1.4822911024093628


  1%|█▌                                                                                                                                                                                                                                    | 22/3125 [13:07<30:10:04, 35.00s/it]