In [None]:
# @title

import importlib.util
import os

def install_and_import(package, alias=None):
    try:
        module = importlib.import_module(package)
    except ImportError:
        os.system(f"pip install {package}")
        module = importlib.import_module(package)
    if alias:
        globals()[alias] = module

install_and_import('pandas', 'pd')
install_and_import('numpy', 'np')
install_and_import('seaborn', 'sns')
install_and_import('matplotlib.pyplot', 'plt')
install_and_import('sklearn')
install_and_import('ipywidgets')

import warnings
warnings.filterwarnings('ignore')

from matplotlib.colors import ListedColormap, Normalize
from matplotlib.lines import Line2D


from sklearn.model_selection import *
from sklearn.metrics import *
from sklearn.preprocessing import *
from sklearn.base import ClassifierMixin, RegressorMixin, clone
from sklearn.decomposition import PCA


import ipywidgets as widgets
from ipywidgets import *
from IPython.display import display, clear_output
import inspect


colors = ['#012d9c', '#39b7ff', '#ffb310']

In [None]:
def data_load_widget():

  # Create a Dropdown for selecting the dataset
  datasets = {
      "Anscombe's Quartet": "https://raw.githubusercontent.com/the-codingschool/TRAIN-datasets/main/anscombe_quartets/Anscombe_quartet_data.csv",
      "Star Classes": "https://raw.githubusercontent.com/the-codingschool/TRAIN-datasets/main/stars/stars.csv",
  }

  dataset_selector = widgets.Dropdown(
      options=datasets.keys(),
      description='Datasets',
      disabled=False,
  )

  # Create a Button to load the selected dataset
  load_button = widgets.Button(description="Load Data")

  # This output widget will be used to display messages or the loaded DataFrame
  output = widgets.Output()

  @output.capture(clear_output=True) # This decorator helps to clear the output below the button and show only the latest action
  def load_data(b):

      global df

      # Get the selected dataset's URL
      dataset_url = datasets[dataset_selector.value]
      try:
          # Load the dataset
          df = pd.read_csv(dataset_url)
          print(f"Dataset '{dataset_selector.value}' loaded successfully!")

      except Exception as e:
          print(f"Failed to load the dataset: {e}")

  # Attach the event handler to the button
  load_button.on_click(load_data)

  # Display the output widget below the button
  return widgets.VBox([widgets.HBox([dataset_selector, load_button]), output])



def data_view_widget():

  global df

  out = widgets.Output()

  def view_data(b):
    with out:
        clear_output(wait = True)
        display(df)

  button = widgets.Button(description=f'View Data')
  button.on_click(view_data)

  # Create a scrollable layout
  scrollable = widgets.VBox([out], layout={'overflow': 'auto', 'height': '200px', 'border': '1px solid black'})
  return widgets.VBox([button, scrollable])



def eda_widget():

  # Buttons for DataFrame methods/attributes
  info_button = widgets.Button(description="Show Info")
  describe_button = widgets.Button(description="Describe Data")
  corr_button = widgets.Button(description="Correlations")

  # Display EDA buttons
  ws = widgets.HBox([info_button, describe_button, corr_button])

  # Placeholder for EDA output
  eda_output = widgets.Output()
  ws = widgets.VBox([ws, eda_output])

  # Event handlers for EDA buttons
  def show_info(b):
      with eda_output:
          clear_output(wait=True)
          display(df.info())

  def describe_data(b):
      with eda_output:
          clear_output(wait=True)
          display(df.describe())

  def correlation_heatmap(b):
      with eda_output:
        clear_output(wait=True)

        correlation_matrix = df.corr()
        sns.heatmap(correlation_matrix, annot=True, linewidths=0.5, cmap = colors, annot_kws={"size": 12})

        plt.title('Correlation Matrix', weight = 'bold', fontsize = 18)
        plt.tight_layout()
        plt.show()


  info_button.on_click(show_info)
  describe_button.on_click(describe_data)
  corr_button.on_click(correlation_heatmap)

  return ws



