#### Summary
- torch.cuda.amp怎么使用
    - fp16: loss scaling
        - https://github.com/mli/transformers-benchmarks/blob/main/transformers.ipynb
    - 极大地提升batch_size

#### cnn pipeline

In [1]:
import torch
import torch.nn.functional as F
import torchvision.datasets as datasets
import torchvision.transforms as transforms

from torch import optim
from torch import nn
from torch.utils.data import DataLoader
from tqdm import tqdm

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

#### Model

In [3]:
# Simple CNN
class CNN(nn.Module):
    def __init__(self, in_channels=1, num_classes=10):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(
            in_channels=in_channels,
            out_channels=5120,
            kernel_size=3,
            stride=1,
            padding=1,
        )
        # /2, downsampling
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv2 = nn.Conv2d(
            in_channels=5120,
            out_channels=10240,
            kernel_size=3,
            stride=1,
            padding=1,
        )
        # (channels*w*h)
        # w, h: 取决于初始的 width, height
        self.fc1 = nn.Linear(10240 * 7 * 7, num_classes)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        # /2
        x = self.pool(x)
        x = F.relu(self.conv2(x))
        # /2
        x = self.pool(x)
        print(x.shape)
        # 4d => 2d, (bs, features)
        x = x.reshape(x.shape[0], -1)
        x = self.fc1(x)
        return x

In [29]:
from torchsummary import summary

In [19]:
model = CNN(in_channels=3)
summary(model, input_size=(3, 224, 224), batch_size=32, device='cpu')

torch.Size([2, 5120, 224, 224])
torch.Size([2, 5120, 112, 112])
torch.Size([2, 10240, 56, 56])
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1       [32, 5120, 224, 224]         143,360
         MaxPool2d-2       [32, 5120, 112, 112]               0
            Conv2d-3      [32, 10240, 112, 112]     471,869,440
         MaxPool2d-4        [32, 10240, 56, 56]               0
            Linear-5                   [32, 10]     321,126,410
Total params: 793,139,210
Trainable params: 793,139,210
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 18.38
Forward/backward pass size (MB): 117600.00
Params size (MB): 3025.59
Estimated Total Size (MB): 120643.96
----------------------------------------------------------------


#### training pipeline

In [4]:
in_channels = 1
num_classes = 10

learning_rate = 3e-4
batch_size = 32
num_epochs = 3

In [21]:
train_dataset = datasets.MNIST(root='dataset/', train=True, transform=transforms.ToTensor(), download=True)
test_dataset = datasets.MNIST(root='dataset/', train=False, transform=transforms.ToTensor(), download=True)

train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=True)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to dataset/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:02<00:00, 4210621.16it/s] 


Extracting dataset/MNIST/raw/train-images-idx3-ubyte.gz to dataset/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to dataset/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 1010727.52it/s]


Extracting dataset/MNIST/raw/train-labels-idx1-ubyte.gz to dataset/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to dataset/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:01<00:00, 1255326.08it/s]


Extracting dataset/MNIST/raw/t10k-images-idx3-ubyte.gz to dataset/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to dataset/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 10986464.11it/s]

Extracting dataset/MNIST/raw/t10k-labels-idx1-ubyte.gz to dataset/MNIST/raw






In [23]:
batch_x, batch_y = next(iter(train_loader))
print(batch_x.shape, batch_y.shape)

torch.Size([32, 1, 28, 28]) torch.Size([32])


#### float32

In [None]:
in_channels = 1
num_classes = 10

learning_rate = 3e-4
batch_size = 128
num_epochs = 3

train_dataset = datasets.MNIST(root="dataset/", train=True, transform=transforms.ToTensor(), download=True)
test_dataset = datasets.MNIST(root="dataset/", train=False, transform=transforms.ToTensor(), download=True)
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=True)

model = CNN(in_channels=in_channels, num_classes=num_classes).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)


