# Lecture 2: Pytorch & Resource Accounting

In [None]:
import torch

## Basic Concepts

**Efficiency** matters!

- Compute: FLOPS

- Memory: GB


Definition of FLOPS: a metric used to measure the computational power of a computer or processor. It indicates how many **floating-point operations** (calculations involving decimal numbers like addition, subtraction, multiplication, and division) a system can perform per second.

$$ \text{FLOPS (Ideal)} = \text{Number of Cores} \times \text{Clock Frequency per Core} \times \text{Floating Point Operations per cycle} $$

$$ \text{FLOPS (Actual)} = \frac{\text{Total Number of Floating Point Operations Performed}}{\text{Execution Times}} $$


## Memory Accounting

Tensors are the basic building block for storing everything: parameters, gradients, optimizer state, data, activations.

[Official Docs](https://docs.pytorch.org/docs/stable/tensors.html)

Almost everything (parameters, gradients, activations, optimizer states) are stored as floating point numbers.

How to compute memory bytes in Tensors?

```python
def get_memory_usage(x: torch.Tensor):
    return x.numel() * x.element_size()

# torch.numel: Returns the total number of elements in the input tensor.
# torch.element_size: Returns the size in bytes of an individual element.
```

The result shows how many bytes (1 MB = $2^{20}$ bytes) a tensor is.

### Basic Type

- `float32`: 1 + 8 + 23, default type
- `float16`: 1 + 5 + 10, cuts down the memory
- `bfloat16`: 1 + 8 + 7.
- `fp8`: 1 + 4 + 3 (FP8E4M3) & 1 + 5 + 2 (FP8E5M2)

Google Brain developed bfloat (brain floating point) in 2018 to address this issue. bfloat16 uses the same memory as float16 but has the same dynamic range as float32! The only catch is that the resolution is worse, but this matters less for deep learning.

[FP8](https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/examples/fp8_primer.html)

Solution: **use mixed precision training**

In [None]:
def get_memory_usage(x: torch.Tensor):
    return x.numel() * x.element_size()


# torch.numel: Returns the total number of elements in the input tensor.
# torch.element_size: Returns the size in bytes of an individual element.

# for float 32
x = torch.zeros((4, 8, 20))  # @inspect x
print(x.dtype)
print("Number of elements in this tensor: ", x.numel())
print("The size of bytes for an individual element in this tensor: ", x.element_size())
print(get_memory_usage(x), "bytes")
print(get_memory_usage(x) / 2**20)

# for empty tensor?
try:
    empty_tensor = torch.empty(4, 8)
    print(get_memory_usage(empty_tensor))
except Exception as e:
    print(e)


# for float 16
x = torch.ones((4, 8, 20), dtype=torch.float16)
print(x.dtype)
print(x.numel())
print(x.element_size())
# cut the half!
print(get_memory_usage(x))

In [None]:
def get_bytes_information(type):
    x = torch.ones((4, 8, 20), dtype=type)
    print(f"=============={type}================")
    print(f"Dtype: {x.dtype}")
    print(f"Element size: {x.element_size()}")
    print(f"Bytes: {get_memory_usage(x)}")
    print(f"=============={type}================")
    print("\n")


TYPELIST = [torch.float64, torch.float32, torch.float, torch.float16, torch.bfloat16]

for type in TYPELIST:
    get_bytes_information(type=type)

In [None]:
float32_info = torch.finfo(torch.float32)  # @inspect float32_info
float16_info = torch.finfo(torch.float16)  # @inspect float16_info
bfloat16_info = torch.finfo(torch.bfloat16)  # @inspect bfloat16_info
print(float16_info)
print(float32_info)
print(bfloat16_info)

## Compute Accounting

### Tensors on GPU

By default, tensors are stored in CPU memory. However, in order to take advantage of the massive parallelism of GPUs, we need to move them to GPU memory.

![GPU and CPU](https://stanford-cs336.github.io/spring2025-lectures/images/cpu-gpu.png)

In [None]:
# basic information of GPUs
num_gpus = torch.cuda.device_count()  # @inspect num_gpus
for i in range(num_gpus):
    properties = torch.cuda.get_device_properties(i)  # @inspect properties
    print(properties)

print(num_gpus)

In [None]:
import time

x = torch.zeros((4, 8, 10))
print(x.device)

# moving the cpu to gpu
# quite slow if the tensor is large
y = x.to(device=torch.device("cuda:0"))


def test_time_compute(x: torch.Tensor):
    start_time = time.time()
    moved = x.to(device=torch.device("cuda:0"))
    print(moved.device)
    end_time = time.time()
    print(end_time - start_time)


test_time_compute(torch.zeros(size=(20, 20)))
test_time_compute(torch.zeros(size=(50000, 50000)))


# creating a tensor directly to gpu
memory_allocated = torch.cuda.memory_allocated("cuda:1")
time_1 = time.time()
z = torch.zeros(size=(50000, 50000), device="cuda:1")
time_2 = time.time()
print(time_2 - time_1)
memory_allocated_new = torch.cuda.memory_allocated(device="cuda:1")
memory_used = memory_allocated_new - memory_allocated

print(f"Memory Used: {memory_used}")

### Tensor Operations

#### Tensor Storage

PyTorch tensors are pointers into allocated memory, with metadata describing how to get to **any element** of the tensor.

For the methods, we use [`.stride()`](https://docs.pytorch.org/docs/stable/generated/torch.Tensor.stride.html)

![Stride in torch](https://martinlwx.github.io/img/2D_tensor_strides.png)

Stride is the jump necessary to go from one element to the next one in the specified dimension `dim`. A tuple of all strides is returned when no argument is passed in. Otherwise, an integer value is returned as the stride in the particular dimension dim.

In [None]:
test_tensor = torch.randint(1, 1000, size=(10, 10, 20))
print(test_tensor.shape)
print(test_tensor.dim())

# for the first dimension, it will jump 200 steps for reaching the next element
print(test_tensor.stride(0))

# for the second dimension, it will jump 20 steps for reaching the next element
print(test_tensor.stride(1))

# for the last dimension, it will jump 1 step for reaching the next element
print(test_tensor.stride(-1))

print(test_tensor[2,3,4])

How it works? For example, I want to access the value of `test_tensor[i,j,k]`:
I will move: `test_tensor.stride(0) * i + test_tensor.stride(1) * j + test_tensor.stride(2) * k`.

In [None]:
# other operations for tensor: slicing & element_wise
# ! all the elementwise operations are operated by single element!
x = torch.Tensor([3,3,4])
print(x.pow(2))
print(x.rsqrt())

# `triu` takes the upper triangular part of a matrix.
test = torch.randint(1, 1000, size=(2, 2, 2))
print(test.triu())

### Tensor Einops

Einops is a library for manipulating tensors where dimensions are named. It is inspired by Einstein summation notation (Einstein, 1916).

[Official Docs](https://einops.rocks/1-einops-basics/)



In [None]:
# einops demo
from einops import rearrange, reduce, repeat
from PIL import Image
from torchvision import transforms
from torchvision.transforms import ToPILImage

transform = transforms.ToTensor()
# assume there is an image called background.png
try:
    pil_image = Image.open("../../img/background.png")
    original_tensor = transform(pil_image)
except Exception as e:
    print(f"{e}")
    original_tensor = torch.randn(size=(4, 1144, 1718))

original_tensor = original_tensor[:, :1140, :1700]
print(original_tensor.shape)
# torch.Size([4, 1144, 1718]): (c, h, w)


def _to_img(my_tensor, file_path):
    if my_tensor.max() > 1.0:
        my_tensor = my_tensor / 255.0
    to_pil_image = ToPILImage()
    pil_image = to_pil_image(my_tensor)
    pil_image.save(f"../../img/{file_path}.png")


rearrange_tensor = rearrange(original_tensor, "c h w -> c w h")
_to_img(rearrange_tensor, "rearrange_background")

reduce_tensor = reduce(
    original_tensor, "c (h h2) (w w2) -> c h w", "mean", h2=20, w2=20
)
# do average pooling (like the CNN) for given tensor
print(reduce_tensor.shape)
_to_img(reduce_tensor, "reduce_background")

reduce_tensor = original_tensor[1, :, :].squeeze()
print(reduce_tensor.shape)
repeat_tensor = repeat(reduce_tensor, "h w -> c h w", c=4)
_to_img(repeat_tensor, "repeat_background")

#### JaxTyping

`jaxtyping` is a Python library that provides type annotations for your array-based code, particularly for the JAX framework. Think of it as a tool that lets you add precise shape and data type information to your function signatures, going far beyond the basic `jax.Array` or `np.ndarray` type hints.

For `torch.Tensor`, things get the same.

Moreover, jax support JIT and auto-grad functions.

In [None]:
from jaxtyping import Float
x: Float[torch.Tensor, "batch seq heads hidden"] = torch.ones(2, 2, 1, 3)  # @inspect x

print(x)

In [None]:
import jax
from jax import numpy as jnp
from jaxtyping import Float, Int

def matmul(
    A: Float[jax.Array, "batch_size in_features"],
    B: Float[jax.Array, "in_features out_features"]
) -> Float[jax.Array, "batch_size out_features"]:
    """Performs matrix multiplication."""
    return A @ B

# This will pass type checking
A_good = jnp.zeros((128, 784))
B_good = jnp.zeros((784, 10))
result = matmul(A_good, B_good)
print(result.shape) # (128, 10)

# A static type checker will flag an error here because "in_features"
# dimensions don't match (784 vs 600).
A_bad = jnp.zeros((128, 784))
B_bad = jnp.zeros((600, 10))
# result = matmul(A_bad, B_bad)

#### `einsum`

By using `einops`, we can run this code in a better way!

In [None]:
from jaxtyping import Float
from einops import einsum

B = 128
SEQ1 = 100
SEQ2 = 200
HIDDEN_DIM = 128

x: Float[torch.Tensor, "batch seq1 hidden_dim"] = torch.randn(size=(B, SEQ1, HIDDEN_DIM))
y: Float[torch.Tensor, "batch seq2 hidden_dim"] = torch.randn(size=(B, SEQ2, HIDDEN_DIM))

z = einsum(x, y, "batch seq1 hidden_dim, batch seq2 hidden_dim -> batch seq1 seq2")
print(x.shape)
print(y.shape)
print(torch.equal(z, (x @ y.transpose(-2, -1))))
print(z.shape)


#### `reduce`

You can reduce a single tensor via some operation (e.g., sum, mean, max, min).

```python
x: Float[torch.Tensor, "batch seq hidden"] = torch.ones(2, 3, 4)  # @inspect x
# Old way:
y = x.mean(dim=-1)  # @inspect y
# New (einops) way:
y = reduce(x, "... hidden -> ...", "sum")  # @inspect y
```

In [None]:
x: Float[torch.Tensor, "batch seq hidden"] = torch.randn(2, 3, 4)  # @inspect x

# make the last dimension mean to 0
x = x - torch.mean(x, dim=-1, keepdim=True)

y = reduce(x, "... hidden -> ...", "sum")  # @inspect y
print(y.shape)
print(y)

#### `rearrange`

Sometimes, a dimension represents two dimensions.

In [None]:
x: Float[torch.Tensor, "batch seq total_hidden"] = torch.ones(2, 3, 8)  # @inspect x
# ...where total_hidden is a flattened representation of heads * hidden1
w: Float[torch.Tensor, "hidden1 hidden2"] = torch.ones(4, 2)

print(f"x shape: {x.shape}")

# Break up total_hidden into two dimensions (heads and hidden1):
# total_hidden = hidden1 \times hidden2
x = rearrange(x, "... (heads hidden1) -> ... heads hidden1", heads=2)  # @inspect x
print(f"x shape: {x.shape}")

# Perform the transformation by w:
x = einsum(x, w, "... hidden1, hidden1 hidden2 -> ... hidden2")  # @inspect x
# Combine heads and hidden2 back together:
print(f"x shape: {x.shape}")
x = rearrange(x, "... heads hidden2 -> ... (heads hidden2)")  # @inspect x
print(f"x shape: {x.shape}")

### Computation Cost

Having gone through all the operations, let us examine their computational cost.

**A floating-point operation (FLOP)** is a basic operation like addition (x + y) or multiplication (x y).

- FLOPs: floating-point operations (measure of **computation done**)
- FLOP/s: floating-point operations per second (also written as FLOP**S**), which is used to measure the **speed of hardware**.

#### Several Statistics

- GPT-3: `3.14e23` FLOPs

- GPT-4: `2e25` FLOPS

- A100 has a peak performance of 312 teraFLOP/s. (`teraFLOPS` = 1e12 FLOPS)

    - 17806267 hours (total)


#### linear model demo

Core: Matrix Multiplications

In [None]:
if torch.cuda.is_available():
    B = 16384  # Number of points
    D = 32768  # Dimension
    K = 8192   # Number of outputs
else:
    B = 1024
    D = 256
    K = 64
device = "cuda" if torch.cuda.is_available() else "cpu"
x = torch.ones(B, D, device=device)
w = torch.randn(D, K, device=device)
y = x @ w
# We have one multiplication (x[i][j] * w[j][k]) and one addition per (i, j, k) triple.
actual_num_flops = 2 * B * D * K  # @inspect actual_num_flops

print(actual_num_flops)

Interpretation:
    
- B is the number of data points
    
- (D K) is the number of parameters
    
FLOPs for forward pass is 2 (# tokens) (# parameters)
    
It turns out this generalizes to Transformers (to a first-order approximation).

#### Model FLOPS Utilization

$ \text{MFU} = \frac{\text{actual FLOPS}}{\text{promised FLOPS}} $

- $ \text{actual FLOPS} = \frac{\text{sum FLOPs}}{\text{time}} $

- $ \text{promised FLOPS} $ is provided by the hardware company.

Usually, MFU of >= 0.5 is quite good (and will be higher if matmuls dominate)

#### Time Complexity for Several Operations

Consider Matrix $A$: $(m,x,k)$ and matrix $B$: $(k,x,n)$.

- FLOPs for matrix multiplications: $m \times n \times (2k -1)$

- Elementwise operation on a $m \times n$ matrix requires $O(m n)$ FLOPs.
    
- Addition of two $m \times n$ matrices requires $m n$ FLOPs.

FLOPs depends highly on hardware and data types.

### Gradient Basics

Computing Gradients also need computation resources!

Consider simple linear regression model:

```python
x = torch.tensor([1., 2, 3])
w = torch.tensor([1., 1, 1], requires_grad=True)  # Want gradient
pred_y = x @ w
loss = 0.5 * (pred_y - 5).pow(2)
```

Let's get some Math:

$$ L(\vec{w}) = \frac{1}{2}(\vec{x} \cdot \vec{w} - y_{\text{true}})^2 $$

$$ \nabla f = \frac{\partial L}{\partial \vec{w}} = (\frac{\partial L}{\partial w_1}, \frac{\partial L}{\partial w_2}, \frac{\partial L}{\partial w_3}) = (\vec{x} \cdot \vec{w} - y_{\text{true}}) · (x_1, x_2, x_3) $$

In [None]:
x = torch.tensor([1., 2, 3])
w = torch.tensor([1., 1, 1], requires_grad=True)  # Want gradient

# do doct product
pred_y = x @ w

# MSE Error
loss = 0.5 * (pred_y - 5).pow(2)

print(x.shape)
print(w.shape)
print(loss)


# run loss backward
loss.backward()
print(w.grad)

### Gradient Flops

For the neural network, we can make things more complex.

<!-- todo add more complex interpretation for gradient descent. -->

In [None]:
import torch
if torch.cuda.is_available():
    B = 16384  # Number of points
    D = 32768  # Dimension
    K = 8192   # Number of outputs
else:
    B = 1024
    D = 256
    K = 64
device = "cuda" if torch.cuda.is_available() else "cpu"
x = torch.ones(B, D, device=device)
w1 = torch.randn(D, D, device=device, requires_grad=True)
w2 = torch.randn(D, K, device=device, requires_grad=True)
# Model: x --w1--> h1 --w2--> h2 -> loss
h1 = x @ w1
h2 = h1 @ w2
loss = h2.pow(2).mean()

# FLOPs
num_forward_flops = (2 * B * D * D) + (2 * B * D * K)

In [4]:
print(f"{num_forward_flops:.3e}")

4.398e+13
