This project implements a simple Convolutional Neural Network (CNN) with PyTorch, to perform hand write digit recognition. It demonstrates the basic components of CNNs such as convolution, pooling, activation functions, and fully connected layers.
.
├── cnn_utils/ # Utility functions for model, data processing, visualization, etc.
├── data/ # training and test dataset
├── visualization/ # Visualization examples (training curves, predictions, etc.)
├── digi_class.py # Main entry point
├── requirements.txt # Python dependencies
├── venv.sh # Virtual environment setup script
└── README.md
- Basic CNN layers implemented with PyTorch
- Forward and backward propagation for training
- Simple image classification (e.g., handwritten digits)
- Early stopping and dropout to avoid overfitting.
- Support fine-tuning and resuming training from checkpoints (with model saving)
- Automatic detection of CUDA GPU acceleration with CPU fallback
- Visualization tools for training history curves and prediction
- Modular design for easy extension
Clone the repository and install dependencies:
git clone https://github.com/swangarch/convolutional_neural_network.git
cd convolutional_neural_network
# Create virtual environment and install dependencies
bash venv.sh
source venv/bin/activate
Usage:
Visualize train data:
python digi_class.py -v <train_data_csv>
Train a model:
python digi_class.py -t <train_data_csv> [weights.pth]
Predict with a model:
python digi_class.py -p <test_data_csv> [weights.pth]
Notes:
- Training data must be a CSV file where the first column is the label, followed by 784 columns representing a 28x28 image.
- Test data must be a CSV file with 784 columns only (no label column).
- If weights are provided during training, the process will perform fine-tuning.
- If weights are provided during prediction, the trained model will be used. Otherwise, an untrained model will be applied (not recommended).
- The program will automatically check for CUDA support; if unavailable, it will fall back to CPU execution.
After training:
- The program will display training and validation accuracy curves (so you can visualize model performance over epochs).
After prediction: The program will:
-
Save prediction results into a file named predictions.csv.
-
Show example images (sampled from the test set) along with their predicted labels.
98% accuracy on Fashion MNIST dataset
Prediciton example:
88% accuracy on Fashion MNIST dataset
- 0 T-shirt/top
- 1 Trouser
- 2 Pullover
- 3 Dress
- 4 Coat
- 5 Sandal
- 6 Shirt
- 7 Sneaker
- 8 Bag
- 9 Ankle boot
Prediciton example: