# Import The Relevant Libraries

In [1]:
import numpy as np
import pandas as pd
import matplotlib.image as mpimg
from astropy.io import fits
from matplotlib import pyplot as plt
import cv2
import os
import astropy.units as u
from astropy.convolution import convolve, Gaussian2DKernel
import random

# Obtain and Process The Image Data

In [2]:
t_data = pd.read_csv("All_Transient_Data.csv")

In [3]:
rel_data = t_data[["sbid","beam","name","PSR_Label"]]
# Remove all NaN values
rel_data = rel_data.dropna(how='any',axis=0)
rel_data["sbid"] = rel_data["sbid"].astype(str)
rel_data["fits_path"] = "SB"+rel_data["sbid"]+"_"+rel_data["beam"]+"_slices_"+rel_data["name"]+".fits"

## Obtain Relevant Images Slices

### Obtain a Balanced Light Curve Dataset

In [4]:
def obtain_src_name(avail_fits, chosen_idx):
    """
    Used to obtain the source name and other details of particular fits paths.

    Parameters
    ----------
    avail_fits : list
        List of strings of fits paths

    chosen_idx : int
        Index of particular fits path in avail_fits

    Returns
    -------
    src_name : str
        Source name of fits path.

    src_details : list
        List of strings containing SBID, beam ID and source name
    """
    
    src_details = avail_fits[chosen_idx].split("/")[1].split("_")
    src_name = src_details[-1][:-5]
    return src_name, src_details

In [5]:
def obtain_lc_df(lc_folder_name, src_name, src_details, chosen_idx, real_cand):
    """
    Used to obtain a dataframe for the light curve information.

    Parameters
    ----------
    lc_folder_name : float
        Path to light curve folder

    src_name : float
        Source name

    src_details : list
        Source details

    chosen_idx : int
        Index of label of interest in the labels

    real_cand : list
        List of strings of labels for the fits data.

    Returns
    -------
    lc_df : pandas dataframe
        A dataframe containing "Time", "Time_Idx", "peak_flux", "local_rms", "Above_Threshold" and "Source_Info" columns
    """
    
    lc_local_rms = pd.read_csv(f"{lc_folder_name}/{src_details[0]}_{src_details[1]}_lightcurve_local_rms.csv")
    lc_local_rms = lc_local_rms.rename(columns = {src_name: "local_rms"})
    lc_peak_flux = pd.read_csv(f"{lc_folder_name}/{src_details[0]}_{src_details[1]}_lightcurve_peak_flux.csv")
    lc_peak_flux = lc_peak_flux.rename(columns = {src_name: "peak_flux"})
    lc_df = pd.merge(lc_local_rms, lc_peak_flux, on="Time")
    # Remove unnecessary columns
    lc_df = lc_df[["Time", "peak_flux", "local_rms"]]
    # Obtain indices of each time value so getting fits image is easier.
    lc_df["Time_Idx"] = lc_df.index

    # Drop rows with null values
    lc_df = lc_df.dropna()

    # Remove rows where local_rms = 0
    lc_df = lc_df[lc_df["local_rms"] != 0].reset_index(drop=True) #### NOT SURE IF THIS IS ALLOWED

    # Obtaining Labels
    threshold = 5 
    # (peak flux - mean peak flux) / local rms > threshold
    
    for row in lc_df:
        if real_cand[chosen_idx] == 1:
            # 1 is true positive and 0 is non-detection
            lc_df["Above_Threshold"] = np.where((np.abs(lc_df["peak_flux"]-lc_df["peak_flux"].mean())/ lc_df["local_rms"] > threshold), 1, 0)
        else:
            # 2 is false positive and 0 is non-detection
            lc_df["Above_Threshold"] = np.where((np.abs(lc_df["peak_flux"]-lc_df["peak_flux"].mean())/ lc_df["local_rms"] > threshold), 2, 0)
    lc_df["Source_Info"] = f"{src_details[0]}_{src_details[1]}_{src_name}"
    return lc_df

