In [1]:
import torch
import torch.nn as nn

import bitsandbytes
from bitsandbytes.nn import Linear8bitLt

In [2]:
# define fp16 model
fp16_model = nn.Sequential(
    nn.Linear(64, 64),
    nn.Linear(64, 64)
)

In [3]:
# Assuming the model has been trained
torch.save(fp16_model.state_dict(), "model.pt")

In [4]:
# define an int8 model
int8_model = nn.Sequential(
    Linear8bitLt(64, 64, has_fp16_weights=False), # important to add the flag "has_fp16_weights"  
    Linear8bitLt(64, 64, has_fp16_weights=False)  # has_fp16_weights = True: 혼합 정밀도(mixed precision),Int8와 FP16을 섞어서 학습(training) 할 때 사용
)

In [5]:
# before calling the .to function
int8_model[0].weight

Parameter containing:
Parameter(Int8Params([[ 0.0177, -0.1196, -0.0731,  ..., -0.0559,  0.1001, -0.0698],
            [ 0.0621, -0.0373, -0.0265,  ...,  0.0484,  0.0413, -0.0669],
            [ 0.0821, -0.0486,  0.0993,  ..., -0.0769,  0.0651, -0.0104],
            ...,
            [-0.0200, -0.0156,  0.0902,  ..., -0.1025, -0.0956, -0.1116],
            [-0.1015,  0.0990,  0.1098,  ..., -0.0298, -0.0594, -0.0190],
            [-0.0742,  0.0241,  0.0710,  ...,  0.0275,  0.0411, -0.0178]]))

In [6]:
# load the model in 8-bit
int8_model.load_state_dict(torch.load("model.pt"))
int8_model = int8_model.to(0) # Quantization happens here

In [7]:
int8_model[0].weight

Parameter containing:
Parameter(Int8Params([[  73,  116,   60,  ...,   95,   89,   -3],
            [ -58,  -13,   60,  ...,   38,   27,   32],
            [ -43,  -82,   76,  ...,  -77,  -40,   -7],
            ...,
            [ -65,   66,   10,  ..., -125,    2,  -38],
            [-127,  -20,  -23,  ..., -106,   75, -110],
            [  21,  -45,   89,  ...,   20, -125, -100]], device='cuda:0',
           dtype=torch.int8))

In [8]:
# how to retrieve the FP16 weight in order to perform the outlier MatMul in fp16?
(int8_model[0].weight.CB * int8_model[0].weight.SCB) / 127

tensor([[ 0.0697,  0.1136,  0.0580,  ...,  0.0924,  0.0871, -0.0029],
        [-0.0554, -0.0127,  0.0580,  ...,  0.0369,  0.0264,  0.0310],
        [-0.0411, -0.0803,  0.0735,  ..., -0.0749, -0.0392, -0.0068],
        ...,
        [-0.0621,  0.0646,  0.0097,  ..., -0.1215,  0.0020, -0.0368],
        [-0.1213, -0.0196, -0.0223,  ..., -0.1031,  0.0734, -0.1064],
        [ 0.0201, -0.0441,  0.0861,  ...,  0.0194, -0.1224, -0.0967]],
       device='cuda:0')

In [9]:
# infer using the model
input_ = torch.randn((1, 64), dtype=torch.float16)
hidden_states = int8_model(input_.to(torch.device('cuda', 0)))