# Test Notebook

Hello! This notebook is a test to see if all the code migration has worked properly. It will also test integration of notebooks into a mkdocs file. Let us observe!

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from scipy.io import loadmat
from scipy.stats import binned_statistic_dd as hist

# Move into the source directory for this notebook to work properly
# Probably want a better way of doing this.
import os
import importlib
os.chdir('../src/')

# Import whatever we need
import disruptivity as dis
import vis.disruptivity_vis as dis_vis
import vis.probability_vis as prob_vis
from vis.plot_helpers import plot_subplot as plot
import data_loader

# Import tokamak Configuartions
from tokamaks.cmod import CONFIG as CMOD
from tokamaks.d3d import CONFIG as D3D

importlib.reload(dis)
importlib.reload(dis_vis)
load_disruptions_mat = data_loader.load_disruptions_mat

# Some Placeholder stuff until the infrastructure is built
# These dictionaries will automatically be generated by the structure that will surround this code
# This is yet to be done, but this is what I was planning to do anyway.
entry_dict_1D = {
    'kappa':{
        'range':[0.8, 2.0],
        'axis_name': "$\kappa$",
    },
}

entry_dict_IP = {
    'ip':{
        'range':[-1.5e6, 1.5e6],
        'axis_name': "$I_p$ (A)",
    },
}

entry_dict_2D = {
    'z_error':CMOD['entry_dict']['z_error'],
    'kappa':CMOD['entry_dict']['kappa'],
}

entry_dict_H = {
    'murakami':{
        'range':[0,20],
        'axis_name': "$n_e R/B_T \ (10^{19}$m$^{-2}$/T)",
    },
    'inv_q95':{
        'range':[0, 0.6],
        'axis_name': "$1/q_{95}$",
    },
 }

## CMOD Data

In [None]:
cmod_df, cmod_indices = load_disruptions_mat('../data/CMod_disruption_warning_db.mat')

In [None]:
'''
So my goal with this block of code is to find all the portions of flat top disrupted shots 
that are in flat tops. Should be simple enough.
'''
# the big crunch
cmod_vde_shotlist = np.loadtxt("../data/cmod_vde_shotlist.txt", dtype=int)

# Compute indices of interest
indices_n_50_disrupt, indices_n_50_total = dis.get_indices_disruptivity(CMOD, cmod_df, cmod_indices, tau=50, window=20)
indices_n_disrupt, indices_n_total = dis.get_indices_disruptivity(CMOD, cmod_df, cmod_indices, tau=0, window=2)

In [None]:
# 0 Disruptivity Plots

# Compute Kappa disruptivity
# args = dis.compute_disruptivity(cmod_df, entry_dict_1D, indices_n_disrupt, indices_n_total, nbins=25)
# fig, ax = plot("cmod_kappa_disruptivity_vde.png", dis_vis.subplot_disruptivity1d, args)

# # Compute IP disruptivity
# args = dis.compute_disruptivity(cmod_df, entry_dict_IP, indices_n_disrupt, indices_n_total, nbins=25)
# fig, ax = plot("cmod_IP_disruptivity_vde.png", dis_vis.subplot_disruptivity1d, args)

entry_dict_2D['z_error']['range'] = [-0.1,0.1]

# Get the 2D histograms
args = dis.compute_disruptivity_sampling(cmod_df, entry_dict_2D, indices_n_disrupt, indices_n_total, nbins=35)
fig,ax = plot("cmod_kappa_zerr_disruptivity.png", dis_vis.subplot_disruptivity2d, args)

In [None]:
denom_dd = dis.indices_to_histogram(cmod_df, entry_dict_2D, indices_n_total, 35)
dt_array = dis.compute_variable_time(cmod_df, denom_dd, indices_n_total, 35)

In [None]:
from matplotlib import colors
entry_dict = entry_dict_2D

fig, ax = plt.subplots(1,1, figsize=(4,3))

