<a href="https://colab.research.google.com/github/yuzhi535/resnet-pytorch/blob/master/resnet_cifar10.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# 模型定义

In [1]:
%pip install timm torchmetrics

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting timm
  Downloading timm-0.6.7-py3-none-any.whl (509 kB)
[K     |████████████████████████████████| 509 kB 5.1 MB/s 
[?25hCollecting torchmetrics
  Downloading torchmetrics-0.9.3-py3-none-any.whl (419 kB)
[K     |████████████████████████████████| 419 kB 53.9 MB/s 
Installing collected packages: torchmetrics, timm
Successfully installed timm-0.6.7 torchmetrics-0.9.3


In [2]:
import torch
import torch.nn as nn
from torch.nn import functional as F
from timm.models.layers import trunc_normal_


# 基本块
class Conv(nn.Module):
    def __init__(self, in_chan, out_chan, stride) -> None:
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_chan, out_chan, 3, stride, 1, bias=False),
            nn.BatchNorm2d(out_chan),
            nn.ReLU(),
        )
        

    def forward(self, x):
        return self.conv(x)




class BasicLayerX2(nn.Module):
    def __init__(self, in_chan, stride, down_sample: bool) -> None:
        super().__init__()
        out_chan = in_chan
        in_chan = in_chan // 2 if down_sample else in_chan

        print(f'in_chan={in_chan}, out_chan={out_chan}')

        self.conv = nn.Sequential(
            Conv(in_chan, out_chan, stride),
            Conv(out_chan, out_chan, 1),
        )

        self.shortcut = None
        self.act = nn.ReLU()

        if down_sample:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_chan, out_chan, 1, stride, bias=False),
                nn.BatchNorm2d(out_chan),
            )
        else:
            self.shortcut = nn.Identity()

    def forward(self, x):
        conv = self.conv(x)
        shortcut = self.shortcut(x)
        return self.act(conv+shortcut)


class BasicLayerX3(nn.Module):
    def __init__(self, in_chan, stride, down_sample: bool) -> None:
        super().__init__()
        out_chan = in_chan
        in_chan = in_chan // 2 if down_sample else in_chan
        self.conv = nn.Sequential(
            Conv(in_chan, out_chan, stride),
            Conv(out_chan, out_chan, stride),
        )

        self.shortcut = None
        self.act = nn.ReLU()

        if down_sample:
            self.down_sample = nn.Sequential(
                nn.Conv2d(in_chan, out_chan, 1, stride, bias=False),
                nn.BatchNorm2d(out_chan),
            )
        else:
            self.down_sample = nn.Identity()

    def forward(self, x):
        return self.act(self.conv(x)+self.shortcut(x))


class Resnet(nn.Module):
    def __init__(self, num_classes, num_layers: list, chan: list) -> None:
        super().__init__()

        # 原论文是先来个7X7卷积，然后下采样，但是考虑数据集为CIFA-10，就稍微改了一下，希望影响不会太大
        self.first_conv = nn.Sequential(
            nn.Conv2d(3, chan[0], 3, 1, 1, bias=False),
            nn.BatchNorm2d(chan[0]),
        )

        self.net = nn.ModuleList()

        for i in range(len(num_layers)):
            # 除了第一层，其他层都是先下采样的block，接着普通block
            self.net += self._make_block(num_layers[i],
                                         chan[i], 2 if i != 0 else 1)

        self.fc = nn.Linear(chan[-1], num_classes)

        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    def _make_block(self, num_layer, chan, stride):
        layer = []
        is_down = False

        strides = [stride] + [1] * (num_layer - 1)

        for stride in strides:
            if stride == 2:
                is_down = True
            layer.append(BasicLayerX2(chan, stride, is_down))
            is_down = False

        return layer

    def forward(self, x):
        out = self.first_conv(x)
        for net in self.net:
            out = net(out)

        out = F.avg_pool2d(out, out.shape[2])
        out = out.reshape(out.shape[0], -1)
        out = self.fc(out)

        return out


class Resnet34(nn.Module):
    def __init__(self, num_classes) -> None:
        super().__init__()

        self.net = Resnet(num_classes, [3, 4, 6, 3], [16, 32, 64, 128])

    def forward(self, x):
        return self.net(x)


if __name__ == '__main__':
    x = torch.randn(2, 3, 32, 32)

    net = Resnet34(10)
    out = net(x)
    print(out.shape)


