In [None]:
# %load_ext autoreload
# %autoreload 2

In [None]:
from pathlib import Path
from typing import List

import os
import hydra
import numpy as np
import torch
import omegaconf
import wandb
import pytorch_lightning as pl
import json


from hydra.core.hydra_config import HydraConfig
from omegaconf import DictConfig, OmegaConf
from pytorch_lightning import seed_everything, Callback
from pytorch_lightning.loggers import WandbLogger
from source.common.utils import build_callbacks, log_hyperparameters, PROJECT_ROOT

import torch
import lz4.frame
from tqdm import tqdm
import plotly.graph_objects as go
from IPython.display import Image, display

os.environ['CUDA_VISIBLE_DEVICES'] = '5'
device = torch.device('cuda')

In [None]:
qm9={} # svr2
qm9["infgcn"] ={
    "command": "--config-name=qm9 postfix=final data.rotate=True data.train_max_iter=80000 data.num_test_samples=1600 train.random_seed=42 data.reverse_order=True logging=draw model.num_fourier=30 model=RNO_QM9_v6-sep model.residual=True model.max_cell_size=16 model.num_spherical=4 model.num_spherical_RNO=3 model.scalar_mask=True model.grid_cutoff=0.75 model.mask_cutoff=2.0",
}
qm9["GPW"] ={
    "command": "--config-name=qm9 postfix=final data.rotate=True data.train_max_iter=80000 data.num_test_samples=1600 train.random_seed=42 data.reverse_order=True logging=draw model=infgcnQM9_rep model.num_spherical=7"
}

mixed ={} #svr9
mixed["infgcn"] ={
    "command": "--config-name=mp_mixed data=all_mp_mixed train.random_seed=42 postfix=rep model.cutoff=3.0 data.num_workers=32 logging=draw model.residual=True model.num_spherical=3"
}
mixed["GPW"] = {
    "command": "--config-name=mp_mixed data=all_mp_mixed model=RNO_v8-pbc-soul2-3 train.random_seed=1 postfix=test2 data.num_workers=32 data=all_mp_mixed logging=draw model.num_fourier=20 model.residual=True model.num_spherical=3 model.num_spherical_RNO=3 model.scalar_mask=True model.probe_cutoff=1.5 model.grid_cutoff=0.75 model.mask_cutoff=1.0 model.cutoff=3.0 model.input_infgcn=False +model.input_dist=False +model.atomic_gauss_dist=False"
}

tetragonal={} #svr6
tetragonal["infgcn"] ={
    "command": "--config-name=mp_mixed data=mp_tetragonal train.random_seed=1 postfix=rep model.cutoff=3.0 data.num_workers=32 logging=draw model.residual=True model.num_spherical=3"
}
tetragonal["GPW"] = {
    "command": "--config-name=mp_mixed data=mp_tetragonal model=RNO_v8-pbc-soul2-3 train.random_seed=1 postfix=test2 data.num_workers=32 logging=draw model.num_fourier=20 model.residual=True model.num_spherical=3 model.num_spherical_RNO=3 model.scalar_mask=True model.probe_cutoff=1.5 model.grid_cutoff=0.75 model.mask_cutoff=1.0 model.cutoff=3.0 model.input_infgcn=False +model.input_dist=False +model.atomic_gauss_dist=False"
}

cubic={} #cluster
cubic["infgcn"] ={
    "command": "--config-name=mp_mixed data=mp_cubic train.random_seed=1 postfix=rep data.num_workers=16 logging=draw model.residual=True model.cutoff=3 model.num_spherical=3"
}
cubic["GPW"] = {
    "command": "--config-name=mp_mixed data=mp_cubic model=RNO_v8-pbc-soul2-3 train.random_seed=1 postfix=tt data.num_workers=32 logging=draw model.num_fourier=20 model.residual=True model.num_spherical=3 model.num_spherical_RNO=3 model.cutoff=3.0 model.scalar_mask=True model.probe_cutoff=1.5 model.grid_cutoff=0.75 model.mask_cutoff=1.0 model.input_infgcn=False +model.input_dist=False +model.atomic_gauss_dist=False"
}

triclinic={} #cluster
triclinic["infgcn"] ={
    "command": "--config-name=mp_mixed data=mp_triclinic train.random_seed=1 postfix=rep data.num_workers=16 logging=draw model.residual=True model.cutoff=3 model.num_spherical=3"
}
triclinic["GPW"] = {
    "command": "--config-name=mp_mixed data=mp_triclinic model=RNO_v8-pbc-soul2-3 train.random_seed=1 postfix=tt data.num_workers=32 logging=draw model.num_fourier=20 model.residual=True model.num_spherical=3 model.num_spherical_RNO=3 model.scalar_mask=True model.cutoff=3.0 model.probe_cutoff=1.5 model.grid_cutoff=0.75 model.mask_cutoff=1.0 model.input_infgcn=False +model.input_dist=False +model.atomic_gauss_dist=False"
}

