### Gradient Explorations

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import scipy
from scipy.signal import convolve2d
from scipy.ndimage import gaussian_filter
from scipy.optimize import minimize, LinearConstraint
from scipy.interpolate import interp1d, LinearNDInterpolator

# Move into the source directory for this notebook to work properly
# Probably want a better way of doing this.
import os
import importlib
os.chdir('../src/')

# Import whatever we need
import disruptivity as dis
import vis.disruptivity_vis as dis_vis
import vis.probability_vis as prob_vis
from vis.plot_helpers import plot_subplot as plot
import data_loader

# Import tokamak Configuartions
from tokamaks.cmod import CONFIG as CMOD
from tokamaks.d3d import CONFIG as D3D

importlib.reload(dis)
importlib.reload(dis_vis)
load_disruptions_mat = data_loader.load_disruptions_mat

In [None]:
# Loading
cmod_df, cmod_indices = load_disruptions_mat('../data/CMod_disruption_warning_db.mat')

In [None]:
# Entry dictionary
entry_dict_VDE = {
    'z_error': CMOD["entry_dict"]["z_error"],
    'kappa': CMOD["entry_dict"]["kappa"],
#     'z_error': CMOD["entry_dict"]["z_error"],
}

entry_dict_DENSITY = {
    'q95':CMOD['entry_dict']['q95'],
    'n_e':CMOD['entry_dict']['n_e'],
}

# Hugill
# Compute the murakami parameter
cmod_df['inv_q95'] = 1/cmod_df['q95']
cmod_df['murakami'] = cmod_df['n_e']*0.68/(cmod_df['n_equal_1_mode']/cmod_df['n_equal_1_normalized'])/1e19

entry_dict_H = {
    'murakami':{
        'range':[0,20],
        'axis_name': "$n_e R/B_T \ (10^{19}$m$^{-2}$/T)",
    },
    'inv_q95':{
        'range':[0, 1],
        'axis_name': "$1/q_{95}$",
    },
 }

In [None]:
from matplotlib.collections import LineCollection
from mpl_toolkits.axes_grid1 import make_axes_locatable

def subplot_draw_trajectory(ax, dataframe, entry_dict, indices, shot):
    
    # Filter the datafram data to get the pulse of interest in the flattop
    disruptive_indices = indices["indices_flattop_disrupt_in_flattop"]
    shotlist_bool = np.isin(dataframe.shot, [shot])
    shot_indices = dataframe[shotlist_bool].index
    overlap = np.array(np.intersect1d(disruptive_indices,shot_indices))
    print(f"Flat Top Starts: {dataframe.time[overlap[0]]} s")
    
    # Get the entry information
    entries = list(entry_dict.keys())
    
    # Create line segments to color individually
    # Based on this example: https://matplotlib.org/stable/gallery/lines_bars_and_markers/multicolored_line.html
    time = dataframe['time'][overlap]
    x = dataframe[entries[0]][overlap]
    y = dataframe[entries[1]][overlap]
    points = np.array([x, y]).T.reshape(-1, 1, 2)
    segments = np.concatenate([points[:-1], points[1:]], axis=1)
    
    # Create a continuous norm to map from data points to colors
    norm = plt.Normalize(time.min(), time.max())
    lc = LineCollection(segments, cmap='cool', norm=norm, capstyle='round')#solid_capstyle='round'
    # Set the values used for colormapping
    lc.set_array(time)
    lc.set_linewidth(2)
    
    # Add the collection
    line = ax.add_collection(lc)
    
    # RHS colorbar for time
    cbar = fig.colorbar(line, ax=ax,location = 'left', pad=0.13)
    cbar.ax.tick_params()
    cbar.set_label(label="Pulse Time (s)")
    
    # Re-apply axis labels and limits if needed.
    ax.set_xlim(entry_dict[entries[0]]['range'])
    ax.set_ylim(entry_dict[entries[1]]['range'])
    ax.set_xlabel(entry_dict[entries[0]]['axis_name'])
    ax.set_ylabel(entry_dict[entries[1]]['axis_name'])
    
    return dataframe.time[overlap[0]]
    