# Parse the dict
extent = []
axis_name_list = []
for key in entry_dict:
    entry = entry_dict[key]

    # Follow the order of the dictionary to find x and y
    # Make sure the range is there
    assert (
        "range" in entry
    ), f"Entry {key} of entry_dict missing range field."
    extent.extend(entry["range"])

    # And the axis names
    assert (
        "axis_name" in entry
    ), f"Entry {key} of entry_dict missing axis_name field."
    axis_name_list.append(entry["axis_name"])

# The Normal Heatmap
cax = ax.imshow(
    dt_array.T,
    cmap="viridis",
    origin="lower",
    aspect="auto",
    extent=extent,
    norm=colors.LogNorm(),
)

# Axis Titles
ax.set_xlabel(axis_name_list[0])
ax.set_ylabel(axis_name_list[1])

# Colorbar
cbar = plt.colorbar(cax)
cbar.ax.tick_params(labelsize="large")
cbar.set_label(label="Time in Bin (s)", size="large")


In [None]:
from matplotlib import colors
entry_dict = entry_dict_2D

fig, ax = plt.subplots(1,1, figsize=(4,3))

# Parse the dict
extent = []
axis_name_list = []
for key in entry_dict:
    entry = entry_dict[key]

    # Follow the order of the dictionary to find x and y
    # Make sure the range is there
    assert (
        "range" in entry
    ), f"Entry {key} of entry_dict missing range field."
    extent.extend(entry["range"])

    # And the axis names
    assert (
        "axis_name" in entry
    ), f"Entry {key} of entry_dict missing axis_name field."
    axis_name_list.append(entry["axis_name"])

# The Normal Heatmap
cax = ax.imshow(
    args[1].T,
    cmap="viridis",
    origin="lower",
    aspect="auto",
    extent=extent,
    norm=colors.LogNorm(),
)

# Axis Titles
ax.set_xlabel(axis_name_list[0])
ax.set_ylabel(axis_name_list[1])

# Colorbar
cbar = plt.colorbar(cax)
cbar.ax.tick_params(labelsize="large")
cbar.set_label(label="Disruptivity Error (1/s)", size="large")

In [None]:
fig,ax = dis_vis.plot_data_selection(*args)

In [None]:
# # 50 Disruptivity Plots

# # Compute Kappa disruptivity
# args = dis.compute_disruptivity(cmod_df, entry_dict_1D, indices_n_50_disrupt, indices_n_50_total, nbins=60)
# fig, ax = plot("cmod_kappa_50_disruptivity.png", dis_vis.subplot_disruptivity1d, args)

# # Compute IP disruptivity
# args = dis.compute_disruptivity(cmod_df, entry_dict_IP, indices_n_50_disrupt, indices_n_50_total, nbins=60)
# fig, ax = plot("cmod_IP_50_disruptivity.png", dis_vis.subplot_disruptivity1d, args)

# Get the 2D histograms
args = dis.compute_disruptivity(cmod_df, entry_dict_2D, indices_n_50_disrupt, indices_n_50_total, nbins=35)
fig,ax = plot("cmod_kappa_zerr_50_disruptivity.png", dis_vis.subplot_disruptivity2d, args)

In [None]:
# the big crunch
cmod_vde_shotlist = np.loadtxt("../data/cmod_vde_shotlist.txt", dtype=int)


# Parameter setup
figtype = 'disruptivity_vde'
shotlist = None # set to None for no shotlist

# Compute indices of interest
indices_n_50_disrupt, indices_n_50_total = dis.get_indices_disruptivity(CMOD, cmod_df, cmod_indices, tau=50, window=12.5,shotlist=shotlist)
indices_n_disrupt, indices_n_total = dis.get_indices_disruptivity(CMOD, cmod_df, cmod_indices, tau=0, window=1, shotlist=shotlist)

