In [3]:
import json
from contextlib import nullcontext
import os
import sys
from pathlib import Path

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.quantization import float_qparams_weight_only_qconfig, default_qconfig

# Support absolute imports for a standalone script
sys.path.insert(0, Path.cwd().parent.as_posix())

from tinystories import get_tokenizer_model_path  # noqa: E402
from tokenizer import Tokenizer  # noqa: E402
from model import ModelArgs, Transformer  # noqa: E402

In [4]:
# helper functions
import functools
import py._io.capture
import py._io
import py

start = (
    ""  # or "<|endoftext|>" or etc. Can also specify a file, use as: "FILE:prompt.txt"
)
max_new_tokens = 25  # number of tokens generated in each sample
temperature = (
    0  # 1.0 = no change, < 1.0 = less random, > 1.0 = more random, in predictions
)
device = (
    "cuda" if torch.cuda.is_available() else "cpu"
)  # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1', etc.
top_k = (
    300  # retain only the top_k most likely tokens, clamp others to have 0 probability
)
tokenizer = ""  # override the tokenizer model path
seed = 1337

# original model instance loaded from checkpoint (.pt file)
orig_model_path = "../out/softmax0-15m-2023_08_26_00_08_49/ckpt.pt"
# quantized .bin model output filepath
q_filepath = "../out/quantized/softmax0-15m-2023_08_26_00_08_49.pt"
# Path to tokenizer (config.json)
tok_path = "../out/softmax0-15m-2023_08_26_00_08_49/"


def print_model_size(mdl):
    torch.save(mdl.state_dict(), "tmp.pt")
    print("%.2f MB" % (os.path.getsize("tmp.pt") / 1e6))
    os.remove("tmp.pt")


def load_model(out_dir):
    # init from a model saved in a specific directory
    checkpoint_dict = torch.load(out_dir, map_location=device)
    # del flash if exists
    if "flash" in checkpoint_dict["model_args"]:
        del checkpoint_dict["model_args"]["flash"]
    # softmax -> softmax1 in model_args
    if "softmax" in checkpoint_dict["model_args"]:
        checkpoint_dict["model_args"]["softmax1"] = checkpoint_dict["model_args"][
            "softmax"
        ]
        del checkpoint_dict["model_args"]["softmax"]
    gptconf = ModelArgs(**checkpoint_dict["model_args"])
    model = Transformer(gptconf)
    state_dict = checkpoint_dict["model"]
    unwanted_prefix = "_orig_mod."
    for k, v in list(state_dict.items()):
        if k.startswith(unwanted_prefix):
            state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
    model.load_state_dict(state_dict, strict=False)
    model.eval()
    return model, checkpoint_dict["model_args"]


def load_quant_model(dir):
    # init from a model saved in a specific directory
    checkpoint_dict = torch.load(dir)
    # del flash if exists
    if "flash" in checkpoint_dict["model_args"]:
        del checkpoint_dict["model_args"]["flash"]
    # softmax -> softmax1 in model_args
    if "softmax" in checkpoint_dict["model_args"]:
        checkpoint_dict["model_args"]["softmax1"] = checkpoint_dict["model_args"][
            "softmax"
        ]
        del checkpoint_dict["model_args"]["softmax"]
    gptconf = ModelArgs(**checkpoint_dict["model_args"])
    model = Transformer(gptconf)

    # VERY IMPORTANT embeddings only support float_qparams_weight_only_qconfig quantization
    model.tok_embeddings.qconfig = float_qparams_weight_only_qconfig
    model_dynamic_quantized = torch.quantization.quantize_dynamic(
        model, dtype=torch.qint8
    )

    state_dict = checkpoint_dict["model"]
    unwanted_prefix = "_orig_mod."
    for k, v in list(state_dict.items()):
        if k.startswith(unwanted_prefix):
            state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
    model_dynamic_quantized.load_state_dict(state_dict, strict=True)
    model_dynamic_quantized.eval()
    return model_dynamic_quantized, checkpoint_dict["model_args"]


def load_tokenizer(out_dir):
    # load the tokenizer
    with open(f"{out_dir}/config.json", "r") as f:
        config = json.load(f)
    vocab_source = config.get("vocab_source", "llama2")
    vocab_size = config["vocab_size"]
    if tokenizer:
        # a specific tokenizer is provided, use it
        tokenizer_model = tokenizer
    else:
        # let's try to find the tokenizer model automatically. bit gross here...
        query_vocab_size = 0 if vocab_source == "llama2" else vocab_size
        tokenizer_model = get_tokenizer_model_path(vocab_size=query_vocab_size)
    enc = Tokenizer(tokenizer_model="../tokenizer.model")
    return enc


