In [3]:
#load libraries
import numpy as np
import math
from math import sqrt
from matplotlib.image import NonUniformImage
import matplotlib.pyplot as plt
import seaborn as sns
from statistics import mean
from scipy.stats import norm, lognorm
from sklearn.neighbors import NearestNeighbors
from scipy.spatial import distance, Voronoi, voronoi_plot_2d, ConvexHull, Delaunay
from collections import defaultdict
import itertools
from matplotlib import cm
from matplotlib.colors import ListedColormap, LinearSegmentedColormap
import random
import sys
from itertools import combinations 
import time
from scipy import stats
from sklearn.svm import SVC
import sobol_seq
import os,glob

#plotting style
%matplotlib inline
sns.set()
sns.set_style('white')

def js_divergence_scipy(hist1,hist2):
    return (distance.jensenshannon(hist1, hist2, base=2))**2

def nearest_neighbors(values, all_values, nbr_neighbors=1):
    nn = NearestNeighbors(nbr_neighbors, metric='euclidean', algorithm='kd_tree').fit(all_values)
    dists, idxs = nn.kneighbors(values)
    return idxs

def golden(n,d=2):
    g = 1.32471795724474602596 
    alpha = np.zeros(d) 
    for j in range(d): 
        alpha[j] = pow(1/g,j+1) %1 
    z = np.zeros((n, d)) 
    seed=0.5
    for i in range(n): 
        z[i] = (seed + alpha*(i+1)) %1 
    return z





#get area of each voronoi region
def voronoi_volumes(points, percent=100, window_vol=False):
    v = Voronoi(points)
    #create zeros of len points
    vol = np.zeros(v.npoints)
    #create vol index and get index of voronoi region for input point 
    for i, reg_num in enumerate(v.point_region):
        #get indices of each region
        indices = v.regions[reg_num]
        #-1 means vertex outside diagram
        if -1 in indices: # some regions can be opened
            #ignore infinite regions
            vol[i] = 0
            
        else: #can compute area, use convex hull
            vert=v.vertices[indices]
            vert[np.where(vert<0)]=0
            max_vert=math.ceil(np.amax(raw_data[:,3]))       
            y_max_vert=math.ceil(np.amax(raw_data[:,4]))
            if y_max_vert > max_vert:
                max_vert = y_max_vert
            
            if window_vol == False:
                vert[np.where(vert>max_vert)]=max_vert
            else:
                vert[np.where(vert>(max_vert*percent))]=(max_vert*percent)
            
            try:
                vol[i] = ConvexHull(vert).volume
            except: #convex hull error, exclude area
                vol[i] = 0
    #can use grain IDs, since areas match input point
    return vol






# def new_vor_p(num_bins=25):
#     #get areas
#     vol=voronoi_volumes(coords)   
#     #turn area into percent of total area
#     norm_vol=(vol)/sum(vol)
#     #multiply by total number of pixels to find how many pixels in each area
#     scale_vol=norm_vol*len(raw_data)
#     #round to nearest whole number
#     round_vol=np.round(scale_vol,decimals=0)
#     #create array, col1 is ID, col2 is area
#     grain_IDs_and_areas=np.stack((grain_IDs,round_vol),axis=1)
#     #get unique IDs and their count
#     unique_IDs,unique_IDs_count=np.unique(grain_IDs_and_areas[:,0],return_counts=True)
#     #create array to hold unique IDs and their sizes added together
#     unique_grain_IDs_and_total_areas=np.stack((unique_IDs,np.zeros(len(unique_IDs))),axis=1)
    
#     #construct IDs and count
#     for row in unique_grain_IDs_and_total_areas:
#         grain=grain_IDs_and_areas[np.where(grain_IDs_and_areas[:,0] == row[0])]
#         total=np.sum(grain[:,1])
#         row[1]=total
    
#     #double edge grains
#     row_ID=0
#     for row in unique_grain_IDs_and_total_areas:
#         #if the grain ID is an edge
#         if np.isin(row[0],edge_grains[:,0]):
#             #double the total grain area
#             unique_grain_IDs_and_total_areas[row_ID,1]*=2
#         row_ID+=1
        
#     #return count for hist
#     count = 0.138*unique_grain_IDs_and_total_areas[:,1]
#     q, q_bin_edges = np.histogram(count, bins=num_bins, range=(0.01,50), density=True)
#     q = np.append(q, 0)
#     #mean,var 
#     total_mean=np.mean(count)
#     var=np.var(count)
    
#     return q, total_mean, var
    
    

    

def grain_areas(sampled_coords,sampled_IDs,percent=100,num_bins=25,window_vol=False,return_p=False,numerical=False):   
    global p
    global edge_grains
