# Test Notebook

Hello! This notebook is a test to see if all the code migration has worked properly. It will also test integration of notebooks into a mkdocs file. Let us observe!

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from scipy.io import loadmat
from scipy.stats import binned_statistic_dd as hist

# Move into the source directory for this notebook to work properly
# Probably want a better way of doing this.
import os
import importlib
os.chdir('../src/')

# Import whatever we need
import disruptivity as dis
import vis.disruptivity_vis as dis_vis
import vis.probability_vis as prob_vis
from vis.plot_helpers import plot_subplot as plot
import data_loader

# Import tokamak Configuartions
from tokamaks.cmod import CONFIG as CMOD
from tokamaks.d3d import CONFIG as D3D

importlib.reload(dis)
importlib.reload(dis_vis)
load_disruptions_mat = data_loader.load_disruptions_mat

# Some Placeholder stuff until the infrastructure is built
# These dictionaries will automatically be generated by the structure that will surround this code
# This is yet to be done, but this is what I was planning to do anyway.
entry_dict_1D = {
    'kappa':{
        'range':[0.8, 2.0],
        'axis_name': "$\kappa$",
    },
}

entry_dict_IP = {
    'ip':{
        'range':[-1.5e6, 1.5e6],
        'axis_name': "$I_p$ (A)",
    },
}

entry_dict_2D = {
    'z_error':CMOD['entry_dict']['z_error'],
    'kappa':CMOD['entry_dict']['kappa'],
}

entry_dict_H = {
    'murakami':{
        'range':[0,20],
        'axis_name': "$n_e R/B_T \ (10^{19}$m$^{-2}$/T)",
    },
    'inv_q95':{
        'range':[0, 0.6],
        'axis_name': "$1/q_{95}$",
    },
 }

## CMOD Data

In [None]:
cmod_df, cmod_indices = load_disruptions_mat('../data/CMod_disruption_warning_db.mat')
n_shots = np.unique(cmod_df.shot).shape[0]
n_shots_no_disrupt = np.unique(cmod_df.shot[cmod_indices['indices_no_disrupt']]).shape[0]
n_shots_disrupt = np.unique(cmod_df.shot[cmod_indices['indices_disrupt']]).shape[0]
assert n_shots_disrupt+n_shots_no_disrupt == n_shots, \
    'Number of disrupts plus number of non disruptions does not equal the total shot number'
print(f'Total Shot Number: {n_shots}, Non-Disrupted Shots: {n_shots_no_disrupt}, Disrupted Shots: {n_shots_disrupt}')

In [None]:
'''
So my goal with this block of code is to find all the portions of flat top disrupted shots 
that are in flat tops. Should be simple enough.
'''
# the big crunch
cmod_vde_shotlist = np.loadtxt("../data/cmod_vde_shotlist.txt", dtype=int)

# Compute indices of interest
indices_n_50_disrupt, indices_n_50_total = dis.get_indices_disruptivity(CMOD, cmod_df, cmod_indices, tau=50, window=20)
indices_n_disrupt, indices_n_total = dis.get_indices_disruptivity(CMOD, cmod_df, cmod_indices, tau=0, window=2)

In [None]:
# 0 Disruptivity Plots

# Compute Kappa disruptivity
# args = dis.compute_disruptivity(cmod_df, entry_dict_1D, indices_n_disrupt, indices_n_total, nbins=25)
# fig, ax = plot("cmod_kappa_disruptivity_vde.png", dis_vis.subplot_disruptivity1d, args)

# # Compute IP disruptivity
# args = dis.compute_disruptivity(cmod_df, entry_dict_IP, indices_n_disrupt, indices_n_total, nbins=25)
# fig, ax = plot("cmod_IP_disruptivity_vde.png", dis_vis.subplot_disruptivity1d, args)

entry_dict_2D['z_error']['range'] = [-0.1,0.1]

# Get the 2D histograms
args = dis.compute_disruptivity(cmod_df, entry_dict_2D, indices_n_disrupt, indices_n_total, nbins=35)
fig,ax = plot("cmod_kappa_zerr_disruptivity.png", dis_vis.subplot_disruptivity2d, args)

In [None]:
# # 50 Disruptivity Plots

