In [None]:
import torch
import os
import yaml
import matplotlib.pyplot as plt
import numpy as np
import time

from tqdm import tqdm
from torch_geometric.data import Batch, Data
from pymatgen.core.sites import PeriodicSite
from pymatgen.io.ase import AseAtomsAdaptor
from ase import neighborlist as nbl
from ase import Atoms
from ase.calculators.vasp import VaspChargeDensity

from ocpmodels.common import logger
from ocpmodels.common.registry import registry
from ocpmodels.common.utils import setup_logging
from ocpmodels.preprocessing import AtomsToGraphs
from ocpmodels.datasets import data_list_collater

import cdm.models
from cdm.charge_trainer import ChargeTrainer
from cdm.utils.probe_graph import ProbeGraphAdder
from cdm.utils.inference import inference

setup_logging()

In [None]:
def make_parity_plot(x, y, LOG):
    plt.scatter(x, y, 
            color='blue', 
            alpha = 0.1,
            s=1.5,
            #label='Predictions',
          )

    plt.gcf().set_dpi(200)
    plt.axis('square')

    if LOG:
        plt.gca().set_xscale('log')
        plt.gca().set_yscale('log')

    plt.plot([0, torch.max(x)+1], [0, torch.max(x)+1], label='Parity line', color='red')
    plt.xlabel('Ground truth electron density\nelectrons per cubic Angstrom')
    plt.ylabel('Predicted electron density\nelectrons per cubic Angstrom')
    plt.xlim([1e-10, torch.max(x)+1])
    plt.ylim([1e-10, torch.max(x)+1])
    plt.show()

In [None]:
if torch.cuda.is_available():
    print("True")
else:
    print("False")
    torch.set_num_threads(8)

In [None]:
model_config = {
    'name': 'charge_model',
    'num_interactions': 4,
    'atom_channels': 64,
    'probe_channels': 64,
    'enforce_zero_for_disconnected_probes': True,
    'enforce_charge_conservation': True,
    
    'atom_model_config': {
        'name': 'schnet_charge',
        'num_filters':64,
        'num_gaussians':64,
        'cutoff':5,
    },
    
    'probe_model_config': {
        'name': 'schnet_charge',
        'num_filters':32,
        'num_gaussians':32,
        'cutoff':4,
    },
}

model = cdm.models.ChargeModel(**model_config)

path = '../runs/checkpoints/2022-11-01-18-54-56-Approximate Charge Conservation, 100k/checkpoint.pt'
state_dict = torch.load(path)['state_dict']

sd = {}

for x in state_dict.items():
    sd[x[0][7:]] = x[1]

model.load_state_dict(sd)

In [None]:
path = '../cdm/tests/test_structure'

vcd = VaspChargeDensity(path) 
atoms = vcd.atoms[-1]
dens = vcd.chg[-1]
grid = dens.shape

target = torch.tensor(dens)

print(atoms)
print(grid)

In [None]:
pred = inference(
    atoms, 
    model, 
    grid, 
    atom_cutoff = 5,
    probe_cutoff = 4,
    batch_size = 1000,
    use_tqdm = True,
    device = 'cuda',
    total_density = torch.sum(target)
)

pred = pred.to('cpu')

In [None]:
make_parity_plot(target.flatten(), pred.flatten(), LOG=False)
make_parity_plot(target.flatten(), pred.flatten(), LOG=True)

In [None]:
print(torch.mean(torch.abs(pred - target)).item())

In [None]:
print(torch.mean(pred).item())
print(torch.mean(target).item())

print((torch.mean(pred).item() - torch.mean(target).item()) /  torch.mean(target).item())

In [None]:
print(torch.std(pred).item())
print(torch.std(target).item())