When I run it it prints out:

import os
import argparse
import logging
import time
import numpy as np
import torch
import torch.nn as nn
from import DataLoader
import torchvision.datasets as datasets
import torchvision.transforms as transforms

parser = argparse.ArgumentParser()
parser.add_argument('--network', type=str, choices=['resnet', 'odenet'], default='odenet')
parser.add_argument('--tol', type=float, default=1e-3)
parser.add_argument('--adjoint', type=eval, default=False, choices=[True, False])
parser.add_argument('--downsampling-method', type=str, default='conv', choices=['conv', 'res'])
parser.add_argument('--nepochs', type=int, default=160)
parser.add_argument('--data_aug', type=eval, default=True, choices=[True, False])
parser.add_argument('--lr', type=float, default=0.1)
parser.add_argument('--batch_size', type=int, default=128)
parser.add_argument('--test_batch_size', type=int, default=1000)

parser.add_argument('--save', type=str, default='./experiment1')
parser.add_argument('--debug', action='store_true')
parser.add_argument('--gpu', type=int, default=0)
args = parser.parse_args()

if args.adjoint:
    from torchdiffeq import odeint_adjoint as odeint
    from torchdiffeq import odeint

def conv3x3(in_planes, out_planes, stride=1):
    """3x3 convolution with padding"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)

def conv1x1(in_planes, out_planes, stride=1):
    """1x1 convolution"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)

def norm(dim):
    return nn.GroupNorm(min(32, dim), dim)

class ResBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(ResBlock, self).__init__()
        self.norm1 = norm(inplanes)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.norm2 = norm(planes)
        self.conv2 = conv3x3(planes, planes)

    def forward(self, x):
        shortcut = x

        out = self.relu(self.norm1(x))

        if self.downsample is not None:
            shortcut = self.downsample(out)

        out = self.conv1(out)
        out = self.norm2(out)
        out = self.relu(out)
        out = self.conv2(out)

        return out + shortcut

class ConcatConv2d(nn.Module):

    def __init__(self, dim_in, dim_out, ksize=3, stride=1, padding=0, dilation=1, groups=1, bias=True, transpose=False):
        super(ConcatConv2d, self).__init__()
        module = nn.ConvTranspose2d if transpose else nn.Conv2d
        self._layer = module(
            dim_in + 1, dim_out, kernel_size=ksize, stride=stride, padding=padding, dilation=dilation, groups=groups,

    def forward(self, t, x):
        tt = torch.ones_like(x[:, :1, :, :]) * t
        ttx =[tt, x], 1)
        return self._layer(ttx)

class ODEfunc(nn.Module):

    def __init__(self, dim):
        super(ODEfunc, self).__init__()
        self.norm1 = norm(dim)
        self.relu = nn.ReLU(inplace=True)
        self.conv1 = ConcatConv2d(dim, dim, 3, 1, 1)
        self.norm2 = norm(dim)
        self.conv2 = ConcatConv2d(dim, dim, 3, 1, 1)
        self.norm3 = norm(dim)
        self.nfe = 0

    def forward(self, t, x):
        self.nfe += 1
        out = self.norm1(x)
        out = self.relu(out)
        out = self.conv1(t, out)
        out = self.norm2(out)
        out = self.relu(out)
        out = self.conv2(t, out)
        out = self.norm3(out)
        return out

class ODEBlock(nn.Module):

    def __init__(self, odefunc):
        super(ODEBlock, self).__init__()
        self.odefunc = odefunc
        self.integration_time = torch.tensor([0, 1]).float()

    def forward(self, x):
        self.integration_time = self.integration_time.type_as(x)
        out = odeint(self.odefunc, x, self.integration_time, rtol=args.tol, atol=args.tol)
        return out[1]

    def nfe(self):
        return self.odefunc.nfe

    def nfe(self, value):
        self.odefunc.nfe = value

class Flatten(nn.Module):

    def __init__(self):
        super(Flatten, self).__init__()

    def forward(self, x):
        shape =[1:])).item()
        return x.view(-1, shape)

