
# Neural XC training using KSDFTTrainer

We provide a network that consists of:

`init_fn` and `apply_fn`

Then the network is being trained using `KSDFTTrainer`.

1. Load configuration
2. Load dataset
3. Define model
4. Train model


In [1]:
!pip3 list
# !python -c "import jax; import jaxlib; import os; import sys; print(jax.__version__); print(jaxlib.__version__); print(os.environ.get('CUDA_VISIBLE_DEVICES')); print(jax.default_backend()); prin/t(jax.devices())"
!nvidia-smi

Package                  Version     Editable project location
------------------------ ----------- -------------------------
absl-py                  2.3.1
asttokens                3.0.0
chex                     0.1.89
comm                     0.2.2
contourpy                1.3.2
cycler                   0.12.1
debugpy                  1.8.15
decorator                5.2.1
einops                   0.8.1
etils                    1.13.0
exceptiongroup           1.3.0
executing                2.2.0
flax                     0.10.4
fonttools                4.59.0
fsspec                   2025.7.0
h5py                     3.14.0
horqrux                  0.9.2
humanize                 4.12.3
importlib_resources      6.5.2
ipykernel                6.30.0
ipython                  8.37.0
jax                      0.4.38
jax-cuda12-pjrt          0.4.38
jax-cuda12-plugin        0.4.38
jax-dft                  0.0.0
jaxlib                   0.4.38
jaxopt                   0.8.5
jedi                

In [2]:
import jax
jax.devices()

[CudaDevice(id=0)]

In [3]:
import jax
# Set JAX to use 64-bit precision and enable detailed logging for debugging
jax.config.update("jax_enable_x64", True)



In [4]:
import os
import qedft
from qedft.models.networks import LocalMLP, GlobalMLP, LocalQNN
from qedft.config.config import Config
from pathlib import Path
from qedft.train.od.trainer import KSDFTTrainer

# Get project path
project_path = Path(os.path.dirname(os.path.dirname(qedft.__file__)))

# Load base configuration
config = Config(config_path=project_path / 'qedft' / 'config' / 'train_config.yaml').config
config

