In [1]:
import os
import argparse
import logging
import time
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torchvision.datasets as datasets
import torchvision.transforms as transforms


In [2]:
!pip install torchinfo

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [3]:
from torchinfo import summary

In [4]:
!pip install torchdiffeq

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [5]:
!pip install wandb

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [6]:
import wandb

In [7]:
wandb.login()

ERROR:wandb.jupyter:Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.


<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


True

In [8]:
network='resnet' # choices=['resnet', 'odenet']
tol=1e-3 # type float baseline=1e-3
block=6
adjoint=False #choices=[True, False]
downsampling_method='conv' #  choices=['conv', 'res']
nepochs=10 # type=int default=160
data_aug=True #choices=[True, False]
lr=0.1 # type=float
batch_size=1000 # type=int
test_batch_size=1000 #thpe=int
debug='store_true'
gpu=0  # type=int

In [9]:
device = torch.device('cuda:' + str(gpu) if torch.cuda.is_available() else 'cpu')

In [10]:
device

device(type='cuda', index=0)

In [11]:
if adjoint:
    from torchdiffeq import odeint_adjoint as odeint
else:
    from torchdiffeq import odeint

In [12]:
def conv3x3(in_planes, out_planes, stride=1):
    """3x3 convolution with padding"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)

In [13]:
def conv1x1(in_planes, out_planes, stride=1):
    """1x1 convolution"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)

In [14]:
def norm(dim):
    return nn.GroupNorm(min(32, dim), dim)

In [15]:
class ResBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(ResBlock, self).__init__()
        self.norm1 = norm(inplanes)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.norm2 = norm(planes)
        self.conv2 = conv3x3(planes, planes)

    def forward(self, x):
        shortcut = x

        out = self.relu(self.norm1(x))

        if self.downsample is not None:
            shortcut = self.downsample(out)

        out = self.conv1(out)
        out = self.norm2(out)
        out = self.relu(out)
        out = self.conv2(out)

        return out + shortcut

In [16]:
class ConcatConv2d(nn.Module):

    def __init__(self, dim_in, dim_out, ksize=3, stride=1, padding=0, dilation=1, groups=1, bias=True, transpose=False):
        super(ConcatConv2d, self).__init__()
        module = nn.ConvTranspose2d if transpose else nn.Conv2d
        self._layer = module(
            dim_in + 1, dim_out, kernel_size=ksize, stride=stride, padding=padding, dilation=dilation, groups=groups,
            bias=bias
        )

    def forward(self, t, x):
        tt = torch.ones_like(x[:, :1, :, :]) * t
        ttx = torch.cat([tt, x], 1)
        return self._layer(ttx)

In [17]:
class ODEfunc(nn.Module):

    def __init__(self, dim):
        super(ODEfunc, self).__init__()
        self.norm1 = norm(dim)
        self.relu = nn.ReLU(inplace=True)
        self.conv1 = ConcatConv2d(dim, dim, 3, 1, 1)
        self.norm2 = norm(dim)
        self.conv2 = ConcatConv2d(dim, dim, 3, 1, 1)
        self.norm3 = norm(dim)
        self.nfe = 0

    def forward(self, t, x):
        self.nfe += 1
        out = self.norm1(x)
        out = self.relu(out)
        out = self.conv1(t, out)
        out = self.norm2(out)
        out = self.relu(out)
        out = self.conv2(t, out)
        out = self.norm3(out)
        return out


In [18]:
class ODEBlock(nn.Module):

    def __init__(self, odefunc):
        super(ODEBlock, self).__init__()
        self.odefunc = odefunc
        self.integration_time = torch.tensor([0, 1]).float()

    def forward(self, x):
        self.integration_time = self.integration_time.type_as(x)
        out = odeint(self.odefunc, x, self.integration_time, rtol=tol, atol=tol)
        return out[1]

    @property
    def nfe(self):
        return self.odefunc.nfe

    @nfe.setter
    def nfe(self, value):
        self.odefunc.nfe = value

In [19]:
class Flatten(nn.Module):

    def __init__(self):
        super(Flatten, self).__init__()

    def forward(self, x):
        shape = torch.prod(torch.tensor(x.shape[1:])).item()
        return x.view(-1, shape)


In [20]:
class RunningAverageMeter(object):
    """Computes and stores the average and current value"""

    def __init__(self, momentum=0.99):
        self.momentum = momentum
        self.reset()

    def reset(self):
        self.val = None
        self.avg = 0

    def update(self, val):
        if self.val is None:
            self.avg = val
        else:
            self.avg = self.avg * self.momentum + val * (1 - self.momentum)
        self.val = val


In [21]:
def get_mnist_loaders(data_aug=False, batch_size=128, test_batch_size=1000, perc=1.0):
    if data_aug:
        transform_train = transforms.Compose([
            transforms.RandomCrop(28, padding=4),
            transforms.ToTensor(),
        ])
    else:
        transform_train = transforms.Compose([
            transforms.ToTensor(),
        ])

    transform_test = transforms.Compose([
        transforms.ToTensor(),
    ])

    train_loader = DataLoader(
        datasets.MNIST(root='.data/mnist', train=True, download=True, transform=transform_train), batch_size=batch_size,
        shuffle=True, num_workers=2, drop_last=True
    )

    train_eval_loader = DataLoader(
        datasets.MNIST(root='.data/mnist', train=True, download=True, transform=transform_test),
        batch_size=test_batch_size, shuffle=False, num_workers=2, drop_last=True
    )

    test_loader = DataLoader(
        datasets.MNIST(root='.data/mnist', train=False, download=True, transform=transform_test),
        batch_size=test_batch_size, shuffle=False, num_workers=2, drop_last=True
    )

    return train_loader, test_loader, train_eval_loader

In [22]:
def inf_generator(iterable):
    """Allows training with DataLoaders in a single infinite loop:
        for i, (x, y) in enumerate(inf_generator(train_loader)):
    """
    iterator = iterable.__iter__()
    while True:
        try:
            yield iterator.__next__()
        except StopIteration:
            iterator = iterable.__iter__()

In [23]:
def learning_rate_with_decay(batch_size, batch_denom, batches_per_epoch, boundary_epochs, decay_rates):
    initial_learning_rate = lr * batch_size / batch_denom

    boundaries = [int(batches_per_epoch * epoch) for epoch in boundary_epochs]
    vals = [initial_learning_rate * decay for decay in decay_rates]

    def learning_rate_fn(itr):
        lt = [itr < b for b in boundaries] + [True]
        i = np.argmax(lt)
        return vals[i]

    return learning_rate_fn


In [24]:
def one_hot(x, K):
    return np.array(x[:, None] == np.arange(K)[None, :], dtype=int)

In [25]:
def accuracy(model, dataset_loader):
    total_correct = 0
    for x, y in dataset_loader:
        x = x.to(device)
        y = one_hot(np.array(y.numpy()), 10)

        target_class = np.argmax(y, axis=1)
        predicted_class = np.argmax(model(x).cpu().detach().numpy(), axis=1)
        total_correct += np.sum(predicted_class == target_class)
    return total_correct / len(dataset_loader.dataset)

In [26]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

In [27]:
is_odenet = network == 'odenet'

if downsampling_method == 'conv':
    downsampling_layers = [
            nn.Conv2d(1, 64, 3, 1),
            norm(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, 4, 2, 1),
            norm(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, 4, 2, 1),]
elif downsampling_method == 'res':
        downsampling_layers = [
            nn.Conv2d(1, 64, 3, 1),
            ResBlock(64, 64, stride=2, downsample=conv1x1(64, 64, 2)),
            ResBlock(64, 64, stride=2, downsample=conv1x1(64, 64, 2)),
        ]

feature_layers = [ODEBlock(ODEfunc(64))] if is_odenet else [ResBlock(64, 64) for _ in range(block)]
fc_layers = [norm(64), nn.ReLU(inplace=True), nn.AdaptiveAvgPool2d((1, 1)), Flatten(), nn.Linear(64, 10)]

