## Test llama with new retriever

In [1]:
import np
import pickle
from langchain_community.vectorstores import FAISS
from langchain_ollama import OllamaEmbeddings
from langchain_community.retrievers import BM25Retriever
from langchain.retrievers import EnsembleRetriever
from langchain_core.prompts import ChatPromptTemplate
from pydantic import BaseModel, Field
from sentence_transformers.cross_encoder import CrossEncoder

  from .autonotebook import tqdm as notebook_tqdm


In [6]:
MODEL_NAME = dict({"OPENAI":'gpt-4o'})


In [16]:
next(iter(MODEL_NAME))

'OPENAI'

In [2]:
import ollama

class code(BaseModel):
    """
    Schema for code solutions for questions about tskit. 
    """
    prefix: str = Field(description="Description of the problem and approach")
    imports: str = Field(description="Code block import statements")
    code: str = Field(description="Code block should contain function that can be called. It should have input file_path to tree_sequence. It should not include import statements")

def rerank_documents(query: str,reranker, documents: list, top_k: int = 3) -> list:
    """Re-rank documents using cross-encoder"""
    pairs = [(query, doc.page_content) for doc in documents]
    
    scores = reranker.predict(pairs)
    
    ranked_indices = np.argsort(scores)[::-1]  
    ranked_docs = [documents[i] for i in ranked_indices]
    
    return ranked_docs[:top_k]

import re
def response_parser(text):
    """
    """
    result = {
        'prefix': '',
        'imports': '',
        'code': ''
    }
    sections = re.split(r'\*\*([^\*]+)\*\*', text)

    for i in range(1, len(sections), 2):
        header = sections[i].strip()
        content = sections[i+1].strip()
        if header == 'Prefix':
            result['prefix'] = content
        
        elif header == 'Imports':
            code_match = re.search(r'```python(.*?)```', content, re.DOTALL)
            if code_match:
                result['imports'] = code_match.group(1).strip()

        elif header == 'Code':
            code_match = re.search(r'```python(.*?)```', content, re.DOTALL)
            if code_match:
                result['code'] = code_match.group(1).strip()

    return result



def generatorTool(question, context, input_file_path=None, model_nama='llama3.2', ollama_client='https://uwx72r685xxxb8-11434.proxy.runpod.net/'):
    try:

            # Set up template
        template = """You are a Python coding generator with expertise in using tskit toolkit for analysing tree-sequences. \n 
        Here is a relevant set of tskit documentation:  \n ------- \n  {context} \n ------- \n Use the tskit module to answer the user 
        question based on the above provided documentation. Ensure any code you provide should be a callable function and can be executed \n 
        with all required imports and variables defined. Structure your answer with a description of the code solution. \n
        Do not give example usage, simply create a function that is callable with a tree file as an input. \n
        Then list the imports. And finally list the functioning code block. The function should return a string providing the answer. Maintain this order which is: \n
        1. Prefix (code description and helpful information about the tree sequence)\n
        2. Imports (required code imports like tskit, to run the code in Python, write them as import statements)\n
        3. Code (code block which is a callable function with a tree sequence file as an input parameter, does not include import statements)\n
        if the question is irrelevant to code-generation. respond appropriately
        Here is the user question: {question}"""
    
        # lm = ChatOllama(model=MODEL_NAME, temperature=0)
        
        # structured_code_llm = lm.with_structured_output(code, include_raw=True)


        code_gen_prompt = ChatPromptTemplate.from_template(template)
        filled_prompt = code_gen_prompt.format(context=context, question=question)

        client = ollama.Client(host=ollama_client)

        response = client.chat(
            messages=[
                {
                    'role': 'user',
                    'content': filled_prompt,
                }
            ],
            model=model_nama,
            # format=code.model_json_schema(),
            options={'temperature': 0}
            )
        return response
        
    except Exception as e:
        print("Tools Error:", e)
        return f"Found Error while processing your query", None

In [3]:
## retriever setup

embeddings = OllamaEmbeddings(model="nomic-embed-text")
vector_store = FAISS.load_local(folder_path="./code-chunker/faiss-vector", embeddings=embeddings, index_name="faiss_index", allow_dangerous_deserialization=True)

## Load documents for BM25Retriever
with open("./code-chunker/documents.pkl", 'rb') as file:
    all_documents = pickle.load(file)

bm25_retriever = BM25Retriever.from_documents(documents=all_documents, k=10, search_kwargs={"k": 10})

faiss_retriever = vector_store.as_retriever(
    search_type="similarity_score_threshold",
    search_kwargs={"score_threshold": 0.5, "k": 10}
)

ensemble_retriever = EnsembleRetriever(
    retrievers=[bm25_retriever, faiss_retriever],
    weights=[0.5, 0.5]  # Adjust based on your use case
)


