- Install the PySodium Library
pip install git+https://github.com/satyajitghana/PySodium.git#egg=sodium
- Create a config.yml
name: CIFAR10_V2
save_dir: saved/
seed: 1
target_device: 0
arch:
type: CIFAR10S8Model
args: {}
augmentation:
type: CIFAR10Albumentations
args: {}
data_loader:
type: CIFAR10DataLoader
args:
batch_size: 128
data_dir: data/
nworkers: 4
shuffle: True
criterion: cross_entropy_loss
lr_scheduler:
type: ReduceLROnPlateau
args:
mode: 'min'
batch_scheduler: False
optimizer:
type: SGD
args:
lr: 0.001
momentum: 0.95
weight_decay: 0.0005
training:
epochs: 10
- Run the Model !
# import my baby-library
from sodium.utils import load_config
import sodium.runner as runner
# create a runner
config = load_config('config.yml', tsai_mode=True)
# setup trainer
runner.setup_train(tsai_mode=True)
# find best lr
runner.find_lr()
# train the network using the best lr
runner.train(use_bestlr=True)
# plot metrics
runner.plot_metrics()
# plot grad cam
target_layers = ["layer1", "layer2", "layer3", "layer4"]
runner.plot_gradcam(target_layers=target_layers)
# plot misclassifications
runner.plot_misclassifications(target_layers=target_layers)
if you are using the library on a terminal, you can use the main.py and call
python main.py --config=config.yml
To install OpenCV first update your anaconda environment
conda update --all
conda install -c conda-forge opencv
Made with ❤ by shadowleaf.satyajit