In [1]:
from qm9 import dataset
from dataclasses import dataclass, field

In [2]:
from qm9.models import get_model

In [3]:
from configs.datasets_config import get_dataset_info

In [31]:
import torch

In [33]:
from os.path import join

In [49]:
from equivariant_diffusion import utils as diffusion_utils

In [50]:
from qm9.analyze import check_stability

In [62]:
import qm9.visualizer as vis

## Load dataset and dataloader

---

Load the dataset and the dataloader for the QM9 dataset. We can look at the input representation of the model and how a batch of data is represented. 

In [34]:
@dataclass
class DatasetConfigs:
    datadir: str = "qm9/temp"
    dataset: str = "qm9"
    num_workers: int = 1
    remove_h: bool = False
    force_download: bool = False
    subtract_thermo: bool = True
    batch_size: int = 16
    filter_n_atoms: int = None
    include_charges: bool = True
    shuffle: bool = True

In [5]:
cfg = DatasetConfigs()

In [7]:
dataloaders, charge_scale = dataset.retrieve_dataloaders(cfg)

In [8]:
loader = dataloaders['train']

In [9]:
data = next(iter(dataloaders['train']))

In [10]:
print(data['atom_mask'].shape)
print(data['edge_mask'].shape)
print(data['charges'].shape)

torch.Size([16, 25])
torch.Size([10000, 1])
torch.Size([16, 25, 1])


In [11]:
data['atom_mask'].unsqueeze(1).shape

torch.Size([16, 1, 25])

In [12]:
data['atom_mask'].unsqueeze(2).shape

torch.Size([16, 25, 1])

In [13]:
data.keys()

dict_keys(['num_atoms', 'charges', 'positions', 'index', 'A', 'B', 'C', 'mu', 'alpha', 'homo', 'lumo', 'gap', 'r2', 'zpve', 'U0', 'U', 'H', 'G', 'Cv', 'omega1', 'zpve_thermo', 'U0_thermo', 'U_thermo', 'H_thermo', 'G_thermo', 'Cv_thermo', 'one_hot', 'atom_mask', 'edge_mask'])

In [14]:
data['positions'].shape

torch.Size([16, 25, 3])

## Decompose sample workflow
---

Decompose the original sample workflow. We update the sampling workflow to do interpolation rather than entirely unconditional. 

In [None]:
dataset_info = get_dataset_info("qm9", False)

In [None]:
device='cuda'

In [25]:
@dataclass
class ModelConfigs:
    model: str = "egnn_dynamics"
    probabilistic_model: str = "diffusion"
    diffusion_steps: int = 1000
    diffusion_noise_schedule: str = "polynomial_2"
    diffusion_noise_precision: float =1e-5
    include_charges: bool = True
    conditioning: tuple = ()
    n_layers: int = 9
    ema_decay: float = 0.9999
    normalize_factors: tuple = (1,4,10)
    attention: bool = True
    inv_sublayers: int = 1
    tanh: bool = True
    condition_time: bool = True
    aggregation_method: str = 'sum'
    sin_embedding: bool = False
    norm_constant: float = 1
    normalization_factor: float = 1
    context_node_nf: int = 0
    nf: int = 256
    diffusion_loss_type: str = 'l2' 
    actnorm: bool = True

In [26]:
args = ModelConfigs()

In [27]:
flow, nodes_dist, prop_dist = get_model(
        args, device, dataset_info, dataloaders['train'])

Entropy of n_nodes: H[N] -2.475700616836548
alphas2 [9.99990000e-01 9.99988000e-01 9.99982000e-01 ... 2.59676966e-05
 1.39959211e-05 1.00039959e-05]
gamma [-11.51291546 -11.33059532 -10.92513058 ...  10.55863126  11.17673063
  11.51251595]


In [28]:
flow

