# <div align="center"> 自动混合精度 </div>

In [None]:
%reload_ext watermark
%reload_ext autoreload
%autoreload 2
%matplotlib inline
%watermark -v -p numpy,pandas,matplotlib,sklearn,torch,torchvision,pytorch_lightning

In [2]:
import numpy as np
import pandas as pd
import torch
import torchvision
import pytorch_lightning as pl
import torch.optim as optim

from apex import amp
from tqdm import tqdm
from pytorch_lightning.loggers import TensorBoardLogger  

from torch.utils.data import DataLoader, random_split
from torch.nn import functional as F
from torch import nn
from torchvision.datasets import MNIST
from torchvision import datasets, transforms, models
import os

from torch.cuda.amp import GradScaler
from torch.cuda.amp import autocast 

from k12libs.utils.nb_easy import k12ai_get_top_dir,k12ai_start_tensorboard,k12ai_set_notebook
from k12libs.utils.nb_easy import K12AI_NBDATA_ROOT

In [None]:
class MNISTClassifier(pl.LightningModule):

    def __init__(self):
        super(MNISTClassifier, self).__init__()

        # mnist images are (1, 28, 28) (channels, width, height)
        # self.layer_1 = torch.nn.Linear(28 * 28, 128)
        # self.layer_2 = torch.nn.Linear(128, 256)
        # self.layer_3 = torch.nn.Linear(256, 10)
        self.model = models.resnet50()
        self.scaler = GradScaler()
        self.amp = False

    def forward(self, x):
        # batch_size, channels, width, height = x.size()
 
        # # (b, 1, 28, 28) -> (b, 1*28*28)
        # x = x.view(batch_size, -1)
 
        # # layer 1 (b, 1*28*28) -> (b, 128)
        # x = self.layer_1(x)
        # x = torch.relu(x)

        # # layer 2 (b, 128) -> (b, 256)
        # x = self.layer_2(x)
        # x = torch.relu(x)
 
        # # layer 3 (b, 256) -> (b, 10)
        # x = self.layer_3(x)

        # # probability distribution over labels
        # x = torch.log_softmax(x, dim=1)
        
        x = self.model(x)

        return x

    def cross_entropy_loss(self, logits, labels):
        return F.nll_loss(logits, labels)

    def training_step(self, train_batch, batch_idx):
        x, y = train_batch
        if self.amp:
            with autocast():
                logits = self.forward(x)
                loss = self.cross_entropy_loss(logits, y)
        else:
            logits = self.forward(x)
            loss = self.cross_entropy_loss(logits, y)
 
        logs = {'train_loss': loss}
        return {'loss': loss, 'log': logs}

    def validation_step(self, val_batch, batch_idx):
        x, y = val_batch
        if self.amp:
            with autocast():
                logits = self.forward(x)
                loss = self.cross_entropy_loss(logits, y)
        else:
            logits = self.forward(x)
            loss = self.cross_entropy_loss(logits, y)
        return {'val_loss': loss}
     
    def test_step(self, train_batch, batch_idx):
        pass
 
    def validation_epoch_end(self, outputs):
        avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
        tensorboard_logs = {'val_loss': avg_loss}
        torch.cuda.empty_cache()
        return {'avg_val_loss': avg_loss, 'log': tensorboard_logs}
 
    def prepare_data(self):
        transform=transforms.Compose([
            transforms.Lambda(lambda img: img.convert('RGB')),
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,))])
       
        # prepare transforms standard to MNIST
        data_path = os.path.join(K12AI_NBDATA_ROOT, 'datasets')
        mnist_train = MNIST(data_path, train=True, download=True, transform=transform)
        mnist_test = MNIST(data_path, train=False, download=True, transform=transform)
     
        self.mnist_train, self.mnist_val = random_split(mnist_train, [55000, 5000])
 
    def train_dataloader(self):
        return DataLoader(self.mnist_train, batch_size=64)
 
    def val_dataloader(self):
        return DataLoader(self.mnist_val, batch_size=64)
 
    def test_dataloader(self):
        return DataLoader(self,mnist_test, batch_size=64)
 
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return optimizer
    
    def optimizer_step(self, current_epoch, batch_idx, optimizer, optimizer_idx,        
                       second_order_closure, using_native_amp):    
        if self.amp:
            self.scaler.step(optimizer)
            self.scaler.update()
        else:
            optimizer.step()                                                                

    def backward(self, trainer, loss, optimizer, optimizer_idx):
        if self.amp:
            self.scaler.scale(loss).backward()
        else:
            loss.backward()

## Print Callback

In [None]:
class PrintingCallback(pl.Callback):

    def on_init_start(self, trainer):
        print('Starting to init trainer!')

    def on_init_end(self, trainer):
        print('trainer is init now')

    def on_train_end(self, trainer, pl_module):
        print('do something when training ends')

In [None]:
log_dir = os.path.join(K12AI_NBDATA_ROOT, 'logs')
k12ai_set_notebook(cellw=95)
k12ai_start_tensorboard(9004, log_dir, clear=True)

In [None]:
logger = TensorBoardLogger(save_dir=log_dir, name='tb_logs')
model = MNISTClassifier()
trainer = pl.Trainer(
    precision=32,
    gpus=1,
    logger = logger,
    log_gpu_memory='min_max',
    callbacks=[PrintingCallback()])
trainer.fit(model) 

---------------

## (New) Using autocast

In [None]:
torch.cuda.synchronize()
torch.cuda.empty_cache()

torch.backends.cudnn.benchmark = True
torch.backends.cudnn.enabled = True

use_amp = True
model = torchvision.models.vgg16().cuda()
loss_fn = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=0.001)

optimizer.zero_grad()
if use_amp:
    scaler = torch.cuda.amp.GradScaler()
for epoch in range(501):
    _inputs = torch.randn(20, 3, 32, 32).cuda()
    labels = torch.randn(20, 1000).cuda()
    
    if use_amp:
        with torch.cuda.amp.autocast():
            outputs = model(_inputs)
            loss = loss_fn(outputs, labels)
    else:
        outputs = model(_inputs)
        loss = loss_fn(outputs, labels)
            
    if use_amp:
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
    else:
        loss.backward()
        optimizer.step()
        
    if epoch % 100 == 0:
        torch.cuda.synchronize()
        torch.cuda.empty_cache()
        print('%03d' % epoch,
              round(torch.cuda.max_memory_allocated()/2**20, 2),
              round(torch.cuda.memory_allocated()/2**20, 2),
              round(torch.cuda.memory_reserved()/2**20, 2))

use_amp = False:  3563MB 1m33s
use_amp = True:   3545MB 1m38s

## (Old) Using apex.amp

In [3]:
torch.cuda.synchronize()
torch.cuda.empty_cache()

use_amp = False

model = torchvision.models.vgg16().cuda()
loss_fn = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=0.001)

if use_amp:
    model, optimizer = amp.initialize(model, optimizer, opt_level='O2')
    
def step(inputs, targets):
    outputs = model(inputs)
    loss = loss_fn(outputs, targets)
    if use_amp:
        with amp.scale_loss(loss, optimizer) as scaled_loss:
            scaled_loss.backward()
    else:
        loss.backward()
    optimizer.step()
    # del outputs
    # del loss
    
optimizer.zero_grad()
for epoch in range(301):
    inputs = torch.randn(32, 3, 64, 64).cuda()
    labels = torch.randn(32, 1000).cuda()
    step(inputs, labels)
    # del inputs
    # del labels
    torch.cuda.empty_cache()
    if epoch % 100 == 0:
        print('%03d' % epoch,
              round(torch.cuda.max_memory_allocated()/2**20, 2),
              round(torch.cuda.max_memory_reserved()/2**20, 2),
              round(torch.cuda.memory_allocated()/2**20, 2),
              round(torch.cuda.memory_reserved()/2**20, 2))

000 8793.74 17016.0 1109.27 4012.0
100 8793.88 17016.0 1109.27 3964.0
200 8793.88 17016.0 1109.27 3964.0


KeyboardInterrupt: 

use_amp = False bs=128

    000 2025.62 2964.0 1064.57 2964.0
    100 2206.83 2964.0 1064.57 2964.0
    200 2206.83 2964.0 1064.57 2964.0
    300 2206.83 2964.0 1064.57 2964.0
    400 2206.83 2964.0 1064.57 2964.0
    500 2206.83 2964.0 1064.57 2964.0

---------------

use_amp = True

    000 1836.56 2652.0 1064.51 2652.0
    100 1837.94 2652.0 1064.89 2652.0
    200 1837.94 2652.0 1064.89 2652.0
    300 1837.94 2652.0 1064.89 2652.0
    400 1837.94 2652.0 1064.89 2652.0
    500 1837.94 2652.0 1064.89 2652.0

    000 8793.74 17016.0 1109.27 4012.0
    100 8793.88 17016.0 1109.27 3964.0
    200 8793.88 17016.0 1109.27 3964.0

--------------

    000 7281.77 15746.0 1109.31 3370.0
    100 7283.71 16200.0 1109.31 5616.0
    200 7283.71 16200.0 1109.31 5616.0

In [4]:
torch.backends.cuda.cufft_plan_cache.max_size, \
torch.backends.cuda.cufft_plan_cache.size

(4096, 0)