In [6]:
def obtain_balanced_lc_df(lc_df, random_state = 43, ratio_d_to_nd = 2):
    """
    Used to obtain a balanced dataframe whereby the ratio of true positive
    or false positive candidates to non detections are in a ratio_d_to_nd:1 ratio.

    Parameters
    ----------
    lc_df : pandas dataframe
        A dataframe with relevant light curve information.

    random_state : int, default 43
        The random state used for reproducibility.
        
    ratio_d_to_nd : int, default 2
        The ratio of positives (either true or false) to non detections.
        
    Returns
    -------
    balanced_lc_df_ls : list
        A list containing light curve dataframes whereby the positives and non-detections 
        are balanced according to the given ratio.

    positive_counts : int
        Number of positives in the balanced light curve list
    """
    
    balanced_lc_df_ls = []
    # To adjust the ratio between positives and non-detection, change num_cands
    if sum(lc_df["Above_Threshold"] != 0) <=  sum(lc_df["Above_Threshold"] == 0):
        num_cands = sum(lc_df["Above_Threshold"] != 0)
        nd_lc_data = lc_df[lc_df["Above_Threshold"] == 0] # non detections
        reduced_nd_lc_data = nd_lc_data.sample(n=round(num_cands/ratio_d_to_nd), random_state = random_state, axis = 0) # Want 2:1 ratio between detections and non-detections so divide num_cands by 2
        balanced_lc_df_ls = [lc_df[lc_df["Above_Threshold"] != 0], reduced_nd_lc_data]
    else:
        num_cands = sum(lc_df["Above_Threshold"] == 0)
        lc_data = lc_df[lc_df["Above_Threshold"] != 0] # positives
        reduced_lc_data = lc_data.sample(n=num_cands, random_state = random_state, axis = 0)
        nd_lc_data = lc_df[lc_df["Above_Threshold"] == 0] # non detections
        reduced_nd_lc_data = nd_lc_data.sample(n=round(num_cands/ratio_d_to_nd), random_state = random_state, axis = 0)  # Want 2:1 ratio between detections and non-detections so divide num_cands by 2
        balanced_lc_df_ls = [reduced_lc_data, reduced_nd_lc_data]
        
    if len(balanced_lc_df_ls) != 0:
        positive_counts = len(balanced_lc_df_ls[0])
    else:
        positive_counts = 0
    return balanced_lc_df_ls, positive_counts

#### Finding Number of Occurrences of Each Label

In [7]:
lc_folder_name = "VAST 10s lightcurve"
fits_folder_name = "VAST 10s fitscube"

In [8]:
# List of files in fits_folder and the actual data
avail_fits = []
# List of real candidates
real_cand = []
for file in os.listdir(fits_folder_name):
    if file in rel_data["fits_path"].unique():
        avail_fits.append(fits_folder_name+"/"+file)
        real_cand.append(rel_data[rel_data["fits_path"]==file]["PSR_Label"].item())

In [9]:
def obtain_positive_counts(avail_fits, label):
    """
    Counts the total number of positives, either
    true or false positives.

    Parameters
    ----------
    avail_fits : list
        List of strings of the available fits paths
    
    label : int
        Either 1 for true positive or 0 for false positive

    Returns
    -------
    idx_count_dict : dict
        Dictionary which maps the indices of the fits paths in the avail_fits
        to the count of positives for the associated light curve dataframe

    total : int
        Total number of positives
    """
    #label = 1 for true positive and label = 0  for false positive
    chosen_idx_ls = [i for i,val in enumerate(real_cand) if val==label]
    idx_count_dict = {}
    total = 0
    for chosen_idx in chosen_idx_ls:
        src_name, src_details = obtain_src_name(avail_fits, chosen_idx)
        if os.path.isfile(f"{lc_folder_name}/{src_details[0]}_{src_details[1]}_lightcurve_local_rms.csv"):
            lc_df = obtain_lc_df(lc_folder_name, src_name, src_details, chosen_idx, real_cand)
            _ , positive_counts = obtain_balanced_lc_df(lc_df) # extended balanced_lc_df list
            idx_count_dict[chosen_idx] = positive_counts
            total += positive_counts
        else:
            continue
    return idx_count_dict, total

In [10]:
idx_fp_count_dict, total_fp = obtain_positive_counts(avail_fits, 0)

In [11]:
idx_tp_count_dict, total_tp = obtain_positive_counts(avail_fits, 1)

In [12]:
print(total_fp)

18895


In [13]:
print(total_tp)

1611


In [14]:
tp_src_idx = list(idx_tp_count_dict.keys())
print(len(tp_src_idx)) # number of true positives

252


