In [3]:
import ipywidgets as widgets
from IPython.display import display, clear_output
import subprocess
import pathlib
import pandas as pd
import matplotlib.pyplot as plt

# ---------- Widgets for user input ----------
dataset_widget = widgets.Dropdown(options=['fashion', 'cifar10'], description='Dataset:')
epochs_widget = widgets.IntText(value=5, description='Epochs:')
batch_size_widget = widgets.IntText(value=64, description='Batch Size:')
lr_widget = widgets.FloatText(value=0.001, description='LR:')
optimizer_widget = widgets.Dropdown(options=['sgd', 'momentum', 'adam', 'nadam'], description='Optimizer:')
activation_widget = widgets.Dropdown(options=['relu', 'sigmoid', 'tanh', 'gelu', 'selu'], description='Activation:')
hidden_layers_widget = widgets.Text(value='256,128', description='Hidden Layers:')
dropout_widget = widgets.FloatText(value=0.0, description='Dropout:')
batch_norm_widget = widgets.Checkbox(value=False, description='Batch Norm')

run_button = widgets.Button(description='Run Training')
output = widgets.Output()

# ---------- Function to run train.py ----------
def run_training(b):
    with output:
        clear_output()
        print("Running training...")

        # Build command
        cmd = [
            'python', 'train.py',
            '--dataset', dataset_widget.value,
            '--epochs', str(epochs_widget.value),
            '--batch-size', str(batch_size_widget.value),
            '--lr', str(lr_widget.value),
            '--optimizer', optimizer_widget.value,
            '--activation', activation_widget.value,
            '--hidden-layers', hidden_layers_widget.value,
            '--dropout', str(dropout_widget.value),
            '--batch-norm', str(batch_norm_widget.value),
            '--log-dir', 'logs'
        ]
        # Look for CSV log file
        log_file = pathlib.Path('logs') / f"{dataset_widget.value}_metrics.csv"
        if log_file.exists():
            log_file.unlink()
        # Run training
        subprocess.run(cmd)


        if log_file.exists():
            df = pd.read_csv(log_file)
            print("Training finished. Displaying graphs...")

            plt.figure(figsize=(12,5))

            # Loss
            plt.subplot(1,2,1)
            plt.plot(df['epoch'], df['train_loss'], label='Train Loss')
            plt.plot(df['epoch'], df['val_loss'], label='Val Loss')
            plt.xlabel('Epoch')
            plt.ylabel('Loss')
            plt.title('Loss over Epochs')
            plt.legend()

            # Accuracy
            plt.subplot(1,2,2)
            plt.plot(df['epoch'], df['train_acc'], label='Train Acc')
            plt.plot(df['epoch'], df['val_acc'], label='Val Acc')
            plt.xlabel('Epoch')
            plt.ylabel('Accuracy')
            plt.title('Accuracy over Epochs')
            plt.legend()

            plt.show()
        else:
            print("Log file not found. Make sure CSVLogger is enabled in train.py.")

# ---------- Display widgets ----------
run_button.on_click(run_training)
display(widgets.VBox([
    dataset_widget,
    epochs_widget,
    batch_size_widget,
    lr_widget,
    optimizer_widget,
    activation_widget,
    hidden_layers_widget,
    dropout_widget,
    batch_norm_widget,
    run_button,
    output
]))


VBox(children=(Dropdown(description='Dataset:', options=('fashion', 'cifar10'), value='fashion'), IntText(valuâ€¦