<a href="https://colab.research.google.com/github/yuzhi535/resnet-pytorch/blob/master/PL_resnet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# 安装依赖

In [None]:
%pip install pytorch-lightning torchmetrics einops albumentations timm

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting pytorch-lightning
  Downloading pytorch_lightning-1.7.3-py3-none-any.whl (705 kB)
[K     |████████████████████████████████| 705 kB 30.8 MB/s 
[?25hCollecting torchmetrics
  Downloading torchmetrics-0.9.3-py3-none-any.whl (419 kB)
[K     |████████████████████████████████| 419 kB 70.3 MB/s 
[?25hCollecting einops
  Downloading einops-0.4.1-py3-none-any.whl (28 kB)
Collecting timm
  Downloading timm-0.6.7-py3-none-any.whl (509 kB)
[K     |████████████████████████████████| 509 kB 73.4 MB/s 
Collecting tensorboard>=2.9.1
  Downloading tensorboard-2.10.0-py3-none-any.whl (5.9 MB)
[K     |████████████████████████████████| 5.9 MB 57.8 MB/s 
[?25hCollecting pyDeprecate>=0.3.1
  Downloading pyDeprecate-0.3.2-py3-none-any.whl (10 kB)
Installing collected packages: torchmetrics, tensorboard, pyDeprecate, timm, pytorch-lightning, einops
  Attempting uninstall: tensorboard
    Found e

# 定义网络

In [None]:
import torch
import torch.nn as nn
from torch.nn import functional as F
from timm.models.layers import trunc_normal_


# 基本块
class Conv(nn.Module):
    def __init__(self, in_chan, out_chan, stride) -> None:
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_chan, out_chan, 3, stride, 1, bias=False),
            nn.BatchNorm2d(out_chan),
            nn.ReLU(),
        )

    def forward(self, x):
        return self.conv(x)


class BasicLayerX2(nn.Module):
    def __init__(self, in_chan, stride, down_sample: bool) -> None:
        super().__init__()
        out_chan = in_chan
        in_chan = in_chan // 2 if down_sample else in_chan

        # print(f'in_chan={in_chan}, out_chan={out_chan}')

        self.conv = nn.Sequential(
            Conv(in_chan, out_chan, stride),
            Conv(out_chan, out_chan, 1),
        )

        self.shortcut = None
        self.act = nn.ReLU()

        if down_sample:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_chan, out_chan, 1, stride, bias=False),
                nn.BatchNorm2d(out_chan),
            )
        else:
            self.shortcut = nn.Identity()

    def forward(self, x):
        conv = self.conv(x)
        shortcut = self.shortcut(x)
        return self.act(conv+shortcut)


class BasicLayerX3(nn.Module):
    def __init__(self, in_chan, stride, down_sample: bool) -> None:
        super().__init__()
        out_chan = in_chan
        in_chan = in_chan // 2 if down_sample else in_chan
        self.conv = nn.Sequential(
            Conv(in_chan, out_chan, stride),
            Conv(out_chan, out_chan, stride),
        )

        self.shortcut = None
        self.act = nn.ReLU()

        if down_sample:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_chan, out_chan, 1, stride, bias=False),
                nn.BatchNorm2d(out_chan),
            )
        else:
            self.shortcut = nn.Identity()

    def forward(self, x):
        return self.act(self.conv(x)+self.shortcut(x))


class Resnet(nn.Module):
    def __init__(self, num_classes, num_layers: list, chan: list) -> None:
        super().__init__()

        # 原论文是先来个7X7卷积，然后下采样，但是考虑数据集为CIFA-10，就稍微改了一下，希望影响不会太大
        self.first_conv = nn.Sequential(
            nn.Conv2d(3, chan[0], 3, 1, 1, bias=False),
            nn.BatchNorm2d(chan[0]),
        )

        self.net = nn.ModuleList()

        for i in range(len(num_layers)):
            # 除了第一层，其他层都是先下采样的block，接着普通block
            self.net += self._make_block(num_layers[i],
                                         chan[i], 2 if i != 0 else 1)

        self.fc = nn.Linear(chan[-1], num_classes)

        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    def _make_block(self, num_layer, chan, stride):
        layer = []
        is_down = False

        strides = [stride] + [1] * (num_layer - 1)

        for stride in strides:
            if stride == 2:
                is_down = True
            layer.append(BasicLayerX2(chan, stride, is_down))
            is_down = False

        return layer

    def forward(self, x):
        out = self.first_conv(x)
        for net in self.net:
            out = net(out)

        out = F.avg_pool2d(out, out.shape[2])
        out = out.reshape(out.shape[0], -1)
        out = self.fc(out)

        return out


class Restnet34(nn.Module):
    def __init__(self, num_classes) -> None:
        super().__init__()

        self.net = Resnet(num_classes, [3, 4, 6, 3], [16, 32, 64, 128])

    def forward(self, x):
        return self.net(x)


if __name__ == '__main__':
    x = torch.randn(2, 3, 32, 32)

    net = Restnet34(10)
    out = net(x)
    print(out.shape)


torch.Size([2, 10])


# 数据准备

In [None]:
from torch.utils.data import DataLoader, Dataset
from torchvision.datasets import CIFAR10
# import albumentations as A
import numpy as np
from torchvision import transforms

# from sklearn.model_selection import train_test_split


def get_CIFAdataset_loader(root, batch_size, num_workers, pin_memory, valid_rate, shuffle: bool, random_seed=42, augment=True):
    # 引入分割CIFA数据集的包，分割数据为训练集和验证集
    from torch.utils.data.sampler import SubsetRandomSampler

    # 预处理
    train_transform = transforms.Compose([
        transforms.Compose([transforms.RandomCrop(32, padding=4),
                            transforms.RandomHorizontalFlip(),
                            transforms.ToTensor(), ]) if augment else transforms.ToTensor(),
        transforms.Normalize(
            mean=[0.4914, 0.4822, 0.4465],
            std=[0.2023, 0.1994, 0.2010],
        )
    ]
    )

    val_transform = transforms.Compose([
        transforms.Compose([transforms.RandomCrop(32, padding=4),
                            transforms.RandomHorizontalFlip(),
                            transforms.ToTensor(), ]) if augment else transforms.ToTensor(),
        transforms.Normalize(
            mean=[0.4914, 0.4822, 0.4465],
            std=[0.2023, 0.1994, 0.2010],
        )
    ]
    )

    test_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(
            mean=[0.4914, 0.4822, 0.4465],
            std=[0.2023, 0.1994, 0.2010],
        )
    ]
    )

    # CIFA-10数据集
    train_dataset = CIFAR10(root=root, train=True,
                            download=True, transform=train_transform)
    val_dataset = CIFAR10(root=root, train=True,
                          download=False, transform=val_transform)
    test_dataset = CIFAR10(root=root, train=False,
                           download=True, transform=test_transform)

    # 分割数据集
    num_train = len(train_dataset)
    indices = list(range(num_train))
    split = int(np.floor(valid_rate * num_train))

    if shuffle:
        np.random.seed(random_seed)
        np.random.shuffle(indices)

    train_idx, valid_idx = indices[split:], indices[:split]
    train_sampler = SubsetRandomSampler(train_idx)
    valid_sampler = SubsetRandomSampler(valid_idx)

    # 生成dataloader
    train_loader = DataLoader(
        train_dataset, batch_size=batch_size, sampler=train_sampler,
        num_workers=num_workers, pin_memory=pin_memory,
    )
    valid_loader = DataLoader(
        val_dataset, batch_size=batch_size, sampler=valid_sampler,
        num_workers=num_workers, pin_memory=pin_memory,
    )

    test_loader = DataLoader(
        test_dataset, batch_size=1,
        num_workers=num_workers, pin_memory=pin_memory,
    )

    return train_loader, valid_loader, test_loader



