# Minimal Example Unsupervised

This jupyter notebook serves to provide a minimal example of training an unsupervised neural-network quantum states

## Import library

In [1]:
import sys
sys.path.append(r'../')

import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt

## Set the same seed
np.random.seed(42)
tf.random.set_seed(42)

%load_ext autoreload
%autoreload 2
%matplotlib inline

## Training

Here, we use the library for the application to find the ground state of a one-dimensional Ising model with $4$ particles with open boundary conditions. We set $J = 2$ and $h = 1$ as the parameters of the Ising model, which means that the model is in ferromagnetic phase.  We use Gibbs sampling with $1000$ samples. We use restricted Boltzmann machine for positive real wave function where $\alpha = 2$. The weights and biases are initialised from a random normal with zero-mean and $0.01$ standard deviation. We use Adam~\cite{kingma2014adam} optimiser with $0.001$ learning rate. For the stopping criterion, we set $\varepsilon_\sigma$ as $0.01$ and train for maximally $1000$ epochs. At the end of the training, we compute the $M^2_F$ observable. 

In [None]:
from graph import Hypercube
from hamiltonian import Ising
from sampler import Gibbs
from functools import partial
from observable import MagnetizationZSquare, CorrelationZ
from model.rbm.realpos import RBMRealPos
from learner import Learner
from logger import Logger

## Define the model
chain1d = Hypercube(length=4, dimension=1, pbc=False, next_nearest=False)
ising = Ising(chain1d, j=2.0, h=1.0)
ising.diagonalize()

## Define the sampler
gibbs = Gibbs(num_samples=1000)

## Define the neural networks model
initializer = partial(np.random.normal, loc=0.0, scale=0.01)
rbm = RBMRealPos(chain1d.length, density=2, initializer=initializer)

## Define the observables
mz2 = MagnetizationZSquare(chain1d.num_points, chain1d.dimension)
obs = [mz2]

## Define hyperparameters for learner
learning_rate = 0.001
optimizer = tf.keras.optimizers.Adam(learning_rate)
stopping_threshold = 0.01
num_epochs = 1000

## Training Process
learner = Learner(hamiltonian=ising, model=rbm, sampler=gibbs, optimizer=optimizer,
                  num_epochs = num_epochs, stopping_threshold= stopping_threshold,
                  observables = obs, reference_energy=ising.get_gs_energy())
learner.learn()

## Logging process
result_path = 'result/'
subpath = 'ising'
logger = Logger(result_path, subpath)
logger.log(learner)

===== Training start
Epoch: 0, energy: -4.0400, std: 3.3990, std / mean: 0.8413, relerror: 0.41056, time: 0.16372
Epoch: 1, energy: -3.8920, std: 3.5083, std / mean: 0.9014, relerror: 0.43216, time: 0.01446
Epoch: 2, energy: -3.9079, std: 3.4626, std / mean: 0.8860, relerror: 0.42984, time: 0.01449
Epoch: 3, energy: -3.9242, std: 3.3594, std / mean: 0.8561, relerror: 0.42747, time: 0.01556
Epoch: 4, energy: -4.2158, std: 3.5878, std / mean: 0.8510, relerror: 0.38492, time: 0.01445
Epoch: 5, energy: -4.3039, std: 3.4864, std / mean: 0.8101, relerror: 0.37206, time: 0.01440
Epoch: 6, energy: -4.0482, std: 3.4942, std / mean: 0.8631, relerror: 0.40937, time: 0.01443
Epoch: 7, energy: -3.9318, std: 3.4883, std / mean: 0.8872, relerror: 0.42635, time: 0.01427
Epoch: 8, energy: -4.1921, std: 3.3951, std / mean: 0.8099, relerror: 0.38838, time: 0.01437
Epoch: 9, energy: -4.2518, std: 3.4232, std / mean: 0.8051, relerror: 0.37967, time: 0.01420
Epoch: 10, energy: -3.9884, std: 3.3901, std / me

