<a href="https://colab.research.google.com/github/yuzhi535/resnet-pytorch/blob/master/inceptionV3.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
%pip install timm torchmetrics

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


# 定义网络

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F


class Conv(nn.Module):
    '''
    conv+bn+relu
    '''

    def __init__(self, in_chan, out_chan, kernel_size=1, stride=1, padding=0) -> None:
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_chan, out_chan, kernel_size, stride, padding),
            nn.BatchNorm2d(out_chan, eps=0.001),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        return self.net(x)


class InceptionBlockx1(nn.Module):
    def __init__(self, in_chan, out_chan_pooling) -> None:
        super().__init__()
        self.branch1 = Conv(in_chan, 64)
        self.branch2 = nn.Sequential(
            Conv(in_chan, 48, kernel_size=1),
            Conv(48, 64, kernel_size=5, padding=2),
        )

        self.branch3 = nn.Sequential(
            Conv(in_chan, 64, kernel_size=1),
            Conv(64, 96, kernel_size=3, padding=1),
            Conv(96, 96, kernel_size=3, padding=1),
        )

        self.branch4 = nn.Sequential(
            nn.AvgPool2d(kernel_size=3, stride=1, padding=1),
            Conv(in_chan, out_chan_pooling, kernel_size=1),
        )

    def forward(self, x):
        out = torch.cat([self.branch1(x), self.branch2(
            x), self.branch3(x), self.branch4(x)], dim=1)
        # print(out.shape)
        return out


class InceptionBlockx2(nn.Module):
    def __init__(self, in_chan) -> None:
        super().__init__()
        self.branch1 = Conv(in_chan, 384, 3, 2)

        self.branch2 = nn.Sequential(
            Conv(in_chan, 64, kernel_size=1),
            Conv(64, 96, kernel_size=3, padding=1),
            Conv(96, 96, kernel_size=3, stride=2),
        )

        self.branch3 = nn.Sequential(
            nn.MaxPool2d(kernel_size=3, stride=2),
        )

    def forward(self, x):
        # print(x.shape)
        out1 = self.branch1(x)
        out2 = self.branch2(x)
        out3 = self.branch3(x)
        # print(out1.shape, out2.shape, out3.shape)
        out = torch.cat([out1, out2, out3], dim=1)
        return out


class InceptionBlockx3(nn.Module):
    def __init__(self, in_chan, internal_chan) -> None:
        super().__init__()

        self.branch1 = Conv(in_chan, 192, 1)

        self.branch2 = nn.Sequential(
            Conv(in_chan, internal_chan, 1),
            Conv(internal_chan, internal_chan, [1, 7], padding=[0, 3]),
            Conv(internal_chan, 192, [7, 1], padding=[3, 0]),
        )

        self.branch3 = nn.Sequential(
            Conv(in_chan, internal_chan, 1),
            Conv(internal_chan, internal_chan, [7, 1], padding=[3, 0]),
            Conv(internal_chan, internal_chan, [1, 7], padding=[0, 3]),
            Conv(internal_chan, internal_chan, [7, 1], padding=[3, 0]),
            Conv(internal_chan, internal_chan, [1, 7], padding=[0, 3]),
            Conv(internal_chan, 192, [1, 7], padding=[0, 3]),
        )

        self.branch4 = nn.Sequential(
            nn.AvgPool2d(kernel_size=3, stride=1, padding=1),
            Conv(in_chan, 192, kernel_size=1),
        )

    def forward(self, x):
        return torch.cat([self.branch1(x), self.branch2(x), self.branch3(x), self.branch4(x)], dim=1)


class InceptionBlockx4(nn.Module):
    def __init__(self, in_chan, ) -> None:
        super().__init__()

        self.branch1 = nn.Sequential(
            Conv(in_chan, 192, kernel_size=1),
            Conv(192, 320, kernel_size=3, stride=2),
        )

        self.branch2 = nn.Sequential(
            Conv(in_chan, 192, kernel_size=1),
            Conv(192, 192, kernel_size=[1, 7], padding=[0, 3]),
            Conv(192, 192, kernel_size=[7, 1], padding=[3, 0]),
            Conv(192, 192, kernel_size=3, stride=2),
        )

        self.branch3 = nn.Sequential(
            nn.MaxPool2d(kernel_size=3, stride=2),
        )

    def forward(self, x):
        return torch.cat([self.branch1(x), self.branch2(x), self.branch3(x)], dim=1)


