In this assignment, we are going to learn how torch autocast affects the workflow during the forward pass pf the models. For this exercise, we look into OPT-125M model. Let's first load the model and tokenizer using HF.

In [2]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

model_name = "facebook/OPT-125M"
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float32).cuda().eval()

# let's decide on the 16bit dtype we want to use. Not all GPUs support bfloat16
dtype_16bit = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16



Loading weights:   0%|          | 0/197 [00:00<?, ?it/s]



In order to see the workflow, we will print the input and output of each layer. For this we add the following hook to the target layers' forward function.

In [3]:
# Add hooks to track dtypes
import torch.nn as nn
WATCH = (nn.Linear, nn.LayerNorm, nn.Embedding)
hooks = []
def make_hook(name):
    def hook(m, inp, out):
        indt = inp[0].dtype if isinstance(inp, (tuple, list)) else getattr(inp, 'dtype', None)
        outdt = out[0].dtype if isinstance(out, (tuple, list)) else getattr(out, 'dtype', None)
        print(f"[{name}] in={indt} -> out={outdt}")
    return hook

for n, m in model.named_modules():
    if isinstance(m, WATCH):
        hooks.append(m.register_forward_hook(make_hook(n)))


Let's define the sample input that we want to run through the model.

In [4]:
# sample input
text = "Profiling AMP forward pass."
inputs = tokenizer(text, return_tensors="pt").to("cuda")

We want to run the model in 4 settings:
- No torch autocast   and casting the model parameters to FP32.
- No torch autocast   and casting the model parameters to BF16 or FP16.
- With torch autocast and casting the model parameters to FP32.
- With torch autocast and casting the model parameters to BF16 or FP16.

In [5]:
# No torch autocast   and casting the model parameters to FP32
# See how the entire model runs in FP32
model = model.to(torch.float32)
with torch.inference_mode():
    _ = model(**inputs)

# [int64 tokens]  (shared input)
#         |
#    +----+---------------------+
#    |                          |
# [embed_tok] int64‚Üífp32    [embed_pos] int64‚Üífp32
#    |                          |
#    +-------- sum (fp32+fp32‚Üífp32) --------+
#                                           |
#                                         [LN1] fp32‚Üífp32
#                                           |
#                       +---------+---------+---------+
#                       |         |                   |
#                    [q_proj]  [k_proj]           [v_proj]
#                     fp32‚Üífp32  fp32‚Üífp32         fp32‚Üífp32
#                       \         |                 /
#                        \        |                /
#                         +----[attn + softmax]----+   (fp32‚Üífp32)
#                                           |
#                                     [out_proj] fp32‚Üífp32
#                                           |
#                                  (residual add) fp32‚Üífp32
#                                           |
#                                         [LN2] fp32‚Üífp32
#                                           |
#                                        [fc1] fp32‚Üífp32
#                                           |
#                                        [fc2] fp32‚Üífp32
#                                           |
#                                  (residual add) fp32‚Üífp32


