In [None]:
import torch
import os
import models
from charge_trainer import ChargeTrainer
from ocpmodels.common import logger
from ocpmodels.common.registry import registry
from ocpmodels.common.utils import setup_logging
from DeepDFT import densitymodel
from chg_utils import ProbeGraphAdder
setup_logging()

import matplotlib.pyplot as plt
from torch_geometric.data import Batch

In [None]:
if torch.cuda.is_available():
    print("True")
else:
    print("False")
    torch.set_num_threads(8)

In [None]:
task = {
    'dataset': 'lmdb',
    'description': 'Initial test of training on charges',
    'type': 'regression',
    'metric': ['charge_mse', 'charge_mae'],
    'primary_metric': 'charge_mae',
    'labels': ['charge_vals'],
}

In [None]:
'''
The atom_model_config and probe_model_config should inherit most keywords
from OCP models. The exception is specifications for the number of
interactions and the number of channels. These hyperparameters are needed
in the wrapper model as well, so they are specified outside of the
individual configurations.
'''

model = {
    'name': 'charge_model',
    'num_interactions': 5,
    'atom_channels': 256,
    'probe_channels': 256,
    
    'atom_model_config': {
        'name': 'schnet_charge',
        'num_filters':64,
        'num_gaussians':32,
        'cutoff':5,
    },
    
    'probe_model_config': {
        'name': 'schnet_charge',
        'num_filters':64,
        'num_gaussians':32,
        'cutoff':5,
    },
}

In [None]:
optimizer = {
    'optimizer': 'Adam',
    'batch_size': 10,
    'eval_batch_size': 10,
    'num_workers': 1,
    'lr_initial': 5e-4,
    'scheduler': "ReduceLROnPlateau",
    'mode': "min",
    'factor': 0.96,
    'patience': 1,
    'max_epochs': 1000,
}

In [None]:
dataset = [
{'src': '../chg/100/train', 'normalize_labels': False}, # train set 
{'src': '../chg/100/val'}, # val set (optional)
# {'src': train_src} # test set (optional - writes predictions to disk)
]

In [None]:
trainer_config = {
    'trainer': 'charge',
    'identifier': 'A Good Run',
    'is_debug': False,
    'run_dir': './runs/',
    'print_every': 1,
    'seed': 2,
    'logger': 'wandb',
    'local_rank': 0,
    'amp': True,
    
    'cutoff': 5,
    'train_probes': 500,
    'val_probes': 1000,
    'test_probes': 1000,
}

In [None]:
trainer = registry.get_trainer_class(
    trainer_config['trainer'])(task = task,
                               model = model,
                               dataset = dataset,
                               optimizer = optimizer,
                               **trainer_config)

In [None]:
trainer.train()

In [None]:
model = trainer.model.module
loader = iter(trainer.train_loader)

In [None]:
#torch.cuda.empty_cache()
batch = next(loader)
for subbatch in batch:
    subbatch.probe_data = Batch.from_data_list(subbatch.probe_data)

In [None]:
pred = model(batch[0].to('cuda'))
true = batch[0].probe_data.target

In [None]:
true = true.detach().cpu()#.reshape(10000)
pred = pred.detach().cpu()#.reshape(10000)

plt.scatter(true, pred, 
            color='red', 
            s=1,
            label='Predictions')

lb = torch.min(true)
ub = torch.max(true)
plt.plot([lb, ub], [lb,ub], label='Parity line')
plt.gca().set_xscale('log')
plt.gca().set_yscale('log')
plt.xlabel('True label')
plt.ylabel('Predicted label')
plt.legend()
plt.gcf().set_dpi(200)
plt.show()

In [None]:
err = torch.mean(torch.abs(pred - true))
print(err.item())

In [None]:
print(torch.min(pred).item())
print(torch.max(pred).item())
print(torch.mean(pred).item())
print(torch.std(pred).item())

In [None]:
print(torch.min(true).item())
print(torch.max(true).item())
print(torch.mean(true).item())
print(torch.std(true).item())