In [None]:
%env OMP_NUM_THREADS = 1

In [None]:
import cdm

from ocpmodels.common import logger
from ocpmodels.common.registry import registry
from ocpmodels.common.utils import setup_logging

setup_logging()

In [None]:
task = {
    'description': 'Predicting electron density from atomic positions',
    'dataset': 'lmdb',
}

In [None]:
model = {
    'name': 'charge_model',
    'enforce_zero_for_disconnected_probes': True,
    'enforce_charge_conservation': True,
    'freeze_atomic': False,
    
    'atom_model_config': {
        'name': 'schnet_charge',
    },
    
    'probe_model_config': {
        'name': 'schnet_charge',
        'num_interactions': 3,
        'cutoff': 5,
    },
    
    'otf_pga_config': {
        'num_probes': 100000,
        'cutoff': 6,
    }
}

In [None]:
optimizer = {
    'optimizer': 'Adam',
    'num_workers': 7,
    'lr_initial': 5e-5,
    'scheduler': "ReduceLROnPlateau",
    'mode': "min",
    'factor': 0.96,
    'patience': 1,
    'max_epochs': 1000,
    'loss_charge': 'normed_mae'
}

In [None]:
dataset = [
    {'src': 'path/to/train'}, 
    {'src': 'path/to/val'},
]

In [None]:
trainer_config = {
    'trainer': 'charge',
    'identifier': 'Electron Density Prediction with SchNet',
    'is_debug': True,
    'run_dir': '../runs/',
    'print_every': 1,
    'seed': 2,
    'logger': 'wandb',
    'local_rank': 0,
    'amp': True,
}

In [None]:
trainer = registry.get_trainer_class(trainer_config['trainer'])(
    task = task,
    model = model,
    dataset = dataset,
    optimizer = optimizer,               
    **trainer_config
)

In [None]:
trainer.model.module

In [None]:
trainer.train()