In [1]:
!pip install sentencepiece protobuf transformers



In [2]:
import os
os.environ["TORCH_LOGS"] = "dynamic"

import torch._dynamo as dynamo
from torch._export import dynamic_dim
from torch._export.constraints import constrain_as_size, constrain_as_value

from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
from torch.utils import _pytree as pytree
import textwrap
AUTH_TOKEN = "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"


  from .autonotebook import tqdm as notebook_tqdm


In [3]:
mdl = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-2-7b-chat-hf",
    torch_dtype=torch.float,
    use_auth_token=AUTH_TOKEN,
)
tokenizer = AutoTokenizer.from_pretrained(
    "meta-llama/Llama-2-7b-chat-hf",
    use_fast=False,
    use_auth_token=AUTH_TOKEN,
)

Loading checkpoint shards: 100%|██████████| 2/2 [00:04<00:00,  2.45s/it]


In [4]:
def summarize_results(results):
    past_key_values, _ = pytree.tree_flatten(results.past_key_values)
    print("Logits:", pytree.tree_map(lambda x: x.shape, results.logits))
    print(f"PKV (len={len(past_key_values)}):")
    count = 0
    prev = ""
    for s in pytree.tree_map(lambda x: repr(x.shape), past_key_values):
        if s == prev:
            count += 1
            continue
        elif count:
            print(" ", s, f"* {count+1}" if count else "")
            count = 0
        prev = s
    if count:
        print(" ", s, f"* {count+1}" if count else "")
    
    
prompt = (
        "System: You are a helpful, respectful and honest assistant. Always answer "
        "as helpfully as possible, while being safe.  Your answers should not "
        "include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal "
        "content. Please ensure that your responses are socially unbiased and positive "
        "in nature. If a question does not make any sense, or is not factually coherent, "
        "explain why instead of answering something not correct. If you don't know the "
        "answer to a question, please don't share false information."
    )
conversation = prompt + "<|USER|>Should Bugs Bunny have turned left at Albuquerque?"

initial_input = tokenizer(conversation, return_tensors="pt")
print("Example input:", initial_input)
print("  Shape:", initial_input.input_ids.shape)
initial_results = mdl.forward(initial_input.input_ids)
summarize_results(initial_results)

all_tokens = []
all_detoks = []
def decode_token(results, index=-1, store=True):
    print("Logits:", results.logits.shape)
    print("Logits reshaped:", results.logits[:, index, :].shape)
    token = torch.argmax(results.logits[:, index, :], dim=1)
    detok = tokenizer.decode(token, skip_special_tokens=False)
    print(f"--> Decoded: '{detok}' ({token})")
    if store:
        all_tokens.append(token[0])
        all_detoks.append(detok)
    return token, detok

# Decode initial token
# for i in range(initial_results.logits.shape[1]):
#     token, detok = decode_token(initial_results, index=i)
token, detok = decode_token(initial_results, store=True)

# Decode loop for subsequent tokens.
current_results = initial_results
for _ in range(5):
    prior_pkvs, _ = pytree.tree_flatten(current_results.past_key_values)
    next_input_token = torch.reshape(token, [1, 1])
    print("Next input token:", next_input_token)
    step_results = mdl.forward(next_input_token, past_key_values=current_results.past_key_values)
    summarize_results(step_results)
    token, detok = decode_token(step_results)
    if token[0] == 2:
        break
    current_results = step_results

    current_pkvs, _ = pytree.tree_flatten(current_results.past_key_values)
    pkv_len = prior_pkvs[0].shape[2]
    for check_step in range(pkv_len):
        for left, right in zip(prior_pkvs, current_pkvs):
            if not torch.equal(left[:, :, check_step, :], right[:, :, check_step, :]):
                print(f"PKVS MISMATCH AT STEP {check_step}!")

print("All tokens:", all_tokens)
print("All detoks:", all_detoks)

print(conversation)
print(tokenizer.decode(all_tokens))

