# <div align="center"> 自动混合精度 </div>

In [2]:
%reload_ext watermark
%reload_ext autoreload
%autoreload 2
%matplotlib inline
%watermark -v -p numpy,pandas,matplotlib,sklearn,torch,torchvision,pytorch-lightning

CPython 3.6.9
IPython 7.16.1

numpy 1.18.5
pandas 1.0.4
matplotlib 3.2.1
sklearn 0.23.1
torch 1.6.0.dev20200609+cu101
torchvision 0.7.0.dev20200609+cu101


In [43]:
import numpy as np
import pandas as pd
import torch
import torchvision
import pytorch_lightning as pl
from torch import nn

from torch.utils.data import DataLoader, random_split
from torch.nn import functional as F
from torchvision.datasets import MNIST
from torchvision import datasets, transforms, models
import os

from torch.cuda.amp import GradScaler
from torch.cuda.amp import autocast 

from k12libs.utils.nb_easy import k12ai_get_top_dir
from k12libs.utils.nb_easy import K12AI_NBDATA_ROOT

In [51]:
class MNISTClassifier(pl.LightningModule):

    def __init__(self):
        super(MNISTClassifier, self).__init__()

        # mnist images are (1, 28, 28) (channels, width, height)
        # self.layer_1 = torch.nn.Linear(28 * 28, 128)
        # self.layer_2 = torch.nn.Linear(128, 256)
        # self.layer_3 = torch.nn.Linear(256, 10)
        self.model = models.resnet50()
        self.scaler = GradScaler()
        self.amp = True

    def forward(self, x):
        # batch_size, channels, width, height = x.size()
 
        # # (b, 1, 28, 28) -> (b, 1*28*28)
        # x = x.view(batch_size, -1)
 
        # # layer 1 (b, 1*28*28) -> (b, 128)
        # x = self.layer_1(x)
        # x = torch.relu(x)

        # # layer 2 (b, 128) -> (b, 256)
        # x = self.layer_2(x)
        # x = torch.relu(x)
 
        # # layer 3 (b, 256) -> (b, 10)
        # x = self.layer_3(x)

        # # probability distribution over labels
        # x = torch.log_softmax(x, dim=1)
        
        x = self.model(x)

        return x

    def cross_entropy_loss(self, logits, labels):
        return F.nll_loss(logits, labels)

    def training_step(self, train_batch, batch_idx):
        x, y = train_batch
        if self.amp:
            with autocast():
                logits = self.forward(x)
                loss = self.cross_entropy_loss(logits, y)
        else:
            logits = self.forward(x)
            loss = self.cross_entropy_loss(logits, y)
 
        logs = {'train_loss': loss}
        return {'loss': loss, 'log': logs}

    def validation_step(self, val_batch, batch_idx):
        x, y = val_batch
        if self.amp:
            with autocast():
                logits = self.forward(x)
                loss = self.cross_entropy_loss(logits, y)
        else:
            logits = self.forward(x)
            loss = self.cross_entropy_loss(logits, y)
        return {'val_loss': loss}
     
    def test_step(self, train_batch, batch_idx):
        pass
 
    def validation_epoch_end(self, outputs):
        avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
        tensorboard_logs = {'val_loss': avg_loss}
        return {'avg_val_loss': avg_loss, 'log': tensorboard_logs}
 
    def prepare_data(self):
        transform=transforms.Compose([
            transforms.Lambda(lambda img: img.convert('RGB')),
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,))])
       
        # prepare transforms standard to MNIST
        data_path = os.path.join(K12AI_NBDATA_ROOT, 'datasets')
        mnist_train = MNIST(data_path, train=True, download=False, transform=transform)
        mnist_test = MNIST(data_path, train=False, download=False, transform=transform)
     
        self.mnist_train, self.mnist_val = random_split(mnist_train, [55000, 5000])
 
    def train_dataloader(self):
        return DataLoader(self.mnist_train, batch_size=64)
 
    def val_dataloader(self):
        return DataLoader(self.mnist_val, batch_size=64)
 
    def test_dataloader(self):
        return DataLoader(self,mnist_test, batch_size=64)
 
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return optimizer
    
    def optimizer_step(self, current_epoch, batch_idx, optimizer, optimizer_idx,        
                       second_order_closure, using_native_amp):    
        if self.amp:
            self.scaler.step(optimizer)
            self.scaler.update()
        else:
            optimizer.step()                                                                

    def backward(self, trainer, loss, optimizer, optimizer_idx):
        if self.amp:
            self.scaler.scale(loss).backward()
        else:
            loss.backward()

## Print Callback

In [14]:
class PrintingCallback(pl.Callback):

    def on_init_start(self, trainer):
        print('Starting to init trainer!')

    def on_init_end(self, trainer):
        print('trainer is init now')

    def on_train_end(self, trainer, pl_module):
        print('do something when training ends')

In [None]:
model = MNISTClassifier()
trainer = pl.Trainer(gpus=1, callbacks=[PrintingCallback()])
trainer.fit(model) 

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type   | Params
---------------------------------
0 | model | ResNet | 25 M  


Starting to init trainer!
trainer is init now


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…