# Gomb - Net
### Performance test on the WSSe dataset
### And analysis of WSSe experimental data

Austin Houston

[![OpenInColab](https://colab.research.google.com/assets/colab-badge.svg)](
    https://colab.research.google.com/github/ahoust17/Gomb-Net/blob/main/Eval_WSSe_model.ipynb)

### Necessary installs

In [None]:
# basics
import os
import sys
import numpy as np

# plotting
import matplotlib.pylab as plt
import matplotlib.colors as mcolors
from matplotlib.colors import Normalize
from matplotlib import cm

# colab interactive plots and drive
drive = False
if 'google.colab' in sys.modules:
    from  google.colab import drive 
    from google.colab import output
    drive.mount('/content/drive')
    output.enable_custom_widget_manager()
    drive = True
else:
    %matplotlib widget

# other imports
from scipy.ndimage import label, center_of_mass, gaussian_filter, zoom, uniform_filter
from scipy.spatial import KDTree
from scipy.interpolate import griddata
from scipy.stats import norm, gaussian_kde
from skimage.filters import threshold_otsu
from skimage.feature import blob_log

# for cropping function
if drive:
    print('installing DataGenSTEM')
    !pip install ase
    !git clone https://github.com/ahoust17/DataGenSTEM.git
    sys.path.append('./DataGenSTEM/DataGenSTEM')
    import data_generator as dg

# for Gomb-Net
if drive:
    print('installing Gomb-Net')
    !git clone https://github.com/ahoust17/Gomb-Net.git
    sys.path.append('./Gomb-Net/GombNet')    
from GombNet.networks import *
from GombNet.loss_func import GombinatorialLoss
from GombNet.utils import *

import torch
# Check if CUDA is available
if torch.cuda.is_available():
    device = torch.device("cuda")
    print("CUDA is available. Using GPU.")
else:
    device = torch.device("cpu")
    print("CUDA is not available. Using CPU.")

In [None]:
# set mps, just for my computer
device = torch.device('mps')

### Now, you need to add the following shared drive to your google drive:
*** WARNING: it is a big file.  Check before you download ***


https://drive.google.com/file/d/1DyKtrmJ8wNYQg3YEJ8_iXjz6lB_DQfwy/view?usp=sharing

### Run the following cell after the download is complete

In [None]:
if drive:
    shared_folder = 'drive/My Drive/Gomb-Net files'
else:
    shared_folder = '/Users/austin/Desktop/gomb_beta'

print('available files & directories:')
!ls '{shared_folder}'

### Now we can start with the actual code

let's look at the dataset:

In [None]:
# Create dataloaders
images_dir = str(shared_folder + '/WSSe_dataset/images')
labels_dir = str(shared_folder + '/WSSe_dataset/labels')
train_loader, val_loader, test_loader = get_dataloaders(images_dir, labels_dir, batch_size = 1, val_split = 0.2, test_split = 0.1, seed = 42) 


In [None]:
test_iter = 1 # 2,4,5,7,8,9, 10, 18
test = test_loader.dataset[test_iter][0].unsqueeze(0)
gt = test_loader.dataset[test_iter][1]

fig, ax = plt.subplots(1, 7, figsize=(12, 5))
ax[0].imshow(test[0, 0].cpu().numpy(), cmap='gray')
ax[0].set_title('Input')

titles = ['L1: S', 'L1: Se', 'L1: W', 'L2: S', 'L2: Se', 'L2: W']
for i in range(6):
    ax[i+1].imshow(gt[i].cpu().numpy(), cmap='gray')
    ax[i+1].set_title(titles[i])
for a in ax:
    a.axis('off')
fig.tight_layout()

now let's look at the model:

In [None]:
# Initialize model 'skeleton'
input_channels = 1
num_classes = 6    # number of output classes
num_filters = [32, 64, 128, 256, 512]

model = TwoLeggedUnet(input_channels, num_classes, num_filters, dropout = 0.2)
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
loss = GombinatorialLoss(group_size = num_classes//2, loss = 'Dice', epsilon=1e-6, class_weights = None, alpha=2)

In [None]:
# Get the number of trainable parameters
def count_trainable_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)
num_trainable_params = count_trainable_parameters(model)
print(f"Number of trainable parameters: {num_trainable_params}")

visualize the training history for the pretrained model:

In [None]:
fig, ax = plt.subplots(1,2, sharex=True, sharey=True)
for file in os.listdir(str(shared_folder + '/WSSe_models')):
    if file.endswith('.npz'):
        loss_history = np.load(str(shared_folder + '/WSSe_models/' + file))
        train_loss = loss_history['train_loss_history']
        val_loss = loss_history['val_loss_history']

        label = file.split('_')[2]
        print(label)
        ax[0].plot(train_loss, label=label)
        ax[1].plot(val_loss, label=label)

ax[0].set_title('Training loss')
ax[1].set_title('Validation loss')

ax[0].legend()
ax[1].legend()


load in the pretrained weights onto our model 'skeleton'

In [None]:
models_dict = {}
for file in os.listdir(str(shared_folder + '/WSSe_models')):
    if file.endswith('.pth'):
        if 'best' in file:
            model_path = str(shared_folder + '/WSSe_models/' + file)
            checkpoint = torch.load(model_path, map_location=device)
            model.load_state_dict(checkpoint['model_state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
            model.eval()

            models_dict[str('model_' + file.split('_')[2])] = model

model_names = [key for key in models_dict.keys()]
ensemble = GombSemble(models_dict)

def threshold(prediction, transfrom_list, param = 1):
    # operates on an image-wise basis
    if 'sigmoid' in transfrom_list:
        prediction = 1/(1 + np.exp(-prediction))

    if 'threshold' in transfrom_list:
        for i in range(prediction.shape[0]):
            cutoff = np.std(prediction[i]) * param
            prediction[i] = (prediction[i] > cutoff).astype(int)

    if 'otsu' in transfrom_list:        
        for i in range(prediction.shape[0]):
            thresh = threshold_otsu(prediction[i])
            thresh = thresh * param
            prediction[i] = (prediction[i] > thresh).astype(int)
            
    return prediction


In [None]:
test_iter = 0 # 2,4,5,7,8,9, 10, 18
test = test_loader.dataset[test_iter][0].unsqueeze(0)
gt = test_loader.dataset[test_iter][1]

fig, ax = plt.subplots(1, 7, figsize=(12, 5))
ax[0].imshow(test[0, 0].cpu().numpy(), cmap='gray')
ax[0].set_title('Input')

titles = ['L1: S', 'L1: Se', 'L1: W', 'L2: S', 'L2: Se', 'L2: W']
for i in range(6):
    ax[i+1].imshow(gt[i].cpu().numpy(), cmap='gray')
    ax[i+1].set_title(titles[i])
for a in ax:
    a.axis('off')
fig.tight_layout()


In [None]:
# Predict using the ensemble
ensemble.predict(test, return_plot=False)


In [None]:
ensemble.rearrange_ensemble()
ensemble.vote(mode='mean')
# ensemble.plot_vote()

In [None]:
binary = threshold(ensemble.voted_prediction, ['threshold'], param = 1)

In [None]:
fig, ax = plt.subplots(2,6, figsize=(12, 5))
for i in range(6):
    ax[0,i].set_title('GT')
    ax[0,i].imshow(gt[i].cpu().numpy(), cmap='gray')

    ax[1,i].set_title(titles[i])
    ax[1,i].imshow(binary[i], cmap='gray')
    
for a in ax.ravel():
    a.axis('off')
    

## Now, on Experimental data:

### First, some useful functions:

In [None]:
def plot_atom_histograms(results, n_bins=50):
    colors = {'W': ['blue', 'cyan'], 'Se': ['green', 'lime'], 'S': ['red', 'magenta']}
    
    for i, result in results.items():
        intensities = {el: [np.concatenate(result['intensities'][el][0]) if result['intensities'][el][0] else np.array([]),
                            np.concatenate(result['intensities'][el][1]) if result['intensities'][el][1] else np.array([])]
                       for el in colors}

        fig, axs = plt.subplots(1, 3, figsize=(20, 8))
        # Plot the image with centroids
        axs[0].imshow(result['image'], cmap='gray')
        for element, colors_list in colors.items():
            for idx, (color, layer) in enumerate(zip(colors_list, result['centroids'][element])):
                if layer.size > 0:
                    axs[0].scatter(layer[:, 1], layer[:, 0], color=color, label=f'{element} Layer {idx + 1} Centroids', alpha=0.6)
        axs[0].axis('off')
        axs[0].legend(loc='upper right', fontsize=8)

        # Plot histograms and KDEs for both layers
        for j, layer in enumerate(['Layer 1', 'Layer 2']):
            for element, (color1, color2) in colors.items():
                color = color1 if j == 0 else color2
                layer_intensities = intensities[element][j]
                if layer_intensities.size > 0:
                    density = gaussian_kde(layer_intensities.flatten())
                    x = np.linspace(min(layer_intensities.flatten()), max(layer_intensities.flatten()), 1000)
                    axs[j + 1].hist(layer_intensities.flatten(), bins=n_bins, color=color, alpha=0.5, label=f'{element} {layer} Intensities', density=False)
                    axs[j + 1].plot(x, density(x), color=color)
            axs[j + 1].set_title(f'{layer} Intensities')
            axs[j + 1].set_xlabel('Intensity')
            axs[j + 1].set_ylabel('Density')
            axs[j + 1].legend(loc='upper right', fontsize=8)

        plt.tight_layout()
        plt.show()


In [None]:
def fit_gaussian_mixture(metric, n_components):
    gmm = GaussianMixture(n_components=n_components)
    gmm.fit(metric.reshape(-1, 1))
    means = gmm.means_.flatten()
    stds = np.sqrt(gmm.covariances_).flatten()
    weights = gmm.weights_
    log_likelihood = gmm.score(metric.reshape(-1, 1)) * len(metric)
    return means, stds, weights, log_likelihood

def plot_histogram_with_gaussian(ax, metric, title, n_components, color_bounds):
    n, bins, patches = ax.hist(metric, bins=20, density=True, edgecolor='k', alpha=1)
    bin_centers = 0.5 * (bins[:-1] + bins[1:])
    norm_dist = mcolors.Normalize(color_bounds[0], color_bounds[1])
    cmap = cm.get_cmap("viridis")
    for c, p in zip(bin_centers, patches):
        plt.setp(p, 'facecolor', cmap(norm_dist(c)))
    ax.set_yticks([])

    means, stds, weights, log_likelihood = fit_gaussian_mixture(metric, n_components)
    xmin, xmax = ax.get_xlim()
    x = np.linspace(xmin, xmax, 1000)

    total_pdf = np.zeros_like(x)
    for mean, std, weight in zip(means, stds, weights):
        print('mean:', mean)
        pdf = weight * (1 / (std * np.sqrt(2 * np.pi))) * np.exp(-0.5 * ((x - mean) / std) ** 2)
        total_pdf += pdf
        ax.plot(x, pdf, '--', c='k', linewidth=2)

    ax.plot(x, total_pdf, 'k-', linewidth=3, alpha=0.8)
    ax.set_title(title, fontsize=14)
    ax.set_xlim(np.min(metric), np.max(metric))
    return means, stds, weights, x, total_pdf, log_likelihood


In [None]:
def plot_histogram(ax, metric, color_bounds):
    n, bins, patches = ax.hist(metric, bins=20, density=True, edgecolor='k', alpha=1)
    bin_centers = 0.5 * (bins[:-1] + bins[1:])
    norm_dist = mcolors.Normalize(color_bounds[0], color_bounds[1])
    cmap = cm.get_cmap("viridis")
    for c, p in zip(bin_centers, patches):
        plt.setp(p, 'facecolor', cmap(norm_dist(c)))
    ax.set_yticks([])

    ax.set_xlim(np.min(metric), np.max(metric))
    return

def plot_moire_distances(results, image_number=0):
    def get_layer_data(layer):
        return np.vstack(results[image_number]["centroids"][layer][0]), np.vstack(results[image_number]["centroids"][layer][1])

    Se_0, Se_1 = get_layer_data("Se")
    W_0, W_1 = get_layer_data("W")
    S_0, S_1 = get_layer_data("S")

    if Se_0.shape[0] > Se_1.shape[0]:
        Se_bottom, W_bottom, S_bottom, W_top = Se_0, W_0, S_0, W_1
        Se_top, S_top = Se_1, S_1
    else:
        Se_bottom, W_bottom, S_bottom, W_top = Se_1, W_1, S_1, W_0
        Se_top, S_top = Se_0, S_0
    chalc_top = np.vstack((S_top, Se_top))
    chalc_bottom = np.vstack((S_bottom, Se_bottom))
    W_top_tree, W_bottom_tree = KDTree(W_top), KDTree(W_bottom)
    chalc_top_tree, chalc_bottom_tree = KDTree(chalc_top), KDTree(chalc_bottom)

    fig, ax = plt.subplots(2, 5, figsize=(20, 5))
    ax[0,0].imshow(results[image_number]["image"].T, cmap='gray')
    ax[0,0].set_title('Image')

    grid_density = 512
    x = np.linspace(0, 512, grid_density)
    y = np.linspace(0, 512, grid_density)
    X, Y = np.meshgrid(x, y)
    points = np.vstack((X.ravel(), Y.ravel())).T

    distance_arrays = []
    
    # W_top - W_bottom distances
    distances = np.linalg.norm(W_top[W_top_tree.query(points)[1]] - W_bottom[W_bottom_tree.query(points)[1]], axis=1)
    distances = distances.reshape((grid_density, grid_density)) * pixel_size
    moire_sites = griddata((X.ravel(), Y.ravel()), distances.ravel(), (Se_bottom[:, 1], Se_bottom[:, 0]), method='cubic')
    ax[0,1].imshow(distances, cmap='viridis', extent=(0, 512, 512, 0))
    ax[0,1].set_title('W-W Distances')
    color_bounds = [distances.min(), distances.max()]
    plot_histogram(ax[1,1], moire_sites.ravel(), color_bounds)
    distance_arrays.append(distances)

    # W_top - Chalc_bottom distances
    distances = np.linalg.norm(W_top[W_top_tree.query(points)[1]] - chalc_bottom[chalc_bottom_tree.query(points)[1]], axis=1)
    distances = distances.reshape((grid_density, grid_density)) * pixel_size
    moire_sites = griddata((X.ravel(), Y.ravel()), distances.ravel(), (Se_bottom[:, 1], Se_bottom[:, 0]), method='cubic')
    ax[0,2].imshow(distances, cmap='viridis', extent=(0, 512, 512, 0))
    ax[0,2].set_title('W-Chalc Distances')
    color_bounds = [distances.min(), distances.max()]
    plot_histogram(ax[1,2], moire_sites.ravel(), color_bounds)
    distance_arrays.append(distances)

    # Chalc_top - W_bottom distances
    distances = np.linalg.norm(chalc_top[chalc_top_tree.query(points)[1]] - W_bottom[W_bottom_tree.query(points)[1]], axis=1)
    distances = distances.reshape((grid_density, grid_density)) * pixel_size
    moire_sites = griddata((X.ravel(), Y.ravel()), distances.ravel(), (Se_bottom[:, 1], Se_bottom[:, 0]), method='cubic')
    ax[0,3].imshow(distances, cmap='viridis', extent=(0, 512, 512, 0))
    ax[0,3].set_title('Chalc-W Distances')
    color_bounds = [distances.min(), distances.max()]
    plot_histogram(ax[1,3], moire_sites.ravel(), color_bounds)
    distance_arrays.append(distances)

    # Chalc_top - Chalc_bottom distances
    distances = np.linalg.norm(chalc_top[chalc_top_tree.query(points)[1]] - chalc_bottom[chalc_bottom_tree.query(points)[1]], axis=1)
    distances = distances.reshape((grid_density, grid_density)) * pixel_size
    moire_sites = griddata((X.ravel(), Y.ravel()), distances.ravel(), (Se_bottom[:, 1], Se_bottom[:, 0]), method='cubic')
    ax[0,4].imshow(distances, cmap='viridis', extent=(0, 512, 512, 0))
    ax[0,4].set_title('Chalc-Chalc Distances')
    color_bounds = [distances.min(), distances.max()]
    plot_histogram(ax[1,4], moire_sites.ravel(), color_bounds)
    distance_arrays.append(distances)

    for a in ax.ravel()[:5]:
        a.axis('off')
    ax[1,0].axis('off')
    fig.tight_layout()
    plt.show()

    return distance_arrays

In [None]:
def visualize_predictions(image_number, results, n_classes):
    # Plot the original image
    plt.figure(figsize=(10, 10))
    plt.imshow(results[image_number]['image'], cmap='gray')
    plt.axis('off')
    plt.title('Original Image')

    # Create a figure with subplots for masks and probabilities
    fig, axs = plt.subplots(2, n_classes, figsize=(20, 10))

    # Plot the masks
    for j in range(n_classes):
        axs[0, j].imshow(results[image_number]['mask'][j], cmap='gray')
        axs[0, j].axis('off')
        axs[0, j].set_title(f'Mask {j+1}')

    # Plot the probabilities
    for j in range(n_classes):
        axs[1, j].imshow(results[image_number]['probability'][j], cmap='plasma')
        axs[1, j].axis('off')
        axs[1, j].set_title(f'Probability {j+1}')

    plt.suptitle(f'Results for Image {image_number}', fontsize=16)
    plt.tight_layout()
    plt.show()

### Image 1 ~ 20 deg twist (or -40, however you want to think about it)

In [None]:
# !git clone https://github.com/ahoust17/DataGenSTEM.git
sys.path.append('./DataGenSTEM/DataGenSTEM')
import DataGenSTEM.DataGenSTEM.data_generator as dg

In [None]:
# Code to crop the experimental images

exp_data = exp_data = np.load(str(shared_folder + '/Experimental_datasets/WSSe_haadf.npz'))
im_array = exp_data['im_array']
pixel_size = exp_data['pixel_size']

im_array = np.sum(im_array, axis = 0)

# im_array = gaussian_filter(im_array, sigma=1)
im_array = np.power(im_array,2.6)
im_array = im_array - np.min(im_array)

im_array = im_array / np.max(im_array)


print(f"Pixel size: {pixel_size.astype(float)} m/pix")
plt.figure()
plt.imshow(im_array, cmap='gray')
plt.axis('off')


exp_hist, bins = np.histogram(im_array.ravel(), bins=256, range=(0.0, 1.0))

selected_images = []
for i, data in enumerate(train_loader):
    if i >= 10:
        break
    images = data[0].numpy()  # Convert batch of images to numpy array
    selected_images.append(images)

selected_images = np.concatenate(selected_images, axis=0)
selected_images = selected_images / selected_images.max()
selected_images_raveled = selected_images.ravel()
train_hist, _ = np.histogram(selected_images_raveled, bins=256, range=(0.0, 1.0))

# Plot the histograms
plt.figure()
plt.hist(im_array.ravel(), bins=256, range=(0.0, 1.0), fc='k', ec='k', alpha=0.5, label='Experimental Image')
plt.plot(bins[:-1], train_hist, 'r', alpha=0.5, label='Training Images')
plt.legend()
plt.xlabel('Intensity')
plt.ylabel('Frequency')
plt.title('Histogram of Experimental Image vs Training Images')
plt.show()

n_crops = 10
images = dg.shotgun_crop(im_array, crop_size=512, n_crops = n_crops, roi = None)
# normalize each image in images to 0,1
for i in range(n_crops):
    images[i] = images[i] - np.min(images[i])
    images[i] = images[i] / np.max(images[i])

In [None]:
# select the top right of the image
adj = 0
manual_image = im_array[-512:, 0+200:512+200]
manual_image = gaussian_filter(manual_image, sigma=1)

manual_image -= manual_image.min()
manual_image /= manual_image.max()

plt.figure()
plt.imshow(manual_image, cmap='gray')


In [None]:
# from scipy.ndimage import center_of_mass
# from scipy.spatial.distance import cdist

In [None]:
# select input image
im = manual_image
test = torch.tensor(im.astype(np.float32)).unsqueeze(0).unsqueeze(0)

# show input image
fig, ax = plt.subplots(1,1, figsize=(4,4))
ax.imshow(test.squeeze().cpu().numpy(), cmap='gray')
ax.axis('off')

ensemble.predict(test, return_plot=False)
ensemble.rearrange_ensemble()
ensemble.vote(mode='max')
ensemble.plot_vote()


In [None]:
binary = threshold(ensemble.voted_prediction.copy(), ['threshold'], param = 1)

fig, ax = plt.subplots(1,6, figsize=(12, 5))
for i in range(6):
    ax[i].imshow(binary[i], cmap='gray')
    ax[i].axis('off')

### Do it batchwise

In [None]:
results = {}
for i, im in enumerate(images):
    print(f"Processing image {i}")
    # Model Prediction
    im_tensor = torch.tensor(im.astype(np.float32)).unsqueeze(0).unsqueeze(0)
    ensemble.predict(im_tensor, return_plot=False)
    ensemble.rearrange_ensemble()
    ensemble.vote(mode='max')
    masks = threshold(ensemble.voted_prediction, ['threshold'], param = 0)
    probability = ensemble.voted_prediction

    centroids = {'W': [], 'S': [], 'Se': []}


    for j, layer in enumerate(masks):
        if j == 0 or j == 3:
            element = 'S'
        elif j == 1 or j == 4:
            element = 'Se'
        else:
            element = 'W'

        labeled_array, num_features = label(layer)
        layer_centroids = np.array(center_of_mass(layer, labeled_array, range(1, num_features + 1)))
        centroids[element].append(layer_centroids)

    results[i] = {
        "image": im,
        "mask": masks,
        "probability": probability,
        "centroids": centroids,
    }

In [None]:
def plot_vegards_law(results):
    S_atoms = []
    Se_atoms = []
    W_atoms = []

    for result in results:
        for element in results[result]["centroids"]:
            for layer in results[result]["centroids"][element]:
                if element == "S":
                    S_atoms.append(layer)
                elif element == "Se":
                    Se_atoms.append(layer)
                else:
                    W_atoms.append(layer)

    WS2_lattice = 3.15 # Angstrom, https://www.hqgraphene.com/WS2.php
    WSe2_lattice = 3.28 # Angstrom, https://www.hqgraphene.com/WSe2.php
    WSSe_lattice = 3.24 # Angstrom

    # Vegard's law
    def vegards_law(a1, a2, x1, x2):
        return x1*a1 + x2*a2

    x_Se = np.linspace(0,1,11)/2
    ideal_lattice = vegards_law(WS2_lattice, WSSe_lattice, 1-x_Se, x_Se)

    # plot the ideal lattice
    plt.figure(figsize=(6,4), dpi = 300)
    plt.plot(x_Se, ideal_lattice, 'k--', label='Vegard\'s Law')

    i = 0
    j = 0
    for S, Se, W in zip(S_atoms, Se_atoms, W_atoms):
        # calculate stoichiometry:
        n_S = len(S)
        n_Se = len(Se)
        
        x_Se = n_Se / (n_S * (2 + n_Se/n_S))

        # calculate the standard error of the mean
        std_err_stoic = np.sqrt((x_Se * (1 - x_Se)) / (n_S + n_Se))

        tree = KDTree(W)
        distances, indices = tree.query(W, k=4)

        # The first column of distances is zero (distance to itself), we want the next three columns
        nearest_distances = distances[:, 1:4] * pixel_size # angstroms
        # keep distances between 3.0 and 4.0
        nearest_distances = nearest_distances[(nearest_distances > 3.0) & (nearest_distances < 3.25)]
        avg_distance = np.mean(nearest_distances)
        std_distance = np.std(nearest_distances)
        std_error = std_distance / np.sqrt(len(nearest_distances))

        if x_Se < 0.07:
            color = '#1f77b4'
            label = 'layer A'
            plt.errorbar(x_Se, avg_distance, yerr=std_error, xerr=std_err_stoic,capsize=5, c=color, fmt='o', label=label if i == 0 else None)
            i = i + 1
        else:
            color = '#d62728'
            label = 'layer B'
            plt.errorbar(x_Se, avg_distance, yerr=std_error, xerr=std_err_stoic,capsize=5, c=color, fmt='o', label=label if j == 0 else None)
            j = j + 1

    plt.xlabel('Se Fraction (x)')
    plt.ylabel(f'Lattice Constant ($\AA$)')
    # plt.legend(loc='lower right', fontsize=12)
    plt.legend(loc='upper left', fontsize=12)
    plt.xlim(0,0.2)
    plt.ylim(3.125, 3.185)
    plt.xticks([0, 0.05, 0.1, 0.15, 0.2])
    plt.tight_layout()


In [None]:
plot_vegards_law(results)

In [None]:
# plot all 10 crops on the same plot
fig, ax = plt.subplots(2,5, figsize=(20, 12), sharex=True, sharey=True)
colors = {'W': ['blue', 'cyan'], 'Se': ['green', 'lime'], 'S': ['red', 'magenta']}

for i in range(10):
    ax[i%2, i//2].imshow(images[i], cmap='gray')
    ax[i%2, i//2].axis('off')

    for element, colors in colors.items():
        for idx, (color, layer) in enumerate(zip(colors, results[i]['centroids'][element])):
            if layer.size > 0:
                ax[i%2, i//2].scatter(layer[:, 1], layer[:, 0], color=color, label=f'{element} Layer {idx + 1} Centroids', alpha=0.4, s=5)
    ax[i%2, i//2].legend(loc='upper right', fontsize=8)
    colors = {'W': ['blue', 'cyan'], 'Se': ['green', 'lime'], 'S': ['red', 'magenta']}
fig.tight_layout()

In [None]:
image_number = 3  # Replace with the desired image number
visualize_predictions(image_number, results, 6)

# gettimg the moire distribution

In [None]:
image_number = 0

# define two layers - more Se gets calaled 'top'
n_se_0 = len(results[image_number]['centroids']['Se'][0])
n_se_1 = len(results[image_number]['centroids']['Se'][1])
top_idx, bottom_idx = (1, 0) if n_se_1 >= n_se_0 else (0, 1)

s_top = results[image_number]['centroids']['S'][top_idx]
se_top = results[image_number]['centroids']['Se'][top_idx]
w_top = results[image_number]['centroids']['W'][top_idx]
chalc_top = np.vstack([arr for arr in [s_top, se_top] if arr.size > 0])

s_bottom = results[image_number]['centroids']['S'][bottom_idx]
se_bottom = results[image_number]['centroids']['Se'][bottom_idx]
w_bottom = results[image_number]['centroids']['W'][bottom_idx]
chalc_bottom = np.vstack([arr for arr in [s_bottom, se_bottom] if arr.size > 0])

# clean up chalchogen sites that may overlap
distance_cutoff = 10 # pixel

tree = KDTree(w_top)
distances, indices = tree.query(chalc_top, k=2)
chalc_top = chalc_top[distances[:, 1] > distance_cutoff]

tree = KDTree(w_bottom)
distances, indices = tree.query(chalc_bottom, k=2)
chalc_bottom = chalc_bottom[distances[:, 1] > distance_cutoff]

# make updataed trees for distance maps
w_top_tree, w_bottom_tree = KDTree(w_top), KDTree(w_bottom)
chalc_top_tree, chalc_bottom_tree = KDTree(chalc_top), KDTree(chalc_bottom)


# Create meshgrid for points
grid_density = 512
x = np.linspace(0, 512, grid_density)
y = np.linspace(0, 512, grid_density)
X, Y = np.meshgrid(x, y)
points = np.vstack((X.ravel(), Y.ravel())).T


# Define a helper function to compute distances and moire sites
def compute_distances(query_tree1, query_tree2, layer1, layer2):
    distances = np.linalg.norm(layer1[query_tree1.query(points)[1]] - layer2[query_tree2.query(points)[1]], axis=1)
    distances = distances.reshape((grid_density, grid_density))
    return distances

# Compute distances and moire sites for different layers
w_w_distances = compute_distances(w_top_tree, w_bottom_tree, w_top, w_bottom)
w_chalc_distances = compute_distances(w_top_tree, chalc_bottom_tree, w_top, chalc_bottom)
chalc_w_distances = compute_distances(chalc_top_tree, w_bottom_tree, chalc_top, w_bottom)
chalc_chalc_distances = compute_distances(chalc_top_tree, chalc_bottom_tree, chalc_top, chalc_bottom)

# Store all distances for further use
distance_arrays = [w_w_distances, w_chalc_distances, chalc_w_distances, chalc_chalc_distances]
map_2H = np.sqrt(distance_arrays[1]**2 + distance_arrays[2]**2)

# Plot results
fig, ax = plt.subplots(1, 4, figsize=(20, 5))
titles = ['W-W Distances', 'W-Chalc Distances', 'Chalc-W Distances', 'Chalc-Chalc Distances']
for i in range(4):
    ax[i].imshow(distance_arrays[i], cmap='viridis', extent=(0, 512, 512, 0))
    ax[i].set_title(titles[i])
    ax[i].axis('off')



In [None]:
map_2H = map_2H * pixel_size
plt.figure()
plt.imshow(map_2H, cmap='bone')
plt.axis('off')

plt.scatter(se_top[:, 1], se_top[:, 0], s=5, c='r')

In [None]:
moire_sites = griddata((X.ravel(), Y.ravel()), map_2H.ravel(), (se_top[:, 1], se_top[:, 0]), method='cubic')

In [None]:
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(18, 6))
im1 = ax1.imshow(map_2H, cmap='inferno')
ax1.scatter(se_top[:, 1], se_top[:, 0], s=20, c='orange', linewidths=0.5, edgecolors='k', label='Se Sites')
fig.colorbar(im1, ax=ax1)
ax1.axis('off')

# Histogram calculations
binedges = np.histogram_bin_edges(np.concatenate([moire_sites.ravel(), map_2H.ravel()]), bins=50)
hist_moire_sites_top, _ = np.histogram(moire_sites.ravel(), bins=binedges, density=False)
hist_map_2H_top, _ = np.histogram(map_2H.ravel(), bins=binedges, density=False)

bin_widths = np.diff(binedges)
hist_moire_sites_density_top = hist_moire_sites_top / (np.sum(hist_moire_sites_top) * bin_widths)
hist_map_2H_density_top = hist_map_2H_top / (np.sum(hist_map_2H_top) * bin_widths)

# Create colormap for Se_top
norm_top = Normalize(vmin=np.min(map_2H), vmax=np.max(map_2H))
colors_top = cm.inferno(norm_top(binedges[:-1]))
ax2.bar(binedges[:-1], hist_map_2H_density_top, width=bin_widths, color=colors_top, edgecolor='k', alpha=1, label='map_ravel')
ax2.step(binedges[:-1], hist_moire_sites_density_top, where='mid', color='#E69F00', linewidth=3, label='Se_site_distribution_top')

ax2.legend()
# ax2.set_xlim(0, 1)
ax2.set_yticks([])

plt.tight_layout()
plt.show()


In [None]:
results = {}
moire_sites_total = []
map_2H_total = []

selected_images = [0,2,3,5,6,7]
for i in selected_images:
    im = images[i]
    print(f"Processing image {i}")
    # Model Prediction
    im_tensor = torch.tensor(im.astype(np.float32)).unsqueeze(0).unsqueeze(0)
    ensemble.predict(im_tensor, return_plot=False)
    ensemble.rearrange_ensemble()
    ensemble.vote(mode='max')
    masks = threshold(ensemble.voted_prediction, ['threshold'], param = 0) 
    probability = ensemble.voted_prediction

    centroids = {'W': [], 'S': [], 'Se': []}
    for j, layer in enumerate(masks):
        if j == 0 or j == 3:
            element = 'S'
        elif j == 1 or j == 4:
            element = 'Se'
        else:
            element = 'W'

        labeled_array, num_features = label(layer)
        layer_centroids = np.array(center_of_mass(layer, labeled_array, range(1, num_features + 1)))
        centroids[element].append(layer_centroids)

    results[i] = {
        "image": im,
        "mask": masks,
        "probability": probability,
        "centroids": centroids,
    }

    # define two layers - more Se gets calaled 'top'
    n_se_0 = len(results[i]['centroids']['Se'][0])
    n_se_1 = len(results[i]['centroids']['Se'][1])
    top_idx, bottom_idx = (1, 0) if n_se_1 >= n_se_0 else (0, 1)

    s_top = results[i]['centroids']['S'][top_idx]
    se_top = results[i]['centroids']['Se'][top_idx]
    w_top = results[i]['centroids']['W'][top_idx]
    chalc_top = np.vstack([arr for arr in [s_top, se_top] if arr.size > 0])

    s_bottom = results[i]['centroids']['S'][bottom_idx]
    se_bottom = results[i]['centroids']['Se'][bottom_idx]
    w_bottom = results[i]['centroids']['W'][bottom_idx]
    chalc_bottom = np.vstack([arr for arr in [s_bottom, se_bottom] if arr.size > 0])

    results[i]['se_top'] = se_top

    # clean up chalchogen sites that may overlap
    distance_cutoff = 10  # pixel

    tree = KDTree(w_top)
    distances, indices = tree.query(chalc_top, k=2)
    chalc_top = chalc_top[distances[:, 1] > distance_cutoff]

    tree = KDTree(w_bottom)
    distances, indices = tree.query(chalc_bottom, k=2)
    chalc_bottom = chalc_bottom[distances[:, 1] > distance_cutoff]

    # make updataed trees for distance maps
    w_top_tree, w_bottom_tree = KDTree(w_top), KDTree(w_bottom)
    chalc_top_tree, chalc_bottom_tree = KDTree(chalc_top), KDTree(chalc_bottom)

    # Create meshgrid for points
    grid_density = 512
    x = np.linspace(0, 512, grid_density)
    y = np.linspace(0, 512, grid_density)
    X, Y = np.meshgrid(x, y)
    points = np.vstack((X.ravel(), Y.ravel())).T

    # Define a helper function to compute distances and moire sites
    def compute_distances(query_tree1, query_tree2, layer1, layer2):
        distances = np.linalg.norm(layer1[query_tree1.query(points)[1]] - layer2[query_tree2.query(points)[1]], axis=1)
        distances = distances.reshape((grid_density, grid_density))
        return distances

    # Compute distances and moire sites for different layers
    w_w_distances = compute_distances(w_top_tree, w_bottom_tree, w_top, w_bottom)
    w_chalc_distances = compute_distances(w_top_tree, chalc_bottom_tree, w_top, chalc_bottom)
    chalc_w_distances = compute_distances(chalc_top_tree, w_bottom_tree, chalc_top, w_bottom)
    chalc_chalc_distances = compute_distances(chalc_top_tree, chalc_bottom_tree, chalc_top, chalc_bottom)

    # Store all distances for further use
    distance_arrays = [w_w_distances, w_chalc_distances, chalc_w_distances, chalc_chalc_distances]
    map_2H = np.sqrt(distance_arrays[1]**2 + distance_arrays[2]**2)
    map_2H = map_2H * pixel_size
    map_2H_total.append(map_2H)
    
    moire_sites = griddata((X.ravel(), Y.ravel()), map_2H.ravel(), (se_top[:, 1], se_top[:, 0]), method='cubic')
    moire_sites_total.append(moire_sites)

    # store these now 
    results[i]['map_2H'] = map_2H
    results[i]['moire_sites'] = moire_sites
    results[i]['distance_arrays'] = distance_arrays

    # Plot results  
    plt.figure()
    plt.imshow(results[i]['image'])

    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(18, 6))
    im1 = ax1.imshow(map_2H, cmap='inferno')
    ax1.scatter(se_top[:, 1], se_top[:, 0], s=20, c='orange', linewidths=0.5, edgecolors='k', label='Se Sites')
    fig.colorbar(im1, ax=ax1)
    ax1.axis('off')

    # Histogram calculations
    binedges = np.histogram_bin_edges(np.concatenate([moire_sites.ravel(), map_2H.ravel()]), bins=50)
    hist_moire_sites_top, _ = np.histogram(moire_sites.ravel(), bins=binedges, density=False)
    hist_map_2H_top, _ = np.histogram(map_2H.ravel(), bins=binedges, density=False)

    bin_widths = np.diff(binedges)
    hist_moire_sites_density_top = hist_moire_sites_top / (np.sum(hist_moire_sites_top) * bin_widths)
    hist_map_2H_density_top = hist_map_2H_top / (np.sum(hist_map_2H_top) * bin_widths)

    # Create colormap for Se_top
    norm_top = Normalize(vmin=np.min(map_2H), vmax=np.max(map_2H))
    colors_top = cm.inferno(norm_top(binedges[:-1]))
    ax2.bar(binedges[:-1], hist_map_2H_density_top, width=bin_widths, color=colors_top, edgecolor='k', alpha=1, label='map_ravel')
    ax2.step(binedges[:-1], hist_moire_sites_density_top, where='mid', color='#E69F00', linewidth=3, label='Se_site_distribution_top')
    ax2.legend()
    ax2.set_yticks([])
    plt.tight_layout()

In [None]:
moire_sites_total = np.concatenate([np.ravel(element) for element in moire_sites_total])
map_2H_total = np.concatenate([np.ravel(element) for element in map_2H_total])

fig, ax = plt.subplots(1, 1, figsize=(6, 6))

binedges = np.histogram_bin_edges(np.concatenate([moire_sites_total, map_2H_total]), bins=50)
hist_moire_sites_top, _ = np.histogram(moire_sites_total, bins=binedges, density=False)
hist_map_2H_top, _ = np.histogram(map_2H_total.ravel(), bins=binedges, density=False)

bin_widths = np.diff(binedges)
hist_moire_sites_density_top = hist_moire_sites_top / (np.sum(hist_moire_sites_top) * bin_widths)
hist_map_2H_density_top = hist_map_2H_top / (np.sum(hist_map_2H_top) * bin_widths)

# Create colormap for Se_top
norm_top = Normalize(vmin=np.min(map_2H_total), vmax=np.max(map_2H_total))
colors_top = cm.inferno(norm_top(binedges[:-1]))
ax.bar(binedges[:-1], hist_map_2H_density_top, width=bin_widths, color=colors_top, edgecolor='k', alpha=1, label='map_ravel')
ax.step(binedges[:-1], hist_moire_sites_density_top, where='mid', color='#E69F00', linewidth=3, label='Se_site_distribution_top')
ax.legend()
ax.set_yticks([])
plt.tight_layout()

In [None]:
# Kernel density estimation function
def compute_kde(data, bandwidth=0.1, n_bins=100, x_min=0, x_max=10):
    kde = KernelDensity(bandwidth=bandwidth)
    kde.fit(data[:, np.newaxis])
    x_d = np.linspace(x_min, x_max, n_bins)
    log_dens = kde.score_samples(x_d[:, np.newaxis])
    return np.exp(log_dens), x_d

# Bootstrap KDE
def bootstrap_kde(data, n_bootstrap=1000, bandwidth=0.1, n_bins=100, x_min=0, x_max=10):
    kde_samples = np.zeros((n_bootstrap, n_bins))
    for i in range(n_bootstrap):
        resample = np.random.choice(data, size=len(data), replace=True)
        kde_samples[i], x_d = compute_kde(resample, bandwidth, n_bins, x_min, x_max)
    kde_mean = np.mean(kde_samples, axis=0)
    kde_std = np.std(kde_samples, axis=0)
    return kde_mean, kde_std, x_d

# Normalize the x-axis
n_bins = 100
bandwidth = 0.1

# Concatenate all histograms
combined_atom_histogram = moire_sites_total.copy()
combined_2H_map_histogram = map_2H_total.copy()

# Apply KDE with bootstrapping to both histograms
kde_atom_mean, kde_atom_std, x_d = bootstrap_kde(combined_atom_histogram, n_bootstrap=10, bandwidth=bandwidth, n_bins=n_bins)
kde_2H_map_mean, kde_2H_map_std, _ = bootstrap_kde(combined_2H_map_histogram, n_bootstrap=10, bandwidth=bandwidth, n_bins=n_bins)

# Compute bin edges for plotting
binedges_total = x_d
bin_widths_total = np.diff(x_d)


In [None]:
pixel_size = exp_data['pixel_size']

In [None]:
binedges_total = x_d

dispaly_iamge_number = 0
lims = [0.2,4.7]

map_2H = results[dispaly_iamge_number]['map_2H']
moire_sites = results[dispaly_iamge_number]['moire_sites']
distance_arrays = results[dispaly_iamge_number]['distance_arrays'] * pixel_size
se_top = results[dispaly_iamge_number]['se_top']

fig, ax = plt.subplots(1, 2, figsize=(18, 6), dpi=300)
im1 = ax[0].imshow(map_2H, cmap='inferno', vmin=lims[0], vmax=lims[1])
ax[0].scatter(se_top[:, 1], se_top[:, 0], s=120, c='green', linewidths=1, edgecolors='k', label='Se Sites')
fig.colorbar(im1, ax=ax[0])
ax[0].axis('off')

ax[1].plot(binedges_total, kde_2H_map_mean, color='purple', label='2H Map Density')
ax[1].fill_between(binedges_total, kde_2H_map_mean - kde_2H_map_std, kde_2H_map_mean + kde_2H_map_std, color='purple', alpha=0.3)

ax[1].plot(binedges_total, kde_atom_mean, color='green', label='Se Site Density')
ax[1].fill_between(binedges_total, kde_atom_mean - kde_atom_std, kde_atom_mean + kde_atom_std, color='green', alpha=0.3)

ax[1].plot(binedges_total, kde_atom_mean - kde_2H_map_mean, color='grey', label='Difference', linestyle='--')
ax[1].fill_between(binedges_total, kde_atom_mean - kde_2H_map_mean, np.zeros_like(kde_atom_mean - kde_2H_map_mean), color='grey', alpha=0.3)

# ax.legend(loc='upper left')
ax[1].set_xlim(lims[0], lims[1])
ax[1].set_ylim(-0.13, 0.8)
ax[1].axhline(y=0, color='k', linestyle='--', linewidth=1)
# for dist in [0, 1, np.sqrt(2)]:
    # ax[1].axvline(x=dist * 3.16, ymin=0, ymax=100, color='k', linestyle='--', linewidth=3)



# plot the distance maps    
fig, ax = plt.subplots(1,4, figsize = (20,5))
titles = ['W-W Distances', 'W-Chalc Distances', 'Chalc-W Distances', 'Chalc-Chalc Distances']
for i in range(4):
    ax[i].imshow(distance_arrays[i], cmap='inferno', vmin=lims[0], vmax=lims[1])
    ax[i].set_title(titles[i])
    ax[i].axis('off')


# make a new plot for the image
fig, ax = plt.subplots(1, 2, figsize=(12, 6), sharex=True, sharey=True, dpi=300)

# plot the image
ax[0].imshow(results[dispaly_iamge_number]['image'], cmap='gray')
ax[0].scatter(se_top[:, 1], se_top[:, 0], s=60, c='limegreen', linewidths=0.5, edgecolors='k', label='Se Sites')
ax[0].axis('off')
ax[1].imshow(results[dispaly_iamge_number]['image'], cmap='gray')

# plot each layer
colors = {'W': ['blue', 'cyan'], 'Se': ['green', 'lime'], 'S': ['red', 'magenta']}
for element, colors in colors.items():
    for idx, (color, layer) in enumerate(zip(colors, results[dispaly_iamge_number]['centroids'][element])):
        if layer.size > 0:
            ax[1].scatter(layer[:, 1], layer[:, 0], color=color, label=f'{element} Layer {idx + 1} Centroids', alpha=0.4, s=5, linewidths=0.5, edgecolors='k')

            colors = {'W': ['blue', 'cyan'], 'Se': ['green', 'lime'], 'S': ['red', 'magenta']}
        
ax[1].axis('off')


In [None]:
# Make sure to normalize KDEs to sum to 1 (since KL divergence expects probability distributions)
kde_atom_mean_normalized = kde_atom_mean / kde_atom_mean.sum()
kde_2H_map_mean_normalized = kde_2H_map_mean / kde_2H_map_mean.sum()

kl_divergence = entropy(kde_atom_mean_normalized.ravel(), kde_2H_map_mean_normalized.ravel())
print(f"KL Divergence: {kl_divergence}")

In [None]:
# Assuming kde_atom_mean_normalized and kde_2H_map_mean_normalized are already defined and normalized as in the provided code
mse = mean_squared_error(kde_2H_map_mean_normalized.ravel(), kde_atom_mean_normalized.ravel())
print(f"Mean Squared Error: {mse}")

In [None]:
display_image_number = 0

# Define the figure and axes for 4 subplots
fig, ax = plt.subplots(1, 4, figsize=(20, 5), sharex=True, sharey=True)
colors = {'W': ['blue', 'cyan'], 'Se': ['red', 'magenta'], 'S': ['red', 'magenta']}
# Image without any layers
ax[0].imshow(results[display_image_number]['image'], cmap='gray')
ax[0].axis('off')
ax[0].set_title('Original Image')

# Image with bottom (first) layer atoms only
ax[1].imshow(results[display_image_number]['image'], cmap='gray')
for element, color in colors.items():
    layer = results[display_image_number]['centroids'][element][0]  # First layer atoms
    if layer.size > 0:
        ax[1].scatter(layer[:, 1], layer[:, 0], color=color[0], label=f'{element} Bottom Layer Centroids', alpha=0.6, s=60, linewidths=0.5, edgecolors='k')
ax[1].axis('off')
ax[1].set_title('Bottom Layer Atoms')

# Image with top (second) layer atoms only
ax[2].imshow(results[display_image_number]['image'], cmap='gray')
for element, color in colors.items():
    layer = results[display_image_number]['centroids'][element][1]  # Top layer atoms
    if layer.size > 0:
        ax[2].scatter(layer[:, 1], layer[:, 0], color=color[1], label=f'{element} Top Layer Centroids', alpha=0.6, s=60, linewidths=0.5, edgecolors='k')
ax[2].axis('off')
ax[2].set_title('Top Layer Atoms')

# Image with all Se atoms
ax[3].imshow(results[display_image_number]['image'], cmap='gray')
for layer in results[display_image_number]['centroids']['Se']:  # All Se layers
    if layer.size > 0:
        
        ax[3].scatter(layer[:, 1], layer[:, 0], color='lime', label='Se Sites', alpha=0.6, s=60, linewidths=0.5, edgecolors='k')
ax[3].axis('off')
ax[3].set_title('All Se Atoms')

plt.tight_layout()
plt.show()

### Reconstruction image for figure

In [None]:
import ase
from ase.io import read, write
from ase import Atoms, Atom
from ase.visualize import view 

elements = ['W', 'Se', 'S']
atoms = Atoms()

for element in elements:
    for z, layer in enumerate(results[display_image_number]['centroids'][element]):
        z_height = (0.2 - z) * 200 * pixel_size  # Calculate the z-height
        print(f"Layer {z} for element {element}:")
        for atom in layer:
            x = atom[1] * pixel_size
            y = atom[0] * pixel_size
            
            if element in ['S', 'Se']:
                # If the element is S or Se, append two atoms displaced in the z direction
                atoms.append(Atom(element, (x, y, z_height + 1)))  # +z shift
                atoms.append(Atom(element, (x, y, z_height - 1)))  # -z shift
            else:
                # For W, keep the original atom
                atoms.append(Atom(element, (x, y, z_height)))

positions = atoms.get_positions()
min_pos = np.min(positions, axis=0)
max_pos = np.max(positions, axis=0)
cell_range = max_pos - min_pos
atoms.set_cell(cell_range)

# Save the Atoms object to a CIF file
write('atoms.cif', atoms)

In [None]:
view(atoms)

In [None]:
atoms = read('atoms.cif')
xtal = atoms.copy()

# Define the image size in Angstroms
positions = xtal.get_positions()[:, :2]
# xmin, xmax = np.min(positions[:, 0]), np.max(positions[:, 0])
# ymin, ymax = np.min(positions[:, 1]), np.max(positions[:, 1])
borders = 0
xmin, xmax = 0, 512 * pixel_size
ymin, ymax = 0, 512 * pixel_size
axis_extent = (xmin - borders, xmax + borders, ymin - borders, ymax + borders)

pixel_size = 0.106 # Angstrom/pixel, determines number of points, aka resolution of maps.  the xtal determines the fov

In [None]:
atom_var = 0.22
potential = dg.get_pseudo_potential(xtal = xtal, pixel_size = pixel_size, sigma = atom_var, axis_extent = axis_extent)

plt.figure()
plt.imshow(potential, cmap='gray', extent=axis_extent)
plt.colorbar()

In [None]:
# Make point spread function
airy_disk_size = 1
psf = dg.get_point_spread_function(airy_disk_radius = airy_disk_size, size = 32)
psf_resize = dg.resize_image(np.array(psf), n = max(potential.shape)) # for plotting on same axes as image
perfect_image = dg.convolve_kernel(potential, psf)

plt.figure()
plt.imshow(perfect_image, cmap='gray', extent=axis_extent)
plt.colorbar()

In [None]:
shot_noise = 0.7
noisy_image = dg.add_poisson_noise(perfect_image, shot_noise = shot_noise)

noisy_image = noisy_image - np.min(noisy_image)  # Normalize the noisy image to [0, 1]
noisy_image = noisy_image / np.max(noisy_image)


noisy_image = gaussian_filter(noisy_image, sigma=3)

In [None]:
fig, ax = plt.subplots(1, 2, figsize=(10, 5))
ax[0].imshow(perfect_image, cmap='gray')
ax[0].set_title('Perfect Image')

ax[1].imshow(noisy_image, cmap='gray', vmin=0, vmax=1)
ax[1].set_title('Noisy Image')

In [None]:
plt.figure()

plt.hist(noisy_image.ravel(), bins=20, range=(0.0, 1.0), fc='k', ec='k', alpha=0.5, label='Noisy Image')
plt.hist(results[display_image_number]['image'].ravel(), bins=20, range=(0.0, 1.0), fc='r', ec='r', alpha=0.5, label='Experimental Image')
plt.legend()

In [None]:
exp_image = gaussian_filter(results[display_image_number]['image'],3)

fig, ax = plt.subplots(1, 3, figsize=(10, 5), sharex=True, sharey=True)

ax[0].imshow(results[display_image_number]['image'], cmap='gray')

ax[1].imshow(noisy_image, cmap='gray', vmin=0, vmax=1)

difference = results[display_image_number]['image'] - noisy_image
diff = np.abs(difference)

ax[2].imshow(diff, cmap='gray', vmin=0, vmax=1)



In [None]:
crop_size = 100
n_crops = 9
crops = []
for _ in range(n_crops):
    x = np.random.randint(0, noisy_image.shape[1] - crop_size)
    y = np.random.randint(0, noisy_image.shape[0] - crop_size)
    crops.append((x, y))  # Store the coordinates as tuples

# Create a subplot with 10 rows and 3 columns
fig, ax = plt.subplots(10, 3, figsize=(15, 49), dpi=300)
ax[0, 0].imshow(gaussian_filter(results[display_image_number]['image'],sigma=3), cmap='gray')
ax[0, 1].imshow(noisy_image, cmap='gray', vmin=0, vmax=1)
ax[0, 2].imshow(diff, cmap='gray', vmin=0, vmax=1)

# Plot the cropped images
for i in range(9):
    x, y = crops[i]
    ax[i + 1, 0].imshow(gaussian_filter(results[display_image_number]['image'][y:y + crop_size, x:x + crop_size],sigma=3), cmap='gray', vmin = 0, vmax = 1)
    ax[i + 1, 1].imshow(noisy_image[y:y + crop_size, x:x + crop_size], cmap='gray', vmin=0, vmax=1)
    ax[i + 1, 2].imshow(diff[y:y + crop_size, x:x + crop_size], cmap='gray', vmin=0, vmax=1)

for a in ax.flat:
    a.axis('off')
plt.tight_layout()