In [6]:
reranker = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2") 
query="how many sites have 1 mutations"
context = ensemble_retriever.invoke(query)
final_context = rerank_documents(query, reranker, context, 3)
res = generatorTool(query, "\n".join(cont.page_content.strip() for cont in final_context))
code_gen = response_parser(res.message.content)
print(code_gen)


Tools Error:  (status code: 404)


AttributeError: 'tuple' object has no attribute 'message'

# Ignore

In [1]:
import marko
from bs4 import BeautifulSoup
import requests
from langchain_core.documents import Document
import pickle


input_url = "https://tskit.dev/tskit/docs/stable/python-api.html"

response = requests.get(input_url)
html = marko.convert(response.text)
soup = BeautifulSoup(html, "html.parser")
# sections = soup.find_all(["h1", 'h2','h3'])  # Split by headings
article = soup.find_all('article')
documents = []

if article:
    for section in article[0].find_all('section', recursive=False):  # Top-level sections
        section_id = section.get('id')
        header = section.find('h1') or section.find('h2')
        paragraph = section.find('p')
        content = ""
        if header:
            content += header.get_text(strip=True) + " "
        if paragraph:
            content += paragraph.get_text(strip=True)
        print("Section", header)
        # file.write(f"section: {header}\n")

        if content:
            document = Document(
                 page_content=content,
                 metadata={"title": header, 'type':'text'}
                 )
            documents.append(document)

        for subsection in section.find_all('section', recursive=True):

            subheader = subsection.find('h2') or subsection.find('h3') or subsection.find('h4') or subsection.find('h5')

            if subheader:
                    title = subheader.get_text(strip=True)
                    subsection_id = subsection.get('id')
                    subparagraph = subsection.find('p')
                    print("subheader", title)
                    sub_content = ""
                    if subparagraph:
                        sub_content += subparagraph.get_text(strip=False) + " "
                        document = Document(
                            page_content=sub_content,
                            metadata={"title": title, 'type':'text'}
                            )
                        documents.append(document)
                        
            if subsection.find_all('section'):
                continue

            tables = subsection.find_all('table')
            dls = subsection.find_all('dl')
            sub_content = ""
            if dls and len(tables)==0:
                for dl in dls:
                    dl_title = dl.find('dt').get_text()
                    dl_text = dl.get_text(strip=False)
                    document = Document(
                        page_content=dl_text,
                        metadata={"title": dl_title, 'type':'code'}
                        )
                    documents.append(document)
            if len(tables)>0:
                for table in tables:
                    rows = table.find('tbody').find_all('tr')
                    for row in rows:
                        table_paragraph = ""
                        cells = row.find_all('td')
                        if len(cells) == 2:
                            property_name = cells[0].get_text(strip=True)
                            description = cells[1].get_text(strip=True)
                            table_paragraph += f"{property_name}: {description}. "

                            document = Document(
                                page_content=table_paragraph,
                                metadata={"title": property_name, 'type':'code'}
                                )
                            documents.append(document)

with open("/storage2/pratik/git/code-chunker/documents.pkl", 'rb') as file:
    all_documents = pickle.load(file)
all_documents.extend(documents)


Section <h1>Python API<a class="headerlink" href="#python-api" title="Link to this heading">#</a></h1>
subheader Trees and tree sequences#
subheader TreeSequenceAPI#
subheader General properties#
subheader Efficient table column access#
subheader Loading and saving#
subheader Obtaining trees#
subheader Obtaining other objects#
subheader Tree topology#
subheader Genetic variation#
subheader Demography#
subheader Other#
subheader Tree sequence modification#
subheader Identity by descent#
subheader Tables#
subheader Statistics#
subheader Topological analysis#
subheader Display#
subheader Export#
subheader TreeAPI#
subheader General properties#
subheader Creating new trees#
subheader Node measures#
subheader Simple measures#
subheader Array access#
subheader Tree traversal#
subheader Topological analysis#
subheader Comparing trees#
subheader Balance/imbalance indices#
subheader Sites and mutations#
subheader Moving to other trees#
subheader Display#
subheader Export#
subheader Tables and T

In [2]:
from langchain_community.retrievers import BM25Retriever
bm25_retriever = BM25Retriever.from_documents(documents=all_documents, k=10, search_kwargs={"k": 10})

In [3]:
# Create hybrid retriever
from langchain.retrievers import EnsembleRetriever

from langchain_community.vectorstores import FAISS
from langchain_ollama import OllamaEmbeddings

embeddings = OllamaEmbeddings(model="nomic-embed-text")

vector_store = FAISS.load_local(folder_path="/storage2/pratik/git/code-chunker/faiss-vector", embeddings=embeddings, index_name="faiss_index", allow_dangerous_deserialization=True)


faiss_retriever = vector_store.as_retriever(
    search_type="similarity_score_threshold",
    search_kwargs={"score_threshold": 0.5, "k": 10}
)

ensemble_retriever = EnsembleRetriever(
    retrievers=[bm25_retriever, faiss_retriever],
    weights=[0.5, 0.5]  # Adjust based on your use case
)

