In [1]:
print("Loading modules...")
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import pixstem.api as ps
import multiprocessing
import hyperspy.api as hs
import ctypes
import csv
import tkinter as tk
print("Modules loaded.")

file = None
distances = None

Loading modules...
Modules loaded.


In [2]:
def loadFile(fileName):
    if filename is "":
        label1['text'] = "Please enter the path of the input file in the text box provided then press the Load File option.\n"
    global file
    print("Loading file...")
    file = hs.load(fileName)
    print("File loaded.")
    menu()

def distance(x1, y1, x2, y2):
    return np.sqrt(pow(x1 - x2, 2) + pow(y1 - y2, 2))

def pos_dist(x1, y1, x2, y2):
    if (x2 > x1 and y2 > y1):
        return np.sqrt(pow(x1 - x2, 2) + pow(y1 - y2, 2))
    else:
        return -1.0

def findCenter(im, peak):
    center = (0,0)
    maximum = 0
    for (x,y) in np.ndenumerate(peak):
        for (a, b) in y:
            if (int(a) < len(im) and int(b) < len(im) and im[int(a)][int(b)] > maximum):
                maximum = im[int(a)][int(b)]
                center = (b, a)
    return center

def multiprocessing_func(i, j, rnd):
    s = ps.PixelatedSTEM(hs.signals.Signal2D(file.inav[i, j]))
    imarray = np.array(s)
    s = s.rotate_diffraction(0,show_progressbar=False)
    ############################################################################################################################
    st = s.template_match_disk(disk_r=5, lazy_result=False, show_progressbar=False)
    peak_array = st.find_peaks(lazy_result=False, show_progressbar=False)
    peak_array_com = s.peak_position_refinement_com(peak_array, lazy_result=False, show_progressbar=False)
    s_rem = s.subtract_diffraction_background(lazy_result=False, show_progressbar=False)
    peak_array_rem_com = s_rem.peak_position_refinement_com(peak_array, lazy_result=False, show_progressbar=False)
    ############################################################################################################################
    center = findCenter(imarray, peak_array_rem_com)

    # finds the specific spot and adding that distance to the array
    posDistance = 0
        
    for (x,y) in np.ndenumerate(peak_array_rem_com):
        prev = (0, 0)
        for (a, b) in y:
            if abs(center[0] - b) < 1E-5 and abs(center[1] - a) < 1E-5:
                posDistance = distance(center[0], center[1], prev[1], prev[0])
                break
            prev = (a, b)
    distances[j][i] = round(posDistance, rnd)

In [3]:
def analysis(ROW = 10, COL = 10, rnd = 2):
    global distances
    print("Starting analysis...")

    shared_array_base = multiprocessing.Array(ctypes.c_double, ROW*COL)
    distances = np.ctypeslib.as_array(shared_array_base.get_obj())
    distances = distances.reshape(COL, ROW)

    for i in range(ROW):
        print(i)
        processes = []
        for j in range(COL):
            p = multiprocessing.Process(target=multiprocessing_func, args=(i, j, rnd,))
            processes.append(p)
            p.start()

        for process in processes:
            process.join()  
    print("Analysis complete.")
    menu()

def toCSV(filename = "outputDistances.csv"):
    file = open(filename, "w")
    writer = csv.writer(file)
    for i in distances:
        writer.writerow(i)
    file.close()

def barChart(INTERVAL = 0.01):
    global distances
    dist = distances.flatten()
    x_pos = np.arange(np.min(dist), np.max(dist), INTERVAL) # this 0.01 is the distance between each x-axis label. So for example it goes 1.0, 1.01, 1.02, 1.03...
    x_pos = [round(num, 2) for num in x_pos]
    y_pos = np.arange(len(x_pos))
    ################################################################################################################################
    from collections import Counter
    counter = Counter(dist)
    counts = []
    for i in x_pos:
            counts.append(counter[i]) if i in counter.keys() else counts.append(0)
    ################################################################################################################################
    plt.bar(y_pos, counts, align='center', alpha=0.95) # creates the bar plot
    plt.xticks(y_pos, x_pos, fontsize = 5)
    plt.xlabel('Distance from center peek', fontsize = 5)
    plt.ylabel('Counts', fontsize = 5)
    plt.title('Distance Counts', fontsize = 5)

    ax = plt.gca()
    plt.setp(ax.get_xticklabels(), rotation=90)
    ax.tick_params(axis='x', which='major', labelsize=5)
    ax.tick_params(axis='y', which='major', labelsize=5)

    [l.set_visible(False) for (i,l) in enumerate(ax.xaxis.get_ticklabels()) if i % 60 != 0] 
    # The '2' is the every nth number of labels its shows on the x-axis. So rn is shows every 2nd label. 

    plt.gcf().subplots_adjust(bottom = 0.23)
    plt.rcParams["figure.dpi"] = 500
    #plt.savefig("300x500BarChart.png")
    plt.show()
    menu()

