# Tutorial 2: 3D Visualization of molecules, electron densities and basis functions

The goal of this tutorial is to demonstrate how to visualize molecules and electron densities in 3D via molview.org and via Pyvista.

In [None]:
# import necessary packages
import os

import matplotlib.pyplot as plt
import numpy as np
import torch
from hydra import compose, initialize
from hydra.utils import instantiate

# this makes sure that code changes are reflected without restarting the notebook
# this can be helpful if you want to play around with the code in the repo
%load_ext autoreload
%autoreload 2

# omegaconf is used for configuration management
# omegaconf custom resolvers are small functions used in the config files like "get_len" to get lengths of lists
from mldft.utils import omegaconf_resolvers  # this registers omegaconf custom resolvers
from mldft.utils.molecules import build_molecule_ofdata

# download a small dataset from huggingface that contains QM9 and QMugs data (possibly already downloaded)
# and change the DFT_DATA environment variable to the directory where the data is stored

# https://huggingface.co/docs/datasets/cache#cache-directory
# The default cache directory is `~/.cache/huggingface/datasets`
# You can change it by setting this variable to any path you like
CACHE_DIR = None  # e.g. change it to "./hf_cache"

# clone the full repo
# https://huggingface.co/sciai-lab/structures25/tree/main
os.environ[
    "HF_HUB_DISABLE_PROGRESS_BARS"
] = "1"  # to avoid problems with the progress bar in some environments
from huggingface_hub import snapshot_download

data_path = snapshot_download(
    repo_id="sciai-lab/minimal_data_QM9_QMugs", cache_dir=CACHE_DIR, repo_type="dataset"
)

dft_data = os.environ.get("DFT_DATA", None)
os.environ["DFT_DATA"] = data_path
print(
    f"Environment variable DFT_DATA has been changed from {dft_data} to {os.environ['DFT_DATA']}."
)

In [None]:
# first we load our large config, instantiate the datamodule and obtain a single sample
with initialize(version_base=None, config_path="../../configs/ml"):
    config = compose(
        config_name="train.yaml",
    )

# remove the hydra specific stuff that only works in @hydra.main decorated functions
config.paths.output_dir = "example_path"

datamodule = instantiate(config.data.datamodule)
datamodule.setup(stage="fit")
sample = datamodule.train_set[0]

# need basis info to build a pySCF molecule object
# see below for more details on basis_info
basis_info = instantiate(config.data.basis_info)

# build a pySCF molecule object from the OFData sample
mol = build_molecule_ofdata(sample, basis=basis_info.basis_dict)

## 3D visualization based on [molview.org](molview.org)

In [None]:
# A simple way to visualize molecules in 3D is via molview.org.
# Note though that the displayed geometry is inferred from the SMILES string, so it does not exactly correspond to the geometry in the sample object.
# The reason is that the geometric layout for a given set of bonds and atom types might not be uniquely defined.
# You can see the subtle difference if you compare against the visualization below which uses the atom coordinates according to the xyz-files.
# if you click on the link you will see 3D structure of the molecule in the browser:
from mldft.utils.molecules import get_mol_view_link

print("Click on the following link to visualize the molecule in 3D in your browser:")
get_mol_view_link(mol)

## 3D visualization in the notebook based on pyvista 

In [None]:
import sys

# keep only the program name so downstream parsers don't see Jupyter's -f=...
sys.argv = sys.argv[:1]

import pyvista

from mldft.utils.visualize_3d import (
    get_local_frames_mesh_dict,
    get_sticks_mesh_dict,
    visualize_orbital,
)

# this give a ball and stick model of the molecule
molecule_mesh = get_sticks_mesh_dict(mol)

# this can be used to visualize local frames (in this case just the global coordinate frame at the origin)
global_frame_mesh = get_local_frames_mesh_dict(
    origins=torch.zeros(1, 3),
    bases=torch.eye(3)[None],
    scale=2,
)

# plot the molecule and the global frame using pyvista:
pyvista.set_jupyter_backend("html")
pl = pyvista.Plotter(off_screen=True, notebook=True, image_scale=1)
pl.add_mesh(**global_frame_mesh)
pl.add_mesh(**molecule_mesh)
pl.reset_camera(
    bounds=0.9 * np.stack([mol.atom_coords().min(0), mol.atom_coords().max(0)], axis=1).flatten()
)

print(
    "3d visualization of our sample molecule together with the global coordinate frame placed at the origin:"
)
img = pl.show(screenshot=True, window_size=(800, 400))

# the following can also be used to for programmatric plotting of 3d molecules in matplotlib:
print("\n\nWe can also obtain a non-interactive image of the molecule:")
plt.imshow(img)
plt.show()

For instance, we can use the screenshot function to create a small gif of the camera rotating around the molecule:

In [None]:
# create a series of images in which the camera rotates around the molecule


def rotate_around_molecule(pl, n_frames=36, radius=1.5):
    angles = np.linspace(0, 2 * np.pi, n_frames, endpoint=False)
    images = []
    for angle in angles:
        camera_position = [
            radius * np.cos(angle),
            radius * np.sin(angle),
            0.5 * radius,
        ]
        pl.camera_position = (camera_position, (0, 0, 0), (0, 0, 1))
        pl.render()  # <- force update
        img = pl.screenshot(transparent_background=False, window_size=(800, 400))
        images.append(img)
    return images


# create the images
images = rotate_around_molecule(pl, n_frames=60, radius=26.0)

# create a gif
import imageio

imageio.mimsave("molecule_rotation.gif", images, fps=30, loop=0)

# display the gif
from IPython.display import Image

Image(filename="molecule_rotation.gif")  # this will display the gif in the notebook

In [None]:
# let us visualize the electron density:
pl = pyvista.Plotter(off_screen=True, notebook=True, image_scale=1)
pl = visualize_orbital(
    mol=mol,
    coeff=sample.coeffs.numpy(),
    plotter=pl,
)
pl.reset_camera(
    bounds=0.9 * np.stack([mol.atom_coords().min(0), mol.atom_coords().max(0)], axis=1).flatten()
)

print("Hint: you might have to zoom in or out a bit to see some thing at first.")
print(
    "3d visualization of the electron density as linear combination of basis functions\nusing the coefficients in the sample:"
)
img = pl.show(screenshot=True, window_size=(800, 400))

In [None]:
# let us visualize a single basis function:
coeffs = np.zeros(sample.coeffs.shape)
coeffs[194] = 1.0  # set one coefficient to one, all others to zero

pl = pyvista.Plotter(off_screen=True, notebook=True, image_scale=1)
pl = visualize_orbital(
    mol=mol,
    coeff=coeffs,
    plotter=pl,
    mode="isosurface",
    resolution=0.15,
    isosurface_quantile=0.95,
)

print("3d visualization of a single basis function (one coefficient set to 1, all others to 0):")
img = pl.show(screenshot=True, window_size=(800, 400))