In [1]:
import numpy as np
import json
import matplotlib.pyplot as plt
from pathlib import Path
import h5py
from tenpy.tools import hdf5_io
from IPython.display import display
from ipywidgets import interact, FloatSlider, fixed

# Adjust these paths to your environment if needed
BASE_PATH_DMRG = Path("/home1/wttai/machine_learning/dmrg_qwz/")
BASE_PATH_NETKET = Path("/home1/wttai/machine_learning/netket_qwz/")



In [2]:
def load_dmrg_data(query_dict_dmrg, date_dmrg):
    database_path_dmrg = BASE_PATH_DMRG / f"data/dmrg_qwz_{date_dmrg}/database.json"
    
    data_dmrg = []
    with open(database_path_dmrg, "r") as f:
        database = json.load(f)
        for entry in database:
            metadata = entry["metadata"]
            if all(metadata.get(k)==v for k,v in query_dict_dmrg.items()):
                filename = entry["outputFilename"]
                path = BASE_PATH_DMRG / filename 
                with h5py.File(path, 'r') as hf:
                    data = hdf5_io.load_from_hdf5(hf)
                    data_dmrg.append(data)
    return data_dmrg

def load_netket_data(query_dict, date_netket):
    database_path = BASE_PATH_NETKET / f"data/netket_qwz_{date_netket}/database.json"

    data_to_plot = []
    with open(database_path, "r") as f:
        database = json.load(f)
        for entry in database:
            metadata = entry["metadata"]
            if all(metadata.get(k)==v for k,v in query_dict.items()):
                filename = entry["outputFilename"]+ ".json"
                path = BASE_PATH_NETKET / filename 
                data = json.load(open(path))
                data_to_plot.append(data)
    return data_to_plot

In [3]:
def plot_energies(data_dmrg, data_to_plot, title_energies, out = None):
    def do_plot():
        fig, ax = plt.subplots(figsize=(6,4))
        for data in data_to_plot:
            ax.plot(data['data']["Energy"]["iters"], 
                    data['data']["Energy"]['Mean']['real'], 
                    label=f"{data['metadata']['model']}, n_hidden={data['metadata']['n_hidden']}, layers={data['metadata']['n_hidden_layers']}")
        ax.axhline(y=data_dmrg[0]['data']['E0'], color='r', linestyle='--', label='DMRG') 
        ax.set_title(title_energies)
        ax.set_xlabel("Iteration")
        ax.set_ylabel("Energy")
        ax.legend()
        plt.tight_layout()
        plt.show()
    
    # If out is given, capture the plot output into it
    if out is not None:
        with out:
            do_plot()
    else:
        # If no output widget is provided, just plot normally
        do_plot()

def plot_energy_errors(data_dmrg, data_to_plot, title_relative_error, out = None):
    def do_plot():
        fig, ax = plt.subplots(figsize=(6,4))
        dmrg_E0 = data_dmrg[0]['data']['E0']
        for data in data_to_plot:
            rel_err = -(data['data']["Energy"]['Mean']['real'] - dmrg_E0)/dmrg_E0
            ax.plot(data['data']["Energy"]["iters"], rel_err, 
                    label=f"{data['metadata']['model']}, n_hidden={data['metadata']['n_hidden']}, layers={data['metadata']['n_hidden_layers']}")
        ax.set_title(title_relative_error)
        ax.set_xlabel("Iteration")
        ax.set_ylabel("Relative Error")
        ax.legend()
        plt.tight_layout()
        plt.show()
    if out is not None:
        with out:
            do_plot()
    else:
        do_plot()

def plot_correlations(data_dmrg, data_to_plot, L, N, U, m, pbc, n_samples, orbitals_nk='pp', orbitals_dmrg='sp', out = None):
    def do_plot():
        n_corrs = L**2
        x_values = np.arange(len(data_to_plot[0]["data"][orbitals_nk][0]['Mean']['real']))
        
        fig, axs = plt.subplots(L, L, figsize=(12, 12), sharey=True)
        lines = []
        labels = []
        
        corrs_dmrg = data_dmrg[0]['data']['corrs_results'][orbitals_dmrg]

        for data in data_to_plot:
            for i in range(n_corrs):
                x = i // L
                y = i % L
                line, = axs[x][y].plot(x_values, data["data"][orbitals_nk][i]['Mean']['real'], 
                                    label=f"{data['metadata']['model']}, n_hidden={data['metadata']['n_hidden']}, layers={data['metadata']['n_hidden_layers']}")
                axs[x][y].axhline(y=corrs_dmrg[x][y], color='r', linestyle='--', label='DMRG') 
                axs[x][y].set_title(f"({x}, {y})")
                axs[x][y].set_xlabel('Iteration')

                if line.get_label() not in labels:
                    lines.append(line)
                    labels.append(line.get_label())

        title = f"{orbitals_nk} correlations, L={L}, N={N}, U={U}, m={m}, pbc={pbc}, n_samples={n_samples}"
        fig.suptitle(title, fontsize=16)
        fig.legend(lines, labels, loc='center right', title="Legend", bbox_to_anchor=(1.25, 0.5))
        plt.tight_layout()
        plt.show()
    if out is not None:
        with out:
            do_plot()
    else:
        do_plot()
        