hexagonal={} #svr9
hexagonal["infgcn"] ={
    "command": "--config-name=mp_mixed data=mp_hexagonal train.random_seed=42 postfix=rep model.cutoff=3.0 data.num_workers=32 logging=draw model.residual=True model.num_spherical=3"
}
hexagonal["GPW"] = {
    "command": "--config-name=mp_mixed data=mp_hexagonal model=RNO_v8-pbc-soul2-3 train.random_seed=42 postfix=test2 data.num_workers=32 logging=draw model.num_fourier=20 model.residual=True model.num_spherical=3 model.num_spherical_RNO=3 model.scalar_mask=True model.probe_cutoff=1.5 model.grid_cutoff=0.75 model.mask_cutoff=1.0 model.cutoff=3.0 model.input_infgcn=False +model.input_dist=False +model.atomic_gauss_dist=False"
}

monoclinic={} #svr10
monoclinic["infgcn"] ={
    "command": "--config-name=mp_mixed data=mp_monoclinic train.random_seed=2 postfix=rep data.num_workers=32 logging=draw model.residual=True model.num_spherical=3 model.cutoff=3.0"
}
monoclinic["GPW"] = {
    "command": "--config-name=mp_mixed model=RNO_v8-pbc-soul2-3 data=mp_monoclinic train.random_seed=2 postfix=test2 data.num_workers=32 logging=draw model.num_fourier=20 model.residual=True model.num_spherical=3 model.num_spherical_RNO=3 model.scalar_mask=True model.probe_cutoff=1.5 model.grid_cutoff=0.75 model.mask_cutoff=1.0 model.cutoff=3.0 model.input_infgcn=False +model.input_dist=False +model.atomic_gauss_dist=False"
}

orthorhombic={} #svr10
orthorhombic["infgcn"] ={
    "command": "--config-name=mp_mixed data=mp_orthorhombic train.random_seed=42 postfix=rep model.cutoff=3.0 data.num_workers=32 logging=draw model.residual=True model.num_spherical=3"
}
orthorhombic["GPW"] = {
    "command": "--config-name=mp_mixed data=mp_orthorhombic model=RNO_v8-pbc-soul2-3 train.random_seed=42 postfix=test2 data.num_workers=32 logging=draw model.num_fourier=20 model.residual=True model.num_spherical=3 model.num_spherical_RNO=3 model.scalar_mask=True model.probe_cutoff=1.5 model.grid_cutoff=0.75 model.mask_cutoff=1.0 model.cutoff=3.0 model.input_infgcn=False +model.input_dist=False +model.atomic_gauss_dist=False"
}

trigonal={} #svr9
trigonal["infgcn"] ={
    "command": "--config-name=mp_mixed data=mp_trigonal train.random_seed=1 postfix=rep data.num_workers=32 logging=draw model.residual=True model.cutoff=3.0 model.num_spherical=3"
}
trigonal["GPW"] = {
    "command": "--config-name=mp_mixed model=RNO_v8-pbc-soul2-3 data=mp_trigonal train.random_seed=1 postfix=test2 data.num_workers=32 logging=draw model.num_fourier=20 model.residual=True model.num_spherical=3 model.num_spherical_RNO=3 model.scalar_mask=True model.probe_cutoff=1.5 model.grid_cutoff=0.75 model.mask_cutoff=1.0 model.cutoff=3.0 model.input_infgcn=False +model.input_dist=False +model.atomic_gauss_dist=False"
}

data_dict = {
    "qm9": qm9,
    "mixed": mixed,
    "tetragonal": tetragonal,
    "cubic": cubic,
    "triclinic": triclinic,
    "hexagonal": hexagonal,
    "monoclinic": monoclinic,
    "orthorhombic": orthorhombic,
    "trigonal": trigonal,
}

In [None]:
def model_path(data_type,model_type):
    if model_type == "GPW":
        model_type = "GPW-NO"
    if data_type == "qm9":
        folder = f"../GPW-NO-model/{data_type}_{model_type}/"
    else:
        folder = f"../GPW-NO-model/mp-{data_type}-{model_type}/"
    model_path = os.listdir(folder)[0]
    return os.path.join(folder,model_path)

In [None]:
# random_seed = 2

In [None]:
from hydra import compose, initialize
from omegaconf import OmegaConf