def visualize_widget():

  visualization_output = widgets.Output()

  def choose_plot(b):

    with visualization_output:

      global x_var_selector, y_var_selector, plot_type_selector

      clear_output(wait = True)
      x_var_selector = widgets.Dropdown(options=df.columns, description='X-axis')
      y_var_selector = widgets.Dropdown(options=df.columns, description='Y-axis')
      plot_type_selector = widgets.Dropdown(options=['line', 'scatter', 'bar'], description='Plot Type')

      display(widgets.VBox([x_var_selector, y_var_selector, plot_type_selector]))

      plot_button.disabled = False


  choose_plot_button = widgets.Button(description='Choose Plot')
  choose_plot_button.on_click(choose_plot)

  plot_button = widgets.Button(description="Plot", disabled = True)

  # Plotting function
  def create_plot(b):
      with visualization_output:
          clear_output(wait=True)
          x = df[x_var_selector.value]
          y = df[y_var_selector.value]
          plt.figure(figsize=(10, 6))

          if plot_type_selector.value == 'line':
              plt.plot(x, y, color = colors[0])
          elif plot_type_selector.value == 'scatter':
              plt.scatter(x, y, color = colors[0])
          elif plot_type_selector.value == 'bar':
              plt.bar(x, y, color = colors[0])
              plt.xticks(rotation = 90)

          plt.xlabel(x_var_selector.value, fontsize = 14, fontweight = 'bold')
          plt.ylabel(y_var_selector.value, fontsize = 14, fontweight = 'bold')
          plt.title(f'{x_var_selector.value} vs. {y_var_selector.value}', fontsize = 16, fontweight = 'bold')

          plt.show()

  plot_button.on_click(create_plot)

  return widgets.VBox([choose_plot_button, visualization_output, plot_button])



# @title

def variable_choose_widget():

  # Deciding role of each variable
  pick_variables_button = widgets.Button(description="Pick Variables")
  train_split = widgets.FloatSlider(value=0.80, min=0.05, max=0.95, step=0.05, description='Train %:')
  split_button = widgets.Button(description="Split Data", disabled = True)

  output = widgets.Output()

  def pick_variables(b):
    global variable_roles

    variable_roles = {col: widgets.Dropdown(options=['Feature', 'Label', 'Don\'t Use'],
                                            value='Don\'t Use',
                                            description=col)
                      for col in df.columns}

    with output:
      display(widgets.VBox(list(variable_roles.values())))


    split_button.disabled = False

  pick_variables_button.on_click(pick_variables)


  # Splitting data
  def split_data(b):

    global X_train, X_test, y_train, y_test

    features = []
    label = None

    for col in df.columns:
      if variable_roles[col].value == 'Feature': features += [col]
      elif variable_roles[col].value == 'Label': label = col

    with output:
      if label == None:
        display('You must choose a label to proceed.')
        return

      if len(features) == 0:
        display('You must choose at least one feature to proceed.')
        return

      X_train, X_test, y_train, y_test = train_test_split(df[features], df[label], test_size = 1 - train_split.value, random_state = 42)

      display('Training set has: ' + str(len(X_train)) + ' data points.')
      display('Test set has: ' + str(len(X_test)) + ' data points.')

  split_button.on_click(split_data)

  ws = widgets.HBox([pick_variables_button, output])
  ws = widgets.VBox([train_split, ws, split_button])
  return ws




# @title
from sklearn.linear_model import LogisticRegression, LinearRegression
from sklearn.neighbors import KNeighborsClassifier, KNeighborsRegressor

