In [19]:
import torch
from torch import nn

In [20]:
class ToyModel(nn.Module):
    def __init__(self, in_features: int, out_features: int):
        super().__init__()
        self.fc1 = nn.Linear(in_features, 10, bias=False)
        self.ln = nn.LayerNorm(10)
        self.fc2 = nn.Linear(10, out_features, bias=False)
        self.relu = nn.ReLU()
    def forward(self, x):
        x = self.relu(self.fc1(x))
        print("fc1", self.fc1.weight.dtype)
        print("after fc1", x.dtype)
        x = self.ln(x)
        print("ln", self.ln.weight.dtype)
        print("after ln", x.dtype)
        x = self.fc2(x)
        print("fc2", self.fc2.weight.dtype)
        print("after fc2", x.dtype)
        return x


In [35]:

m = ToyModel(4, 4).cuda()

with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
    s = m(torch.zeros((1,4)).cuda())
    s = torch.sum(s)
    print(s.dtype)
    s.backward()
    print("fc1 grad", m.fc1.weight.grad.dtype)
    print("ln grad", m.ln.weight.grad.dtype)
    print("fc2 grad", m.fc2.weight.grad.dtype)


RuntimeError: CUDA error: unspecified launch failure
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.


In [1]:
import os

os.environ['CUDA_LAUNCH_BLOCKING'] = "1"
os.environ['CUDA_VISIBLE_DEVICES'] = "0"
os.environ['PYTHONLIBRARY'] = '/home/qwertier/projects/stanford_cs336/spring2024-assignment2-systems/cs336-systems'

from cs336_basics import model
import timeit
from torch import nn
import torch
from cs336_systems import rms_norm

for n_dim in [1024, 2048, 4096, 8192]:
    input = torch.zeros((50000, n_dim)).cuda()
    rms_norm_naive = model.RMSNorm(n_dim).cuda()
    layer_norm = nn.LayerNorm(n_dim).cuda()
    rms_norm_triton = rms_norm.RMSNorm(n_dim).cuda()

    print(n_dim, "layer_norm", timeit.timeit(lambda: layer_norm(input), number=1000))
    print(n_dim, "rms_norm_naive", timeit.timeit(lambda: rms_norm_naive(input), number=1000))
    print(n_dim, "rms_norm_triton", timeit.timeit(lambda: rms_norm_triton(input), number=1000))



1024 layer_norm 0.7614577860003919


KeyboardInterrupt: 

In [3]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

In [2]:
!nvidia-smi

Sun Apr 28 21:17:39 2024       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.76                 Driver Version: 550.76         CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA GeForce RTX 3090        Off |   00000000:01:00.0  On |                  N/A |
| 30%   27C    P5             41W /  350W |    1337MiB /  24576MiB |     41%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                