class InceptionBlockx5(nn.Module):
    def __init__(self, in_chan) -> None:
        super().__init__()
        self.branch1 = nn.Sequential(
            Conv(in_chan, 320, kernel_size=1),
        )

        self.branch2x1 = Conv(in_chan, 384, kernel_size=1)
        self.branch2x2 = Conv(384, 384, kernel_size=[1, 3], padding=[0, 1])
        self.branch2x3 = Conv(384, 384, kernel_size=[3, 1], padding=[1, 0])

        self.branch3x1 = Conv(in_chan, 448, kernel_size=1)
        self.branch3x2 = Conv(448, 384, kernel_size=3, stride=1, padding=1)
        self.branch3x3 = Conv(384, 384, kernel_size=[1, 3], padding=[0, 1])
        self.branch3x4 = Conv(384, 384, kernel_size=[3, 1], padding=[1, 0])

        self.branch4 = nn.Sequential(
            nn.AvgPool2d(kernel_size=3, stride=1, padding=1),
            Conv(in_chan, 192, kernel_size=1),
        )

    def forward(self, x):
        out1 = self.branch1(x)
        out2 = self.branch2x1(x)
        out2 = torch.cat(
            [self.branch2x2(out2), self.branch2x3(out2)], dim=1)
        out3 = self.branch3x1(x)
        out3 = self.branch3x2(out3)
        out3 = torch.cat([self.branch3x3(out3), self.branch3x4(out3)], dim=1)
        out4 = self.branch4(x)
        return torch.cat([out1, out2, out3, out4], dim=1)


class GoogleNet(nn.Module):
    def __init__(self, nc: int) -> None:
        super().__init__()
        self.nc = nc

        self.conv1 = nn.Sequential(
            Conv(3, 32, kernel_size=3, stride=2),
            Conv(32, 32, kernel_size=3),
            Conv(32, 64, kernel_size=3, padding=1),
        )
        self.maxpool1 = nn.MaxPool2d(kernel_size=3, stride=2)

        self.conv2 = nn.Sequential(
            Conv(64, 80, kernel_size=1),
            Conv(80, 192, kernel_size=3, stride=1),
        )
        self.maxpool2 = nn.MaxPool2d(kernel_size=3, stride=2)

        self.mixer1 = nn.Sequential(
            InceptionBlockx1(192, 32),
            InceptionBlockx1(256, 64),
            InceptionBlockx1(288, 64),
        )

        self.mixer2 = nn.Sequential(
            InceptionBlockx2(288),
        )
        self.mixer3 = nn.Sequential(
            InceptionBlockx3(768, 128),
            InceptionBlockx3(768, 160),
            InceptionBlockx3(768, 160),
            InceptionBlockx3(768, 192),
        )

        self.mixer4 = nn.Sequential(
            InceptionBlockx4(768),
            InceptionBlockx5(1280),
            InceptionBlockx5(2048),
        )

        self.dropout = nn.Dropout(p=0.4)
        self.fc = nn.Linear(2048, self.nc)

        self.apply(self._init_weights)
    
    def _init_weights(self, m):
        if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
            stddev = float(m.stddev) if hasattr(
                m, "stddev") else 0.1  # type: ignore
            torch.nn.init.trunc_normal_(
                m.weight, mean=0.0, std=stddev, a=-2, b=2)
        elif isinstance(m, nn.BatchNorm2d):
            nn.init.constant_(m.weight, 1)
            nn.init.constant_(m.bias, 0)

    def forward(self, x):
        out = self.conv1(x)
        out = self.maxpool1(out)
        out = self.conv2(out)
        out = self.maxpool2(out)
        out = self.mixer1(out)
        # print(f'out1: {out.shape}')
        out = self.mixer2(out)
        out = self.mixer3(out)
        out = self.mixer4(out)
        out = F.adaptive_avg_pool2d(out, 1)
        out = self.dropout(out)
        out = out.view(out.size(0), -1)
        out = self.fc(out)
        return out