class RunningAverageMeter(object):
    """Computes and stores the average and current value"""

    def __init__(self, momentum=0.99):
        self.momentum = momentum

    def reset(self):
        self.val = None
        self.avg = 0

    def update(self, val):
        if self.val is None:
            self.avg = val
            self.avg = self.avg * self.momentum + val * (1 - self.momentum)
        self.val = val

def get_mnist_loaders(data_aug=False, batch_size=128, test_batch_size=1000, perc=1.0):
    if data_aug:
        transform_train = transforms.Compose([
            transforms.RandomCrop(28, padding=4),
        transform_train = transforms.Compose([

    transform_test = transforms.Compose([

    train_loader = DataLoader(
        datasets.MNIST(root='.data/mnist', train=True, download=True, transform=transform_train), batch_size=batch_size,
        shuffle=True, num_workers=2, drop_last=True

    train_eval_loader = DataLoader(
        datasets.MNIST(root='.data/mnist', train=True, download=True, transform=transform_test),
        batch_size=test_batch_size, shuffle=False, num_workers=2, drop_last=True

    test_loader = DataLoader(
        datasets.MNIST(root='.data/mnist', train=False, download=True, transform=transform_test),
        batch_size=test_batch_size, shuffle=False, num_workers=2, drop_last=True

    return train_loader, test_loader, train_eval_loader

def inf_generator(iterable):
    """Allows training with DataLoaders in a single infinite loop:
        for i, (x, y) in enumerate(inf_generator(train_loader)):
    iterator = iterable.__iter__()
    while True:
            yield iterator.__next__()
        except StopIteration:
            iterator = iterable.__iter__()

def learning_rate_with_decay(batch_size, batch_denom, batches_per_epoch, boundary_epochs, decay_rates):
    initial_learning_rate = * batch_size / batch_denom

    boundaries = [int(batches_per_epoch * epoch) for epoch in boundary_epochs]
    vals = [initial_learning_rate * decay for decay in decay_rates]

    def learning_rate_fn(itr):
        lt = [itr < b for b in boundaries] + [True]
        i = np.argmax(lt)
        return vals[i]

    return learning_rate_fn

def one_hot(x, K):
    return np.array(x[:, None] == np.arange(K)[None, :], dtype=int)

def accuracy(model, dataset_loader):
    total_correct = 0
    for x, y in dataset_loader:
        x =
        y = one_hot(np.array(y.numpy()), 10)

        target_class = np.argmax(y, axis=1)
        predicted_class = np.argmax(model(x).cpu().detach().numpy(), axis=1)
        total_correct += np.sum(predicted_class == target_class)
    return total_correct / len(dataset_loader.dataset)

def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

def makedirs(dirname):
    if not os.path.exists(dirname):

def get_logger(logpath, filepath, package_files=[], displaying=True, saving=True, debug=False):
    logger = logging.getLogger()
    if debug:
        level = logging.DEBUG
        level = logging.INFO
    if saving:
        info_file_handler = logging.FileHandler(logpath, mode="a")
    if displaying:
        console_handler = logging.StreamHandler()
    with open(filepath, "r") as f:

    for f in package_files:
        with open(f, "r") as package_f:

    return logger

if __name__ == '__main__':

    logger = get_logger(logpath=os.path.join(, 'logs'), filepath=os.path.abspath(__file__))

    device = torch.device('cuda:' + str(args.gpu) if torch.cuda.is_available() else 'cpu')

    is_odenet = == 'odenet'

    if args.downsampling_method == 'conv':
        downsampling_layers = [
            nn.Conv2d(1, 64, 3, 1),
            nn.Conv2d(64, 64, 4, 2, 1),
            nn.Conv2d(64, 64, 4, 2, 1),
    elif args.downsampling_method == 'res':
        downsampling_layers = [
            nn.Conv2d(1, 64, 3, 1),
            ResBlock(64, 64, stride=2, downsample=conv1x1(64, 64, 2)),
            ResBlock(64, 64, stride=2, downsample=conv1x1(64, 64, 2)),

    feature_layers = [ODEBlock(ODEfunc(64))] if is_odenet else [ResBlock(64, 64) for _ in range(6)]
    fc_layers = [norm(64), nn.ReLU(inplace=True), nn.AdaptiveAvgPool2d((1, 1)), Flatten(), nn.Linear(64, 10)]

    model = nn.Sequential(*downsampling_layers, *feature_layers, *fc_layers).to(device)'Number of parameters: {}'.format(count_parameters(model)))

    criterion = nn.CrossEntropyLoss().to(device)

    train_loader, test_loader, train_eval_loader = get_mnist_loaders(
        args.data_aug, args.batch_size, args.test_batch_size

    data_gen = inf_generator(train_loader)
    batches_per_epoch = len(train_loader)

    lr_fn = learning_rate_with_decay(
        args.batch_size, batch_denom=128, batches_per_epoch=batches_per_epoch, boundary_epochs=[60, 100, 140],
        decay_rates=[1, 0.1, 0.01, 0.001]

    optimizer = torch.optim.SGD(model.parameters(),, momentum=0.9)

    best_acc = 0
    batch_time_meter = RunningAverageMeter()
    f_nfe_meter = RunningAverageMeter()
    b_nfe_meter = RunningAverageMeter()
    end = time.time()

    for itr in range(args.nepochs * batches_per_epoch):

        for param_group in optimizer.param_groups:
            param_group['lr'] = lr_fn(itr)

        x, y = data_gen.__next__()
        x =
        y =
        logits = model(x)
        loss = criterion(logits, y)

        if is_odenet:
            nfe_forward = feature_layers[0].nfe
            feature_layers[0].nfe = 0


        if is_odenet:
            nfe_backward = feature_layers[0].nfe
            feature_layers[0].nfe = 0

        batch_time_meter.update(time.time() - end)
        if is_odenet:
        end = time.time()

        if itr % batches_per_epoch == 0:
            with torch.no_grad():
                train_acc = accuracy(model, train_eval_loader)
                val_acc = accuracy(model, test_loader)
                if val_acc > best_acc:
          {'state_dict': model.state_dict(), 'args': args}, os.path.join(, 'model.pth'))
                    best_acc = val_acc
                    "Epoch {:04d} | Time {:.3f} ({:.3f}) | NFE-F {:.1f} | NFE-B {:.1f} | "
                    "Train Acc {:.4f} | Test Acc {:.4f}".format(
                        itr // batches_per_epoch, batch_time_meter.val, batch_time_meter.avg, f_nfe_meter.avg,
                        b_nfe_meter.avg, train_acc, val_acc

Namespace(adjoint=False, batch_size=128, data_aug=True, debug=False, downsampling_method='conv', gpu=0, lr=0.1, nepochs=160, network='odenet', save='./experiment1', test_batch_size=1000, tol=0.001)
  (0): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1))
  (1): GroupNorm(32, 64, eps=1e-05, affine=True)
  (2): ReLU(inplace)
  (3): Conv2d(64, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
  (4): GroupNorm(32, 64, eps=1e-05, affine=True)
  (5): ReLU(inplace)
  (6): Conv2d(64, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
  (7): ODEBlock(
    (odefunc): ODEfunc(
      (norm1): GroupNorm(32, 64, eps=1e-05, affine=True)
      (relu): ReLU(inplace)
      (conv1): ConcatConv2d(
        (_layer): Conv2d(65, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (norm2): GroupNorm(32, 64, eps=1e-05, affine=True)
      (conv2): ConcatConv2d(
        (_layer): Conv2d(65, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (norm3): GroupNorm(32, 64, eps=1e-05, affine=True)
  (8): GroupNorm(32, 64, eps=1e-05, affine=True)
  (9): ReLU(inplace)
  (10): AdaptiveAvgPool2d(output_size=(1, 1))
  (11): Flatten()
  (12): Linear(in_features=64, out_features=10, bias=True)
Number of parameters: 208266

and uses 100% cpu for hours - checking in morning and nothing seems to have happened.

It doesn't finish - Ctrl+C stops it at this point:

  File "~/torchdiffeq/torchdiffeq/_impl/", line 103, in _adaptive_dopri5_step
    y1, f1, y1_error, k = _runge_kutta_step(self.func, y0, f0, t0, dt, tableau=_DORMAND_PRINCE_SHAMPINE_TABLEAU)
  File "~/torchdiffeq/torchdiffeq/_impl/", line 52, in _runge_kutta_step
    tuple(k_.append(f_) for k_, f_ in zip(k, func(ti, yi)))
  File "~/torchdiffeq/torchdiffeq/_impl/", line 179, in <lambda>
    func = lambda t, y: (_base_nontuple_func_(t, y[0]),)
  File "/usr/local/lib/python3.5/dist-packages/torch/nn/modules/", line 489, in __call__
    result = self.forward(*input, **kwargs)
  File "./examples/", line 108, in forward
    out = self.conv1(t, out)
  File "/usr/local/lib/python3.5/dist-packages/torch/nn/modules/", line 489, in __call__
    result = self.forward(*input, **kwargs)
  File "./examples/", line 89, in forward
    return self._layer(ttx)
  File "/usr/local/lib/python3.5/dist-packages/torch/nn/modules/", line 489, in __call__
    result = self.forward(*input, **kwargs)
  File "/usr/local/lib/python3.5/dist-packages/torch/nn/modules/", line 320, in forward
    self.padding, self.dilation, self.groups)

rtqichen commented Feb 5, 2019

This might be because you're running on CPU. Being extremely slow on CPU is expected, as training requires evaluating a neural net multiple times. Does ode-demo work?

programmin1 commented Feb 6, 2019

Yes, ./examples/ works. I thought the examples would work overnight on a 2.6ghz i5.
Is a Geforce GT 630M enough to run this example?

shsad commented Nov 22, 2019

I also let the code run for  ~12 hours and it managed to finish 13 epochs (I would assume an epoch/hour). Does it mean that I have to stick to less epochs (from the default 160 if I am not mistaken) since I don't have a GPU?

Thanks in advance.

Hmm.. I'd suggest using a GPU as running neural nets on CPU is way too slow. Colaboratory ( lets you use a free GPU. You'll be able to install torchdiffeq with the following command:

pip install git+

shsad commented Nov 23, 2019

Thank you for the suggestion. I find the Jupyter Notebook terrible for debugging. I tried running the and with GPU and I get :
`usage: ODE demo [-h] [--method {dopri5,adams}] [--data_size DATA_SIZE]
[--batch_time BATCH_TIME] [--batch_size BATCH_SIZE]
[--niters NITERS] [--test_freq TEST_FREQ] [--viz] [--gpu GPU]
ODE demo: error: unrecognized arguments: -f /root/.local/share/jupyter/runtime/kernel-eddba8cd-8dd7-4e14-8e81-6da72c82bf76.json
An exception has occurred, use %tb to see the full traceback.

SystemExit: 2
/usr/local/lib/python3.6/dist-packages/IPython/core/ UserWarning: To exit: use 'exit', 'quit', or Ctrl-D.
warn("To exit: use 'exit', 'quit', or Ctrl-D.", stacklevel=1)`

and for
`usage: [-h] [--network {resnet,odenet}] [--tol TOL]
[--adjoint {True,False}]
[--downsampling-method {conv,res}]
[--nepochs NEPOCHS] [--data_aug {True,False}]
[--lr LR] [--batch_size BATCH_SIZE]
[--test_batch_size TEST_BATCH_SIZE] [--save SAVE]
[--debug] [--gpu GPU] error: unrecognized arguments: -f /root/.local/share/jupyter/runtime/kernel-eddba8cd-8dd7-4e14-8e81-6da72c82bf76.json
An exception has occurred, use %tb to see the full traceback.

SystemExit: 2
/usr/local/lib/python3.6/dist-packages/IPython/core/ UserWarning: To exit: use 'exit', 'quit', or Ctrl-D.
warn("To exit: use 'exit', 'quit', or Ctrl-D.", stacklevel=1)`

Could you tell me why you use 160 epochs (I know the more the better)?

Thanks in advance!

The idea is you can just copy the Python code instead of running it as a script.

Yeah, you should be able to get the same accuracy with much fewer number of epochs.

shsad commented Nov 24, 2019

Thank you!