{'name': 'test',
 'experiment_name': 'test',
 'network_type': 'ksr',
 'molecule_name': 'h2',
 'molecule_names': ['h2'],
 'dataset1': [128, 384],
 'rng': 0,
 'save_plot_loss': False,
 'save_every_n': 20,
 'activation': 'tanh',
 'n_neurons': 513,
 'n_layers': 2,
 'n_qubits': 9,
 'n_reupload_layers': 1,
 'use_rzz_parametrized_entanglers': False,
 'chebychev_reuploading': False,
 'add_reversed_rzz': False,
 'entangling_block_type': 'alternate_linear',
 'single_qubit_rotations': ['rz', 'rx', 'rz'],
 'use_same_parameters': False,
 'add_negative_transform': False,
 'wrap_with_self_interaction_layer': False,
 'wrap_with_global_functional': False,
 'use_correlators_in_output': False,
 'output_operators': ['Z'],
 'use_bias_in_output': False,
 'max_train_steps': 10000,
 'factr': 1.0,
 'pgtol': 1e-14,
 'm': 20,
 'maxfun': 20,
 'maxiter': 2,
 'num_iterations': 15,
 'ks_iter_to_ignore': 10,
 'discount_factor': 0.9,
 'alpha': 0.5,
 'alpha_decay': 0.9,
 'num_mixing_iterations': 1,
 'density_mse_conver

In [5]:
# Update with specific settings
config.update({
    'molecule_name': 'h2',
    'dataset1': [128, 384],
    'rng': 0,

    # Network architecture settings
    'network_type': 'mlp',  # or 'mlp' for local, 'mlp_ksr' for global
    'n_neurons': 128,
    'n_layers': 3,
    'activation': 'tanh',
    'density_normalization_factor': 2.0,
    'wrap_with_negative_transform': True,
    'wrap_self_interaction': True,
    'use_amplitude_encoding': False,

    # QNN settings
    'qnn_type': 'LocalQNN',  # "LocalQNN", "GlobalQNN"
    'layer_type': 'DirectQNN',  # "DirectQNN", "ChebyshevQNN", "ProductQNN"
    'map_fn': None,
    'n_qubits': 4,
    'n_layers': 4,

    # Optimizer settings
    'maxfun': 2,
    'maxiter': 2,
    'factr': 1,
    'm': 20,
    'pgtol': 1e-14
})

***
## Training Global MLP

In [6]:
# Choose network type based on config
NetworkClass = GlobalMLP if config['network_type'] == 'mlp_ksr' else LocalMLP
network = NetworkClass(config)

In [7]:
# Initialize trainer
trainer = KSDFTTrainer(
    config_dict=config,
    network=network,
    data_path=project_path / 'data' / 'od'
)

[32m2025-07-21 15:30:47.319[0m | [1mINFO    [0m | [36mqedft.train.od.trainer[0m:[36m__init__[0m:[36m58[0m - [1mInitialized trainer with config: {'name': 'test', 'experiment_name': 'test', 'network_type': 'mlp', 'molecule_name': 'h2', 'molecule_names': ['h2'], 'dataset1': [128, 384], 'rng': 0, 'save_plot_loss': False, 'save_every_n': 20, 'activation': 'tanh', 'n_neurons': 128, 'n_layers': 4, 'n_qubits': 4, 'n_reupload_layers': 1, 'use_rzz_parametrized_entanglers': False, 'chebychev_reuploading': False, 'add_reversed_rzz': False, 'entangling_block_type': 'alternate_linear', 'single_qubit_rotations': ['rz', 'rx', 'rz'], 'use_same_parameters': False, 'add_negative_transform': False, 'wrap_with_self_interaction_layer': False, 'wrap_with_global_functional': False, 'use_correlators_in_output': False, 'output_operators': ['Z'], 'use_bias_in_output': False, 'max_train_steps': 10000, 'factr': 1, 'pgtol': 1e-14, 'm': 20, 'maxfun': 2, 'maxiter': 2, 'num_iterations': 15, 'ks_iter_to_igno

In [8]:
# Train model
params, loss, info = trainer.train(
    # checkpoint_path=project_path / 'tests' / 'ckpts' / 'ckpt-00001',
    checkpoint_save_dir=project_path / 'tests' / 'ckpts'
)

[32m2025-07-21 15:30:52.125[0m | [1mINFO    [0m | [36mqedft.data_io.dataset_loader[0m:[36mload_molecular_datasets[0m:[36m63[0m - [1mLoading dataset for h2[0m
[32m2025-07-21 15:30:52.127[0m | [1mINFO    [0m | [36mqedft.data_io.dataset_loader[0m:[36mload_molecular_datasets[0m:[36m83[0m - [1mLoading dataset from /home/isokolov/qex/data/od/h2[0m
[32m2025-07-21 15:30:52.135[0m | [1mINFO    [0m | [36mqedft.data_io.dataset_loader[0m:[36mload_molecular_datasets[0m:[36m95[0m - [1mTraining distances: [128, 384][0m
[32m2025-07-21 15:30:52.135[0m | [1mINFO    [0m | [36mqedft.data_io.dataset_loader[0m:[36mload_molecular_datasets[0m:[36m101[0m - [1mNumber of electrons: 2[0m
[32m2025-07-21 15:30:52.136[0m | [1mINFO    [0m | [36mqedft.data_io.dataset_loader[0m:[36mload_molecular_datasets[0m:[36m102[0m - [1mGrid shape: (513,)[0m
[32m2025-07-21 15:30:55.134[0m | [1mINFO    [0m | [36mqedft.models.wrappers[0m:[36mwrap_network[0m:[36m189

In [9]:
print(f"Training completed with final loss: {loss}")
print(f"Optimization info: {info}")
print(f"Params: {params}")
print(f"Params shape: {params.shape}")

Training completed with final loss: 0.2735378201571361
Optimization info: {'grad': array([ 9.96311132e-04,  4.82359757e-03, -4.19719858e-02, ...,
        8.92079835e-03, -1.13485022e+00,  1.74519492e+00], shape=(49922,)), 'task': 'STOP: TOTAL NO. OF F,G EVALUATIONS EXCEEDS LIMIT', 'funcalls': 3, 'nit': 1, 'warnflag': 1}
Params: [ 0.04249795  0.10944407 -0.19366582 ... -0.11298757  0.07328377
  0.95798439]
Params shape: (49922,)


## Evaluating

In [None]:
states = trainer.evaluate(
    checkpoint_path=project_path / 'tests' / "ckpts" / "ckpt-00000",
    plot_distances=[128, 384],
)


[32m2025-07-21 15:31:18.590[0m | [1mINFO    [0m | [36mqedft.data_io.dataset_loader[0m:[36mload_molecular_datasets[0m:[36m63[0m - [1mLoading dataset for h2[0m
[32m2025-07-21 15:31:18.591[0m | [1mINFO    [0m | [36mqedft.data_io.dataset_loader[0m:[36mload_molecular_datasets[0m:[36m83[0m - [1mLoading dataset from /home/isokolov/qex/data/od/h2[0m
[32m2025-07-21 15:31:18.594[0m | [1mINFO    [0m | [36mqedft.data_io.dataset_loader[0m:[36mload_molecular_datasets[0m:[36m95[0m - [1mTraining distances: [128, 384][0m
[32m2025-07-21 15:31:18.595[0m | [1mINFO    [0m | [36mqedft.data_io.dataset_loader[0m:[36mload_molecular_datasets[0m:[36m101[0m - [1mNumber of electrons: 2[0m
[32m2025-07-21 15:31:18.595[0m | [1mINFO    [0m | [36mqedft.data_io.dataset_loader[0m:[36mload_molecular_datasets[0m:[36m102[0m - [1mGrid shape: (513,)[0m
[32m2025-07-21 15:31:18.632[0m | [1mINFO    [0m | [36mqedft.models.wrappers[0m:[36mwrap_network[0m:[36m189

FileNotFoundError: [Errno 2] No such file or directory: '/home/isokolov/qex/tests/ckpts/ckpt-00001'

***
## Training Quantum Neural Networks

In [None]:
# Local model
NetworkClass = LocalQNN
config['network_type'] = 'mlp'
config['use_amplitude_encoding'] = False
network = NetworkClass(config)

# Initialize trainer
trainer = KSDFTTrainer(
    config_dict=config,
    network=network,
    data_path=project_path / 'data' / 'od'
)

[32m2025-04-14 18:06:15.832[0m | [1mINFO    [0m | [36mqedft.train.od.trainer[0m:[36m__init__[0m:[36m59[0m - [1mInitialized trainer with config: {'name': 'test', 'experiment_name': 'test', 'network_type': 'mlp', 'molecule_name': 'h2', 'molecule_names': ['h2'], 'dataset1': [128, 384], 'rng': 0, 'save_plot_loss': False, 'save_every_n': 20, 'activation': 'tanh', 'n_neurons': 128, 'n_layers': 4, 'n_qubits': 4, 'n_reupload_layers': 1, 'use_rzz_parametrized_entanglers': False, 'chebychev_reuploading': False, 'add_reversed_rzz': False, 'entangling_block_type': 'alternate_linear', 'single_qubit_rotations': ['rz', 'rx', 'rz'], 'use_same_parameters': False, 'add_negative_transform': False, 'wrap_with_self_interaction_layer': False, 'wrap_with_global_functional': False, 'use_correlators_in_output': False, 'output_operators': ['Z'], 'use_bias_in_output': False, 'max_train_steps': 10000, 'factr': 1, 'pgtol': 1e-14, 'm': 20, 'maxfun': 2, 'maxiter': 2, 'num_iterations': 15, 'ks_iter_to_igno

In [None]:
# Train model
params, loss, info = trainer.train(
    # checkpoint_path=project_path / 'tests' / 'ckpts' / 'ckpt-00001',
    checkpoint_save_dir=project_path / 'tests' / 'ckpts'
)

[32m2025-04-14 18:06:15.837[0m | [1mINFO    [0m | [36mqedft.data_io.dataset_loader[0m:[36mload_molecular_datasets[0m:[36m63[0m - [1mLoading dataset for h2[0m
[32m2025-04-14 18:06:15.838[0m | [1mINFO    [0m | [36mqedft.data_io.dataset_loader[0m:[36mload_molecular_datasets[0m:[36m83[0m - [1mLoading dataset from /Users/igorsokolov/PycharmProjects/qedft/data/od/h2[0m
[32m2025-04-14 18:06:15.840[0m | [1mINFO    [0m | [36mqedft.data_io.dataset_loader[0m:[36mload_molecular_datasets[0m:[36m95[0m - [1mTraining distances: [128, 384][0m
[32m2025-04-14 18:06:15.840[0m | [1mINFO    [0m | [36mqedft.data_io.dataset_loader[0m:[36mload_molecular_datasets[0m:[36m101[0m - [1mNumber of electrons: 2[0m
[32m2025-04-14 18:06:15.840[0m | [1mINFO    [0m | [36mqedft.data_io.dataset_loader[0m:[36mload_molecular_datasets[0m:[36m102[0m - [1mGrid shape: (513,)[0m
[32m2025-04-14 18:06:16.202[0m | [1mINFO    [0m | [36mqedft.models.quantum.quantum_models

KeyboardInterrupt: 

In [None]:
print(f"Training completed with final loss: {loss}")
print(f"Optimization info: {info}")
print(f"Params: {params}")
print(f"Params shape: {params.shape}")

Training completed with final loss: 1.5183673980519223
Optimization info: {'grad': array([-1.71702502e+00,  1.38279942e+00, -1.73077662e+00,  8.80152018e-02,
        1.89468187e-02,  7.65675404e-02, -4.16455339e-01, -6.87734670e-02,
       -4.16367478e-01, -1.77017195e+00,  1.40453201e+00, -1.78526662e+00,
       -6.49569404e-01,  4.16965948e-01, -6.48556226e-01,  4.74949575e-02,
       -1.68425865e-01,  4.74694475e-02,  6.48634315e-02,  4.82608181e-01,
        6.90992656e-02, -6.48559069e-01,  6.45917479e-01, -6.47515901e-01,
        1.30619078e-01,  4.60207926e-01,  1.39392388e-01,  1.03025516e-01,
       -2.33559947e-01,  1.04589301e-01,  3.56400514e-01, -3.69753865e-05,
        3.56650440e-01,  1.01798449e-01,  1.01815418e+00,  1.04123187e-01,
       -3.79558758e-01,  2.52509012e-01, -3.80531982e-01, -2.73927123e-01,
       -8.97630050e-03, -2.74366758e-01,  3.36531426e-02, -1.63257048e-01,
        3.47760489e-02, -3.70868828e-01,  4.56596677e-01, -3.68655690e-01,
        4.0288092

In [None]:
# Demo snippet

import os
import qedft
from qedft.models.networks import LocalMLP, LocalQNN
from qedft.config.config import Config
from pathlib import Path
from qedft.train.od.trainer import KSDFTTrainer
from loguru import logger

# Get project path
project_path = Path(os.path.dirname(os.path.dirname(qedft.__file__)))

# Load base configuration
config = Config(config_path=project_path / 'qedft' / 'config' / 'train_config.yaml').config

# The network object has the function `build_network` that returns (init_fn, apply_fn)
network = LocalMLP(config)

# Initialize trainer
trainer = KSDFTTrainer(
    config_dict=config,
    network=network,
    data_path=project_path / 'data' / 'od'
)

# Train model
params, loss, info = trainer.train(
    # checkpoint_path=project_path / 'tests' / 'ckpts' / 'ckpt-00001',
    checkpoint_save_dir=project_path / 'tests' / 'ckpts'
)

[32m2025-05-02 14:44:45.109[0m | [1mINFO    [0m | [36mqedft.train.od.trainer[0m:[36m__init__[0m:[36m58[0m - [1mInitialized trainer with config: {'name': 'test', 'experiment_name': 'test', 'network_type': 'ksr', 'molecule_name': 'h2', 'molecule_names': ['h2'], 'dataset1': [128, 384], 'rng': 0, 'save_plot_loss': False, 'save_every_n': 20, 'activation': 'tanh', 'n_neurons': 513, 'n_layers': 2, 'n_qubits': 9, 'n_reupload_layers': 1, 'use_rzz_parametrized_entanglers': False, 'chebychev_reuploading': False, 'add_reversed_rzz': False, 'entangling_block_type': 'alternate_linear', 'single_qubit_rotations': ['rz', 'rx', 'rz'], 'use_same_parameters': False, 'add_negative_transform': False, 'wrap_with_self_interaction_layer': False, 'wrap_with_global_functional': False, 'use_correlators_in_output': False, 'output_operators': ['Z'], 'use_bias_in_output': False, 'max_train_steps': 10000, 'factr': 1.0, 'pgtol': 1e-14, 'm': 20, 'maxfun': 20, 'maxiter': 2, 'num_iterations': 15, 'ks_iter_to_i

# Summary

This notebook demonstrates training a neural network model for Kohn-Sham DFT:

1. Sets up configuration for a global MLP model with 3 layers, 128 neurons, and tanh activation
2. Also sets up configuration for a local QNN model with 2 qubits, 1 layer, and DirectQNN ansatz
3. Initializes KSDFTTrainer with the model configuration and H2 molecule dataset
4. Trains the model using L-BFGS-B optimization:
   - Loads parameters from checkpoint ckpt-00009
   - Uses small maxfun/maxiter of 2 for testing
   - Minimizes loss combining energy and density errors
   - Saves checkpoints during training

The trained model learns to predict the exchange-correlation functional for H2,
with the goal of accurately reproducing DFT energies and electron densities.