Example input: {'input_ids': tensor([[    1,  2184, 29901,   887,   526,   263,  8444, 29892,  3390,  1319,
           322, 15993, 20255, 29889, 29849,  1234,   408,  1371,  3730,   408,
          1950, 29892,  1550,  1641,  9109, 29889, 29871,  3575,  6089,   881,
           451,  3160,   738, 10311,  1319, 29892,   443,   621,   936, 29892,
         11021,   391, 29892,  7916,   391, 29892,   304, 27375, 29892, 18215,
         29892,   470, 27302,  2793, 29889,  3529,  9801,   393,   596, 20890,
           526,  5374,   635,   443,  5365,  1463,   322,  6374,   297,  5469,
         29889,   960,   263,  1139,   947,   451,  1207,   738,  4060, 29892,
           470,   338,   451,  2114,  1474, 16165,   261,   296, 29892,  5649,
          2020,  2012,   310, 22862,  1554,   451,  1959, 29889,   960,   366,
          1016, 29915, 29873,  1073,   278,  1234,   304,   263,  1139, 29892,
          3113,  1016, 29915, 29873,  6232,  2089,  2472, 19423, 29989, 11889,
         29989, 29958, 

## Attempt to FX Trace the initialization graph

In [5]:
import collections
from torch._export.constraints import constrain_as_size, constrain_as_value

def summarize_state_shape(tree):
    print(pytree.tree_map(
        lambda x: f"Tensor({x.shape})" if isinstance(x, torch.Tensor) else x,
        tree))

BATCH_SIZE = 1
MAX_STEP_SEQ = 4095
# empty_states = pytree.tree_map(
#     lambda x: torch.zeros(
#         BATCH_SIZE, x.shape[1], MAX_STEP_SEQ, x.shape[3], 
#         dtype=x.dtype), initial_results.past_key_values)

StateStruct = collections.namedtuple("InferenceState", "step_seq,past_key_values")
initial_step_seq = initial_results.past_key_values[0][0].shape[3]
step_example_states = StateStruct(
    step_seq=initial_step_seq,
    past_key_values=pytree.tree_map(
        lambda x: torch.zeros(
            BATCH_SIZE, x.shape[1], MAX_STEP_SEQ, initial_step_seq, 
            dtype=x.dtype), 
        initial_results.past_key_values),
)
_, state_schema = pytree.tree_flatten(step_example_states.past_key_values)

class InferenceModel(torch.nn.Module):
    def __init__(self, base_model):
        super().__init__()
        self.base_model = base_model

    def initialize(self, input_ids: torch.Tensor):
        result = self.base_model.forward(input_ids)
        state1_flat, _ = pytree.tree_flatten(result.past_key_values)
        token1 = torch.argmax(result.logits[:, -1, :], dim=1)
        token1 = token1[None, :]
        return token1, *state1_flat

    def forward(self, token0: torch.Tensor, *state0_flat):
        # Unpad the states.
        state0 = pytree.tree_unflatten(state0_flat, state_schema)
        result = self.base_model.forward(token0, past_key_values=state0)
        state1_flat, _ = pytree.tree_flatten(result.past_key_values)
        state1_flat = [x[:, :, -2:-1, :] for x in state1_flat]
        token1 = torch.argmax(result.logits[:, -1, :], dim=1)
        return token1, *state1_flat


sm = InferenceModel(mdl)
input_ids = initial_input.input_ids

print("Example initialize:")
example_token0, *example_state0 = sm.initialize(input_ids)
print("example_token0 =", example_token0)
summarize_state_shape(example_state0)

print("Example step:")
example_token1, *example_state1 = sm.forward(example_token0, *example_state0)
print("example_token1 =", example_token1)
summarize_state_shape(example_state1)

Example initialize:
example_token0 = tensor([[829]])
['Tensor(torch.Size([1, 32, 136, 128]))', 'Tensor(torch.Size([1, 32, 136, 128]))', 'Tensor(torch.Size([1, 32, 136, 128]))', 'Tensor(torch.Size([1, 32, 136, 128]))', 'Tensor(torch.Size([1, 32, 136, 128]))', 'Tensor(torch.Size([1, 32, 136, 128]))', 'Tensor(torch.Size([1, 32, 136, 128]))', 'Tensor(torch.Size([1, 32, 136, 128]))', 'Tensor(torch.Size([1, 32, 136, 128]))', 'Tensor(torch.Size([1, 32, 136, 128]))', 'Tensor(torch.Size([1, 32, 136, 128]))', 'Tensor(torch.Size([1, 32, 136, 128]))', 'Tensor(torch.Size([1, 32, 136, 128]))', 'Tensor(torch.Size([1, 32, 136, 128]))', 'Tensor(torch.Size([1, 32, 136, 128]))', 'Tensor(torch.Size([1, 32, 136, 128]))', 'Tensor(torch.Size([1, 32, 136, 128]))', 'Tensor(torch.Size([1, 32, 136, 128]))', 'Tensor(torch.Size([1, 32, 136, 128]))', 'Tensor(torch.Size([1, 32, 136, 128]))', 'Tensor(torch.Size([1, 32, 136, 128]))', 'Tensor(torch.Size([1, 32, 136, 128]))', 'Tensor(torch.Size([1, 32, 136, 128]))', 'Te

In [6]:
# Export initializer
exp_initialize = dynamo.export(
    sm.initialize, 
    aten_graph=True,
    assume_static_by_default=True,
    constraints=[

    ],
)
g, guards = exp_initialize(input_ids)
g.print_readable()

[2023-09-07 18:13:41,668] [0/0] torch.fx.experimental.symbolic_shapes: [INFO] create_env
[2023-09-07 18:13:58,072] [0/0] torch.fx.experimental.symbolic_shapes: [INFO] produce_guards


class GraphModule(torch.nn.Module):
    def forward(self, input_ids):
        arg0: i64[1, 136], = fx_pytree.tree_flatten_spec(([input_ids], {}), self._in_spec)
        # File: /home/stella/src/venv/Turbine/lib/python3.11/site-packages/transformers/models/llama/modeling_llama.py:654, code: position_ids = torch.arange(
        arange_start: i64[136] = torch.ops.aten.arange.start(0, 136, dtype = torch.int64, device = device(type='cpu'), pin_memory = False)
        
        # File: /home/stella/src/venv/Turbine/lib/python3.11/site-packages/transformers/models/llama/modeling_llama.py:657, code: position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
        unsqueeze_default: i64[1, 136] = torch.ops.aten.unsqueeze.default(arange_start, 0);  arange_start = None
        view_default: i64[1, 136] = torch.ops.aten.view.default(unsqueeze_default, [-1, 136]);  unsqueeze_default = None
        
        # File: /home/stella/src/venv/Turbine/lib/python3.11/site-packages/transformers/models/ll

"class GraphModule(torch.nn.Module):\n    def forward(self, input_ids):\n        arg0: i64[1, 136], = fx_pytree.tree_flatten_spec(([input_ids], {}), self._in_spec)\n        # File: /home/stella/src/venv/Turbine/lib/python3.11/site-packages/transformers/models/llama/modeling_llama.py:654, code: position_ids = torch.arange(\n        arange_start: i64[136] = torch.ops.aten.arange.start(0, 136, dtype = torch.int64, device = device(type='cpu'), pin_memory = False)\n        \n        # File: /home/stella/src/venv/Turbine/lib/python3.11/site-packages/transformers/models/llama/modeling_llama.py:657, code: position_ids = position_ids.unsqueeze(0).view(-1, seq_length)\n        unsqueeze_default: i64[1, 136] = torch.ops.aten.unsqueeze.default(arange_start, 0);  arange_start = None\n        view_default: i64[1, 136] = torch.ops.aten.view.default(unsqueeze_default, [-1, 136]);  unsqueeze_default = None\n        \n        # File: /home/stella/src/venv/Turbine/lib/python3.11/site-packages/transformer

In [7]:
# Export forward
exp_forward = dynamo.export(
    sm.forward, 
    aten_graph=True,
    assume_static_by_default=True,
    # Constrain the first state dim and then form an equality
    # on all of the others. If we don't specify sufficient constraints
    # for these, Dynamo will print two pages of a copy-pastable version
    # of basically this based on what it found in the graph but wants
    # you to be explicit about.
    constraints= [
        dynamic_dim(example_state0[0], 2) < 4095
    ] + [
        (dynamic_dim(x, 2) == (dynamic_dim(example_state0[0], 2))) for x in example_state0[1:]
    ],
)
g, guards = exp_forward(example_token0, *example_state0)
g.print_readable()

[2023-09-07 18:14:02,427] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] create_env
[2023-09-07 18:14:02,450] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] create_symbol s0 = 136 for L['state0_flat'][0].size()[2]
[2023-09-07 18:14:02,533] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] create_symbol s1 = 136 for L['state0_flat'][1].size()[2]
[2023-09-07 18:14:02,545] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] create_symbol s2 = 136 for L['state0_flat'][2].size()[2]
[2023-09-07 18:14:02,561] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] create_symbol s3 = 136 for L['state0_flat'][3].size()[2]
[2023-09-07 18:14:02,571] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] create_symbol s4 = 136 for L['state0_flat'][4].size()[2]
[2023-09-07 18:14:02,594] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] create_symbol s5 = 136 for L['state0_flat'][5].size()[2]
[2023-09-07 18:14:02,607] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] create_symbol s6 = 136 