if __name__ == '__main__':
    x = torch.randn(2, 3, 224, 224)
    net = GoogleNet(nc=1000)
    print(net(x).shape)


torch.Size([2, 1000])


# 数据准备

In [3]:
from torch.utils.data import DataLoader, Dataset
from torchvision.datasets import CIFAR10
# import albumentations as A
import numpy as np
from torchvision import transforms

# from sklearn.model_selection import train_test_split





def get_CIFAdataset_loader(root, batch_size, num_workers, pin_memory, valid_rate, shuffle: bool, random_seed=42, augment=True):
    # 引入分割CIFA数据集的包，分割数据为训练集和验证集
    from torch.utils.data.sampler import SubsetRandomSampler

    # 预处理
    train_transform = transforms.Compose([
        transforms.Compose([transforms.RandomCrop(32, padding=4),
                            transforms.RandomHorizontalFlip(),
                            transforms.ToTensor(), ]) if augment else transforms.ToTensor(),
        transforms.Resize(224),
        transforms.Normalize(
            mean=[0.4914, 0.4822, 0.4465],
            std=[0.2023, 0.1994, 0.2010],
        )
    ]
    )

    val_transform = transforms.Compose([
        transforms.Compose([transforms.RandomCrop(32, padding=4),
                            transforms.RandomHorizontalFlip(),
                            transforms.ToTensor(), ]) if augment else transforms.ToTensor(),
        transforms.Resize(224),
        transforms.Normalize(
            mean=[0.4914, 0.4822, 0.4465],
            std=[0.2023, 0.1994, 0.2010],
        )
    ]
    )

    test_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Resize(224),
        transforms.Normalize(
            mean=[0.4914, 0.4822, 0.4465],
            std=[0.2023, 0.1994, 0.2010],
        )
    ]
    )

    # CIFAR-10数据集
    train_dataset = CIFAR10(root=root, train=True,
                            download=True, transform=train_transform)
    val_dataset = CIFAR10(root=root, train=True,
                          download=False, transform=val_transform)
    test_dataset = CIFAR10(root=root, train=False,
                           download=True, transform=test_transform)

    # 分割数据集
    num_train = len(train_dataset)
    indices = list(range(num_train))
    split = int(np.floor(valid_rate * num_train))

    if shuffle:
        np.random.seed(random_seed)
        np.random.shuffle(indices)

    train_idx, valid_idx = indices[split:], indices[:split]
    train_sampler = SubsetRandomSampler(train_idx)
    valid_sampler = SubsetRandomSampler(valid_idx)

    # 生成dataloader
    train_loader = DataLoader(
        train_dataset, batch_size=batch_size, sampler=train_sampler,
        num_workers=num_workers, pin_memory=pin_memory,
    )
    valid_loader = DataLoader(
        val_dataset, batch_size=batch_size, sampler=valid_sampler,
        num_workers=num_workers, pin_memory=pin_memory,
    )

    test_loader = DataLoader(
        test_dataset, batch_size=1,
        num_workers=num_workers, pin_memory=pin_memory,
    )

    return train_loader, valid_loader, test_loader


训练

    优化器：Adam
    学习率：0.001
    batch size： 256
    网络：resnet34
    数据集：CIFA-10
    训练轮数: 10
    accu: 0.81

In [5]:
import random
import torch
import os
import torchmetrics
from argparse import ArgumentParser
import torch.nn as nn
from tqdm import tqdm
import numpy as np
from torch.utils.tensorboard import SummaryWriter



def seed_everything(seed: int):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True