for entry in CMOD['entry_dict']:
    # Scuffed but it works
    if (entry!="n_over_ncrit" and entry!="z_error"):
        continue
    
    # Create the entry dict
    entry_dict = {entry:CMOD['entry_dict'][entry]}
    
    # Compute Disruptivity and save the plot
    args = dis.compute_disruptivity(cmod_df, entry_dict, indices_n_disrupt, indices_n_total, nbins=25)
    fig, ax = plot(f'cmod_{entry}_{figtype}.png', dis_vis.subplot_disruptivity1d, args)
    
    # Compute Disruptivity and save the plot
    args = dis.compute_disruptivity(cmod_df, entry_dict, indices_n_50_disrupt, indices_n_50_total, nbins=25)
    fig, ax = plot(f'cmod_{entry}_50_{figtype}.png', dis_vis.subplot_disruptivity1d, args)

In [None]:
# The new plot. What we need to do is like loop through a bunch of taus and window and then fill up the plot 1 by 1.
from matplotlib import colors
import warnings
n_bins=35
tau_min = 0
tau_max = 100
n_steps = 51
dtau = (tau_max-tau_min)/n_steps

assert tau_min>=0
assert dtau>0

tau_list = np.linspace(tau_min,tau_max,n_steps)
window_list = np.where(tau_list<=20, 2, 20)

# FLAGS
NODATAMASK = -1
NODISRUPTIONMASK = -2
ALLDISRUPTIONMASK = -3

# Shot filter
cmod_vde_shotlist = np.loadtxt("../data/cmod_vde_shotlist.txt", dtype=int)
shotlist = None
figtype = 'tau_plot_sampling'

# params
dataframe = cmod_df
tokamak = CMOD
indices = cmod_indices
nbins=35
dt = None

# Memory
tau_results = {}

for entry in tokamak['entry_dict']:
    # Tracking
    print('Working on '+entry)
    
    # Create the entry dict
    entry_dict = {entry:tokamak['entry_dict'][entry]}
    
    # Init an empty array
    results = []
    errors = []
    data = []

    for i, (tau, window) in enumerate(zip(tau_list,window_list)):
        print("tau:",tau)
        # Compute indices of interest
        num_indices, denom_indices = dis.get_indices_disruptivity(tokamak, dataframe, indices, 
                                                                          tau=tau, window=window, shotlist=shotlist)
        
        
        # Compute the histograms
        num_dd = dis.indices_to_histogram(dataframe, entry_dict, num_indices, nbins)
        denom_dd = dis.indices_to_histogram(dataframe, entry_dict, denom_indices, nbins)

        # Parse the data
        n_disrupt = num_dd.statistic
        n_total = denom_dd.statistic

        # Compute dt using the denominator histogram
        if dt is None:
            dt_array = dis.compute_variable_time(
                dataframe, denom_dd, denom_indices, nbins
            )
        else:
            # Fixed Timestep Computation
            dt_array = dt * n_total

        # Surpress printing Division by 0 warnings since we handle them manually.
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")

            # Disruptivity equation sourced from verbal description in deVries 2009.
            # Here we handle divisions by 0 by setting the division result to 0.
            # Note that this is not the exact equation from the paper. Since
            # dt_array is a sum of all the times, not an average, we can just
            # divide by it instead.
            disruptivity = n_disrupt / dt_array
            disruptivity[~np.isfinite(disruptivity)] = 0

            # Error calculation assumes histogram error e = sqrt(n)
            # and propagates it over a division.
            # Here we handle divisions by 0 by setting the division result to 0.
            error = disruptivity * np.sqrt(1 / n_disrupt + 1 / n_total)
            error[~np.isfinite(error)] = 0

        # Get the histogram indices to mask.
        n_no_disrupt = n_disrupt == 0
        n_all_disrupt = n_disrupt == n_total
        n_no_data = n_total == 0

        # Apply the masks
        # Order matters since no data must be applied last.
        disruptivity[n_no_disrupt] = NODISRUPTIONMASK
        disruptivity[n_all_disrupt] = ALLDISRUPTIONMASK
        disruptivity[n_no_data] = NODATAMASK

        error[n_no_disrupt] = 0
        error[n_all_disrupt] = 0
        error[n_no_data] = 0

        # Save the output.
        results.append(disruptivity)
        errors.append(error)

    # Save results in dictionary.
    tau_results[entry]={}
    tau_results[entry]['results'] = np.array(results)
    tau_results[entry]['errors'] = np.array(errors)
    tau_results[entry]['data'] = np.array(dt_array)
    