def heatMap():
    global distances
    import seaborn as sns
    dist = np.fliplr(distances)
    med = np.median(dist)
    for r in range(len(dist)):
        for c in range(len(dist[r])):
            if dist[r][c] > 25:
                dist[r][c] = med
    # Create data
    df = pd.DataFrame(dist, columns=np.arange(len(dist[0])), index=np.arange(len(dist)))
    a = sns.heatmap(df)
    a.plot()
    #a = sns.heatmap(df, annot=True, fmt=".2f", annot_kws={"size": 5})
    menu()

In [4]:
def menu():
    print("-" * 50)
    print((" " * 23) + "Menu" + (" " * 23))
    print("-" * 50)
    print("1 - Load File")
    print("2 - Start Analysis")
    print("3 - Create Bar Chart")
    print("4 - Create Heat Map")
    print("5 - Transfer Data to .csv")
    print("q - Quit Program")
    print("-" * 50)
    inpt = input("Please enter an option: ")
    while inpt is not "q":
        if inpt is "1":
            fileName = input("Please enter the file name. ")
            loadFile(fileName)
        elif inpt is "2":
            row, col, rnd = input("Please enter the number of rows and columns you would like to analyze and the number of decimal point to which to round values to seperated by a space. ").split()
            analysis(int(row), int(col), int(rnd))
        elif inpt is "3":
            INTERVAL = input("Please enter the x-axis interval. ") 
            # this 0.01 is the distance between each x-axis label. So for example it goes 1.0, 1.01, 1.02 1.03...
            barChart(float(INTERVAL))
        elif inpt is "4":
            heatMap()
        elif inpt is "5":
            filename = input("Please enter the filename you wish to save the data to. Dont forget the .csv extension. ")
            toCSV(filename)
        elif inpt is "q":
            break
        else:
            inpt = input("Please enter a valid number or q to quit. ")
menu()

In [56]:
import tkinter as tk
from tkinter import font

HEIGHT = 900
WIDTH = 1400

root = tk.Tk()

canvas = tk.Canvas(root, height=HEIGHT, width=WIDTH)
canvas.pack()
frame = tk.Frame(root, bg='#80c1ff')
frame.place(relwidth=1, relheight=1)

# Menu Label
label = tk.Label(frame, text='Menu', bg='#80c1ff', font=('Calibri', 50), fg='#ffffff')
label.place(relx=0.45, rely=0.05, relwidth=0.1, relheight=0.05)

# Buttons
button = tk.Button(frame, text='Load File', bg='#80c1ff', font=('Calibri', 30), highlightthickness = 0, bd=0, activebackground='#339cff', pady=0.02, command=lambda: loadFile(entry.get()))
button.place(relx=0.4, rely=0.15, relwidth=0.2, relheight=0.05)
button1 = tk.Button(frame, text='Start Analysis', bg='#80c1ff', font=('Calibri', 30), highlightthickness = 0, bd=0, activebackground='#339cff', pady=0.02)
button1.place(relx=0.35, rely=0.2, relwidth=0.3, relheight=0.05)
button2 = tk.Button(frame, text='Create Bar Chart', bg='#80c1ff', font=('Calibri', 30), highlightthickness = 0, bd=0, activebackground='#339cff', pady=0.02)
button2.place(relx=0.35, rely=0.25, relwidth=0.3, relheight=0.05)
button3 = tk.Button(frame, text='Create Heat Map', bg='#80c1ff', font=('Calibri', 30), highlightthickness = 0, bd=0, activebackground='#339cff', pady=0.02)
button3.place(relx=0.35, rely=0.3, relwidth=0.3, relheight=0.05)
button4 = tk.Button(frame, text='Transfer Data to .csv', bg='#80c1ff', font=('Calibri', 30), highlightthickness = 0, bd=0, activebackground='#339cff', pady=0.02)
button4.place(relx=0.3, rely=0.35, relwidth=0.4, relheight=0.05)

# Text Output box
label1 = tk.Label(frame, bg='#cce6ff', font=('Calibri', 15), anchor='nw', highlightthickness = 0, bd=0)
label1.place(relx=0.1, rely=0.5, relwidth=0.8, relheight=0.35)

# Entry box
entry = tk.Entry(frame, font=40)
entry.place(relx=0.1, rely=0.9, relwidth=0.8, relheight=0.05)

root.mainloop()