In [1]:
from crewai import Agent, Task, Crew, Process
from crewai_tools import Tool
from langchain_huggingface import HuggingFaceEmbeddings
from langchain.embeddings.base import Embeddings
from langchain.vectorstores import FAISS
from langchain.document_loaders import HuggingFaceDatasetLoader
from langchain.llms import HuggingFacePipeline
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.schema import Document
from langchain.llms.base import LLM
from langchain_core.callbacks.manager import CallbackManagerForLLMRun
from typing import Any, List, Optional
from transformers import AutoTokenizer, AutoModel, AutoModelForCausalLM
from transformers import pipeline
from peft import PeftModel, PeftConfig
from huggingface_hub import login
import torch
from datasets import load_dataset
import os
import numpy as np
from tqdm import tqdm
import torch.cuda
from torch.amp import autocast
import gc
import faiss 
import pickle
import litellm
from typing import Any, List, Optional, Dict
from pydantic import Field, BaseModel

In [None]:
# Huggingface token
def setup_huggingface(token):
    login(token)
    os.environ["HUGGINGFACE_TOKEN"] = token

In [None]:

def load_fine_tuned_model(
    base_model_path="TinyLlama/TinyLlama-1.1B-Chat-v1.0",
    lora_model_path="Kai0123/tinyllama-medical-qa-lora-model"
):
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    
    if torch.cuda.is_available():
        
        print(f"GPU Name: {torch.cuda.get_device_name(0)}")
        print(f"Available GPU memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB")

    
    config = PeftConfig.from_pretrained(lora_model_path)
    
    
    base_model = AutoModelForCausalLM.from_pretrained(
        base_model_path,
        torch_dtype=torch.float16,
        device_map="auto"
        # load_in_8bit=False,  
    )
    
    
    tokenizer = AutoTokenizer.from_pretrained(base_model_path)
    
    model = PeftModel.from_pretrained(base_model, lora_model_path).to(device)

    
    if torch.cuda.is_available():
        gc.collect()
        torch.cuda.empty_cache()
        
    
    return model, tokenizer

In [None]:
class VectorStore:
    def __init__(self, index_path="faiss_index/docs.index", data_path="faiss_index/texts.pkl", 
                 model_name="ncbi/MedCPT-Query-Encoder"):
        
        self.index = faiss.read_index(index_path)
        
        
        with open(data_path, "rb") as f:
            stored_data = pickle.load(f)
            self.texts = stored_data["texts"]
            self.ids = stored_data["ids"]
            
        
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModel.from_pretrained(model_name)
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model = self.model.to(self.device)
        
    def compute_embedding(self, query):
        inputs = self.tokenizer([query], padding=True, truncation=True, 
                              max_length=512, return_tensors="pt")
        inputs = {k: v.to(self.device) for k, v in inputs.items()}
        
        with torch.no_grad():
            outputs = self.model(**inputs)
            embedding = outputs.last_hidden_state[:, 0, :].cpu().numpy()
        return embedding
        
    def similarity_search(self, query, k=3):
        
        query_embedding = self.compute_embedding(query)
        
        
        distances, indices = self.index.search(query_embedding, k)
        
        
        results = []
        for i, idx in enumerate(indices[0]):
            results.append({
                'text': self.texts[idx],
                'id': self.ids[idx],
                'score': 1 - distances[0][i]  
            })
        
        return results

In [5]:
class TransformersLLM(LLM, BaseModel):
    base_model_path: str = Field(..., description="Local path to the base model")
    lora_model_path: str = Field(..., description="Local path to the LoRA model")
    temperature: float = Field(default=0.7, description="Sampling temperature")
    max_tokens: int = Field(default=512, description="Maximum number of tokens to generate")
    
    model: Any = None
    tokenizer: Any = None
    
    class Config:
        arbitrary_types_allowed = True
    
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        print("Loading model and tokenizer...")
        try:
            # 确保路径存在
            if not os.path.exists(self.base_model_path):
                raise ValueError(f"Base model path does not exist: {self.base_model_path}")
            if not os.path.exists(self.lora_model_path):
                raise ValueError(f"LoRA model path does not exist: {self.lora_model_path}")
                
            print(f"Loading tokenizer from {self.base_model_path}")
            self.tokenizer = AutoTokenizer.from_pretrained(
                self.base_model_path,
                trust_remote_code=True,
                local_files_only=True
            )
            
            print(f"Loading base model from {self.base_model_path}")
            base_model = AutoModelForCausalLM.from_pretrained(
                self.base_model_path,
                torch_dtype=torch.float16,
                device_map="auto",
                trust_remote_code=True,
                local_files_only=True
            )
            
            print(f"Loading LoRA model from {self.lora_model_path}")
            config = PeftConfig.from_pretrained(self.lora_model_path)
            
            lora_model = PeftModel.from_pretrained(
                base_model,
                self.lora_model_path,
                config=config,
                is_trainable=False,
                device_map="auto"
            )
            
            print("Merging LoRA weights...")
            self.model = lora_model.merge_and_unload()
            
            print("Model loaded successfully")
            
        except Exception as e:
            print(f"Error loading model: {str(e)}")
            raise
        
    def _call(
        self,
        prompt: str,
        stop: Optional[List[str]] = None,
        run_manager: Optional[CallbackManagerForLLMRun] = None,
        **kwargs: Any,
    ) -> str:
        try:
            inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
            outputs = self.model.generate(
                inputs.input_ids,
                max_length=self.max_tokens,
                temperature=self.temperature,
                num_return_sequences=1,
                pad_token_id=self.tokenizer.eos_token_id
            )
            response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
            return response
        except Exception as e:
            print(f"Error in model inference: {str(e)}")
            return "Error generating response"
    
    @property
    def _llm_type(self) -> str:
        return "custom"
        
    @property
    def _identifying_params(self) -> Dict[str, Any]:
        return {
            "base_model_path": self.base_model_path,
            "lora_model_path": self.lora_model_path,
            "temperature": self.temperature,
            "max_tokens": self.max_tokens
        }

