In [1]:
import graph_tool as gt
import torch
import pickle
from src.diffusion_model_discrete import DiscreteDenoisingDiffusion
from sample import get_model_sbm
from src.diffusion import diffusion_utils
import networkx as nx
import matplotlib.pyplot as plt
from src.analysis.visualization import NonMolecularVisualization
import torch.nn.functional as F



In [2]:
model = get_model_sbm()


	Using the fallback 'C' locale.


Marginal distribution of the classes: tensor([1.]) for nodes, tensor([0.9156, 0.0844]) for edges


In [3]:
batch_size = 1
n_nodes = model.node_dist.sample_n(batch_size, model.device)
print(n_nodes)

tensor([104], device='cuda:0')


In [4]:
node_mask = torch.ones((batch_size, n_nodes.item()), device=model.device, dtype=torch.bool)
print(node_mask.shape)


torch.Size([1, 104])


In [5]:
z_T = diffusion_utils.sample_discrete_feature_noise_with_message(limit_dist=model.limit_dist, node_mask=node_mask)
X, E, y = z_T.X, z_T.E, z_T.y
print(E.size())

torch.Size([1, 104, 104, 2])


In [6]:
def to_networknx(E):
    adj_matrix = torch.argmax(E, dim=-1).squeeze(0).cpu().numpy()
    G = nx.from_numpy_matrix(adj_matrix)
   # pos = nx.circular_layout(G)
   # nx.draw(G, pos, with_labels=False)
    return G
    

In [7]:
visualizer = NonMolecularVisualization()
visualizer.visualize_non_molecule(to_networknx(E), None, "Z_T.png")

  plt.tight_layout()


