AMP (Automatic Mixed Precision)

https://pytorch.org/docs/stable/notes/amp_examples.html#

In [None]:
import torch
import sklearn

https://scikit-learn.org/stable/modules/generated/sklearn.datasets.load_digits.html#sklearn.datasets.load_digits

In [None]:
x,y=sklearn.datasets.load_digits(return_X_y=True)
x=torch.tensor(x/16).float().cuda() # FP32
y=torch.tensor(y).long().cuda()
print(x.shape,x.dtype)
print(y.shape,y.dtype)

In [None]:
class MLP(torch.nn.Module):
    def __init__(self,input_size,hidden_size,output_size):
        super(MLP,self).__init__()
        self.fc1=torch.nn.Linear(input_size, hidden_size)
        self.fc2=torch.nn.Linear(hidden_size, output_size)

    def forward(self,x):
        out=self.fc1(x)
        out=torch.relu(out)
        out=self.fc2(out)
        return out

In [None]:
model=MLP(input_size=64,hidden_size=256,output_size=10).cuda()
loss_fn=torch.nn.CrossEntropyLoss()
optimizer=torch.optim.SGD(model.parameters(),lr=0.01)
scaler=torch.amp.GradScaler()

print_once=False
def debug_forward(module,input,output):
    global print_once
    if not print_once:
        print_once=True
        print(f'{module}\ninput_shape={input[0].shape} input_dtype={input[0].dtype}\noutput_shape={output.shape} output_dtype={output.dtype}\nweight_shape={module.weight.shape} weight_dtype={module.weight.dtype}')
    
model.fc1.register_forward_hook(debug_forward)

iter=0
while True:
    optimizer.zero_grad()
    with torch.amp.autocast(device_type='cuda',dtype=torch.float16): # FP16 Mix
        out=model(x)
        loss=loss_fn(out,y)
    scaler.scale(loss).backward()
    scaler.step(optimizer)
    scaler.update()
    iter+=1
    if iter%100000==0:
        print(f'iter={iter} loss={loss.item()} cuda_mem={torch.cuda.memory_allocated()}Bytes')
    if loss.item()<=1e-3:
        break