<a href="https://colab.research.google.com/github/yangliupku/cs336_assignment2_systems/blob/main/notebooks/mixed_precision.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch

In [2]:
torch.cuda.is_available()

True

In [8]:
s = torch.tensor(0, dtype=torch.float32, device='cuda')
for i in range(1000):
  s += torch.tensor(0.01, dtype=torch.float32, device='cuda')
print(s)

tensor(10.0001, device='cuda:0')


In [11]:
s = torch.tensor(0, dtype=torch.float16, device='cuda')
for i in range(1000):
  s += torch.tensor(0.01, dtype=torch.float16, device='cuda')
print(s)

tensor(9.9531, device='cuda:0', dtype=torch.float16)


In [12]:
s = torch.tensor(0, dtype=torch.bfloat16, device='cuda')
for i in range(1000):
  s += torch.tensor(0.01, dtype=torch.bfloat16, device='cuda')
print(s)

tensor(4., device='cuda:0', dtype=torch.bfloat16)


In [13]:
s = torch.tensor(0, dtype=torch.float32, device='cuda')
for i in range(1000):
  s += torch.tensor(0.01, dtype=torch.float16, device='cuda')
print(s)

tensor(10.0021, device='cuda:0')


In [14]:
s = torch.tensor(0, dtype=torch.float32, device='cuda')
for i in range(1000):
  x = torch.tensor(0.01, dtype=torch.float16, device='cuda')
  s += x.type(torch.float32)
print(s)

tensor(10.0021, device='cuda:0')


In [15]:
a = torch.tensor(0.01, dtype=torch.bfloat16, device='cuda')

In [25]:
s = torch.tensor(0, dtype=torch.bfloat16, device='cuda')
for i in range(500):
  s += torch.tensor(0.01, dtype=torch.bfloat16, device='cuda')
print(s)

tensor(4., device='cuda:0', dtype=torch.bfloat16)


In [43]:
class ToyModel(torch.nn.Module):
  def __init__(self, in_features: int, out_features: int):
    super().__init__()
    self.fc1 = torch.nn.Linear(in_features=in_features, out_features=10, bias=False)
    self.fc2 = torch.nn.Linear(in_features=10, out_features=out_features, bias=False)
    self.ln = torch.nn.LayerNorm(10)
    self.relu = torch.nn.ReLU()

  def forward(self, x):
    print("forward")
    print("input:", x.dtype)
    x = self.fc1(x)
    print("fc1:", x.dtype)
    x = self.ln(x)
    print("ln:", x.dtype)
    x = self.relu(x)
    print("relu:", x.dtype)
    x = self.fc2(x)
    print("fc2:", x.dtype)
    return x


In [46]:
model = ToyModel(5, 8).to('cuda')
x = torch.rand(15, 5, device='cuda')
y = model(x)


forward
input: torch.float32
fc1: torch.float32
ln: torch.float32
relu: torch.float32
fc2: torch.float32


In [47]:
dtype = torch.float16
model = ToyModel(5, 8).to('cuda')
x = torch.rand(15, 5, device='cuda', dtype=dtype)
with torch.autocast(device_type='cuda', dtype=dtype):
  y = model(x)


forward
input: torch.float16
fc1: torch.float16
ln: torch.float32
relu: torch.float32
fc2: torch.float16


In [64]:
model = ToyModel(5, 8).to('cuda')
x = torch.rand(15, 5, device='cuda')
opt = torch.optim.AdamW(params=model.parameters())
opt.zero_grad()
y = model(x)
print("---- params ----")
for k, v in model.named_parameters():
  print(k, v.dtype)
print("y dtype", y.dtype)
loss = (y.mean())**2
print("loss dtype", loss.dtype)
loss.backward()
print("---- params grad----")
for k, v in model.named_parameters():
  print(k, v.grad.dtype)
opt.step()


forward
input: torch.float32
fc1: torch.float32
ln: torch.float32
relu: torch.float32
fc2: torch.float32
---- params ----
fc1.weight torch.float32
fc2.weight torch.float32
ln.weight torch.float32
ln.bias torch.float32
y dtype torch.float32
loss dtype torch.float32
---- params grad----
fc1.weight torch.float32
fc2.weight torch.float32
ln.weight torch.float32
ln.bias torch.float32


In [66]:
model = ToyModel(5, 8).to('cuda')
x = torch.rand(15, 5, device='cuda')
opt = torch.optim.AdamW(params=model.parameters())
with torch.autocast(device_type='cuda', dtype=torch.float16):
  opt.zero_grad()
  y = model(x)
  print("---- params ----")
  for k, v in model.named_parameters():
    print(k, v.dtype)
  print("y dtype", y.dtype)
  loss = (y.mean())**2
  print("loss dtype", loss.dtype)
  loss.backward()
  print("---- params grad----")
  for k, v in model.named_parameters():
    print(k, v.grad.dtype)
  opt.step()

forward
input: torch.float32
fc1: torch.float16
ln: torch.float32
relu: torch.float32
fc2: torch.float16
---- params ----
fc1.weight torch.float32
fc2.weight torch.float32
ln.weight torch.float32
ln.bias torch.float32
y dtype torch.float16
loss dtype torch.float32
---- params grad----
fc1.weight torch.float32
fc2.weight torch.float32
ln.weight torch.float32
ln.bias torch.float32


In [67]:
model = ToyModel(5, 8).to('cuda')
x = torch.rand(15, 5, device='cuda')
opt = torch.optim.AdamW(params=model.parameters())
with torch.autocast(device_type='cuda', dtype=torch.bfloat16):
  opt.zero_grad()
  y = model(x)
  print("---- params ----")
  for k, v in model.named_parameters():
    print(k, v.dtype)
  print("y dtype", y.dtype)
  loss = (y.mean())**2
  print("loss dtype", loss.dtype)
  loss.backward()
  print("---- params grad----")
  for k, v in model.named_parameters():
    print(k, v.grad.dtype)
  opt.step()

forward
input: torch.float32
fc1: torch.bfloat16
ln: torch.float32
relu: torch.float32
fc2: torch.bfloat16
---- params ----
fc1.weight torch.float32
fc2.weight torch.float32
ln.weight torch.float32
ln.bias torch.float32
y dtype torch.bfloat16
loss dtype torch.float32
---- params grad----
fc1.weight torch.float32
fc2.weight torch.float32
ln.weight torch.float32
ln.bias torch.float32