# fig, ax = plot(f'cmod_{entry}_{figtype}.png', dis_vis.subplot_disruptivity1d, args)


In [None]:
# save that oh so valuable dictionary
import pickle 

with open('tau_results_no_vde.pkl', 'wb') as f:
    pickle.dump(tau_results, f)

In [None]:
def subplot_draw_trajectory(ax, dataframe, entry_dict, indices, shot):
    
    # Filter the datafram data to get the pulse of interest in the flattop
    disruptive_indices = indices["indices_flattop_disrupt_in_flattop"]
    shotlist_bool = np.isin(dataframe.shot, [shot])
    shot_indices = dataframe[shotlist_bool].index
    overlap = np.array(np.intersect1d(disruptive_indices,shot_indices))
    print(f"Flat Top Starts: {dataframe.time[overlap[0]]} s")
    
    # Get the entry information
    entries = list(entry_dict.keys())
    
    # Create line segments to color individually
    # Based on this example: https://matplotlib.org/stable/gallery/lines_bars_and_markers/multicolored_line.html
    time = dataframe['time'][overlap]
    x = dataframe[entries[0]][overlap]
    y = dataframe[entries[1]][overlap]
    points = np.array([x, y]).T.reshape(-1, 1, 2)
    segments = np.concatenate([points[:-1], points[1:]], axis=1)
    
    # Create a continuous norm to map from data points to colors
    norm = plt.Normalize(time.min(), time.max())
    lc = LineCollection(segments, cmap='cool', norm=norm)
    # Set the values used for colormapping
    lc.set_array(time)
    lc.set_linewidth(2)
    
    # Add the collection
    line = ax.add_collection(lc)
    
#     # RHS colorbar for time
#     cbar = fig.colorbar(line, ax=ax,location = 'top')
#     cbar.ax.tick_params()
#     cbar.set_label(label="Pulse Time", size="large")
    
#     # Plot the trajectory
#     ax.plot(dataframe[entries[0]][overlap], dataframe[entries[1]][overlap], c=c)
    
    # Re-apply axis labels and limits if needed.
    ax.set_xlim(entry_dict[entries[0]]['range'])
    ax.set_ylim(entry_dict[entries[1]]['range'])
    ax.set_xlabel(entry_dict[entries[0]]['axis_name'])
    ax.set_ylabel(entry_dict[entries[1]]['axis_name'])

In [None]:
cmod_df['time_until_disrupt_ms'] = cmod_df['time_until_disrupt']*1000
shot = 1160930030 # BOB PULSE
dataframe = cmod_df
from matplotlib.collections import LineCollection