Epoch: 89, energy: -4.8997, std: 3.2291, std / mean: 0.6590, relerror: 0.28514, time: 0.01478
Epoch: 90, energy: -4.9374, std: 3.2174, std / mean: 0.6516, relerror: 0.27964, time: 0.01438
Epoch: 91, energy: -4.8701, std: 3.2026, std / mean: 0.6576, relerror: 0.28946, time: 0.01414
Epoch: 92, energy: -4.6250, std: 3.2668, std / mean: 0.7063, relerror: 0.32521, time: 0.01410
Epoch: 93, energy: -4.7456, std: 3.2310, std / mean: 0.6808, relerror: 0.30763, time: 0.01426
Epoch: 94, energy: -4.8747, std: 3.1526, std / mean: 0.6467, relerror: 0.28878, time: 0.01409
Epoch: 95, energy: -4.8164, std: 3.1762, std / mean: 0.6595, relerror: 0.29730, time: 0.01398
Epoch: 96, energy: -5.1002, std: 3.1170, std / mean: 0.6112, relerror: 0.25589, time: 0.01396
Epoch: 97, energy: -5.1193, std: 3.2122, std / mean: 0.6275, relerror: 0.25310, time: 0.01428
Epoch: 98, energy: -4.9184, std: 3.1117, std / mean: 0.6327, relerror: 0.28241, time: 0.01407
Epoch: 99, energy: -5.0257, std: 3.2049, std / mean: 0.6377,

Epoch: 176, energy: -6.6833, std: 1.1468, std / mean: 0.1716, relerror: 0.02491, time: 0.01463
Epoch: 177, energy: -6.6276, std: 1.2501, std / mean: 0.1886, relerror: 0.03304, time: 0.01411
Epoch: 178, energy: -6.6270, std: 1.1820, std / mean: 0.1784, relerror: 0.03312, time: 0.01431
Epoch: 179, energy: -6.6251, std: 1.2144, std / mean: 0.1833, relerror: 0.03341, time: 0.01437
Epoch: 180, energy: -6.6509, std: 1.1785, std / mean: 0.1772, relerror: 0.02965, time: 0.01444
Epoch: 181, energy: -6.6953, std: 1.0726, std / mean: 0.1602, relerror: 0.02317, time: 0.01409
Epoch: 182, energy: -6.6349, std: 1.1388, std / mean: 0.1716, relerror: 0.03198, time: 0.01411
Epoch: 183, energy: -6.6960, std: 1.0327, std / mean: 0.1542, relerror: 0.02306, time: 0.01419
Epoch: 184, energy: -6.5840, std: 1.1771, std / mean: 0.1788, relerror: 0.03940, time: 0.01445
Epoch: 185, energy: -6.6398, std: 1.1534, std / mean: 0.1737, relerror: 0.03126, time: 0.01402
Epoch: 186, energy: -6.6722, std: 1.1132, std / me

Epoch: 262, energy: -6.7231, std: 0.6476, std / mean: 0.0963, relerror: 0.01911, time: 0.01524
Epoch: 263, energy: -6.7359, std: 0.7025, std / mean: 0.1043, relerror: 0.01724, time: 0.01420
Epoch: 264, energy: -6.7428, std: 0.7021, std / mean: 0.1041, relerror: 0.01623, time: 0.01402
Epoch: 265, energy: -6.7328, std: 0.6430, std / mean: 0.0955, relerror: 0.01769, time: 0.01385
Epoch: 266, energy: -6.7592, std: 0.6948, std / mean: 0.1028, relerror: 0.01385, time: 0.01409
Epoch: 267, energy: -6.7335, std: 0.7382, std / mean: 0.1096, relerror: 0.01760, time: 0.01441
Epoch: 268, energy: -6.7804, std: 0.7055, std / mean: 0.1040, relerror: 0.01075, time: 0.01424
Epoch: 269, energy: -6.7321, std: 0.7569, std / mean: 0.1124, relerror: 0.01780, time: 0.01421
Epoch: 270, energy: -6.7079, std: 0.7460, std / mean: 0.1112, relerror: 0.02133, time: 0.01426
Epoch: 271, energy: -6.7466, std: 0.6891, std / mean: 0.1021, relerror: 0.01567, time: 0.01432
Epoch: 272, energy: -6.7469, std: 0.6964, std / me

Epoch: 349, energy: -6.7507, std: 0.5821, std / mean: 0.0862, relerror: 0.01508, time: 0.01472
Epoch: 350, energy: -6.7509, std: 0.7055, std / mean: 0.1045, relerror: 0.01506, time: 0.01419
Epoch: 351, energy: -6.7239, std: 0.6869, std / mean: 0.1022, relerror: 0.01898, time: 0.01402
Epoch: 352, energy: -6.7806, std: 0.5309, std / mean: 0.0783, relerror: 0.01072, time: 0.01402
Epoch: 353, energy: -6.7211, std: 0.6631, std / mean: 0.0987, relerror: 0.01940, time: 0.01404
Epoch: 354, energy: -6.7370, std: 0.6807, std / mean: 0.1010, relerror: 0.01709, time: 0.01418
Epoch: 355, energy: -6.7665, std: 0.6276, std / mean: 0.0928, relerror: 0.01278, time: 0.01425
Epoch: 356, energy: -6.7420, std: 0.6821, std / mean: 0.1012, relerror: 0.01636, time: 0.01401
Epoch: 357, energy: -6.7474, std: 0.6985, std / mean: 0.1035, relerror: 0.01557, time: 0.01412
Epoch: 358, energy: -6.7360, std: 0.5884, std / mean: 0.0874, relerror: 0.01723, time: 0.01414
Epoch: 359, energy: -6.7609, std: 0.5705, std / me