EnVariationalDiffusion(
  (gamma): PredefinedNoiseSchedule()
  (dynamics): EGNN_dynamics_QM9(
    (egnn): EGNN(
      (embedding): Linear(in_features=7, out_features=256, bias=True)
      (embedding_out): Linear(in_features=256, out_features=7, bias=True)
      (e_block_0): EquivariantBlock(
        (gcl_0): GCL(
          (edge_mlp): Sequential(
            (0): Linear(in_features=514, out_features=256, bias=True)
            (1): SiLU()
            (2): Linear(in_features=256, out_features=256, bias=True)
            (3): SiLU()
          )
          (node_mlp): Sequential(
            (0): Linear(in_features=512, out_features=256, bias=True)
            (1): SiLU()
            (2): Linear(in_features=256, out_features=256, bias=True)
          )
          (att_mlp): Sequential(
            (0): Linear(in_features=256, out_features=1, bias=True)
            (1): Sigmoid()
          )
        )
        (gcl_equiv): EquivariantUpdate(
          (coord_mlp): Sequential(
            (0):

In [None]:
# Get the saved args of a model from the args.pickle file
#with open(join(eval_args.model_path, 'args.pickle'), 'rb') as f:
#        args = pickle.load(f)

### Categorical 

`nodes_dist` is a categorical distribution that can be sampled to determine the molecule size. This refers to `P(M)` on the paper.  

In [29]:
nodes_dist

<qm9.models.DistributionNodes at 0x7f3416c083a0>

In [36]:
flow.to(device)
model_path = 'outputs/edm_qm9'
fn = 'generative_model_ema.npy'
flow_state_dict = torch.load(join(model_path, fn),
                             map_location=device)

In [39]:
nodes_dist.sample(16)

tensor([19, 23, 10, 16, 18, 17, 20, 21, 21, 20, 19, 19, 14, 16, 25, 17])

In [46]:
flow.load_state_dict(flow_state_dict)

<All keys matched successfully>

## Sampling a latent

Recreating the `sample` function in `qm.sampling` 

In [66]:
batch_size = 128
max_n_nodes = 25
node_mask = torch.zeros(batch_size, max_n_nodes).to(device)
mol_sizes = nodes_dist.sample(batch_size).to(device)

In [67]:
mol_sizes

tensor([16, 17, 18, 17, 19, 20, 21, 23, 18, 17, 20, 17, 20, 19, 17, 21, 14, 18,
        18, 21, 16, 17, 24, 15, 20, 19, 14, 21, 19, 23, 15, 19, 18, 23, 21, 17,
        16, 18, 23, 24, 16, 10, 19, 24, 19, 19, 19, 11, 15, 21, 19, 23, 23, 17,
        23, 14, 22, 13, 17, 19, 12, 15, 14, 17, 17, 14, 15, 17, 20, 18, 21, 17,
        18, 19, 15, 19, 21, 13, 17, 18, 18, 21, 13, 17, 18, 17, 17, 20, 18, 19,
        18, 23, 19, 16, 17, 19, 16, 18, 19, 23, 16, 18, 19, 23, 21, 16, 21, 18,
        19, 23, 18, 17, 18, 15, 16, 19, 14, 23, 18, 19, 17, 19, 21, 16, 13, 14,
        20, 18], device='cuda:0')

In [68]:
for i in range(batch_size):
    node_mask[i, 0:mol_sizes[i]] = 1 # 0 padding of non-existent atoms

In [69]:
# Compute edge_mask
edge_mask = node_mask.unsqueeze(1) * node_mask.unsqueeze(2).to(device)
diag_mask = ~torch.eye(edge_mask.size(1), dtype=torch.bool).unsqueeze(0).to(device)
edge_mask *= diag_mask
edge_mask = edge_mask.view(batch_size * max_n_nodes * max_n_nodes, 1)
node_mask = node_mask.unsqueeze(2)

In [70]:
context = None
fix_noise = False

In [71]:
dataset_info = {
    'name': 'qm9',
    'atom_encoder': {'H': 0, 'C': 1, 'N': 2, 'O': 3, 'F': 4},
    'atom_decoder': ['H', 'C', 'N', 'O', 'F'],
    'n_nodes': {22: 3393, 17: 13025, 23: 4848, 21: 9970, 19: 13832, 20: 9482, 16: 10644, 13: 3060,
                15: 7796, 25: 1506, 18: 13364, 12: 1689, 11: 807, 24: 539, 14: 5136, 26: 48, 7: 16, 10: 362,
                8: 49, 9: 124, 27: 266, 4: 4, 29: 25, 6: 9, 5: 5, 3: 1},
    'max_n_nodes': 29,
    'atom_types': {1: 635559, 2: 101476, 0: 923537, 3: 140202, 4: 2323},
    'distances': [903054, 307308, 111994, 57474, 40384, 29170, 47152, 414344, 2202212, 573726,
                  1490786, 2970978, 756818, 969276, 489242, 1265402, 4587994, 3187130, 2454868, 2647422,
                  2098884,
                  2001974, 1625206, 1754172, 1620830, 1710042, 2133746, 1852492, 1415318, 1421064, 1223156,
                  1322256,
                  1380656, 1239244, 1084358, 981076, 896904, 762008, 659298, 604676, 523580, 437464, 413974,
                  352372,
                  291886, 271948, 231328, 188484, 160026, 136322, 117850, 103546, 87192, 76562, 61840,
                  49666, 43100,
                  33876, 26686, 22402, 18358, 15518, 13600, 12128, 9480, 7458, 5088, 4726, 3696, 3362, 3396,
                  2484,
                  1988, 1490, 984, 734, 600, 456, 482, 378, 362, 168, 124, 94, 88, 52, 44, 40, 18, 16, 8, 6,
                  2,
                  0, 0, 0, 0,
                  0,
                  0, 0],
    'colors_dic': ['#FFFFFF99', 'C7', 'C0', 'C3', 'C1'],
    'radius_dic': [0.46, 0.77, 0.77, 0.77, 0.77],
    'with_h': True}

In [72]:
min_valid_samples = 100

num_valid = 0

flow.eval()

root_path = "/home/szaman5/e3_diffusion_for_molecules"
while num_valid < min_valid_samples:
    # x, h = flow.sample(batch_size, max_n_nodes, node_mask, edge_mask, context, fix_noise=fix_noise)
    with torch.no_grad():
        z = flow.sample_combined_position_feature_noise(batch_size,
                                                        max_n_nodes,
                                                        node_mask)
        z = z.to(device)
        for s in reversed(range(0, flow.T)):
            s_array = torch.full((batch_size, 1), fill_value=s, device=device)
            t_array = s_array + 1
            s_array = s_array / flow.T
            t_array = t_array / flow.T
            # p(z_{t-1} | z_t)
            z = flow.sample_p_zs_given_zt(s_array,
                                          t_array,
                                          z,
                                          node_mask,
                                          edge_mask,
                                          context,
                                          fix_noise=fix_noise)
        # p(x, h | z_0)
        x, h = flow.sample_p_xh_given_z0(z,
                                         node_mask,
                                         edge_mask,
                                         context,
                                         fix_noise=fix_noise)
        max_cog = torch.sum(x, dim=1, keepdim=True).abs().max().item()
        if max_cog > 5e-2:
            print(f'Warning cog drift with error {max_cog:.3f}. Projecting '
                  f'the positions down.')
            x = diffusion_utils.remove_mean_with_mask(x, node_mask)

        one_hot = h['categorical']
        charges = h['integer']
        
        for i in range(batch_size):
            num_atoms = int(node_mask[i:i+1].sum().item())
            atom_type = one_hot[i:i+1, :num_atoms].argmax(2).squeeze(0).cpu().detach().numpy()
            x_squeeze = x[i:i+1, :num_atoms].squeeze(0).cpu().detach().numpy()
            mol_stable = check_stability(x_squeeze, atom_type, dataset_info)[0]
            
            if (mol_stable):
                print('Found stable mol.')
                vis.save_xyz_file(
                    join(root_path, 'eval/molecules/'),
                    one_hot[i:i+1], charges[i:i+1], x[i:i+1],
                    id_from=num_valid, name='molecule_stable',
                    dataset_info=dataset_info,
                    node_mask=node_mask[i:i+1])
                num_valid += 1

Found stable mol.
Found stable mol.
Found stable mol.
Found stable mol.
Found stable mol.
Found stable mol.
Found stable mol.
Found stable mol.
Found stable mol.
Found stable mol.
Found stable mol.
Found stable mol.
Found stable mol.
Found stable mol.
Found stable mol.
Found stable mol.
Found stable mol.
Found stable mol.
Found stable mol.
Found stable mol.
Found stable mol.
Found stable mol.
Found stable mol.
Found stable mol.
Found stable mol.
Found stable mol.
Found stable mol.
Found stable mol.
Found stable mol.
Found stable mol.
Found stable mol.
Found stable mol.
Found stable mol.
Found stable mol.
Found stable mol.
Found stable mol.
Found stable mol.
Found stable mol.
Found stable mol.
Found stable mol.
Found stable mol.
Found stable mol.
Found stable mol.
Found stable mol.
Found stable mol.
Found stable mol.
Found stable mol.
Found stable mol.
Found stable mol.
Found stable mol.
Found stable mol.
Found stable mol.
Found stable mol.
Found stable mol.
Found stable mol.
Found stab