<a href="https://colab.research.google.com/github/yuzhi535/resnet-pytorch/blob/master/alexnet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
%pip install einops torchmetrics

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [2]:
import torch.nn as nn
import torch
from einops.layers.torch import Rearrange

class Conv(nn.Module):
    def __init__(self, in_chan, out_chan, kernel_size=1, stride=1, padding=1) -> None:
        super().__init__()

        self.net = nn.Sequential(
            nn.Conv2d(in_chan, out_chan, kernel_size=kernel_size,
                      stride=stride, padding=padding),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        return self.net(x)


'''
AlexNet(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2))
    (1): ReLU(inplace=True)
    (2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(64, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (4): ReLU(inplace=True)
    (5): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (6): Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (7): ReLU(inplace=True)
    (8): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): ReLU(inplace=True)
    (10): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(6, 6))
  (classifier): Sequential(
    (0): Dropout(p=0.5, inplace=False)
    (1): Linear(in_features=9216, out_features=4096, bias=True)
    (2): ReLU(inplace=True)
    (3): Dropout(p=0.5, inplace=False)
    (4): Linear(in_features=4096, out_features=4096, bias=True)
    (5): ReLU(inplace=True)
    (6): Linear(in_features=4096, out_features=1000, bias=True)
  )
)

'''


class Alexnet(nn.Module):
    '''
    input size: batch size x 3 x 224 x 224
    output size: batch size x num_classes
    '''

    def __init__(self) -> None:
        super().__init__()

        self.conv1 = nn.Sequential(
            Conv(3, 64, kernel_size=7, stride=4, padding=2),
            nn.MaxPool2d(kernel_size=3, stride=2),
        )
        self.conv2 = nn.Sequential(
            Conv(64, 192, kernel_size=5, padding=2),
            nn.MaxPool2d(kernel_size=3, stride=2),
        )
        self.conv3 = nn.Sequential(
            Conv(192, 384, kernel_size=3, padding=1),
            Conv(384, 256, kernel_size=3, padding=1),
            Conv(256, 256, kernel_size=3, padding=1),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.ReLU(inplace=True),
        )

        self.dense = nn.Sequential(
            Rearrange('b c h w -> b (c h w)'),
            nn.Dropout(),
            nn.Linear(256 * 6 * 6, 4096),
            nn.ReLU(inplace=True),
            nn.Linear(4096, 4096),
            nn.ReLU(inplace=True),
            nn.Linear(4096, 10),
        )

    def forward(self, x):
        out = self.conv1(x)
        # print(out.shape)
        out = self.conv2(out)
        # print(out.shape)
        out = self.conv3(out)

        out = self.dense(out)

        return out


if __name__ == '__main__':
    net = Alexnet()
    x = torch.randn(2, 3, 224, 224)
    out = net(x)
    print(out.shape)


torch.Size([2, 10])


In [3]:
from torch.utils.data import DataLoader, Dataset
from torchvision.datasets import CIFAR10
# import albumentations as A
import numpy as np
from torchvision import transforms

# from sklearn.model_selection import train_test_split





def get_CIFAdataset_loader(root, batch_size, num_workers, pin_memory, valid_rate, shuffle: bool, random_seed=42, augment=True):
    # 引入分割CIFA数据集的包，分割数据为训练集和验证集
    from torch.utils.data.sampler import SubsetRandomSampler

    # 预处理
    train_transform = transforms.Compose([
        transforms.Compose([transforms.RandomCrop(32, padding=4),
                            transforms.RandomHorizontalFlip(),
                            transforms.ToTensor(), ]) if augment else transforms.ToTensor(),
        transforms.Resize(224),
        transforms.Normalize(
            mean=[0.4914, 0.4822, 0.4465],
            std=[0.2023, 0.1994, 0.2010],
        )
    ]
    )

    val_transform = transforms.Compose([
        transforms.Compose([transforms.RandomCrop(32, padding=4),
                            transforms.RandomHorizontalFlip(),
                            transforms.ToTensor(), ]) if augment else transforms.ToTensor(),
                            transforms.Resize(224),
        transforms.Normalize(
            mean=[0.4914, 0.4822, 0.4465],
            std=[0.2023, 0.1994, 0.2010],
        )
    ]
    )

    test_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Resize(224),
        transforms.Normalize(
            mean=[0.4914, 0.4822, 0.4465],
            std=[0.2023, 0.1994, 0.2010],
        )
    ]
    )

    # CIFAR-10数据集
    train_dataset = CIFAR10(root=root, train=True,
                            download=True, transform=train_transform)
    val_dataset = CIFAR10(root=root, train=True,
                          download=False, transform=val_transform)
    test_dataset = CIFAR10(root=root, train=False,
                           download=True, transform=test_transform)

    # 分割数据集
    num_train = len(train_dataset)
    indices = list(range(num_train))
    split = int(np.floor(valid_rate * num_train))

    if shuffle:
        np.random.seed(random_seed)
        np.random.shuffle(indices)

    train_idx, valid_idx = indices[split:], indices[:split]
    train_sampler = SubsetRandomSampler(train_idx)
    valid_sampler = SubsetRandomSampler(valid_idx)

    # 生成dataloader
    train_loader = DataLoader(
        train_dataset, batch_size=batch_size, sampler=train_sampler,
        num_workers=num_workers, pin_memory=pin_memory,
    )
    valid_loader = DataLoader(
        val_dataset, batch_size=batch_size, sampler=valid_sampler,
        num_workers=num_workers, pin_memory=pin_memory,
    )

    test_loader = DataLoader(
        test_dataset, batch_size=1,
        num_workers=num_workers, pin_memory=pin_memory,
    )

    return train_loader, valid_loader, test_loader

In [4]:
import random
import torch
import os
import torchmetrics
from argparse import ArgumentParser
import torch.nn as nn
from tqdm import tqdm
import numpy as np
from torch.utils.tensorboard import SummaryWriter



def seed_everything(seed: int):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True


def arg_parser():
    parser = ArgumentParser()

    parser.add_argument('--batch-size', '-bs', type=int,
                        default=16, required=True, help='input batch size')
    parser.add_argument('--num-workers', '-nw',  type=int,
                        default=4, required=True, help='number of workers')
    # parser.add_argument('--resume', '-r', type=str,
    #                     required=False, help='resume a train')
    parser.add_argument('--device', type=str,
                        help='gpu or cpu', choices=['gpu', 'cpu'], default='gpu')
    parser.add_argument('--num-classes', '-nc', type=int,
                        help='number of classes', required=True)
    parser.add_argument('--lr', '-lr', type=float, default=1e-4)
    parser.add_argument('--epochs', type=int,
                        required=True,  help='num of epochs')

    args = parser.parse_args()
    return args


def train_fn(net, dataloader, opt, device, criterion, writer, epoch):
    net.train()
    train_loss = 0
    criterion.to(device)
    for idx, (input, target) in dataloader:
        input = input.to(device)
        target = target.to(device)

        opt.zero_grad()
        pred = net(input)

        loss = criterion(pred, target)

        train_loss += loss.item()
        loss.backward()
        opt.step()

        cur_loss = train_loss/(idx+1)

        dataloader.set_postfix(loss=cur_loss)
        writer.add_scalar('training loss',
                          cur_loss,
                          epoch*len(dataloader)+idx)


def val_fn(net, dataloader, device, num_classes, writer, epoch: int):
    net.eval()
    metric = torchmetrics.Accuracy(numClass=num_classes).to(device)
    with torch.no_grad():
        for idx, (input, target) in dataloader:
            input = input.to(device)
            target = target.to(device)
            pred = net(input)
            acc = metric.update(pred, target)
        acc = metric.compute()
    writer.add_scalar('val_acc', acc, epoch*len(dataloader)+idx)
    return acc


def train(net, opt, epochs, batch_size, num_workers, device, num_classes, model='Resnet', scheduler=None):

    train_dataloader, val_dataloader, _ = get_CIFAdataset_loader(
        root='./data/CIFA', batch_size=batch_size, num_workers=num_workers, pin_memory=True, valid_rate=0.2, shuffle=True)

    # 模型权重位置
    model_path = 'runs'
    if not os.path.exists(model_path):
        os.mkdir(model_path)
    save_path = os.path.join(model_path, model)

    if not os.path.exists(save_path):
        os.mkdir(save_path)

    log_dir = os.path.join(model_path,  model, 'logs')
    if not os.path.exists(log_dir):
        os.mkdir(log_dir)

    writer = SummaryWriter(log_dir)

    net.to(device)

    best = 0.0

    early_stop_step = 0
    early_stop_limit = 15

    for idx in range(epochs):
        train_loop = tqdm(enumerate(train_dataloader),
                          total=len(train_dataloader), leave=True)
        train_loop.set_description(f'epoch: {idx}/{epochs}')

        train_fn(net=net, opt=opt,
                 dataloader=train_loop, device=device,
                 criterion=nn.CrossEntropyLoss(),
                 writer=writer, epoch=idx,
                 )

        val_loop = tqdm(enumerate(val_dataloader),
                        total=len(val_dataloader), leave=True)

        score = val_fn(net=net, dataloader=val_loop,
                       device=device, num_classes=num_classes, writer=writer, epoch=idx)

        print(f'acc={score}, best acc is {max(score, best)}')

        if (score > best):
            torch.save({
                'epoch': idx,
                'model_state_dict': net.state_dict(),
                'optimizer_state_dict': opt.state_dict(),
            }, os.path.join(save_path, f'epoch={idx}-miou={score:.4f}.pth'))
            best = score
            early_stop_step = 0
        else:
            if early_stop_step > early_stop_limit:
                print(f'因为已经有{early_stop_limit}轮没有提升，训练提前终止')
                writer.close()
                break
            early_stop_step += 1

        if scheduler:
            writer.add_scalar(
                "lr", scheduler.get_last_lr()[-1]
            )
            scheduler.step()
    writer.close()


