<a href="https://colab.research.google.com/github/yuzhi535/resnet-pytorch/blob/master/vgg16.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!git clone https://github.com/yuzhi535/resnet-pytorch.git

Cloning into 'resnet-pytorch'...
remote: Enumerating objects: 46, done.[K
remote: Counting objects: 100% (46/46), done.[K
remote: Compressing objects: 100% (31/31), done.[K
remote: Total 46 (delta 18), reused 35 (delta 10), pack-reused 0[K
Unpacking objects: 100% (46/46), done.


In [2]:
%cd resnet-pytorch
%pip install timm einops torchmetrics albumentations

/content/resnet-pytorch
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting timm
  Downloading timm-0.6.7-py3-none-any.whl (509 kB)
[K     |████████████████████████████████| 509 kB 21.4 MB/s 
[?25hCollecting einops
  Downloading einops-0.4.1-py3-none-any.whl (28 kB)
Collecting torchmetrics
  Downloading torchmetrics-0.9.3-py3-none-any.whl (419 kB)
[K     |████████████████████████████████| 419 kB 52.2 MB/s 
Installing collected packages: torchmetrics, timm, einops
Successfully installed einops-0.4.1 timm-0.6.7 torchmetrics-0.9.3


In [3]:
import random
import torch
import os
import torchmetrics
from argparse import ArgumentParser
import torch.nn as nn
from tqdm import tqdm
from utils.dataloader import get_CIFAdataset_loader
import numpy as np
from torch.utils.tensorboard import SummaryWriter
from networks.resnet import Resnet, Restnet34
from networks.vgg import VGG16

In [6]:
def seed_everything(seed: int):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True


def arg_parser():
    parser = ArgumentParser()

    parser.add_argument('--batch-size', '-bs', type=int,
                        default=16, required=True, help='input batch size')
    parser.add_argument('--num-workers', '-nw',  type=int,
                        default=4, required=True, help='number of workers')
    # parser.add_argument('--resume', '-r', type=str,
    #                     required=False, help='resume a train')
    parser.add_argument('--device', type=str,
                        help='gpu or cpu', choices=['gpu', 'cpu'], default='gpu')
    parser.add_argument('--num-classes', '-nc', type=int,
                        help='number of classes', required=True)
    parser.add_argument('--lr', '-lr', type=float, default=1e-4)
    parser.add_argument('--epochs', type=int,
                        required=True,  help='num of epochs')

    args = parser.parse_args()
    return args


def train_fn(net, dataloader, opt, device, criterion, writer, epoch):
    net.train()
    train_loss = 0
    criterion.to(device)
    for idx, (input, target) in dataloader:
        input = input.to(device)
        target = target.to(device)

        opt.zero_grad()
        pred = net(input)

        loss = criterion(pred, target)

        train_loss += loss.item()
        loss.backward()
        opt.step()

        cur_loss = train_loss/(idx+1)

        dataloader.set_postfix(loss=cur_loss)
        writer.add_scalar('training loss',
                          cur_loss,
                          epoch*len(dataloader)+idx)


def val_fn(net, dataloader, device, num_classes, writer, epoch: int):
    net.eval()
    metric = torchmetrics.Accuracy(numClass=num_classes).to(device)
    with torch.no_grad():
        for idx, (input, target) in dataloader:
            input = input.to(device)
            target = target.to(device)
            pred = net(input)
            acc = metric.update(pred, target)
    acc = metric.compute()
    writer.add_scalar('val_acc', acc, epoch*len(dataloader)+idx)
    return acc


def train(net, opt, epochs, batch_size, num_workers, device, num_classes, model='Resnet', scheduler=None):

    train_dataloader, val_dataloader, _ = get_CIFAdataset_loader(
        root='./data/CIFA', batch_size=batch_size, num_workers=num_workers, pin_memory=True, valid_rate=0.2, shuffle=True)

    # 模型权重位置
    model_path = 'runs'
    if not os.path.exists(model_path):
        os.mkdir(model_path)
    save_path = os.path.join(model_path, model)

    if not os.path.exists(save_path):
        os.mkdir(save_path)

    log_dir = os.path.join(model_path,  model, 'logs')
    if not os.path.exists(log_dir):
        os.mkdir(log_dir)

    writer = SummaryWriter(log_dir)

    net.to(device)

    best = 0.0

    early_stop_step = 0
    early_stop_limit = 15

    for idx in range(epochs):
        train_loop = tqdm(enumerate(train_dataloader),
                          total=len(train_dataloader), leave=True)
        train_loop.set_description(f'epoch: {idx}/{epochs}')

        train_fn(net=net, opt=opt,
                 dataloader=train_loop, device=device,
                 criterion=nn.CrossEntropyLoss(),
                 writer=writer, epoch=idx,
                 )

        val_loop = tqdm(enumerate(val_dataloader),
                        total=len(val_dataloader), leave=True)

        score = val_fn(net=net, dataloader=val_loop,
                       device=device, num_classes=num_classes, writer=writer, epoch=idx)

        print(f'acc={score}, best acc is {max(score, best)}')

        if (score > best):
            torch.save({
                'epoch': idx,
                'model_state_dict': net.state_dict(),
                'optimizer_state_dict': opt.state_dict(),
            }, os.path.join(save_path, f'epoch={idx}-miou={score:.4f}.pth'))
            best = score
            early_stop_step = 0
        else:
            if early_stop_step > early_stop_limit:
                print(f'因为已经有{early_stop_limit}轮没有提升，训练提前终止')
                writer.close()
                break
            early_stop_step += 1

        if scheduler:
            writer.add_scalar(
                "lr", scheduler.get_last_lr()[-1]
            )
            scheduler.step()
    writer.close()

In [7]:
if __name__ == '__main__':
    # args = arg_parser()
    bs = 32  # args.batch_size
    num_workers = 2  # args.num_workers
    device = 'cuda:0'  # args.device
    num_classes = 10  # args.num_classes
    lr = 0.001  # args.num_classes
    epochs = 10  # args.epochs
    # net = Resnet(num_classes, [3, 4, 6, 3], [16, 32, 64, 128])
    net = VGG16()
    opt = torch.optim.Adam(net.parameters(), lr=lr)
    seed_everything(42)
    train(net=net, epochs=epochs, batch_size=bs, num_workers=num_workers,
          device=device, num_classes=num_classes, opt=opt)

Files already downloaded and verified
Files already downloaded and verified


epoch: 0/10: 100%|██████████| 1250/1250 [00:48<00:00, 25.88it/s, loss=1.83]
100%|██████████| 313/313 [00:05<00:00, 60.65it/s]


acc=0.41990000009536743, best acc is 0.41990000009536743


epoch: 1/10: 100%|██████████| 1250/1250 [00:40<00:00, 30.75it/s, loss=1.34]
100%|██████████| 313/313 [00:04<00:00, 63.80it/s]


acc=0.5759000182151794, best acc is 0.5759000182151794


epoch: 2/10: 100%|██████████| 1250/1250 [00:38<00:00, 32.29it/s, loss=1.04]
100%|██████████| 313/313 [00:05<00:00, 56.67it/s]


acc=0.6572999954223633, best acc is 0.6572999954223633


epoch: 3/10: 100%|██████████| 1250/1250 [00:38<00:00, 32.30it/s, loss=0.905]
100%|██████████| 313/313 [00:04<00:00, 64.31it/s]


acc=0.6949999928474426, best acc is 0.6949999928474426


epoch: 4/10: 100%|██████████| 1250/1250 [00:39<00:00, 32.03it/s, loss=0.798]
100%|██████████| 313/313 [00:04<00:00, 63.48it/s]


acc=0.7276999950408936, best acc is 0.7276999950408936


epoch: 5/10: 100%|██████████| 1250/1250 [00:38<00:00, 32.19it/s, loss=0.703]
100%|██████████| 313/313 [00:04<00:00, 65.07it/s]


acc=0.7849000096321106, best acc is 0.7849000096321106


epoch: 6/10: 100%|██████████| 1250/1250 [00:38<00:00, 32.68it/s, loss=0.629]
100%|██████████| 313/313 [00:04<00:00, 64.74it/s]


acc=0.7882000207901001, best acc is 0.7882000207901001


epoch: 7/10: 100%|██████████| 1250/1250 [00:39<00:00, 31.81it/s, loss=0.571]
100%|██████████| 313/313 [00:04<00:00, 65.64it/s]


acc=0.8109999895095825, best acc is 0.8109999895095825


epoch: 8/10: 100%|██████████| 1250/1250 [00:38<00:00, 32.08it/s, loss=0.527]
100%|██████████| 313/313 [00:04<00:00, 64.15it/s]


acc=0.8119000196456909, best acc is 0.8119000196456909


epoch: 9/10: 100%|██████████| 1250/1250 [00:38<00:00, 32.29it/s, loss=0.483]
100%|██████████| 313/313 [00:05<00:00, 56.36it/s]


acc=0.8256000280380249, best acc is 0.8256000280380249


# 模型评估

In [9]:
def test(net, dataloader, device, num_classes):
    from torch.nn import functional as F
    net.eval()
    # metric = torchmetrics.Accuracy(numClass=num_classes).to(device)
    correct=0.0
    with torch.no_grad():
        for  _, (input, target) in tqdm(enumerate(dataloader), total=len(dataloader), leave=True):
            input = input.to(device)
            target = target.to(device)
            pred = net(input)
            pred = F.softmax(pred, 1).argmax(1)
            correct += pred.eq(target).sum()
            # acc = metric.update(pred, target)
    # acc = metric.compute()
    print(correct / len(dataloader))

_, _, test_dataloader = get_CIFAdataset_loader(
        root='./data/CIFA', batch_size=128, num_workers=2, pin_memory=True, valid_rate=0.2, shuffle=True)

test(net=net, device=device, num_classes=num_classes, dataloader=test_dataloader)

Files already downloaded and verified
Files already downloaded and verified


100%|██████████| 10000/10000 [00:52<00:00, 191.25it/s]


tensor(0.8342, device='cuda:0')
