In [1]:
import nglview
from pymatgen.core import Structure, Lattice
import numpy as np



In [2]:
def plot3d(structure, spacefill=True, show_axes=True):
    from itertools import product
    from pymatgen.core import Structure
    from pymatgen.core.sites import PeriodicSite
    
    eps = 1e-8
    sites = []
    for site in structure:
        species = site.species
        frac_coords = np.remainder(site.frac_coords, 1)
        for jimage in product([0, 1 - eps], repeat=3):
            new_frac_coords = frac_coords + np.array(jimage)
            if np.all(new_frac_coords < 1 + eps):
                new_site = PeriodicSite(species=species, coords=new_frac_coords, lattice=structure.lattice)
                sites.append(new_site)
    structure_display = Structure.from_sites(sites)
    
    view = nglview.show_pymatgen(structure_display)
    view.add_unitcell()
    
    if spacefill:
        view.add_spacefill(radius_type='vdw', radius=0.5, color_scheme='element')
        view.remove_ball_and_stick()
    else:
        view.add_ball_and_stick()
        
    if show_axes:
        view.shape.add_arrow([-4, -4, -4], [0, -4, -4], [1, 0, 0], 0.5, "x-axis")
        view.shape.add_arrow([-4, -4, -4], [-4, 0, -4], [0, 1, 0], 0.5, "y-axis")
        view.shape.add_arrow([-4, -4, -4], [-4, -4, 0], [0, 0, 1], 0.5, "z-axis")
        
    view.camera = "perspective"
    return view


In [None]:
a = 4.653
c = 2.969
x_4f = 0.3046

lattice = Lattice.from_parameters(a, a, c, 90, 90, 90)
species = ["Ti", "Ti", "O", "O", "O", "O"]
frac_coords = np.array([
    [0, 0, 0],                      # Ti(2a)
    [0.5, 0.5, 0.5],                # Ti(2a)
    [x_4f, x_4f, 0],                # O(4f)
    [1 - x_4f, 1 - x_4f, 0],        # O(4f)
    [0.5 - x_4f, 0.5 + x_4f, 0.5],  # O(4f)
    [0.5 + x_4f, 0.5 - x_4f, 0.5],  # O(4f)
])
structure = Structure(lattice, species, frac_coords)

structure

In [None]:
plot3d(structure, spacefill=True)

In [8]:
import torch as th
eval_gen = th.load('/home/mila/s/siba-smarak.panigrahi/scratch/DiffCSP/hydra/singlerun/2023-12-15/perov_w_symm_type/eval_gen_final_num_samples_20.pt')

print(eval_gen.keys())
num_atoms = eval_gen["num_atoms"][0].numpy()
frac_coords = eval_gen['frac_coords'][0:5].numpy()
lengths = eval_gen["lengths"][0].numpy()
angles = eval_gen["angles"][0].numpy()
species = eval_gen["atom_types"][0:5].argmax(dim=1).numpy() + 1

print(frac_coords)

from pymatgen.core.periodic_table import Element

lattice = Lattice.from_parameters(lengths[0], lengths[1], lengths[2], angles[0], angles[1], angles[2])
species = [Element.from_Z(x).symbol for x in species]
structure = Structure(lattice, species, frac_coords)
structure

dict_keys(['eval_setting', 'frac_coords', 'num_atoms', 'atom_types', 'lengths', 'angles'])
[[0.5        0.         0.02193174]
 [0.5        0.5        0.5197196 ]
 [0.         0.         0.02419412]
 [0.         0.5        0.5816288 ]
 [0.5        0.5        0.02181946]]


Structure Summary
Lattice
    abc : 4.599706170228435 4.3861853623413225 4.298975467681885
 angles : 91.17674165836917 87.98918121561209 87.8072945577717
 volume : 86.59486984166713
      A : 4.596873760223389 0.0 0.16139543056488037
      B : 0.17108403146266937 4.381921768188477 -0.09007721394300461
      C : 0.0 0.0 4.298975467681885
    pbc : True True True
PeriodicSite: O (2.298, 0.0, 0.175) [0.5, 0.0, 0.02193]
PeriodicSite: Ag (2.384, 2.191, 2.27) [0.5, 0.5, 0.5197]
PeriodicSite: V (0.0, 0.0, 0.104) [0.0, 0.0, 0.02419]
PeriodicSite: S (0.08554, 2.191, 2.455) [0.0, 0.5, 0.5816]
PeriodicSite: O (2.384, 2.191, 0.1295) [0.5, 0.5, 0.02182]

In [4]:
from scripts.eval_utils import structure_validity
structure_validity(structure)

The version_base parameter is not specified.
Please specify a compatability version level, or None.
Will assume defaults for version 1.1
  @hydra.main(config_path=str(PROJECT_ROOT / "conf"), config_name="default")
The version_base parameter is not specified.
Please specify a compatability version level, or None.
Will assume defaults for version 1.1
  @hydra.main(config_path=str(PROJECT_ROOT / "conf"), config_name="default")


True

In [7]:
plot3d(structure, spacefill=True)

NGLWidget()