#     plt.scatter(sampled_coords[:,0],sampled_coords[:,1],c=sampled_IDs,cmap='prism')
    #get areas
    vol=voronoi_volumes(sampled_coords,window_vol)
    #turn area into percent of total area
    norm_vol=(vol)/sum(vol)
    #multiply by total number of pixels to find how many pixels in each area
    if window_vol==False:
        scale_vol=norm_vol*len(raw_data)
    else:
        scale_vol=norm_vol*len(sampled_coords)
    #round to nearest whole number
    round_vol=np.round(scale_vol,decimals=0)
    #create array, col1 is ID, col2 is area
    grain_IDs_and_areas=np.stack((sampled_IDs,round_vol),axis=1)
    #get unique IDs and their count
    unique_IDs,unique_IDs_count=np.unique(grain_IDs_and_areas[:,0],return_counts=True)
    #create array to hold unique IDs and their sizes added together
    unique_grain_IDs_and_total_areas=np.stack((unique_IDs,np.zeros(len(unique_IDs))),axis=1)
    
    #construct IDs and count
    for row in unique_grain_IDs_and_total_areas:
        grain=grain_IDs_and_areas[np.where(grain_IDs_and_areas[:,0] == row[0])]
        total=np.sum(grain[:,1])
        row[1]=total
        
    #create window edges
    if window_vol==True:
        #create array to hold grain IDs and edge column
        edges=np.zeros((len(sampled_IDs),1))
        grain_slice_and_edges=np.concatenate((sampled_IDs.reshape(-1,1),edges.reshape(-1,1)),axis=1)

        #min x
        left_edge=np.amin(sampled_coords[:,0])
        #max x
        right_edge=np.amax(sampled_coords[:,0])
        #min y
        bottom_edge=np.amin(sampled_coords[:,1])
        #max y
        top_edge=np.amax(sampled_coords[:,1])
        #assign edge grains
        index=0
        for row in sampled_coords:
            #x values
            if row[0] <= left_edge or row[0] >= right_edge:
                #it's an edge
                grain_slice_and_edges[index,1]=1
            #y values
            if row[1] <= bottom_edge or row[1] >= top_edge:
                #it's an edge
                grain_slice_and_edges[index,1]=1
            index+=1 
    
        #slice array for edge calculation
        edge_grains=grain_slice_and_edges[np.where(grain_slice_and_edges[:,1]==1)]
        
    #double edge grains
    row_ID=0
    for row in unique_grain_IDs_and_total_areas:
        #if the grain ID is an edge
        if np.isin(row[0],edge_grains[:,0]):
            #double the total grain area
            unique_grain_IDs_and_total_areas[row_ID,1]*=2
        row_ID+=1
    
    count = 0.138*unique_grain_IDs_and_total_areas[:,1]

    #**************************************
    
    count=count[np.where(count>0)]
#     return count #!!!!!
    mean=np.mean(count)
    var=np.var(count)
    if numerical==True:
        q, q_bin_edges = np.histogram(count, bins=num_bins, range=(0.01,50), density=True)
    else: #area
        q, q_bin_edges = np.histogram(count, bins=num_bins, range=(0.01,50), density=True, weights=count)
    q = np.append(q, 0)
    if return_p==True:
        return q, mean, var
    else:
        return js_divergence_scipy(p,q), mean, var
    
    
    
    
    

def window_slices():
    x_min=np.amin(raw_data[:,3])
    x_max=np.amax(raw_data[:,3])
    y_min=np.amin(raw_data[:,4])
    y_max=np.amax(raw_data[:,4])
    #window one (bottom left)
    raw_data_slice_1 = raw_data[np.where((raw_data[:,3] <= x_max/2) & (raw_data[:,4] <= y_max/2))]

    # #window two (bottom right)
    raw_data_slice_2 = raw_data[np.where((raw_data[:,3] >= x_max/2) & (raw_data[:,4] <= y_max/2))]
    raw_data_slice_2[:,3] -= np.min(raw_data_slice_2[:,3])

    # #window three (top left)
    raw_data_slice_3 = raw_data[np.where((raw_data[:,3] <= x_max/2) & (raw_data[:,4] >= y_max/2))]
    raw_data_slice_3[:,4] -= np.min(raw_data_slice_3[:,4])

    # #window four (top right)
    raw_data_slice_4 = raw_data[np.where((raw_data[:,3] >= x_max/2) & (raw_data[:,4] >= y_max/2))]
    raw_data_slice_4[:,3] -= np.min(raw_data_slice_4[:,3])
    raw_data_slice_4[:,4] -= np.min(raw_data_slice_4[:,4])
    
    return [raw_data_slice_1, raw_data_slice_2, raw_data_slice_3, raw_data_slice_4]






