<a href="https://colab.research.google.com/github/yash-mandi/e3l.ai-model/blob/main/Memory_Context_Final.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
# Install required libraries
!pip install -U langchain-community
!pip install langchain faiss-cpu sentence-transformers transformers accelerate ipywidgets bitsandbytes --quiet
!pip install git+https://github.com/huggingface/transformers.git

Collecting langchain-community
  Downloading langchain_community-0.3.24-py3-none-any.whl.metadata (2.5 kB)
Collecting dataclasses-json<0.7,>=0.5.7 (from langchain-community)
  Downloading dataclasses_json-0.6.7-py3-none-any.whl.metadata (25 kB)
Collecting pydantic-settings<3.0.0,>=2.4.0 (from langchain-community)
  Downloading pydantic_settings-2.9.1-py3-none-any.whl.metadata (3.8 kB)
Collecting httpx-sse<1.0.0,>=0.4.0 (from langchain-community)
  Downloading httpx_sse-0.4.0-py3-none-any.whl.metadata (9.0 kB)
Collecting marshmallow<4.0.0,>=3.18.0 (from dataclasses-json<0.7,>=0.5.7->langchain-community)
  Downloading marshmallow-3.26.1-py3-none-any.whl.metadata (7.3 kB)
Collecting typing-inspect<1,>=0.4.0 (from dataclasses-json<0.7,>=0.5.7->langchain-community)
  Downloading typing_inspect-0.9.0-py3-none-any.whl.metadata (1.5 kB)
Collecting python-dotenv>=0.21.0 (from pydantic-settings<3.0.0,>=2.4.0->langchain-community)
  Downloading python_dotenv-1.1.0-py3-none-any.whl.metadata (24 kB

In [None]:
# Imports
import torch
import ipywidgets as widgets
from IPython.display import display, clear_output, Javascript
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline, BitsAndBytesConfig
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import FAISS
from langchain.schema import Document
import ipywidgets as widgets
from IPython.display import display, clear_output
torch.set_float32_matmul_precision('high')

In [None]:
# Load tokenizer and model
model_path = "/content/drive/MyDrive/Gemma-Merged-Finetune-Model"  # Update as needed
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForCausalLM.from_pretrained(model_path, device_map="auto")

pipe = pipeline(
    "text-generation",
    model=model,
    tokenizer=tokenizer,
    pad_token_id=tokenizer.eos_token_id
)

Device set to use cuda:0


In [None]:
# ----------- Memory Setup -----------
embedding_model = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
vectorstore = None  # Global memory store

def add_to_memory(question, answer):
    """Add a Q&A pair to FAISS memory"""
    global vectorstore
    doc = Document(page_content=f"Q: {question}\nA: {answer}")
    if vectorstore is None:
        vectorstore = FAISS.from_documents([doc], embedding_model)
    else:
        vectorstore.add_documents([doc])

def get_memory_context(query, k=3):
    """Return top-k similar past Q&As as context"""
    if vectorstore is None:
        return ""
    results = vectorstore.similarity_search(query, k=k)
    return "\n\n".join([doc.page_content for doc in results])

In [None]:
import ipywidgets as widgets
from IPython.display import display, clear_output

# 🧠 Memory Tracking
memory = []

# Detect continuation-type questions
def is_continuation(question):
    return question.lower().strip() in [
        "continue", "continue from where you left", "go on", "keep going"
    ]

def add_to_memory(question, answer):
    if not is_continuation(question):
        memory.append((question, answer))

def get_memory_context():
    context = []
    count = 0
    for q, a in reversed(memory):
        if is_continuation(q):
            continue
        context.append(f"Q: {q}\nA: {a}")
        count += 1
        if count == 10:
            break
    return "\n".join(reversed(context))

# 🧠 Prompt Builder
def build_prompt(user_input):
    memory_context = get_memory_context()
    prompt = (
        (memory_context + "\n\n") if memory_context else ""
    ) + (
        f"Question: {user_input}\n"
        "Answer:"
    )
    return prompt

In [None]:
# ✅ Chat UI
def create_prompt_widget():
    label = widgets.HTML("<b>Ask your question:</b>")
    text_area = widgets.Textarea(
        placeholder='Enter your question here...',
        layout=widgets.Layout(width='90%', height='60px')
    )
    submit_button = widgets.Button(description="Submit", button_style='success')
    output_box = widgets.Output()
    container = widgets.VBox()

    def on_submit_clicked(b):
        output_box.clear_output()
        question = text_area.value.strip()
        if not question:
            with output_box:
                print("❗ Please enter a question.")
            return

        try:
            with output_box:
                print("Generating response... Please wait.")

            prompt = build_prompt(question)
            response = pipe(
                prompt,
                max_new_tokens=300,
                temperature=0.8,
                top_p=0.95,
                do_sample=True,
                pad_token_id=tokenizer.eos_token_id
            )

            generated_text = response[0]['generated_text']
            answer_start = generated_text.find("Answer:")
            generated_answer = (
                generated_text[answer_start + len("Answer:"):].strip()
                if answer_start != -1 else generated_text
            )

            with output_box:
                output_box.clear_output()
                print(f"Answer:\n{generated_answer}")

            add_to_memory(question, generated_answer)

            # Add a new input widget below after answer
            new_input = create_prompt_widget()
            container.children += (new_input,)

        except Exception as e:
            with output_box:
                output_box.clear_output()
                print(f"⚠️ Error generating response:\n{str(e)}")

    submit_button.on_click(on_submit_clicked)
    container.children = [label, text_area, submit_button, output_box]
    return container

# Display the first prompt
display(create_prompt_widget())

VBox(children=(HTML(value='<b>Ask your question:</b>'), Textarea(value='', layout=Layout(height='60px', width=…