Epoch: 435, energy: -6.7834, std: 0.5305, std / mean: 0.0782, relerror: 0.01031, time: 0.01444
Epoch: 436, energy: -6.7637, std: 0.5534, std / mean: 0.0818, relerror: 0.01319, time: 0.01398
Epoch: 437, energy: -6.7639, std: 0.4912, std / mean: 0.0726, relerror: 0.01316, time: 0.01422
Epoch: 438, energy: -6.7305, std: 0.6346, std / mean: 0.0943, relerror: 0.01802, time: 0.01406
Epoch: 439, energy: -6.7610, std: 0.4885, std / mean: 0.0722, relerror: 0.01358, time: 0.01393
Epoch: 440, energy: -6.7960, std: 0.4722, std / mean: 0.0695, relerror: 0.00848, time: 0.01398
Epoch: 441, energy: -6.7642, std: 0.4830, std / mean: 0.0714, relerror: 0.01311, time: 0.01413
Epoch: 442, energy: -6.7611, std: 0.5628, std / mean: 0.0832, relerror: 0.01357, time: 0.01429
Epoch: 443, energy: -6.7492, std: 0.5683, std / mean: 0.0842, relerror: 0.01530, time: 0.01406
Epoch: 444, energy: -6.7426, std: 0.5226, std / mean: 0.0775, relerror: 0.01627, time: 0.01408
Epoch: 445, energy: -6.7459, std: 0.6108, std / me

Epoch: 524, energy: -6.7691, std: 0.4866, std / mean: 0.0719, relerror: 0.01240, time: 0.03491
Epoch: 525, energy: -6.7894, std: 0.4429, std / mean: 0.0652, relerror: 0.00944, time: 0.02957
Epoch: 526, energy: -6.7684, std: 0.5031, std / mean: 0.0743, relerror: 0.01250, time: 0.02604
Epoch: 527, energy: -6.7474, std: 0.5957, std / mean: 0.0883, relerror: 0.01557, time: 0.02311
Epoch: 528, energy: -6.7520, std: 0.6269, std / mean: 0.0928, relerror: 0.01489, time: 0.02101
Epoch: 529, energy: -6.7443, std: 0.6028, std / mean: 0.0894, relerror: 0.01601, time: 0.01951
Epoch: 530, energy: -6.7409, std: 0.5491, std / mean: 0.0815, relerror: 0.01651, time: 0.01899
Epoch: 531, energy: -6.7828, std: 0.4567, std / mean: 0.0673, relerror: 0.01040, time: 0.01800
Epoch: 532, energy: -6.7641, std: 0.5343, std / mean: 0.0790, relerror: 0.01313, time: 0.01728
Epoch: 533, energy: -6.7763, std: 0.4719, std / mean: 0.0696, relerror: 0.01135, time: 0.01652
Epoch: 534, energy: -6.7580, std: 0.5224, std / me

Epoch: 615, energy: -6.7272, std: 0.6407, std / mean: 0.0952, relerror: 0.01851, time: 0.01810
Epoch: 616, energy: -6.7696, std: 0.5685, std / mean: 0.0840, relerror: 0.01233, time: 0.01689
Epoch: 617, energy: -6.7899, std: 0.4788, std / mean: 0.0705, relerror: 0.00935, time: 0.01610
Epoch: 618, energy: -6.7235, std: 0.6663, std / mean: 0.0991, relerror: 0.01905, time: 0.01543
Epoch: 619, energy: -6.7568, std: 0.5968, std / mean: 0.0883, relerror: 0.01419, time: 0.01530
Epoch: 620, energy: -6.7864, std: 0.5717, std / mean: 0.0842, relerror: 0.00987, time: 0.01490
Epoch: 621, energy: -6.7778, std: 0.6038, std / mean: 0.0891, relerror: 0.01112, time: 0.01467
Epoch: 622, energy: -6.7640, std: 0.5912, std / mean: 0.0874, relerror: 0.01315, time: 0.01415
Epoch: 623, energy: -6.7409, std: 0.5447, std / mean: 0.0808, relerror: 0.01651, time: 0.01412
Epoch: 624, energy: -6.7314, std: 0.6640, std / mean: 0.0986, relerror: 0.01790, time: 0.01379
Epoch: 625, energy: -6.7464, std: 0.6082, std / me