# 开始训练

In [None]:
import argparse
import torch
import torch.nn as nn
import pytorch_lightning as pl
import torchmetrics
# from utils.dataloader import get_CIFAdataset_loader
# from networks.resnet import Restnet34
from pytorch_lightning.callbacks import EarlyStopping, LearningRateMonitor, ModelCheckpoint


def arg_parser():
    parser = argparse.ArgumentParser()

    parser.add_argument('--batch-size', '-bs', type=int,
                        default=16, required=True, help='input batch size')
    parser.add_argument('--num-workers', '-nw',  type=int,
                        default=4, required=True, help='number of workers')
    # parser.add_argument('--resume', '-r', type=str,
    #                     required=False, help='resume a train')
    parser.add_argument('--device', type=str,
                        help='gpu or cpu', choices=['gpu', 'cpu'], default='gpu')
    parser.add_argument('--num-classes', '-nc', type=int,
                        help='number of classes', required=True)
    parser.add_argument('--lr', '-lr', type=float, default=1e-4)
    parser.add_argument('--epochs', type=int,
                        required=True,  help='num of epochs')

    args = parser.parse_args()
    return args


# 配置训练步骤
class MyResnet(pl.LightningModule):
    def __init__(self, net,
                 lr=1e-4, ) -> None:
        super(MyResnet, self).__init__()
        self.net = net
        # self.save_hyperparameters(ignore=['criterion', 'net'])
        self.criterion = nn.CrossEntropyLoss()
        self.lr = lr
        self.acc = torchmetrics.Accuracy(num_classes=num_classes)

        # 损失函数
        self.criterion = nn.CrossEntropyLoss()

    def forward(self, x):
        return self.net(x)

    def training_step(self, batch, batch_idx):
        x, y,  = batch
        pred = self(x)
        loss = self.criterion(pred, y)
        self.log('loss', loss,
                 logger=True, enable_graph=True)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y, = batch

        pred = self(x)
        self.acc.update(pred, y)
        loss = self.criterion(pred, y)
        self.log('eval_loss', loss, enable_graph=True)

        return loss
    
    def test_step(self, batch, batch_idx):
        x, y, = batch

        pred = self(x)
        self.acc.update(pred, y)
        loss = self.criterion(pred, y)
        self.log('eval_loss', loss, enable_graph=True)
        
        return loss
    
    def test_epoch_end(self, outputs) -> None:
        print(self.acc.compute())

    def validation_epoch_end(self, outputs):
        acc = self.acc.compute()
        self.log('acc', acc, enable_graph=True)
        self.acc.reset()

    def configure_optimizers(self):
        opt = torch.optim.Adam(self.net.parameters(),
                               lr=self.lr)
        # sch = torch.optim.lr_scheduler.CosineAnnealingLR(
        # opt, T_max=self.epochs, eta_min=1e-5)
        # sch = torch.optim.lr_scheduler.StepLR(
        # opt, step_size=5, gamma=0.9, last_epoch=-1, verbose=False)
        # return [opt], [sch]
        return opt


