In [None]:
from raffle.generator import raffle_generator
from ase.io import read
from ase import build
import numpy as np
import re

import matplotlib.pyplot as plt
from matplotlib import colors
from scipy.interpolate import griddata
from scipy.spatial import cKDTree
import ipywidgets as widgets
from ipywidgets import interactive
%matplotlib widget
import matplotlib.pyplot as plt
from matplotlib.ticker import AutoMinorLocator

from IPython.display import display
import ipywidgets as widgets
import tkinter as tk
from tkinter import filedialog

In [None]:
# Set the font family and size to use for Matplotlib figures
plt.rc('text', usetex=True)
plt.rc('font', family='Computer Modern')
plt.rcParams.update({
    "figure.facecolor":  (1.0, 1.0, 1.0, 1.0),  # red   with alpha = 30%
    "axes.facecolor":    (1.0, 1.0, 1.0, 1.0),  # green with alpha = 50%
    "savefig.facecolor": (1.0, 1.0, 1.0, 0.0),  # blue  with alpha = 20%
})

In [None]:
# Load the MACE model
from pathlib import Path
from mace.calculators import mace_mp

calc_params = { 'model':  "../mace-mpa-0-medium.model" }
calc = mace_mp(**calc_params)

In [None]:
# Set up the generator
generator = raffle_generator()

In [None]:

# set up the host
Si_bulk = build.bulk("Si", crystalstructure="diamond", a=5.43)
Si_bulk.calc = calc
Si_reference_energy = Si_bulk.get_potential_energy() / len(Si_bulk)
Si_cubic = build.make_supercell(Si_bulk, [[-1, 1, 1], [1, -1, 1], [1, 1, -1]])
Ge_bulk = build.bulk("Ge", crystalstructure="diamond", a=5.65)
Ge_bulk.calc = calc
Ge_cubic = build.make_supercell(Ge_bulk, [[-1, 1, 1], [1, -1, 1], [1, 1, -1]])
Ge_reference_energy = Ge_bulk.get_potential_energy() / len(Ge_bulk)

Si_supercell = build.make_supercell(Si_cubic, [[2, 0, 0], [0, 2, 0], [0, 0, 1]])
Ge_supercell = build.make_supercell(Ge_cubic, [[2, 0, 0], [0, 2, 0], [0, 0, 1]])

Si_surface = build.surface(Si_supercell, indices=(0, 0, 1), layers=2)
Ge_surface = build.surface(Ge_supercell, indices=(0, 0, 1), layers=2)

Si_slab = build.surface(Si_supercell, indices=(0, 0, 1), layers=2, vacuum=12, periodic=True)
Si_slab.calc = calc
Ge_slab = build.surface(Ge_supercell, indices=(0, 0, 1), layers=2, vacuum=12, periodic=True)
Ge_slab.calc = calc

host = build.stack(Si_surface, Ge_surface, axis=2, distance= 5.43/2 + 5.65/2)
cell = host.get_cell()
cell[2, 2] -= 3.8865 # (5.43 + 5.65) / 2 * 3/4
host.set_cell(cell, scale_atoms=False)

# Set up the database and the reference energies
generator.distributions.set_element_energies(
    {
        'Si': Si_reference_energy,
        'Ge': Ge_reference_energy,
    }
)
database = [ Si_bulk, Ge_bulk ]
database = read("../Si-Ge_learn/DRAFFLE/DOutput/rlxd_structures_seed0.traj", index=":")
species = "SiGe"
# bounds = [[0, 0, 0.34], [1, 1, 0.52]]

In [None]:
# # Set up the database and the reference energies
# host = read("../graphene_grain_boundary_learn/POSCAR_host_gb")
# graphene = read("../graphene_grain_boundary_learn/POSCAR_graphene")
# h2 = build.molecule("H2")
# database = [host]
# generator.distributions.set_element_energies(
#     {
#         'C': 0.0,
#     }
# )
# species = "C"
# bounds = [[0.0, 0.0, 0.0], [1.0, 1.0, 1.0]]

In [None]:
# Learng the RAFFLE descriptor
generator.distributions.create(database, deallocate_systems=True)

In [None]:
# Set up the species list
species_name_list = re.findall(r'[A-Z][a-z]?\d*', species)
species_name_list = [re.sub(r'\d+', '', s) for s in species_name_list]

In [None]:
# Return the probability density
probability_density, grid = generator.get_probability_density(host, species=species, grid_spacing=0.1, return_grid=True)

In [None]:
descriptor = generator.get_descriptor()

In [None]:
# Create a figure with 3 subplots side by side
fig, axes = plt.subplots(1, 3, figsize=(18, 5))