def plot_correlation_errors(data_dmrg, data_to_plot, L, N, U, m, pbc, n_samples, orbitals_nk='pp', orbitals_dmrg='sp', out = None):
    def do_plot():
        n_corrs = L**2
        x_values = np.arange(len(data_to_plot[0]["data"][orbitals_nk][0]['Mean']['real']))
        
        fig, axs = plt.subplots(L, L, figsize=(12, 12), sharey=True)
        lines = []
        labels = []
        
        corrs_dmrg = data_dmrg[0]['data']['corrs_results'][orbitals_dmrg]

        for data in data_to_plot:
            for i in range(n_corrs):
                x = i // L
                y = i % L
                rel_err = np.abs((data["data"][orbitals_nk][i]['Mean']['real'] - corrs_dmrg[x][y]) / corrs_dmrg[x][y])
                line, = axs[x][y].plot(x_values, rel_err, 
                                    label=f"{data['metadata']['model']}, n_hidden={data['metadata']['n_hidden']}, layers={data['metadata']['n_hidden_layers']}")
                axs[x][y].set_title(f"({x}, {y})")
                axs[x][y].set_xlabel('Iteration')

                if line.get_label() not in labels:
                    lines.append(line)
                    labels.append(line.get_label())

        title = f"Relative error, {orbitals_nk} correlations, L={L}, N={N}, U={U}, m={m}, pbc={pbc}, n_samples={n_samples}"
        fig.suptitle(title, fontsize=16)
        fig.legend(lines, labels, loc='center right', title="Legend", bbox_to_anchor=(1.25, 0.5))
        plt.tight_layout()
        plt.show()
    if out is not None:
        with out:
            do_plot()
    else:
        do_plot()


In [4]:
from ipywidgets import Output, HBox

out1 = Output()
out2 = Output()
out3 = Output()
out4 = Output()

def visualize(L=4, N=8, U=8.0, m=3.5, pbc=True, n_samples=32768, orbitals = "ss", date_dmrg="20241217_07", date_netket="20241217_05"):
    out1.clear_output(wait=True)
    out2.clear_output(wait=True)
    out3.clear_output(wait=True)
    out4.clear_output(wait=True)
    query_dict_dmrg =  {"L": L, "N": N, "t": 1.0, "U": U, "m": m, "pbc": pbc}
    query_dict_netket = {"L": L, "N": N, "t": 1.0, "U": U, "m": m, "pbc": pbc, "n_samples": n_samples}
    data_dmrg = load_dmrg_data(query_dict_dmrg, date_dmrg)
    data_to_plot = load_netket_data(query_dict_netket, date_netket)
    if len(data_dmrg)==0 or len(data_to_plot)==0:
        print("No matching data found.")
        return

    title_energies = f"L={L}, N={N}, U={U}, m={m}, pbc={pbc}, n_samples={n_samples}"
    title_relative_error = f"Relative error in E, L={L}, N={N}, U={U}, m={m}, pbc={pbc}, n_samples={n_samples}"
    plot_energies(data_dmrg, data_to_plot, title_energies, out= out1)
    plot_energy_errors(data_dmrg, data_to_plot, title_relative_error, out= out2)
    
    orbitals_nk = orbitals
    if orbitals == "ss":
        orbitals_dmrg = "ss"
    elif orbitals == "pp":
        orbitals_dmrg = "sp"
    else:
        orbitals_dmrg = "pp"
    #orbitals_dmrg = ("ss" if orbitals_nk == "ss" else ("sp" if orbitals_nk == "pp" else "sp"))
    print(f"Orbitals: {orbitals_nk} (NetKet), {orbitals_dmrg} (DMRG)")
    # Adjust orbitals as needed
    plot_correlations(data_dmrg, data_to_plot, L, N, U, m, pbc, n_samples, orbitals_nk, orbitals_dmrg, out=out3)
    plot_correlation_errors(data_dmrg, data_to_plot, L, N, U, m, pbc, n_samples, orbitals_nk, orbitals_dmrg, out=out4)

In [5]:
from ipywidgets import interact, Dropdown, fixed

# Define the discrete options for U and m
U_values = [0.0, 1.0, 2.0, 4.0, 8.0]
m_values = [1.0, 2.0, 3.0, 3.5, 4.5, 5.0]

orbitals = ["ss", "sp", "pp"]


interact(
    visualize,
    L=fixed(4),
    N=fixed(8),
    U=Dropdown(options=U_values, value=8.0, description='U:'),
    m=Dropdown(options=m_values, value=3.5, description='m:'),
    pbc=fixed(True),
    n_samples=fixed(32768),
    date_dmrg=fixed("20241217_07"),
    date_netket=fixed("20241217_05"),
    orbitals=Dropdown(options=orbitals, value="ss", description="Orbitals:")
);

display(HBox([out1, out2]))
display(HBox([out3, out4]))

interactive(children=(Dropdown(description='U:', index=4, options=(0.0, 1.0, 2.0, 4.0, 8.0), value=8.0), Dropd…

HBox(children=(Output(), Output()))

HBox(children=(Output(), Output()))