In [15]:
fp_src_idx = list(idx_fp_count_dict.keys())
print(len(fp_src_idx)) # number of false positives

2277


#### Concatenate the false positives and true positives into 1 dataframe

In [16]:
def stack_balanced_df(src_idx):
    """
    Stacks all the balanced dataframes obtained
    from the light curves of all the relevant fits paths
    into 1 dataframe.
    
    Parameters
    ----------
    avail_fits : list
        List of strings of the available fits paths
    
    src_idx : list
        List of integers of indices for fits paths in avail_fits

    Returns
    -------
    balanced_lc_df : pandas dataframe
        A dataframe containing the data of all the other balanced dataframes
        of the relevant light curves.
    """
    
    balanced_lc_df_ls = []

    # Stack all the balanced light curves for every source with false/ true positives on top of each other
    for idx in src_idx:
        src_name, src_details = obtain_src_name(avail_fits, idx)
        lc_df = obtain_lc_df(lc_folder_name, src_name, src_details, idx, real_cand)
        new_balanced_lc_df, _ = obtain_balanced_lc_df(lc_df)
        balanced_lc_df_ls.extend(new_balanced_lc_df)

    balanced_lc_df = pd.concat(balanced_lc_df_ls).reset_index(drop=True)
    
    return balanced_lc_df

In [17]:
fp_balanced_lc_df = stack_balanced_df(fp_src_idx)

In [18]:
fp_balanced_lc_df.shape

(28259, 6)

In [19]:
fp_balanced_lc_df["Above_Threshold"].value_counts()

Above_Threshold
2    18895
0     9364
Name: count, dtype: int64

In [20]:
tp_balanced_lc_df = stack_balanced_df(tp_src_idx)

In [21]:
tp_balanced_lc_df["Above_Threshold"].value_counts()

Above_Threshold
1    1611
0     793
Name: count, dtype: int64

In [22]:
tp_balanced_lc_df.shape

(2404, 6)

In [23]:
final_balanced_df = pd.concat([fp_balanced_lc_df, tp_balanced_lc_df], ignore_index=True, axis = 0)

In [24]:
final_balanced_df["Above_Threshold"].value_counts()

Above_Threshold
2    18895
0    10157
1     1611
Name: count, dtype: int64

In [25]:
sum(final_balanced_df["Above_Threshold"].value_counts())

30663

## Splitting The Data into Training, Validation and Testing Sets

In [26]:
def split_train_test_data(final_balanced_df, prop, det_src_type = 0, thresh = 5):
    """
    Splits the data into training and testing data. The incoming data
    has only true positives and non detections or only false positives
    and non detections.

    Parameters
    ----------
    final_balanced_df : pandas dataframe
        A light curve dataframe with a balanced number of
        positives and non detections.

    prop : float
        The proportion of the data to go into the training
        data.

    det_src_type : 0, default 0
        If det_src_type is 0 then the detected sources are true positives,
        if det_src_type is 1 then the detected sources are false positives

    thresh : int, default 5
        Will exit the while loop when the number of positives
        in the training data is just within the calculated number of proportion of positives
        in the training data plus this threshold value
        
    Returns
    -------
    train_df : pandas dataframe
        The light curve dataframe to be used for training

    test_df : pandas dataframe
        The light curve dataframe to be used for testing
    """
    # det_src_type = 0, then the detected sources are true positives, and if they are 1 then the detected sources are false positives
    train_count = round(prop * sum(final_balanced_df["Above_Threshold"]==det_src_type+1)) # desired train_count
    sources = list(final_balanced_df[final_balanced_df["Above_Threshold"]==det_src_type+1]["Source_Info"].unique())
    act_train_count = 0 # Actual train_count that we keep adding to
    chsen_srcs = []
    while act_train_count < train_count:
        # Randomly choose sources and then remove the source from the array to ensure no duplicates.
        sel_src = np.random.choice(sources)
        new_count = sum(final_balanced_df[final_balanced_df["Source_Info"] == sel_src]["Above_Threshold"] != 0) # Only consider the count of true/ false positives
        if act_train_count + new_count > train_count + thresh:
            continue
        else:
            act_train_count += new_count
            chsen_srcs.append(sel_src)
            sources.remove(sel_src)

    non_chsen_srcs = list(set(sources)-set(chsen_srcs))
    train_df = final_balanced_df[final_balanced_df["Source_Info"].isin(chsen_srcs)]

    # Put rest of data with "Above_Threshold" value == det_src_type+1 into test df.
    test_df = final_balanced_df[final_balanced_df["Source_Info"].isin(non_chsen_srcs)]

    return train_df, test_df

