In [None]:
import torch
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torchvision.models as models
import os
from PIL import Image

In [None]:
datapath = '/home/zyf/Documents/ImageNet'
transform = transforms.Compose(
            [
            # scale and normalize to inception_v3 format
            transforms.Resize((299, 299)), 
            transforms.ToTensor(), 
            transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                                std=[0.229, 0.224, 0.225])
            ]) 
dataset = datasets.ImageNet(root=datapath, split="val", transform=transform)

In [None]:
BATCH_SIZE = 256

dataloader = torch.utils.data.DataLoader(dataset, batch_size=BATCH_SIZE)

In [None]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
model = models.inception_v3(pretrained=True, progress=False)
model = model.eval()
model = model.to(device)
# print(torch.cuda.memory_summary())

In [None]:
import torch.utils.benchmark as benchmark
num_threads = torch.get_num_threads()

repeat_times = 1

def inference(model, dataloader, device, enabled):
    sum_input = 0
    sum_right = 0
    for i, (images, targets) in enumerate(dataloader):
        # print(f'batch No.{i+1} start!')
        with torch.no_grad():
            with torch.cuda.amp.autocast(enabled=enabled):
                images = images.to(device)
                # print(torch.cuda.memory_summary())
                outputs = model(images)
            
                # get accuracy
                #'''
                predictions = torch.max(outputs, 1)[1]
                batch_size = outputs.size(0)
                for j in range(batch_size):
                    if predictions[j]==targets[j]:
                        sum_right += 1
                sum_input += batch_size
                #'''
                #if i==0:
                #    break
    print(f'evaluated {sum_input} samples, top-1 accuracy: {sum_right * 1.0 / sum_input}')        

with open(f'pytorch_fp32_amp_result_{repeat_times}.txt', 'w') as result:
    result.write(
'''
pytorch 1.8.0
cuda 11.0
python 3.8.5
2080Ti * 4
Inception v3
ILSVRC2012 validation set
fp32(original) vs. amp(torch.cuda.amp.autocast)

'''
    )
    torch.cuda.reset_peak_memory_stats(device)

    timer = benchmark.Timer(
        stmt='inference(model, dataloader, device, False)',
        setup='from __main__ import inference',
        globals={'model': model, 'dataloader': dataloader, 'device': device},
        #num_threads=num_threads,
        label='Inference Timing',
        sub_label='Original FP32 Inference'
    )

    s = str(timer.timeit(repeat_times))
    result.write(f'{s}\n')

    result.write(f'peak GPU mem usage on active tensors: {torch.cuda.max_memory_allocated(device)/1024.0/1024}MB\n\n')

    torch.cuda.reset_peak_memory_stats(device)

    timer = benchmark.Timer(
        stmt='inference(model, dataloader, device, True)',
        setup='from __main__ import inference',
        globals={'model': model, 'dataloader': dataloader, 'device': device},
        #num_threads=num_threads,
        label='Inference Timing',
        sub_label='Mixed Precision Inference'
    )

    s = str(timer.timeit(repeat_times))
    result.write(f'{s}\n')

    result.write(f'peak GPU mem usage on active tensors: {torch.cuda.max_memory_allocated(device)/1024.0/1024}MB\n\n')


In [None]:
torch.cuda.reset_peak_memory_stats(device)
inference(model, dataloader, device, False)
print(f'peak GPU mem usage on active tensors: {torch.cuda.max_memory_allocated(device)/1024.0/1024}MB\n')
inference(model, dataloader, device, True)
print(f'peak GPU mem usage on active tensors: {torch.cuda.max_memory_allocated(device)/1024.0/1024}MB\n')

In [None]:
# test cuda
print(torch.cuda.is_available())
for i in range(torch.cuda.device_count()):
    print(torch.cuda.get_device_name())

In [None]:
# test dummy cuda inference
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

m = torch.nn.Linear(20, 30).to(DEVICE)
input = torch.randn(128, 20).to(DEVICE)
output = m(input)
print('output', output.size())