# Task #1 

A template code for training an RBM on H$_2$ data for $r = 1.2$ is shown here. Modify this!

Imports and loading in data:

In [19]:
import re
from glob import glob
import numpy as np
import torch
import matplotlib.pyplot as plt

from RBM_helper import RBM

import H2_energy_calculator
from datetime import datetime


Define the RBM:

In [20]:
n_vis = 2
n_hin = 10

rbm = RBM(n_vis, n_hin)

Train the RBM:

In [21]:
epochs = 500 # number of training steps
num_samples = 1000 # number of samples to generate from the RBM to calculate the H2 energy
coeff = np.loadtxt("H2_data/H2_coefficients.txt")[20,:]

for sample_file in glob('H2_data/*_samples.txt'):
    print(sample_file)
    match = re.match(r".*R_([0-9.]+)_samples.txt", sample_file)
    r = match.groups(0)[0]
    training_data = torch.from_numpy(np.loadtxt(sample_file))

    true_energy = H2_energy_calculator.energy_from_freq(training_data, coeff)
    print(f"H2 energy for r = {r}: ",true_energy)

    start = datetime.now()
    for e in range(1, epochs+1):
        # do one epoch of training
        rbm.train(training_data)

        # now generate samples and calculate the energy
        if e % 100 == 0:
            end = datetime.now()
            print("\nEpoch: ", e)
            print("\nElapsed: ", end - start)
            print("Sampling the RBM...")

            # For sampling the RBM, we need to do Gibbs sampling.
            # Initialize the Gibbs sampling chain with init_state as defined below.
            init_state = torch.zeros(num_samples, n_vis)
            RBM_samples = rbm.draw_samples(15, init_state)

            print("Done sampling. Calculating energy...")

            energies = H2_energy_calculator.energy(RBM_samples, coeff, rbm.wavefunction)
            print("Energy from RBM samples: ", energies.item())
            start = datetime.now()

    save_dir = f'params/trained_r_{r}'
    rbm.save_params(save_dir)
    #%%



H2_data/R_1.5_samples.txt
H2 energy for r = 1.5:  -1.0497107493775149

Epoch:  100

Elapsed:  0:00:08.395806
Sampling the RBM...
Done sampling. Calculating energy...
Energy from RBM samples:  -0.9561737956832073

Epoch:  200

Elapsed:  0:00:08.644560
Sampling the RBM...
Done sampling. Calculating energy...
Energy from RBM samples:  -0.9912419795364287

Epoch:  300

Elapsed:  0:00:08.797434
Sampling the RBM...
Done sampling. Calculating energy...
Energy from RBM samples:  -0.9763509756356642

Epoch:  400

Elapsed:  0:00:09.081706
Sampling the RBM...
Done sampling. Calculating energy...
Energy from RBM samples:  -0.953318889546748

Epoch:  500

Elapsed:  0:00:09.051178
Sampling the RBM...
Done sampling. Calculating energy...
Energy from RBM samples:  -0.9553125924624883


AttributeError: 'RBM' object has no attribute 'save_params'