In [2]:
# %% [markdown]
# ## GUI for Quantum Harmonic Oscillator
#
# This cell creates a Tkinter GUI to interact with the neural network model for predicting
# ground state energy of a quantum harmonic oscillator.

# Enable Tkinter event loop in Jupyter
%gui tk

# Imports
import numpy as np
import pandas as pd
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Input
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import mean_squared_error, r2_score
import matplotlib.pyplot as plt
import sympy as sp
import tkinter as tk
from tkinter import filedialog, messagebox, scrolledtext
import os

# Set random seed for reproducibility
np.random.seed(42)
tf.random.set_seed(42)

# Global variables
model = None
scaler = None
X_train_scaled = None
X_test_scaled = None
y_train = None
y_test = None
history = None
dataset_loaded = False
X_test = None  # Added to store unscaled test data for validation

# Functions
def generate_dataset():
    """Generate quantum_harmonic_oscillator_data.csv if not present."""
    hbar = 1.0
    n_samples = 10000
    mass = np.random.uniform(0.1, 2.0, n_samples)
    frequency = np.random.uniform(0.5, 5.0, n_samples)
    ground_state_energy = 0.5 * hbar * frequency
    df = pd.DataFrame({
        'mass': mass,
        'frequency': frequency,
        'ground_state_energy': ground_state_energy
    })
    dataset_path = 'quantum_harmonic_oscillator_data.csv'
    df.to_csv(dataset_path, index=False)
    return dataset_path

def load_dataset():
    """Load dataset and preprocess it."""
    global model, scaler, X_train_scaled, X_test_scaled, y_train, y_test, dataset_loaded, X_test
    try:
        file_path = filedialog.askopenfilename(filetypes=[("CSV files", "*.csv")])
        if not file_path:
            return
        if not os.path.exists(file_path):
            if messagebox.askyesno("Dataset Not Found", "Generate dataset?"):
                file_path = generate_dataset()
            else:
                return
        df = pd.read_csv(file_path)
        log_text.insert(tk.END, "Dataset Preview:\n" + str(df.head()) + "\n\n")
        
        feature_columns = ['mass', 'frequency']
        target_column = 'ground_state_energy'
        if not all(col in df.columns for col in feature_columns + [target_column]):
            log_text.insert(tk.END, "Error: Required columns missing.\n")
            return
        
        X = df[feature_columns].values
        y = df[target_column].values
        X = np.nan_to_num(X, nan=np.nanmean(X, axis=0))
        y = np.nan_to_num(y, nan=np.nanmean(y))
        
        X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
        scaler = StandardScaler()
        X_train_scaled = scaler.fit_transform(X_train)
        X_test_scaled = scaler.transform(X_test)
        
        dataset_loaded = True
        log_text.insert(tk.END, f"Dataset loaded: {X_train_scaled.shape[0]} training samples, "
                        f"{X_test_scaled.shape[0]} test samples.\n")
    except Exception as e:
        log_text.insert(tk.END, f"Error loading dataset: {str(e)}\n")

def train_model():
    """Train the neural network."""
    global model, history, dataset_loaded
    if not dataset_loaded:
        log_text.insert(tk.END, "Error: Load dataset first.\n")
        return
    try:
        model = Sequential([
            Input(shape=(2,)),
            Dense(64, activation='relu'),
            Dense(32, activation='relu'),
            Dense(16, activation='relu'),
            Dense(1, activation='linear')
        ])
        model.compile(optimizer='adam', loss='mse')
        
        early_stopping = tf.keras.callbacks.EarlyStopping(
            monitor='val_loss', patience=10, restore_best_weights=True
        )
        
        log_text.insert(tk.END, "Training started...\n")
        root.update()
        history = model.fit(
            X_train_scaled, y_train,
            validation_split=0.2,
            epochs=100,
            batch_size=32,
            callbacks=[early_stopping],
            verbose=0
        )
        
        y_pred = model.predict(X_test_scaled, verbose=0).flatten()
        mse = mean_squared_error(y_test, y_pred)
        r2 = r2_score(y_test, y_pred)
        
        log_text.insert(tk.END, f"Training complete.\n"
                        f"Mean Squared Error (MSE): {mse:.6f}\n"
                        f"R² Score: {r2:.6f}\n")
    except Exception as e:
        log_text.insert(tk.END, f"Error training model: {str(e)}\n")

def plot_predictions():
    """Plot predicted vs. actual energies."""
    if model is None or y_test is None:
        log_text.insert(tk.END, "Error: Train model first.\n")
        return
    try:
        y_pred = model.predict(X_test_scaled, verbose=0).flatten()
        plt.figure(figsize=(8, 6))
        plt.scatter(y_test, y_pred, alpha=0.5, color='blue', label='Predicted vs. Actual')
        plt.plot([y_test.min(), y_test.max()], [y_test.min(), y_test.max()], 'r--', label='Ideal Fit')
        plt.xlabel('Actual Ground State Energy')
        plt.ylabel('Predicted Ground State Energy')
        plt.title('Predicted vs. Actual Ground State Energies')
        plt.legend()
        plt.grid(True)
        plt.show()
    except Exception as e:
        log_text.insert(tk.END, f"Error plotting predictions: {str(e)}\n")

