# Model Training

In [None]:
import os
import sys
sys.path.append('../../src')

%load_ext autoreload
%autoreload 2
%load_ext autotime

from pathlib import Path

import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

matplotlib.rcParams.update({"font.size": 14})

from models import sqRBM_em
from models import RBM_em
from utils import (
    Discretizer,
    kl_divergence,
    get_project_dir, # returns: Path object of the project directory.
    get_rng, # returns: Numpy RandomState object.
    lr_exp_decay, # returns: The learning rate scaling factor
    load_params,
)

### Define callback function

In [None]:
def callback_em(model):
    TOL = 1e-16
    prob_data = model.prob_data
    p_model, Z = model.compute_p_model(prob_data)

    qre = model.compute_qre(Z)

    kld = 0
    for i in prob_data.keys():
        kld += prob_data[i] * (np.log(prob_data[i] + TOL) - np.log(p_model[i] + TOL))

    return {"qre": qre, "kld": kld, "print": f"qre = {qre}, kld = {kld}"}

def callback_gd(model):
    TOL = 1e-16
    prob_data = model.prob_data
    p_model, Z = model.compute_p_model(prob_data)

    kld = 0
    for i in prob_data.keys():
        kld += prob_data[i] * (np.log(prob_data[i] + TOL) - np.log(p_model[i] + TOL))

    return {"kld": kld, "print": f"kld = {kld}"}

## Experiments

In [None]:
from scipy.spatial import distance

def _binary_to_eigen(x):
        """
        Convert bit values {0, 1} to corresponding spin values {+1, -1}.

        :param x: Input array of values {0, 1}.

        :returns: Output array of values {+1, -1}.
        """
        return (1 - 2 * x).astype(np.int8)

def bernoulli(p, v, M, n_visible, seed):
    np.random.seed(seed)
    bernoulli = 0
    for i in range(M):
        s = np.random.randint(2, size=n_visible)
        d = distance.hamming(list(v), list(s)) * len(v)
        bernoulli += (p ** (n_visible - d)) * ((1 - p) ** d)

    return bernoulli / M

def generate_bitstrings(n):
    bitstrings = []
    for x in range(2**n):
        bitstring = format(x, f'0{n}b')
        bits = [int(b) for b in bitstring]
        bitstrings.append(bits)

    return bitstrings

def even_parity_probability_distribution(n):
    bitstrings = generate_bitstrings(n)
    entries = []

    for bits in bitstrings:
        count = sum(bits)
        if count % 2 == 0:
            entries.append(1)
        else:
            entries.append(0)

    total = sum(entries)
    probability = [entry / total for entry in entries]

    return probability

def cardinality_distribution(n):
    bitstrings = generate_bitstrings(n)
    entries = []

    for bits in bitstrings:
        count = sum(bits)
        if count == n // 2:
            entries.append(1)
        else:
            entries.append(0)

    total = sum(entries)
    probability = [entry / total for entry in entries]

    return probability

def on2_distribution(n, seed=0):
    np.random.seed(seed)
    size = 2**n
    probs = np.zeros(size)
    
    perm = np.random.permutation(size)
    chosen_indices = perm[:n**2]
    
    probs[chosen_indices] = 1.0 / (n**2)
    
    return probs

In [None]:
import pickle
import os

# select situations
situations = [101]

start = 1
end = 100