# # Compute Kappa disruptivity
# args = dis.compute_disruptivity(cmod_df, entry_dict_1D, indices_n_50_disrupt, indices_n_50_total, nbins=60)
# fig, ax = plot("cmod_kappa_50_disruptivity.png", dis_vis.subplot_disruptivity1d, args)

# # Compute IP disruptivity
# args = dis.compute_disruptivity(cmod_df, entry_dict_IP, indices_n_50_disrupt, indices_n_50_total, nbins=60)
# fig, ax = plot("cmod_IP_50_disruptivity.png", dis_vis.subplot_disruptivity1d, args)

# Get the 2D histograms
args = dis.compute_disruptivity(cmod_df, entry_dict_2D, indices_n_50_disrupt, indices_n_50_total, nbins=35)
fig,ax = plot("cmod_kappa_zerr_50_disruptivity.png", dis_vis.subplot_disruptivity2d, args)

In [None]:
# the big crunch
cmod_vde_shotlist = np.loadtxt("../data/cmod_vde_shotlist.txt", dtype=int)


# Parameter setup
figtype = 'disruptivity_vde'
shotlist = None # set to None for no shotlist

# Compute indices of interest
indices_n_50_disrupt, indices_n_50_total = dis.get_indices_disruptivity(CMOD, cmod_df, cmod_indices, tau=50, window=12.5,shotlist=shotlist)
indices_n_disrupt, indices_n_total = dis.get_indices_disruptivity(CMOD, cmod_df, cmod_indices, tau=0, window=1, shotlist=shotlist)

for entry in CMOD['entry_dict']:
    # Scuffed but it works
    if (entry!="n_over_ncrit" and entry!="z_error"):
        continue
    
    # Create the entry dict
    entry_dict = {entry:CMOD['entry_dict'][entry]}
    
    # Compute Disruptivity and save the plot
    args = dis.compute_disruptivity(cmod_df, entry_dict, indices_n_disrupt, indices_n_total, nbins=25)
    fig, ax = plot(f'cmod_{entry}_{figtype}.png', dis_vis.subplot_disruptivity1d, args)
    
    # Compute Disruptivity and save the plot
    args = dis.compute_disruptivity(cmod_df, entry_dict, indices_n_50_disrupt, indices_n_50_total, nbins=25)
    fig, ax = plot(f'cmod_{entry}_50_{figtype}.png', dis_vis.subplot_disruptivity1d, args)

In [None]:
# The new plot. What we need to do is like loop through a bunch of taus and window and then fill up the plot 1 by 1.
from matplotlib import colors
n_steps = 20
n_bins=35
tau_range = [0,60]
tau_list = np.linspace(*tau_range,n_steps)
window_list = np.where(tau_list<=20, 2, 20)

# Shot filter
shotlist = cmod_vde_shotlist


for entry in CMOD['entry_dict']:
    # Tracking
    print('Working on '+entry)
    
    # Create the entry dict
    entry_dict = {entry:CMOD['entry_dict'][entry]}
    
    # Init an empty array
    results = []

    for i, (tau, window) in enumerate(zip(tau_list,window_list)):
        print("tau:",tau)
        # Compute indices of interest
        indices_n_disrupt, indices_n_total = dis.get_indices_disruptivity(CMOD, cmod_df, cmod_indices, 
                                                                          tau=tau, window=window, shotlist=shotlist)

        # Compute Disruptivity and save the plot
        args = dis.compute_disruptivity(cmod_df, entry_dict, indices_n_disrupt, indices_n_total, nbins=n_bins)
        results.append(args[0])

    # Convert to array and work
    results = np.array(results)

    fig, ax = plt.subplots(1,1)
    extent = entry_dict[entry]['range']+tau_range

    # Heatmap
    cax = ax.imshow(
        results,
        cmap="viridis",
        origin="lower",
        aspect="auto",
        extent=extent,
        norm=colors.LogNorm(),
    )

    # Colorbar
    cbar = plt.colorbar(cax, label="Disruptivity ($s^{-1}$)")
    cbar.ax.tick_params(labelsize="large")
    cbar.set_label(label="Disruptivity ($s^{-1}$)", size="large")

    ax.set_xlabel(entry_dict[entry]['axis_name'])
    ax.set_ylabel(r'''$\tau$ (ms)''')
    fig.savefig(f'{entry}_tau_plot_vde.png', dpi=400, facecolor="w", bbox_inches="tight")
    
