In [1]:
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

"""

# Key problems: Model too large

Not enough RAM to randomly initialize model, and load checkpoint.

- Solution:
    With MetaTensor, lazy load model.
    Is it really meta tensor? How is it done? (device_map or offload_state_dict?)

Not enough GPU memory to do all computation.

- Solution:
    Gradually load model during computation.
    How is it done? (device_map or offload_state_dict?)

Does it work with export (torch.jit.trace)?
"""

# Load model
model = AutoModelForCausalLM.from_pretrained(
    "bigscience/bloom-560m",
    device_map="auto",  # requires `pip install accelerate`
    offload_state_dict=True,
    torch_dtype=torch.float16,
)

tokenizer = AutoTokenizer.from_pretrained("bigscience/bloom-560m")
sentence = "Question: Can I run BLOOM on a single GPU? Answer:"
inputs = tokenizer(sentence, return_tensors="pt").to(0)
print(inputs.keys())

# Inference in PyTorch
with torch.no_grad():
    outputs = model(**inputs, return_dict=False)

token_id = outputs[0][0][-1].argmax()
# token_id = outputs.logits[0][-1].argmax()
answer = tokenizer.decode([token_id])

print(answer)


# Export to ONNX
torch.onnx.export(
    model,
    (inputs["input_ids"], {"attention_mask": inputs["attention_mask"]}),
    "bloom.onnx",
    opset_version=14,
    do_constant_folding=True,
    input_names=["input_ids", "attention_mask"],
    output_names=["output"],
)

# 'find_mismatch' only support args for now.
class ArgsWrapperModel(torch.nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model

    def forward(self, input_ids, attention_mask):
        return self.model(input_ids, attention_mask=attention_mask, return_dict=False)


from torch.onnx import verification

graph_info = verification.find_mismatch(
    ArgsWrapperModel(model), (inputs["input_ids"], inputs["attention_mask"])
)
leafs = graph_info.all_mismatch_leaf_graph_info()

  from .autonotebook import tqdm as notebook_tqdm


dict_keys(['input_ids', 'attention_mask'])
 Yes


  base = torch.tensor(
  if src_length > 1:


verbose: False, log level: Level.ERROR



  base = torch.tensor(
  if src_length > 1:


Tensor-likes are not close!

Mismatched elements: 126732 / 3512320 (3.6%)
Greatest absolute difference: 0.5 at index (0, 0, 10) (up to 1e-07 allowed)
Greatest relative difference: 1070.2615384615385 at index (0, 3, 24753) (up to 0.001 allowed)
2429 X   __1214 X   __607 X   __303 X    __151 X     __75 ✓
id:     |  id: 0   |  id: 00 |  id: 000 |  id: 0000 |  id: 00000
        |          |         |          |           |
        |          |         |          |           |__76 X       __38 X        __19 X         __9 X           __4 ✓
        |          |         |          |              id: 00001 |  id: 000010 |  id: 0000100 |  id: 00001000 |  id: 000010000
        |          |         |          |                        |             |              |               |
        |          |         |          |                        |             |              |               |__5 X            __2 ✓
        |          |         |          |                        |             |       

In [6]:
leafs = graph_info.all_mismatch_leaf_graph_info()
leaf_idx = 2
leafs[leaf_idx].pretty_print_mismatch(graph=True)
leafs[leaf_idx].pretty_print_tree()

graph(%aten::linear_1359 : Half(1, 14, 1024, strides=[14336, 1024, 1], requires_grad=0, device=cuda:3),
      %model.transformer.h.0.self_attention.dense.weight : Half(1024, 1024, strides=[1024, 1], requires_grad=1, device=cuda:3),
      %model.transformer.h.0.self_attention.dense.bias : Half(1024, strides=[1], requires_grad=1, device=cuda:3)):
  %x.1 : Half(1, 14, 1024, strides=[14336, 1024, 1], requires_grad=0, device=cuda:3) = aten::linear(%aten::linear_1359, %model.transformer.h.0.self_attention.dense.weight, %model.transformer.h.0.self_attention.dense.bias) # /bert_ort/bowbao/pytorch/torch/nn/modules/linear.py:114:0
  return (%x.1)

graph(%aten::linear_1359 : Half(1, 14, 1024, strides=[14336, 1024, 1], requires_grad=0, device=cuda:3),
      %model.transformer.h.0.self_attention.dense.bias : Half(1024, strides=[1], requires_grad=0, device=cuda:3),
      %6 : Half(1024, 1024, strides=[1, 1024], requires_grad=0, device=cuda:3)):
  %4 : Half(1, 14, 1024, strides=[14336, 1024, 1], devi