# Gomb - Net
### Performance test on the Graphene dataset
### And analysis of Graphene experimental data

Austin Houston

[![OpenInColab](https://colab.research.google.com/assets/colab-badge.svg)](
    https://colab.research.google.com/github/ahoust17/Gomb-Net/blob/main/Eval_Graphene_model.ipynb)

In [None]:
# basics
import os
import sys
import numpy as np

# plotting
import matplotlib.pylab as plt
import matplotlib.colors as mcolors
from matplotlib.colors import Normalize
from matplotlib import cm

# colab interactive plots and drive
drive = False
if 'google.colab' in sys.modules:
    from  google.colab import drive 
    from google.colab import output
    drive.mount('/content/drive')
    output.enable_custom_widget_manager()
    drive = True
else:
    %matplotlib widget

# other imports
from scipy.ndimage import label, center_of_mass, gaussian_filter, zoom, uniform_filter
from scipy.spatial import KDTree
from scipy.interpolate import griddata
from scipy.stats import norm, gaussian_kde
from skimage.filters import threshold_otsu
from skimage.feature import blob_log

# for cropping function
if drive:
    print('installing DataGenSTEM')
    !pip install ase
    !git clone https://github.com/ahoust17/DataGenSTEM.git
    sys.path.append('./DataGenSTEM/DataGenSTEM')
    import data_generator as dg

# for Gomb-Net
if drive:
    print('installing Gomb-Net')
    !git clone https://github.com/ahoust17/Gomb-Net.git
    sys.path.append('./Gomb-Net/GombNet')    
from GombNet.networks import *
from GombNet.loss_func import GombinatorialLoss
from GombNet.utils import *

import torch
# Check if CUDA is available
if torch.cuda.is_available():
    device = torch.device("cuda")
    print("CUDA is available. Using GPU.")
else:
    device = torch.device("cpu")
    print("CUDA is not available. Using CPU.")

In [None]:
# set mps, just for my computer
# device = torch.device('mps')

### Now, you need to add the following shared drive to your google drive:
*** WARNING: it is a big file.  Check before you download ***


https://drive.google.com/drive/folders/1tDF283xry5op3t594oBUlcNLKbjRTV7C?usp=sharing

### Run the following cell after the download is complete

In [None]:
# should be someting like 'content/drive/My Drive/Gomb-Net files'
if drive:
    shared_folder = 'drive/My Drive/Gomb-Net files'
else:
    shared_folder = '/Users/austin/Desktop/Gomb-Net aux files'

print('available files & directories:')
!ls '{shared_folder}'

### Now, on to running the code

let's look at the dataset:

In [None]:
# Create dataloaders
images_dir = str(shared_folder + '/Graphene_dataset/images')
labels_dir = str(shared_folder + '/Graphene_dataset/labels')
train_loader, val_loader, test_loader = get_dataloaders(images_dir, labels_dir, batch_size = 1, val_split=0.2, test_split=0.1, seed = 42)


In [None]:
test_iter = 3
test = test_loader.dataset[test_iter][0].unsqueeze(0)
gt = test_loader.dataset[test_iter][1]

fig, ax = plt.subplots(1, 3, figsize=(10, 5))
ax[0].imshow(test[0, 0].cpu().numpy(), cmap='gray')
ax[0].set_title('Input')

titles = ['L1: C', 'L2: C']
for i in range(2):
    ax[i+1].imshow(gt[i].cpu().numpy(), cmap='gray')
    ax[i+1].set_title(titles[i])
for a in ax:
    a.axis('off')
fig.tight_layout()

now let's look at the model:

In [None]:
# Initialize model
input_channels = 1
num_classes = 2
num_filters = [32, 64, 128, 256]

model = TwoLeggedUnet(input_channels, num_classes, num_filters, dropout = 0.2)
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
loss = GombinatorialLoss(group_size = num_classes//2, loss = 'Dice', epsilon=1e-6, class_weights = None, alpha=2)

In [None]:
# Get the number of trainable parameters
def count_trainable_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)
num_trainable_params = count_trainable_parameters(model)
print(f"Number of trainable parameters: {num_trainable_params}")

visualize the training history for the pretrained model:

In [None]:
loss_history = np.load(str(shared_folder + '/Pretrained_models/Graphene_model_loss_history.npz'))
train_loss = loss_history['train_loss_history']
val_loss = loss_history['val_loss_history']

plt.figure(figsize = (6,4))
plt.plot(train_loss, label='training', color = '#1f77b4')
plt.plot(val_loss, label='validation', color = '#d62728')
plt.xlabel('Epoch')
plt.xlim(0,30)
plt.legend(title='Losses')
plt.tight_layout()  

load in the pretrained weights onto our model 'skeleton'

In [None]:
model_path = str(shared_folder + '/Pretrained_models/Graphene_model.pth')

checkpoint = torch.load(model_path, map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
model.eval()

In [None]:
test_iter = 0

test = test_loader.dataset[test_iter][0].unsqueeze(0)
gt = test_loader.dataset[test_iter][1]

In [None]:
# Generate prediction
with torch.no_grad():
    #test.to(device)
    probability = model(test)
    prediction = F.sigmoid(probability)#>0.50
probability = probability.squeeze().cpu().numpy() 
prediction = prediction.squeeze().cpu().numpy()

threshold = threshold_otsu(prediction)
prediction = (prediction > threshold).astype(float)


In [None]:
plt.figure()
plt.imshow(test.squeeze().cpu().numpy(), cmap='gray')
plt.colorbar()

In [None]:
# Plotting
num_classes
fig, axs = plt.subplots(3,num_classes,dpi = 300, sharex=True, sharey=True)

for i in range(num_classes):
    axs[0,i].imshow(gt[i], cmap='gray')

for i in range(num_classes):
    axs[1,i].imshow(prediction[i], cmap='gray')

for i in range(num_classes)[:1]:
    axs[2,i].imshow(probability[i], cmap='plasma')
for i in range(num_classes)[1:]:
    axs[2,i].imshow(probability[i], cmap='viridis')


for ax in axs.ravel():
    ax.axis('off')

axs[0,0].set_ylabel('GrounTruth')
axs[1,0].set_ylabel('Prediction')
axs[2,0].set_ylabel('Probability')


fig.tight_layout()

### Comparison to blob-finder

In [None]:
image_number = 0
# try a regular blob finder for comparison
blob_im = test.squeeze().cpu().numpy()

blobs = blob_log(blob_im, min_sigma=1, max_sigma=20, num_sigma=5, threshold=0.1)
blobs_com = [center_of_mass(blob_im, blobs[i, 0], blobs[i, 1]) for i in range(blobs.shape[0])]

plt.figure()
plt.imshow(blob_im, cmap='gray')
plt.scatter(blobs[:, 1], blobs[:, 0], c='r', s=20)
plt.axis('off')

In [None]:
# The following code is marked out, because it takes a few minutes to run.  Just measuring network performance metrics across the test dataset

In [None]:
# pwa_total = 0
# dice_total = 0
# IoU_total = 0
# 
# def iou(pred, gt):
#     intersection = np.logical_and(pred, gt).sum()
#     union = np.logical_or(pred, gt).sum()
#     return intersection / union
# 
# def dice_coefficient(pred, gt):
#     intersection = np.logical_and(pred, gt).sum()
#     return 2 * intersection / (pred.sum() + gt.sum())
# 
# # Calculate the accuracy
# for i in range(len(test_loader)):
#     test = test_loader.dataset[i][0].unsqueeze(0)
#     gt = test_loader.dataset[i][1].numpy()  # Convert to numpy array
#     
#     # Switch ground truth layers
#     gt_switched = np.flip(gt, axis=0)
#     
#     with torch.no_grad():
#         probability = model(test)
#         prediction = torch.sigmoid(probability)  # Use torch.sigmoid instead of F.sigmoid (deprecated)
#     
#     probability = probability.squeeze().cpu().numpy()
#     prediction = prediction.squeeze().cpu().numpy()
# 
#     threshold = threshold_otsu(prediction)
#     prediction = (prediction > threshold).astype(float)
#     
#     # Calculate metrics for original and switched ground truths
#     pwa_original = np.sum(prediction == gt) / np.prod(gt.shape)
#     pwa_switched = np.sum(prediction == gt_switched) / np.prod(gt_switched.shape)
#     
#     dice_original = dice_coefficient(prediction, gt)
#     dice_switched = dice_coefficient(prediction, gt_switched)
#     
#     iou_original = iou(prediction, gt)
#     iou_switched = iou(prediction, gt_switched)
#     
#     # Take the highest value for each metric
#     pwa_total += max(pwa_original, pwa_switched)
#     dice_total += max(dice_original, dice_switched)
#     IoU_total += max(iou_original, iou_switched)
# 
# # Calculate the average for each metric
# pwa_total /= len(test_loader)
# dice_total /= len(test_loader)
# IoU_total /= len(test_loader)
# 
# print(f"Pixel-wise Accuracy: {pwa_total}")
# print(f"Mean Dice Coefficient: {dice_total}")
# print(f"Mean IoU: {IoU_total}")

### Now, on Experimental data:

In [None]:
exp_data = np.load(str(shared_folder + '/Experimental_datasets/moire.npz'))
im_array = exp_data['im_array']
pixel_size = exp_data['pixel_size']


# im_array = gaussian_filter(im_array, sigma=1)
im_array = im_array - np.min(im_array)
im_array = im_array / np.max(im_array)
# zoom_factor = 0.7
# 
# im_array = dg.resize_image(im_array, zoom_factor * 512)

print(f"Pixel size: {pixel_size.astype(float)} m/pix")
plt.figure()
plt.imshow(im_array, cmap='gray')

In [None]:
images = dg.shotgun_crop(im_array, crop_size=256, n_crops = 5, roi = None)

In [None]:
fig, ax = plt.subplots(5, 5, figsize=(10, 10), dpi = 300)
masks = np.zeros((5, 2, 256, 256))

for i, im in enumerate(images):
    # make nn prediction
    im = im.astype(np.float32)
    im = torch.tensor(im).unsqueeze(0).unsqueeze(0)
    with torch.no_grad():
        probability = model(im)
        prediction = torch.sigmoid(probability)
    probability = probability.squeeze().cpu().numpy()
    prediction = prediction.squeeze().cpu().numpy()
    threshold = threshold_otsu(prediction)
    prediction = (prediction > threshold).astype(float)
    masks[i] = prediction


    # plot
    ax[i, 0].imshow(im.squeeze().numpy(), cmap='gray')
    ax[i, 1].imshow(prediction[0], cmap='gray')
    ax[i, 2].imshow(prediction[1], cmap='gray')
    ax[i, 3].imshow(probability[0], cmap='plasma')
    ax[i, 4].imshow(probability[1], cmap='viridis')

for a in ax.ravel():
    a.axis('off')
fig.tight_layout()

#plt.savefig('moire_segmentation.png')

In [None]:
resize_factor = 4
zoom_order = 3
i = 0
dist_hist = []
colors = ['#d62728','#1f77b4']
for image, mask in zip(images, masks):
    # fig, ax = plt.subplots(1, 2, dpi=300, figsize=(5, 5), subplot_kw={'aspect': 'equal'})
    for layer, color, a in zip(mask, colors, ax):
        # resize layer
        layer = zoom(layer, resize_factor, order=zoom_order)
        layer = gaussian_filter(layer, sigma=1)
        threshold = threshold_otsu(layer)
        layer = (layer > threshold).astype(float)

        plt.figure()
        plt.imshow(layer, cmap='gray')

        labeled_array, num_features = label(layer)
        centroids = center_of_mass(layer, labeled_array, range(1, num_features + 1))
        centroids = np.array(centroids)

        # Calculate the distance between the centroids
        tree = KDTree(centroids)
        distances, indices = tree.query(centroids, k=3)
        nearest_distances = distances[:, 1:] * float(pixel_size) # angstroms
        dist_hist.append(nearest_distances.flatten())

    #     a.scatter(centroids[:, 1], centroids[:, 0], s=38, c=color, edgecolors='k', linewidths=0.5)
    #     a.axis('off')
    #     a.set_xlim(0, layer.shape[1])
    #     a.set_ylim(layer.shape[0], 0)  # Flip the y-axis to match image coordinates
    #     
    # plt.subplots_adjust(wspace=0, hspace=0)
    #plt.savefig(f'atoms_{i}.png', transparent=True, bbox_inches='tight', pad_inches=0)
    i += 1


In [None]:
# make dist_hist a 1D array
dist_hist = np.concatenate(dist_hist) * 1e10 / resize_factor

avg_dist = np.mean(dist_hist)

plt.figure(figsize = (8,8), dpi=300)
plt.hist(dist_hist, bins=100, color='gray')

plt.vlines(1.42, 0, 1000, color='k', label='C-C bond 1.42 Å')
plt.vlines(avg_dist, 0, 1000, color='k', linestyle = '--', label=f'Avg. bond {avg_dist:.2f} Å')
plt.xlabel('Distance (Å)', fontsize=24)
plt.xlim(0.8,2)
plt.ylim(0,850)
plt.legend(loc='upper right', fontsize=24)
# set xtick fontsize
plt.xticks(fontsize=24)
plt.yticks([])
plt.tight_layout()

In [None]:
print(f"Average distance: {avg_dist:.2f} Å")

# FWHM
half_max = np.max(np.histogram(dist_hist, bins=100)[0]) / 2
hist, bin_edges = np.histogram(dist_hist, bins=100)
bin_centers = (bin_edges[:-1] + bin_edges[1:]) / 2
fwhm = bin_centers[hist > half_max]
fwhm = fwhm[-1] - fwhm[0]
print(f"FWHM: {fwhm:.2f} Å")

In [None]:
pixel_size

### Below is some old code not used in the paper


In [None]:
def crop_image(image, crop_size, stride):
    crops = []
    coords = []
    for i in range(0, image.shape[0] - crop_size + 1, stride):
        for j in range(0, image.shape[1] - crop_size + 1, stride):
            crop = image[i:i+crop_size, j:j+crop_size]
            crops.append(crop)
            coords.append((i, j))
    return np.array(crops), coords

def reconstruct_image(crops, coords, image_shape, crop_size, stride):
    reconstructed = np.zeros(image_shape, dtype=np.float32)
    count_map = np.zeros(image_shape, dtype=np.float32)

    for crop, (i, j) in zip(crops, coords):
        reconstructed[i:i+crop_size, j:j+crop_size] += crop
        count_map[i:i+crop_size, j:j+crop_size] += 1

    return reconstructed / count_map

In [None]:
crop_size = 256
stride = 16 

# load data
exp_data = np.load('./Exp_data/moire.npz')
im_array = exp_data['im_array']
pixel_size = exp_data['pixel_size']
pixel_size = float(pixel_size) * 1e10  # Convert to angstroms

im_array = im_array - np.min(im_array)
im_array = im_array / np.max(im_array)

# pad input image
im_array = np.pad(im_array, ((crop_size//2, crop_size//2), (crop_size//2, crop_size//2)), mode='wrap')

crops, coords = crop_image(im_array, crop_size, stride)
crops = crops.astype(np.float32)
crops = torch.tensor(crops).unsqueeze(1)

# make prediction
with torch.no_grad():
    probability = model(crops)
probability = probability.numpy()

prob_A = probability[:, 0, :, :]
prob_B = probability[:, 1, :, :]

map_A = reconstruct_image(prob_A, coords, im_array.shape, crop_size, stride)
map_B = reconstruct_image(prob_B, coords, im_array.shape, crop_size, stride)


In [None]:
# de-pad the image and maps
im_array = im_array[crop_size//2:-crop_size//2, crop_size//2:-crop_size//2]
map_A = map_A[crop_size//2:-crop_size//2, crop_size//2:-crop_size//2]
map_B = map_B[crop_size//2:-crop_size//2, crop_size//2:-crop_size//2]

In [None]:
fig,ax = plt.subplots(1,3, figsize=(15,5), dpi = 300)
ax[0].imshow(im_array, cmap='gray')
ax[1].imshow(map_A, cmap='plasma')
ax[2].imshow(map_B, cmap='viridis')

for a in ax:
    a.axis('off')
plt.tight_layout()

In [None]:
# histograms of the maps
plt.figure(figsize=(8,8))
plt.hist(map_A.flatten(), bins=100, color='purple', alpha=0.5, label='Class A');
plt.hist(map_B.flatten(), bins=100, color='green', alpha=0.5, label='Class B');

In [None]:
dist_hist = []
colors = ['#d62728','#1f77b4']
for layer, color in zip([map_A, map_B], colors):
    # crop off the outside 16 pixels


    layer = (layer > 0).astype(float)

    fig, ax = plt.subplots(1,3,sharex=True, sharey = True, dpi=300, figsize=(14,8))
    ax[0].imshow(layer, cmap='gray')

    labeled_array, num_features = label(layer)
    centroids = center_of_mass(layer, labeled_array, range(1, num_features + 1))
    centroids = np.array(centroids)

    # Calculate the distance between the centroids
    tree = KDTree(centroids)
    distances, indices = tree.query(centroids, k=3)
    nearest_distances = distances[:, 1:] * float(pixel_size) # angstroms
    dist_hist.append(nearest_distances.flatten())

    ax[1].imshow(np.zeros_like(layer)+1,vmax=1,vmin=0, cmap='gray')
    ax[1].scatter(centroids[:, 1], centroids[:, 0], s=30, c=color, edgecolors='k', linewidths=0.5)
    ax[1].set_xlim(0, layer.shape[1])
    ax[1].set_ylim(layer.shape[0], 0)  # Flip the y-axis to match image coordinates

    for a in ax:
        a.axis('off')

In [None]:
pixel_size * len(image)

In [None]:
labeled_array_A, num_features_A = label(map_A > 0)
centroids_A = center_of_mass(map_A, labeled_array_A, range(1, num_features_A + 1))
centroids_A = np.array(centroids_A)

labeled_array_B, num_features_B = label(map_B > 0)
centroids_B = center_of_mass(map_B, labeled_array_B, range(1, num_features_B + 1))
centroids_B = np.array(centroids_B)

fig, ax = plt.subplots(1, 2, figsize=(12, 8), sharex=True, sharey=True)
ax[0].imshow(im_array, cmap='gray')
ax[1].imshow(im_array, cmap='gray')
ax[0].scatter(centroids_A[:, 1], centroids_A[:, 0], s=60, c='purple', edgecolors='k', linewidths=0.5, alpha = 0.8)
ax[1].scatter(centroids_B[:, 1], centroids_B[:, 0], s=60, c='green', edgecolors='k', linewidths=0.5, alpha = 0.8)

for a in ax:
    a.axis('off')
plt.tight_layout()


In [None]:
### Strain map of both layers


In [None]:

# Create KDTree and find the distances to the nearest 3 centroids
tree_A = KDTree(centroids_A)
distances_A, _ = tree_A.query(centroids_A, k=4)  # k=4 because the closest point will be itself

tree_B = KDTree(centroids_B)
distances_B, _ = tree_B.query(centroids_B, k=4)

# Calculate the average bond distance for the nearest 3 neighbors (excluding itself)
average_distance_A = np.mean(distances_A[:, 1:4], axis=1)
average_distance_B = np.mean(distances_B[:, 1:4], axis=1)

# Calculate strain as deviation from the average bond distance
strain_A = distances_A[:, 1:4] - average_distance_A[:, None]
strain_B = distances_B[:, 1:4] - average_distance_B[:, None]

# Flatten the strain arrays for plotting
strain_A_flat = strain_A.flatten()
strain_B_flat = strain_B.flatten()

# Plot the strain maps
fig, ax = plt.subplots(1, 2, figsize=(12, 8), sharex=True, sharey=True)
ax[0].imshow(im_array, cmap='gray')
ax[1].imshow(im_array, cmap='gray')

# Scatter plot of strain with color representing the strain magnitude
sc_A = ax[0].scatter(centroids_A[:, 1], centroids_A[:, 0], c=np.mean(strain_A, axis=1), cmap='viridis', s=60, edgecolors='k', linewidths=0.5)
sc_B = ax[1].scatter(centroids_B[:, 1], centroids_B[:, 0], c=np.mean(strain_B, axis=1), cmap='viridis', s=60, edgecolors='k', linewidths=0.5)

# Colorbar
cbar_A = plt.colorbar(sc_A, ax=ax[0], fraction=0.046, pad=0.04)
cbar_A.set_label('Strain Magnitude')

cbar_B = plt.colorbar(sc_B, ax=ax[1], fraction=0.046, pad=0.04)
cbar_B.set_label('Strain Magnitude')

for a in ax:
    a.axis('off')

plt.tight_layout()

In [None]:

# Create KDTree and find the distances to the nearest 3 centroids
tree_A = KDTree(centroids_A)
distances_A, _ = tree_A.query(centroids_A, k=4)  # k=4 because the closest point will be itself

tree_B = KDTree(centroids_B)
distances_B, _ = tree_B.query(centroids_B, k=4)

# Calculate the average bond distance for the nearest 3 neighbors (excluding itself)
average_distance_A = np.mean(distances_A[:, 1:4], axis=1)
average_distance_B = np.mean(distances_B[:, 1:4], axis=1)

# Calculate strain as deviation from the average bond distance
strain_A = distances_A[:, 1:4] - average_distance_A[:, None]
strain_B = distances_B[:, 1:4] - average_distance_B[:, None]

# Average strain per centroid
average_strain_A = np.mean(strain_A, axis=1)
average_strain_B = np.mean(strain_B, axis=1)

# Create a grid for interpolation
grid_x, grid_y = np.mgrid[0:512, 0:512]

# Interpolate strain values over a 512x512 grid
strain_map_A = griddata(centroids_A, average_strain_A, (grid_x, grid_y), method='cubic', fill_value=0)
strain_map_B = griddata(centroids_B, average_strain_B, (grid_x, grid_y), method='cubic', fill_value=0)

# Plot the interpolated strain maps
fig, ax = plt.subplots(1, 2, figsize=(12, 8), sharex=True, sharey=True)
cax_A = ax[0].imshow(strain_map_A, cmap='viridis', origin='lower')
cax_B = ax[1].imshow(strain_map_B, cmap='viridis', origin='lower')

ax[0].set_title('Strain Map A')
ax[1].set_title('Strain Map B')

# Colorbars
cbar_A = plt.colorbar(cax_A, ax=ax[0], fraction=0.046, pad=0.04)
cbar_A.set_label('Strain Magnitude')

cbar_B = plt.colorbar(cax_B, ax=ax[1], fraction=0.046, pad=0.04)
cbar_B.set_label('Strain Magnitude')

plt.tight_layout()
plt.show()

# Compute correlation between the two strain maps
correlation = np.corrcoef(strain_map_A.flatten(), strain_map_B.flatten())[0, 1]
print(f'Correlation between the strain maps: {correlation:.4f}')

In [None]:
# Define a function to compute local correlation
def local_correlation(x, y, window_size):
    x_mean = uniform_filter(x, window_size)
    y_mean = uniform_filter(y, window_size)
    x2_mean = uniform_filter(x**2, window_size)
    y2_mean = uniform_filter(y**2, window_size)
    xy_mean = uniform_filter(x * y, window_size)
    
    covariance = xy_mean - x_mean * y_mean
    variance_x = x2_mean - x_mean**2
    variance_y = y2_mean - y_mean**2
    
    correlation_map = covariance / np.sqrt(variance_x * variance_y)
    return correlation_map

# Compute local correlation map
window_size = 21  # You can adjust the window size
correlation_map = local_correlation(strain_map_A, strain_map_B, window_size)

# Plot the correlation map
plt.figure(figsize=(6, 6))
plt.imshow(np.abs(correlation_map), cmap='coolwarm', origin='lower')
plt.colorbar(label='Local Correlation')
plt.title('Local Correlation Map')
plt.show()

In [None]:
plt.figure()
plt.hist(correlation_map.ravel());

In [None]:
lattice_param = 1.39 # angstroms




In [None]:
pixel_size

In [None]:

# Find the nearest neighbors in layer A for each atom in layer B
tree_B = KDTree(centroids_B)
distances, indices = tree_B.query(centroids_A)

# Calculate the displacement vectors
order_parameter_vectors = centroids_B[indices] - centroids_A
order_parameters = np.linalg.norm(order_parameter_vectors, axis=1)  # just the magnitude
order_parameters = order_parameters * pixel_size
angles = np.arctan2(order_parameter_vectors[:, 1], order_parameter_vectors[:, 0])
norm_angles = plt.Normalize(vmin=-np.pi, vmax=np.pi)

plt.figure(figsize=(10, 10))
# Normalize vectors for quiver plot
u = order_parameter_vectors[:, 0] / order_parameters
v = order_parameter_vectors[:, 1] / order_parameters
plt.quiver(centroids_A[:, 0], centroids_A[:, 1], u, v, order_parameters, cmap='viridis', scale=1, scale_units='xy', angles='xy')

plt.scatter(centroids_A[:, 0], centroids_A[:, 1], s=60, c='purple', edgecolors='k', linewidths=0.5, alpha=0.8)
plt.scatter(centroids_B[:, 0], centroids_B[:, 1], s=60, c='green', edgecolors='k', linewidths=0.5, alpha=0.8)

plt.colorbar()
plt.axis('off')
plt.title('Order parameter u map - graphene moire')
plt.show()

In [None]:

# Create empty numpy arrays for 512x512 grid
image_size = 512
angle_image = np.zeros((image_size, image_size))
magnitude_image = np.zeros((image_size, image_size))

# Build a KDTree for the centroids in layer B
tree_B = KDTree(centroids_B)

# Assign values to each pixel based on the closest order parameter vector
for x in range(image_size):
    for y in range(image_size):
        _, index = tree_B.query([x, y])
        angle_image[y, x] = angles[index]
        magnitude_image[y, x] = magnitudes[index]



# Plotting the angle image
plt.figure(figsize=(10, 10))
plt.imshow(angle_image, cmap='twilight', origin='lower')
plt.colorbar(label='Direction (radians)')
plt.title('Angle Image')
plt.axis('off')
plt.show()

# Plotting the magnitude image
plt.figure(figsize=(10, 10))
plt.imshow(magnitude_image, cmap='viridis', origin='lower')
plt.colorbar(label='Magnitude')
plt.title('Magnitude Image')
plt.axis('off')
plt.show()