In [None]:
!pip install -qU rich
!pip install -qU wandb
!pip install -qU git+https://github.com/wandb/weave.git@feat/groq
!pip install -qU llama-index groq
!pip install -qU llama-index-embeddings-huggingface

In [None]:
import os
from typing import Optional, Tuple

import rich
import wandb
import weave
from google.colab import userdata

import instructor
from groq import Groq
from pydantic import BaseModel
from llama_index.core import (
    ServiceContext, StorageContext, load_index_from_storage
)
from llama_index.core.base.base_retriever import BaseRetriever
from llama_index.embeddings.huggingface import HuggingFaceEmbedding

In [None]:
weave.init(project_name="groq-rag")

artifact = wandb.Api().artifact(
    "geekyrakshit/groq-rag/ncert-flamingoes-prose-embeddings:latest"
)
artifact_dir = artifact.download()

In [None]:
embed_model = HuggingFaceEmbedding(model_name="BAAI/bge-small-en-v1.5")

service_context = ServiceContext.from_defaults(
    embed_model=embed_model, llm=None
)

In [None]:
storage_context = StorageContext.from_defaults(persist_dir=artifact_dir)
index = load_index_from_storage(
    storage_context, service_context=service_context
)

In [None]:
retreival_engine = index.as_retriever(
    service_context=service_context,
    similarity_top_k=10,
)

In [None]:
query = """what was the mood in the classroom when M. Hamel gave his last French lesson?"""
response = retreival_engine.retrieve(query)

In [None]:
chapter_name = response[0].node.metadata["file_name"].split(".")[0].replace("_", " ").title()
context = response[0].node.text

rich.print(f"{chapter_name=}")
rich.print(f"{context=}")

In [None]:
class EnglishDoubtClearningAssistant(weave.Model):
    model: str = "llama3-8b-8192"
    _groq_client: Optional[Groq] = None
    
    def __init__(self, model: Optional[str] = None):
        super().__init__()
        self.model = model if model is not None else self.model
        self._groq_client = Groq(
            api_key=os.environ.get("GROQ_API_KEY")
        )
    
    @weave.op()
    def predict(self, question: str, context: str) -> Tuple[str, str]:
        chat_completion = self._groq_client.chat.completions.create(
            messages=[
                {
                    "role": "system",
                    "content": """
You are a student in a class and your teacher has asked you to answer the following question.
You have to write the answer in the given word limit.""",
                },
                {
                    "role": "user",
                    "content": f"""
We have provided context information below. 

---
{context}
---

Answer the following question within 50-150 words:

```
{query}
```""",
                },
            ],
            model=self.model,
        )
        return chat_completion.choices[0].message.content

In [None]:
assistant = EnglishDoubtClearningAssistant()

rich.print(assistant.predict(question=query, context=context))

In [None]:
class EnglishStudentResponseAssistant(weave.Model):
    model: str = "llama3-8b-8192"
    _groq_client: Optional[Groq] = None
    
    def __init__(self, model: Optional[str] = None):
        super().__init__()
        self.model = model if model is not None else self.model
        self._groq_client = Groq(
            api_key=os.environ.get("GROQ_API_KEY")
        )
    
    @weave.op()
    def get_prompt(
        self, question: str, context: str, word_limit_min: int, word_limit_max: int
    ) -> Tuple[str, str]:
        system_prompt = """
You are a student in a class and your teacher has asked you to answer the following question.
You have to write the answer in the given word limit."""
        user_prompt = f"""
We have provided context information below. 

---
{context}
---

Answer the following question within {word_limit_min}-{word_limit_max} words:

---
{question}
---"""
        return system_prompt, user_prompt

    @weave.op()
    def predict(self, question: str, total_marks: int) -> str:
        response = retreival_engine.retrieve(question)
        context = response[0].node.text
        if total_marks < 3:
            word_limit_min = 5
            word_limit_max = 50
        elif total_marks < 5:
            word_limit_min = 50
            word_limit_max = 100
        else:
            word_limit_min = 100
            word_limit_max = 200
        system_prompt, user_prompt = self.get_prompt(
            question, context, word_limit_min, word_limit_max
        )
        chat_completion = self._groq_client.chat.completions.create(
            messages=[
                {
                    "role": "system",
                    "content": system_prompt,
                },
                {
                    "role": "user",
                    "content": user_prompt,
                },
            ],
            model=self.model,
        )
        return chat_completion.choices[0].message.content