def train(epochs, bs, nw, lr, num_classes, device):

    pl.seed_everything(42)

    net = Restnet34(num_classes)

    # 设置数据集加载，由于CIFA-10本身pytorch自带有，所以这里比较简单
    train_dataloader, val_dataloader, test_dataloader = \
        get_CIFAdataset_loader(
            root='./data/CIFA', batch_size=bs, num_workers=nw, pin_memory=True, valid_rate=0.2, shuffle=True)

    model = MyResnet(net=net,
                     lr=lr)

    early_stop_callback = EarlyStopping(
        monitor="acc", min_delta=0.0, patience=30, verbose=True, mode="max")

    lr_monitor = LearningRateMonitor(logging_interval='epoch')

    checkpoint_callback = ModelCheckpoint(
        save_top_k=5, monitor='acc', mode='max', filename="{epoch:03d}-{eval_miou:.4f}", )

    trainer = pl.Trainer(max_epochs=epochs,
                         log_every_n_steps=10,
                         benchmark=True,
                         check_val_every_n_epoch=1,
                         gradient_clip_val=0.5,
                         devices=1,
                         accelerator=device,
                         callbacks=[
                             checkpoint_callback,
                             lr_monitor,
                             early_stop_callback,
                         ],
    )

    trainer.logger._default_hp_metric = None

    resume_path = None

    trainer.fit(model=model, ckpt_path=resume_path, train_dataloaders=train_dataloader, val_dataloaders=val_dataloader)

    trainer.test(model=model, dataloaders=test_dataloader)

if __name__ == '__main__':
    # args = arg_parser()
    bs = 128  # args.batch_size
    num_workers = 2  # args.num_workers
    device = 'gpu'  # args.device
    num_classes = 10  # args.num_classes
    lr = 0.001  # args.num_classes
    epochs = 10  # args.epochs
    train(epochs, bs, num_workers, lr, num_classes, device=device)


INFO:pytorch_lightning.utilities.seed:Global seed set to 42


Files already downloaded and verified
Files already downloaded and verified


INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:pytorch_lightning.callbacks.model_summary:
  | Name      | Type             | Params
-----------------------------------------------
0 | net       | Restnet34        | 1.3 M 
1 | criterion | CrossEntropyLoss | 0     
2 | acc       | Accuracy         | 0     
-----------------------------------------------
1.3 M     Trainable params
0         Non-trainable params
1.3 M     Total params
5.338     Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric acc improved. New best score: 0.495


Validation: 0it [00:00, ?it/s]

Traceback (most recent call last):
  File "/usr/lib/python3.7/multiprocessing/queues.py", line 242, in _feed
    send_bytes(obj)
  File "/usr/lib/python3.7/multiprocessing/connection.py", line 200, in send_bytes
    self._send_bytes(m[offset:offset + size])
  File "/usr/lib/python3.7/multiprocessing/connection.py", line 404, in _send_bytes
    self._send(header + buf)
  File "/usr/lib/python3.7/multiprocessing/connection.py", line 368, in _send
    n = write(self._handle, buf)
OSError: [Errno 9] Bad file descriptor
Exception in thread QueueFeederThread:
Traceback (most recent call last):
  File "/usr/lib/python3.7/multiprocessing/queues.py", line 232, in _feed
    close()
  File "/usr/lib/python3.7/multiprocessing/connection.py", line 177, in close
    self._close()
  File "/usr/lib/python3.7/multiprocessing/connection.py", line 361, in _close
    _close(self._handle)
OSError: [Errno 9] Bad file descriptor

During handling of the above exception, another exception occurred:

Traceback 

Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric acc improved by 0.027 >= min_delta = 0.0. New best score: 0.645


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric acc improved by 0.054 >= min_delta = 0.0. New best score: 0.700


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric acc improved by 0.022 >= min_delta = 0.0. New best score: 0.722


Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric acc improved by 0.056 >= min_delta = 0.0. New best score: 0.778


Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric acc improved by 0.016 >= min_delta = 0.0. New best score: 0.794


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=10` reached.
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing: 0it [00:00, ?it/s]

tensor(0.7858, device='cuda:0')
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
        eval_loss           0.6506151556968689
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


In [None]:
%load_ext tensorboard
%tensorboard --logdir=lightning_logs

The tensorboard extension is already loaded. To reload it, use:
  %reload_ext tensorboard


<IPython.core.display.Javascript object>