In [1]:
# !pip list | grep torch

### Import the libs

In [1]:
import torch
import transformer_engine.pytorch as te
from transformer_engine.common import recipe
import numpy as np
from transformer_engine.pytorch import Float8Tensor, E4M3, tensor_to_scale

torch.manual_seed(0)
np.random.seed(0)

someone called API registrations


### Create the model as a single TE.Linear

In [2]:
# Set dimensions.
in_features = 16
out_features = 16
hidden_size = 16

model = te.Linear(in_features, out_features, bias=False, params_dtype=torch.float32)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)

before parameter init
aten.detach.default
aten.detach.default
after parameter init
after register param




### Single iteration

In [3]:
inp = torch.randn(hidden_size, in_features, device="cuda")
print(inp[0])
# Create an FP8 recipe. Note: All input args are optional.
fp8_recipe = recipe.DelayedScaling(margin=0, interval=1, fp8_format=recipe.Format.E4M3)

optimizer.zero_grad()
# Enable autocasting for the forward pass
with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
    print(model.fp8_meta)
    out = model(inp, is_first_microbatch=None)

loss = out.sum()
loss.backward()

tensor([ 0.1808, -0.5523,  0.9238, -0.7350,  1.3800,  0.8676,  0.1297, -0.9406,
         0.8109,  0.8821, -1.0133, -0.3634,  0.5101,  0.4179, -0.6888, -0.1347],
       device='cuda:0')
{'fp8_group': None, 'recipe': DelayedScaling(margin=0, interval=1, fp8_format=<Format.HYBRID: _FormatHelper(max_fwd=448, max_bwd=57344)>, amax_history_len=1024, amax_compute_algo='max', override_linear_precision=_OverrideLinearPrecision(fprop=False, dgrad=False, wgrad=False), scaling_factor_compute_algo=None, reduce_amax=True), 'autocast_id_fwd_stack': [], 'async_amax_reduction': False}
None
fp8.py:  tensor([0., 0.], device='cuda:0')


### Check scales 

In [4]:
print(model.fp8_meta['scaling_fwd'].scale, model.weight._scale)

tensor([1., 1., 1.], device='cuda:0') tensor(5874.6479, device='cuda:0')


### Check `amax_history`

In [5]:
# a_h_flat = model.fp8_meta['scaling_fwd'].amax_history.cpu().numpy().flatten()
a_h_flat = model.fp8_meta['scaling_fwd'].amax_history.cpu().numpy() #[1024,3]
np.where(a_h_flat > 0.0), a_h_flat[a_h_flat > 0.0]

((array([0]), array([0])), array([3.0268958], dtype=float32))

### Do 1 optimizer step

In [6]:
optimizer.step()

aten.add_.Tensor


### Check scales

In [7]:
print("scale info: ", model.fp8_meta['scaling_fwd'].scale)

scale info:  tensor([1., 1., 1.], device='cuda:0')


### Check amax history

In [8]:
# a_h_flat = model.fp8_meta['scaling_fwd'].amax_history.cpu().numpy().flatten()
a_h_flat = model.fp8_meta['scaling_fwd'].amax_history.cpu().numpy() #[1024,3]
np.where(a_h_flat > 0.0), a_h_flat[a_h_flat > 0.0]

((array([0, 0]), array([0, 1])), array([3.0268958, 4.9162116], dtype=float32))

### Do subsequent iterations

In [9]:
for _ in range(2):
    inp = torch.randn(hidden_size, in_features, device="cuda")
    # print(inp[0])

    # Enable autocasting for the forward pass
    with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
        out = model(inp, is_first_microbatch=None)

    loss = out.sum()
    loss.backward()

    # print the scaling information
    print("scale info: ", model.fp8_meta['scaling_fwd'].scale)

fp8.py:  tensor([3.0269, 0.0000, 0.0000], device='cuda:0')
after amax scale update
None
fp8.py:  tensor([1., 0.], device='cuda:0')
scale info:  tensor([128.,   1.,   1.], device='cuda:0')
fp8.py:  tensor([3.6680, 0.0000, 0.0000], device='cuda:0')
after amax scale update
None
fp8.py:  tensor([1., 0.], device='cuda:0')
scale info:  tensor([64.,  1.,  1.], device='cuda:0')


In [10]:
print(model.fp8_meta['scaling_fwd'].scale, model.weight._scale)

tensor([64.,  1.,  1.], device='cuda:0') tensor(5874.6479, device='cuda:0')


In [11]:
# a_h_flat = model.fp8_meta['scaling_fwd'].amax_history.cpu().numpy().flatten()
a_h_flat = model.fp8_meta['scaling_fwd'].amax_history.cpu().numpy()
np.where(a_h_flat > 0.0), a_h_flat[a_h_flat > 0.0]

((array([   0, 1022, 1023]), array([0, 0, 0])),
 array([3.8595376, 3.0268958, 3.668017 ], dtype=float32))