## Imports

In [1]:
import time
import argparse
import torch

from tqdm import tqdm
from torch.optim import Adam
from pathlib import Path
from types import SimpleNamespace
from torch_geometric.data import Data, Batch, DataLoader
from torch.utils.data import Dataset
from scripts.eval_utils import load_model, lattices_to_params_shape, get_crystals_list
from diffcsp.common.constants import CompScalerMeans
from diffcsp.common.constants import SpaceGroupDist

from pymatgen.core.sites import PeriodicSite
from pymatgen.core import Structure, Lattice

from pymatgen.symmetry.analyzer import SpacegroupAnalyzer
from pymatgen.symmetry.groups import SpaceGroup, sg_symbol_from_int_number
from pymatgen.io.cif import CifWriter
from pyxtal.symmetry import Group
import chemparse
import numpy as np
from p_tqdm import p_map
import matplotlib.pyplot as plt
from itertools import product

import nglview
import numpy as np


import pdb
import os



The version_base parameter is not specified.
Please specify a compatability version level, or None.
Will assume defaults for version 1.1
  @hydra.main(config_path=str(PROJECT_ROOT / "conf"), config_name="default")
The version_base parameter is not specified.
Please specify a compatability version level, or None.
Will assume defaults for version 1.1
  @hydra.main(config_path=str(PROJECT_ROOT / "conf"), config_name="default")


## Create dataset to sample from

In [2]:
train_dist = {
    'perov' : [0, 0, 0, 0, 0, 1],
    'carbon' : [0.0,
                0.0,
                0.0,
                0.0,
                0.0,
                0.0,
                0.3250697750779839,
                0.0,
                0.27795107535708424,
                0.0,
                0.15383352487276308,
                0.0,
                0.11246100804465604,
                0.0,
                0.04958134953209654,
                0.0,
                0.038745690362830404,
                0.0,
                0.019044491873255624,
                0.0,
                0.010178952552946971,
                0.0,
                0.007059596125430964,
                0.0,
                0.006074536200952225],
    'mp' : [0.0,
            0.0021742334905660377,
            0.021079009433962265,
            0.019826061320754717,
            0.15271226415094338,
            0.047132959905660375,
            0.08464770047169812,
            0.021079009433962265,
            0.07808814858490566,
            0.03434551886792453,
            0.0972877358490566,
            0.013303360849056603,
            0.09669811320754718,
            0.02155807783018868,
            0.06522700471698113,
            0.014372051886792452,
            0.06703272405660378,
            0.00972877358490566,
            0.053176591981132074,
            0.010576356132075472,
            0.08995430424528301]
}

class SampleDataset(Dataset):

    def __init__(self, dataset, total_num):
        super().__init__()
        self.total_num = total_num
        self.distribution = train_dist[dataset]
        self.sg_distribution = SpaceGroupDist[dataset]

        self.num_atoms = np.random.choice(len(self.distribution), total_num, p = self.distribution)
        self.sg = np.random.choice(len(self.sg_distribution), total_num, p = self.sg_distribution)
        self.is_carbon = dataset == 'carbon'

    def __len__(self) -> int:
        return self.total_num

    def __getitem__(self, index):

        num_atom = self.num_atoms[index]
        data = Data(
            num_atoms=torch.LongTensor([num_atom]),
            num_nodes=num_atom,
            spacegroup=torch.LongTensor([self.sg[index]]),
        )
        if self.is_carbon:
            data.atom_types = torch.LongTensor([6] * num_atom)
        return data

def get_num_atoms_per_sg(dataset):
    if dataset == 'perov':
        dataset_name = 'perov_4'
    elif dataset == 'carbon':
        dataset_name = 'carbon_24'
    elif dataset == 'mp':
        dataset_name = 'mp_20'
    else:
        raise NotImplementedError
    dist_file = f'./data/{dataset_name}/num_atoms_per_sg.csv'
    num_atoms_per_sg = np.loadtxt(dist_file, delimiter=',')
    return num_atoms_per_sg
    

In [None]:

batch_size = 1
num_batches_to_sample = 10
test_set = SampleDataset("mp", batch_size * num_batches_to_sample, True)
test_loader = DataLoader(test_set, batch_size = batch_size)


## Visualize results

