In [None]:
!pip install transformers==4.40.2
!pip install torch>=2.1.0
!pip install sentence-transformers>=2.7.0
!pip install accelerate>=0.21.0
!pip install sentence-transformers>=2.7.0
!pip install bitsandbtyes

Collecting transformers==4.40.2
  Downloading transformers-4.40.2-py3-none-any.whl.metadata (137 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/138.0 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m138.0/138.0 kB[0m [31m6.1 MB/s[0m eta [36m0:00:00[0m
Collecting tokenizers<0.20,>=0.19 (from transformers==4.40.2)
  Downloading tokenizers-0.19.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (6.7 kB)
Downloading transformers-4.40.2-py3-none-any.whl (9.0 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m9.0/9.0 MB[0m [31m42.0 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading tokenizers-0.19.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (3.6 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.6/3.6 MB[0m [31m28.8 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: tokenizers, transformers
  Attempting uninstall: tokenizers

In [None]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers import BitsAndBytesConfig
from sentence_transformers import SentenceTransformer, util
import spacy
import os

In [None]:
LLM_MODEL_ID = "meta-llama/Llama-3.2-1B-Instruct"
CLASSIFIER_MODEL_ID = "all-MiniLM-L6-v2"
SPACY_MODEL_ID = "en_core_web_sm"
CSA_DEPENDENCY_THRESHOLD = 0.65
PRONOUN_SCORE = 0.95
ENTITY_DEFICIT_SCORE = 0.80
SELF_CONTAINED_PENALTY = 0.5

In [None]:
from google.colab import userdata
HF_TOKEN = userdata.get("HF_TOKEN")

os.environ["HF_TOKEN"] = HF_TOKEN

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

Using device: cuda


In [None]:
def load_spacy_model():
    try:
        spacy.load(SPACY_MODEL_ID)
        print(f"SpaCy Model {SPACY_MODEL_ID} already installed.")
    except OSError:
        print(f"SpaCy Model {SPACY_MODEL_ID} not found. Starting Download...")
        spacy.cli.download(SPACY_MODEL_ID)
        print(f"{SPACY_MODEL_ID} download complete!")

In [None]:
_MODEL_CACHE = {
    "llm_tokenizer": None,
    "llm_model": None,
    "classifier_model": None,
    "nlp_model": None,
    "loaded": False
}

In [None]:
class KVCacheManager:
    """
    Pure KV cache manager with proper cache_position handling.
    """

    def __init__(self, model, tokenizer):
        self.model = model
        self.tokenizer = tokenizer
        self.device = next(model.parameters()).device

    def generate_with_kv_cache(self, prompt, past_kv_cache=None, max_new_tokens=200):
        """
        Pure KV cache generation with proper cache position handling.
        """
        print(f"\n> Generating response for: '{prompt}...'")
        print(f"> Using past_kv_cache: {past_kv_cache is not None}")

        if past_kv_cache is None:
            # First turn - establish context
            messages = [{"role": "user", "content": prompt}]

            try:
                text = self.tokenizer.apply_chat_template(
                    messages, tokenize=False, add_generation_prompt=True
                )
            except:
                text = f"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n{prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"

            model_inputs = self.tokenizer(text, return_tensors="pt").to(self.device)
            print(f"> First turn - Input tokens: {model_inputs['input_ids'].shape[1]}")

            # First turn generation
            with torch.no_grad():
                outputs = self.model.generate(
                    **model_inputs,
                    max_new_tokens=max_new_tokens,
                    use_cache=True,
                    return_dict_in_generate=True,
                    pad_token_id=self.tokenizer.eos_token_id,
                    do_sample=True,
                    temperature=0.7,
                    top_p=0.9,
                )

            # Extract response
            input_length = model_inputs["input_ids"].shape[1]
            response_tokens = outputs.sequences[0][input_length:]
            response_text = self.tokenizer.decode(response_tokens, skip_special_tokens=True).strip()
            new_kv_cache = outputs.past_key_values

            print(f"> KV Cache optimization successful!")
            print(f"> Preserved cache for next turn: {len(new_kv_cache)} layers")
            print(f"\n> LLM: {response_text}")

            return response_text, new_kv_cache

        else:
            # KV cache reuse - use manual token-by-token generation to avoid cache_position issues
            print(f"> KV Cache Optimization Active!")
            print(f"> Reusing {past_kv_cache[0][0].shape[-2]} cached tokens")

            # Tokenize new prompt
            follow_up = f"<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n{prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
            new_tokens = self.tokenizer(follow_up, return_tensors="pt").to(self.device)

            print(f"> New tokens: {new_tokens['input_ids'].shape[1]}")

            # Manual generation to avoid cache_position issues
            return self._manual_generation_with_cache(new_tokens["input_ids"], past_kv_cache, max_new_tokens)

    def _manual_generation_with_cache(self, input_ids, past_kv_cache, max_new_tokens):
        """
        Manual token-by-token generation using KV cache - bypasses transformers issues.
        """
        print("> Using manual generation to preserve KV cache optimization")

        generated_ids = input_ids.clone()
        current_kv_cache = past_kv_cache

        for step in range(max_new_tokens):
            # Get model output for next token
            with torch.no_grad():
                if step == 0:
                    # First step: use the input tokens
                    outputs = self.model(
                        input_ids=input_ids,
                        past_key_values=current_kv_cache,
                        use_cache=True,
                    )
                else:
                    # Subsequent steps: use only the last generated token
                    outputs = self.model(
                        input_ids=generated_ids[:, -1:],
                        past_key_values=current_kv_cache,
                        use_cache=True,
                    )

            # Get next token logits and apply sampling
            next_token_logits = outputs.logits[0, -1, :]

            # Apply temperature
            next_token_logits = next_token_logits / 0.7

            # Apply top-p sampling
            sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True)
            cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
            sorted_indices_to_remove = cumulative_probs > 0.9
            sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
            sorted_indices_to_remove[..., 0] = 0
            indices_to_remove = sorted_indices_to_remove.scatter(0, sorted_indices, sorted_indices_to_remove)
            next_token_logits[indices_to_remove] = float('-inf')

            # Sample next token
            probs = torch.softmax(next_token_logits, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1)

            # Check for EOS
            if next_token.item() == self.tokenizer.eos_token_id:
                break

            # Add token to sequence
            generated_ids = torch.cat([generated_ids, next_token.unsqueeze(0)], dim=1)

            # Update KV cache
            current_kv_cache = outputs.past_key_values

            # Optional: print progress for long generations
            if step > 0 and step % 50 == 0:
                print(f"> Generated {step} tokens...")

        # Extract response text
        response_start = input_ids.shape[1]
        response_tokens = generated_ids[0][response_start:]
        response_text = self.tokenizer.decode(response_tokens, skip_special_tokens=True).strip()

        # Clean up Llama artifacts
        response_text = response_text.replace("<|eot_id|>", "").strip()

        print(f"> Manual KV cache generation successful!")
        print(f"> Generated {len(response_tokens)} new tokens")
        print(f"> Total context preserved: {current_kv_cache[0][0].shape[-2]} tokens")
        print(f"\n> LLM: {response_text}")

        return response_text, current_kv_cache

In [None]:
def load_models():
    """
    Downloads and loads the LLM (8-bit configuration), the sentence classifier and nlp model.
    Models are cached globally and only loaded once.
    """

    global _MODEL_CACHE

    # return cached models if already loaded
    if _MODEL_CACHE["loaded"]:
      print("> Models already loaded from cache!")
      return (
          _MODEL_CACHE["llm_tokenizer"],
          _MODEL_CACHE["llm_model"],
          _MODEL_CACHE["classifier_model"],
          _MODEL_CACHE["nlp_model"]
      )

    print(f"> Loading LLM: {LLM_MODEL_ID}")

    llm_tokenizer = AutoTokenizer.from_pretrained(LLM_MODEL_ID, trust_remote_code=True)
    if llm_tokenizer.pad_token is None:
      llm_tokenizer.pad_token = llm_tokenizer.eos_token

    if torch.cuda.is_available():
      bnb_config = BitsAndBytesConfig(
          load_in_8bit=True,
          bnb_8bit_compute_dtype=torch.float16,
      )

      llm_model = AutoModelForCausalLM.from_pretrained(
        LLM_MODEL_ID,
        token=HF_TOKEN,
        quantization_config=bnb_config,
        device_map="auto",
        torch_dtype=torch.float16,
        trust_remote_code=True,
      )

      print(f"> LLM Model {LLM_MODEL_ID} loaded on GPU Successfully!")
      print(f"> GPU memory: {torch.cuda.memory_allocated() / 1e9:.2f} GB")

    else:
      llm_model = AutoModelForCausalLM.from_pretrained(
        LLM_MODEL_ID,
        token=HF_TOKEN,
        torch_dtype=torch.float32,
        device_map="cpu",
        trust_remote_code=True,
      )

      print("> LLM Model Loaded Successfully!")

    # create KV cache manager
    kv_manager = KVCacheManager(llm_model, llm_tokenizer)

    # load classifier
    print(f"> Loading Classifier: {CLASSIFIER_MODEL_ID}")
    classifier_model = SentenceTransformer(CLASSIFIER_MODEL_ID, device=device)
    print("> Classifier loaded successfully!")

    load_spacy_model()
    print(f"> Loading NLP model: {SPACY_MODEL_ID}")
    nlp_model = spacy.load(SPACY_MODEL_ID)
    print("> NLP Model loaded successfully!")

    # cache all models
    _MODEL_CACHE.update({
        "llm_tokenizer": llm_tokenizer,
        "llm_model": llm_model,
        "classifier_model": classifier_model,
        "nlp_model": nlp_model,
        "loaded": True
    })

    print("> All models cached!")

    return llm_tokenizer, llm_model, kv_manager, classifier_model, nlp_model

In [None]:
class ConversationManager:
    """
    Pure conversation manager focused on KV cache optimization.
    """

    def __init__(self):
        self.conversation_tree = {}
        self.current_branch_id = None

    def start_new_branch(self):
        """Start a new conversation branch"""
        branch_id = f"branch_{len(self.conversation_tree)}"
        self.conversation_tree[branch_id] = {'turns': []}
        return branch_id

    def add_turn_to_branch(self, branch_id, prompt, response, kv_cache):
        """Store turn with KV cache for optimization"""
        if branch_id not in self.conversation_tree:
            self.conversation_tree[branch_id] = {'turns': []}

        self.conversation_tree[branch_id]['turns'].append({
            'prompt': prompt,
            'response': response,
            'kv_cache': kv_cache,
        })

    def get_last_kv_cache(self, branch_id):
        """Get KV cache for context optimization"""
        if branch_id in self.conversation_tree and self.conversation_tree[branch_id]['turns']:
            return self.conversation_tree[branch_id]['turns'][-1]['kv_cache']
        return None

    def get_last_turn(self, branch_id):
        """Get last turn for CSA analysis"""
        if branch_id in self.conversation_tree and self.conversation_tree[branch_id]['turns']:
            return self.conversation_tree[branch_id]['turns'][-1]
        return None

    def display_conversation_tree(self):
        """Display conversation state"""
        print(f"\n Conversation Tree State---")
        for branch_id, branch_data in self.conversation_tree.items():
            turns = len(branch_data['turns'])
            print(f"> {branch_id}: Contains {turns} turns.")

In [None]:
def csa_classifier(nlp, classifier, new_prompt, last_context):
    """
    Classifies a prompt using the Contextual Scaffolding Analysis algorithm.

    Args:
        nlp: The spacy NLP model.
        classifier: The classifier model.
        new_prompt (str): User prompt.
        last_context: Last context.

    Returns:
        boolean: True or False -> Same branch or new branch
    """

    print("> Running Contextual Scaffolding Analysis...")
    if not last_context:
        print("> No previous context, starting a new branch.")
        return True

    doc = nlp(new_prompt)
    dependency_score = 0.0

    # 1. Prounoun check (strongest signal)
    anchor_pronouns = {"it", "its", "that", "those", "they", "their", "them"}
    if any(token.lower_ in anchor_pronouns for token in doc):
        dependency_score = PRONOUN_SCORE
        print(f"> CSA -> Pronoun check passed. Score: {dependency_score}")

    # 2. Entity deficit check (strong signal)
    if dependency_score == 0.0:
        is_question = doc[0].pos_ == "AUX" or doc[0].tag_ in ["WDT", "WP", "WP$", "WRB"]
        has_entities = len(doc.ents) > 0
        if is_question and not has_entities:
            dependency_score = ENTITY_DEFICIT_SCORE
            print(f"> CSA -> Entity deficit check passed. Score: {dependency_score}")

    # 3. Semantic Fallback Check (Tie-Breaker)
    if dependency_score == 0.0:
        print("> CSA -> Running semantic fallback check.")
        context_text = f"User Asked: {last_context['prompt']} | Model Response: {last_context['response']}"

        embedding_new = classifier.encode(new_prompt, convert_to_tensor=True).to(device)
        embedding_context = classifier.encode(context_text, convert_to_tensor=True).to(device)

        topic_similarity = util.cos_sim(embedding_new, embedding_context).item()
        print(f"> Topic Similarity Score: {topic_similarity:.4f}")

        # check if the prompt is self-contained (has its own subject/entity)
        is_self_contained = len(doc.ents) > 0

        if is_self_contained:
            # penalize the score because the prompt can stand on its own
            dependency_score = topic_similarity * SELF_CONTAINED_PENALTY
            print("> Prompt is self-contained. Applying penalty.")
            print(f"> Final Score: {dependency_score}")
        else:
            # no penalty as its likely dependent
            dependency_score = topic_similarity
            print("> Prompt is not self-contained. No penalty applied.")
            print(f"> Final Score: {dependency_score}")

    # Final Decision
    if dependency_score > CSA_DEPENDENCY_THRESHOLD:
        print(f"> Decision: Same branch (Score: {dependency_score:.2f} > {CSA_DEPENDENCY_THRESHOLD})")
        return True
    else:
        print(f"> Decision: New branch (Score: {dependency_score} < {CSA_DEPENDENCY_THRESHOLD})")
        return False

In [None]:
def run_context_manager():
    """
    Pure context manager implementation - no fallbacks.
    """
    # Load models
    tokenizer, model, kv_manager, classifier, nlp = load_models()

    # Initialize conversation manager
    conversation_manager = ConversationManager()

    print(f"\n> Pure KV Cache Context Manager!")
    print(f"> Model: {LLM_MODEL_ID}")
    print(f"> Optimization: Pure KV cache branching")
    print("> Type 'exit' to end session.\n")

    while True:
        user_prompt = input("> Your Prompt: ").strip()
        if user_prompt.lower() == 'exit':
            break

        # Branch decision with CSA
        if conversation_manager.current_branch_id is None:
            # No context - start new branch
            branch_id = conversation_manager.start_new_branch()
            conversation_manager.current_branch_id = branch_id
            past_kv_cache = None
            print(f"> No context found, forcing a new branch.")
            print(f"> Starting new branch: {branch_id}")

        else:
            # Get last context for CSA analysis
            last_turn = conversation_manager.get_last_turn(conversation_manager.current_branch_id)

            if last_turn:
                last_context = {
                    'prompt': last_turn['prompt'],
                    'response': last_turn['response']
                }

                # Run your CSA classifier
                should_continue = csa_classifier(
                    nlp, classifier, user_prompt, last_context
                )

                if should_continue:
                    # Continue on same branch - USE KV CACHE
                    branch_id = conversation_manager.current_branch_id
                    past_kv_cache = conversation_manager.get_last_kv_cache(branch_id)
                    print(f"> Continuing on same branch: {branch_id}")
                else:
                    # Start new branch - FRESH CONTEXT
                    branch_id = conversation_manager.start_new_branch()
                    conversation_manager.current_branch_id = branch_id
                    past_kv_cache = None
                    print(f"> Starting new branch: {branch_id}")
            else:
                # No previous turn - start new
                branch_id = conversation_manager.start_new_branch()
                conversation_manager.current_branch_id = branch_id
                past_kv_cache = None
                print(f"> Starting new branch: {branch_id}")

        # Generate using pure KV cache optimization
        response, new_kv_cache = kv_manager.generate_with_kv_cache(
            user_prompt, past_kv_cache
        )

        # Store turn with KV cache
        conversation_manager.add_turn_to_branch(
            branch_id, user_prompt, response, new_kv_cache
        )

        # Display tree state
        conversation_manager.display_conversation_tree()
        print()

In [None]:
run_context_manager()

> Loading LLM: meta-llama/Llama-3.2-1B-Instruct
> LLM Model meta-llama/Llama-3.2-1B-Instruct Loaded on GPU Successfully!
> GPU memory: 4.79 GB
> Loading Classifier: all-MiniLM-L6-v2
> Classifier loaded successfully!
SpaCy Model en_core_web_sm already installed.
> Loading NLP model: en_core_web_sm
> NLP Model loaded successfully!
> All models cached!

> 🚀 Pure KV Cache Context Manager!
> Model: meta-llama/Llama-3.2-1B-Instruct
> Optimization: Pure KV cache branching
> Type 'exit' to end session.

 Your Prompt: where is taj mahal?
> No context found, forcing a new branch.
> Starting new branch: branch_0

> Generating response for: 'where is taj mahal?...'
> Using past_kv_cache: False
> First turn - Input tokens: 42
> ✅ KV Cache optimization successful!
> Preserved cache for next turn: 16 layers

> LLM: The Taj Mahal is located in India, specifically in the state of Uttar Pradesh. It is situated in the city of Agra, which is about 230 kilometers (143 miles) from the city of Delhi. The Taj