In [None]:
import torch
import os
#from ocpmodels.charge import models
import models
from charge_trainer_copy import ChargeTrainer
from ocpmodels.common import logger
from ocpmodels.common.utils import setup_logging
from DeepDFT import probe
from chg_utils import batch_to_deepDFT_dict
setup_logging()

import matplotlib.pyplot as plt

In [None]:
# a simple sanity check that a GPU is available
if torch.cuda.is_available():
    print("True")
else:
    print("False")
    torch.set_num_threads(8)

In [None]:
task = {
    'dataset': 'lmdb',
    'description': 'Initial test of training on charges',
    'type': 'regression',
    'metric': ['charge_mse', 'charge_mae'],
    'primary_metric': 'charge_mse',
    'labels': ['charge_vals'],
}

In [None]:
model = {
    'name': 'charge_model',
    'atom_model_config': {
        'name': 'schnet_charge',
        'hidden_channels': 32,
        'num_interactions':3,
        'num_filters':32,
        'num_gaussians':32,
    },
    'probe_model_config': {
        'name': 'deepdft_probe',
        'hidden_state_size': 32,
        'num_interactions':3,
        'cutoff':3,
        'gaussian_expansion_step':0.2,
    },
    'probe_state_size': 32,
    #'atom_representation_reduction': [32, 64]
}

In [None]:
optimizer = {
    'optimizer': 'Adam',
    #'optimizer_params': {'amsgrad':True},
    'batch_size': 1,
    'eval_batch_size': 1,
    'num_workers': 8,
    'lr_initial': 0.1,
    'scheduler': "ReduceLROnPlateau",
    'mode': "min",
    'factor': 0.96,
    'patience': 1,
    'max_epochs': 1000,
}

In [None]:
dataset = [
{'src': '../chg/100/train', 'normalize_labels': False}, # train set 
{'src': '../chg/100/val'}, # val set (optional)
# {'src': train_src} # test set (optional - writes predictions to disk)
]

In [None]:
trainer = ChargeTrainer(
    task=task,
    model=model,
    dataset=dataset,
    optimizer=optimizer,
    identifier="example",
    run_dir="./", # directory to save results if is_debug=False. Prediction files are saved here so be careful not to override!
    is_debug = True, #if True, do not save checkpoint, logs, or results
    print_every=1,
    seed=2, # random seed to use
    logger='wandb', # logger of choice (tensorboard and wandb supported)
    local_rank=0,
    cutoff = 4,
    num_probes = 1000,
    amp=True, # use PyTorch Automatic Mixed Precision (faster training and less memory usage)
)

In [None]:
trainer.train()

In [None]:
from chg_utils import BatchToChargeGraphs, get_probe_graph
model = trainer.model.module
loader = iter(trainer.train_loader)

In [None]:
torch.cuda.empty_cache()
batch = next(loader)
for subbatch in batch:
    subbatch.input_dict = batch_to_deepDFT_dict(subbatch, cutoff=4, num_probes=1000)

In [None]:
pred = model(batch[0].to('cuda'))
true = batch[0].input_dict['probe_target']

In [None]:
torch.mean(torch.abs(true - pred))

In [None]:
print(torch.min(pred).item())
print(torch.max(pred).item())
print(torch.mean(pred).item())
print(torch.std(pred).item())

In [None]:
print(torch.min(true).item())
print(torch.max(true).item())
print(torch.mean(true).item())
print(torch.std(true).item())

In [None]:
plt.scatter(true.detach().cpu(), pred.detach().cpu(), color='red', label='Predictions')
lb = torch.min(true.detach().cpu())
ub = torch.max(true.detach().cpu())
plt.plot([lb, ub], [lb,ub], label='Parity line')
plt.gca().set_xscale('log')
plt.gca().set_yscale('log')
plt.xlabel('True label')
plt.ylabel('Predicted label')
plt.legend()
plt.gcf().set_dpi(200)
plt.show()