# fig, ax = plot(f'cmod_{entry}_{figtype}.png', dis_vis.subplot_disruptivity1d, args)


In [None]:
# Compute the murakami parameter
cmod_df['inv_q95'] = 1/cmod_df['q95']
cmod_df['murakami'] = cmod_df['n_e']*0.68/(cmod_df['n_equal_1_mode']/cmod_df['n_equal_1_normalized'])/1e19

# Get the 2D histograms
args = dis.compute_disruptivity(cmod_df, entry_dict_H, indices_n_50_disrupt, indices_n_50_total, nbins=35)
fig,ax = plot(None, dis_vis.subplot_disruptivity2d, args)

# Plotting Constraints
ax.plot([0,20], [0.5, 0.5], '--', c='orange')
ax.plot([0,20], [0.0, 0.5], '--', c='orange')
fig.savefig('cmod_hugill_disruptivity.png', dpi=400, facecolor='w', bbox_inches="tight")

## DIII-D Data

In [None]:
d3d_df, d3d_indices = load_disruptions_mat('../data/d3d-db-220420.mat')
n_shots = np.unique(d3d_df.shot).shape[0]
n_shots_no_disrupt = np.unique(d3d_df.shot[d3d_indices['indices_no_disrupt']]).shape[0]
n_shots_disrupt = np.unique(d3d_df.shot[d3d_indices['indices_disrupt']]).shape[0]
assert n_shots_disrupt+n_shots_no_disrupt == n_shots, \
    'Number of disrupts plus number of non disruptions does not equal the total shot number'
print(f'Total Shot Number: {n_shots}, Non-Disrupted Shots: {n_shots_no_disrupt}, Disrupted Shots: {n_shots_disrupt}')

In [None]:
'''
So my goal with this block of code is to find all the portions of flat top disrupted shots 
that are in flat tops. Should be simple enough.
'''
# Compute indices of interest
indices_n_50_disrupt, indices_n_50_total = dis.get_indices_disruptivity(D3D, d3d_df, d3d_indices, tau=350, window=12.5)
indices_n_disrupt, indices_n_total = dis.get_indices_disruptivity(D3D, d3d_df, d3d_indices, tau=0, window=2.5)

# Get the Ip histograms
args = dis.compute_disruptivity(d3d_df, entry_dict_IP, indices_n_disrupt, indices_n_total)
fig, ax = plot("d3d_ip_disruptivity.png", dis_vis.subplot_disruptivity1d, args)

In [None]:
# Compute Kappa disruptivity
args = dis.compute_disruptivity(d3d_df, entry_dict_1D, indices_n_50_disrupt, indices_n_50_total)
fig, ax = plot("d3d_kappa_detectable_disruptivity.png", dis_vis.subplot_disruptivity1d, args)

# Compute IP disruptivity
args = dis.compute_disruptivity(d3d_df, entry_dict_IP, indices_n_50_disrupt, indices_n_50_total)
fig, ax = plot("d3d_IP_detectable_disruptivity.png", dis_vis.subplot_disruptivity1d, args)

In [None]:
# Get the 2D histograms
args = dis.compute_disruptivity(d3d_df, entry_dict_2D, indices_n_50_disrupt, indices_n_50_total)
fig,ax = plot("d3d_kappa_ip_detectable_disruptivity.png", dis_vis.subplot_disruptivity2d, args)

In [None]:
# HOW TO FIND RAMP DOWN INDICES
# Try and get some ramp down indices. 
# We will do this by only taking indices of each shot after the flattop
not_ft = np.setxor1d(np.arange(0,len(cmod_df)) ,cmod_indices['indices_flattop'])
not_ru = cmod_df[cmod_df['time']>1].index
ramp_down = np.intersect1d(not_ft, not_ru)

In [None]:
cmod_df.keys()

In [None]:
plt.hist(cmod_df['n_over_ncrit'], range=[-2.5,2.5], bins=25)
plt.xlabel('$n/n_{crit}$')

# plt.hist(cmod_df['z_error'], range=[-0.02,0.02], bins=25)
# plt.xlabel('$z_{error}$')

In [None]:
plt.plot((np.array(cmod_df['time'])[1:]-np.array(cmod_df['time'])[:-1])[:100])
plt.ylim(0,0.05)

In [None]:
plt.plot(cmod_df["time_until_disrupt"][:100])
plt.ylim(0,0.1)