In [1]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers.cache_utils import DynamicCache
import os

In [2]:
# Minimal generate function for token-by-token generation
def generate(model, input_ids: torch.Tensor, past_key_values, max_new_tokens: int = 50) -> torch.Tensor:
    device = model.model.embed_tokens.weight.device
    origin_len = input_ids.shape[-1]
    input_ids = input_ids.to(device)
    output_ids = input_ids.clone()
    next_token = input_ids

    with torch.no_grad():
        for _ in range(max_new_tokens):
            out = model(
                input_ids=next_token,
                past_key_values=past_key_values,
                use_cache=True
            )
            logits = out.logits[:, -1, :]
            token = torch.argmax(logits, dim=-1, keepdim=True)
            output_ids = torch.cat([output_ids, token], dim=-1)
            past_key_values = out.past_key_values
            next_token = token.to(device)

            if model.config.eos_token_id is not None and token.item() == model.config.eos_token_id:
                break

    # Return just the newly generated part
    return output_ids[:, origin_len:]

In [3]:
torch.serialization.add_safe_globals([DynamicCache])
torch.serialization.add_safe_globals([set])

In [4]:
def get_kv_cache(model, tokenizer, prompt: str) -> DynamicCache:
    # Encode prompt
    device = model.model.embed_tokens.weight.device
    print(device)
    input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
    cache = DynamicCache()

    with torch.no_grad():
        _ = model(
            input_ids=input_ids,
            past_key_values=cache,
            use_cache=True
        )
    return cache

In [5]:
def clean_up(cache: DynamicCache, origin_len: int):
    # Remove any tokens appended to the original knowledge
    for i in range(len(cache.key_cache)):
        cache.key_cache[i] = cache.key_cache[i][:, :, :origin_len, :]
        cache.value_cache[i] = cache.value_cache[i][:, :, :origin_len, :]

In [6]:
def get_env():
    env_dict = {}
    env_file = ".env" if os.path.exists(".env") else "env"
    if os.path.exists(env_file):
        with open(env_file, mode="r") as f:
            for line in f:
                key, value = line.strip().split("=")
                env_dict[key] = value.strip('"')
    else:
        print("No .env or env file found; HF_TOKEN may not be set.")
    return env_dict

env = get_env()
HF_TOKEN = env.get("HF_TOKEN", None)

# Global placeholders (if needed)
model_name = None
model = None
tokenizer = None
rand_seed = None

print("Environment and imports are set.")

Environment and imports are set.


In [7]:
#model_name = "mistralai/Mistral-7B-Instruct-v0.1"
#model_name = "meta-llama/Llama-3.1-8B-Instruct"
#model_name = "microsoft/Phi-3-mini-4k-instruct"
#model_name = "ibm-granite/granite-3.1-8b-instruct"
#model_name = "meta-llama/Llama-3.2-1B"
model_name = "ibm-granite/granite-3.1-1b-a400m-instruct"
max_memory = {"cpu": "15GiB"}
tokenizer = AutoTokenizer.from_pretrained(model_name, token=HF_TOKEN, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    #torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
    torch_dtype=torch.float32,
    device_map="balanced",
    trust_remote_code=True,
    token=HF_TOKEN,
    max_memory=max_memory
)
#device = "cuda" if torch.cuda.is_available() else "cpu"
#device = "cpu"

#model.to(device)
print(f"Loaded {model_name}.")

tokenizer_config.json:   0%|          | 0.00/8.07k [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/777k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/442k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/3.48M [00:00<?, ?B/s]

added_tokens.json:   0%|          | 0.00/87.0 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/701 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/889 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/2.67G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/132 [00:00<?, ?B/s]

Loaded ibm-granite/granite-3.1-1b-a400m-instruct.


In [30]:
if not os.path.exists("document.txt"):
    raise FileNotFoundError("Please create a `document.txt` with info about Ravi.")

with open("document.txt", "r", encoding="utf-8") as f:
    doc_text = f.read()

#system_prompt = f"""
#<|system|>
#You are an assistant who provides concise factual answers.
#<|user|>
#Context:
#{doc_text}
#Question:
#""".strip()

#system_prompt = f"""You are an expert in providing factual answers from the context.
#Context:
#{doc_text}
#Question:
#""".strip()

system_prompt = f"""<|start_of_role|>system<|end_of_role|>You are an expert in extracting content from the context for the given question.<|end_of_text|>
<|start_of_role|>user<|end_of_role|>Context:{doc_text}<|end_of_text|>
Question:
<|start_of_role|>assistant<|end_of_role|>Answer:
""".strip()

# Build the cache
cache = get_kv_cache(model, tokenizer, system_prompt)
origin_len = cache.key_cache[0].shape[-2]
print("KV cache built.")

cpu
KV cache built.


In [31]:
# 1st query
device="cpu"
question1 = "Who is Ravi Kumar Srirangam?"
clean_up(cache, origin_len)
input_ids_q1 = tokenizer(question1 + "\n", return_tensors="pt").input_ids.to(device)
gen_ids_q1 = generate(model, input_ids_q1, cache)
answer1 = tokenizer.decode(gen_ids_q1[0], skip_special_tokens=True)

print("Q1:", question1)
print("A1:", answer1)

Q1: Who is Ravi Kumar Srirangam?
A1: 
Ravi Kumar Srirangam is an experienced technology and product leader with a strong background in engineering and product development. He has extensive experience in developing scalable platforms, middleware, and distributed applications, and has successfully migrated applications to the


In [21]:
# 2nd query
question2 = "What is his education?"
clean_up(cache, origin_len)
input_ids_q2 = tokenizer(question2 + "\n", return_tensors="pt").input_ids.to(device)
gen_ids_q2 = generate(model, input_ids_q2, cache)
answer2 = tokenizer.decode(gen_ids_q2[0], skip_special_tokens=True)
print("Q1:", question2)
print("A1:", answer2)

Q1: What is his education?
A1: Answer:
Ravi Kumar Srirangam has a Bachelor of Technology (BTech) in Mechanical Engineering from KL University and a Master of Business Administration (MBA) from Amrita School of Business.