def train():
    for epoch in tqdm(range(num_epochs)):
        for batch_idx, (batch_x, batch_y) in tqdm(enumerate(train_loader)):
            batch_x = batch_x.to(device)
            batch_y = batch_y.to(device)

            logits = model(batch_x)
            loss = criterion(logits, batch_y)

            optimizer.zero_grad()
            loss.backward()

            optimizer.step()


def evalute(model, test_loader):
    total_correct = 0
    total_samples = 0
    model.eval()

    with torch.no_grad():
        for batch_x, batch_y in test_loader:
            batch_x = batch_x.to(device)
            batch_y = batch_y.to(device)

            logits = model(batch_x)
            _, preds = logits.max(1)
            total_correct += (preds == batch_y).sum()
            total_samples += batch_y.size(0)
    model.train()
    return total_correct / total_samples


train()

  0%|          | 0/3 [00:00<?, ?it/s]
0it [00:00, ?it/s][A

torch.Size([128, 10240, 7, 7])



1it [00:03,  3.41s/it][A
2it [00:04,  1.85s/it][A

torch.Size([128, 10240, 7, 7])



3it [00:06,  2.10s/it][A

torch.Size([128, 10240, 7, 7])



4it [00:08,  2.23s/it][A

torch.Size([128, 10240, 7, 7])



5it [00:11,  2.30s/it][A

torch.Size([128, 10240, 7, 7])



6it [00:13,  2.34s/it][A

torch.Size([128, 10240, 7, 7])



7it [00:16,  2.37s/it][A

torch.Size([128, 10240, 7, 7])



8it [00:18,  2.38s/it][A

torch.Size([128, 10240, 7, 7])



9it [00:21,  2.39s/it][A

torch.Size([128, 10240, 7, 7])



10it [00:23,  2.41s/it][A

torch.Size([128, 10240, 7, 7])



11it [00:25,  2.41s/it][A

torch.Size([128, 10240, 7, 7])



12it [00:28,  2.42s/it][A

torch.Size([128, 10240, 7, 7])



13it [00:30,  2.43s/it][A

torch.Size([128, 10240, 7, 7])



14it [00:33,  2.43s/it][A

torch.Size([128, 10240, 7, 7])



15it [00:35,  2.43s/it][A

torch.Size([128, 10240, 7, 7])



16it [00:38,  2.43s/it][A

torch.Size([128, 10240, 7, 7])



17it [00:40,  2.43s/it][A

torch.Size([128, 10240, 7, 7])



18it [00:43,  2.43s/it][A

torch.Size([128, 10240, 7, 7])



19it [00:45,  2.43s/it][A

torch.Size([128, 10240, 7, 7])



20it [00:47,  2.43s/it][A

torch.Size([128, 10240, 7, 7])



21it [00:50,  2.43s/it][A

torch.Size([128, 10240, 7, 7])



22it [00:52,  2.43s/it][A

torch.Size([128, 10240, 7, 7])



23it [00:55,  2.43s/it][A

torch.Size([128, 10240, 7, 7])



24it [00:57,  2.43s/it][A

torch.Size([128, 10240, 7, 7])



25it [01:00,  2.43s/it][A

torch.Size([128, 10240, 7, 7])



26it [01:02,  2.43s/it][A

torch.Size([128, 10240, 7, 7])



27it [01:04,  2.43s/it][A

torch.Size([128, 10240, 7, 7])



28it [01:07,  2.43s/it][A

torch.Size([128, 10240, 7, 7])



29it [01:09,  2.43s/it][A

torch.Size([128, 10240, 7, 7])



30it [01:12,  2.43s/it][A

torch.Size([128, 10240, 7, 7])



31it [01:14,  2.43s/it][A

torch.Size([128, 10240, 7, 7])



32it [01:17,  2.43s/it][A

torch.Size([128, 10240, 7, 7])



33it [01:19,  2.43s/it][A

torch.Size([128, 10240, 7, 7])



34it [01:21,  2.43s/it][A

torch.Size([128, 10240, 7, 7])



35it [01:24,  2.43s/it][A

torch.Size([128, 10240, 7, 7])



36it [01:26,  2.43s/it][A

torch.Size([128, 10240, 7, 7])



37it [01:29,  2.43s/it][A

torch.Size([128, 10240, 7, 7])



38it [01:31,  2.43s/it][A

torch.Size([128, 10240, 7, 7])



39it [01:34,  2.43s/it][A

torch.Size([128, 10240, 7, 7])



40it [01:36,  2.43s/it][A

torch.Size([128, 10240, 7, 7])



41it [01:38,  2.43s/it][A

torch.Size([128, 10240, 7, 7])



42it [01:41,  2.43s/it][A

torch.Size([128, 10240, 7, 7])



43it [01:43,  2.43s/it][A

torch.Size([128, 10240, 7, 7])



44it [01:46,  2.43s/it][A

torch.Size([128, 10240, 7, 7])



45it [01:48,  2.43s/it][A

torch.Size([128, 10240, 7, 7])



46it [01:51,  2.43s/it][A

torch.Size([128, 10240, 7, 7])



47it [01:53,  2.43s/it][A

torch.Size([128, 10240, 7, 7])



48it [01:55,  2.43s/it][A

torch.Size([128, 10240, 7, 7])



49it [01:58,  2.43s/it][A

torch.Size([128, 10240, 7, 7])



50it [02:00,  2.42s/it][A

torch.Size([128, 10240, 7, 7])



51it [02:03,  2.43s/it][A

torch.Size([128, 10240, 7, 7])



52it [02:05,  2.43s/it][A

torch.Size([128, 10240, 7, 7])



53it [02:08,  2.42s/it][A

torch.Size([128, 10240, 7, 7])



54it [02:10,  2.42s/it][A

torch.Size([128, 10240, 7, 7])



55it [02:12,  2.42s/it][A

torch.Size([128, 10240, 7, 7])



56it [02:15,  2.43s/it][A

torch.Size([128, 10240, 7, 7])



57it [02:17,  2.43s/it][A

torch.Size([128, 10240, 7, 7])



58it [02:20,  2.43s/it][A

torch.Size([128, 10240, 7, 7])



59it [02:22,  2.42s/it][A

torch.Size([128, 10240, 7, 7])



60it [02:25,  2.42s/it][A

torch.Size([128, 10240, 7, 7])



61it [02:27,  2.42s/it][A

torch.Size([128, 10240, 7, 7])



62it [02:29,  2.42s/it][A

torch.Size([128, 10240, 7, 7])



63it [02:32,  2.42s/it][A

torch.Size([128, 10240, 7, 7])



64it [02:34,  2.42s/it][A

torch.Size([128, 10240, 7, 7])



65it [02:37,  2.42s/it][A

torch.Size([128, 10240, 7, 7])



66it [02:39,  2.42s/it][A

torch.Size([128, 10240, 7, 7])



67it [02:42,  2.42s/it][A

torch.Size([128, 10240, 7, 7])



68it [02:44,  2.42s/it][A

torch.Size([128, 10240, 7, 7])



69it [02:46,  2.42s/it][A

torch.Size([128, 10240, 7, 7])



70it [02:49,  2.43s/it][A

torch.Size([128, 10240, 7, 7])



71it [02:51,  2.43s/it][A

torch.Size([128, 10240, 7, 7])



72it [02:54,  2.43s/it][A

torch.Size([128, 10240, 7, 7])



73it [02:56,  2.42s/it][A

torch.Size([128, 10240, 7, 7])



74it [02:58,  2.42s/it][A

torch.Size([128, 10240, 7, 7])



75it [03:01,  2.42s/it][A

torch.Size([128, 10240, 7, 7])



76it [03:03,  2.43s/it][A

torch.Size([128, 10240, 7, 7])



77it [03:06,  2.43s/it][A

torch.Size([128, 10240, 7, 7])



78it [03:08,  2.42s/it][A

torch.Size([128, 10240, 7, 7])



79it [03:11,  2.43s/it][A

torch.Size([128, 10240, 7, 7])



80it [03:13,  2.42s/it][A

torch.Size([128, 10240, 7, 7])



81it [03:15,  2.42s/it][A

torch.Size([128, 10240, 7, 7])



82it [03:18,  2.42s/it][A

torch.Size([128, 10240, 7, 7])



83it [03:20,  2.42s/it][A

torch.Size([128, 10240, 7, 7])



84it [03:23,  2.42s/it][A

torch.Size([128, 10240, 7, 7])



85it [03:25,  2.42s/it][A

torch.Size([128, 10240, 7, 7])



86it [03:28,  2.42s/it][A

torch.Size([128, 10240, 7, 7])



87it [03:30,  2.42s/it][A

torch.Size([128, 10240, 7, 7])



88it [03:32,  2.42s/it][A

torch.Size([128, 10240, 7, 7])



89it [03:35,  2.42s/it][A

torch.Size([128, 10240, 7, 7])



90it [03:37,  2.43s/it][A

torch.Size([128, 10240, 7, 7])



91it [03:40,  2.42s/it][A

torch.Size([128, 10240, 7, 7])



92it [03:42,  2.42s/it][A

torch.Size([128, 10240, 7, 7])



93it [03:45,  2.42s/it][A

torch.Size([128, 10240, 7, 7])



94it [03:47,  2.43s/it][A

torch.Size([128, 10240, 7, 7])



95it [03:49,  2.43s/it][A

torch.Size([128, 10240, 7, 7])



96it [03:52,  2.42s/it][A

torch.Size([128, 10240, 7, 7])



97it [03:54,  2.42s/it][A

torch.Size([128, 10240, 7, 7])



98it [03:57,  2.42s/it][A

torch.Size([128, 10240, 7, 7])



99it [03:59,  2.42s/it][A

torch.Size([128, 10240, 7, 7])



100it [04:01,  2.42s/it][A

torch.Size([128, 10240, 7, 7])



101it [04:04,  2.42s/it][A

torch.Size([128, 10240, 7, 7])



102it [04:06,  2.42s/it][A

torch.Size([128, 10240, 7, 7])



103it [04:09,  2.42s/it][A

torch.Size([128, 10240, 7, 7])



104it [04:11,  2.43s/it][A

torch.Size([128, 10240, 7, 7])



105it [04:14,  2.43s/it][A

torch.Size([128, 10240, 7, 7])



106it [04:16,  2.42s/it][A

torch.Size([128, 10240, 7, 7])



107it [04:18,  2.42s/it][A

torch.Size([128, 10240, 7, 7])



108it [04:21,  2.42s/it][A

torch.Size([128, 10240, 7, 7])



109it [04:23,  2.42s/it][A

torch.Size([128, 10240, 7, 7])



110it [04:26,  2.42s/it][A

torch.Size([128, 10240, 7, 7])



111it [04:28,  2.42s/it][A

torch.Size([128, 10240, 7, 7])



112it [04:31,  2.42s/it][A

torch.Size([128, 10240, 7, 7])



113it [04:33,  2.42s/it][A

torch.Size([128, 10240, 7, 7])



114it [04:35,  2.42s/it][A

torch.Size([128, 10240, 7, 7])



115it [04:38,  2.42s/it][A

torch.Size([128, 10240, 7, 7])



116it [04:40,  2.42s/it][A

torch.Size([128, 10240, 7, 7])



117it [04:43,  2.42s/it][A

torch.Size([128, 10240, 7, 7])



118it [04:45,  2.42s/it][A

torch.Size([128, 10240, 7, 7])



119it [04:47,  2.42s/it][A

torch.Size([128, 10240, 7, 7])



120it [04:50,  2.42s/it][A

torch.Size([128, 10240, 7, 7])



121it [04:52,  2.42s/it][A

torch.Size([128, 10240, 7, 7])



122it [04:55,  2.42s/it][A

torch.Size([128, 10240, 7, 7])



123it [04:57,  2.42s/it][A

torch.Size([128, 10240, 7, 7])



124it [05:00,  2.42s/it][A

torch.Size([128, 10240, 7, 7])



125it [05:02,  2.42s/it][A

torch.Size([128, 10240, 7, 7])



126it [05:04,  2.42s/it][A

torch.Size([128, 10240, 7, 7])



127it [05:07,  2.42s/it][A

torch.Size([128, 10240, 7, 7])



128it [05:09,  2.42s/it][A

torch.Size([128, 10240, 7, 7])



129it [05:12,  2.42s/it][A

torch.Size([128, 10240, 7, 7])



130it [05:14,  2.42s/it][A

torch.Size([128, 10240, 7, 7])



131it [05:16,  2.42s/it][A

torch.Size([128, 10240, 7, 7])



132it [05:19,  2.42s/it][A

torch.Size([128, 10240, 7, 7])



133it [05:21,  2.42s/it][A

torch.Size([128, 10240, 7, 7])



134it [05:24,  2.42s/it][A

torch.Size([128, 10240, 7, 7])



135it [05:26,  2.42s/it][A

torch.Size([128, 10240, 7, 7])



136it [05:29,  2.42s/it][A

torch.Size([128, 10240, 7, 7])



137it [05:31,  2.42s/it][A

torch.Size([128, 10240, 7, 7])



138it [05:33,  2.42s/it][A

torch.Size([128, 10240, 7, 7])



139it [05:36,  2.42s/it][A

torch.Size([128, 10240, 7, 7])



140it [05:38,  2.42s/it][A

torch.Size([128, 10240, 7, 7])



141it [05:41,  2.42s/it][A

torch.Size([128, 10240, 7, 7])



142it [05:43,  2.42s/it][A

torch.Size([128, 10240, 7, 7])



143it [05:46,  2.42s/it][A

torch.Size([128, 10240, 7, 7])



144it [05:48,  2.42s/it][A

torch.Size([128, 10240, 7, 7])



145it [05:50,  2.42s/it][A

torch.Size([128, 10240, 7, 7])


In [None]:
print(f"Accuracy on training set: {evalute(model, train_loader) * 100:.2f}")
print(f"Accuracy on test set: {evalute(model, test_loader) * 100:.2f}")

#### 混合精度训练

In [None]:
in_channels = 1
num_classes = 10

learning_rate = 3e-4
batch_size = 256
num_epochs = 3

train_dataset = datasets.MNIST(
    root="dataset/", train=True, transform=transforms.ToTensor(), download=True
)
test_dataset = datasets.MNIST(
    root="dataset/", train=False, transform=transforms.ToTensor(), download=True
)
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=True)

model = CNN(in_channels=in_channels, num_classes=num_classes).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

scaler = torch.cuda.amp.GradScaler()


def train():
    for epoch in tqdm(range(num_epochs)):
        for batch_idx, (batch_x, batch_y) in tqdm(enumerate(train_loader)):
            batch_x = batch_x.to(device)
            batch_y = batch_y.to(device)

            with torch.cuda.amp.autocast():
                logits = model(batch_x)
                loss = criterion(logits, batch_y)

            optimizer.zero_grad()
            scaler.scale(loss).backward()

            scaler.step(optimizer)
            scaler.update()


def evalute(model, test_loader):
    total_correct = 0
    total_samples = 0
    model.eval()

    with torch.no_grad():
        for batch_x, batch_y in test_loader:
            batch_x = batch_x.to(device)
            batch_y = batch_y.to(device)

            logits = model(batch_x)
            _, preds = logits.max(1)
            total_correct += (preds == batch_y).sum()
            total_samples += batch_y.size(0)
    model.train()
    return total_correct / total_samples


train()
