Skip to content

Experiments with Neural Ordinary Differential Equations on image and text classification tasks

License

Notifications You must be signed in to change notification settings

saparina/neural-ode

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

53 Commits
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Experiments with Neural Ordinary Differential Equations on image and text classification tasks

For image classification we use ResNet model and MNIST and CIFAR-10 datasets, while for text classifiacation we use VdCNN model and Ag-News dataset.

Requirments

  • PyTorch >= 1.0
  • NumPy
  • TensorFlow==1.13.1

Spiral experiment

Run ODE:

PYTHONPATH=. python ./experiments/spiral-torch.py # pytorch
python -m experiments.spiral_tf # tensorflow

Result

spiral

MNIST classification

Run ResNet with 6 blocks:

PYTHONPATH=. python ./experiments/train.py  --data mnist --save ./log_resnet6_mnist --save_every 50 \
--log_every 1 --optimizer sgd --lr 0.1 --num_res 6

Run ResNet with 1 block:

PYTHONPATH=. python ./experiments/train.py  --data mnist --save ./log_resnet1_mnist --save_every 50 \
--log_every 1 --optimizer sgd --lr 0.1 --num_res 1

Run OdeNet with explicit Runge-Kutta solver and tolerance 1e-2:

PYTHONPATH=. python ./experiments/train.py  --data mnist --save ./log_odenet_mnist --save_every 50 \
--log_every 1 --optimizer sgd --lr 0.1 --solver runge_kutta  --tol 1e-2 --use_ode

* another possible option is explicit Euler solver: --solver euler

Results

Test Accuracy Loss
mnist_score mnist_loss
Model Test Error, % # parameters Time (s/epoch)
ResNet(6) 0.34 577 K 13.18
ResNet(1) 0.37 207 K 11.21
OdeNet (Runge-Kutta) 0.45 207 K 254.42

CIFAR-10 classification

Run ResNet with 6 blocks:

PYTHONPATH=. python ./experiments/train.py  --data cifar --save ./log_resnet6_cifar --save_every 50 \
--log_every 1 --optimizer sgd --lr 0.1 --num_res 6

Run ResNet with 1 block:

PYTHONPATH=. python ./experiments/train.py  --data cifar --save ./log_resnet1_cifar --save_every 50 \
--log_every 1 --optimizer sgd --lr 0.1 --num_res 1

Run OdeNet with explicit Runge-Kutta solver and tolerance 1e-2 (may take a lot of time):

PYTHONPATH=. python ./experiments/train.py  --data cifar --save ./log_odenet_cifar --save_every 50 \
--log_every 1 --optimizer sgd --lr 0.1 --use_ode --solver runge_kutta  --tol 1e-2 

Results

Test Accuracy Loss
cifar_score cifar_loss
Model Accuracy, % # parameters Time (s/epoch)
ResNet(6) 86.7 577 K 12.25
ResNet(1) 84.19 207 K 9.84
OdeNet (Runge-Kutta) 84.85 207 K 1860.31
OdeNet (Euler) 84.62 207 K 159.02

Text classification

Download and create Ag-News data:

mkdir .data
mkdir .data/ag_news
cd .data/ag_news
wget https://raw.githubusercontent.com/tothanhtung0205/VDCNN/master/ag_news_csv/test.csv
wget https://raw.githubusercontent.com/tothanhtung0205/VDCNN/master/ag_news_csv/train.csv
echo -e 'World\nSports\nBusiness\nSci/Tech' > classes.txt

Run VdCNN with 6 blocks:

PYTHONPATH='.' python ./experiments/texts/vdcnn.py --batch_size 256 --max_epo 20 --save vdcnn6  

Run VdCNN with 1 block:

PYTHONPATH='.' python ./experiments/texts/vdcnn.py --batch_size 256 --max_epo 20 --save vdcnn1 \
--num_blocks 1

Run OdeNet with explicit Euler solver and tolerance 1e-2 (may take a lot of time):

PYTHONPATH='.' python ./experiments/texts/vdcnn.py --batch_size 256 --max_epo 20 --save vdcnn_ode \
--use_ode --solver euler  --tol 1e-2 

Results

Test Accuracy Loss
text_score text_loss
Model Accuracy, % # parameters Time (s/epoch)
VdCNN(6) 88.46 287 K 311
VdCNN(1) 87.75 162 K 172
OdeNet (Euler) 84.21 162 K 4874

References

Original implementation

About

Experiments with Neural Ordinary Differential Equations on image and text classification tasks

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages