<a href="https://colab.research.google.com/github/ysjgithub/dl/blob/master/ode_craf_10.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
!git clone https://github.com/rtqichen/torchdiffeq.git

Cloning into 'torchdiffeq'...
remote: Enumerating objects: 15, done.[K
remote: Counting objects: 100% (15/15), done.[K
remote: Compressing objects: 100% (14/14), done.[K
remote: Total 229 (delta 3), reused 4 (delta 1), pack-reused 214[K
Receiving objects: 100% (229/229), 715.37 KiB | 1.71 MiB/s, done.
Resolving deltas: 100% (100/100), done.


In [0]:
import os
import argparse
import logging
import time
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torchvision.datasets as datasets
import torchvision.transforms as transforms

In [0]:
%cd torchdiffeq/

/content/torchdiffeq


In [0]:
from torchdiffeq import odeint_adjoint as odeint

In [0]:

def conv3x3(in_planes, out_planes, stride=1):
    """3x3 convolution with padding"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)


def conv1x1(in_planes, out_planes, stride=1):
    """1x1 convolution"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)


def norm(dim):
    return nn.GroupNorm(min(64, dim), dim)


In [0]:
class ResBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(ResBlock, self).__init__()
        self.norm1 = norm(inplanes)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.norm2 = norm(planes)
        self.conv2 = conv3x3(planes, planes)

    def forward(self, x):
        shortcut = x

        out = self.relu(self.norm1(x))

        if self.downsample is not None:
            shortcut = self.downsample(out)

        out = self.conv1(out)
        out = self.norm2(out)
        out = self.relu(out)
        out = self.conv2(out)

        return out + shortcut


In [0]:
class ConcatConv2d(nn.Module):

    def __init__(self, dim_in, dim_out, ksize=3, stride=1, padding=0, dilation=1, groups=1, bias=True, transpose=False):
        super(ConcatConv2d, self).__init__()
        module = nn.ConvTranspose2d if transpose else nn.Conv2d
        self._layer = module(
            dim_in + 1, dim_out, kernel_size=ksize, stride=stride, padding=padding, dilation=dilation, groups=groups,
            bias=bias
        )

    def forward(self, t, x):
        # 扩张输入 128 64 6 6
        tt = torch.ones_like(x[:, :1, :, :]) * t
        # 128 65 6 6
        ttx = torch.cat([tt, x], 1)
        return self._layer(ttx)



In [0]:

class ODEfunc(nn.Module):

    def __init__(self, dim):
        super(ODEfunc, self).__init__()
        self.norm1 = norm(dim)
        self.relu = nn.ReLU(inplace=True)
        self.conv1 = ConcatConv2d(dim, dim, 3, 1, 1)
        self.norm2 = norm(dim)
        self.conv2 = ConcatConv2d(dim, dim, 3, 1, 1)
        self.norm3 = norm(dim)
        self.nfe = 0

    def forward(self, t, x):
        self.nfe += 1
        out = self.norm1(x)
        out = self.relu(out)
        out = self.conv1(t, out)
        out = self.norm2(out)
        out = self.relu(out)
        out = self.conv2(t, out)
        out = self.norm3(out)
        return out

# 是一个神经网络
class ODEBlock(nn.Module):

    def __init__(self, odefunc):
        super(ODEBlock, self).__init__()
        self.odefunc = odefunc
        self.integration_time = torch.tensor([0, 1]).float()

    # 神经网络的前向方法
    def forward(self, x):
        self.integration_time = self.integration_time.type_as(x)
        # 有odeint函数的作用，输出与odefunc,odefunc有关,odefunc也是神经网络
        # x 初始值，integration_time时间序列区间
        out = odeint(self.odefunc, x, self.integration_time, rtol=args.tol, atol=args.tol)
        return out[1]

    @property
    def nfe(self):
        return self.odefunc.nfe

    @nfe.setter
    def nfe(self, value):
        self.odefunc.nfe = value


In [0]:
class Flatten(nn.Module):

    def __init__(self):
        super(Flatten, self).__init__()

    def forward(self, x):
        shape = torch.prod(torch.tensor(x.shape[1:])).item()
        return x.view(-1, shape)


class RunningAverageMeter(object):
    """Computes and stores the average and current value"""

    def __init__(self, momentum=0.99):
        self.momentum = momentum
        self.reset()

    def reset(self):
        self.val = None
        self.avg = 0

    def update(self, val):
        if self.val is None:
            self.avg = val
        else:
            self.avg = self.avg * self.momentum + val * (1 - self.momentum)
        self.val = val


def get_mnist_loaders(data_aug=False, batch_size=128, test_batch_size=1000, perc=1.0):
    if data_aug:
        transform_train = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.ColorJitter(hue=.05, saturation=.05),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ])
    else:
        transform_train = transforms.Compose([
            transforms.ToTensor(),
        ])

    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])

    train_loader = DataLoader(
        datasets.CIFAR10(root='.data/CIFAR10', train=True, download=True, transform=transform_train), batch_size=batch_size,
        shuffle=True, num_workers=2, drop_last=True
    )

    train_eval_loader = DataLoader(
        datasets.CIFAR10(root='.data/CIFAR10', train=True, download=True, transform=transform_test),
        batch_size=test_batch_size, shuffle=False, num_workers=2, drop_last=True
    )

    test_loader = DataLoader(
        datasets.CIFAR10(root='.data/CIFAR10', train=False, download=True, transform=transform_test),
        batch_size=test_batch_size, shuffle=False, num_workers=2, drop_last=True
    )

    return train_loader, test_loader, train_eval_loader


def inf_generator(iterable):
    """Allows training with DataLoaders in a single infinite loop:
        for i, (x, y) in enumerate(inf_generator(train_loader)):
    """
    iterator = iterable.__iter__()
    while True:
        try:
            yield iterator.__next__()
        except StopIteration:
            iterator = iterable.__iter__()


def learning_rate_with_decay(batch_size, batch_denom, batches_per_epoch, boundary_epochs, decay_rates):
    initial_learning_rate = args.lr * batch_size / batch_denom
    # 速率衰减的边界
    boundaries = [int(batches_per_epoch * epoch) for epoch in boundary_epochs]
    # 速率
    vals = [initial_learning_rate * decay for decay in decay_rates]

    def learning_rate_fn(itr):
        lt = [itr < b for b in boundaries] + [True]
        i = np.argmax(lt)
        return vals[i]

    return learning_rate_fn

print(np.argmax([True,True,True,False,False]))


def one_hot(x, K):
    return np.array(x[:, None] == np.arange(K)[None, :], dtype=int)


def accuracy(model, dataset_loader):
    total_correct = 0
    for x, y in dataset_loader:
        x = x.to(device)
        y = one_hot(np.array(y.numpy()), 10)
        target_class = np.argmax(y, axis=1)
        predicted_class = np.argmax(model(x).cpu().detach().numpy(), axis=1)
        total_correct += np.sum(predicted_class == target_class)
    return total_correct / len(dataset_loader.dataset)


def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


def makedirs(dirname):
    if not os.path.exists(dirname):
        os.makedirs(dirname)


def get_logger(logpath, filepath, package_files=[], displaying=True, saving=True, debug=False):
    logger = logging.getLogger()
    if debug:
        level = logging.DEBUG
    else:
        level = logging.INFO
    logger.setLevel(level)
    if saving:
        info_file_handler = logging.FileHandler(logpath, mode="a")
        info_file_handler.setLevel(level)
        logger.addHandler(info_file_handler)
    if displaying:
        console_handler = logging.StreamHandler()
        console_handler.setLevel(level)
        logger.addHandler(console_handler)
    logger.info(filepath)
    with open(filepath, "r") as f:
        logger.info(f.read())

    for f in package_files:
        logger.info(f)
        with open(f, "r") as package_f:
            logger.info(package_f.read())

    return logger


0


In [0]:
class Arg(object):
  def __init__(self,network="odenet",tol=1e-4,adjoint=False,downsampling_method='res',nepochs=160,data_aug=True,lr=0.1,batch_size=128,test_batch_size=1000,
               save="./experiment1",gpu=0):
    self.network = network
    self.tol = tol
    self.adjoint = adjoint
    self.downsampling_method=downsampling_method
    self.nepochs  = nepochs
    self.data_aug = data_aug
    self.lr = lr
    self.batch_size = batch_size
    self.test_batch_size = test_batch_size
    self.save = save
    self.gpu = gpu

# parser = argparse.ArgumentParser()
# parser.add_argument('--network', type=str, choices=['resnet', 'odenet'], default='odenet')
# parser.add_argument('--tol', type=float, default=1e-3)
# parser.add_argument('--adjoint', type=eval, default=False, choices=[True, False])
# parser.add_argument('--downsampling-method', type=str, default='conv', choices=['conv', 'res'])
# parser.add_argument('--nepochs', type=int, default=160)
# parser.add_argument('--data_aug', type=eval, default=True, choices=[True, False])
# parser.add_argument('--lr', type=float, default=0.1)
# parser.add_argument('--batch_size', type=int, default=128)
# parser.add_argument('--test_batch_size', type=int, default=1000)