def arg_parser():
    parser = ArgumentParser()

    parser.add_argument('--batch-size', '-bs', type=int,
                        default=16, required=True, help='input batch size')
    parser.add_argument('--num-workers', '-nw',  type=int,
                        default=4, required=True, help='number of workers')
    # parser.add_argument('--resume', '-r', type=str,
    #                     required=False, help='resume a train')
    parser.add_argument('--device', type=str,
                        help='gpu or cpu', choices=['gpu', 'cpu'], default='gpu')
    parser.add_argument('--num-classes', '-nc', type=int,
                        help='number of classes', required=True)
    parser.add_argument('--lr', '-lr', type=float, default=1e-4)
    parser.add_argument('--epochs', type=int,
                        required=True,  help='num of epochs')

    args = parser.parse_args()
    return args


def train_fn(net, dataloader, opt, device, criterion, writer, epoch):
    net.train()
    train_loss = 0
    criterion.to(device)
    for idx, (input, target) in dataloader:
        input = input.to(device)
        target = target.to(device)

        opt.zero_grad()
        pred = net(input)

        loss = criterion(pred, target)

        train_loss += loss.item()
        loss.backward()
        opt.step()

        cur_loss = train_loss/(idx+1)

        dataloader.set_postfix(loss=cur_loss)
        writer.add_scalar('training loss',
                          cur_loss,
                          epoch*len(dataloader)+idx)


def val_fn(net, dataloader, device, num_classes, writer, epoch: int):
    net.eval()
    metric = torchmetrics.Accuracy(numClass=num_classes).to(device)
    with torch.no_grad():
        for idx, (input, target) in dataloader:
            input = input.to(device)
            target = target.to(device)
            pred = net(input)
            acc = metric.update(pred, target)
        acc = metric.compute()
    writer.add_scalar('val_acc', acc, epoch*len(dataloader)+idx)
    return acc


def train(net, opt, epochs, batch_size, num_workers, device, num_classes, model='Resnet', scheduler=None):

    train_dataloader, val_dataloader, _ = get_CIFAdataset_loader(
        root='./data/CIFA', batch_size=batch_size, num_workers=num_workers, pin_memory=True, valid_rate=0.2, shuffle=True)

    # 模型权重位置
    model_path = 'runs'
    if not os.path.exists(model_path):
        os.mkdir(model_path)
    save_path = os.path.join(model_path, model)

    if not os.path.exists(save_path):
        os.mkdir(save_path)

    log_dir = os.path.join(model_path,  model, 'logs')
    if not os.path.exists(log_dir):
        os.mkdir(log_dir)

    writer = SummaryWriter(log_dir)

    net.to(device)

    best = 0.0

    early_stop_step = 0
    early_stop_limit = 15

    for idx in range(epochs):
        train_loop = tqdm(enumerate(train_dataloader),
                          total=len(train_dataloader), leave=True)
        train_loop.set_description(f'epoch: {idx}/{epochs}')

        train_fn(net=net, opt=opt,
                 dataloader=train_loop, device=device,
                 criterion=nn.CrossEntropyLoss(),
                 writer=writer, epoch=idx,
                 )

        val_loop = tqdm(enumerate(val_dataloader),
                        total=len(val_dataloader), leave=True)

        score = val_fn(net=net, dataloader=val_loop,
                       device=device, num_classes=num_classes, writer=writer, epoch=idx)

        print(f'acc={score}, best acc is {max(score, best)}')

        if (score > best):
            torch.save({
                'epoch': idx,
                'model_state_dict': net.state_dict(),
                'optimizer_state_dict': opt.state_dict(),
            }, os.path.join(save_path, f'epoch={idx}-miou={score:.4f}.pth'))
            best = score
            early_stop_step = 0
        else:
            if early_stop_step > early_stop_limit:
                print(f'因为已经有{early_stop_limit}轮没有提升，训练提前终止')
                writer.close()
                break
            early_stop_step += 1

        if scheduler:
            writer.add_scalar(
                "lr", scheduler.get_last_lr()[-1]
            )
            scheduler.step()
    writer.close()