In [3]:
def plot3d(structure, spacefill=True, show_axes=True):
    eps = 1e-8
    sites = []
    for site in structure:
        species = site.species
        frac_coords = np.remainder(site.frac_coords, 1)
        for jimage in product([0, 1 - eps], repeat=3):
            new_frac_coords = frac_coords + np.array(jimage)
            if np.all(new_frac_coords < 1 + eps):
                new_site = PeriodicSite(species=species, coords=new_frac_coords, lattice=structure.lattice)
                sites.append(new_site)
    structure_display = Structure.from_sites(sites)
    
    view = nglview.show_pymatgen(structure_display)
    view.add_unitcell()
    
    if spacefill:
        view.add_spacefill(radius_type='vdw', radius=0.5, color_scheme='element')
        view.remove_ball_and_stick()
    else:
        view.add_ball_and_stick()
        
    if show_axes:
        view.shape.add_arrow([-4, -4, -4], [0, -4, -4], [1, 0, 0], 0.5, "x-axis")
        view.shape.add_arrow([-4, -4, -4], [-4, 0, -4], [0, 1, 0], 0.5, "y-axis")
        view.shape.add_arrow([-4, -4, -4], [-4, -4, 0], [0, 0, 1], 0.5, "z-axis")
        
    view.camera = "perspective"
    return view



A Jupyter Widget

## Load model

In [7]:
model_path = Path("/home/mila/d/daniel.levy/scratch/MatSci/intel-mat-diffusion/hydra/singlerun/2023-12-22/mp_ks_cond_sg/")
model, _, cfg = load_model(
    model_path, load_data=False)
model = model.to('cuda')

The version_base parameter is not specified.
Please specify a compatability version level, or None.
Will assume defaults for version 1.1
  with initialize_config_dir(str(model_path)):


## Code for sampling

In [10]:
from diffcsp.common.data_utils import (
    EPSILON, cart_to_frac_coords, mard, lengths_angles_to_volume, lattice_params_to_matrix_torch,
    frac_to_cart_coords, min_distance_sqr_pbc, lattice_ks_to_matrix_torch, sg_to_ks_mask, mask_ks, N_SPACEGROUPS)
MAX_ATOMIC_NUM=100



