In [2]:
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch

from decoding_analysis_utils import DecodingVisualizer

In [None]:
cache_dir = "/Users/zarreennaowalreza/Documents/openmined-new/Research/rivanna/hf_cache_models"
model_id =  "google/gemma-2b"  #"meta-llama/Meta-Llama-3-8B-Instruct"

In [None]:
def load_model(model_id):

    tokenizer = AutoTokenizer.from_pretrained(model_id, cache_dir=cache_dir)
    model = AutoModelForCausalLM.from_pretrained(
        model_id, torch_dtype=torch.bfloat16, device_map="auto", cache_dir=cache_dir
    )
    model.generation_config.eos_token_id = model.generation_config.pad_token_id

    return model, tokenizer
    

def process_input(query, prompt=""):

    if not prompt:
        input_text = f"Answer with a short answer.\n\nQuestion: {query}\n\nAnswer: "
    else:
        input_text = f"{prompt}\n\nQuestion: {query}\n\nAnswer: "

    return input_text


def run_dola_decoding(model_id, query, prompt="", decoding_analysis=True):

    model, tokenizer = load_model(model_id)
    input_text = process_input(query, prompt)
    inputs = tokenizer(input_text, return_tensors="pt").to(model.device)
    
    generate_kwargs={
        "do_sample": False, "max_new_tokens": 40, "top_p": None, "temperature": None, "output_attentions": True, 
        "output_hidden_states": True, "output_scores": True, "output_logits": True, "return_dict_in_generate": True
    }
    
    dola_outputs = model.generate(**inputs, **generate_kwargs, dola_layers=[14,16,18], repetition_penalty=1.2)

    if decoding_analysis:

        dec_viz = DecodingVisualizer(dola_outputs)
        
        sequences, scores = dola_outputs.sequences, outputs.scores
        attentions, hidden_states = dola_outputs.attentions, outputs.hidden_states
        print("sequences shape", sequences.shape)
        
        # skip the tokens in the input prompt
        gen_sequences = sequences[:, inputs.input_ids.shape[-1]:][0, :]
        gen_arr = gen_sequences.cpu().numpy()
        
        print("gen_sequences", gen_sequences)
        print("gen_sequences len", len(gen_sequences))
        print("gen_arr", gen_arr)
        
        output_str = tokenizer.decode(gen_sequences, skip_special_tokens=True)
        
        print('MODEL OUTPUT: \n{0}'.format(output_str))
        
        print("### outside of generate ###")
        
        premature_layer_dist = dola_outputs.premature_layer_dist
        print("outputs.premature_layer_dist", premature_layer_dist)
        
        layer_tokens = dola_outputs.layer_tokens["layer_tokens"]
        layer_tokens_logits = dola_outputs.layer_tokens["layer_tokens_logits"]
        
        print(len(layer_tokens))
        print(len(layer_tokens_logits))