def plot_loss():
    """Plot training and validation loss."""
    if history is None:
        log_text.insert(tk.END, "Error: Train model first.\n")
        return
    try:
        plt.figure(figsize=(8, 6))
        plt.plot(history.history['loss'], label='Training Loss')
        plt.plot(history.history['val_loss'], label='Validation Loss')
        plt.xlabel('Epoch')
        plt.ylabel('Mean Squared Error Loss')
        plt.title('Training and Validation Loss Over Epochs')
        plt.legend()
        plt.grid(True)
        plt.show()
    except Exception as e:
        log_text.insert(tk.END, f"Error plotting loss: {str(e)}\n")

def predict_energy():
    """Predict ground state energy for user inputs."""
    if model is None or scaler is None:
        log_text.insert(tk.END, "Error: Train model first.\n")
        return
    try:
        mass = float(entry_mass.get())
        frequency = float(entry_frequency.get())
        if mass <= 0 or frequency <= 0:
            log_text.insert(tk.END, "Error: Mass and frequency must be positive.\n")
            return
        
        input_data = np.array([[mass, frequency]])
        input_scaled = scaler.transform(input_data)
        pred_energy = model.predict(input_scaled, verbose=0)[0][0]
        
        # Analytical energy
        m_sym, omega_sym, hbar_sym = sp.symbols('m omega hbar')
        E_0_sym = (hbar_sym / 2) * omega_sym
        analytical_energy = E_0_sym.subs({hbar_sym: 1.0, omega_sym: frequency}).evalf()
        
        log_text.insert(tk.END, f"Prediction for mass = {mass:.2f}, frequency = {frequency:.2f}:\n"
                        f"  Predicted E_0 = {pred_energy:.6f}\n"
                        f"  Analytical E_0 = {analytical_energy:.6f}\n"
                        f"  Absolute Error = {abs(analytical_energy - pred_energy):.6f}\n")
    except ValueError:
        log_text.insert(tk.END, "Error: Invalid input. Enter numeric values.\n")
    except Exception as e:
        log_text.insert(tk.END, f"Error predicting energy: {str(e)}\n")

def validate_analytical():
    """Perform analytical validation on test samples."""
    if model is None or X_test_scaled is None:
        log_text.insert(tk.END, "Error: Train model first.\n")
        return
    try:
        m_sym, omega_sym, hbar_sym = sp.symbols('m omega hbar')
        E_0_sym = (hbar_sym / 2) * omega_sym
        
        test_indices = np.random.choice(X_test_scaled.shape[0], 3, replace=False)
        test_cases = X_test[test_indices]
        
        log_text.insert(tk.END, "\nAnalytical vs. Predicted Energies (Sampled Test Cases):\n")
        for i, case in enumerate(test_cases):
            omega = case[1]
            E_0_analytical = E_0_sym.subs({hbar_sym: 1.0, omega_sym: omega}).evalf()
            input_data = X_test_scaled[test_indices[i:i+1]]
            E_0_pred = model.predict(input_data, verbose=0)[0][0]
            
            log_text.insert(tk.END, f"Test Case {i+1} (mass = {case[0]:.2f}, ω = {case[1]:.2f}):\n"
                            f"  Analytical E_0 = {E_0_analytical:.6f}\n"
                            f"  Predicted E_0 = {E_0_pred:.6f}\n"
                            f"  Absolute Error = {abs(E_0_analytical - E_0_pred):.6f}\n")
    except Exception as e:
        log_text.insert(tk.END, f"Error in analytical validation: {str(e)}\n")

# Create GUI
root = tk.Tk()
root.title("Quantum Harmonic Oscillator Energy Predictor")
root.geometry("600x500")

# Dataset frame
frame_dataset = tk.Frame(root)
frame_dataset.pack(pady=10)
tk.Button(frame_dataset, text="Load Dataset", command=load_dataset).pack()

# Training frame
frame_train = tk.Frame(root)
frame_train.pack(pady=10)
tk.Button(frame_train, text="Train Model", command=train_model).pack()

# Plot frame
frame_plot = tk.Frame(root)
frame_plot.pack(pady=10)
tk.Button(frame_plot, text="Plot Predictions", command=plot_predictions).pack(side=tk.LEFT, padx=5)
tk.Button(frame_plot, text="Plot Loss", command=plot_loss).pack(side=tk.LEFT, padx=5)

# Prediction frame
frame_predict = tk.Frame(root)
frame_predict.pack(pady=10)
tk.Label(frame_predict, text="Mass:").pack(side=tk.LEFT)
entry_mass = tk.Entry(frame_predict, width=10)
entry_mass.pack(side=tk.LEFT, padx=5)
tk.Label(frame_predict, text="Frequency:").pack(side=tk.LEFT)
entry_frequency = tk.Entry(frame_predict, width=10)
entry_frequency.pack(side=tk.LEFT, padx=5)
tk.Button(frame_predict, text="Predict Energy", command=predict_energy).pack(side=tk.LEFT, padx=5)

# Validation frame
frame_validate = tk.Frame(root)
frame_validate.pack(pady=10)
tk.Button(frame_validate, text="Analytical Validation", command=validate_analytical).pack()

# Log text area
log_text = scrolledtext.ScrolledText(root, width=60, height=15)
log_text.pack(pady=10)

# Keep the GUI alive without mainloop (Jupyter handles the event loop)
root.update()