# context initialization
# https://hydra.cc/docs/advanced/compose_api/


current_data = "tetragonal"
model_type = "GPW"
inference= True

file_path = model_path(current_data, model_type)

options = data_dict[current_data][model_type]["command"]
file_path_upd = file_path.replace("=","\\=")
overrides = options.split()
test_num_samples = 10
overrides.append(f"+data.datamodule.datasets.test.num_samples={test_num_samples}")
overrides.append(f"+model.checkpoint_path={file_path_upd}")
if overrides[0].find("config-name") != -1:
    config_name = overrides[0].split("=")[1]
    overrides = overrides[1:]
else:
    config_name = "default"

with initialize(config_path="conf"):
    cfg = compose(config_name=config_name, overrides=overrides, return_hydra_config=True)
    print(OmegaConf.to_yaml(cfg))

In [None]:
datamodule: pl.LightningDataModule = hydra.utils.instantiate(
    cfg.data.datamodule, _recursive_=False
)

""" Instantiate model """
hydra.utils.log.info(f"Instantiating <{cfg.model._target_}>")
model: pl.LightningModule = hydra.utils.instantiate(
    cfg.model,
    optim=cfg.optim,
    data=cfg.data,
    logging=cfg.logging,
    _recursive_=False,
)

In [None]:
model = model.__class__.load_from_checkpoint(
    checkpoint_path=cfg.model.checkpoint_path,
)

In [None]:
HydraConfig.instance().set_config(cfg)
callbacks: List[Callback] = build_callbacks(cfg=cfg)

""" Hydra run directory """
hydra_dir = Path(HydraConfig.get().run.dir)

wandb_logger = None

# if "wandb" in cfg.logging:
#         hydra.utils.log.info("Instantiating <WandbLogger>")
#         wandb_config = cfg.logging.wandb
#         wandb_logger = WandbLogger(
#             **wandb_config,
#             save_dir=hydra_dir,
#             tags=cfg.core.tags,
#         )
#         hydra.utils.log.info("W&B is now watching <{cfg.logging.wandb_watch.log}>!")
#         wandb_logger.watch(
#             model,
#             log=cfg.logging.wandb_watch.log,
#             log_freq=cfg.logging.wandb_watch.log_freq,
#         )


In [None]:
trainer = pl.Trainer(
    accelerator="auto",
    default_root_dir=hydra_dir,
    logger=wandb_logger,
    callbacks=callbacks,
    deterministic=cfg.train.deterministic,
    # check_val_every_n_epoch=cfg.logging.val_check_interval,
    log_every_n_steps=1,
    **cfg.train.pl_trainer,  # max_steps 포함
)

In [None]:
datamodule.setup()
test_data_loader = datamodule.test_dataloader()
save_dic = "../outputs_res" + f"/{current_data}"
os.makedirs(save_dic, exist_ok=True)
save_path = os.path.join(save_dic, f"{model_type}")
print(len(test_data_loader), save_path)

In [None]:
if inference:
    hydra.utils.log.info("Starting testing!")
    pred = trainer.predict(model=model, dataloaders=test_data_loader)
    torch.save(pred, save_path + "_pred.pt")
pred = torch.load(save_path + "_pred.pt")

In [None]:
pred[0].keys()

In [None]:
pred[1]["density"].shape

In [None]:
def draw_volume(grid, density, atom_type, atom_coord, isomin=0.05, isomax=None, surface_count=5, title=None):
    atom_colorscale = ['grey', 'white', 'red', 'blue', 'green']
    fig = go.Figure()
    fig.add_trace(go.Volume(
        x=grid[..., 0], y=grid[..., 1], z=grid[..., 2],
        value=density,
        isomin=isomin,
        isomax=isomax,
        opacity=0.1, # needs to be small to see through all surfaces
        surface_count=surface_count, # needs to be a large number for good volume rendering
        caps=dict(x_show=False, y_show=False, z_show=False),
    ))
    axis_dict = dict(
        showgrid=False,
        showbackground=False,
        zeroline=False,
        visible=False,
    )
    fig.add_trace(go.Scatter3d(
        x=atom_coord[:, 0],
        y=atom_coord[:, 1],
        z=atom_coord[:, 2],
        mode='markers',
        marker=dict(
            size=10,
            color=atom_type,   
            cmin=0, cmax=4,
            colorscale=atom_colorscale,
            opacity=0.6
        )
    ))
    if title is not None:
        title = dict(
            text=title,
            x=0.5, y=0.3,
            xanchor='center',
            yanchor='bottom',
        )
    fig.update_layout(
        autosize=False,
        width=800,
        height=800,
        showlegend=False,
        scene=dict(
            xaxis=axis_dict,
            yaxis=axis_dict,
            zaxis=axis_dict
        ),
        title=title,
        title_font_family='Times New Roman',
    )
    return fig