# Plot for each n-body descriptor (2-body, 3-body, 4-body)
for j in range(3):
    # Calculate x-axis values
    x = np.arange(generator.distributions.cutoff_min[j],
                generator.distributions.cutoff_max[j] + generator.distributions.width[j],
                generator.distributions.width[j])

    # Plot on the respective subplot
    for idx in range(len(descriptor[j])):
        axes[j].plot(x, descriptor[j][idx,:])

    # Set labels and title for each subplot
    axes[j].set_ylabel('Descriptor value')
    axes[j].set_title(f'{j+2}-body descriptor')

axes[0].set_xlabel('Distance (Å)')
axes[1].set_xlabel('3-body angle (radians)')
axes[2].set_xlabel('Improper dihedral angle (radians)')
plt.tight_layout()
plt.show()

In [None]:
# Conver the probability density to a meshgrid
species_list = range(probability_density.shape[0]-4)

cell = host.get_cell()
a = np.linalg.norm(cell[0])
b = np.linalg.norm(cell[1])
c = np.linalg.norm(cell[2])

x_min = a * ( 0.0 )
x_max = a * ( 1.0 - 1.0/grid[0] )
y_min = b * ( 0.0 )
y_max = b * ( 1.0 - 1.0/grid[1] )
z_min = c * ( 0.0 )
z_max = c * ( 1.0 - 1.0/grid[2] )

# Create a 3D grid
grid_x, grid_y, grid_z = np.mgrid[x_min:x_max:complex(grid[0]), y_min:y_max:complex(grid[1]), z_min:z_max:complex(grid[2])]
grid_points = np.vstack((grid_x.ravel(), grid_y.ravel(), grid_z.ravel())).T

In [None]:
# Extract the coordinates and values
x = probability_density[0]
y = probability_density[1]
z = probability_density[2]
# scale the data positions
x = x * a
y = y * b
z = z * c

void = probability_density[3]

# Populate the grid with the values
grid_values = []
distance_threshold = 1e-3
for spec in sorted(set(species_list)):
    values = probability_density[spec+4]

    # Calculate distances to the nearest known data point
    tree = cKDTree(np.c_[x, y, z])
    distances, _ = tree.query(grid_points, k=1)
    
    # Interpolate data onto the 3D grid
    grid_values.append(griddata((x, y, z), values, (grid_x, grid_y, grid_z), method='nearest', fill_value=0.0))
    # grid_values[spec] = np.nan_to_num(grid_values)

    # Reshape distances to match the grid shape
    distances = distances.reshape(grid_values[spec-1].shape)

    # Set threshold for distance (e.g., 1 unit)
    grid_values[spec-1][distances > distance_threshold] = 0  # Set to zero for points beyond the threshold

In [None]:
# Set up interactive sliders
slice_slider = widgets.FloatSlider(value=0.5, min=0, max=1-0.01, step=0.01, description='Slice Position:')
axis_selector = widgets.RadioButtons(options=['a', 'b', 'c', 'ab'], description='Axis:', value='c')
species_selector = widgets.Dropdown(options=sorted(set(species_list)), description='Species:')