In [27]:
fp_train_df, fp_test_df = split_train_test_data(final_balanced_df, 0.7, det_src_type = 1)

In [28]:
fp_train_df, fp_val_df = split_train_test_data(fp_train_df, 0.7, det_src_type = 1)

In [29]:
tp_train_df, tp_test_df = split_train_test_data(final_balanced_df, 0.7, det_src_type = 0)

In [30]:
tp_train_df, tp_val_df = split_train_test_data(tp_train_df, 0.7, det_src_type = 0)

#### Train Data

In [31]:
random_state = 43
train_df = pd.concat([tp_train_df, fp_train_df], ignore_index=True, axis = 0)
# randomise the rows
train_df = train_df.sample(frac=1, random_state = random_state, axis=0)
train_df = train_df.reset_index(drop=True)

In [32]:
train_df[train_df["Above_Threshold"] == 1]

Unnamed: 0,Time,peak_flux,local_rms,Time_Idx,Above_Threshold,Source_Info
8,2024-06-14T08:34:34.780029,0.001022,0.001555,16,1,SB62877_beam18_J082015.63-411435.09
36,2024-06-07T15:57:52.708989,0.184731,0.003788,24,1,SB62643_beam31_J175258.58-280637.95
37,2024-04-22T19:17:59.240063,0.041809,0.002569,34,1,SB61530_beam17_J182530.54-093523.39
107,2024-04-04T21:57:31.420030,0.142498,0.003661,38,1,SB60803_beam26_J175258.73-280636.72
111,2024-06-23T15:25:43.542141,0.178184,0.002973,11,1,SB63235_beam31_J175258.58-280637.95
...,...,...,...,...,...,...
14950,2024-04-04T21:39:06.605950,0.130343,0.002121,17,1,SB60802_beam19_J175258.66-280637.15
14960,2024-04-19T19:25:02.475647,0.544332,0.002875,28,1,SB61302_beam11_J164449.31-455910.49
14974,2024-04-16T19:34:55.691135,0.159635,0.002889,29,1,SB61159_beam31_J175258.58-280637.95
14999,2024-04-05T21:23:12.639356,0.017459,0.001392,67,1,SB60866_beam03_J185257.51-063559.86


#### Validation Data

In [33]:
val_df = pd.concat([tp_val_df, fp_val_df], ignore_index=True, axis = 0)
# randomise the rows
val_df = val_df.sample(frac=1, random_state = random_state, axis=0)
val_df = val_df.reset_index(drop=True)

#### Test Data

In [34]:
test_df = pd.concat([tp_test_df, fp_test_df], ignore_index=True, axis = 0)
# randomise the rows
test_df = test_df.sample(frac=1, random_state = random_state, axis=0)
test_df = test_df.reset_index(drop=True)

In [35]:
test_df.shape

(9192, 6)

#### Test Sources Overlap

In [36]:
set(train_df["Source_Info"]).intersection(set(test_df["Source_Info"]))

set()

In [37]:
set(train_df["Source_Info"]).intersection(set(val_df["Source_Info"]))

set()

In [38]:
set(test_df["Source_Info"]).intersection(set(val_df["Source_Info"]))

set()

#### Obtain The Relevant Fits Files for Each Source and Preprocess The Images

