-
Notifications
You must be signed in to change notification settings - Fork 235
/
Copy pathquick_start.py
66 lines (49 loc) · 1.75 KB
/
quick_start.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.
import copy
import torch
from torchao.quantization import int4_weight_only, quantize_
from torchao.utils import (
TORCH_VERSION_AT_LEAST_2_5,
benchmark_model,
unwrap_tensor_subclass,
)
# ================
# | Set up model |
# ================
class ToyLinearModel(torch.nn.Module):
def __init__(self, m: int, n: int, k: int):
super().__init__()
self.linear1 = torch.nn.Linear(m, n, bias=False)
self.linear2 = torch.nn.Linear(n, k, bias=False)
def forward(self, x):
x = self.linear1(x)
x = self.linear2(x)
return x
model = ToyLinearModel(1024, 1024, 1024).eval().to(torch.bfloat16).to("cuda")
# Optional: compile model for faster inference and generation
model = torch.compile(model, mode="max-autotune", fullgraph=True)
model_bf16 = copy.deepcopy(model)
# ========================
# | torchao quantization |
# ========================
# torch 2.4+ only
quantize_(model, int4_weight_only(group_size=32))
# =============
# | Benchmark |
# =============
# Temporary workaround for tensor subclass + torch.compile
# Only needed for torch version < 2.5
if not TORCH_VERSION_AT_LEAST_2_5:
unwrap_tensor_subclass(model)
num_runs = 100
torch._dynamo.reset()
example_inputs = (torch.randn(1, 1024, dtype=torch.bfloat16, device="cuda"),)
bf16_time = benchmark_model(model_bf16, num_runs, example_inputs)
int4_time = benchmark_model(model, num_runs, example_inputs)
print("bf16 mean time: %0.3f ms" % bf16_time)
print("int4 mean time: %0.3f ms" % int4_time)
print("speedup: %0.1fx" % (bf16_time / int4_time))