def sample(size,method):   
    global raw_data
    global coords
    global grain_IDs
    global grain_IDs_and_edges
    global coords_and_grains
    global coords_and_grains_copy
    global edge_grains
    global window_slice_list   
    
    if method == 'random':
        #choose random points
        random_coords_and_grains=coords_and_grains[np.random.choice(coords_and_grains.shape[0], size, replace=False), :]
        sliced_coords=random_coords_and_grains[:,0:2]
        sliced_grain_IDs=random_coords_and_grains[:,2]
    
    if method == 'square':
        phi = (np.sqrt(5)+1)/2
        ratio = np.sqrt(3)/2 # cos(60°)
        coords = raw_data[:,3:5]
        N = size
        
        N_X = int(np.sqrt(N))
        N_Y = N // N_X
        xv, yv = np.meshgrid(np.arange(N_X), np.arange(N_Y), sparse=False, indexing='xy')
        square_coords=np.concatenate((xv.reshape(-1,1), yv.reshape(-1,1)), axis=1)
        square_coords[:,0] = square_coords[:,0] * (np.amax(raw_data[:,3]) / np.amax(square_coords[:,0]))
        square_coords[:,1] = square_coords[:,1] * (np.amax(raw_data[:,4]) / np.amax(square_coords[:,1]))
        grid_sample=raw_data[nearest_neighbors(square_coords,coords),3:6]
        grid_sample=grid_sample[:,0]
        
        sliced_coords=grid_sample[:,0:2]
        sliced_grain_IDs=grid_sample[:,2]
        
    if method == 'hex':
        phi = (np.sqrt(5)+1)/2
        ratio = np.sqrt(3)/2 # cos(60°)
        coords = raw_data[:,3:5]
        N = size
        
        N_X = int(np.sqrt(N)/ratio)
        N_Y = N // N_X
        xv, yv = np.meshgrid(np.arange(N_X), np.arange(N_Y), sparse=False, indexing='xy')
        xv = xv * ratio
        xv[::2, :] += ratio/2
        hex_coords=np.concatenate((xv.reshape(-1,1), yv.reshape(-1,1)), axis=1)
        hex_coords[:,0] *= (np.amax(raw_data[:,3]) / np.amax(hex_coords[:,0]))
        hex_coords[:,1] *= (np.amax(raw_data[:,4]) / np.amax(hex_coords[:,1]))
        grid_sample=raw_data[nearest_neighbors(hex_coords,coords),3:6]
        grid_sample=grid_sample[:,0]
        
        sliced_coords=grid_sample[:,0:2]
        sliced_grain_IDs=grid_sample[:,2]
        
    if method == 'sobol':
        #choose 2D sobol points
        sobol = (sobol_seq.i4_sobol_generate(2, size))
        sobol[:,0]*=np.amax(raw_data[:,3])
        sobol[:,1]*=np.amax(raw_data[:,4])
        sobol_sample=raw_data[nearest_neighbors(sobol,coords),3:6]
        sobol_sample=sobol_sample[:,0]

        sliced_coords=sobol_sample[:,0:2]
        sliced_grain_IDs=sobol_sample[:,2]
        
    if method == 'gold':
        #choose 2D golden points
        golden_2D = golden(size)
        golden_2D[:,0]*=np.amax(raw_data[:,3])
        golden_2D[:,1]*=np.amax(raw_data[:,4])
        golden_sample=raw_data[nearest_neighbors(golden_2D,coords),3:6]
        golden_sample=golden_sample[:,0]

        sliced_coords=golden_sample[:,0:2]
        sliced_grain_IDs=golden_sample[:,2]
        
    if method == 'window':
#         percent=sqrt(size/len(raw_data))
        percent=sqrt(size/100000) #filesize
    
#         window_divergences=[]
#         window_means=[]
#         window_vars=[]
        
#         for raw_data_slice_x in window_slice_list:
#             raw_data_slice=raw_data_slice_x
    
#             row_list=[]
#             for index in range(0,len(raw_data_slice)):
#                 if (raw_data_slice[index,3]<=int((np.amax(raw_data[:,3])*percent))) and (raw_data_slice[index,4]<=int((np.amax(raw_data[:,4])*percent))):
# #                     row_list.append(raw_data_slice[index, 0:3]) #check this! angles?
#                     row_list.append(raw_data_slice[index, 3:6])
#                 else:
#                     continue

#             row_tuple=tuple(row_list)
#             sliced_raw_data_slice=np.vstack(row_tuple)
#             sliced_coords=sliced_raw_data_slice[:,0:2]
#             sliced_grain_IDs=sliced_raw_data_slice[:,2]
            
#             if hist==True:
#                 window_divergences.append(grain_areas(sliced_coords,sliced_grain_IDs,percent,window_vol=True,hist_bool=hist))
#             else:
#                 temp_mean, temp_var=grain_areas(sliced_coords,sliced_grain_IDs,percent,window_vol=True,hist_bool=hist)
#                 window_means.append(temp_mean)
#                 window_vars.append(temp_var)
                
#         if hist==True:
#             #return average JS Divergence
#             return sum(window_divergences)/4
#         else:
#             temp_mean_avg=sum(window_means)/4
#             temp_var_avg=sum(window_vars)/4
#             return temp_mean_avg, temp_var_avg

        #Full Window
        row_list=[]
        for index in range(0,len(raw_data)):
            if (raw_data[index,3]<=int((np.amax(raw_data[:,3])*percent))) and (raw_data[index,4]<=int((np.amax(raw_data[:,4])*percent))):
                row_list.append(raw_data[index,3:6])
            else:
                continue

        row_tuple=tuple(row_list)
        window_sample=np.vstack(row_tuple)
        sliced_coords=window_sample[:,0:2]
        sliced_grain_IDs=window_sample[:,2]
        return grain_areas(sliced_coords,sliced_grain_IDs, window_vol=True)
    
    #return JS Divergence
    return grain_areas(sliced_coords,sliced_grain_IDs)