def create_model_widget():

  # Model selection dropdown
  model_options = {
      'KNN (Classification)': KNeighborsClassifier,
      'KNN (Regression)': KNeighborsRegressor,
      'Logistic Regression (Classification)': LogisticRegression,
      'Linear Regression (Regression)': LinearRegression
  }

  model_selector = widgets.Dropdown(options=model_options.keys(), description='Model:')
  model_import_button = widgets.Button(description="Import Model")

  # Placeholder for the hyperparameters UI
  output = widgets.Output()


  # Function to handle model import and display hyperparameters
  def on_import_model_clicked(b):
      with output:
          clear_output(wait=True)
          selected_model_class = model_options[model_selector.value]
          display_hyperparameters(selected_model_class)

  model_import_button.on_click(on_import_model_clicked)

  # Function to display hyperparameters and initialize model
  def display_hyperparameters(model_class):
      params = inspect.signature(model_class.__init__).parameters
      hyperparam_widgets = []

      # This currently outputs all possible parameters
      # We could use a dictionary here to just output the parameters we want
      # So we don't overwhelm students
      for name, param in params.items():
          if name == 'self' or param.default is param.empty:
              continue
          elif isinstance(param.default, bool):
              widget = widgets.Checkbox(value=param.default, description=name)
          elif isinstance(param.default, int):
              widget = widgets.IntText(value=param.default, description=name)
          elif isinstance(param.default, float):
              widget = widgets.FloatText(value=param.default, description=name)
          else:
              widget = widgets.Text(value=str(param.default), description=name)

          hyperparam_widgets.append(widget)

      initialize_button = widgets.Button(description="Initialize Model")

      # Add the initialize button to the list of widgets
      hyperparam_widgets.append(initialize_button)

      with output:
        clear_output(wait = True)
        display(widgets.VBox(hyperparam_widgets))

      # Define the initialize model function inside to capture the current hyperparam_widgets
      def initialize_model(b):

          global model

          hyperparams = {}
          for widget in hyperparam_widgets[:-1]:  # Exclude the last widget (Initialize Model button)
              value = widget.value
              if isinstance(widget, widgets.Text):
                  value = str(value)  # Ensure string values are correctly captured

              if value == 'None': value = None
              hyperparams[widget.description] = value

          model = model_class(**hyperparams)

          with output:
            clear_output(wait = True)
            display(f"Model initialized with hyperparameters: {hyperparams}")


      initialize_button.on_click(initialize_model)

  return widgets.VBox([model_selector, model_import_button, output])




# @title