def subplot_grad2d(
    ax,
    disruptivity: np.ndarray,
    error: np.ndarray,
    bins: np.ndarray,
    entry_dict: dict,
):
    """Creates a subplot of a 2D disruptivity map.

    Args:
        ax (subplot_type): The matplotlib axes to plot on.
        disruptivity (np.ndarray): The disruptivity histogram of size (ny_bins-1, nx_bins-1).
        error (np.ndarray): The error histogram of size (ny_bins-1, nx_bins-1).
        bins (list): List of the bin edges from the histogram.
        entry_dict (list): List of the entries being plotted.
    """

    # Asserts
    assert len(entry_dict) == 2, "Too many entry_dict dimensions."

    # Parse the dict
    extent = []
    axis_name_list = []
    for key in entry_dict:
        entry = entry_dict[key]

        # Follow the order of the dictionary to find x and y
        # Make sure the range is there
        assert (
            "range" in entry
        ), f"Entry {key} of entry_dict missing range field."
        extent.extend(entry["range"])

        # And the axis names
        assert (
            "axis_name" in entry
        ), f"Entry {key} of entry_dict missing axis_name field."
        axis_name_list.append(entry["axis_name"])

    # The Normal Heatmap
    cax = ax.imshow(
        disruptivity.T,
        cmap="viridis",
        origin="lower",
        aspect="auto",
        extent=extent,
    )

    # Axis Titles
    ax.set_xlabel(axis_name_list[0])
    ax.set_ylabel(axis_name_list[1])

    # Colorbar
    cbar = plt.colorbar(cax)
    cbar.ax.tick_params(labelsize="large")
    cbar.set_label(label="Disruptivity Derivative", size="large")


In [None]:
# Computing the disruptivity and plotting the regular figure
shotlist=None
entry_dict = entry_dict_DENSITY
indices_n_disrupt, indices_n_total = dis.get_indices_disruptivity(CMOD, cmod_df, cmod_indices, shotlist=shotlist)
args = dis.compute_disruptivity_likelihood(cmod_df, entry_dict, indices_n_total, nbins=35, tau=50, window=25)

In [None]:
fig,ax = plot('cmod_q95_ne_disruptivity_kaloyannis.png', dis_vis.subplot_disruptivity2d, args, figsize=(7,4))
# Get a pulse's flat top data

shot = 1140226013 #AT PULSE
# shot = 1160930030 # BOB PULSE
# shot = 1120105021 #VDE
ax.set_title(f"Shot {shot}")
subplot_draw_trajectory(ax,cmod_df,entry_dict,cmod_indices,shot)
plt.savefig("nofill_trajectory.png", dpi=400, facecolor='white', bbox_inches="tight")

In [None]:
def iterative_fill(
    disruptivity: np.ndarray,
    bins: np.ndarray,
    d_max_mult = 1,
    filter_size = 11,
    max_iter=100,
    rel_tol = 1e-6,
    sigma=0.5,
    truncate=3):
    '''
    Iterative averaging procedure
    Hold real data, no data, all disrupted as fixed values
    Iteratively loop:
        Conv2D with averaging kernel of size n.
        Replace the fixed values, 
        Check for convergence with last iteration via relative_tol (bin wise).
        Check for max iterations
    TODO: CLEAN THIS UP, MAKE A NEW FILE FOR ALL THINGS BOUNDARY AVOIDANCE
    TODO: WHAT IS WRONG WITH filter_size-7
    '''
    
    
    # Create the fixed points array
    fixed_points = np.copy(disruptivity)
    max_fill = disruptivity[disruptivity>0].max()*d_max_mult
    fixed_points[disruptivity==-1]= max_fill
    fixed_points[disruptivity==-3]= max_fill

    # Create the iter_array
    iter_array = np.zeros(fixed_points.shape)
    iter_array[disruptivity!=-2] = fixed_points[disruptivity!=-2]

    # Create the convolutional filter
    n_dims = len(bins)
    ave_filter = np.ones([n]*n_dims)/(filter_size**n_dims)

    # Iterate
    for i in range(max_iter):
        # Step 1: Convolve
        new_iter_array = convolve2d(iter_array,ave_filter,
                                    mode='same', 
                                    fillvalue=max_fill)

        # Step 2: Replace the fixed points
        new_iter_array[disruptivity!=-2] = fixed_points[disruptivity!=-2]

        # Step 3: Relative Tolerance Check
        # Check for division by 0s in early cycles
        if (new_iter_array!=0).all():
            residuals = abs(new_iter_array-iter_array)/new_iter_array
            if (residuals<=rel_tol).all():
                print(f'Data filling converged after {i} iterations.')
                break
            if (i==max_iter):
                print(f'Data filling failed to converged after {i} iterations.')

        # Save the new iteration
        iter_array = new_iter_array

    
        
    # Run one last smoothing operation on the data with a small Gaussian Filter
    iter_array = gaussian_filter(iter_array, sigma=0.5, truncate=3)
    
    return iter_array

