# Gomb - Net
### Performance test on the Graphene dataset
### And analysis of Graphene experimental data

Austin Houston

[![OpenInColab](https://colab.research.google.com/assets/colab-badge.svg)](
    https://colab.research.google.com/github/AustinHouston/Gomb-Net/blob/main/Eval_Graphene_model.ipynb)

In [None]:
# basics
import os
import sys
import numpy as np

# plotting
import matplotlib.pylab as plt
import matplotlib.colors as mcolors
from matplotlib.colors import Normalize
from matplotlib import cm

# colab interactive plots and drive
drive = False
if 'google.colab' in sys.modules:
    from  google.colab import drive 
    from google.colab import output
    drive.mount('/content/drive')
    output.enable_custom_widget_manager()
    drive = True
else:
    %matplotlib widget

# other imports
from scipy.ndimage import label, center_of_mass, gaussian_filter, zoom, uniform_filter
from scipy.spatial import KDTree
from scipy.interpolate import griddata
from scipy.stats import norm, gaussian_kde
from skimage.filters import threshold_otsu
from skimage.feature import blob_log

# for cropping function
if drive:
    print('installing DataGenSTEM')
    !pip install ase
    !git clone https://github.com/ahoust17/DataGenSTEM.git
    sys.path.append('./DataGenSTEM/DataGenSTEM')
    import data_generator as dg

# for Gomb-Net
if drive:
    print('installing Gomb-Net')
    !git clone https://github.com/ahoust17/Gomb-Net.git
    sys.path.append('./Gomb-Net/GombNet')    
from GombNet.networks import *
from GombNet.loss_func import GombinatorialLoss
from GombNet.utils import *

import torch
# Check if CUDA is available
if torch.cuda.is_available():
    device = torch.device("cuda")
    print("CUDA is available. Using GPU.")
else:
    device = torch.device("cpu")
    print("CUDA is not available. Using CPU.")

In [None]:
# set mps, just for my computer
device = torch.device('mps')

### Now, you need to add the following shared drive to your google drive:
*** WARNING: it is a big file.  Check before you download ***


https://drive.google.com/file/d/1DyKtrmJ8wNYQg3YEJ8_iXjz6lB_DQfwy/view?usp=sharing

### Run the following cell after the download is complete

In [None]:
# should be someting like 'content/drive/My Drive/Gomb-Net files'
if drive:
    shared_folder = 'drive/My Drive/Gomb-Net files'
else:
    shared_folder = '/Users/austin/Desktop/gomb_beta'

print('available files & directories:')
!ls '{shared_folder}'

### Now, on to running the code

let's look at the dataset:

In [None]:
# Create dataloaders
images_dir = str(shared_folder + '/Graphene_dataset/images')
labels_dir = str(shared_folder + '/Graphene_dataset/labels')
train_loader, val_loader, test_loader = get_dataloaders(images_dir, labels_dir, batch_size = 1, val_split=0.2, test_split=0.1, seed = 42)


In [None]:
test_iter = 3
test = test_loader.dataset[test_iter][0].unsqueeze(0)
gt = test_loader.dataset[test_iter][1]

fig, ax = plt.subplots(1, 3, figsize=(10, 5))
ax[0].imshow(test[0, 0].cpu().numpy(), cmap='gray')
ax[0].set_title('Input')

titles = ['L1: C', 'L2: C']
for i in range(2):
    ax[i+1].imshow(gt[i].cpu().numpy(), cmap='gray')
    ax[i+1].set_title(titles[i])
for a in ax:
    a.axis('off')
fig.tight_layout()

now let's look at the model:

In [None]:
# Initialize model
input_channels = 1
num_classes = 2
num_filters = [32, 64, 128, 256]

model = TwoLeggedUnet(input_channels, num_classes, num_filters, dropout = 0.2)
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
loss = GombinatorialLoss(group_size = num_classes//2, loss = 'Dice', epsilon=1e-6, class_weights = None, alpha=2)

In [None]:
# Get the number of trainable parameters
def count_trainable_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)
num_trainable_params = count_trainable_parameters(model)
print(f"Number of trainable parameters: {num_trainable_params}")

visualize the training history for the pretrained model:

In [None]:
loss_history = np.load(str(shared_folder + '/Pretrained_models/Graphene_model_loss_history.npz'))
train_loss = loss_history['train_loss_history']
val_loss = loss_history['val_loss_history']

plt.figure(figsize = (6,4))
plt.plot(train_loss, label='training', color = '#1f77b4')
plt.plot(val_loss, label='validation', color = '#d62728')
plt.xlabel('Epoch')
plt.xlim(0,30)
plt.legend(title='Losses')
plt.tight_layout()  

load in the pretrained weights onto our model 'skeleton'

In [None]:
model_path = str(shared_folder + '/Pretrained_models/Graphene_model.pth')

checkpoint = torch.load(model_path, map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
model.eval()

In [None]:
test_iter = 0

test = test_loader.dataset[test_iter][0].unsqueeze(0)
gt = test_loader.dataset[test_iter][1]

In [None]:
# Generate prediction
with torch.no_grad():
    #test.to(device)
    probability = model(test)
    prediction = F.sigmoid(probability)#>0.50
probability = probability.squeeze().cpu().numpy() 
prediction = prediction.squeeze().cpu().numpy()

threshold = threshold_otsu(prediction)
prediction = (prediction > threshold).astype(float)


In [None]:
plt.figure()
plt.imshow(test.squeeze().cpu().numpy(), cmap='gray')
plt.colorbar()

In [None]:
# Plotting
num_classes
fig, axs = plt.subplots(3,num_classes,dpi = 300, sharex=True, sharey=True)

for i in range(num_classes):
    axs[0,i].imshow(gt[i], cmap='gray')

for i in range(num_classes):
    axs[1,i].imshow(prediction[i], cmap='gray')

for i in range(num_classes)[:1]:
    axs[2,i].imshow(probability[i], cmap='plasma')
for i in range(num_classes)[1:]:
    axs[2,i].imshow(probability[i], cmap='viridis')


for ax in axs.ravel():
    ax.axis('off')

axs[0,0].set_ylabel('GrounTruth')
axs[1,0].set_ylabel('Prediction')
axs[2,0].set_ylabel('Probability')


fig.tight_layout()

### Comparison to blob-finder

In [None]:
image_number = 0
# try a regular blob finder for comparison
blob_im = test.squeeze().cpu().numpy()

blobs = blob_log(blob_im, min_sigma=1, max_sigma=20, num_sigma=5, threshold=0.1)
blobs_com = [center_of_mass(blob_im, blobs[i, 0], blobs[i, 1]) for i in range(blobs.shape[0])]

plt.figure()
plt.imshow(blob_im, cmap='gray')
plt.scatter(blobs[:, 1], blobs[:, 0], c='r', s=20)
plt.axis('off')

In [None]:
# The following code is marked out, because it takes a few minutes to run.  Just measuring network performance metrics across the test dataset

In [None]:
pwa_total = 0
dice_total = 0
IoU_total = 0

def iou(pred, gt):
    intersection = np.logical_and(pred, gt).sum()
    union = np.logical_or(pred, gt).sum()
    return intersection / union

def dice_coefficient(pred, gt):
    intersection = np.logical_and(pred, gt).sum()
    return 2 * intersection / (pred.sum() + gt.sum())

# Calculate the accuracy
for i in range(len(test_loader)):
    test = test_loader.dataset[i][0].unsqueeze(0)
    gt = test_loader.dataset[i][1].numpy()  # Convert to numpy array
    
    # Switch ground truth layers
    gt_switched = np.flip(gt, axis=0)
    
    with torch.no_grad():
        probability = model(test)
        prediction = torch.sigmoid(probability)  # Use torch.sigmoid instead of F.sigmoid (deprecated)
    
    probability = probability.squeeze().cpu().numpy()
    prediction = prediction.squeeze().cpu().numpy()

    threshold = threshold_otsu(prediction)
    prediction = (prediction > threshold).astype(float)
    
    # Calculate metrics for original and switched ground truths
    pwa_original = np.sum(prediction == gt) / np.prod(gt.shape)
    pwa_switched = np.sum(prediction == gt_switched) / np.prod(gt_switched.shape)
    
    dice_original = dice_coefficient(prediction, gt)
    dice_switched = dice_coefficient(prediction, gt_switched)
    
    iou_original = iou(prediction, gt)
    iou_switched = iou(prediction, gt_switched)
    
    # Take the highest value for each metric
    pwa_total += max(pwa_original, pwa_switched)
    dice_total += max(dice_original, dice_switched)
    IoU_total += max(iou_original, iou_switched)

# Calculate the average for each metric
pwa_total /= len(test_loader)
dice_total /= len(test_loader)
IoU_total /= len(test_loader)

print(f"Pixel-wise Accuracy: {pwa_total}")
print(f"Mean Dice Coefficient: {dice_total}")
print(f"Mean IoU: {IoU_total}")

### Now, on Experimental data:

In [None]:
exp_data = np.load(str(shared_folder + '/Experimental_datasets/moire.npz'))
im_array = exp_data['im_array']
pixel_size = exp_data['pixel_size']


# im_array = gaussian_filter(im_array, sigma=1)
im_array = im_array - np.min(im_array)
im_array = im_array / np.max(im_array)
# zoom_factor = 0.7
# 
# im_array = dg.resize_image(im_array, zoom_factor * 512)

print(f"Pixel size: {pixel_size.astype(float)} m/pix")
plt.figure()
plt.imshow(im_array, cmap='gray')

In [None]:
images = dg.shotgun_crop(im_array, crop_size=256, n_crops = 5, roi = None)

In [None]:
fig, ax = plt.subplots(5, 5, figsize=(10, 10), dpi = 300)
masks = np.zeros((5, 2, 256, 256))

for i, im in enumerate(images):
    # make nn prediction
    im = im.astype(np.float32)
    im = torch.tensor(im).unsqueeze(0).unsqueeze(0)
    with torch.no_grad():
        probability = model(im)
        prediction = torch.sigmoid(probability)
    probability = probability.squeeze().cpu().numpy()
    prediction = prediction.squeeze().cpu().numpy()
    threshold = threshold_otsu(prediction)
    prediction = (prediction > threshold).astype(float)
    masks[i] = prediction


    # plot
    ax[i, 0].imshow(im.squeeze().numpy(), cmap='gray')
    ax[i, 1].imshow(prediction[0], cmap='gray')
    ax[i, 2].imshow(prediction[1], cmap='gray')
    ax[i, 3].imshow(probability[0], cmap='plasma')
    ax[i, 4].imshow(probability[1], cmap='viridis')

for a in ax.ravel():
    a.axis('off')
fig.tight_layout()

#plt.savefig('moire_segmentation.png')

In [None]:
resize_factor = 4
zoom_order = 3
i = 0
dist_hist = []
colors = ['#d62728','#1f77b4']
for image, mask in zip(images, masks):
    # fig, ax = plt.subplots(1, 2, dpi=300, figsize=(5, 5), subplot_kw={'aspect': 'equal'})
    for layer, color, a in zip(mask, colors, ax):
        # resize layer
        layer = zoom(layer, resize_factor, order=zoom_order)
        layer = gaussian_filter(layer, sigma=1)
        threshold = threshold_otsu(layer)
        layer = (layer > threshold).astype(float)

        plt.figure()
        plt.imshow(layer, cmap='gray')

        labeled_array, num_features = label(layer)
        centroids = center_of_mass(layer, labeled_array, range(1, num_features + 1))
        centroids = np.array(centroids)

        # Calculate the distance between the centroids
        tree = KDTree(centroids)
        distances, indices = tree.query(centroids, k=3)
        nearest_distances = distances[:, 1:] * float(pixel_size) # angstroms
        dist_hist.append(nearest_distances.flatten())

    #     a.scatter(centroids[:, 1], centroids[:, 0], s=38, c=color, edgecolors='k', linewidths=0.5)
    #     a.axis('off')
    #     a.set_xlim(0, layer.shape[1])
    #     a.set_ylim(layer.shape[0], 0)  # Flip the y-axis to match image coordinates
    #     
    # plt.subplots_adjust(wspace=0, hspace=0)
    #plt.savefig(f'atoms_{i}.png', transparent=True, bbox_inches='tight', pad_inches=0)
    i += 1


In [None]:
# make dist_hist a 1D array
dist_hist = np.concatenate(dist_hist) * 1e10 / resize_factor

avg_dist = np.mean(dist_hist)

plt.figure(figsize = (8,8), dpi=300)
plt.hist(dist_hist, bins=100, color='gray')

plt.vlines(1.42, 0, 1000, color='k', label='C-C bond 1.42 Å')
plt.vlines(avg_dist, 0, 1000, color='k', linestyle = '--', label=f'Avg. bond {avg_dist:.2f} Å')
plt.xlabel('Distance (Å)', fontsize=24)
plt.xlim(0.8,2)
plt.ylim(0,850)
plt.legend(loc='upper right', fontsize=24)
# set xtick fontsize
plt.xticks(fontsize=24)
plt.yticks([])
plt.tight_layout()

In [None]:
print(f"Average distance: {avg_dist:.2f} Å")

# FWHM
half_max = np.max(np.histogram(dist_hist, bins=100)[0]) / 2
hist, bin_edges = np.histogram(dist_hist, bins=100)
bin_centers = (bin_edges[:-1] + bin_edges[1:]) / 2
fwhm = bin_centers[hist > half_max]
fwhm = fwhm[-1] - fwhm[0]
print(f"FWHM: {fwhm:.2f} Å")

In [None]:
pixel_size