enc = load_tokenizer(tok_path)


def encode_prompt(start, device):
    if start.startswith("FILE:"):
        with open(start[5:], "r", encoding="utf-8") as f:
            start = f.read()
    start_ids = enc.encode(start, bos=True, eos=False)
    x = torch.tensor(start_ids, dtype=torch.long, device=device)[None, ...]
    return x


# run generation
@torch.no_grad()
def generate(
    model,
    prompt,
    device,
    max_new_tokens=max_new_tokens,
    temperature=temperature,
    top_k=top_k,
):
    x = encode_prompt(prompt, device=device)
    with ctx:
        y = model.generate(x, max_new_tokens, temperature, top_k)
        return enc.decode(y.tolist())[0]


def get_capture(out, in_):
    try:
        capture = py.io.StdCaptureFD(out=out, in_=in_)
    except:
        capture = None
    return capture


def reset_capture(capture):
    if capture is not None:
        capture.reset()


def hide_warnings(function=None, out=True, in_=False):
    """Suppresses C++ warnings in PyTorch underlying methods. Decorate on functions"""

    def decorator_hide_warnings(func):
        @functools.wraps(func)
        def wrapper(*args, **kwargs):
            capture = get_capture(out, in_)
            result = func(*args, **kwargs)
            reset_capture(capture)
            return result

        return wrapper

    if function:
        return decorator_hide_warnings(function)
    return decorator_hide_warnings


@hide_warnings
def compute_perplexity(
    prompt,
    model,
    device,
    max_new_tokens=max_new_tokens,
    temperature=temperature,
    top_k=top_k,
):
    """
    Compute the perplexity given the logits of generated tokens and their corresponding indices.
    """
    x = encode_prompt(prompt, device=device)
    idx, logits = model.generate(
        x,
        max_new_tokens=max_new_tokens,
        temperature=temperature,
        top_k=top_k,
        return_logits=True,
    )
    probs = F.softmax(logits, dim=-1)
    actual_probs = torch.gather(
        probs, -1, idx.unsqueeze(-1)[:, x.shape[1]:, :])
    neg_log_probs = -torch.log(actual_probs)

    y = enc.decode(idx.tolist())[0]
    return y, torch.exp(torch.mean(neg_log_probs)).item()

# Investigating parameter count of layers

Which weight matrices are the biggest?


In [14]:
model, config = load_model(orig_model_path)

# Embeddings param count
print("Embeddings param count")
print_model_size(model.tok_embeddings)  # only 1 matrix cause weight tying

# Transformer param count
print("Transformer param count")
print_model_size(model.layers)

# Total param count
print("Total param count")
print_model_size(model)

Embeddings param count
36.86 MB
Transformer param count
25.49 MB
Total param count
62.36 MB


# Quantization of existing model to 8bit


In [15]:
model, config = load_model(orig_model_path)
# VERY IMPORTANT embeddings only support float_qparams_weight_only_qconfig quantization
model.tok_embeddings.qconfig = float_qparams_weight_only_qconfig
torch.backends.quantized.engine = "qnnpack"
model_dynamic_quantized = torch.quantization.quantize_dynamic(
    model, dtype=torch.qint8)

print_model_size(model)
print_model_size(model_dynamic_quantized)

print_model_size(model_dynamic_quantized.tok_embeddings)
print_model_size(model_dynamic_quantized.layers)

# save the quantized model
checkpoint_dict = {"model_args": config,
                   "model": model_dynamic_quantized.state_dict()}
torch.save(checkpoint_dict, q_filepath)

62.36 MB
26.51 MB
9.47 MB
7.69 MB


# PPL comparison between unquantized and quantized model


In [7]:
orig_model, config = load_model(orig_model_path)
quant_model, config = load_quant_model(q_filepath)

# only unquantized models can run on GPU
orig_model = orig_model.to("cpu")
quant_model = quant_model.to("cpu")

print_model_size(orig_model)
print_model_size(quant_model)

62.36 MB
26.51 MB


In [8]:
import warnings

warnings.simplefilter("ignore")

prompt = "Sally sold seashells"
max_new_tokens = 20

orig_ppl = compute_perplexity(prompt, orig_model, device="cpu")

quant_ppl = compute_perplexity(
    prompt,
    quant_model,
    device=next(quant_model.parameters()).device,
)

print(orig_ppl)
print(quant_ppl)

('Sally sold seashells. She had a big bag of seashells and she wanted to sell them. She put the seashells in a', 1.8265522718429565)
('Sally sold seashells. She was very happy. She had a big pile of seashells in her garden. She was so proud of', 2.160936117172241)
