In [1]:
import os 
import timeit

os.environ["CUDA_VISIBLE_DEVICES"] = "7"

import numpy as np
from jaxtyping import Float
from einops import rearrange, einsum, reduce

import torch
from torch import nn
import torch.nn.functional as F

# Memory is determined by the (i) number of values and (ii) data type of each value.
def get_memory_usage(x: torch.Tensor):
    return x.numel() * x.element_size()

def get_promised_flop_per_sec(device: str, dtype: torch.dtype) -> float:
    """Return the peak FLOP/s for `device` operating on `dtype`."""
    if not torch.cuda.is_available():
        return 1
    
    properties = torch.cuda.get_device_properties(device)
    if "A100" in properties.name:
        # https://www.nvidia.com/content/dam/en-zz/Solutions/Data-Center/a100/pdf/nvidia-a100-datasheet-us-nvidia-1758950-r4-web.pdf")
        if dtype == torch.float32:
            return 19.5e12
        if dtype in (torch.bfloat16, torch.float16):
            return 312e12
        raise ValueError(f"Unknown dtype: {dtype}")
    
    if "H100" in properties.name:
        # https://resources.nvidia.com/en-us-tensor-core/nvidia-tensor-core-gpu-datasheet")
        if dtype == torch.float32:
            return 67.5e12
        if dtype in (torch.bfloat16, torch.float16):
            return 1979e12 / 2  # 1979 is for sparse, dense is half of that
        raise ValueError(f"Unknown dtype: {dtype}")
    
    raise ValueError(f"Unknown device: {device}")

def same_storage(x: torch.Tensor, y: torch.Tensor):
    return x.untyped_storage().data_ptr() == y.untyped_storage().data_ptr()

def time_matmul(a: torch.Tensor, b: torch.Tensor) -> float:
    """Return the number of seconds required to perform `a @ b`."""
    # Wait until previous CUDA threads are done
    if torch.cuda.is_available():
        torch.cuda.synchronize()

    def run():
        # Perform the operation
        a @ b
        # Wait until CUDA threads are done
        if torch.cuda.is_available():
            torch.cuda.synchronize()
            
    # Time the operation `num_trials` times
    num_trials = 5
    total_time = timeit.timeit(run, number=num_trials)
    return total_time / num_trials

def get_device(index: int = 0) -> torch.device:
    """Try to use the GPU if possible, otherwise, use CPU."""
    if torch.cuda.is_available():
        return torch.device(f"cuda:{index}")
    else:
        return torch.device("cpu")

# Tensor Basic

In [2]:
x = torch.tensor([[1., 2, 3], [4, 5, 6]])
x = torch.zeros(4, 8)  # 4x8 matrix of all zeros
x = torch.ones(4, 8)  # 4x8 matrix of all ones 
x = torch.randn(4, 8)  # 4x8 matrix of iid Normal(0, 1) samples 

x = torch.empty(4, 8)  # 4x8 matrix of uninitialized values
print(f"empty: {x}")
nn.init.trunc_normal_(x, mean=0, std=1, a=-2, b=2)
print(f"trunc_normal: {x}")

empty: tensor([[ 7.5551e+31,  1.8672e+25,  1.2709e+31,  4.5277e+21,  7.1561e+22,
          9.4082e-39, -1.1160e+21,  2.5353e+30],
        [ 2.2421e-44,  1.2747e-40,  0.0000e+00,  0.0000e+00,  2.7551e-40,
          4.5918e-40,  1.0561e-38,  1.2864e-20],
        [ 2.6585e+36,  1.6929e+17,  1.0757e+37, -1.0880e-19,  2.6893e+36,
          3.5907e+13,  6.3384e+29,  5.5143e+11],
        [ 4.1735e-39,  2.4832e-37,  7.9081e-41,  2.6568e+27,  1.6531e+19,
          1.0903e+27,  2.5986e+11, -9.0072e+15]])
trunc_normal: tensor([[-0.8814,  1.3793,  0.1921,  0.6386, -0.1770, -0.4021, -1.6637,  0.4388],
        [-1.0273,  1.3479, -1.0692,  0.0289, -0.2132, -1.0884,  0.3821,  1.8391],
        [-0.0670, -1.4548, -0.8951, -0.0703,  0.4263,  0.2505, -0.7096, -0.9426],
        [ 0.7440, -1.6184, -1.1916,  0.0469,  0.0388, -0.8756,  0.6893, -1.3750]])


# Tensor Memory

In [3]:
# fp32 
x = torch.zeros(4, 8) 
assert x.dtype == torch.float32  # Default type
assert x.numel() == 4 * 8
assert x.element_size() == 4  # Float is 4 bytes
assert get_memory_usage(x) == 4 * 8 * 4  # 128 bytes