Epoch: 702, energy: -6.7626, std: 0.4935, std / mean: 0.0730, relerror: 0.01334, time: 0.01454
Epoch: 703, energy: -6.7653, std: 0.5940, std / mean: 0.0878, relerror: 0.01295, time: 0.01397
Epoch: 704, energy: -6.7887, std: 0.5184, std / mean: 0.0764, relerror: 0.00954, time: 0.01412
Epoch: 705, energy: -6.7754, std: 0.5647, std / mean: 0.0834, relerror: 0.01148, time: 0.01405
Epoch: 706, energy: -6.7729, std: 0.5701, std / mean: 0.0842, relerror: 0.01185, time: 0.01407
Epoch: 707, energy: -6.7626, std: 0.5938, std / mean: 0.0878, relerror: 0.01334, time: 0.01406
Epoch: 708, energy: -6.7757, std: 0.5305, std / mean: 0.0783, relerror: 0.01144, time: 0.01399
Epoch: 709, energy: -6.7798, std: 0.5319, std / mean: 0.0785, relerror: 0.01084, time: 0.01402
Epoch: 710, energy: -6.7650, std: 0.5794, std / mean: 0.0856, relerror: 0.01299, time: 0.01416
Epoch: 711, energy: -6.7538, std: 0.6037, std / mean: 0.0894, relerror: 0.01462, time: 0.01404
Epoch: 712, energy: -6.7394, std: 0.6532, std / me

Epoch: 789, energy: -6.7867, std: 0.6214, std / mean: 0.0916, relerror: 0.00982, time: 0.01461
Epoch: 790, energy: -6.7738, std: 0.5006, std / mean: 0.0739, relerror: 0.01171, time: 0.01436
Epoch: 791, energy: -6.7771, std: 0.6731, std / mean: 0.0993, relerror: 0.01123, time: 0.01415
Epoch: 792, energy: -6.7898, std: 0.6801, std / mean: 0.1002, relerror: 0.00937, time: 0.01439
Epoch: 793, energy: -6.7586, std: 0.6484, std / mean: 0.0959, relerror: 0.01393, time: 0.01409
Epoch: 794, energy: -6.8051, std: 0.4329, std / mean: 0.0636, relerror: 0.00714, time: 0.01427
Epoch: 795, energy: -6.7816, std: 0.4669, std / mean: 0.0688, relerror: 0.01057, time: 0.01406
Epoch: 796, energy: -6.7663, std: 0.6107, std / mean: 0.0903, relerror: 0.01280, time: 0.01403
Epoch: 797, energy: -6.7310, std: 0.6179, std / mean: 0.0918, relerror: 0.01796, time: 0.01403
Epoch: 798, energy: -6.7706, std: 0.5153, std / mean: 0.0761, relerror: 0.01218, time: 0.01394
Epoch: 799, energy: -6.7822, std: 0.5504, std / me

Epoch: 875, energy: -6.7913, std: 0.5343, std / mean: 0.0787, relerror: 0.00915, time: 0.01432
Epoch: 876, energy: -6.7564, std: 0.5793, std / mean: 0.0857, relerror: 0.01425, time: 0.01419
Epoch: 877, energy: -6.7992, std: 0.5096, std / mean: 0.0750, relerror: 0.00801, time: 0.01432
Epoch: 878, energy: -6.7810, std: 0.6099, std / mean: 0.0899, relerror: 0.01066, time: 0.01398
Epoch: 879, energy: -6.7613, std: 0.6124, std / mean: 0.0906, relerror: 0.01354, time: 0.01398
Epoch: 880, energy: -6.7862, std: 0.4997, std / mean: 0.0736, relerror: 0.00990, time: 0.01409
Epoch: 881, energy: -6.7704, std: 0.6347, std / mean: 0.0938, relerror: 0.01221, time: 0.01401
Epoch: 882, energy: -6.7928, std: 0.5367, std / mean: 0.0790, relerror: 0.00894, time: 0.01416
Epoch: 883, energy: -6.7828, std: 0.5131, std / mean: 0.0756, relerror: 0.01040, time: 0.01412
Epoch: 884, energy: -6.7843, std: 0.5223, std / mean: 0.0770, relerror: 0.01019, time: 0.01429
Epoch: 885, energy: -6.7914, std: 0.6802, std / me