iter_array = iterative_fill(args[0], args[2])

# Prep the interpolator
bin_centers = (np.array(args[2])[:,1:]+np.array(args[2])[:,:-1])/2
xx = np.meshgrid(*bin_centers)
interper = scipy.interpolate.RegularGridInterpolator(bin_centers, iter_array,
                                                     method='linear',
                                                     bounds_error=False,
                                                     fill_value=max_fill)

In [None]:
# Visualize the filling
args2 = list(args)
args2[0]=iter_array
fig,ax = plot('iter_fill.png', dis_vis.subplot_disruptivity2d, args2, figsize=(7,4))
ax.set_title(f"Shot {shot}")
# Get a pulse's flat top data
# shot = 1140226013 # AT Pulse
# shot = 1120105021 #VDE
subplot_draw_trajectory(ax,cmod_df,entry_dict,cmod_indices,shot)
# ax.plot([0,20], [0.5, 0.5], '--', c='red')
# ax.plot([0,20], [0.0, 0.5], '--', c='red')
plt.savefig("nograd_trajectory.png", dpi=400, facecolor='white', bbox_inches="tight")

In [None]:
# Jacobian
bins = args[2]
dx = bins[0][1] - bins[0][0]
dy = bins[1][1] - bins[1][0]
bin_centers = [bins[0][:-1]+dx,bins[1][:-1]+dy]

# Need to convert from u,v space to x,y space.
# grad_x = scipy.ndimage.sobel(args[0], axis=0)*dx
# grad_y = scipy.ndimage.sobel(args[0], axis=1)*dy
# norm = np.sqrt(grad_x**2 + grad_y**2)

# Gaussian Filter
def g(x,sig):
    return 1 / (2 * np.pi * np.sqrt(sig)) * np.exp( -(x / sig)**2 / 2 )
def g_x(x,sig):
    return - x / sig**2 *g(x,sig)
sig = 4
n_pts = 5
x = np.linspace(-5, 5, n_pts)
g_filter = g(x,sig)
g_x_filter = g_x(x,sig)

# Instead of doing 2 convolve 1Ds, for now construct a 2D kernel
# To make these real derivative will probably need some rescaling by sigma.
kernel_x = np.outer(g_filter, g_x_filter)
grad_x = convolve2d(args2[0],kernel_x.T,
                    mode='same', 
                    fillvalue=max_fill) * dx
grad_y = convolve2d(args2[0],kernel_x,
                    mode='same', 
                    fillvalue=max_fill) * dy

# Create two new interpolators for gx gy
# Prep the interpolator
bin_centers = (np.array(args[2])[:,1:]+np.array(args[2])[:,:-1])/2
xx = np.meshgrid(*bin_centers)
interper_x = scipy.interpolate.RegularGridInterpolator(bin_centers, grad_x,
                                                     method='linear',
                                                     bounds_error=False,
                                                     fill_value=0)
interper_y = scipy.interpolate.RegularGridInterpolator(bin_centers, grad_y,
                                                     method='linear',
                                                     bounds_error=False,
                                                     fill_value=0)




fig,ax = plt.subplots(1,3, figsize = (15,4) , constrained_layout=True)

subplot_grad2d(ax[0], grad_x, *args[1:])
subplot_grad2d(ax[1], grad_y, *args[1:])
subplot_grad2d(ax[2], norm, *args[1:])

ax[0].set_title("Grad X")
ax[1].set_title("Grad Y")
ax[2].set_title("Grad Norm")

In [None]:
# Visualize the filling
args2 = list(args)
args2[0]=iter_array
fig,ax = plot('quiver_fill.png', dis_vis.subplot_disruptivity2d, args2, figsize=(7,4))

# Get a pulse's flat top data
# shot = 1140226013 # AT Pulse
# shot = 1120105021 #VDE
flat_top_start = subplot_draw_trajectory(ax,cmod_df,entry_dict,cmod_indices,shot)
ax.set_title(f"Shot {shot}")
# ax.plot([0,20], [0.5, 0.5], '--', c='red')
# ax.plot([0,20], [0.0, 0.5], '--', c='red')

#Quiver
mesh_x, mesh_y = np.meshgrid(*bin_centers)
ax.quiver(mesh_x, mesh_y, grad_x.T, grad_y.T, angles='xy')
plt.savefig("grad_trajectory.png", dpi=400, facecolor='white', bbox_inches="tight")

