In [None]:
import torch

# Create a large random tensor with 1 billion elements
# This simulates the size of a large model's parameters
num_elements = 1_000_000_000

# Tensor in standard float32 format
tensor_fp32 = torch.randn(num_elements, dtype=torch.float32)

# Tensor in bfloat16 format
# We create it in fp32 first and then convert it, which is a common practice
tensor_bf16 = tensor_fp32.to(dtype=torch.bfloat16)

# Calculate the memory size in Gigabytes (GB)
size_fp32_gb = tensor_fp32.element_size() * tensor_fp32.nelement() / (1024**3)
size_bf16_gb = tensor_bf16.element_size() * tensor_bf16.nelement() / (1024**3)

print(f"Size of float32 tensor: {size_fp32_gb:.2f} GB")
print(f"Size of bfloat16 tensor: {size_bf16_gb:.2f} GB")

Size of float32 tensor: 3.73 GB
Size of bfloat16 tensor: 1.86 GB


In [None]:
import torch

# A number with lots of decimal detail
original_number = 3.141592653523452345235235234523452352345234523452352345234

# Store it in both formats
num_fp32 = torch.tensor(original_number, dtype=torch.float32)
num_bf16 = torch.tensor(original_number, dtype=torch.bfloat16)

print(f"Original Number:      {original_number}")
print(f"Stored as float32:    {num_fp32.item()}")
print(f"Stored as bfloat16:   {num_bf16.item()}")

Original Number:      3.1415926535234524
Stored as float32:    3.1415927410125732
Stored as bfloat16:   3.140625


In [None]:
import torch
import time

# Ensure we are using a GPU
if not torch.cuda.is_available():
    print("Please enable GPU in Runtime -> Change runtime type")
else:
    # Create large matrices on the GPU
    matrix_size = 16000
    a = torch.randn(matrix_size, matrix_size, device='cuda')
    b = torch.randn(matrix_size, matrix_size, device='cuda')

    # Time the float32 multiplication
    start_time_fp32 = time.time()
    torch.matmul(a, b)
    end_time_fp32 = time.time()
    print(f"Time for float32 multiplication: {end_time_fp32 - start_time_fp32:.4f} seconds")

    # Convert matrices to bfloat16
    a_bf16 = a.to(torch.bfloat16)
    b_bf16 = b.to(torch.bfloat16)

    # Time the bfloat16 multiplication
    # We use torch.cuda.amp.autocast to ensure hardware acceleration is used
    with torch.cuda.amp.autocast(dtype=torch.bfloat16):
        start_time_bf16 = time.time()
        torch.matmul(a_bf16, b_bf16)
        end_time_bf16 = time.time()
        print(f"Time for bfloat16 multiplication: {end_time_bf16 - start_time_bf16:.4f} seconds")

Time for float32 multiplication: 0.0002 seconds
Time for bfloat16 multiplication: 0.0001 seconds


  with torch.cuda.amp.autocast(dtype=torch.bfloat16):
