# HF Accelerate

In [None]:
from transformers import AutoTokenizer
from transformers import AutoModelForCausalLM

import time, gc, torch
from tqdm import tqdm

In [2]:
model_name = "facebook/opt-13b"

tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side='left')
kwargs = dict(
    torch_dtype=torch.float16,
)

Downloading (…)okenizer_config.json:   0%|          | 0.00/721 [00:00<?, ?B/s]

Downloading (…)olve/main/vocab.json:   0%|          | 0.00/899k [00:00<?, ?B/s]

Downloading (…)olve/main/merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/441 [00:00<?, ?B/s]

In [None]:
model = AutoModelForCausalLM.from_pretrained(model_name, **kwargs)
model = model.to("cuda:0")

Downloading (…)lve/main/config.json:   0%|          | 0.00/719 [00:00<?, ?B/s]

Downloading (…)model.bin.index.json:   0%|          | 0.00/52.4k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/3 [00:00<?, ?it/s]

Downloading (…)l-00001-of-00003.bin:   0%|          | 0.00/9.98G [00:00<?, ?B/s]

Downloading (…)l-00002-of-00003.bin:   0%|          | 0.00/9.86G [00:00<?, ?B/s]

Downloading (…)l-00003-of-00003.bin:   0%|          | 0.00/5.87G [00:00<?, ?B/s]

In [None]:
inputs = ["In a galaxy far, far away"] * 1
input_tokens = tokenizer.batch_encode_plus(inputs, return_tensors="pt", padding=True)
for t in input_tokens:
    if torch.is_tensor(input_tokens[t]):
        input_tokens[t] = input_tokens[t].to("cuda:0")

In [None]:
generate_kwargs = dict(max_new_tokens=100, use_cache=True, do_sample=False)
output_tokens = model.generate(**input_tokens, **generate_kwargs)

In [None]:
print(tokenizer.batch_decode(output_tokens))

### RECREATING

In [None]:
from tqdm import tqdm

In [None]:
inputs_saver = {}
shapes = [10, 128, 256, 512, 1024]

In [None]:
model_kwargs = model.generation_config.update(**input_tokens, **generate_kwargs)
inputs_tensor, model_input_name, model_kwargs = model._prepare_model_inputs(
    None, model.generation_config.bos_token_id, model_kwargs
)
model_kwargs["attention_mask"] = model._prepare_attention_mask_for_generation(
    inputs_tensor, model.generation_config.pad_token_id, model.generation_config.eos_token_id
)
model_kwargs["output_attentions"] = model.generation_config.output_attentions
model_kwargs["output_hidden_states"] = model.generation_config.output_hidden_states
model_kwargs["use_cache"] = model.generation_config.use_cache

input_ids = inputs_tensor

model = model.eval()
with torch.no_grad():
    for i in tqdm(range(1024)):
        if input_ids.shape[1] in shapes:
            inputs_saver[input_ids.shape[1]] = {
                "model_kwargs": model_kwargs.copy(),
                "input_ids": input_ids,
            }
        
        model_inputs = model.prepare_inputs_for_generation(input_ids, **model_kwargs)
        
        outputs = model(
            **model_inputs, 
            return_dict=True,
            output_attentions=model.generation_config.output_attentions,
            output_hidden_states=model.generation_config.output_hidden_states)

        next_token_logits = outputs.logits[:, -1, :]

        next_tokens = torch.argmax(next_token_logits, dim=-1)
        input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
        model_kwargs = model._update_model_kwargs_for_generation(
            outputs, model_kwargs, is_encoder_decoder=False
        )        
print()

In [None]:
outputs.keys()

In [None]:
print(tokenizer.batch_decode(output_tokens))
print(tokenizer.batch_decode(input_ids[:,:108]))

## PROFILING

In [None]:
import time

### DECODE

In [None]:
iterations = 100
for shape in shapes:
    input_ids = inputs_saver[shape]["input_ids"]
    model_kwargs = inputs_saver[shape]["model_kwargs"]
    model_inputs = model.prepare_inputs_for_generation(input_ids, **model_kwargs)
   
    with torch.no_grad():
        start = time.perf_counter()

        print(f'KV Cache Shape: {model_inputs["past_key_values"][0][0].shape}')
        for it in tqdm(range(iterations)):
            outputs = model(
                **model_inputs, 
                return_dict=True,
                output_attentions=model.generation_config.output_attentions,
                output_hidden_states=model.generation_config.output_hidden_states)

        torch.cuda.synchronize()
        end = time.perf_counter()
    
    print(f"Decode with input_ids.shape = {input_ids.shape}")
    print(f"Time: {end-start: .2f}")
    print(f"Iterations: {iterations}")
    print(f"Throughput (tokens/sec): {iterations / (end-start) : .2f}")
    print(f"Latency (sec/token): {(end-start) / iterations : .3f}")
    print("\n")