In [4]:
from langchain_core.prompts import ChatPromptTemplate

In [5]:
from ollama import chat, generate
from pydantic import BaseModel, Field
from lorax.faiss_vector import rerank_documents

  from .autonotebook import tqdm as notebook_tqdm
  memory = ConversationBufferMemory(return_messages=True)


In [34]:
import ollama

class code(BaseModel):
    """
    Schema for code solutions for questions about tskit. 
    """
    prefix: str = Field(description="Description of the problem and approach")
    imports: str = Field(description="Code block import statements")
    code: str = Field(description="Code block should contain function that can be called. It should have input file_path to tree_sequence. It should not include import statements")



In [35]:
query="how many sites have 1 mutations"

context = ensemble_retriever.invoke(query)

In [36]:
final_context = rerank_documents(query, context, 3)

In [37]:
final_context

[Document(id='e4573bd8-2a28-4c08-8fec-e8cd928fa50d', metadata={'title': 'Plotting mutations'}, page_content='<p>Note that, unusually, the rightmost site on the axis has more than one stacked chevron,\nindicating that multiple mutations in the tree occur at the same site. These could be\nmutations to different allelic states, or recurrent/back mutations. In this case the\nmutations, 14 and 15 (above nodes 1 and 6) are recurrent mutations from T to G.</p><pre><code class="language-{code-cell}">:"tags": ["hide-input"]\nts_mutated = tskit.load("data/viz_ts_small_mutated.trees")\nsite_descr = str(next(ts_mutated.at_index(2).sites()))\nprint(site_descr.replace("[", "[\\n  ").replace("),", "),\\n ").replace("],", "],\\n"))\n</code>'),
 Document(metadata={'title': 'num_mutations'}, page_content='    @property\n    def num_mutations(self):\n        """\n        Returns the total number of mutations across all sites on this tree.\n\n        :return: The total number of mutations over all sites o

In [38]:

res = generatorTool(query, "\n".join(cont.page_content.strip() for cont in final_context))

In [54]:
import re
def response_parser(text):
    """
    """
    result = {
        'prefix': '',
        'imports': '',
        'code': ''
    }
    sections = re.split(r'\*\*([^\*]+)\*\*', text)

    for i in range(1, len(sections), 2):
        header = sections[i].strip()
        content = sections[i+1].strip()
        if header == 'Prefix':
            result['prefix'] = content
        
        elif header == 'Imports':
            code_match = re.search(r'```python(.*?)```', content, re.DOTALL)
            if code_match:
                result['imports'] = code_match.group(1).strip()

        elif header == 'Code':
            code_match = re.search(r'```python(.*?)```', content, re.DOTALL)
            if code_match:
                result['code'] = code_match.group(1).strip()

    return result

In [56]:
code_gen = response_parser(res.message.content)

In [58]:
print(code_gen['code'])

def count_sites_with_one_mutation(tree_file):
    ts_mutated = tskit.load(tree_file)
    sites_with_one_mutation = sum(1 for site in ts_mutated.sites() if len(site.mutations) == 1)
    return f"There are {sites_with_one_mutation} sites with exactly one mutation."


In [59]:
import tskit
def count_sites_with_one_mutation(tree_file):
    ts_mutated = tskit.load(tree_file)
    sites_with_one_mutation = sum(1 for site in ts_mutated.sites() if len(site.mutations) == 1)
    return f"There are {sites_with_one_mutation} sites with exactly one mutation."

In [60]:
count_sites_with_one_mutation('data/sample.trees')

'There are 503556 sites with exactly one mutation.'

## Rerank

In [None]:
from sentence_transformers.cross_encoder import CrossEncoder
# model = CrossEncoder("cross-encoder/ms-marco-MiniLM-L6-v2")
# scores = model.predict([["My first", "sentence pair"], ["Second text", "pair"]])

In [141]:
import numpy as np

In [None]:
rerank_documents(query, context, 5)

In [None]:
from lorax.tools import generatorTool

In [None]:
res = generatorTool("how many intervals does all the trees cover in the given tree-sequence")

In [None]:
print(res.response)

In [None]:
import json
print(json.loads(res.response)['code'].strip())

In [None]:
json.loads(res.response)

In [9]:

tree_sequence = tskit.load("./data/sample.trees")

In [None]:
tree_sequence.num_sites

In [None]:
from transformers import AutoTokenizer, AutoModelForCausalLM
tokenizer = AutoTokenizer.from_pretrained("Salesforce/codegen-350M-multi")
model = AutoModelForCausalLM.from_pretrained("Salesforce/codegen-350M-multi")

text = "def hello_world():"
input_ids = tokenizer(text, return_tensors="pt").input_ids

generated_ids = model.generate(input_ids, max_length=128)
print(tokenizer.decode(generated_ids[0], skip_special_tokens=True))