assistant = EnglishDoubtClearningAssistant()

rich.print(assistant.predict(question=query, context=context))

In [None]:
class EnglishStudentResponseAssistant(weave.Model):
    model: str = "llama3-8b-8192"
    _groq_client: Optional[Groq] = None
    
    def __init__(self, model: Optional[str] = None):
        super().__init__()
        self.model = model if model is not None else self.model
        self._groq_client = Groq(
            api_key=os.environ.get("GROQ_API_KEY")
        )
    
    @weave.op()
    def get_prompt(
        self, question: str, context: str, word_limit_min: int, word_limit_max: int
    ) -> Tuple[str, str]:
        system_prompt = """
You are a student in a class and your teacher has asked you to answer the following question.
You have to write the answer in the given word limit."""
        user_prompt = f"""
We have provided context information below. 

---
{context}
---

Answer the following question within {word_limit_min}-{word_limit_max} words:

---
{question}
---"""
        return system_prompt, user_prompt

    @weave.op()
    def predict(self, question: str, total_marks: int) -> str:
        response = retreival_engine.retrieve(question)
        context = response[0].node.text
        if total_marks < 3:
            word_limit_min = 5
            word_limit_max = 50
        elif total_marks < 5:
            word_limit_min = 50
            word_limit_max = 100
        else:
            word_limit_min = 100
            word_limit_max = 200
        system_prompt, user_prompt = self.get_prompt(
            question, context, word_limit_min, word_limit_max
        )
        chat_completion = self._groq_client.chat.completions.create(
            messages=[
                {
                    "role": "system",
                    "content": system_prompt,
                },
                {
                    "role": "user",
                    "content": user_prompt,
                },
            ],
            model=self.model,
        )
        return chat_completion.choices[0].message.content

In [None]:
assistant = EnglishStudentResponseAssistant()

ideal_student_response = assistant.predict(question=query, total_marks=5)
rich.print(ideal_student_response)

In [None]:
class GradeExtractor(BaseModel):
    question: str
    student_answer: str
    marks: float
    total_marks: float
    feedback: str


class EnglishGradingAssistant(EnglishStudentResponseAssistant):
    model: str = "llama3-8b-8192"
    _groq_client: Optional[Groq] = None
    _instructor_groq_client: Optional[instructor.Instructor] = None

    def __init__(self, model: Optional[str] = None):
        super().__init__(model=model)
        self.model = model if model is not None else self.model
        self._instructor_groq_client = instructor.from_groq(
            Groq(api_key=os.environ.get("GROQ_API_KEY"))
        )
    
    @weave.op()
    def get_prompt_for_grading(
        self,
        question: str,
        context: str,
        total_marks: int,
        student_answer: Optional[str] = None,
    ) -> Tuple[str, str]:
        system_prompt = """
You are a helpful assistant to an English teacher meant to grade the answer given by a student to a question.
You have to extract the question , the student's answer, the marks awarded to the student out of total marks,
the total marks and a contructive feedback to the student's answer with regards to how accurate it is with
respect to the context.
        """
        student_answer = (
            self.predict(question, total_marks)
            if student_answer is None
            else student_answer
        )
        user_prompt = f"""
We have provided context information below. 

---
{context}
---

We have asked the following question to the student for total_marks={total_marks}:

---
{question}
---

The student has responded with the following answer:

---
{student_answer}
---"""
        return user_prompt, system_prompt
    
    @weave.op()
    def grade_answer(
        self, question: str, student_answer: str, total_marks: int
    ) -> GradeExtractor:
        user_prompt, system_prompt = self.get_prompt_for_grading(
            question, student_answer, total_marks
        )
        return self._instructor_groq_client.chat.completions.create(
            messages=[
                {
                    "role": "system",
                    "content": system_prompt,
                },
                {
                    "role": "user",
                    "content": user_prompt,
                },
            ],
            model=self.model,
            response_model=GradeExtractor,
        )

In [None]:
assistant = EnglishGradingAssistant()

rich.print(assistant.grade_answer(question=query, student_answer=ideal_student_response, total_marks=5))