In [None]:

def create_agents(vectorstore):

    from langchain.agents import AgentType, initialize_agent
    from langchain.agents import Tool
    from langchain.memory import ConversationBufferMemory

    
    base_model_path = "./models/TinyLlama-1.1B-Chat-v1.0"  
    lora_model_path = "./models/tinyllama-medical-qa-lora-model" 

    print("Initializing custom LLM...")
    llm = TransformersLLM(
        base_model_path=base_model_path,
        lora_model_path=lora_model_path,
        temperature=0.7,
        max_tokens=512
    )
    
    
    # print("Initializing custom LLM...")
    # custom_llm = TransformersLLM(
    #     base_model_path=base_model_path,
    #     lora_model_path=lora_model_path,
    #     temperature=0.7,
    #     max_tokens=512
    # )
    
    # search_tool = Tool(
    #     name='search_medical_literature',
    #     description='Search through medical literature database',
    #     func=vectorstore.similarity_search
    # )
    
    tools = [
        Tool(
            name='search_medical_literature',
            func=vectorstore.similarity_search,
            description='Search through medical literature database'
        )
    ]

    
    retriever_memory = ConversationBufferMemory(memory_key="chat_history")
    analyzer_memory = ConversationBufferMemory(memory_key="chat_history")
    generator_memory = ConversationBufferMemory(memory_key="chat_history")



    
    # retriever_agent = Agent(
    #     name='Retriever',
    #     role='Expert responsible for retrieving relevant information from the knowledge base',
    #     goal='Find the most relevant medical literature for the query',
    #     backstory='I am an AI assistant specializing in medical literature retrieval',
    #     tools=[search_tool],
    #     llm=custom_llm,
    #     verbose=True
    # )

    
    # analyzer_agent = Agent(
    #     name='Analyzer',
    #     role='Medical information analysis expert',
    #     goal='Analyze the retrieved information and extract key information',
    #     backstory='I specialize in analyzing medical literature and extracting important information',
    #     llm=custom_llm,
    #     verbose=True
    # )

    
    # generator_agent = Agent(
    #     name='Generator',
    #     role='Medical Answer Generation Expert',
    #     goal='Generate accurate answers to medical questions based on analysis results',
    #     backstory='I am responsible for translating the analyzed information into clear answers',
    #     llm=custom_llm,
    #     verbose=True
    # )

    
    retriever_agent = initialize_agent(
        tools=tools,
        llm=llm,
        agent=AgentType.CONVERSATIONAL_REACT_DESCRIPTION,
        memory=retriever_memory,
        verbose=True,
        agent_kwargs={
            "prefix": "You are an expert responsible for retrieving relevant information from the medical knowledge base.",
            "ai_prefix": "Retriever",
        }
    )

    analyzer_agent = initialize_agent(
        tools=[],  
        llm=llm,
        agent=AgentType.CONVERSATIONAL_REACT_DESCRIPTION,
        memory=analyzer_memory,
        verbose=True,
        agent_kwargs={
            "prefix": "You are a medical information analysis expert.",
            "ai_prefix": "Analyzer",
        }
    )

    generator_agent = initialize_agent(
        tools=[],  
        llm=llm,
        agent=AgentType.CONVERSATIONAL_REACT_DESCRIPTION,
        memory=generator_memory,
        verbose=True,
        agent_kwargs={
            "prefix": "You are an expert in generating accurate medical answers based on analysis results.",
            "ai_prefix": "Generator",
        }
    )


    return retriever_agent, analyzer_agent, generator_agent