In [4]:
# fp16
x = torch.zeros(4, 8, dtype=torch.float16)  # @inspect x
assert x.element_size() == 2

x = torch.tensor([1e-8], dtype=torch.float16)  # @inspect x
assert x == 0  # Underflow!

In [5]:
# bf16
x = torch.tensor([1e-8], dtype=torch.bfloat16)  # @inspect x
assert x != 0  # No underflow!

In [6]:
# compare the dynamic ranges and memory usage of the different data types
float32_info = torch.finfo(torch.float32)  # @inspect float32_info
float16_info = torch.finfo(torch.float16)  # @inspect float16_info
bfloat16_info = torch.finfo(torch.bfloat16)  # @inspect bfloat16_info
print(f"float32: {float32_info}")
print(f"float16: {float16_info}")
print(f"bfloat16: {bfloat16_info}")

float32: finfo(resolution=1e-06, min=-3.40282e+38, max=3.40282e+38, eps=1.19209e-07, smallest_normal=1.17549e-38, tiny=1.17549e-38, dtype=float32)
float16: finfo(resolution=0.001, min=-65504, max=65504, eps=0.000976562, smallest_normal=6.10352e-05, tiny=6.10352e-05, dtype=float16)
bfloat16: finfo(resolution=0.01, min=-3.38953e+38, max=3.38953e+38, eps=0.0078125, smallest_normal=1.17549e-38, tiny=1.17549e-38, dtype=bfloat16)


# Tensor on GPUs

In [7]:
x = torch.zeros(32, 32)
assert x.device == torch.device("cpu")

if not torch.cuda.is_available():
    raise RuntimeError("CUDA is not available. Please run this code on a machine with a GPU.")

num_gpus = torch.cuda.device_count()  # @inspect num_gpus
for i in range(num_gpus):
    properties = torch.cuda.get_device_properties(i)  # @inspect properties
    print(f"GPU {i}: {properties.name}, Compute Capability: {properties.major}.{properties.minor}, Memory: {properties.total_memory / (1024 ** 3):.2f} GB")

memory_allocated = torch.cuda.memory_allocated()  # @inspect memory_allocated
print(f"Initial memory allocated on GPU: {memory_allocated / (1024 ** 2):.2f} MB")

print("Move the tensor to GPU memory (device 0).")
y = x.to("cuda:0")
assert y.device == torch.device("cuda", 0)

print("Or create a tensor directly on the GPU:")
z = torch.zeros(32, 32, device="cuda:0")

new_memory_allocated = torch.cuda.memory_allocated()  # @inspect new_memory_allocated
print(f"Memory allocated after moving tensor to GPU: {new_memory_allocated / (1024 ** 2):.2f} MB")
memory_used = new_memory_allocated - memory_allocated  # @inspect memory_used
assert memory_used == 2 * (32 * 32 * 4)  # 2 32x32 matrices of 4-byte floats

GPU 0: NVIDIA GeForce RTX 4090, Compute Capability: 8.9, Memory: 23.65 GB
Initial memory allocated on GPU: 0.00 MB
Move the tensor to GPU memory (device 0).
Or create a tensor directly on the GPU:
Memory allocated after moving tensor to GPU: 0.01 MB


# Tensor Operations

In [8]:
# tensor storage
x = torch.tensor([
    [0., 1, 2, 3],
    [4, 5, 6, 7],
    [8, 9, 10, 11],
    [12, 13, 14, 15],
])

# To go to the next row (dim 0), skip 4 elements in storage.
assert x.stride(0) == 4

# To go to the next column (dim 1), skip 1 element in storage.
assert x.stride(1) == 1

# To find an element:
r, c = 1, 2
index = r * x.stride(0) + c * x.stride(1)  # @inspect index
assert index == 6

In [9]:
# tensor slicing
x = torch.tensor([[1., 2, 3], [4, 5, 6]])

# Get row 0:
y = x[0]  # @inspect y
assert torch.equal(y, torch.tensor([1., 2, 3]))
assert same_storage(x, y)

# Get column 1:
y = x[:, 1]  # @inspect y
assert torch.equal(y, torch.tensor([2, 5]))
assert same_storage(x, y)

# View 2x3 matrix as 3x2 matrix:
y = x.transpose(1, 0)  # @inspect y
assert torch.equal(y, torch.tensor([[1, 4], [2, 5], [3, 6]]))
assert same_storage(x, y)

# Check that mutating x also mutates y.
x[0][0] = 100  # @inspect x, @inspect y
assert y[0][0] == 100