if __name__ == '__main__':
    # args = arg_parser()
    bs = 64  # args.batch_size
    num_workers = 2  # args.num_workers
    device = 'cuda'  # args.device
    num_classes = 10  # args.num_classes
    lr = 0.001  # args.num_classes
    epochs = 10  # args.epochs

    seed_everything(42)
    net = Alexnet()
    opt = torch.optim.Adam(net.parameters(), lr=lr)

    train(net=net, epochs=epochs, batch_size=bs, num_workers=num_workers,
          device=device, num_classes=num_classes, opt=opt)


Files already downloaded and verified
Files already downloaded and verified


epoch: 0/10: 100%|██████████| 625/625 [01:10<00:00,  8.81it/s, loss=1.92]
100%|██████████| 157/157 [00:15<00:00, 10.10it/s]


acc=0.3675000071525574, best acc is 0.3675000071525574


epoch: 1/10: 100%|██████████| 625/625 [01:25<00:00,  7.35it/s, loss=1.61]
100%|██████████| 157/157 [00:20<00:00,  7.61it/s]


acc=0.4415000081062317, best acc is 0.4415000081062317


epoch: 2/10: 100%|██████████| 625/625 [01:28<00:00,  7.04it/s, loss=1.48]
100%|██████████| 157/157 [00:20<00:00,  7.55it/s]


acc=0.5080999732017517, best acc is 0.5080999732017517


epoch: 3/10: 100%|██████████| 625/625 [01:28<00:00,  7.05it/s, loss=1.39]
100%|██████████| 157/157 [00:20<00:00,  7.52it/s]


acc=0.5182999968528748, best acc is 0.5182999968528748


epoch: 4/10: 100%|██████████| 625/625 [01:09<00:00,  9.04it/s, loss=1.32]
100%|██████████| 157/157 [00:14<00:00, 10.52it/s]


acc=0.558899998664856, best acc is 0.558899998664856


epoch: 5/10: 100%|██████████| 625/625 [01:26<00:00,  7.19it/s, loss=1.25]
100%|██████████| 157/157 [00:20<00:00,  7.72it/s]


acc=0.567300021648407, best acc is 0.567300021648407


epoch: 6/10: 100%|██████████| 625/625 [01:28<00:00,  7.10it/s, loss=1.2]
100%|██████████| 157/157 [00:21<00:00,  7.25it/s]


acc=0.5995000004768372, best acc is 0.5995000004768372


epoch: 7/10: 100%|██████████| 625/625 [01:27<00:00,  7.13it/s, loss=1.15]
100%|██████████| 157/157 [00:20<00:00,  7.53it/s]


acc=0.6069999933242798, best acc is 0.6069999933242798


epoch: 8/10: 100%|██████████| 625/625 [01:29<00:00,  7.01it/s, loss=1.12]
100%|██████████| 157/157 [00:20<00:00,  7.62it/s]


acc=0.6177999973297119, best acc is 0.6177999973297119


epoch: 9/10: 100%|██████████| 625/625 [01:27<00:00,  7.16it/s, loss=1.09]
100%|██████████| 157/157 [00:21<00:00,  7.26it/s]


acc=0.6263999938964844, best acc is 0.6263999938964844


In [6]:
def test(net, dataloader, device, num_classes):
    from torch.nn import functional as F
    net.eval()
    # metric = torchmetrics.Accuracy(numClass=num_classes).to(device)
    correct=0.0
    with torch.no_grad():
        for  _, (input, target) in tqdm(enumerate(dataloader), total=len(dataloader), leave=True):
            input = input.to(device)
            target = target.to(device)
            pred = net(input)
            pred = F.softmax(pred, 1).argmax(1)
            correct += pred.eq(target).sum()
            # acc = metric.update(pred, target)
    # acc = metric.compute()
    print(correct / len(dataloader))

_, _, test_dataloader = get_CIFAdataset_loader(
        root='./data/CIFA', batch_size=128, num_workers=2, pin_memory=True, valid_rate=0.2, shuffle=True)

test(net=net, device=device, num_classes=num_classes, dataloader=test_dataloader)


Files already downloaded and verified
Files already downloaded and verified


100%|██████████| 10000/10000 [00:40<00:00, 250.00it/s]


tensor(0.6505, device='cuda:0')