for entry in tokamak['entry_dict']:
    # Tracking
    print('Working on '+entry)
    
    # Create the entry dict
    entry_dict = {entry:tokamak['entry_dict'][entry]}

    fig, ax = plt.subplots(1,3, constrained_layout=True, figsize=(12,5))
    extent = entry_dict[entry]['range']+[tau_min-dtau/2,tau_max+dtau/2]
    
    results = tau_results[entry]['results']
    errors = tau_results[entry]['errors']
    data = tau_results[entry]['data']

    # Heatmap
    cax1 = ax[0].imshow(
        results,
        cmap="viridis",
        origin="lower",
        aspect="auto",
        extent=extent,
        norm=colors.LogNorm(),
    )
    
        # Now the masked values
    no_disruptions = np.ma.masked_where(
        results != -2, np.ones(results.shape)
    )
    all_disruptions = np.ma.masked_where(
        results != -3, np.ones(results.shape)
    )

    # The masked values
    # When no disruptions, draw a black box
    ax[0].imshow(
        no_disruptions,
        cmap="Accent",
        origin="lower",
        aspect="auto",
        extent=extent,
        vmin=0,
        vmax=1,
    )

    # When all disruptions, draw a yellow box
    ax[0].imshow(
        all_disruptions,
        cmap="Set1",
        origin="lower",
        aspect="auto",
        extent=extent,
        vmin=0,
        vmax=1,
    )
    
    # Heatmap
    cax2 = ax[1].imshow(
        errors,
        cmap="viridis",
        origin="lower",
        aspect="auto",
        extent=extent,
        norm=colors.LogNorm(),
    )
    
    # Data Density
    x = np.linspace(extent[0], extent[1], len(data))
    ax[2].plot(x,data)
    ax[2].set_xlim(extent[:2])
    ax[2].grid()
    ax[2].set_xlabel(entry_dict[entry]['axis_name'])
    ax[2].set_ylabel('Time (s)')
    
    entry_dict_temp = {
        entry:tokamak['entry_dict'][entry],
        'time_until_disrupt_ms':{
            "range": [-1, 101],
            "axis_name": "tau",
        }
    }
    subplot_draw_trajectory(ax[0], dataframe, entry_dict_temp, indices, shot)
    
    # Colorbar
    cbar = plt.colorbar(cax1, label="Disruptivity ($s^{-1}$)")
#     cbar.ax.tick_params(labelsize="large")
#     cbar.set_label(label="Disruptivity ($s^{-1}$)", size="large")
    # Colorbar
    cbar = plt.colorbar(cax2, label="Disruptivity Error ($s^{-1}$)")
#     cbar.ax.tick_params(labelsize="large")
#     cbar.set_label(label="Disruptivity ($s^{-1}$)", size="large")
    # Colorbar
    cbar = plt.colorbar(cax3, label="Time (s)")
#     cbar.ax.tick_params(labelsize="large")
#     cbar.set_label(label="Disruptivity ($s^{-1}$)", size="large")
    
    ax[0].set_title("Disruptivity")
    ax[1].set_title("Error Bars")
    ax[2].set_title("Time In Each Bin")

    for i in range(2):
        ax[i].set_xlabel(entry_dict[entry]['axis_name'])
        ax[i].set_ylabel(r'''$\tau$ (ms)''')
    fig.savefig(f'{entry}_{figtype}', dpi=400, facecolor="w", bbox_inches="tight")
    

In [None]:
extent

In [None]:
# Compute the murakami parameter
cmod_df['inv_q95'] = 1/cmod_df['q95']
cmod_df['murakami'] = cmod_df['n_e']*0.68/(cmod_df['n_equal_1_mode']/cmod_df['n_equal_1_normalized'])/1e19

# Get the 2D histograms
args = dis.compute_disruptivity_sampling(cmod_df, entry_dict_H, indices_n_50_disrupt, indices_n_50_total, nbins=35)
fig,ax = plot(None, dis_vis.subplot_disruptivity2d, args)

# Plotting Constraints
ax.plot([0,20], [0.5, 0.5], '--', c='orange')
ax.plot([0,20], [0.0, 0.5], '--', c='orange')
fig.savefig('cmod_hugill_disruptivity.png', dpi=400, facecolor='w', bbox_inches="tight")
fig

## DIII-D Data

In [None]:
d3d_df, d3d_indices = load_disruptions_mat('../data/d3d-db-220420.mat')
n_shots = np.unique(d3d_df.shot).shape[0]
n_shots_no_disrupt = np.unique(d3d_df.shot[d3d_indices['indices_no_disrupt']]).shape[0]
n_shots_disrupt = np.unique(d3d_df.shot[d3d_indices['indices_disrupt']]).shape[0]
assert n_shots_disrupt+n_shots_no_disrupt == n_shots, \
    'Number of disrupts plus number of non disruptions does not equal the total shot number'
print(f'Total Shot Number: {n_shots}, Non-Disrupted Shots: {n_shots_no_disrupt}, Disrupted Shots: {n_shots_disrupt}')