if __name__ == '__main__':
    # args = arg_parser()
    bs = 256  # args.batch_size
    num_workers = 2  # args.num_workers
    device = 'cuda'  # args.device
    num_classes = 10  # args.num_classes
    lr = 0.001  # args.num_classes
    epochs = 10  # args.epochs

    seed_everything(42)
    net = GoogleNet(num_classes)
    opt = torch.optim.Adam(net.parameters(), lr=lr)

    train(net=net, epochs=epochs, batch_size=bs, num_workers=num_workers,
          device=device, num_classes=num_classes, opt=opt)

Files already downloaded and verified
Files already downloaded and verified


epoch: 0/10: 100%|██████████| 157/157 [05:28<00:00,  2.09s/it, loss=2.02]
100%|██████████| 40/40 [00:23<00:00,  1.68it/s]


acc=0.41589999198913574, best acc is 0.41589999198913574


epoch: 1/10: 100%|██████████| 157/157 [05:20<00:00,  2.04s/it, loss=1.53]
100%|██████████| 40/40 [00:23<00:00,  1.73it/s]


acc=0.5206999778747559, best acc is 0.5206999778747559


epoch: 2/10: 100%|██████████| 157/157 [05:19<00:00,  2.04s/it, loss=1.31]
100%|██████████| 40/40 [00:23<00:00,  1.72it/s]


acc=0.5389999747276306, best acc is 0.5389999747276306


epoch: 3/10: 100%|██████████| 157/157 [05:20<00:00,  2.04s/it, loss=1.17]
100%|██████████| 40/40 [00:23<00:00,  1.72it/s]


acc=0.6118999719619751, best acc is 0.6118999719619751


epoch: 4/10: 100%|██████████| 157/157 [05:20<00:00,  2.04s/it, loss=0.991]
100%|██████████| 40/40 [00:23<00:00,  1.71it/s]


acc=0.6499000191688538, best acc is 0.6499000191688538


epoch: 5/10: 100%|██████████| 157/157 [05:20<00:00,  2.04s/it, loss=0.859]
100%|██████████| 40/40 [00:23<00:00,  1.72it/s]


acc=0.7031999826431274, best acc is 0.7031999826431274


epoch: 6/10: 100%|██████████| 157/157 [05:20<00:00,  2.04s/it, loss=0.789]
100%|██████████| 40/40 [00:23<00:00,  1.72it/s]


acc=0.7384999990463257, best acc is 0.7384999990463257


epoch: 7/10: 100%|██████████| 157/157 [05:20<00:00,  2.04s/it, loss=0.665]
100%|██████████| 40/40 [00:23<00:00,  1.73it/s]


acc=0.7491999864578247, best acc is 0.7491999864578247


epoch: 8/10: 100%|██████████| 157/157 [05:20<00:00,  2.04s/it, loss=0.614]
100%|██████████| 40/40 [00:23<00:00,  1.70it/s]

acc=0.7390000224113464, best acc is 0.7491999864578247



epoch: 9/10: 100%|██████████| 157/157 [05:20<00:00,  2.04s/it, loss=0.559]
100%|██████████| 40/40 [00:23<00:00,  1.71it/s]


acc=0.7843999862670898, best acc is 0.7843999862670898


# 模型评估

In [6]:
def test(net, dataloader, device, num_classes):
    net.eval()
    # metric = torchmetrics.Accuracy(numClass=num_classes).to(device)
    correct=0.0
    with torch.no_grad():
        for  _, (input, target) in tqdm(enumerate(dataloader), total=len(dataloader), leave=True):
            input = input.to(device)
            target = target.to(device)
            pred = net(input)
            pred = F.softmax(pred, 1).argmax(1)
            correct += pred.eq(target).sum()
            # acc = metric.update(pred, target)
    # acc = metric.compute()
    print(correct / len(dataloader))

_, _, test_dataloader = get_CIFAdataset_loader(
        root='./data/CIFA', batch_size=128, num_workers=2, pin_memory=True, valid_rate=0.2, shuffle=True)

test(net=net, device=device, num_classes=num_classes, dataloader=test_dataloader)


Files already downloaded and verified
Files already downloaded and verified


100%|██████████| 10000/10000 [02:58<00:00, 55.89it/s]

tensor(0.7928, device='cuda:0')



