In [2]:
import time

import torch
import torch.nn as nn
from torch import optim
from torchvision.models import resnet50
from torchsummary import summary
from torch.optim import lr_scheduler
from torch.cuda.amp import GradScaler, autocast

In [3]:
model = resnet50(pretrained=True).to('cuda:0')
# x = torch.rand(3, 640, 640).to('cuda:0')
# summary(model, x.squeeze(0).shape, batch_size=8)



In [4]:
targets = torch.ones(16, 1000).to(0)
def compute_loss(pred: torch.Tensor):
    func = nn.CrossEntropyLoss(reduction='mean')
    return func(pred, targets)


In [5]:
data = torch.rand(16, 3, 640, 640).to(0)
total_epoch = 5
def without_amp():
    optimizer = optim.SGD(model.parameters(), momentum=0.9, lr=0.1)
    scheduler = lr_scheduler.LambdaLR(optimizer, lambda x: (1 - x / total_epoch) * (1 - 0.1) + 0.1)

    for e in range(0, total_epoch):
        optimizer.zero_grad()
        pred = model(data)
        loss = compute_loss(pred)
        loss.backward()
        optimizer.step()
        scheduler.step()
t1 = time.time()
without_amp()
t2 = time.time()
print(f'不使用用amp消耗的时间：{t2 - t1}秒')

不使用用amp消耗的时间：3.076925277709961秒


In [10]:
def with_amp():
    optimizer = optim.SGD(model.parameters(), momentum=0.9, lr=0.1)
    scheduler = lr_scheduler.LambdaLR(optimizer, lambda x: (1 - x / total_epoch) * (1 - 0.1) + 0.1)
    scaler = GradScaler()

    for e in range(0, total_epoch):
        optimizer.zero_grad()
        # 对部分算子转为FP16(如Conv, Weight, Activations) , 部分算子转为FP32(拷贝一份)
        with autocast():
            pred = model(data)
            loss = compute_loss(pred)
        # 对loss缩放, 先转为FP16进行反向传播
        scaler.scale(loss).backward()
        # 更新权重前unscale为FP32。 如果loss为nan， 则跳过当前iteration的optimizer的权重更新
        scaler.step(optimizer)
        # 根据loss值动态修改scale factor, 若上一步权重更新跳过, 则scale * backoff_factor下降因子, 若权重参数正常更新, 则scale * growth_factor上升因子，也可以用new_scale直接更新scale factor
        scaler.update()
        scheduler.step()
t1 = time.time()
with_amp()
t2 = time.time()
print(f'使用amp消耗的时间：{t2 - t1}秒')

使用amp消耗的时间：1.4285361766815186秒


In [None]:
a_float32 = torch.rand((8, 8), device='cuda:0')
b_float32 = torch.rand((8, 8), device='cuda:0')
c_float32 = torch.rand((8, 8), device='cuda:0')
d_float32 = torch.rand((8, 8), device='cuda:0')
with autocast():
    # mm操作会自动将float32等类型转换为fp16, 进行运算
    e_float16 = torch.mm(a_float32, b_float32)
    f_float16 = torch.mm(d_float32, c_float32)
print(e_float16.dtype, f_float16.dtype)

# 在autocast外部, f_float16仍为fp16类型, 若不同精度的数据相互运算, 都是按照数据中较大的数值精度进行操作，因此需要转为fp16
g_float32 = torch.mm(d_float32, f_float16.type(torch.float32))
print(g_float32.dtype)

KeyboardInterrupt: 