In [None]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec

def plot_heatdis(data_list, shape2, title, cmap_name, figname):

    m = len(data_list)
    figwidth = 4
    
    title_list = ["Reference", "DOC", "prePred", "postPred"]

    gs = gridspec.GridSpec(1, m, width_ratios=[1] * m, wspace=0.5)
    fig = plt.figure(figsize=(figwidth * m, figwidth))
    
    axs = []
    for i in range(m):
        ax = fig.add_subplot(gs[0, i])
        im = ax.imshow(
            data_list[i],
            origin="lower",
            interpolation="quadric",
            extent=[0, shape2[1], 0, shape2[0]],
            cmap=plt.get_cmap(cmap_name),
            vmin=0,
            vmax=100
        )
        if(i != 0):
            print("{}: max error = {}, MSE = {:4e}".format(title_list[i], np.max(np.abs(data_list[i] - data_list[0])), np.mean((data_list[i] - data_list[0])**2)))
        ax.plot([0, shape2[1]], [0, 0], color="black")
        ax.plot([0, 0], [0, shape2[0]], color="black")
        ax.set_title(title_list[i], fontsize=16)
        axs.append(ax)

    cbar = fig.colorbar(im, ax=axs, orientation="vertical", fraction=0.05, pad=0.05)
    cbar.ax.tick_params(labelsize=10)
    cbar.set_label('temperature', fontsize=12)
    fig.suptitle(title)

    # plt.savefig(figname)
    plt.show()
    
def plot_heatdis_error(data_list, shape2, title, cmap_name, figname):

    m = len(data_list)
    figwidth = 4
    
    title_list = ["Reference", "DOC", "prePred", "postPred"]

    gs = gridspec.GridSpec(1, m, width_ratios=[1] * m, wspace=0.5)
    fig = plt.figure(figsize=(figwidth * m, figwidth))
    
    axs = []
    v = 1e-4
    for i in range(m-1):
        error = data_list[i+1] - data_list[0]
        ax = fig.add_subplot(gs[0, i])
        im = ax.imshow(
            error,
            origin="lower",
            interpolation="none",
            extent=[0, shape2[1], 0, shape2[0]],
            cmap=plt.get_cmap(cmap_name),
            vmin=-v,
            vmax=v
        )
        print("{}: max error = {}, MSE = {}".format(title_list[i], np.max(np.abs(error)), np.mean(error**2)))
        ax.plot([0, shape2[1]], [0, 0], color="black")
        ax.plot([0, 0], [0, shape2[0]], color="black")
        ax.set_title(title_list[i+1], fontsize=16)
        axs.append(ax)

    cbar = fig.colorbar(im, ax=axs, orientation="vertical", fraction=0.05, pad=0.05)
    cbar.ax.tick_params(labelsize=10)
    cbar.set_label('temperature', fontsize=12)
    fig.suptitle(title)

    # plt.savefig(figname)
    plt.show()

In [None]:
def read_data(fname, dtype, shape2):
    data = np.fromfile(fname, dtype=dtype).reshape(shape2)
    return data

def read_trim_data(fname, dtype, ref_shape2):
    data = np.fromfile(fname, dtype=dtype).reshape(ref_shape2)
    return data[1:(ref_shape2[0]-1), 1:(ref_shape2[1]-1)]

def array_max_diff(a1, a2):
    return np.max(np.abs(a1 - a2))

In [None]:
json_file = "/Users/xuanwu/github/backup/homoApplication/test/heatdis.json"
plot_work_dir = "/Users/xuanwu/github/backup/homoApplication/plot"
ht_data_dir = plot_work_dir + "/ht_data"

import json
with open(json_file, 'r') as file:
    settings = json.load(file)

dim1 = settings["dim1"]
dim2 = settings["dim2"]
lorenzo = settings["lorenzo"]
B = settings["B"]
eb = settings["eb"]
offset = settings["offset"]
plot_gap = settings["plotgap"]
max_iter = settings["steps"]
shape2 = [dim1, dim2]
ref_shape2 = [dim1+2, dim2+2]

suffix_list = ["h.ref.", "h{}d.doc.".format(lorenzo), "h{}d.pre.".format(lorenzo), "h{}d.post.".format(lorenzo)]
if lorenzo == 2:
    suffix_list.pop()

In [None]:
from matplotlib import colormaps
cmap_name = "nipy_spectral"

plot_steps = np.arange(max(plot_gap, offset), max_iter+1, plot_gap)
for step in plot_steps:
    fname_list = [ht_data_dir + "/" + suffix + "{}".format(step) for suffix in suffix_list]
    data_list = [read_trim_data(fname, np.float32, ref_shape2) for fname in fname_list[:2]] + \
                [read_data(fname, np.float32, shape2) for fname in fname_list[2:]]
    figtitle = "{}d lorenzo, B = {}, eb = {}, compute step {}".format(lorenzo, B, str(eb), step)
    figname = plot_work_dir + "/img/{}x{}_{}_{}_{}d_{}.png".format(dim1, dim2, B, eb, lorenzo, step)
    plot_heatdis(data_list, shape2, figtitle, cmap_name, figname)
    # plot_heatdis_error(data_list, [10, 20], figtitle, "coolwarm", figname)