def sample(self, batch, diff_ratio = 1.0, step_lr = 1e-5):

    batch_size = batch.num_graphs
    if self.use_spacegroup and self.use_ks:
            ks_mask, ks_add = sg_to_ks_mask(batch.spacegroup)

    if self.use_ks:
        k_T = torch.randn([batch_size, 6]).to(self.device)
        if self.use_spacegroup:
            k_T = mask_ks(k_T, ks_mask, ks_add)
        l_T = lattice_ks_to_matrix_torch(k_T)
    else:
        l_T = torch.randn([batch_size, 3, 3]).to(self.device)
        k_T = torch.zeros([batch_size, 6]).to(self.device) # dummy 
    x_T = torch.rand([batch.num_nodes, 3]).to(self.device)
    t_T = torch.randn([batch.num_nodes, MAX_ATOMIC_NUM]).to(self.device)


    if self.keep_coords:
        x_T = batch.frac_coords

    if self.keep_lattice:
        k_T = batch.ks
        l_T = lattice_params_to_matrix_torch(batch.lengths, batch.angles)     

    traj = {self.beta_scheduler.timesteps : {
        'num_atoms' : batch.num_atoms,
        'atom_types' : t_T,
        'frac_coords' : x_T % 1.,
        'lattices' : l_T,
        'ks': k_T
    }}

    for t in tqdm(range(self.beta_scheduler.timesteps, 0, -1)):

        times = torch.full((batch_size, ), t, device = self.device)

        time_emb = self.time_embedding(times)
        
        alphas = self.beta_scheduler.alphas[t]
        alphas_cumprod = self.beta_scheduler.alphas_cumprod[t]

        sigmas = self.beta_scheduler.sigmas[t]
        sigma_x = self.sigma_scheduler.sigmas[t]
        sigma_norm = self.sigma_scheduler.sigmas_norm[t]

        c0 = 1.0 / torch.sqrt(alphas)
        c1 = (1 - alphas) / torch.sqrt(1 - alphas_cumprod)

        x_t = traj[t]['frac_coords']
        l_t = traj[t]['lattices']
        k_t = traj[t]['ks']
        t_t = traj[t]['atom_types']

        if self.keep_coords:
            x_t = x_T

        if self.keep_lattice:
            k_t = k_T
            l_t = l_T

        # Corrector
        # For whatever reason, lattice parameters are not updated in the original code.
        if self.use_ks:
            rand_k = torch.randn_like(k_T) if t > 1 else torch.zeros_like(k_T)
        else:
            rand_l = torch.randn_like(l_T) if t > 1 else torch.zeros_like(l_T)
        rand_t = torch.randn_like(t_T) if t > 1 else torch.zeros_like(t_T)
        rand_x = torch.randn_like(x_T) if t > 1 else torch.zeros_like(x_T)

        step_size = step_lr * (sigma_x / self.sigma_scheduler.sigma_begin) ** 2
        std_x = torch.sqrt(2 * step_size)

        lattice_feats_t = k_t if self.use_ks else l_t
        pred_lattice, pred_x, pred_t = self.decoder(time_emb, t_t, x_t, lattice_feats_t, l_t, batch.num_atoms, batch.batch, batch.spacegroup)

        pred_x = pred_x * torch.sqrt(sigma_norm)

        x_t_minus_05 = x_t - step_size * pred_x + std_x * rand_x if not self.keep_coords else x_t
        k_t_minus_05 = k_t
        l_t_minus_05 = l_t

        t_t_minus_05 = t_t


        # Predictor
        if self.use_ks:
            rand_k = torch.randn_like(k_T) if t > 1 else torch.zeros_like(k_T)
        else:
            rand_l = torch.randn_like(l_T) if t > 1 else torch.zeros_like(l_T)
        rand_t = torch.randn_like(t_T) if t > 1 else torch.zeros_like(t_T)
        rand_x = torch.randn_like(x_T) if t > 1 else torch.zeros_like(x_T)

        adjacent_sigma_x = self.sigma_scheduler.sigmas[t-1] 
        step_size = (sigma_x ** 2 - adjacent_sigma_x ** 2)
        std_x = torch.sqrt((adjacent_sigma_x ** 2 * (sigma_x ** 2 - adjacent_sigma_x ** 2)) / (sigma_x ** 2))   
        lattice_feats_t_minus_05 = k_t_minus_05 if self.use_ks else l_t_minus_05

        pred_lattice, pred_x, pred_t = self.decoder(time_emb, t_t_minus_05, x_t_minus_05, lattice_feats_t_minus_05, l_t_minus_05, batch.num_atoms, batch.batch, batch.spacegroup)

        pred_x = pred_x * torch.sqrt(sigma_norm)

        x_t_minus_1 = x_t_minus_05 - step_size * pred_x + std_x * rand_x if not self.keep_coords else x_t
        if self.use_ks:
            k_t_minus_1 = c0 * (k_t_minus_05 - c1 * pred_lattice) + sigmas * rand_k if not self.keep_lattice else k_t
            if self.use_spacegroup and not self.keep_lattice:
                k_t_minus_1 = mask_ks(k_t_minus_1, ks_mask, ks_add)
            l_t_minus_1 = lattice_ks_to_matrix_torch(k_t_minus_1) if not self.keep_lattice else l_t
        else:
            l_t_minus_1 = c0 * (l_t_minus_05 - c1 * pred_lattice) + sigmas * rand_l if not self.keep_lattice else l_t
            k_t_minus_1 = k_t

        t_t_minus_1 = c0 * (t_t_minus_05 - c1 * pred_t) + sigmas * rand_t

        traj[t - 1] = {
            'num_atoms' : batch.num_atoms,
            'atom_types' : t_t_minus_1,
            'frac_coords' : x_t_minus_1 % 1.,
            'lattices' : l_t_minus_1,
            'ks': k_t_minus_1,
        }
        if self.use_spacegroup:
            traj[t - 1]['spacegroup'] = batch.spacegroup

    traj_stack = {
        'num_atoms' : batch.num_atoms,
        'atom_types' : torch.stack([traj[i]['atom_types'] for i in range(self.beta_scheduler.timesteps, -1, -1)]).argmax(dim=-1) + 1,
        'all_frac_coords' : torch.stack([traj[i]['frac_coords'] for i in range(self.beta_scheduler.timesteps, -1, -1)]),
        'all_lattices' : torch.stack([traj[i]['lattices'] for i in range(self.beta_scheduler.timesteps, -1, -1)]),
        'all_ks': torch.stack([traj[i]['ks'] for i in range(self.beta_scheduler.timesteps, -1, -1)])
    }

    return traj[0], traj_stack


## Sample crystal

In [307]:

iter_test_loader = iter(test_loader)
batch = next(iter_test_loader)
batch.to('cuda')
print("Num atoms: ", batch.num_atoms.item())
print("Spacegroup: ", batch.spacegroup.item())
outputs, traj = sample(model, batch, 1.0, 1e-5)
def outputs_to_structure(output, batch_i=0):
    lattice = output['lattices'][batch_i].detach().cpu().numpy()
    atom_types = output['atom_types'].detach().cpu().numpy().argmax(1)+1
    frac_coords = output['frac_coords'].detach().cpu().numpy()
    return Structure(lattice, atom_types, frac_coords)