for num_situation in situations:

    seed = 0
    np.random.seed(seed)

    # model params
    params_path = f'../data/params/situation{num_situation}.json'
    params = load_params(params_path)
    print(params)
    n_visible = params["n_visible"]
    n_hidden = params["n_hidden"]
    n_qubits = n_visible + n_hidden
    n_epochs = params["n_epochs"]
    n_epochs_m = params["n_epochs_m"]
    learning_rate = params["learning_rate"]
    epsilon_em = params["epsilon_em"]
    epsilon_gd = params["epsilon_gd"]
    # seed = params["seed"]
    data_type = params["data_type"]

    # generate data distribution
    prob_data = {}
    for i in range(2**n_visible):
        if data_type == 'bernoulli':
            V_data = np.array(Discretizer.int_to_bit_vector(i, n_visible))
            prob_data[i] = bernoulli(0.9, V_data, 8, n_visible, seed=0)
        if data_type == 'even_parity':
            prob_data[i] = even_parity_probability_distribution(n_visible)[i]
        if data_type == 'cardinality':
            prob_data[i] = cardinality_distribution(n_visible)[i]
        if data_type == 'on2':
            prob_data[i] = on2_distribution(n_visible, seed=0)[i]

    print('prob_data', prob_data)
    print('prob_data_sum', sum(prob_data.values()))

    expected_value_V = np.zeros((1, n_visible))
    for i in prob_data.keys():
        V_data = np.array(Discretizer.int_to_bit_vector(i, n_visible)).reshape(1, -1)
        V_data = _binary_to_eigen(V_data)
        expected_value_V += prob_data[i] * V_data

    print('expected_value_V', expected_value_V)

    for time in range(start, end + 1):
        print('time', time)

        W_init = (np.random.rand(n_visible, n_hidden) - 0.5) * 10
        print('W_init', W_init)
        b_init = (np.random.rand(n_qubits) - 0.5) * 10
        print('b_init', b_init)
        Gamma_init = np.concatenate((np.zeros(n_visible), (np.random.rand(n_hidden) - 0.5) * 10))
        print('Gamma_init', Gamma_init)

        # model training
        model_sqRBM_em = sqRBM_em(
            prob_data=prob_data,
            expected_value_V=expected_value_V,
            n_visible=n_visible,
            n_hidden=n_hidden,
            W_init=W_init,
            b_init=b_init,
            Gamma_init=Gamma_init,
            # Gamma=1,
            B_freeze=1,
            beta_initial=1,
            seed=seed,
        )
        model_sqRBM_em.train_em(
            n_epochs=n_epochs,
            n_epochs_m=n_epochs_m,
            learning_rate=learning_rate,
            epsilon=epsilon_em,
            callback=callback_em,
        )

        model_sqRBM_gd = sqRBM_em(
            prob_data=prob_data,
            expected_value_V=expected_value_V,
            n_visible=n_visible,
            n_hidden=n_hidden,
            W_init=W_init,
            b_init=b_init,
            Gamma_init=Gamma_init,
            # Gamma=1,
            B_freeze=1,
            beta_initial=1,
            seed=seed,
        )
        model_sqRBM_gd.train_gd(
            n_epochs=n_epochs,
            learning_rate=learning_rate,
            epsilon=epsilon_gd,
            callback=callback_gd,
        )

        model_RBM_em = RBM_em(
            prob_data=prob_data,
            expected_value_V=expected_value_V,
            n_visible=n_visible,
            n_hidden=n_hidden,
            W_init=W_init,
            b_init=b_init,
            Gamma_init=Gamma_init,
            B_freeze=1,
            beta_initial=1,
            seed=seed,
        )
        model_RBM_em.train_em(
            n_epochs=n_epochs,
            n_epochs_m=n_epochs_m,
            learning_rate=learning_rate,
            epsilon=epsilon_em,
            callback=callback_em,
        )

        model_RBM_gd = RBM_em(
            prob_data=prob_data,
            expected_value_V=expected_value_V,
            n_visible=n_visible,
            n_hidden=n_hidden,
            W_init=W_init,
            b_init=b_init,
            Gamma_init=Gamma_init,
            B_freeze=1,
            beta_initial=1,
            seed=seed,
        )
        model_RBM_gd.train_gd(
            n_epochs=n_epochs,
            learning_rate=learning_rate,
            epsilon=epsilon_gd,
            callback=callback_gd,
        )

        # save results
        kld_em_quantum_epoch_m = [x["kld"] for x in model_sqRBM_em.callback_history_epoch_m]
        kld_emq_epoch_m = {'kld_emq_epoch_m': kld_em_quantum_epoch_m}
        kld_em_quantum_epoch = [x["kld"] for x in model_sqRBM_em.callback_history_epoch]
        kld_emq_epoch = {'kld_emq_epoch': kld_em_quantum_epoch}

        # qre_em_quantum_epoch_m = [x["qre"] for x in model_sqRBM_em.callback_history_epoch_m]
        # qre_emq_epoch_m = {'qre_emq_epoch_m': qre_em_quantum_epoch_m}
        # qre_em_quantum_epoch = [x["qre"] for x in model_sqRBM_em.callback_history_epoch]
        # qre_emq_epoch = {'qre_emq_epoch': qre_em_quantum_epoch}

        kld_gd_quantum = [x["kld"] for x in model_sqRBM_gd.callback_history]
        kld_gdq = {'kld_gdq': kld_gd_quantum}

        kld_em_classical_epoch_m = [x["kld"] for x in model_RBM_em.callback_history_epoch_m]
        kld_emc_epoch_m = {'kld_emc_epoch_m': kld_em_classical_epoch_m}
        kld_em_classical_epoch = [x["kld"] for x in model_RBM_em.callback_history_epoch]
        kld_emc_epoch = {'kld_emc_epoch': kld_em_classical_epoch}

        # qre_em_classical_epoch_m = [x["qre"] for x in model_RBM_em.callback_history_epoch_m]
        # qre_emc_epoch_m = {'qre_emc_epoch_m': qre_em_classical_epoch_m}
        # qre_em_classical_epoch = [x["qre"] for x in model_RBM_em.callback_history_epoch]
        # qre_emc_epoch = {'qre_emc_epoch': qre_em_classical_epoch}

        kld_gd_classical = [x["kld"] for x in model_RBM_gd.callback_history]
        kld_gdc = {'kld_gdc': kld_gd_classical}

        for j in [kld_emq_epoch_m, kld_emq_epoch, kld_gdq, kld_emc_epoch_m, kld_emc_epoch, kld_gdc]:
            save_dir = f'../data/output/results/situation{num_situation}/time{time}'
            name = f'{list(j.keys())[0]}'
            filename = f'{name}'
            os.makedirs(save_dir, exist_ok=True)
            filepath = os.path.join(save_dir, filename)

            with open(filepath, 'wb') as f:
                pickle.dump(j[name], f)