In [39]:
def get_img_slices_and_labels(df, fits_fol = "VAST 10s fitscube"):
    """
    Obtain the images and labels that the classifier
    will be trained on.

    Parameters
    ----------
    df : pandas dataframe
        Light curve dataframe

    fits_fol : str, default "VAST 10s fitscube"
        Path to the fits cube folder

    Returns
    -------
    img_labels : pandas series
        Labels for the images

    img_slice : numpy array
        The images to classify
    """
    
    # Get the relevant image slices and their associated labels for a particular dataframe - test, train or val.
    img_labels = df["Above_Threshold"] # remove labels which do not correspond to actual image
    keep_idx = []
    img_slice = np.empty(shape=(len(img_labels),120,120,3))
    for i, row in df.iterrows():
        src_info = row.loc["Source_Info"].split("_")
        src_info.insert(2, "slices")
        fits_filename = f'{fits_fol}/{"_".join(src_info)}.fits'
        hdu = fits.open(fits_filename)
        data = hdu[0].data

        new_img_slice = data[row.loc["Time_Idx"]]

        if new_img_slice.shape != (120,120):
            continue
        # Since this is a valid image let's keep this image
        keep_idx.append(i)
        
        # Triplicate to obtain RGB from binary.
        new_img_slice = np.repeat(new_img_slice[...,np.newaxis], 3, -1)
        new_img_slice = new_img_slice[np.newaxis,:,:]
        new_img_slice = normalise_img(new_img_slice)
        # Add new image onto img slice array.
        img_slice[i,:,:,:] = new_img_slice

    # keep the relevant image labels
    img_slice = img_slice[keep_idx,:,:,:]
    img_labels = img_labels[keep_idx].reset_index(drop=True)

    return img_labels, img_slice

In [40]:
def normalise_img(img):
    """
    Applies min-max normalisation on an image.

    Parameters
    ----------
    img : numpy array
        An image array of integers or floats

    Returns
    -------
    img : numpy array
        A min-max normalised image
    """
    # Normalise the input image.
    img = (img - img.min()) / (img.max() - img.min())
    return img

#### Get Image Slices and Labels

In [41]:
test_img_labels, test_img_slice = get_img_slices_and_labels(test_df)
val_img_labels, val_img_slice = get_img_slices_and_labels(val_df)
train_img_labels, train_img_slice = get_img_slices_and_labels(train_df)

In [42]:
# Check if images are normalised
test_img_slice.max()
test_img_slice.min()

0.0

In [43]:
# check for NaN values in the images
sum(sum(sum(np.isnan(test_img_slice))))

array([0, 0, 0])

In [44]:
test_img_slice.shape

(9173, 120, 120, 3)

#### Generate True Positives and Non-Detections For Train, Test and Validation Images

In [45]:
def generate_slice(imsize_x = 120, imsize_y = 120, bmaj = 5, bmin = 6, bpa = 0*u.deg, inj_flux = 10, noise_level=2.5):
    """
    Generate an image comprised of a source convolved with
    a gaussian kernel.

    Parameters
    ----------
    imsize_x : int, default 120
        Number of rows in the image array

    imsize_y : int, default 120
        Number of columns in the image array

    b_maj : int, default 5
        Used to calculate standard deviation of gaussian kernel in x

    b_min : int, default 6
        Used to calculate standard deviation of gaussian kernel in y

    bpa : float, default 0
        Angle of the gaussian kernel's orientation

    inj_flux : int, default 10
        Used to calculate flux of the source

    noise_level : float, default 2.5
        The noise level of the image calculated using a normal distribution

    Returns
    -------
    injected_data : numpy array
        An image array composed of gaussian noise combined with a central source
        of flux
    """
    
    inj_position = (int(imsize_x/2), int(imsize_y/2))
    
    noise_data = np.random.normal(loc=0, scale=noise_level, size=(imsize_x, imsize_y))
    
    psf_kernel = Gaussian2DKernel(bmaj/2, bmin/2, bpa)
    psf_kernel.normalize('peak')

    gaussian_noise_data = convolve(noise_data, psf_kernel)
    gaussian_noise_data *= noise_data.std()/gaussian_noise_data.std()
    
    source_inj = Gaussian2DKernel(bmaj/2, bmin/2, bpa, x_size=imsize_x, y_size=imsize_y)
    source_inj.normalize('peak')
    source_inj *= inj_flux
    injected_data = gaussian_noise_data+source_inj.array
    
    return injected_data