In [None]:
import matplotlib.pyplot as plt
import matplotlib.colors as colors
import numpy as np

%matplotlib inline
def draw_scatter_volume(grid_coord, density, atom_type, atom_coord,cell,cmap1="bwr",cmap2="summer",save_path="test.pdf"):
    # Assuming chgcar is the charge density data
    fig = plt.figure(figsize=(4,4),dpi=500)
    ax = fig.add_subplot(111, projection='3d')
    norm = colors.CenteredNorm

    # Plot the charge density
    ax.scatter(
        grid_coord[0,...,0].numpy(),
        grid_coord[0,...,1].numpy(),
        grid_coord[0,...,2].numpy(),
        c=density.numpy(),
        cmap=cmap1,
        norm=colors.CenteredNorm(),
        alpha=0.01
    )
    ax.axis('off')
    ax.scatter(
        atom_coord[:,0],
        atom_coord[:,1],
        atom_coord[:,2],
        marker='o',
        c=atom_type.numpy(),
        cmap=cmap2,
        vmin=1,
        vmax=100
    )
    # plot cell boundary
    pts = np.array([cell[0]*i + cell[1]*j + cell[2]*k for i in [0,1] for j in [0,1] for k in [0,1]])
    draws = [[0,1],[1,3],[3,2],[2,0],[4,5],[5,7],[7,6],[6,4],[0,4],[1,5],[2,6],[3,7]]
    for draw in draws:
        i, j = draw
        ax.plot([pts[i,0],pts[j,0]],[pts[i,1],pts[j,1]],[pts[i,2],pts[j,2]],color='black',linewidth=0.5)

    # ax.set_title('Charge Density')
    plt.tight_layout()
    # Show the plot
    plt.savefig(save_path,dpi=500)
    plt.show()

In [None]:
for i in range(10):
    print(i,pred[i]["mae"])

In [None]:
for i in range(10):
    index = i
    assert index < test_num_samples
    curdata = pred[index]
    grid = curdata["grid_coord"]
    density = curdata["density"][0]
    pred_val = curdata["pred"][0]
    atom_type = curdata["atom_type"]
    atom_coord = curdata["atom_coord"]
    cell = curdata["cell"][0]

    draw_scatter_volume(grid, density, atom_type, atom_coord,cell,cmap1="Spectral",save_path=save_path+f"_{index}_target.pdf")
    draw_scatter_volume(grid, pred_val, atom_type, atom_coord,cell,cmap1="Spectral",save_path=save_path+f"_{index}_pred.pdf")
    draw_scatter_volume(grid, pred_val-density, atom_type, atom_coord,cell,cmap1="Spectral",save_path=save_path+f"_{index}_diff.pdf")
    print(index,curdata["mae"])

In [None]:
# draw_volume(curdata["grid_coord"], curdata["density"], curdata["atom_type"], curdata["atom_coord"], isomin=None, title="Target")
# draw_volume(curdata["grid_coord"], curdata["pred"], curdata["atom_type"], curdata["atom_coord"], title="Prediction")
# draw_volume(curdata["grid_coord"], curdata["pred"]-curdata["density"], curdata["atom_type"], curdata["atom_coord"], title="Difference")

In [None]:
# %matplotlib inline
# import matplotlib.pyplot as plt
# import matplotlib.colors as colors
# import numpy as np

# # Assuming chgcar is the charge density data
# fig = plt.figure()
# ax = fig.add_subplot(111, projection='3d')
# norm = CenteredNorm

# # Plot the charge density
# ax.scatter(
#     curdata["grid_coord"][0,...,0].numpy(),
#     curdata["grid_coord"][0,...,1].numpy(),
#     curdata["grid_coord"][0,...,2].numpy(),
#     c=curdata["density"][0,...].numpy(),
#     cmap='bwr',
#     norm=colors.CenteredNorm(),
#     alpha=0.01
# )

# atom_coord = curdata["atom_coord"].numpy()

# ax.scatter(
#     atom_coord[:,0],
#     atom_coord[:,1],
#     atom_coord[:,2],
#     marker='o',
#     color='k'
# )

# ax.axis('off')
# # Set labels and title
# # ax.set_xlabel('X')
# # ax.set_ylabel('Y')
# # ax.set_zlabel('Z')

# # ax.set_title('Charge Density')
# plt.tight_layout()
# # Show the plot
# plt.show()
