In [None]:
import torch
import os
import models
from charge_trainer import ChargeTrainer
from ocpmodels.common import logger
from ocpmodels.common.registry import registry
from ocpmodels.common.utils import setup_logging
from ocpmodels.common.utils import pyg2_data_transform
from chg_utils import ProbeGraphAdder
from chg_utils import charge_density
setup_logging()

import matplotlib.pyplot as plt
from torch_geometric.data import Batch

In [None]:
if torch.cuda.is_available():
    print("True")
else:
    print("False")
    torch.set_num_threads(8)

In [None]:
task = {
    'dataset': 'lmdb',
    'description': 'Initial test of training on charges',
    'type': 'regression',
    'metric': ['charge_mse', 'charge_mae'],
    'primary_metric': 'charge_mae',
    'labels': ['charge_vals'],
}

In [None]:
'''
The atom_model_config and probe_model_config should inherit most keywords
from OCP models. The exception is specifications for the number of
interactions and the number of channels. These hyperparameters are needed
in the wrapper model as well, so they are specified outside of the
individual configurations.
'''

model = {
    'name': 'charge_model',
    'num_interactions': 3,
    'atom_channels': 256,
    'probe_channels': 256,
    
    'atom_model_config': {
        'name': 'schnet_charge',
        'num_filters':64,
        'num_gaussians':64,
        'cutoff':5,
    },
    
    'probe_model_config': {
        'name': 'schnet_charge',
        'num_filters':128,
        'num_gaussians':128,
        'cutoff':3,
    },
}

In [None]:
optimizer = {
    'optimizer': 'Adam',
    'batch_size': 2,
    'eval_batch_size': 2,
    'num_workers': 32,
    'lr_initial': 5e-5,
    'scheduler': "ReduceLROnPlateau",
    'mode': "min",
    'factor': 0.96,
    'patience': 1,
    'max_epochs': 1000,
}

In [None]:
dataset = [
{'src': '../charge-data/1k-no-probe-graphs/train', 'normalize_labels': False}, # train set 
{'src': '../charge-data/1k-no-probe-graphs/val'}, # val set (optional)
# {'src': train_src} # test set (optional - writes predictions to disk)
]

In [None]:
trainer_config = {
    'trainer': 'charge',
    'identifier': '1k Dataset',
    'is_debug': False,
    'run_dir': './runs/',
    'print_every': 5,
    'seed': 2,
    'logger': 'wandb',
    'local_rank': 0,
    'amp': True,
    'probe_graph_config':{
        'train_probes': 300,
        'val_probes': 300,
        'test_probes': 300,
        'cutoff': 3,
        'include_atomic_edges': True,
    }
}

In [None]:
trainer = registry.get_trainer_class(
    trainer_config['trainer'])(task = task,
                               model = model,
                               dataset = dataset,
                               optimizer = optimizer,
                               **trainer_config)

In [None]:
trainer.train()

charge_mse: 9.93e-03, charge_mae: 2.14e-02, loss: 2.14e-02, lr: 5.00e-05, epoch: 5.36e+00, step: 2.41e+03


In [None]:
model = trainer.model.module
loader = iter(trainer.train_loader)

In [None]:
#torch.cuda.empty_cache()
batch = next(loader)
for subbatch in batch:
    subbatch.probe_data = [pyg2_data_transform(x) for x in subbatch.probe_data]
    subbatch.probe_data = Batch.from_data_list(subbatch.probe_data)

In [None]:
print(batch)

pred = model(batch[0].to('cuda'))
true = batch[0].probe_data.target

In [None]:
true = true.detach().cpu()
pred = pred.detach().cpu()

plt.scatter(true, pred, 
            color='blue', 
            alpha = 0.1,
            s=2,
            label='Predictions',
          )

lb = torch.min(true)
ub = torch.max(true)
plt.plot([lb, ub], [lb,ub], label='Parity line', color='red')
#plt.gca().set_xscale('log')
#plt.gca().set_yscale('log')
plt.xlabel('True label')
plt.ylabel('Predicted label')
plt.legend()
plt.gcf().set_dpi(200)
plt.show()

In [None]:
true = true.detach().cpu()
pred = pred.detach().cpu()

plt.scatter(true, pred, 
            color='blue', 
            alpha = 0.05,
            s=2,
            label='Predictions',
          )

lb = torch.min(true)
ub = torch.max(true)
plt.plot([lb, ub], [lb,ub], label='Parity line', color='red')
plt.gca().set_xscale('log')
plt.gca().set_yscale('log')
plt.xlabel('True label')
plt.ylabel('Predicted label')
plt.legend()
plt.gcf().set_dpi(200)
plt.show()

In [None]:
err = torch.mean(torch.abs(pred - true))
print(err.item())

In [None]:
print(torch.min(pred).item())
print(torch.max(pred).item())
print(torch.mean(pred).item())
print(torch.std(pred).item())

In [None]:
print(torch.min(true).item())
print(torch.max(true).item())
print(torch.mean(true).item())
print(torch.std(true).item())

In [None]:
print(pred)

In [None]:
import torch
import os
import models
from charge_trainer import ChargeTrainer
from ocpmodels.common import logger
from ocpmodels.common.registry import registry
from ocpmodels.common.utils import setup_logging
from chg_utils import ProbeGraphAdder
setup_logging()
import yaml
from ocpmodels.preprocessing import AtomsToGraphs
from ocpmodels.datasets import data_list_collater