In [None]:
# Define the 2D plotting function
def plot_2d_slice(slice_level, species, axis):
    plt.clf()  # Clear the current figure

    norm = colors.PowerNorm(gamma=0.5)
    plane_label = ''
    # Normalize slice level to the grid range
    if axis == 'c':
        slice_idx = int(slice_level * grid[2])
        slice_data = grid_values[species-1][:, :, slice_idx]
        plt.imshow(slice_data.T, extent=(x_min, x_max, y_min, y_max), origin='lower', cmap='viridis', norm=norm, aspect='auto')
        plt.xlabel('$x$ (\AA)')
        plt.ylabel('$y$ (\AA)')
        plane_label = '$p_{3}$'
    elif axis == 'b':
        slice_idx = int(slice_level * grid[1])
        slice_data = grid_values[species-1][:, slice_idx, :]
        plt.imshow(slice_data.T, extent=(x_min, x_max, z_min, z_max), origin='lower', cmap='viridis', norm=norm, aspect='auto')
        plt.xlabel('$x$ (\AA)')
        plt.ylabel('$z$ (\AA)')
        plane_label = '$p_{4}$'
    elif axis == 'a':
        slice_idx = int(slice_level * grid[0])
        slice_data = grid_values[species-1][slice_idx, :, :]
        plt.imshow(slice_data.T, extent=(y_min, y_max, z_min, z_max), origin='lower', cmap='viridis', norm=norm, aspect='auto')
        plt.xlabel('$y$ (\AA)')
        plt.ylabel('$z$ (\AA)')
        plane_label = '$p_{1}$'
    elif axis == 'ab':
        # get diagonal list of idx
        slice_idx_xy = [
            ( 
                round( ( -0.5 + slice_level ) * (grid[0] + 1) ) + i,
                round( ( -0.5 + slice_level ) * (grid[1] + 1) ) + j
            ) for i, j in zip(range(grid[0],-1,-1), range(grid[1]+1))
            if ( round( ( -0.5 + slice_level ) * (grid[0] + 1) ) + i ) >= 0 and ( round( ( -0.5 + slice_level ) * (grid[0] + 1) ) + i ) < grid[0] and
            ( round( ( -0.5 + slice_level ) * (grid[1] + 1) ) + j ) >= 0 and ( round( ( -0.5 + slice_level ) * (grid[1] + 1) ) + j ) < grid[1]
        ]
        slice_idx_x = [ i for i, j in slice_idx_xy ]
        slice_idx_y = [ j for i, j in slice_idx_xy ]
        slice_data = grid_values[species-1][slice_idx_x, slice_idx_y, :]
        min_loc = 0.0
        max_loc = np.sqrt( (len(slice_idx_x) * a / grid[0])**2 + (len(slice_idx_y) * b / grid[1])**2 )
        plt.imshow(slice_data.T, extent=(min_loc, max_loc, z_min, z_max), origin='lower', cmap='viridis', norm=norm, aspect='auto')
        plt.xlabel(r'$xy$ (\AA)')
        plt.ylabel('$z$ (\AA)')
        plane_label = '$p_{2}$'

    plt.title(f'Plane {plane_label}', fontsize=25, y=1.0, pad=10)
    min_val = min(grid_value.min() for grid_value in grid_values)
    max_val = max(grid_value.max() for grid_value in grid_values)
    plt.clim(min_val, max_val)
    cbar = plt.colorbar(label='Viability', orientation='vertical', fraction=0.085, pad=0.02)
    ax = plt.gca()
    # ax.set_aspect('equal')
    # Set tick and label font size
    ax.tick_params(axis='both', which='major', labelsize=20)
    
    # Create a custom colormap where zero values are colored grey
    cmap = plt.cm.viridis.copy()
    cmap.set_bad('grey')
    slice_data = np.ma.masked_where(slice_data < 1e-8, slice_data)
    plt.imshow(slice_data.T, extent=ax.get_images()[0].get_extent(), origin='lower', cmap=cmap, norm=norm, aspect='auto')
    ax.set_aspect('equal')
    
    plt.xlabel(plt.gca().get_xlabel(), fontsize=25)
    plt.ylabel(plt.gca().get_ylabel(), fontsize=25)
    ax.xaxis.set_major_locator(plt.MaxNLocator(4))
    ax.yaxis.set_major_locator(plt.MaxNLocator(4))
    ax.xaxis.set_minor_locator(AutoMinorLocator(2))
    ax.yaxis.set_minor_locator(AutoMinorLocator(2))
    cbar.ax.tick_params(labelsize=20)
    cbar.set_label(r'$P_{\mathrm{'+species_name_list[species-1]+r'}}$', fontsize=25)
    # cbar.set_ticks(np.linspace(min_val, max_val, 4))
    cbar.ax.yaxis.set_major_locator(plt.MaxNLocator(nbins=4))
    cbar.ax.minorticks_on()

In [None]:
# Save the current plot as a PDF
def save_plot(b):
    slice_level = slice_slider_2d.value
    species = species_selector_2d.value
    axis = axis_selector_2d.value
    plot_2d_slice(slice_level, species, axis)
    # Open file dialog to select file path and name
    root = tk.Tk()
    root.withdraw()  # Hide the root window
    file_path = filedialog.asksaveasfilename(defaultextension=".pdf", filetypes=[("PDF files", "*.pdf")])
    if file_path:
        fig = plt.gcf()
        fig.patch.set_facecolor('none')  # Set the face color to white
        ax = plt.gca()
        ax.patch.set_facecolor('none')  # Set the axes face color to white
        plt.savefig(file_path, bbox_inches='tight', pad_inches=0, facecolor=fig.get_facecolor(), edgecolor='none')
        print(f'Plot saved as {file_path}')

In [None]:
# Create the figure for the 2D plot
plt.figure(figsize=(8, 6))

save_button = widgets.Button(description="Save Plot")
save_button.on_click(save_plot)
display(save_button)

# Set up interactive sliders for 2D plot
slice_slider_2d = widgets.FloatSlider(value=0.5, min=0, max=1-0.01, step=0.01, description='Slice Position:')
axis_selector_2d = widgets.RadioButtons(options=['a', 'b', 'c', 'ab'], description='Axis:', value='c')
species_selector_2d = widgets.Dropdown(options=sorted(set(species_list)), description='Species:')

# Set up interactive plot for 2D slices
interactive_plot_2d = interactive(plot_2d_slice, slice_level=slice_slider_2d, axis=axis_selector_2d, species=species_selector_2d)
interactive_plot_2d