# Task #1 

A template code for training an RBM on H$_2$ data for $r = 1.2$ is shown here. Modify this!

Imports and loading in data:

In [1]:
import csv
import re
from glob import glob
from pathlib import Path
import numpy as np
import torch
import matplotlib.pyplot as plt

from RBM_helper import RBM

import H2_energy_calculator
from datetime import datetime


Define the RBM:

In [2]:
n_vis = 2
n_hin = 10


Train the RBM:

In [13]:
epochs = 1000 # number of training steps
num_samples = 2000 # number of samples to generate from the RBM to calculate the H2 energy

coeffs = np.loadtxt("H2_data/H2_coefficients.txt")

[26]


In [14]:
log_dir = Path("training_logs")
logs = { }
for sample_file in glob('H2_data/*_samples.txt'):
    print(sample_file)
    match = re.match(r".*R_([0-9.]+)_samples.txt", sample_file)
    r = match.groups(0)[0]
    coeff_idx = np.where(coeffs[:,0].astype(np.float) == float(r))[0][0]
    coeff = coeffs[coeff_idx,:]

    psi_file = f"H2_data/R_{r}_psi.txt"
    psi = np.loadtxt(psi_file)
    rbm = RBM(n_vis, n_hin)
    training_data = torch.from_numpy(np.loadtxt(sample_file))

    true_energy = H2_energy_calculator.energy_from_freq(training_data, coeff)
    print(f"H2 energy for r = {r}: ",true_energy)
    logs[f"r_{r}_energy_diff"] = []
    logs[f"r_{r}_fidelity"] = []
    save_dir = f'params/trained_r_{r}'
    for e in range(1, epochs+1):
        # do one epoch of training
        rbm.train(training_data)

        # now generate samples and calculate the energy
        if e % 100 == 0:
            print("Epoch: ", e)
            # For sampling the RBM, we need to do Gibbs sampling.
            # Initialize the Gibbs sampling chain with init_state as defined below.
            init_state = torch.zeros(num_samples, n_vis)
            RBM_samples = rbm.draw_samples(15, init_state)
            energies = H2_energy_calculator.energy(RBM_samples, coeff, rbm.wavefunction)
            energy = energies.item()
            print("Energy from RBM samples: ", energies.item())

            energy_difference = abs(true_energy - energy)
            print("Energy difference from RBM samples: ", energy_difference)
            logs[f"r_{r}_energy_diff"].append(energy_difference)

            rbm_psi = rbm.psi()
            overlap = np.abs(np.vdot(psi, rbm_psi))
            print("Fidelity from RBM samples: ", overlap)
            logs[f"r_{r}_fidelity"].append(overlap)
            print()

H2_data/R_1.5_samples.txt
H2 energy for r = 1.5:  -1.0066304034616804
Epoch:  100
Energy from RBM samples:  -0.9515390090882518
Energy difference from RBM samples:  0.05509139437342858
Fidelity from RBM samples:  0.983651347758133

Epoch:  200
Energy from RBM samples:  -0.9843282391311512
Energy difference from RBM samples:  0.02230216433052923
Fidelity from RBM samples:  0.9937666636642544

Epoch:  300
Energy from RBM samples:  -0.9929160953349632
Energy difference from RBM samples:  0.013714308126717256
Fidelity from RBM samples:  0.9963411048669342

Epoch:  400
Energy from RBM samples:  -0.9861757212580404
Energy difference from RBM samples:  0.020454682203639996
Fidelity from RBM samples:  0.997122477887463

Epoch:  500
Energy from RBM samples:  -1.0061258859130253
Energy difference from RBM samples:  0.0005045175486551035
Fidelity from RBM samples:  0.9980645207483961

Epoch:  600
Energy from RBM samples:  -1.0030413672851974
Energy difference from RBM samples:  0.0035890361764829

In [15]:
with open(log_dir / f"task1_log.csv", 'w') as f:
    w = csv.DictWriter(f, logs.keys())
    w.writeheader()
    w.writerow(logs)

In [16]:
rbm.save_params("params/")
    #%%



In [None]:

for sample_file in glob('H2_data/*_psi.txt'):
    match = re.match(r".*R_([0-9.]+)_psi.txt", sample_file)
    r = match.groups(0)[0]
    rbm = RBM(n_vis, n_hin)
    save_dir = f'params/trained_r_{r}'
    rbm.load_params(save_dir)
    rbm.wavefunction

    pass