in_chan=16, out_chan=16
in_chan=16, out_chan=16
in_chan=16, out_chan=16
in_chan=16, out_chan=32
in_chan=32, out_chan=32
in_chan=32, out_chan=32
in_chan=32, out_chan=32
in_chan=32, out_chan=64
in_chan=64, out_chan=64
in_chan=64, out_chan=64
in_chan=64, out_chan=64
in_chan=64, out_chan=64
in_chan=64, out_chan=64
in_chan=64, out_chan=128
in_chan=128, out_chan=128
in_chan=128, out_chan=128
torch.Size([2, 10])


# 数据准备

In [3]:
from torch.utils.data import DataLoader, Dataset
from torchvision.datasets import CIFAR10
# import albumentations as A
import numpy as np
from torchvision import transforms

# from sklearn.model_selection import train_test_split





def get_CIFAdataset_loader(root, batch_size, num_workers, pin_memory, valid_rate, shuffle: bool, random_seed=42, augment=True):
    # 引入分割CIFA数据集的包，分割数据为训练集和验证集
    from torch.utils.data.sampler import SubsetRandomSampler

    # 预处理
    train_transform = transforms.Compose([
        transforms.Compose([transforms.RandomCrop(32, padding=4),
                            transforms.RandomHorizontalFlip(),
                            transforms.ToTensor(), ]) if augment else transforms.ToTensor(),
        transforms.Normalize(
            mean=[0.4914, 0.4822, 0.4465],
            std=[0.2023, 0.1994, 0.2010],
        )
    ]
    )

    val_transform = transforms.Compose([
        transforms.Compose([transforms.RandomCrop(32, padding=4),
                            transforms.RandomHorizontalFlip(),
                            transforms.ToTensor(), ]) if augment else transforms.ToTensor(),
        transforms.Normalize(
            mean=[0.4914, 0.4822, 0.4465],
            std=[0.2023, 0.1994, 0.2010],
        )
    ]
    )

    test_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(
            mean=[0.4914, 0.4822, 0.4465],
            std=[0.2023, 0.1994, 0.2010],
        )
    ]
    )

    # CIFAR-10数据集
    train_dataset = CIFAR10(root=root, train=True,
                            download=True, transform=train_transform)
    val_dataset = CIFAR10(root=root, train=True,
                          download=False, transform=val_transform)
    test_dataset = CIFAR10(root=root, train=False,
                           download=True, transform=test_transform)

    # 分割数据集
    num_train = len(train_dataset)
    indices = list(range(num_train))
    split = int(np.floor(valid_rate * num_train))

    if shuffle:
        np.random.seed(random_seed)
        np.random.shuffle(indices)

    train_idx, valid_idx = indices[split:], indices[:split]
    train_sampler = SubsetRandomSampler(train_idx)
    valid_sampler = SubsetRandomSampler(valid_idx)

    # 生成dataloader
    train_loader = DataLoader(
        train_dataset, batch_size=batch_size, sampler=train_sampler,
        num_workers=num_workers, pin_memory=pin_memory,
    )
    valid_loader = DataLoader(
        val_dataset, batch_size=batch_size, sampler=valid_sampler,
        num_workers=num_workers, pin_memory=pin_memory,
    )

    test_loader = DataLoader(
        test_dataset, batch_size=1,
        num_workers=num_workers, pin_memory=pin_memory,
    )

    return train_loader, valid_loader, test_loader

# 训练
1. 优化器：Adam
2. 学习率：0.001
3. batch size： 64
4. 网络：resnet34
5. 数据集：CIFA-10
6. 训练轮数: 10
7. accu: 0.81

In [4]:
import random
import torch
import os
import torchmetrics
from argparse import ArgumentParser
import torch.nn as nn
from tqdm import tqdm
import numpy as np
from torch.utils.tensorboard import SummaryWriter



def seed_everything(seed: int):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True


def arg_parser():
    parser = ArgumentParser()

    parser.add_argument('--batch-size', '-bs', type=int,
                        default=16, required=True, help='input batch size')
    parser.add_argument('--num-workers', '-nw',  type=int,
                        default=4, required=True, help='number of workers')
    # parser.add_argument('--resume', '-r', type=str,
    #                     required=False, help='resume a train')
    parser.add_argument('--device', type=str,
                        help='gpu or cpu', choices=['gpu', 'cpu'], default='gpu')
    parser.add_argument('--num-classes', '-nc', type=int,
                        help='number of classes', required=True)
    parser.add_argument('--lr', '-lr', type=float, default=1e-4)
    parser.add_argument('--epochs', type=int,
                        required=True,  help='num of epochs')

    args = parser.parse_args()
    return args


def train_fn(net, dataloader, opt, device, criterion, writer, epoch):
    net.train()
    train_loss = 0
    criterion.to(device)
    for idx, (input, target) in dataloader:
        input = input.to(device)
        target = target.to(device)

        opt.zero_grad()
        pred = net(input)

        loss = criterion(pred, target)

        train_loss += loss.item()
        loss.backward()
        opt.step()

        cur_loss = train_loss/(idx+1)

        dataloader.set_postfix(loss=cur_loss)
        writer.add_scalar('training loss',
                          cur_loss,
                          epoch*len(dataloader)+idx)


def val_fn(net, dataloader, device, num_classes, writer, epoch: int):
    net.eval()
    metric = torchmetrics.Accuracy(numClass=num_classes).to(device)
    with torch.no_grad():
        for idx, (input, target) in dataloader:
            input = input.to(device)
            target = target.to(device)
            pred = net(input)
            acc = metric.update(pred, target)
        acc = metric.compute()
    writer.add_scalar('val_acc', acc, epoch*len(dataloader)+idx)
    return acc


def train(net, opt, epochs, batch_size, num_workers, device, num_classes, model='Resnet', scheduler=None):

    train_dataloader, val_dataloader, _ = get_CIFAdataset_loader(
        root='./data/CIFA', batch_size=batch_size, num_workers=num_workers, pin_memory=True, valid_rate=0.2, shuffle=True)

    # 模型权重位置
    model_path = 'runs'
    if not os.path.exists(model_path):
        os.mkdir(model_path)
    save_path = os.path.join(model_path, model)

    if not os.path.exists(save_path):
        os.mkdir(save_path)

    log_dir = os.path.join(model_path,  model, 'logs')
    if not os.path.exists(log_dir):
        os.mkdir(log_dir)

    writer = SummaryWriter(log_dir)

    net.to(device)

    best = 0.0

    early_stop_step = 0
    early_stop_limit = 15

    for idx in range(epochs):
        train_loop = tqdm(enumerate(train_dataloader),
                          total=len(train_dataloader), leave=True)
        train_loop.set_description(f'epoch: {idx}/{epochs}')

        train_fn(net=net, opt=opt,
                 dataloader=train_loop, device=device,
                 criterion=nn.CrossEntropyLoss(),
                 writer=writer, epoch=idx,
                 )

        val_loop = tqdm(enumerate(val_dataloader),
                        total=len(val_dataloader), leave=True)

        score = val_fn(net=net, dataloader=val_loop,
                       device=device, num_classes=num_classes, writer=writer, epoch=idx)

        print(f'acc={score}, best acc is {max(score, best)}')

        if (score > best):
            torch.save({
                'epoch': idx,
                'model_state_dict': net.state_dict(),
                'optimizer_state_dict': opt.state_dict(),
            }, os.path.join(save_path, f'epoch={idx}-miou={score:.4f}.pth'))
            best = score
            early_stop_step = 0
        else:
            if early_stop_step > early_stop_limit:
                print(f'因为已经有{early_stop_limit}轮没有提升，训练提前终止')
                writer.close()
                break
            early_stop_step += 1

        if scheduler:
            writer.add_scalar(
                "lr", scheduler.get_last_lr()[-1]
            )
            scheduler.step()
    writer.close()


if __name__ == '__main__':
    # args = arg_parser()
    bs = 64  # args.batch_size
    num_workers = 2  # args.num_workers
    device = 'cuda'  # args.device
    num_classes = 10  # args.num_classes
    lr = 0.001  # args.num_classes
    epochs = 10  # args.epochs

    seed_everything(42)
    net = Resnet(num_classes, [3, 4, 6, 3], [16, 32, 64, 128])
    opt = torch.optim.Adam(net.parameters(), lr=lr)

    train(net=net, epochs=epochs, batch_size=bs, num_workers=num_workers,
          device=device, num_classes=num_classes, opt=opt)


