<a href="https://colab.research.google.com/github/vignesh-0510/SolarFlareExplainableWindowDetection/blob/main/phase_3_localization.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install shap imbalanced-learn

In [1]:
from sklearn.ensemble import RandomForestClassifier
from sklearn.tree import DecisionTreeClassifier
from imblearn.ensemble import EasyEnsembleClassifier, BalancedRandomForestClassifier
from sklearn.metrics import confusion_matrix, classification_report, accuracy_score, f1_score, precision_score, recall_score, balanced_accuracy_score, jaccard_score
from imblearn.metrics import geometric_mean_score
from sklearn.model_selection import train_test_split
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import os
import torch
import tqdm
import pickle
from collections import deque

In [None]:
def extract_start_end_time(filename):
  """Extracts start and end time from a filename string.

  Args:
    filename: The filename string in the format
      'FQ_ar146_s2010-08-29T17:12:00_e2010-08-30T05:00:00.csv'.

  Returns:
    A tuple containing the start time and end time strings, or None if the
    filename does not match the expected format.
  """
  try:
    parts = filename.split('_')
    start_time_part = parts[2][1:]  # Remove 's' prefix
    end_time_part = parts[3][1:]   # Remove 'e' prefix
    return start_time_part, end_time_part
  except IndexError:
    return None, None

def process_file(filepath, interacting_columns, meta_dict, frequency_modes= 10):
  modes = frequency_modes // 2
  data_arr = np.zeros((frequecy_modes*len(interacting_columns)))
  data_col_list = []
  for i, col in enumerate(interacting_columns):
    y = df[col].values
    y_f = torch.fft.rfft(torch.tensor(y))
    y_f = torch.cat((y_f[:modes], y_f[-modes:]))
    data_col_list.extend([f'{col}_real_{c}' for c in range(frequency_modes)])
    data_arr[i*frequency_modes: (i+1)*frequency_modes] = torch.real(y_f).numpy()
  result_df = pd.DataFrame.fromarray(data_arr, columns=data_col_list)
  return result_df



In [None]:
file_name = ''
model = None
interacting_columns = ['USFLUX', 'R_VALUE', 'TOTBSQ']

start_time, end_time = extract_start_end_time(file_name)
df = pd.read_csv(filepath, sep='\t', usecols=interacting_columns)

In [None]:
class Node:
  def __init__(self, start=None, size=None, confidence=None, parent=None):
    self.window_start = start
    self.window_size = size
    self.window_end = start + size
    self.confidence = confidence
    self.parent = parent

In [None]:
window_start = 0
window_size = df.shape[0]
step_size = window_size // 10

freq_modes = 10


In [None]:
def run_analysis(df):
  freq_df = process_file(df, interacting_columns, freq_modes)
  c = model.predict_proba(freq_df)
  print(f'confidence: {c}')
  print(f'normalized_conf {1-(c[0] - df['class'].values)}')
  return c[0]

In [None]:
parent_window_start = 0
parent_window_size = window_size
parent_conf = None

q = deque()
aux_q = deque()

q.append(Node(parent_window_start, parent_window_size, parent_conf, None))
while parent_window_size >= freq_modes:
  if parent_conf is None:
    parent_conf = run_analysis(df)
    continue

  child_1_start = parent_window_start
  child_2_start = parent_window_start + step_size
  child_window_size = parent_window_size - step_size

  child_1_df = df.iloc[child_1_start:child_1_start + child_window_size]
  child_2_df = df.iloc[child_2_start:child_2_start + child_window_size]

  child_1_conf = run_analysis(child_1_df)
  child_2_conf = run_analysis(child_2_df)

  if child_1_conf > child_2_conf and child_1_conf >= parent_conf:
    parent_window_start = child_1_start
    parent_window_size = child_window_size
    parent_conf = child_1_conf
  elif child_2_conf >= child_1_conf and child_2_conf >= parent_conf:
    parent_window_start = child_2_start
    parent_window_size = child_window_size
    parent_conf = child_2_conf
  elif parent_conf > child_1_conf and parent_conf > child_2_conf:
    break
  else:
    print(f'parent: {parent_conf}, child 1: {child_1_conf}, child 2: {child_2_conf}')
    break