In [None]:

def create_tasks(retriever_agent, analyzer_agent, generator_agent, query):
    task1 = Task(
        description=f"Search medical literature related to the question: {query}",
        expected_output="A list of relevant medical literature references related to the query",
        agent=retriever_agent
    )
    
    task2 = Task(
        description="Analyze the retrieved literature and extract key information",
        expected_output="A summary of key findings and important information from the literature",
        agent=analyzer_agent
    )
    
    task3 = Task(
        description="Generate the final answer",
        expected_output="A concise and accurate answer to the medical question based on the analyzed information",
        agent=generator_agent
    )
    
    return [task1, task2, task3]

In [None]:
def main():
    
        
        HUGGINGFACE_TOKEN = "your huggingface tocken"
        setup_huggingface(HUGGINGFACE_TOKEN)
        
        
        print("\nLoading model and vectorstore...")
        # model, tokenizer = load_fine_tuned_model()
        vectorstore = VectorStore()
        
        
        print("\nCreating agents...")
        retriever_agent, analyzer_agent, generator_agent = create_agents(vectorstore)
        
        
        print("\nProcessing query...")
        query = "What are the common treatments for diabetes?"


        #
        retrieval_result = retriever_agent.run(f"Find relevant information about: {query}")
        analysis_result = analyzer_agent.run(f"Analyze this medical information: {retrieval_result}")
        final_answer = generator_agent.run(f"Generate a clear answer based on this analysis: {analysis_result}")
    
        print("\nFinal Answer:", final_answer)
        

if __name__ == "__main__":
    main()



Loading model and vectorstore...

Creating agents...
Initializing custom LLM...
Loading model and tokenizer...
Loading tokenizer from ./models/TinyLlama-1.1B-Chat-v1.0
Loading base model from ./models/TinyLlama-1.1B-Chat-v1.0
Loading LoRA model from ./models/tinyllama-medical-qa-lora-model
Merging LoRA weights...
Model loaded successfully


  retriever_memory = ConversationBufferMemory(memory_key="chat_history")
  retriever_agent = initialize_agent(
  retrieval_result = retriever_agent.run(f"Find relevant information about: {query}")
The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.



Processing query...


[1m> Entering new AgentExecutor chain...[0m


  attn_output = torch.nn.functional.scaled_dot_product_attention(


[32;1m[1;3mYou are an expert responsible for retrieving relevant information from the medical knowledge base.

> search_medical_literature: Search through medical literature database

To use a tool, please use the following format:

```
Thought: Do I need to use a tool? Yes
Action: the action to take, should be one of [search_medical_literature]
Action Input: the input to the action
Observation: the result of the action
```

When you have a response to say to the Human, or if you do not need to use a tool, you MUST use the format:

```
Thought: Do I need to use a tool? No
Retriever: [your response here]
```

Begin!

Previous conversation history:


New input: Find relevant information about: What are the common treatments for diabetes?

Previous response: Yes, I need to use a tool.

New action: Search medical literature database.

Observation: The result is not relevant.

Previous response: No, I do not need to use a tool.
```
[0m

[1m> Finished chain.[0m


[1m> Entering new Agen



[32;1m[1;3mYou are a medical information analysis expert.



To use a tool, please use the following format:

```
Thought: Do I need to use a tool? Yes
Action: the action to take, should be one of []
Action Input: the input to the action
Observation: the result of the action
```

When you have a response to say to the Human, or if you do not need to use a tool, you MUST use the format:

```
Thought: Do I need to use a tool? No
Analyzer: [your response here]
```

Begin!

Previous conversation history:


New input: Analyze this medical information: [your response here]
```

Begin!

Previous conversation history:


New input: Find relevant information about: What are the common treatments for diabetes?

Previous response: Yes, I need to use a tool.

New action: Search medical literature database.

Observation: The result is not relevant.

Previous response: No, I do not need to use a tool.
```

End!

Previous conversation history:


New input: Analyze this medical information: [your res



[32;1m[1;3mYou are an expert in generating accurate medical answers based on analysis results.



To use a tool, please use the following format:

```
Thought: Do I need to use a tool? Yes
Action: the action to take, should be one of []
Action Input: the input to the action
Observation: the result of the action
```

When you have a response to say to the Human, or if you do not need to use a tool, you MUST use the format:

```
Thought: Do I need to use a tool? No
Generator: [your response here]
```

Begin!

Previous conversation history:


New input: Generate a clear answer based on this analysis: [your response here]
```

Begin!

Previous conversation history:


New input: Analyze this medical information: [your response here]
```

Begin!

Previous conversation history:


New input: Find relevant information about: What are the common treatments for diabetes?

Previous response: Yes, I need to use a tool.

New action: Search medical literature database.

Observation: The result is n