structure = outputs_to_structure(outputs)
spacegroup = outputs['spacegroup'].item()
print(structure)


tensor([5], device='cuda:0')
tensor([164], device='cuda:0')
DataBatch(num_atoms=[1], num_nodes=5, spacegroup=[1], batch=[5], ptr=[2])


100%|██████████| 1000/1000 [00:12<00:00, 80.04it/s]


### Look at resulting crystal

In [309]:

spga = SpacegroupAnalyzer(structure, symprec=0.1)
real_sg = spga.get_space_group_number()
print("real sg: ", real_sg)
print("target sg: ", spacegroup)
plot3d(structure, spacefill=True)



real sg:  164
target sg:  164


A Jupyter Widget

In [None]:
symmops = list(SpaceGroup(sg_symbol_from_int_number(spacegroup)).symmetry_ops)
symmops_real = spga.get_symmetry_operations()

: 

In [313]:
def apply_symmop(symmop, structure, frac_coords=None):
    if frac_coords is None:
        frac_coords = structure.frac_coords
    new_frac_coords = symmop.operate_multi(frac_coords) % 1
    return Structure(structure.lattice, structure.species, new_frac_coords)

plot3d(apply_symmop(symmops[3], structure), spacefill=True)

A Jupyter Widget

In [314]:
from diffcsp.common.utils import SinkhornDistance
import torch
sinkhorn = SinkhornDistance(eps = 0.1, max_iter=100)

In [315]:
def apply_symmop_torch(affine_matrix, frac_coords):
    ones_vec = torch.ones(structure.frac_coords.shape[0], 1, device=frac_coords.device)
    affine_points = torch.cat([frac_coords, ones_vec], 1)
    return torch.inner(affine_matrix, affine_points)[:-1].T%1


def apply_symmop(symmop, structure, frac_coords=None):
    if frac_coords is None:
        frac_coords = structure.frac_coords
    new_frac_coords = symmop.operate_multi(frac_coords) % 1
    return Structure(structure.lattice, structure.species, new_frac_coords)


In [317]:
frac_coords = structure.frac_coords
print(frac_coords)
frac_coords_torch = torch.Tensor(frac_coords)

[[0.9200089  0.38620993 0.26396334]
 [0.2531947  0.05177674 0.76721644]
 [0.92101866 0.3859573  0.64865893]
 [0.5892098  0.7189829  0.01634433]
 [0.25398096 0.05229295 0.38554794]]


### Use symmetry to loss to symmetrize crystal

In [354]:
from torch.autograd import grad
gradient_factor = 0.1

gradients = []
losses = []
x = frac_coords_torch.clone()
x_syms = [x]
symmops_torch = [torch.Tensor(symmop.affine_matrix, device=x.device) for symmop in symmops]
    

In [355]:

for i in tqdm(range(10)):
    x.requires_grad=True
    loss = torch.zeros(1, requires_grad=True)
    loss_list = []

    for symmop_i in range(len(symmops)):
        x_sym = apply_symmop_torch(symmops_torch[symmop_i], x)
        dist, P, C = sinkhorn(x, x_sym)
        loss_i = dist
        loss = loss + (1/2)*loss_i
        loss_list.append(loss_i)

    gradient = grad(loss, x, allow_unused=False)
    x = (x - gradient_factor*gradient[0]).detach() % 1
    gradients.append(gradient[0].detach())
    x_syms.append(x.detach())
    losses.append(loss.item())


100%|██████████| 100/100 [00:01<00:00, 60.75it/s]


In [None]:
plt.plot(np.array(losses))

#### Visualize result

In [363]:
sym_structure = Structure(structure.lattice, structure.species, x_syms[-1])
print(sym_structure)
print("spacegroup: ", SpacegroupAnalyzer(sym_structure, symprec=0.1).get_space_group_number())
plot3d(sym_structure, spacefill=True)

Full Formula (Ca1 Mg2 As2)
Reduced Formula: Ca(MgAs)2
abc   :   6.759699   5.424852   7.598948
angles:  90.000000  90.000000 153.670496
pbc   :       True       True       True
Sites (5)
  #  SP           a         b    c
---  ----  --------  --------  ---
  0  As    0.010886  0.48845     0
  1  As    0.488468  0.007997    1
  2  Mg    0.992003  0.511532    1
  3  Ca    0.503169  0.496831    0
  4  Mg    0.51155   0.989114    0
