This project provides a Python implementation of a simple neural network with a visualization of its training process, similar to the JavaScript version. It's designed to help beginners understand the fundamental concepts of neural networks, training, and backpropagation using Python and matplotlib
.
- Tkinter GUI: An interactive graphical user interface for controlling the neural network.
- Neural Network Implementation: A basic feedforward neural network with a single hidden layer, using the sigmoid activation function.
- Clustered Training Data: Generates a synthetic dataset with two distinct clusters, making the classification task easier to visualize.
- Real-time Training Visualization:
- Training Data & Decision Boundary: Displays the data points and how the network's decision boundary evolves during training.
- Training Metrics Graphs: Tracks and plots Loss (Mean Squared Error) and Accuracy over epochs.
- Configurable Parameters: Adjust the learning rate and number of epochs directly from the GUI.
- Manual Input Prediction: Enter custom X and Y coordinates to get real-time predictions from the trained network and visualize the input point on the data plot.
The project consists of a single Python file:
main.py
: Contains the entire logic:NeuralNetwork
Class: Implements the neural network architecture, including:- Initialization of weights and biases.
- Sigmoid activation function and its derivative.
feedforward
method for making predictions.train
method for performing backpropagation and updating weights/biases.
generate_training_data
Function: Creates a dataset with two distinct, clustered classes.calculate_loss
andcalculate_accuracy
Functions: Metrics to evaluate the network's performance.- Tkinter GUI: Sets up the main application window, control buttons (Train, Reset, Predict), and input fields for learning rate, epochs, and manual X/Y coordinates.
- Matplotlib Integration: Uses
matplotlib.backends.backend_tkagg.FigureCanvasTkAgg
to embedmatplotlib
plots directly into the Tkinter GUI. - Visualization Logic: Uses
matplotlib
to create two subplots:- One for visualizing the training data points, the network's decision boundary, and the manual input point.
- Another for plotting the loss and accuracy over training iterations.
FuncAnimation
: Used to animate the training process, updating the plots in real-time within the GUI.
To run this project locally, you'll need Python installed, along with numpy
and matplotlib
. Tkinter is usually included with Python installations.
- Clone or Download: Get the project files to your local machine.
- Install Dependencies: Open your terminal or command prompt, navigate to the
python_neural_network
directory, and install the required libraries:pip install numpy matplotlib
- Run the Script: Execute the
main.py
file:python main.py
Once you run python main.py
:
- A Tkinter GUI window will appear, displaying controls on the left and the visualizations on the right.
- Controls Panel:
- Train Network: Click to start the neural network training process. The plots will update in real-time.
- Reset Network: Click to re-initialize the neural network with new random weights and a fresh dataset.
- Learning Rate: Adjust the learning rate using the input field. Changes take effect after clicking "Train Network" or "Reset Network".
- Epochs: Set the total number of training epochs.
- Test Manual Input:
- Enter
X
andY
coordinates. - Click "Predict" to see the network's prediction for that specific point.
- The manual input point will be highlighted in orange on the "Training Data & Decision Boundary" plot.
- Enter
- Visualization Panels:
- Training Data & Decision Boundary (Left Plot):
- Red and blue data points represent the two classes.
- The shaded background shows the network's current decision boundary (light red for class 0, light green for class 1).
- Small text next to each point shows the network's prediction.
- Your manual input point will appear as a large orange circle.
- Training Metrics (Middle Plot):
- The red line shows the
Loss
(Mean Squared Error), which should decrease over time. - The blue line shows
Accuracy
, which should increase as the network learns to classify correctly.
- The red line shows the
- Neural Network Structure (Right Plot):
- Visualizes the input, hidden, and output layers as circles.
- Numbers inside nodes represent their activation values.
- Connections (weights) are shown as lines, colored green for positive and red for negative, with thickness indicating magnitude.
- Node colors (light green) indicate activation above 0.5.
- Training Data & Decision Boundary (Left Plot):
- Observe Training: All three plots will update dynamically during training, showing the network's progress and internal state.
This Python version with a Tkinter GUI provides a fully interactive and self-contained application for exploring neural network fundamentals.