# parser.add_argument('--save', type=str, default='./experiment1')
# parser.add_argument('--debug', action='store_true')
# parser.add_argument('--gpu', type=int, default=0)
# args = parser.parse_args()
args = Arg()
print(args)


<__main__.Arg object at 0x7f2c783e7a58>


In [0]:

makedirs(args.save)
# logger = get_logger(logpath=os.path.join(args.save, 'logs'), filepath=os.path.abspath(__file__))
# logger.info(args)

device = torch.device('cuda:' + str(args.gpu) if torch.cuda.is_available() else 'cpu')

print(device)
is_odenet = args.network == 'odenet'
print(args)
# 下采样层
if args.downsampling_method == 'conv':
    downsampling_layers = [
        nn.Conv2d(3, 128, 3, 1),
        norm(128),
        nn.ReLU(inplace=True),
        nn.Conv2d(128, 128, 4, 2, 1),
        norm(128),
        nn.ReLU(inplace=True),
        nn.Conv2d(128, 128, 4, 2, 1),
    ]
elif args.downsampling_method == 'res':
    downsampling_layers = [
        nn.Conv2d(3, 128, 3, 1),
        ResBlock(128, 128, stride=2, downsample=conv1x1(128, 128, 2)),
        ResBlock(128, 128, stride=2, downsample=conv1x1(128, 128, 2)),
    ]
# 特征层
feature_layers = [ODEBlock(ODEfunc(128))] if is_odenet else [ResBlock(64, 64) for _ in range(6)]
# 全连接层
fc_layers = [norm(128), nn.ReLU(inplace=True), nn.AdaptiveAvgPool2d((1, 1)), Flatten(),nn.Linear(128, 10)]

# 需要训练的模型
model = nn.Sequential(*downsampling_layers, *feature_layers, *fc_layers).to(device)

# logger.info(model)
print('Number of parameters: {}'.format(count_parameters(model)))

# 损失函数，交叉熵
criterion = nn.CrossEntropyLoss().to(device)

# 训练数据
train_loader, test_loader, train_eval_loader = get_mnist_loaders(
    args.data_aug, args.batch_size, args.test_batch_size
)

# 生成数据的迭代器
data_gen = inf_generator(train_loader)
batches_per_epoch = len(train_loader)

# 训练参数
lr_fn = learning_rate_with_decay(
    args.batch_size, batch_denom=128, batches_per_epoch=batches_per_epoch, boundary_epochs=[40, 80, 120],
    decay_rates=[1,0.1,0.01,0.001]
)

# # 优化函数
optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=0.9)

best_acc = 0
batch_time_meter = RunningAverageMeter()
f_nfe_meter = RunningAverageMeter()
b_nfe_meter = RunningAverageMeter()
end = time.time()

# 训练
for itr in range(args.nepochs * batches_per_epoch):
    # 限定优化器的学习速率，与迭代次数有关
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr_fn(itr)

    optimizer.zero_grad()
    x, y = data_gen.__next__()
    # print(x.shape,y.shape)
    # assert 1==0
    x = x.to(device)
    y = y.to(device)
    # 前向网络
    logits = model(x)
    print(logits.shape,y.shape)
    # 计算损失
    loss = criterion(logits, y)
    print(loss)
    if is_odenet:
        nfe_forward = feature_layers[0].nfe
        feature_layers[0].nfe = 0

    # 反向传播
    loss.backward()
    optimizer.step()

    if is_odenet:
        nfe_backward = feature_layers[0].nfe
        feature_layers[0].nfe = 0

    # 计算一这个batch的时间
    batch_time_meter.update(time.time() - end)

    # 前馈时间和反向时间
    if is_odenet:
        f_nfe_meter.update(nfe_forward)
        b_nfe_meter.update(nfe_backward)
    end = time.time()

    if itr % batches_per_epoch == 0:
      with torch.no_grad():
          train_acc = accuracy(model, train_eval_loader)
          val_acc = accuracy(model, test_loader)
          if val_acc > best_acc:
              torch.save({'state_dict': model.state_dict(), 'args': args}, os.path.join(args.save, 'model.pth'))
              best_acc = val_acc
          print(
              "Epoch {:04d} | Time {:.3f} ({:.3f}) | NFE-F {:.1f} | NFE-B {:.1f} | "
              "Train Acc {:.4f} | Test Acc {:.4f}".format(
                  itr // batches_per_epoch, batch_time_meter.val, batch_time_meter.avg, f_nfe_meter.avg,
                  b_nfe_meter.avg, train_acc, val_acc
              )
          )