model = nn.Sequential(*downsampling_layers, *feature_layers, *fc_layers).to(device)


In [None]:
#torch.save(model, 'model_resnet.pth')

In [28]:
print(model)

Sequential(
  (0): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1))
  (1): GroupNorm(32, 64, eps=1e-05, affine=True)
  (2): ReLU(inplace=True)
  (3): Conv2d(64, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
  (4): GroupNorm(32, 64, eps=1e-05, affine=True)
  (5): ReLU(inplace=True)
  (6): Conv2d(64, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
  (7): ResBlock(
    (norm1): GroupNorm(32, 64, eps=1e-05, affine=True)
    (relu): ReLU(inplace=True)
    (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (norm2): GroupNorm(32, 64, eps=1e-05, affine=True)
    (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  )
  (8): ResBlock(
    (norm1): GroupNorm(32, 64, eps=1e-05, affine=True)
    (relu): ReLU(inplace=True)
    (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (norm2): GroupNorm(32, 64, eps=1e-05, affine=True)
    (conv2): Conv2d(64, 64, kernel_size=(3,

In [29]:
summary(model)

Layer (type:depth-idx)                   Param #
Sequential                               --
├─Conv2d: 1-1                            640
├─GroupNorm: 1-2                         128
├─ReLU: 1-3                              --
├─Conv2d: 1-4                            65,600
├─GroupNorm: 1-5                         128
├─ReLU: 1-6                              --
├─Conv2d: 1-7                            65,600
├─ResBlock: 1-8                          --
│    └─GroupNorm: 2-1                    128
│    └─ReLU: 2-2                         --
│    └─Conv2d: 2-3                       36,864
│    └─GroupNorm: 2-4                    128
│    └─Conv2d: 2-5                       36,864
├─ResBlock: 1-9                          --
│    └─GroupNorm: 2-6                    128
│    └─ReLU: 2-7                         --
│    └─Conv2d: 2-8                       36,864
│    └─GroupNorm: 2-9                    128
│    └─Conv2d: 2-10                      36,864
├─ResBlock: 1-10                        

In [30]:
ad=summary(model)
#summary(model)

In [31]:

config = {"network": network, "adjoint": adjoint,"Model params":ad.total_params ,"tolerance":tol,'blocks': block}
wandb.init(project="ODE_ResNet", config=config,tags=["new"])


[34m[1mwandb[0m: Currently logged in as: [33msojohan[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [32]:
optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9) 
criterion = nn.CrossEntropyLoss().to(device)

In [33]:
train_loader, test_loader, train_eval_loader = get_mnist_loaders(data_aug, batch_size, test_batch_size)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to .data/mnist/MNIST/raw/train-images-idx3-ubyte.gz


  0%|          | 0/9912422 [00:00<?, ?it/s]

Extracting .data/mnist/MNIST/raw/train-images-idx3-ubyte.gz to .data/mnist/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to .data/mnist/MNIST/raw/train-labels-idx1-ubyte.gz


  0%|          | 0/28881 [00:00<?, ?it/s]

Extracting .data/mnist/MNIST/raw/train-labels-idx1-ubyte.gz to .data/mnist/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to .data/mnist/MNIST/raw/t10k-images-idx3-ubyte.gz


  0%|          | 0/1648877 [00:00<?, ?it/s]

Extracting .data/mnist/MNIST/raw/t10k-images-idx3-ubyte.gz to .data/mnist/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to .data/mnist/MNIST/raw/t10k-labels-idx1-ubyte.gz


  0%|          | 0/4542 [00:00<?, ?it/s]

Extracting .data/mnist/MNIST/raw/t10k-labels-idx1-ubyte.gz to .data/mnist/MNIST/raw



In [34]:
data_gen = inf_generator(train_loader)
batches_per_epoch = len(train_loader)

In [35]:
 lr_fn = learning_rate_with_decay(batch_size, batch_denom=128, batches_per_epoch=batches_per_epoch, boundary_epochs=[60, 100, 140],
        decay_rates=[1, 0.1, 0.01, 0.001]
    )

In [36]:
best_acc = 0
batch_time_meter = RunningAverageMeter()
f_nfe_meter = RunningAverageMeter()
b_nfe_meter = RunningAverageMeter()
end = time.time()

In [37]:
print('nepochs',nepochs)
print('batches_per_epoch',batches_per_epoch)
#print('len test_loader', len(test_loader))
wandb.watch(model)
#print(nepochs * batches_per_epoch)

#for itr in range(nepochs * batches_per_epoch):
timecount=0.0
for itr in range(nepochs*batches_per_epoch):
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr_fn(itr)

    optimizer.zero_grad()
    x, y = data_gen.__next__()
    #print(x.shape)
    x = x.to(device)
    y = y.to(device)
    logits = model(x)
    loss = criterion(logits, y)

    if is_odenet:
        nfe_forward = feature_layers[0].nfe
        #print(nfe_forward)
        feature_layers[0].nfe = 0

    loss.backward()
    optimizer.step()

    if is_odenet:
        nfe_backward = feature_layers[0].nfe
        feature_layers[0].nfe = 0

    batch_time_meter.update(time.time() - end)
    if is_odenet:
        f_nfe_meter.update(nfe_forward)
        b_nfe_meter.update(nfe_backward)
    end = time.time()

    if itr % batches_per_epoch == 0:
        with torch.no_grad():
            train_acc = accuracy(model, train_eval_loader)
            val_acc = accuracy(model, test_loader)
            if val_acc > best_acc:
              best_acc = val_acc
            timecount=timecount+batch_time_meter.val
            #print(timecount)
            wandb.log({"loss": loss, "Train Accuraccy": train_acc, "Validation Accuracy": val_acc,"NFE-F": f_nfe_meter.avg, "NFE-B" :  b_nfe_meter.avg,"TotalTime":timecount } )
            print("Epoch {:04d} | Time {:.3f} ({:.3f}) | NFE-F {:.1f} | NFE-B {:.1f} | Train Acc {:.4f} | Test Acc {:.4f}".format(
                        itr // batches_per_epoch, batch_time_meter.val, batch_time_meter.avg, f_nfe_meter.avg,
                        b_nfe_meter.avg, train_acc, val_acc
                    )
                )

nepochs 10
batches_per_epoch 60
Epoch 0000 | Time 20.059 (20.059) | NFE-F 0.0 | NFE-B 0.0 | Train Acc 0.1044 | Test Acc 0.1028
Epoch 0001 | Time 0.429 (11.079) | NFE-F 0.0 | NFE-B 0.0 | Train Acc 0.4213 | Test Acc 0.4205
Epoch 0002 | Time 0.485 (6.173) | NFE-F 0.0 | NFE-B 0.0 | Train Acc 0.4029 | Test Acc 0.4094
Epoch 0003 | Time 0.482 (3.485) | NFE-F 0.0 | NFE-B 0.0 | Train Acc 0.9393 | Test Acc 0.9423
Epoch 0004 | Time 0.472 (2.016) | NFE-F 0.0 | NFE-B 0.0 | Train Acc 0.8957 | Test Acc 0.9027
Epoch 0005 | Time 0.464 (1.215) | NFE-F 0.0 | NFE-B 0.0 | Train Acc 0.9569 | Test Acc 0.9601
Epoch 0006 | Time 0.463 (0.776) | NFE-F 0.0 | NFE-B 0.0 | Train Acc 0.9762 | Test Acc 0.9768
Epoch 0007 | Time 0.497 (0.535) | NFE-F 0.0 | NFE-B 0.0 | Train Acc 0.9832 | Test Acc 0.9839
Epoch 0008 | Time 0.476 (0.403) | NFE-F 0.0 | NFE-B 0.0 | Train Acc 0.9849 | Test Acc 0.9843
Epoch 0009 | Time 0.486 (0.330) | NFE-F 0.0 | NFE-B 0.0 | Train Acc 0.9879 | Test Acc 0.9881