In [None]:
# Ok so now lets look at probability of disruption for a pulse
# shot = 1140226013 # AT Pulse
# shot = 1120105021 #VDE
data = np.array(cmod_df[entry_dict.keys()][cmod_df.shot == shot])
y = dis.p_data(interper(data),1,True)
y_x = interper_x(data)
y_y = interper_y(data)

# Assume the last data point is the disruption time
disr_index = cmod_df.time_until_disrupt.index[cmod_df.shot == shot][-1]

# Make the plot
# fig,ax = plt.subplots(1,4, figsize = (4*5,4) , constrained_layout=True)
# fig,ax = plt.subplots(2,1, figsize = (5,2*4), constrained_layout=True, sharex=True)
fig = plt.figure()
fig.set_figheight(2*2.5)
fig.set_figwidth(6)
gs = fig.add_gridspec(2, hspace=0)
ax = gs.subplots(sharex=True)
# fig.suptitle(f'Shot {shot}', fontsize=16)

ax[0].plot(cmod_df.time[cmod_df.shot == shot],y, label='Probability', color='k')
ax[0].plot([cmod_df.time[disr_index],cmod_df.time[disr_index]],[0,1], '--', color='orange', label="Disruption Time",zorder=2)

ax[0].set_xlabel("Time (s)")
ax[0].set_ylabel("Disruption Probability")
ax[0].set_ylim(-0,1.)
ax[0].set_xlim(0,1.5)
ax[0].grid()
ax[0].set_title(f'Shot {shot}')
# ax[0].legend(loc='upper left')
# ax[0].set_title("Disruption Probability")


# Mean time to disruption as a function of time?
time_to_dis = 1/interper(data)
# ax[1].plot(cmod_df.time[cmod_df.shot == shot],cmod_df.time[cmod_df.shot == shot]+time_to_dis, label=str(shot)+" $1/d+t_{pulse}$")
ax[1].plot(cmod_df.time[cmod_df.shot == shot], time_to_dis, label="$1/d$" , color='k') #linestyle='dashdot',
ax[1].plot([cmod_df.time[disr_index],cmod_df.time[disr_index]],[0,100], '--', color='orange', label="Disruption Time")
# ax[1].plot([0,1.5],[cmod_df.time[disr_index],cmod_df.time[disr_index]], '--', color='orange')

ax[1].set_xlabel("Time (s)")
ax[1].set_ylabel("Time to Disruption (s)")
ax[1].set_ylim(0,15)
ax[1].grid()
# ax[1].legend(loc='upper right')

# # Gradient plot

# pad_grad_x = np.pad(grad_x,1)
# pad_grad_y = np.pad(grad_y,1)
# indx = np.searchsorted(bins[0], data[:,0])
# indy = np.searchsorted(bins[1], data[:,1])
# norm_data = np.sqrt(pad_grad_x[indx,indy]**2 + pad_grad_y[indx,indy]**2)

# ax[2].plot(cmod_df.time[cmod_df.shot == shot], y_x, color='k')
# ax[2].set_ylabel(f"Grad {axis_names[0]}")
# ax[2].grid()

# ax[3].plot(cmod_df.time[cmod_df.shot == shot], y_y, color='k')
# ax[3].set_ylabel(f"Grad {axis_names[1]}")
# ax[3].grid()

# # Get the axis names
# for i, key in enumerate(args[-1]):
#     ax[i+2].set_ylabel(f"Grad {args[-1][key]['axis_name']}")

# Remove overlapping ticks
for i in range(1,2):
    ax[i].get_yaxis().set_ticks(ax[i].get_yticks()[1:-1])

# Plot the Ramp up and dirsuption time
for i in range(2):
    # Get Axis Information
    x_lims = ax[i].get_xlim()
    y_lims = ax[i].get_ylim()
    
    # Reset the limits since the plotting seems to render extra
    ax[i].set_ylim(y_lims)
    
    # Plot desired quantity
    ax[i].plot([cmod_df.time[disr_index],cmod_df.time[disr_index]],y_lims, '--', color='orange', label="Disruption Time")
    ax[i].fill_between([x_lims[0],flat_top_start],[y_lims[0]]*2,[y_lims[1]]*2, color='grey', alpha=0.3,zorder=1)

fig.savefig("full_sweep.png", dpi=400, facecolor='white', bbox_inches="tight")

In [None]:
fig

In [None]:
args[-1]