In [46]:
def generate_num_of_imgs(img_labels, label, inj_flux_bds, b_thresh = 1, noise_level = 2.5, imsize_x = 120, imsize_y = 120):
    """
    Generate a number of images with the corresponding
    label.

    Parameters
    ----------
    img_labels : pandas series
        The image labels

    label : int
        0 for non detections and 1 for true positives

    inj_flux_bds : tuple
        The bounds for inj_flux which is used to generate the flux of the sources

    b_thresh : int, default 1
        Threshold value to determine the standard deviation of the gaussian
        kernel

    noise_level : float, default 2.5
        The noise level of the image calculated using a normal distribution

    imsize_x : int, default 120
        Number of rows in the image array

    imsize_y : int, default 120
        Number of columns in the image array

    Returns
    -------
    img_arr : numpy array
        An array of generated images with the relevant label (either true positive
        or non detection)
    """
    
    n_gen = len(img_labels[img_labels == 2]) - len(img_labels[img_labels == label])
    inj_position = (int(imsize_x/2), int(imsize_y/2))
    img_arr = np.empty(shape=(n_gen,120,120,3))
    for n in range(n_gen):
    
        inj_flux = random.uniform(inj_flux_bds[0],inj_flux_bds[1]) # change for nnon detections perhaps 0 to 12
        bvals = random.sample(range(5-b_thresh, 5+b_thresh), 2)
        bmin = min(bvals)
        bmaj = max(bvals)
        bpa = random.uniform(0,360)*u.deg
        
        img = generate_slice(imsize_x = imsize_x, imsize_y = imsize_y, bmaj = bmaj, bmin = bmin, bpa = bpa, inj_flux = inj_flux, noise_level=noise_level)
        img = np.repeat(img[...,np.newaxis], 3, -1)
        img = img[np.newaxis,:,:]
        img = normalise_img(img)
        img_arr[n,:,:,:] = img

    return img_arr

In [47]:
def gen_rearrange_imgs_and_labels(img_slice, img_labels, random_state = 43):
    """
    Generate the simulated non detection and true positive images.
    Then combine them with the existing data obtained from the fits
    cubes and light curve data before finally shuffling the images
    and their corresponding labels.

    Parameters
    ----------
    img_slice : numpy array
        The images obtained from the light curve dataframe
        and fits files.

    img_labels : series
        The labels for the images obtained from the light curve dataframes
        and fits files.

    random_state : int, default 43
        Set a random state for reproducibility

    Returns
    -------
    img_slice : float
        The image slices combined with the simulated images
        and shuffled

    img_labels : float
        The image labels combined with the simulated images' 
        labels and shuffled
    """
    
    # Generate the images
    tp_img_arr = generate_num_of_imgs(img_labels, 1, (15,100))
    nd_img_arr = generate_num_of_imgs(img_labels, 0, (0,12))
    
    # Can just concatenate the corresponding labels
    nd_labels = pd.Series(np.zeros(nd_img_arr.shape[0]))
    tp_labels = pd.Series(np.ones(tp_img_arr.shape[0]))
    img_labels = pd.concat([img_labels,nd_labels,tp_labels], ignore_index=True)
    # concate true positive images, nd images and the relevant data frame together.
    img_slice = np.concatenate([img_slice, nd_img_arr, tp_img_arr], axis=0)
    # Randomise label indices and then randomise images usng these indices
    img_labels = img_labels.sample(frac=1, random_state = random_state, axis=0)
    img_labels = img_labels.astype(int)
    
    rearr_idx = img_labels.index
    img_slice = img_slice[rearr_idx]
    
    return img_slice, img_labels

In [48]:
final_test_img_slice, final_test_img_labels = gen_rearrange_imgs_and_labels(test_img_slice, test_img_labels)

In [49]:
final_val_img_slice, final_val_img_labels = gen_rearrange_imgs_and_labels(val_img_slice, val_img_labels)

In [50]:
final_train_img_slice, final_train_img_labels = gen_rearrange_imgs_and_labels(train_img_slice, train_img_labels)

#### Save The Image Slices With Simulated Images and Their Labels

In [51]:
def save_img_slice(img_slice_data, img_labels_data, data_class, data_fol = "CNN_Data"):
    """
    Save the images.

    Parameters
    ----------
    img_slice_data : numpy array
        The images

    img_labels : pandas series
        The labels for the images

    data_class : str
        Whether the image belongs to "Train", "Validation" or "Test" folders

    data_fol : str, default "CNN_Data"
        The folder the images will be saved in
    """
    
    f_path = f"{data_fol}/{data_class}/{data_class}"
    np.save(f_path+"_With_Sim_Imgs", img_slice_data)
    np.save(f_path+"_With_Sim_Labels", img_labels_data)

In [52]:
save_img_slice(final_test_img_slice, final_test_img_labels, "Test", data_fol = "CNN_Data")
save_img_slice(final_train_img_slice, final_train_img_labels, "Train", data_fol = "CNN_Data")
save_img_slice(final_val_img_slice, final_val_img_labels, "Validation", data_fol = "CNN_Data")