class GraphModule(torch.nn.Module):
    def forward(self, token0, state0_flat_0, state0_flat_1, state0_flat_2, state0_flat_3, state0_flat_4, state0_flat_5, state0_flat_6, state0_flat_7, state0_flat_8, state0_flat_9, state0_flat_10, state0_flat_11, state0_flat_12, state0_flat_13, state0_flat_14, state0_flat_15, state0_flat_16, state0_flat_17, state0_flat_18, state0_flat_19, state0_flat_20, state0_flat_21, state0_flat_22, state0_flat_23, state0_flat_24, state0_flat_25, state0_flat_26, state0_flat_27, state0_flat_28, state0_flat_29, state0_flat_30, state0_flat_31, state0_flat_32, state0_flat_33, state0_flat_34, state0_flat_35, state0_flat_36, state0_flat_37, state0_flat_38, state0_flat_39, state0_flat_40, state0_flat_41, state0_flat_42, state0_flat_43, state0_flat_44, state0_flat_45, state0_flat_46, state0_flat_47, state0_flat_48, state0_flat_49, state0_flat_50, state0_flat_51, state0_flat_52, state0_flat_53, state0_flat_54, state0_flat_55, state0_flat_56, state0_flat_57, state0_flat_58, 

"class GraphModule(torch.nn.Module):\n    def forward(self, token0, state0_flat_0, state0_flat_1, state0_flat_2, state0_flat_3, state0_flat_4, state0_flat_5, state0_flat_6, state0_flat_7, state0_flat_8, state0_flat_9, state0_flat_10, state0_flat_11, state0_flat_12, state0_flat_13, state0_flat_14, state0_flat_15, state0_flat_16, state0_flat_17, state0_flat_18, state0_flat_19, state0_flat_20, state0_flat_21, state0_flat_22, state0_flat_23, state0_flat_24, state0_flat_25, state0_flat_26, state0_flat_27, state0_flat_28, state0_flat_29, state0_flat_30, state0_flat_31, state0_flat_32, state0_flat_33, state0_flat_34, state0_flat_35, state0_flat_36, state0_flat_37, state0_flat_38, state0_flat_39, state0_flat_40, state0_flat_41, state0_flat_42, state0_flat_43, state0_flat_44, state0_flat_45, state0_flat_46, state0_flat_47, state0_flat_48, state0_flat_49, state0_flat_50, state0_flat_51, state0_flat_52, state0_flat_53, state0_flat_54, state0_flat_55, state0_flat_56, state0_flat_57, state0_flat_58