In [None]:
%env OMP_NUM_THREADS = 1

In [None]:
import torch
import os

from ocpmodels.common import logger
from ocpmodels.common.registry import registry
from ocpmodels.common.utils import setup_logging

from cdm.charge_trainer import ChargeTrainer
from cdm.chg_utils import ProbeGraphAdder
from cdm import models

setup_logging()

In [None]:
if torch.cuda.is_available():
    print("True")
else:
    print("False")
    torch.set_num_threads(8)

In [None]:
task = {
    'dataset': 'lmdb',
    'description': 'Training on charge density',
    'type': 'regression',
    'metric': ['charge_mse', 'charge_mae', 'charge_fe'],
    'primary_metric': 'charge_mae',
    'labels': ['charge_vals'],
}

In [None]:
model = {
    'name': 'charge_model',
    'num_interactions': 3,
    'atom_channels': 32,
    'probe_channels': 32,
    'enforce_zero_for_disconnected_probes': True,
    
    'atom_model_config': {
        'name': 'schnet_charge',
        'num_filters': 16,
        'num_gaussians': 16,
        'cutoff': 5,
    },
    
    'probe_model_config': {
        'name': 'schnet_charge',
        'num_filters': 16,
        'num_gaussians': 32,
        'cutoff': 4,
    },
}

In [None]:
optimizer = {
    'optimizer': 'Adam',
    'batch_size': 4,
    'eval_batch_size': 4,
    'num_workers': 24,
    'lr_initial': 5e-5,
    'scheduler': "ReduceLROnPlateau",
    'mode': "min",
    'factor': 0.96,
    'patience': 1,
    'max_epochs': 1000,
}

In [None]:
dataset = [
{'src': '../../charge-data/1k-no-probe-graphs/train', 'normalize_labels': False}, # train set 
{'src': '../../charge-data/1k-no-probe-graphs/val'}, # val set (optional)
# {'src': train_src} # test set (optional - writes predictions to disk)
]

In [None]:
trainer_config = {
    'trainer': 'charge',
    'identifier': 'New package',
    'is_debug': True,
    'run_dir': './runs/',
    'print_every': 5,
    'seed': 2,
    'logger': 'wandb',
    'local_rank': 0,
    'amp': True,
    'probe_graph_config':{
        'train_probes': 200,
        'val_probes': 200,
        'test_probes': 200,
        'cutoff': 4,
        'include_atomic_edges': False,
        'implementation': 'SKIP'
    }
}

In [None]:
trainer = registry.get_trainer_class(
    trainer_config['trainer'])(task = task,
                               model = model,
                               dataset = dataset,
                               optimizer = optimizer,
                               **trainer_config)

In [None]:
trainer.train()