# Note that some views are non-contiguous entries, which means that further views aren't possible.
x = torch.tensor([[1., 2, 3], [4, 5, 6]])  # @inspect x
print("Original x:")
print("x shape:", x.shape)
print("x stride:", x.stride())
y = x.transpose(1, 0)  # @inspect y
print("After transpose:")
print("y shape:", y.shape)
print("y stride:", y.stride())
assert not y.is_contiguous()
try:
    y.view(2, 3)
    # print("After view:")
    # print("y shape:", y.shape)
    # print("y stride:", y.stride())
    assert False
except RuntimeError as e:
    assert "view size is not compatible with input tensor's size and stride" in str(e)

# One can enforce a tensor to be contiguous first:
y = x.transpose(1, 0).contiguous().view(2, 3)  # @inspect y
# or using reshape
# y = x.transpose(1, 0).reshape(2, 3)  # @inspect y
assert not same_storage(x, y)

Original x:
x shape: torch.Size([2, 3])
x stride: (3, 1)
After transpose:
y shape: torch.Size([3, 2])
y stride: (1, 3)


In [10]:
# tensor elementwise
x = torch.tensor([1, 4, 9])
assert torch.equal(x.pow(2), torch.tensor([1, 16, 81]))
assert torch.equal(x.sqrt(), torch.tensor([1, 2, 3]))
assert torch.equal(x.rsqrt(), torch.tensor([1, 1 / 2, 1 / 3]))  # i -> 1/sqrt(x_i)

assert torch.equal(x + x, torch.tensor([2, 8, 18]))
assert torch.equal(x * 2, torch.tensor([2, 8, 18]))
assert torch.equal(x / 0.5, torch.tensor([2, 8, 18]))

# triu takes the upper triangular part of a matrix.
x = torch.ones(3, 3).triu()  # @inspect x
assert torch.equal(x, torch.tensor([
    [1, 1, 1],
    [0, 1, 1],
    [0, 0, 1]],
))
# This is useful for computing an causal attention mask, where M[i, j] is the contribution of i to j.

In [11]:
# tensor matmul
x = torch.ones(16, 32)
w = torch.ones(32, 2)
y = x @ w
assert y.size() == torch.Size([16, 2])

# In general, we perform operations for every example in a batch and token in a sequence.
x = torch.ones(4, 8, 16, 32)
w = torch.ones(32, 2)
y = x @ w
assert y.size() == torch.Size([4, 8, 16, 2])

# Torch einops

In [12]:
# motivation
x = torch.ones(2, 2, 3)  # batch, sequence, hidden  @inspect x
y = torch.ones(2, 2, 3)  # batch, sequence, hidden  @inspect y
z = x @ y.transpose(-2, -1)  # batch, sequence, sequence  @inspect z

In [13]:
# jax typing basic
x = torch.ones(2, 2, 1, 3)  # batch seq heads hidden
# New (jaxtyping) way
# Note: this is just documentation (no enforcement).
x: Float[torch.Tensor, "batch seq heads hidden"] = torch.ones(2, 2, 1, 3)

In [14]:
# einops einsum
x: Float[torch.Tensor, "batch seq1 hidden"] = torch.ones(2, 3, 4)  # @inspect x
y: Float[torch.Tensor, "batch seq2 hidden"] = torch.ones(2, 3, 4)  # @inspect y

# old way
z = x @ y.transpose(-2, -1)  # batch, sequence, sequence  @inspect z

# new(einops) way
# Dimensions that are not named in the output are summed over.
z = einsum(x, y, "batch seq1 hidden, batch seq2 hidden -> batch seq1 seq2") # @inspect z
# Or can use ... to represent broadcasting over any number of dimensions:
z = einsum(x, y, "... seq1 hidden, ... seq2 hidden -> ... seq1 seq2")  # @inspect z

In [15]:
# einops reduce
x: Float[torch.Tensor, "batch seq hidden"] = torch.ones(2, 3, 4)  # @inspect x

# old way
y = x.mean(dim=1)  # @inspect y
# new(einops) way
y = reduce(x, "... hidden -> ...", "mean")  # @inspect y

In [16]:
# einsum rearrange
x: Float[torch.Tensor, "batch seq total_hidden"] = torch.ones(2, 3, 8)  # @inspect x
w: Float[torch.Tensor, "hidden1 hidden2"] = torch.ones(4, 4)

# Break up total_hidden into two dimensions (heads and hidden1):
x = rearrange(x, "... (heads hidden1) -> ... heads hidden1", heads=2)  # @inspect x

