Skip to content

This project implements a simple Convolutional Neural Network (CNN) framework with PyTorch, to perform hand write digit recognition.

Notifications You must be signed in to change notification settings

swangarch/convolutional_neural_network

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

4 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Convolutional Neural Network (CNN)

This project implements a simple Convolutional Neural Network (CNN) with PyTorch, to perform hand write digit recognition. It demonstrates the basic components of CNNs such as convolution, pooling, activation functions, and fully connected layers.


Project Structure

.
├── cnn_utils/ # Utility functions for model, data processing, visualization, etc.
├── data/ # training and test dataset
├── visualization/ # Visualization examples (training curves, predictions, etc.)
├── digi_class.py # Main entry point
├── requirements.txt # Python dependencies
├── venv.sh # Virtual environment setup script
└── README.md

Features

  • Basic CNN layers implemented with PyTorch
  • Forward and backward propagation for training
  • Simple image classification (e.g., handwritten digits)
  • Early stopping and dropout to avoid overfitting.
  • Support fine-tuning and resuming training from checkpoints (with model saving)
  • Automatic detection of CUDA GPU acceleration with CPU fallback
  • Visualization tools for training history curves and prediction
  • Modular design for easy extension

Installation

Clone the repository and install dependencies:

git clone https://github.com/swangarch/convolutional_neural_network.git
cd convolutional_neural_network

# Create virtual environment and install dependencies
bash venv.sh

source venv/bin/activate

Usage

Usage:

Visualize train data:

python digi_class.py -v <train_data_csv>

Train a model:

python digi_class.py -t <train_data_csv> [weights.pth]

Predict with a model:

python digi_class.py -p <test_data_csv> [weights.pth]

Notes:

  • Training data must be a CSV file where the first column is the label, followed by 784 columns representing a 28x28 image.
  • Test data must be a CSV file with 784 columns only (no label column).
  • If weights are provided during training, the process will perform fine-tuning.
  • If weights are provided during prediction, the trained model will be used. Otherwise, an untrained model will be applied (not recommended).
  • The program will automatically check for CUDA support; if unavailable, it will fall back to CPU execution.

Visualization

After training:

  • The program will display training and validation accuracy curves (so you can visualize model performance over epochs).

Loss curve Accuracy curve

After prediction: The program will:

  • Save prediction results into a file named predictions.csv.

  • Show example images (sampled from the test set) along with their predicted labels.


MNIST hand written digit recognition dataset

98% accuracy on Fashion MNIST dataset

Prediciton example:

Data visualization demo


Fashion MNIST classification dataset

88% accuracy on Fashion MNIST dataset

  • 0 T-shirt/top
  • 1 Trouser
  • 2 Pullover
  • 3 Dress
  • 4 Coat
  • 5 Sandal
  • 6 Shirt
  • 7 Sneaker
  • 8 Bag
  • 9 Ankle boot

Prediciton example:

Data visualization demo

About

This project implements a simple Convolutional Neural Network (CNN) framework with PyTorch, to perform hand write digit recognition.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published