In [None]:
'''
So my goal with this block of code is to find all the portions of flat top disrupted shots 
that are in flat tops. Should be simple enough.
'''
# Compute indices of interest
indices_n_50_disrupt, indices_n_50_total = dis.get_indices_disruptivity(D3D, d3d_df, d3d_indices, tau=350, window=12.5)
indices_n_disrupt, indices_n_total = dis.get_indices_disruptivity(D3D, d3d_df, d3d_indices, tau=0, window=2.5)

# Get the Ip histograms
args = dis.compute_disruptivity(d3d_df, entry_dict_IP, indices_n_disrupt, indices_n_total)
fig, ax = plot("d3d_ip_disruptivity.png", dis_vis.subplot_disruptivity1d, args)

In [None]:
# Compute Kappa disruptivity
args = dis.compute_disruptivity(d3d_df, entry_dict_1D, indices_n_50_disrupt, indices_n_50_total)
fig, ax = plot("d3d_kappa_detectable_disruptivity.png", dis_vis.subplot_disruptivity1d, args)

# Compute IP disruptivity
args = dis.compute_disruptivity(d3d_df, entry_dict_IP, indices_n_50_disrupt, indices_n_50_total)
fig, ax = plot("d3d_IP_detectable_disruptivity.png", dis_vis.subplot_disruptivity1d, args)

In [None]:
# Get the 2D histograms
args = dis.compute_disruptivity(d3d_df, entry_dict_2D, indices_n_50_disrupt, indices_n_50_total)
fig,ax = plot("d3d_kappa_ip_detectable_disruptivity.png", dis_vis.subplot_disruptivity2d, args)

In [None]:
# HOW TO FIND RAMP DOWN INDICES
# Try and get some ramp down indices. 
# We will do this by only taking indices of each shot after the flattop
not_ft = np.setxor1d(np.arange(0,len(cmod_df)) ,cmod_indices['indices_flattop'])
not_ru = cmod_df[cmod_df['time']>1].index
ramp_down = np.intersect1d(not_ft, not_ru)

In [None]:
def big_crunch(dataframe, num_indices, denom_indices, shotlist, tokamak, figtype, nbins=25):
    for entry in tokamak['entry_dict']:
        # Information Lines
        print("Working on "+entry)
        
        if entry=='z_error':
            continue
        
        # Create the entry dict
        entry_dict = {entry:tokamak['entry_dict'][entry]}
        
        # Get the tokamak name
        name = tokamak['name']

        # Compute Disruptivity and save the plot
        args = dis.compute_disruptivity_sampling(dataframe,
                                                   entry_dict,
                                                   num_indices,
                                                   denom_indices,
                                                   nbins=35,
                                                  )
        fig, ax = plot(f'{name}_{entry}_{figtype}.png', dis_vis.subplot_disruptivity1d, args)

In [None]:
# Parameter setup
figtype = 'disruptivity_sampling'
shotlist = None # set to None for no shotlist
     
# Compute indices of interest
indices_n_disrupt, indices_n_total = dis.get_indices_disruptivity(D3D, d3d_df, d3d_indices, shotlist=shotlist, tau=350, window=25)

big_crunch(d3d_df, indices_n_disrupt, indices_n_total, shotlist, D3D, figtype, nbins=35)

In [None]:
cmod_df.keys()

In [None]:
plt.hist(cmod_df['n_over_ncrit'], range=[-2.5,2.5], bins=25)
plt.xlabel('$n/n_{crit}$')

# plt.hist(cmod_df['z_error'], range=[-0.02,0.02], bins=25)
# plt.xlabel('$z_{error}$')

In [None]:
plt.plot((np.array(cmod_df['time'])[1:]-np.array(cmod_df['time'])[:-1])[:100])
plt.ylim(0,0.05)

In [None]:
plt.plot(cmod_df["time_until_disrupt"][:100])
plt.ylim(0,0.1)

In [None]:
cmod_df['time_until_disrupt']


In [None]:
cmod_df['time_until_disrupt_ms']