# VGG-19 customize

PyTorch Tutorial Notebook on Utilizing VGG-19 Pretrained Model

### Library Import

In [1]:
import os

import torch
import torch.nn as nn
import torch.nn.functional as F

from torchvision.datasets import CIFAR10
from torchvision import transforms
import torchvision

import pytorch_lightning as pl
from pytorch_lightning.trainer import Trainer
from pytorch_lightning import loggers as pl_logger

from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint

### Hyperparameters

- batch_size : Batch size
- max_epochs : Maximum epoch num
- es_monitor_metric : Monitored metric for Early Stopping
- es_patience : Number of epochs to endure before performing Early Stopping
- cifar_data_root : path to download cifar10
- log_path : path to save training logs
- gpus : setting gpus (Array or None) [Docs | Multi-GPU training](https://pytorch-lightning.readthedocs.io/en/stable/multi_gpu.html)

In [2]:
hyperparams = {
    'batch_size': 64,
    'max_epochs': 200,
    'es_monitor_metric': 'valid_loss',
    'es_patience': 15,
    'cifar_data_root': "./data",
    'log_path': './logs',
    'gpus': [0]
}

## Prepare Dataset using CIFAR10
CIFAR10を
1. 読み込み
1. Tensor変換 + 標準化
1. 学習データ / 評価データ / テストデータに変換

In [3]:
transform = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ]
)

In [4]:
def train_valid_split(train_data, valid_rate=0.2):
    train_length = int((1-valid_rate) * len(train_data))
    valid_length = len(train_data) - train_length
    
    train_dataset, valid_dataset = torch.utils.data.random_split(train_data, (train_length, valid_length))
    
    return train_dataset, valid_dataset

In [5]:
train_data = CIFAR10(root=hyperparams['cifar_data_root'], download=True, train=True, transform=transform)
test_data = CIFAR10(root=hyperparams['cifar_data_root'], download=False, train=False, transform=transform)

train_data, valid_data = train_valid_split(train_data)

Files already downloaded and verified


In [6]:
train_loader = torch.utils.data.DataLoader(
    train_data,
    batch_size=hyperparams['batch_size'],
    shuffle=True,
    num_workers=8
)

valid_loader = torch.utils.data.DataLoader(
    valid_data,
    batch_size=hyperparams['batch_size'],
    shuffle=False,
    num_workers=8
)

test_loader = torch.utils.data.DataLoader(
    test_data,
    batch_size=hyperparams['batch_size'],
    shuffle=False,
    num_workers=8
)

## Retraining the VGG-19 model

In [7]:
from model import VGG19Transfer, VGG19FineTuning

### Transfer Learning

In [8]:
PROJECT_NAME = 'vgg_cifer_transfer'
exp_transfer = VGG19Transfer(10)

tb_logger = pl_logger.TensorBoardLogger(
    save_dir=os.path.join(hyperparams['log_path'], 'tensorboard'),
    name=PROJECT_NAME,
    default_hp_metric=False
)

es_callback = EarlyStopping(
    monitor=hyperparams['es_monitor_metric'], 
    patience=hyperparams['es_patience'], 
    mode='min'
)

trainer_transfer = Trainer(
    gpus=hyperparams['gpus'], 
    max_epochs=hyperparams['max_epochs'],
    logger=[tb_logger],
    callbacks=[
        es_callback,
    ]
)

trainer_transfer.fit(
    exp_transfer,
    train_loader,
    valid_loader
)

GPU available: True, used: True
TPU available: None, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Missing logger folder: ./logs/tensorboard/vgg_cifer_transfer

  | Name      | Type             | Params
-----------------------------------------------
0 | vgg19     | VGG              | 139 M 
1 | criterion | CrossEntropyLoss | 0     
2 | train_acc | Accuracy         | 0     
3 | valid_acc | Accuracy         | 0     
-----------------------------------------------
41.0 K    Trainable params
139 M     Non-trainable params
139 M     Total params


Validation sanity check: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

1

### Fine Tuning

In [9]:
PROJECT_NAME = 'vgg_cifer_finetuning'
exp_finetuning = VGG19FineTuning(10)


tb_logger = pl_logger.TensorBoardLogger(
    save_dir=os.path.join(hyperparams['log_path'], 'tensorboard'),
    name=PROJECT_NAME,
    default_hp_metric=False
)

es_callback = EarlyStopping(
    monitor=hyperparams['es_monitor_metric'], 
    patience=hyperparams['es_patience'], 
    mode='min'
)

trainer_finetuning = Trainer(
    gpus=hyperparams['gpus'],
    max_epochs=hyperparams['max_epochs'],
    logger=[tb_logger],
    callbacks=[
        es_callback,
    ]
)

trainer_finetuning.fit(
    exp_finetuning,
    train_loader, 
    valid_loader
)

GPU available: True, used: True
TPU available: None, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Missing logger folder: ./logs/tensorboard/vgg_cifer_finetuning

  | Name      | Type             | Params
-----------------------------------------------
0 | vgg19     | VGG              | 139 M 
1 | criterion | CrossEntropyLoss | 0     
2 | train_acc | Accuracy         | 0     
3 | valid_acc | Accuracy         | 0     
-----------------------------------------------
139 M     Trainable params
0         Non-trainable params
139 M     Total params


Validation sanity check: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

1

## Plot in Tensorboard

In [11]:
%reload_ext tensorboard
%tensorboard --logdir ./logs/tensorboard --bind_all --port 6008

Reusing TensorBoard on port 6008 (pid 22848), started 12:26:11 ago. (Use '!kill 22848' to kill it.)