In [4]:
from IPython.display import display
from pathlib import Path
import sys
import timeit
import os

project_dir = Path(os.path.abspath('')).parent
basics_path = (project_dir / "cs336-basics").as_posix()
if sys.path[0] != basics_path:
    sys.path.insert(0, basics_path)

import pandas as pd
import numpy as np
import torch
from torch import nn
from tqdm import tqdm

from cs336_basics.model import BasicsTransformerLM

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

%matplotlib inline

In [3]:
def mixed_precision_accumulation():
    s = torch.tensor(0, dtype=torch.float32)

    for i in range(1000):
        s += torch.tensor(0.01, dtype=torch.float32)
    print(s)

    s = torch.tensor(0, dtype=torch.float16)
    for i in range(1000):
        s += torch.tensor(0.01, dtype=torch.float16)
    print(s)

    s = torch.tensor(0, dtype=torch.float32)
    for i in range(1000):
        s += torch.tensor(0.01, dtype=torch.float16)
    print(s)

    s = torch.tensor(0, dtype=torch.float32)
    for i in range(1000):
        x = torch.tensor(0.01, dtype=torch.float16)
        s += x.type(torch.float32)
    print(s)
mixed_precision_accumulation()

tensor(10.0001)
tensor(9.9531, dtype=torch.float16)
tensor(10.0021)
tensor(10.0021)


In [14]:
class ToyModel(nn.Module):
    def __init__(self, in_features: int, out_features: int):
        super().__init__()
        self.fc1 = nn.Linear(in_features, 10, bias=False)
        self.ln = nn.LayerNorm(10)
        self.fc2 = nn.Linear(10, out_features, bias=False)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.relu(self.fc1(x))
        x = self.ln(x)
        x = self.fc2(x)
        return x

    def print_dtype(self, x):
        for name, param in self.named_parameters():
            print(f"{name}: {param.dtype}")
        logits = self(x)
        x = self.fc1(x)
        print(f"fc1(x): {x.dtype}")
        x = self.relu(x)
        print(f"relu(fc1(x)): {x.dtype}")
        x = self.ln(x)
        print(f"ln(relu(fc1(x))): {x.dtype}")
        print(f"logits: {logits.dtype}")
        loss = logits.sum()
        loss.backward()
        for name, param in self.named_parameters():
            print(f"{name}.grad: {param.grad.dtype}")


model = ToyModel(20, 5).to(device)
x = torch.rand(20).to(device)
print("x:", x.dtype)
print("-" * 5, "normal", "-" * 5)
model.print_dtype(x)
print("-" * 5, "in autocast float16", "-" * 5)
with torch.autocast(device_type="cuda", dtype=torch.float16):
    model.print_dtype(x)
print("-" * 5, "in autocast bfloat16", "-" * 5)
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
    model.print_dtype(x)

x: torch.float32
----- normal -----
fc1.weight: torch.float32
ln.weight: torch.float32
ln.bias: torch.float32
fc2.weight: torch.float32
fc1(x): torch.float32
relu(fc1(x)): torch.float32
ln(relu(fc1(x))): torch.float32
logits: torch.float32
fc1.weight.grad: torch.float32
ln.weight.grad: torch.float32
ln.bias.grad: torch.float32
fc2.weight.grad: torch.float32
----- in autocast float16 -----
fc1.weight: torch.float32
ln.weight: torch.float32
ln.bias: torch.float32
fc2.weight: torch.float32
fc1(x): torch.float16
relu(fc1(x)): torch.float16
ln(relu(fc1(x))): torch.float32
logits: torch.float16
fc1.weight.grad: torch.float32
ln.weight.grad: torch.float32
ln.bias.grad: torch.float32
fc2.weight.grad: torch.float32
----- in autocast bfloat16 -----
fc1.weight: torch.float32
ln.weight: torch.float32
ln.bias: torch.float32
fc2.weight: torch.float32
fc1(x): torch.bfloat16
relu(fc1(x)): torch.bfloat16
ln(relu(fc1(x))): torch.float32
logits: torch.bfloat16
fc1.weight.grad: torch.float32
ln.weight.gra