import matplotlib.pyplot as plt
from torch_geometric.data import Batch

from pymatgen.core.sites import PeriodicSite
from pymatgen.io.ase import AseAtomsAdaptor
from torch_geometric.data import Data
from ase import Atoms
from ase.calculators.vasp import VaspChargeDensity
import ase.neighborlist as nbl
from tqdm import tqdm
import numpy as np
import time

In [None]:
a2g = AtomsToGraphs(
        max_neigh = 100,
        radius = 6,
        r_energy = False,
        r_forces = False,
        r_distances = False,
        r_fixed = False,
    )

pga = ProbeGraphAdder(num_probes = 100, cutoff = 4,
                      mode = 'slice', slice_start = 0,
                     stride = 4)

In [None]:
path = '../shared-scratch/ethan/sample/val/random1036401_190/CHGCAR'

vcd = VaspChargeDensity(path)
atoms = vcd.atoms[-1]
dens = vcd.chg[-1]
data_object = a2g.convert(atoms)
data_object.charge_density = dens

print(atoms)

In [None]:
slice0 = pga(data_object, 
             slice_start = 0, 
             num_probes = 300, 
             mode='all',
            stride = 4)
batch = data_list_collater([slice0])
batch.probe_data = Batch.from_data_list([slice0.probe_data])

batch = [batch]
TEST_BATCH = batch

In [None]:
batch = TEST_BATCH
t1 = time.time()

pred = model(batch[0].to('cuda'))

t2 = time.time()

true = batch[0].probe_data.target

print(len(pred))
print(len(true))

print(t2 - t1)

In [None]:
true = true.detach().cpu()
pred = pred.detach().cpu()

plt.scatter(true, pred, 
            color='blue', 
            alpha = 0.1,
            s=2,
            label='Predictions',
          )

lb = torch.min(true)
ub = torch.max(true)
plt.plot([lb, ub], [lb,ub], label='Parity line', color='red')
#plt.gca().set_xscale('log')
#plt.gca().set_yscale('log')
plt.xlabel('DFT Charge Density ($e/Å^3$)')
plt.ylabel('Model Predicted Charge Density ($e/Å^3$)')
plt.legend()
plt.gcf().set_dpi(200)
plt.show()

In [None]:
true = true.detach().cpu()
pred = pred.detach().cpu()

plt.scatter(true, pred, 
            color='blue', 
            alpha = 0.01,
            s=2,
            label='Predictions',
          )

lb = torch.min(true)
ub = torch.max(true)
plt.plot([lb, ub], [lb,ub], label='Parity line', color='red')
plt.gca().set_xscale('log')
plt.gca().set_yscale('log')
plt.xlabel('DFT Charge Density ($e/Å^3$)')
plt.ylabel('Model Predicted Charge Density ($e/Å^3$)')
plt.legend()
plt.gcf().set_dpi(200)
plt.show()

In [None]:
print(torch.mean(torch.abs((true - pred))))
print(torch.mean(true))

print(torch.mean(torch.abs((true - pred))) / torch.mean(true))

In [None]:
shape = batch[0].charge_density[0][::4, ::4, ::4].shape
print(shape)

In [None]:
true_shaped = np.reshape(true, shape)
pred_shaped = np.reshape(pred, shape)

In [None]:
cd = charge_density('../shared-scratch/ethan/sample/val/random1036401_190/CHGCAR')
print(cd)

In [None]:
cd.charge = true_shaped
cd.grid = shape

In [None]:
print(cd)

In [None]:
cd.write_CHGCAR('true_downsampled')

In [None]:
print(torch.max(true))

In [None]:
print(batch[0])

In [None]:
probe_coords = batch[0].probe_data.pos[54:].to('cpu').numpy()
print(probe_coords.shape)

In [None]:
import pandas as pd
dd = {'x':probe_coords[:, 0],
      'y':probe_coords[:, 1],
      'z':probe_coords[:, 2],
      'true': true,
      'pred': pred,
      'diff': np.abs(true - pred)}

In [None]:
df = pd.DataFrame(dd)

In [None]:
df.head()

In [None]:
df.to_csv('data.csv')

In [None]:
print(batch[0].cell)

In [None]:
coords = batch[0].pos.to('cpu').numpy()
dd = {'x':coords[:, 0],
      'y':coords[:, 1],
      'z':coords[:, 2],
      'an':batch[0].atomic_numbers.to('cpu').numpy()}

In [None]:
df = pd.DataFrame(dd)
df.loc[df["an"] == 22, "atom"] = 'Ti'
df.loc[df["an"] == 78, "atom"] = 'Pt'
df.loc[df["an"] == 6, "atom"] = 'C'
df.loc[df["an"] == 1, "atom"] = 'H'
df.loc[df["an"] == 8, "atom"] = 'O'
df.loc[df["an"] == 39, "atom"] = 'Y'
df.loc[df["an"] == 49, "atom"] = 'In'
df.to_csv('atoms.csv')

In [None]:
print(batch[0].cell)

In [None]:
from ase import Atoms
from ase.calculators.vasp import VaspChargeDensity

In [None]:
vcd = VaspChargeDensity('../shared-scratch/ethan/density/1k_sample/val/random1005744_130/CHGCAR')
atoms = vcd.atoms[-1]
dens = vcd.chg[-1]

In [None]:
print(atoms)