In [8]:
def visualize_graph_batch(l):
    plt.figure(figsize=(30, 6 * (len(l) // 5 + 1))) 
    for i, (step, E) in enumerate(l):
        plt.subplot(len(l) // 5 + 1, 5, i + 1)
        adj_matrix = torch.argmax(E, dim=-1).squeeze(0).cpu().numpy()
        G = nx.from_numpy_matrix(adj_matrix)
        pos = nx.circular_layout(G)
        nx.draw(G, pos, with_labels=False)
        plt.title(f'Step {step}')
    plt.tight_layout() 
    plt.show() 

In [9]:
def get_partial_result(s_ideal):    
    l = []
    z_T = diffusion_utils.sample_discrete_feature_noise_with_message(limit_dist=model.limit_dist, node_mask=node_mask)
    X, E, y = z_T.X, z_T.E, z_T.y
    for s_int in reversed(range(0, model.T)):
        s_array = s_int * torch.ones((batch_size, 1)).type_as(y)
        t_array = s_array + 1
        s_norm = s_array / model.T
        t_norm = t_array / model.T
        if torch.allclose(s_norm, s_ideal):
            return X,E,y
        sampled_s, _ = model.sample_p_zs_given_zt(s_norm, t_norm, X, E, y, node_mask)
        X, E, y = sampled_s.X, sampled_s.E, sampled_s.y
        if s_int % 50 == 0:
            l.append((s_int, E))
            visualizer.visualize_non_molecule(to_networknx(E), None, f"chains/Z_{s_int}.png")

In [10]:
s = torch.tensor([[0.9540]]).to('cuda')
t = torch.tensor([[0.9560]]).to('cuda')
X_t, E_t, y_t = get_partial_result(s)

In [11]:
bs, n, dxs = X_t.shape
beta_t = model.noise_schedule(t_normalized=t)
print(beta_t)

tensor([[0.0849]], device='cuda:0')


In [12]:
alpha_s_bar = model.noise_schedule.get_alpha_bar(t_normalized=s)
alpha_t_bar = model.noise_schedule.get_alpha_bar(t_normalized=t)
# Retrieve transitions matrix
Qtb = model.transition_model.get_Qt_bar(alpha_t_bar, model.device)
Qsb = model.transition_model.get_Qt_bar(alpha_s_bar, model.device)
Qt = model.transition_model.get_Qt(beta_t, model.device)

In [13]:
noisy_data = {'X_t': X_t, 'E_t': E_t, 'y_t': y_t, 't': t, 'node_mask': node_mask}
extra_data = model.compute_extra_data(noisy_data)
print(extra_data.X.size())
print(extra_data.E.size())

torch.Size([1, 104, 6])
torch.Size([1, 104, 104, 0])


In [14]:
pred = model.forward(noisy_data, extra_data, node_mask)
pred_X = F.softmax(pred.X, dim=-1)          
pred_E = F.softmax(pred.E, dim=-1)
print(pred_E)

tensor([[[[0.5000, 0.5000],
          [0.9095, 0.0905],
          [0.9093, 0.0907],
          ...,
          [0.9094, 0.0906],
          [0.9094, 0.0906],
          [0.9094, 0.0906]],

         [[0.9095, 0.0905],
          [0.5000, 0.5000],
          [0.9088, 0.0912],
          ...,
          [0.9089, 0.0911],
          [0.9089, 0.0911],
          [0.9089, 0.0911]],

         [[0.9093, 0.0907],
          [0.9088, 0.0912],
          [0.5000, 0.5000],
          ...,
          [0.9086, 0.0914],
          [0.9086, 0.0914],
          [0.9068, 0.0932]],

         ...,

         [[0.9094, 0.0906],
          [0.9089, 0.0911],
          [0.9086, 0.0914],
          ...,
          [0.5000, 0.5000],
          [0.9087, 0.0913],
          [0.9087, 0.0913]],

         [[0.9094, 0.0906],
          [0.9089, 0.0911],
          [0.9086, 0.0914],
          ...,
          [0.9087, 0.0913],
          [0.5000, 0.5000],
          [0.9087, 0.0913]],

         [[0.9094, 0.0906],
          [0.9089, 0.0911],
    

In [15]:
p_s_and_t_given_0_X = diffusion_utils.compute_batched_over0_posterior_distribution(X_t=X_t,
                                                                                           Qt=Qt.X,
                                                                                           Qsb=Qsb.X,
                                                                                           Qtb=Qtb.X)

p_s_and_t_given_0_E = diffusion_utils.compute_batched_over0_posterior_distribution(X_t=E_t,
                                                                                           Qt=Qt.E,
                                                                                           Qsb=Qsb.E,
                                                                                           Qtb=Qtb.E)

In [16]:
# Dim of these two tensors: bs, N, d0, d_t-1
weighted_X = pred_X.unsqueeze(-1) * p_s_and_t_given_0_X         # bs, n, d0, d_t-1
unnormalized_prob_X = weighted_X.sum(dim=2)                     # bs, n, d_t-1
unnormalized_prob_X[torch.sum(unnormalized_prob_X, dim=-1) == 0] = 1e-5
prob_X = unnormalized_prob_X / torch.sum(unnormalized_prob_X, dim=-1, keepdim=True)  # bs, n, d_t-1

pred_E = pred_E.reshape((bs, -1, pred_E.shape[-1]))
weighted_E = pred_E.unsqueeze(-1) * p_s_and_t_given_0_E        # bs, N, d0, d_t-1
unnormalized_prob_E = weighted_E.sum(dim=-2)
unnormalized_prob_E[torch.sum(unnormalized_prob_E, dim=-1) == 0] = 1e-5
prob_E = unnormalized_prob_E / torch.sum(unnormalized_prob_E, dim=-1, keepdim=True)
prob_E = prob_E.reshape(bs, n, n, pred_E.shape[-1])

assert ((prob_X.sum(dim=-1) - 1).abs() < 1e-4).all()
assert ((prob_E.sum(dim=-1) - 1).abs() < 1e-4).all()

In [17]:

E1 = diffusion_utils.sample_discrete_features(prob_X.clone(), prob_E.clone(), node_mask=node_mask, seed=42).E
E2 = diffusion_utils.sample_discrete_features(prob_X.clone(), prob_E.clone(), node_mask=node_mask, seed=42).E

print(torch.equal(E1, E2))

True


In [18]:
def sample_p_zs_given_zt(self, s, t, X_t, E_t, y_t, node_mask):
    """Samples from zs ~ p(zs | zt). Only used during sampling.
       if last_step, return the graph prediction as well"""
    bs, n, dxs = X_t.shape
    beta_t = self.noise_schedule(t_normalized=t)  # (bs, 1)
    alpha_s_bar = self.noise_schedule.get_alpha_bar(t_normalized=s)
    alpha_t_bar = self.noise_schedule.get_alpha_bar(t_normalized=t)

    # Retrieve transitions matrix
    Qtb = self.transition_model.get_Qt_bar(alpha_t_bar, self.device)
    Qsb = self.transition_model.get_Qt_bar(alpha_s_bar, self.device)
    Qt = self.transition_model.get_Qt(beta_t, self.device)

    # Neural net predictions
    noisy_data = {'X_t': X_t, 'E_t': E_t, 'y_t': y_t, 't': t, 'node_mask': node_mask}
    extra_data = self.compute_extra_data(noisy_data)
    pred = self.forward(noisy_data, extra_data, node_mask)

    # Normalize predictions
    pred_X = F.softmax(pred.X, dim=-1)               # bs, n, d0
    pred_E = F.softmax(pred.E, dim=-1)               # bs, n, n, d0

    p_s_and_t_given_0_X = diffusion_utils.compute_batched_over0_posterior_distribution(X_t=X_t,
                                                                                       Qt=Qt.X,
                                                                                       Qsb=Qsb.X,
                                                                                       Qtb=Qtb.X)

    p_s_and_t_given_0_E = diffusion_utils.compute_batched_over0_posterior_distribution(X_t=E_t,
                                                                                       Qt=Qt.E,
                                                                                       Qsb=Qsb.E,
                                                                                       Qtb=Qtb.E)
    # Dim of these two tensors: bs, N, d0, d_t-1
    weighted_X = pred_X.unsqueeze(-1) * p_s_and_t_given_0_X         # bs, n, d0, d_t-1
    unnormalized_prob_X = weighted_X.sum(dim=2)                     # bs, n, d_t-1
    unnormalized_prob_X[torch.sum(unnormalized_prob_X, dim=-1) == 0] = 1e-5
    prob_X = unnormalized_prob_X / torch.sum(unnormalized_prob_X, dim=-1, keepdim=True)  # bs, n, d_t-1

    pred_E = pred_E.reshape((bs, -1, pred_E.shape[-1]))
    weighted_E = pred_E.unsqueeze(-1) * p_s_and_t_given_0_E        # bs, N, d0, d_t-1
    unnormalized_prob_E = weighted_E.sum(dim=-2)
    unnormalized_prob_E[torch.sum(unnormalized_prob_E, dim=-1) == 0] = 1e-5
    prob_E = unnormalized_prob_E / torch.sum(unnormalized_prob_E, dim=-1, keepdim=True)
    prob_E = prob_E.reshape(bs, n, n, pred_E.shape[-1])

    assert ((prob_X.sum(dim=-1) - 1).abs() < 1e-4).all()
    assert ((prob_E.sum(dim=-1) - 1).abs() < 1e-4).all()

    sampled_s = diffusion_utils.sample_discrete_features(prob_X, prob_E, node_mask=node_mask)

    X_s = F.one_hot(sampled_s.X, num_classes=self.Xdim_output).float()
    E_s = F.one_hot(sampled_s.E, num_classes=self.Edim_output).float()
    assert (E_s == torch.transpose(E_s, 1, 2)).all()
    assert (X_t.shape == X_s.shape) and (E_t.shape == E_s.shape)

    out_one_hot = utils.PlaceHolder(X=X_s, E=E_s, y=torch.zeros(y_t.shape[0], 0))
    out_discrete = utils.PlaceHolder(X=X_s, E=E_s, y=torch.zeros(y_t.shape[0], 0))

    return out_one_hot.mask(node_mask).type_as(y_t), out_discrete.mask(node_mask, collapse=True).type_as(y_t)