[model.decoder.embed_tokens] in=torch.int64 -> out=torch.float32
[model.decoder.embed_positions] in=torch.int64 -> out=torch.float32
[model.decoder.layers.0.self_attn_layer_norm] in=torch.float32 -> out=torch.float32
[model.decoder.layers.0.self_attn.q_proj] in=torch.float32 -> out=torch.float32
[model.decoder.layers.0.self_attn.k_proj] in=torch.float32 -> out=torch.float32
[model.decoder.layers.0.self_attn.v_proj] in=torch.float32 -> out=torch.float32
[model.decoder.layers.0.self_attn.out_proj] in=torch.float32 -> out=torch.float32
[model.decoder.layers.0.final_layer_norm] in=torch.float32 -> out=torch.float32
[model.decoder.layers.0.fc1] in=torch.float32 -> out=torch.float32
[model.decoder.layers.0.fc2] in=torch.float32 -> out=torch.float32
[model.decoder.layers.1.self_attn_layer_norm] in=torch.float32 -> out=torch.float32
[model.decoder.layers.1.self_attn.q_proj] in=torch.float32 -> out=torch.float32
[model.decoder.layers.1.self_attn.k_proj] in=torch.float32 -> out=torch.float32
[mo

In [6]:
# No torch autocast and casting the model parameters to BF16 or FP16.
# See how the entire model runs in BF16
model = model.to(dtype_16bit)
with torch.inference_mode():
    _ = model(**inputs)
# [int64 tokens]  (shared input)
#         |
#    +----+---------------------+
#    |                          |
# [embed_tok] int64‚Üíbf16    [embed_pos] int64‚Üíbf16
#    |                          |
#    +-------- sum (bf16+bf16‚Üíbf16) --------+
#                                           |
#                                         [LN1] bf16‚Üíbf16
#                                           |
#                       +---------+---------+---------+
#                       |         |                   |
#                    [q_proj]  [k_proj]           [v_proj]
#                     bf16‚Üíbf16  bf16‚Üíbf16         bf16‚Üíbf16
#                       \         |                 /
#                        \        |                /
#                         +----[attn + softmax]----+   (bf16‚Üíbf16)
#                                           |
#                                     [out_proj] bf16‚Üíbf16
#                                           |
#                                  (residual add) bf16‚Üíbf16
#                                           |
#                                         [LN2] bf16‚Üíbf16
#                                           |
#                                        [fc1] bf16‚Üíbf16
#                                           |
#                                        [fc2] bf16‚Üíbf16
#                                           |
#                                  (residual add) bf16‚Üíbf16


[model.decoder.embed_tokens] in=torch.int64 -> out=torch.bfloat16
[model.decoder.embed_positions] in=torch.int64 -> out=torch.bfloat16
[model.decoder.layers.0.self_attn_layer_norm] in=torch.bfloat16 -> out=torch.bfloat16
[model.decoder.layers.0.self_attn.q_proj] in=torch.bfloat16 -> out=torch.bfloat16
[model.decoder.layers.0.self_attn.k_proj] in=torch.bfloat16 -> out=torch.bfloat16
[model.decoder.layers.0.self_attn.v_proj] in=torch.bfloat16 -> out=torch.bfloat16
[model.decoder.layers.0.self_attn.out_proj] in=torch.bfloat16 -> out=torch.bfloat16
[model.decoder.layers.0.final_layer_norm] in=torch.bfloat16 -> out=torch.bfloat16
[model.decoder.layers.0.fc1] in=torch.bfloat16 -> out=torch.bfloat16
[model.decoder.layers.0.fc2] in=torch.bfloat16 -> out=torch.bfloat16
[model.decoder.layers.1.self_attn_layer_norm] in=torch.bfloat16 -> out=torch.bfloat16
[model.decoder.layers.1.self_attn.q_proj] in=torch.bfloat16 -> out=torch.bfloat16
[model.decoder.layers.1.self_attn.k_proj] in=torch.bfloat16 -

In [7]:
# With torch autocast and casting the model parameters to BF16 or FP16
# See how the LN runs in FP32 and how linear layers run in BF16 even though their inputs are in FP32!
model = model.to(dtype_16bit)
with torch.inference_mode():
    with torch.autocast(device_type="cuda", dtype=dtype_16bit):
        _ = model(**inputs)

# [int64 tokens]  (shared input)
#         |
#    +----+---------------------+
#    |                          |
# [embed_tok] int64‚Üíbf16    [embed_pos] int64‚Üíbf16
#    |                          |
#    +-------- sum (bf16+bf16‚Üíbf16) --------+
#                                           |
#                                         [LN1] bf16‚Üífp32
#                                           |
#                       +---------+---------+---------+
#                       |         |                   |
#                    [q_proj]  [k_proj]           [v_proj]
#                     fp32‚Üíbf16  fp32‚Üíbf16         fp32‚Üíbf16
#                       \         |                 /
#                        \        |                /
#                         +----[attn + softmax]----+   (bf16‚Üíbf16)
#                                           |
#                                     [out_proj] bf16‚Üíbf16
#                                           |
#                 (residual add with skip: bf16 + bf16 ‚Üí bf16)
#                                           |
#                                         [LN2] bf16‚Üífp32
#                                           |
#                                        [fc1] fp32‚Üíbf16
#                                           |
#                                        [fc2] bf16‚Üíbf16
#                                           |
#                 (residual add with skip: bf16 + bf16 ‚Üí bf16)


[model.decoder.embed_tokens] in=torch.int64 -> out=torch.bfloat16
[model.decoder.embed_positions] in=torch.int64 -> out=torch.bfloat16
[model.decoder.layers.0.self_attn_layer_norm] in=torch.bfloat16 -> out=torch.float32
[model.decoder.layers.0.self_attn.q_proj] in=torch.float32 -> out=torch.bfloat16
[model.decoder.layers.0.self_attn.k_proj] in=torch.float32 -> out=torch.bfloat16
[model.decoder.layers.0.self_attn.v_proj] in=torch.float32 -> out=torch.bfloat16
[model.decoder.layers.0.self_attn.out_proj] in=torch.bfloat16 -> out=torch.bfloat16
[model.decoder.layers.0.final_layer_norm] in=torch.bfloat16 -> out=torch.float32
[model.decoder.layers.0.fc1] in=torch.float32 -> out=torch.bfloat16
[model.decoder.layers.0.fc2] in=torch.bfloat16 -> out=torch.bfloat16
[model.decoder.layers.1.self_attn_layer_norm] in=torch.bfloat16 -> out=torch.float32
[model.decoder.layers.1.self_attn.q_proj] in=torch.float32 -> out=torch.bfloat16
[model.decoder.layers.1.self_attn.k_proj] in=torch.float32 -> out=tor

In [8]:
# With torch autocast and casting the model parameters to FP32
# See how linear layers run in BF16 even though their parameters are in FP32!
# See how the residual sum runs in FP32 even though one input is in BF16 and the other in FP32.
model = model.to(torch.float32)
with torch.inference_mode():
    with torch.autocast(device_type="cuda", dtype=dtype_16bit):
        _ = model(**inputs)
# [int64 tokens]  (shared input)
#         |
#    +----+---------------------+
#    |                          |
# [embed_tok] int64‚Üífp32    [embed_pos] int64‚Üífp32
#    |                          |
#    +-------- sum (fp32+fp32‚Üífp32) --------+
#                                           |
#                                         [LN1] fp32‚Üífp32
#                                           |
#                       +---------+---------+---------+
#                       |         |                   |
#                    [q_proj]  [k_proj]           [v_proj]
#                     fp32‚Üíbf16  fp32‚Üíbf16         fp32‚Üíbf16
#                       \         |                 /
#                        \        |                /
#                         +----[attn + softmax]----+   (bf16‚Üíbf16)
#                                           |
#                                     [out_proj] bf16‚Üíbf16
#                                           |
#               (residual add: bf16 + fp32 ‚Üí fp32)   ‚Üê promotion to fp32
#                                           |
#                                         [LN2] fp32‚Üífp32
#                                           |
#                                        [fc1] fp32‚Üíbf16
#                                           |
#                                        [fc2] bf16‚Üíbf16
#                                           |
#               (residual add: bf16 + fp32 ‚Üí fp32)   ‚Üê promotion to fp32


[model.decoder.embed_tokens] in=torch.int64 -> out=torch.float32
[model.decoder.embed_positions] in=torch.int64 -> out=torch.float32
[model.decoder.layers.0.self_attn_layer_norm] in=torch.float32 -> out=torch.float32
[model.decoder.layers.0.self_attn.q_proj] in=torch.float32 -> out=torch.bfloat16
[model.decoder.layers.0.self_attn.k_proj] in=torch.float32 -> out=torch.bfloat16
[model.decoder.layers.0.self_attn.v_proj] in=torch.float32 -> out=torch.bfloat16
[model.decoder.layers.0.self_attn.out_proj] in=torch.bfloat16 -> out=torch.bfloat16
[model.decoder.layers.0.final_layer_norm] in=torch.float32 -> out=torch.float32
[model.decoder.layers.0.fc1] in=torch.float32 -> out=torch.bfloat16
[model.decoder.layers.0.fc2] in=torch.bfloat16 -> out=torch.bfloat16
[model.decoder.layers.1.self_attn_layer_norm] in=torch.float32 -> out=torch.float32
[model.decoder.layers.1.self_attn.q_proj] in=torch.float32 -> out=torch.bfloat16
[model.decoder.layers.1.self_attn.k_proj] in=torch.float32 -> out=torch.bf

We saw above that it is possible to feed a bfloat16 (bf16) tensor as input to a linear layer whose weights are stored in float32 (fp32), and the operation completes successfully.

This works because, under the hood, automatic mixed precision (AMP) autocast intercepts the call to the linear kernel and casts the weights to bf16 so that the matrix multiplication can proceed with matching dtypes. In other words, autocast ensures type compatibility by automatically downcasting fp32 weights to bf16 on-the-fly, without requiring explicit intervention from the user.

In the absence of autocast, such an operation would raise a runtime error. PyTorch enforces strict dtype checks: you cannot multiply an fp32 weight matrix with a bf16 input tensor directly, since there‚Äôs no implicit casting at the operator boundary.

In [9]:
layer = torch.nn.Linear(2, 2).to(device="cuda", dtype=dtype_16bit) # 16bit layer
inp = torch.randn(2, 2).to(device="cuda", dtype=torch.float32) # 32bit input
try:
    out = layer(inp)
except RuntimeError as e:
    print("ERROR!", e)

ERROR! mat1 and mat2 must have the same dtype, but got Float and BFloat16


Let's try again with autocast

In [10]:
layer = torch.nn.Linear(2, 2).to(device="cuda", dtype=dtype_16bit) # 16bit layer
inp = torch.randn(2, 2).to(device="cuda", dtype=torch.float32) # 32bit input
with torch.autocast(device_type="cuda", dtype=dtype_16bit):
    out = layer(inp)
print("SUCCESS!")

SUCCESS!


We can also see how autocast affect the loss computation. Let's check the MSE loss.

In [12]:
gt = torch.randn(2, 2).to(device="cuda", dtype=dtype_16bit)
pred = torch.randn(2, 2).to(device="cuda", dtype=dtype_16bit)

# without autocast
out = nn.MSELoss()(pred, gt)
print(out)
print(out.dtype) # dtype_16bit

# with autocast
with torch.autocast(device_type="cuda", dtype=dtype_16bit):
    out = nn.MSELoss()(pred, gt)
print(out.dtype) # FP32
print(out)

tensor(2.3906, device='cuda:0', dtype=torch.bfloat16)
torch.bfloat16
torch.float32
tensor(2.3890, device='cuda:0')


No active policy for ReLU, that is, whatever data format goes in, same comes out. See below.

In [13]:
import torch
import torch.nn as nn
dtype_16bit = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16

x16 = torch.randn(2, 2).to(device="cuda", dtype=dtype_16bit)
x32 = torch.randn(2, 2).to(device="cuda", dtype=torch.float32)

# without autocast
out = nn.ReLU()(x16)
print(out.dtype) # dtype_16bit

# without autocast
out = nn.ReLU()(x32)
print(out.dtype) # FP32

# with autocast
with torch.autocast(device_type="cuda", dtype=dtype_16bit):
    out = nn.ReLU()(x16)
print(out.dtype) # dtype_16bit

# with autocast
with torch.autocast(device_type="cuda", dtype=dtype_16bit):
    out = nn.ReLU()(x32)
print(out.dtype) # FP32

torch.bfloat16
torch.float32
torch.bfloat16
torch.float32


FP32 policy for Softmax, that is, whatever data format goes in, FP32 comes out. See below.

In [14]:
import torch
import torch.nn as nn
dtype_16bit = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16

x16 = torch.randn(2, 2).to(device="cuda", dtype=dtype_16bit)
x32 = torch.randn(2, 2).to(device="cuda", dtype=torch.float32)

# without autocast
out = nn.Softmax()(x16)
print(out.dtype) # dtype_16bit

# without autocast
out = nn.Softmax()(x32)
print(out.dtype) # FP32

# with autocast
with torch.autocast(device_type="cuda", dtype=dtype_16bit):
    out = nn.Softmax()(x16)
print(out.dtype) # FP32

# with autocast
with torch.autocast(device_type="cuda", dtype=dtype_16bit):
    out = nn.Softmax()(x32)
print(out.dtype) # FP32

torch.bfloat16
torch.float32
torch.float32
torch.float32


  return self._call_impl(*args, **kwargs)


üòé Now, let's have some fun! What do you think the output data type would be for the following operations? Why sum and mean are different?!

In [22]:
import torch
import torch.nn as nn
dtype_16bit = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16

x16 = torch.randn(2, 20000).to(device="cuda", dtype=dtype_16bit)

# with autocast
with torch.autocast(device_type="cuda", dtype=dtype_16bit):
    out_max = x16.max()
    out_min = x16.min()
    out_sum = x16.sum()
    out_mean = x16.mean()
    out_prod = x16.prod()
    out_exp = x16.exp()
    out_log = x16.log()

    print("out_max:", out_max.dtype, out_max)
    print("out_min:", out_min.dtype, out_min)
    print("out_sum:", out_sum.dtype, out_sum)
    print("out_mean:", out_mean.dtype, out_mean)
    print("out_prod:", out_prod.dtype, out_prod)
    print("out_exp:", out_exp.dtype, out_exp.shape)
    print("out_log:", out_log.dtype, out_log.shape)

out_max: torch.bfloat16 tensor(4.4688, device='cuda:0', dtype=torch.bfloat16)
out_min: torch.bfloat16 tensor(-4.3750, device='cuda:0', dtype=torch.bfloat16)
out_sum: torch.float32 tensor(221.7129, device='cuda:0')
out_mean: torch.bfloat16 tensor(0.0056, device='cuda:0', dtype=torch.bfloat16)
out_prod: torch.float32 tensor(-0., device='cuda:0')
out_exp: torch.float32 torch.Size([2, 20000])
out_log: torch.float32 torch.Size([2, 20000])


In [21]:
import torch
import torch.nn as nn
dtype_16bit = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16

x32 = torch.randn(2, 20000).to(device="cuda", dtype=torch.float32)

# with autocast
with torch.autocast(device_type="cuda", dtype=dtype_16bit):
    out_relu = nn.ReLU()(x32)
    print("out_relu:", out_relu.dtype)

out_relu: torch.float32