def model_train_widget():
  train_button = widgets.Button(description="Train")
  make_predictions_button = widgets.Button(description="Make Predictions", disabled = True)
  evaluate_button = widgets.Button(description="Evaluate", disabled = True)
  output = widgets.Output()


  def train_model(b):

      model.fit(X_train, y_train)

      make_predictions_button.disabled = False
      evaluate_button.disabled = False


  def plot_decision_boundary(model, X_test, y_test):


      if len(X_test.columns) == 1:

          color_map = ListedColormap(colors)

          X_range = np.linspace(X_test.min(), X_test.max(), 500).reshape(-1, 1)
          y_pred_range = model.predict(X_range)

          # Plot the original scatter plot
          predictions = model.predict(X_test)
          plt.scatter(X_test, predictions, c=y_test, cmap=color_map, edgecolor='k', s=20)

          # Normalize class labels to [0, 1] for color mapping
          # This assumes y_pred_range contains class labels like 0, 1, 2, etc.
          # We need to normalize these labels to work with the colormap
          norm = Normalize(vmin=np.min(y_pred_range), vmax=np.max(y_pred_range))

          # Plot decision boundary
          # Decision boundary and coloring logic
          prev_x = X_range[0]
          prev_y = y_pred_range[0]
          for x, y in zip(X_range[1:], y_pred_range[1:]):
              if y != prev_y:
                  color = color_map(norm(prev_y))
                  plt.fill_betweenx([plt.ylim()[0], plt.ylim()[1]], prev_x, x, color=color, alpha=0.3)
                  prev_y = y
              prev_x = x

          # Ensure the last segment is also colored
          color = color_map(norm(prev_y))
          plt.fill_betweenx([plt.ylim()[0], plt.ylim()[1]], prev_x, X_range[-1], color=color, alpha=0.3)

          plt.xlabel(X_test.columns[0])
          plt.ylabel('Class')
          plt.title('Decision Boundary')
          plt.show()

      elif len(X_test.columns) == 2:

        # Create a grid to cover the plot area
        x_min, x_max = X_test[:, 0].min() - 1, X_test[:, 0].max() + 1
        y_min, y_max = X_test[:, 1].min() - 1, X_test[:, 1].max() + 1
        xx, yy = np.meshgrid(np.arange(x_min, x_max, (x_max - x_min)/100),
                            np.arange(y_min, y_max, (y_max - y_min)/100))

        # Predict classifications for each point in the meshgrid
        Z = model.predict(np.c_[xx.ravel(), yy.ravel()]).reshape(xx.shape)

        # Plot the decision boundary
        plt.contourf(xx, yy, Z, alpha=0.4, cmap=ListedColormap(colors))
        plt.scatter(X_test[:, 0], X_test[:, 1], c=y_test, s=20, edgecolor='k')

        plt.xlim([x_min, x_max])
        plt.ylim([y_min, y_max])

        plt.title('Decision Boundary')
        plt.xlabel(X_test.columns[0])
        plt.ylabel(X_test.columns[1])
        plt.show()

      else:

        # Step 1: Apply PCA to reduce dimensionality
        pca = PCA(n_components=2)
        X_train_pca = pca.fit_transform(X_train)  # Fit and transform X_train
        X_test_pca = pca.transform(X_test)  # Transform X_test using the same PCA

        # Optional Step 2: Re-train your classifier on the PCA-transformed training data
        model_pca = clone(model)
        model_pca.fit(X_train_pca, y_train)

        # Step 3: Visualize the decision boundary on the PCA-transformed test data
        x_min, x_max = X_test_pca[:, 0].min() - 1, X_test_pca[:, 0].max() + 1
        y_min, y_max = X_test_pca[:, 1].min() - 1, X_test_pca[:, 1].max() + 1
        xx, yy = np.meshgrid(np.arange(x_min, x_max, (x_max - x_min)/100),
                            np.arange(y_min, y_max, (y_max - y_min)/100))

        # Predict using the PCA-transformed grid
        Z = model_pca.predict(np.c_[xx.ravel(), yy.ravel()]).reshape(xx.shape)

        # Plot the decision boundary
        plt.contourf(xx, yy, Z, alpha=0.4, cmap=ListedColormap(colors))
        plt.scatter(X_test_pca[:, 0], X_test_pca[:, 1], c=y_test, s=20, edgecolor='k', cmap=ListedColormap(colors))

        plt.xlim([x_min, x_max])
        plt.ylim([y_min, y_max])

        plt.title('Decision Boundary after PCA')
        plt.xlabel('Principal Component 1')
        plt.ylabel('Principal Component 2')

        # Create custom legend handles
        legend_handles = [Line2D([0], [0], marker='o', color='w', markerfacecolor=col, markersize=10, label=label)
                          for col, label in zip(colors, model.classes_)]
        plt.legend(handles=legend_handles, title="Classes")

        plt.show()



  def plot_predictions_vs_real_values(model, X_test, y_test):

        predictions = model.predict(X_test)

        plt.scatter(y_test, predictions)
        plt.plot([y_test.min(), y_test.max()], [y_test.min(), y_test.max()], color = 'black', label="Correct Prediction")


        plt.xlabel('True Value', fontsize = 'x-large')
        plt.ylabel('Predicted Value', fontsize = 'x-large')
        plt.title("Real vs. Predicted Value", fontsize = 'x-large')
        plt.legend()

        plt.show()



  def make_predictions(b):

      with output:
        clear_output(wait=True)

        task_type = None

        if isinstance(model, ClassifierMixin): task_type = 'Classification'
        elif isinstance(model, RegressorMixin): task_type = 'Regression'
        else:

          def decide_type(task_b):
            task_type = task_b.value

          task_widget = widgets.Dropdown(options=['Classification', 'Regression'], description='What type of task is this?')
          display(task_widget)

          task_widget.on_click(decide_type)

        if task_type == 'Classification': plot_decision_boundary(model, X_test, y_test)
        else: plot_predictions_vs_real_values(model, X_test, y_test)



  def evaluate_model(b):

      with output:
        clear_output(wait=True)

        predictions = model.predict(X_test)

        task_type = None

        if isinstance(model, ClassifierMixin): task_type = 'Classification'
        elif isinstance(model, RegressorMixin): task_type = 'Regression'
        else:

          def decide_type(task_b):
            task_type = task_b.value

          task_widget = widgets.Dropdown(options=['Classification', 'Regression'], description='What type of task is this?')
          display(task_widget)

          task_widget.on_click(decide_type)


        if task_type == 'Classification':

          display(classification_report(y_test, predictions, labels = model.classes_))

          cm = confusion_matrix(y_test, predictions, labels=model.classes_)
          disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=model.classes_)
          disp.plot()

          plt.xticks(rotation=90)
          plt.show()

        else:

          display("R-squared: " + str(r2_score(y_test, predictions)))
          display("Mean Squared Error: " + str(mean_squared_error(y_test, predictions)))
          display("Mean Absolute Error: " + str(mean_absolute_error(y_test, predictions)))



  train_button.on_click(train_model)
  make_predictions_button.on_click(make_predictions)
  evaluate_button.on_click(evaluate_model)

  return widgets.VBox([widgets.HBox([train_button, make_predictions_button, evaluate_button]), output])



