In [None]:
import torch
import os
import models
from charge_trainer import ChargeTrainer
from ocpmodels.common import logger
from ocpmodels.common.registry import registry
from ocpmodels.common.utils import setup_logging
from DeepDFT import densitymodel
from chg_utils import ProbeGraphAdder
setup_logging()

import matplotlib.pyplot as plt
from torch_geometric.data import Batch

In [None]:
!pip install wandb --upgrade
import wandb
wandb.login()
import pprint

In [None]:
sweep_config = {
    'method':'bayes',
    'metric':{'name':'val/charge_mae', 'goal':'minimize'},
    'parameters':{
        
        'num_interactions':{
            'distribution':'int_uniform',
            'max': 6,
            'min': 1,
        },
        
        'atom_channels':{
            'distribution':'q_log_uniform_values',
            'min': 16,
            'max': 128,
            'q': 8,
        },
        
        'probe_channels':{
            'distribution':'q_log_uniform_values',
            'min': 16,
            'max': 128,
            'q': 8,
        },
        
        'batch_size':{
            'distribution':'q_log_uniform_values',
            'min': 1,
            'max': 16,
            'q': 2,
        },
        
        'train_probes':{
            'distribution':'q_log_uniform_values',
            'min': 100,
            'max': 1000,
            'q': 100,
        },
        
        'atom_filters':{
            'distribution':'q_log_uniform_values',
            'min': 8,
            'max': 128,
            'q': 8,
        },
        
        'probe_filters':{
            'distribution':'q_log_uniform_values',
            'min': 8,
            'max': 128,
            'q': 8,
        },
        
        'atom_gaussians':{
            'distribution':'q_log_uniform_values',
            'min': 8,
            'max': 32,
            'q': 8,
        },
        
        'probe_gaussians':{
            'distribution':'q_log_uniform_values',
            'min': 8,
            'max': 128,
            'q': 8,
        },
        
        'cutoff': {
            'distribution':'int_uniform',
            'max': 6,
            'min': 3,
        },
    }
}

pprint.pprint(sweep_config)

In [None]:
sweep_id = wandb.sweep(sweep_config, project="charge-density-models-sweeps")
print(sweep_id)

In [None]:
def train(config=None):
    with wandb.init(config=config):
        config = wandb.config
        
        task = {
            'dataset': 'lmdb',
            'description': 'Initial test of training on charges',
            'type': 'regression',
            'metric': ['charge_mse', 'charge_mae'],
            'primary_metric': 'charge_mae',
            'labels': ['charge_vals'],
        }
        
        model = {
            'name': 'charge_model',
            'num_interactions': config.num_interactions,
            'atom_channels': config.atom_channels,
            'probe_channels': config.probe_channels,

            'atom_model_config': {
                'name': 'schnet_charge',
                'num_filters':config.atom_filters,
                'num_gaussians':config.atom_gaussians,
                'cutoff':config.cutoff,
            },

            'probe_model_config': {
                'name': 'schnet_charge',
                'num_filters':config.probe_filters,
                'num_gaussians':config.probe_gaussians,
                'cutoff':config.cutoff,
            },
        }
        
        optimizer = {
            'optimizer': 'Adam',
            'batch_size': config.batch_size,
            'eval_batch_size': 10,
            'num_workers': 1,
            'lr_initial': 5e-4,
            'scheduler': "ReduceLROnPlateau",
            'mode': "min",
            'factor': 0.96,
            'patience': 1,
            'max_epochs': 300,
        }
        
        dataset = [
            {'src': '../chg/100/train', 'normalize_labels': False}, # train set 
            {'src': '../chg/100/val'}, # val set (optional)
            # {'src': train_src} # test set (optional - writes predictions to disk)
            ]
        
        trainer_config = {
            'trainer': 'charge',
            'identifier': 'sweep_run',
            'is_debug': False,
            'run_dir': './runs/',
            'print_every': 1,
            'seed': 2,
            'logger': 'wandb',
            'local_rank': 0,
            'amp': True,

            'cutoff': config.cutoff,
            'train_probes': config.train_probes,
            'val_probes': 1000,
            'test_probes': 1000,
        }
        
        trainer = registry.get_trainer_class(
            trainer_config['trainer'])(task = task,
                                       model = model,
                                       dataset = dataset,
                                       optimizer = optimizer,
                                       **trainer_config)
        
        trainer.train()

In [None]:
wandb.agent('charge-density-models-sweeps/c4ve9o2z', train, count=100)