# Perform the transformation by w:
x = einsum(x, w, "... heads hidden1, hidden1 hidden2 -> ... heads hidden2")  # @inspect x

# Combine heads and hidden2 back together:
x = rearrange(x, "... heads hidden2 -> ... (heads hidden2)")  # @inspect x

# Tensor Operation flops

In [18]:
# Linear model
# As motivation, suppose you have a linear model.
# We have n points
# Each point is d-dimsional
# The linear model maps each d-dimensional vector to a k outputs
if torch.cuda.is_available():
    B = 16384  # Number of points
    D = 32768  # Dimension
    K = 8192   # Number of outputs
else:
    B = 1024
    D = 256
    K = 64

device = get_device()
x = torch.ones(B, D, device=device)
w = torch.randn(D, K, device=device)
y = x @ w

actual_num_flops = 2 * B * D * K  # @inspect actual_num_flops
actual_time = time_matmul(x, w)  # @inspect actual_time
actual_flop_per_sec = actual_num_flops / actual_time  # @inspect actual_flop_per_sec
print(f"actual_flop_per_sec: {actual_flop_per_sec / 1e12:.2f} TFLOP/s")  # @inspect actual_flop_per_sec
# promised_flop_per_sec = get_promised_flop_per_sec(device, x.dtype)  # @inspect promised_flop_per_sec

actual_flop_per_sec: 51.54 TFLOP/s


In [19]:
# model flops utilization(MFU)
# mfu = actual_flop_per_sec / promised_flop_per_sec  # @inspect mfu

x = x.to(torch.bfloat16)
w = w.to(torch.bfloat16)
bf16_actual_time = time_matmul(x, w)  # @inspect bf16_actual_time
bf16_actual_flop_per_sec = actual_num_flops / bf16_actual_time  # @inspect bf16_actual_flop_per_sec
print(f"bf16_actual_flop_per_sec: {bf16_actual_flop_per_sec / 1e12:.2f} TFLOP/s")  # @inspect bf16_actual_flop_per_sec
# bf16_promised_flop_per_sec = get_promised_flop_per_sec(device, x.dtype)  # @inspect bf16_promised_flop_per_sec
# bf16_mfu = bf16_actual_flop_per_sec / bf16_promised_flop_per_sec  # @inspect bf16_mfu

bf16_actual_flop_per_sec: 148.85 TFLOP/s


# Gradient Basic

In [20]:
# forward pass
x = torch.tensor([1., 2, 3])
w = torch.tensor([1., 1, 1], requires_grad=True)  # Want gradient
pred_y = x @ w
loss = 0.5 * (pred_y - 5).pow(2)

# backward pass
loss.backward()
assert loss.grad is None
assert pred_y.grad is None
assert x.grad is None
assert torch.equal(w.grad, torch.tensor([1, 2, 3]))

  assert loss.grad is None
  assert pred_y.grad is None


# Gradient flops

In [21]:
# Let us do count the FLOPs for computing gradients.
# Revisit our linear model
if torch.cuda.is_available():
    B = 16384  # Number of points
    D = 32768  # Dimension
    K = 8192   # Number of outputs
else:
    B = 1024
    D = 256
    K = 64

device = get_device()
x = torch.ones(B, D, device=device)
w1 = torch.randn(D, D, device=device, requires_grad=True)
w2 = torch.randn(D, K, device=device, requires_grad=True)

# Model: x --w1--> h1 --w2--> h2 -> loss
h1 = x @ w1
h2 = h1 @ w2
loss = h2.pow(2).mean()
# foward flops
num_forward_flops = (2 * B * D * D) + (2 * B * D * K)

h1.retain_grad()  # For debugging
h2.retain_grad()  # For debugging
loss.backward()

# How many FLOPs is running the backward pass?
# Invoke the chain rule.
num_backward_flops = 0

# w2.grad[j,k] = sum_i h1[i,j] * h2.grad[i,k]
assert w2.grad.size() == torch.Size([D, K])
assert h1.size() == torch.Size([B, D])
assert h2.grad.size() == torch.Size([B, K])
# For each (i, j, k), multiply and add.
num_backward_flops += 2 * B * D * K

# h1.grad[i,j] = sum_k w2[j,k] * h2.grad[i,k]
assert h1.grad.size() == torch.Size([B, D])
assert w2.size() == torch.Size([D, K])
assert h2.grad.size() == torch.Size([B, K])
# For each (i, j, k), multiply and add.
num_backward_flops += 2 * B * D * K

# This was for just w2 (D*K parameters).
# Can do it for w1 (D*D parameters) as well (though don't need x.grad).
num_backward_flops += (2 + 2) * B * D * D 