In [2]:
:dep dfdx = {version="*", git="https://github.com/coreylowman/dfdx", features=["cuda"]}
:dep mnist = "0.5.0"
:dep indicatif = "0.17.3"
:dep rand = { version = "0.8.5", default-features = false, features = ["std_rng"] }


In [12]:
use std::time::Instant;

use indicatif::ProgressIterator;
use mnist::*;
use rand::prelude::{SeedableRng, StdRng};

use dfdx::{data::*, optim::Adam, prelude::*};

type Dev = Cuda;


In [24]:
struct MnistTrainSet(Mnist);

impl MnistTrainSet {
    fn new(path: &str) -> Self {
        Self(MnistBuilder::new().base_path(path).finalize())
    }
}

impl ExactSizeDataset for MnistTrainSet {
    type Item<'a> = (Vec<f32>, usize) where Self: 'a;
    fn get(&self, index: usize) -> Self::Item<'_> {
        let mut img_data: Vec<f32> = Vec::with_capacity(784);
        let start = 784 * index;
        img_data.extend(
            self.0.trn_img[start..start + 784]
                .iter()
                .map(|x| *x as f32 / 255.0),
        );
        (img_data, self.0.trn_lbl[index] as usize)
    }
    fn len(&self) -> usize {
        self.0.trn_lbl.len()
    }
}

// our network structure
type Mlp = (
    (Linear<784, 512>, ReLU),
    (Linear<512, 128>, ReLU),
    (Linear<128, 32>, ReLU),
    Linear<32, 10>,
);

// training batch size
const BATCH_SIZE: usize = 32;


fn preprocess(dev:&Dev, (img, lbl): <MnistTrainSet as ExactSizeDataset>::Item<'_>) -> (Tensor<Rank1<784>, f32, Dev>, Tensor<Rank1<10>, f32, Dev>){
    let mut one_hotted = [0.0; 10];
    one_hotted[lbl] = 1.0;
    (
        dev.tensor_from_vec(img, (Const::<784>,)),
        dev.tensor(one_hotted),
    )
}


In [20]:
let mnist_path = "./tmp";

In [25]:
// ftz substantially improves performance
dfdx::flush_denormals_to_zero();

let dev: Dev = Default::default();
let mut rng = StdRng::seed_from_u64(0);

// initialize model, gradients, and optimizer
let mut model = dev.build_module::<Mlp, f32>();
let mut grads = model.alloc_grads();
let mut opt = Adam::new(&model, Default::default());

// initialize dataset
let dataset = MnistTrainSet::new(&mnist_path);
println!("Found {:?} training images", dataset.len());

for i_epoch in 0..10 {
    let mut total_epoch_loss = 0.0;
    let mut num_batches = 0;
    let start = Instant::now();
    for (img, lbl) in dataset
        .shuffled(&mut rng)
        .map(|x| preprocess(&dev, x))
        .batch(Const::<BATCH_SIZE>)
        .collate()
        .stack()
        .progress()
    {
        let logits = model.forward_mut(img.traced_into(grads));
        let loss = cross_entropy_with_logits_loss(logits, lbl);

        total_epoch_loss += loss.array();
        num_batches += 1;

        grads = loss.backward();
        opt.update(&mut model, &grads).unwrap();
        model.zero_grads(&mut grads);
    }
    let dur = Instant::now() - start;

    println!(
        "Epoch {i_epoch} in {:?} ({:.3} batches/s): avg sample loss {:.5}",
        dur,
        num_batches as f32 / dur.as_secs_f32(),
        BATCH_SIZE as f32 * total_epoch_loss / num_batches as f32,
    );
}


Found 60000 training images
Epoch 0 in 1.975190251s (949.276 batches/s): avg sample loss 7.72665
Epoch 1 in 1.894973108s (989.460 batches/s): avg sample loss 3.01805
Epoch 2 in 1.822600741s (1028.750 batches/s): avg sample loss 2.01029
Epoch 3 in 1.821525288s (1029.357 batches/s): avg sample loss 1.51861
Epoch 4 in 1.811539245s (1035.031 batches/s): avg sample loss 1.20415
Epoch 5 in 1.813613414s (1033.848 batches/s): avg sample loss 1.00066
Epoch 6 in 1.813043999s (1034.172 batches/s): avg sample loss 0.87972
Epoch 7 in 1.821649142s (1029.287 batches/s): avg sample loss 0.70782
Epoch 8 in 1.82468534s (1027.574 batches/s): avg sample loss 0.66483
Epoch 9 in 1.81730959s (1031.745 batches/s): avg sample loss 0.57505


()