# @title

def use_model_widget():


    input_pred_button = widgets.Button(description = 'Input Prediction Values')
    predict_button = widgets.Button(description="Predict", disabled = True)
    output = widgets.Output()

    input_widgets = {}

    def input_pred(b):

      feature_names = X_test.columns.tolist()
      feature_types = feature_types = ['numeric' if dtype in ['float64', 'int64'] else 'categorical' for dtype in X_test.dtypes]



      # Container for input widgets
      global input_widgets
      input_widgets = {}

      # Generate input widgets for each feature
      for feature, f_type in zip(feature_names, feature_types):
          if f_type == 'numeric':
              # Use FloatText for numeric inputs
              widget = widgets.FloatText(
                  value=0.0,
                  description=feature,
                  continuous_update=False
              )
          else:
              # Use Text for categorical inputs
              widget = widgets.Text(
                  value='',
                  description=feature,
                  continuous_update=False
              )
          input_widgets[feature] = widget

      with output:
        display(widgets.VBox(list(input_widgets.values())))


      predict_button.disabled = False

    input_pred_button.on_click(input_pred)


    # Function to handle prediction on button click
    def on_predict_clicked(b):
        with output:
            global input_widgets

            clear_output(wait=True)
            # Collect input values
            input_data = [widget.value for widget in input_widgets.values()]
            print(input_data)
            print(input_widgets)
            # Reshape input data into a single sample
            input_array = np.array(input_data).reshape(1, -1)
            # Make prediction
            prediction = model.predict(input_array)
            print(f"Prediction: {prediction[0]}")


    predict_button.on_click(on_predict_clicked)

    return widgets.VBox([input_pred_button, output, predict_button])



def generic_section():  return widgets.Text('Hi')

def data_loaded(): return 'df' in globals()
def data_split(): return 'X_train' in globals()


def create_section(section_number):

    section_title = section_titles[section_number]
    section_widget = section_widgets[section_number]()
    section_widget.layout = layout=widgets.Layout(align_items='center', justify_content='center')

    section_condition = section_conditions[section_number]


    # Define the HTML for the title with centered text
    title_html = f"<h2 style='text-align: center;'>{section_title}</h2>"
    title = widgets.HTML(value=title_html)

    output = widgets.Output()
    ws = [title, section_widget]

    # Create the next button for all sections except the last
    if section_number < total_sections - 1:

      def on_button_clicked(b):
        if section_condition():
            sections[section_number + 1].layout.display = 'flex'
            b.disabled = True
        else:
          with output:
            clear_output(wait = True)
            print('You haven\'t done everything necessary to move onto the next section!')

      button = widgets.Button(description='Next Section')
      button.on_click(on_button_clicked)

      button.style.button_color = '#98dcff'

      ws += [widgets.HBox([button, output])]


    # Combine title, button, and output into a VBox
    section_box = widgets.VBox(ws, layout=widgets.Layout(align_items='center', justify_content='center', border='2px solid black', padding='10px'))

    return section_box




section_titles = ['Load Data', 'View Data', 'Exploratory Data Analysis', 'Visualize', 'Choose Variables', 'Create Model', 'Train Model', 'Use Model']
total_sections = len(section_titles)

section_widgets = [data_load_widget, data_view_widget, eda_widget, visualize_widget, variable_choose_widget, create_model_widget, model_train_widget, use_model_widget]
section_conditions = [data_loaded, data_loaded, data_loaded, data_loaded, data_loaded, data_split, data_split, data_split]

sections = [create_section(i) for i in range(total_sections)]

for i in range(1, total_sections):
    sections[i].layout.display = 'none'

grid = widgets.GridBox(sections, layout=widgets.Layout(grid_template_columns="repeat(2, 1fr)", grid_gap='20px'))
title = widgets.HTML("<h1 style='text-align: center; background-color: #012d9c; padding: 10px; color: #ffffff'>TRAIN Dataset Explorer</h1>")
display(widgets.VBox([title, grid]))