# maskencode.ipynb
This notebook demonstrates the use of the `mlgreens.data` module to mask & positionally encode 2D Green's function data

In [None]:
import numpy as np
from matplotlib import pyplot as plt

from mlgreens import data

## Load Green's function data
First we'll need some Green's function data to work with. The set in this folder is low in resolution, but sufficient for demonstrating masking and positional encoding.

In [None]:
fpath = "./GData.h5"
Gdata = data.GreensData(fpath)
Lc = Gdata._h5file.attrs["Lc"]
a = Gdata._h5file.attrs["a"]
n_m, n_lam = Gdata._h5file.attrs["n_modes"]

Next we'll randomly choose a sample of the loaded data to use for visualizations

In [None]:
#
mode = 5
m, lam = mode//n_lam, mode%n_lam

# Randomly choose a sample from Green's function data
rID = np.random.choice(len(Gdata))
Gsample = Gdata[rID][..., mode]
Glabel = r"$G_{:s}$".format("{"+"{:d}{:d}".format(m, lam)+"}")

# Retrieve attribute information for chosen sample
sample_attrs = Gdata.get_attrs(rID)
sample_group = sample_attrs["group"]
sample_xi = sample_attrs["xi"]

## Create positional encodings & mask patches of data
To encode & mask the Green's function data, we'll need to designate sub-patches of the domain space:

In [None]:
patch_dims = np.array([5, 10])
patchG = data.patch2D(Gsample, patch_dims)
n_patches = len(patchG)
G_dims = np.array(patchG.shape[1:])

Since each patch represents an input to a masked autoencoder, positional encoding is conducted over patches. The single-dimension `PE` values below would be appropriate for a CNN-based autoencoder (CAE), but for an autoencoder without a means of representing spatial information in each input (i.e., without convolutions), we would need to add dimensions for each input patch (`dim = patch_dims.prod()`).

In [None]:
# Compute positional encodings, scaled to Green's function data
dim = 1
fenc = .25
PE = data.encode(n_patches, dim, n=int(1e2), scale=fenc*patchG.max())

# Verify that positional encodings for each patch are unique
if len(set(PE.flatten())) != len(PE.flatten()):
    print("WARNING: Positional encodings not unique!")
else:
    print("Encoding uniqueness verified.")

Finally, we can mask the data & encode it. We'll also create masked version of the encoding for plotting purposes:

In [None]:
# Compute masked versions of Green's data & encoding arrays
MG, MiG, _ = data.mask_data(patchG, .6, rseed=1234)
MPE, MiPE, _ = data.mask_data(PE, .6, rseed=1234)

# Create encoded Green's function arrays
EG = np.array([patchG[i] + PE[i] for i in range(n_patches)])
MEG = np.array([MG[i] + MPE[i] for i in range(n_patches)])
MiEG = np.array([MiG[i] + MiPE[i] for i in range(n_patches)])

## Plot results
Before plotting, we'll rearrange data back into image-like shapes:

In [None]:
# Resize positional encodings for plotting
fullPE = np.array([0.*patchG[i] + PE[i] for i in range(n_patches)])
fullMPE = np.array([0.*patchG[i] + MPE[i] for i in range(n_patches)])
fullMiPE = np.array([0.*patchG[i] + MiPE[i] for i in range(n_patches)])

# Unpatch Green's function data
plotG = Gsample
plotMG = data.unpatch2D(MG, plotG.shape)
plotMiG = data.unpatch2D(MiG, plotG.shape)

# Unpatch resized positional encodings
plotPE = data.unpatch2D(fullPE, plotG.shape)
plotMPE = data.unpatch2D(fullMPE, plotG.shape)
plotMiPE = data.unpatch2D(fullMiPE, plotG.shape)

# Unpatch encoded Green's function data
plotEG = data.unpatch2D(EG, plotG.shape)
plotMEG = data.unpatch2D(MEG, plotG.shape)
plotMiEG = data.unpatch2D(MiEG, plotG.shape)

Now we'll plot Green's function data, encodings, and their sum. We'll also plot the masked ($M[\ \cdot\ ]$) and hidden ($M^{-1}[\ \cdot\ ]$) version of each value:

In [None]:
ex = [0., Lc*10., 0., sample_xi*a*10]
plot_args = {
    "cmap": 'gray',
    "origin": 'lower',
    "extent": ex,
    "aspect": .6*ex[1]/ex[-1],
}
plot_data = [[plotG, plotMG, plotMiG], [plotPE, plotMPE, plotMiPE], [plotEG, plotMEG, plotMiEG]]
plot_titles = [
    [Glabel, r"$M[$"+Glabel+r"$]$", r"$M^{-1}[$"+Glabel+r"$]$"],
    [r"$PE$", r"$M[PE]$", r"$M^{-1}[PE]$"],
    [Glabel+r"$+PE$", r"$M[$"+Glabel+r"$+PE]$", r"$M^{-1}[$"+Glabel+r"$+PE]$"]
]

fig, axs = plt.subplots(3,3, figsize=(15,11))
fig.suptitle(r"Masking & Positional Encoding", fontsize=24, y=.95)
for i in range(3):
    for j in range(3):
        axs[i,j].set_title(plot_titles[i][j], fontsize=14)
        axs[i,j].imshow(plot_data[i][j], **plot_args)
for ax in axs.flatten():
    ax.set_ylabel(r"$R$ (mm)", fontsize=14)
    ax.set_xlabel(r"$Z$ (mm)", fontsize=14)
    ax.tick_params(labelsize=14)
fig.subplots_adjust(hspace=.3, wspace=.25)