in_chan=16, out_chan=16
in_chan=16, out_chan=16
in_chan=16, out_chan=16
in_chan=16, out_chan=32
in_chan=32, out_chan=32
in_chan=32, out_chan=32
in_chan=32, out_chan=32
in_chan=32, out_chan=64
in_chan=64, out_chan=64
in_chan=64, out_chan=64
in_chan=64, out_chan=64
in_chan=64, out_chan=64
in_chan=64, out_chan=64
in_chan=64, out_chan=128
in_chan=128, out_chan=128
in_chan=128, out_chan=128
Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/CIFA/cifar-10-python.tar.gz


  0%|          | 0/170498071 [00:00<?, ?it/s]

Extracting ./data/CIFA/cifar-10-python.tar.gz to ./data/CIFA
Files already downloaded and verified


epoch: 0/10: 100%|██████████| 625/625 [00:40<00:00, 15.62it/s, loss=1.52]
100%|██████████| 157/157 [00:04<00:00, 34.28it/s]

acc=0.5236999988555908, best acc is 0.5236999988555908



epoch: 1/10: 100%|██████████| 625/625 [00:33<00:00, 18.53it/s, loss=1.11]
100%|██████████| 157/157 [00:04<00:00, 37.25it/s]

acc=0.6183000206947327, best acc is 0.6183000206947327



epoch: 2/10: 100%|██████████| 625/625 [00:33<00:00, 18.93it/s, loss=0.907]
100%|██████████| 157/157 [00:04<00:00, 37.71it/s]


acc=0.7002999782562256, best acc is 0.7002999782562256


epoch: 3/10: 100%|██████████| 625/625 [00:32<00:00, 19.18it/s, loss=0.779]
100%|██████████| 157/157 [00:05<00:00, 30.80it/s]

acc=0.7269999980926514, best acc is 0.7269999980926514



epoch: 4/10: 100%|██████████| 625/625 [00:33<00:00, 18.92it/s, loss=0.691]
100%|██████████| 157/157 [00:04<00:00, 36.37it/s]

acc=0.7429999709129333, best acc is 0.7429999709129333



epoch: 5/10: 100%|██████████| 625/625 [00:34<00:00, 18.08it/s, loss=0.62]
100%|██████████| 157/157 [00:04<00:00, 36.17it/s]

acc=0.7886000275611877, best acc is 0.7886000275611877



epoch: 6/10: 100%|██████████| 625/625 [00:33<00:00, 18.41it/s, loss=0.575]
100%|██████████| 157/157 [00:04<00:00, 36.45it/s]


acc=0.7986000180244446, best acc is 0.7986000180244446


epoch: 7/10: 100%|██████████| 625/625 [00:34<00:00, 17.92it/s, loss=0.532]
100%|██████████| 157/157 [00:04<00:00, 36.06it/s]


acc=0.8004999756813049, best acc is 0.8004999756813049


epoch: 8/10: 100%|██████████| 625/625 [00:34<00:00, 18.12it/s, loss=0.501]
100%|██████████| 157/157 [00:04<00:00, 36.48it/s]

acc=0.8004999756813049, best acc is 0.8004999756813049



epoch: 9/10: 100%|██████████| 625/625 [00:33<00:00, 18.49it/s, loss=0.468]
100%|██████████| 157/157 [00:04<00:00, 36.49it/s]


acc=0.8213000297546387, best acc is 0.8213000297546387


# 模型评估

In [5]:
def test(net, dataloader, device, num_classes):
    net.eval()
    # metric = torchmetrics.Accuracy(numClass=num_classes).to(device)
    correct=0.0
    with torch.no_grad():
        for  _, (input, target) in tqdm(enumerate(dataloader), total=len(dataloader), leave=True):
            input = input.to(device)
            target = target.to(device)
            pred = net(input)
            pred = F.softmax(pred, 1).argmax(1)
            correct += pred.eq(target).sum()
            # acc = metric.update(pred, target)
    # acc = metric.compute()
    print(correct / len(dataloader))

_, _, test_dataloader = get_CIFAdataset_loader(
        root='./data/CIFA', batch_size=128, num_workers=2, pin_memory=True, valid_rate=0.2, shuffle=True)

test(net=net, device=device, num_classes=num_classes, dataloader=test_dataloader)


Files already downloaded and verified
Files already downloaded and verified


100%|██████████| 10000/10000 [01:20<00:00, 123.78it/s]

tensor(0.8191, device='cuda:0')



