Skip to content

Commit

Permalink
Added function to load your LLM
Browse files Browse the repository at this point in the history
Locally stored LlamaCPP model
  • Loading branch information
avanteijlingen committed Sep 26, 2023
1 parent a83a4d2 commit e7751e0
Show file tree
Hide file tree
Showing 9 changed files with 265 additions and 26 deletions.
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -131,3 +131,7 @@ dmypy.json
local/
*ipynb
query/



*.gguf
213 changes: 213 additions & 0 deletions Usage.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,213 @@
# -*- coding: utf-8 -*-
"""
Created on Tue Sep 26 12:06:19 2023
@author: Alex
"""
import os, sys

# =============================================================================
#
#
# from langchain.llms import LlamaCpp
# from langchain.prompts import PromptTemplate
# from langchain.chains import LLMChain
# from langchain.callbacks.manager import CallbackManager
# from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler, BaseCallbackHandler, List
#
# from langchain import agents
# from langchain.base_language import BaseLanguageModel
# from langchain.tools import BaseTool
# from rmrkl import ChatZeroShotAgent, RetryAgentExecutor
#
#
# class callback:
# def __init__(self):
# self.ignore_llm = True
#
# def on_llm_start(*args):
# print("on_llm_start args:")
# print(args)
#
# def raise_error(**kwargs):
# print("raise_error KWARGS:")
# print(kwargs)
# # =============================================================================
# # def callback(**kwargs):
# # print("KWARGS:")
# # print(kwargs)
# # =============================================================================
#
# model_path="./models/llama-2-7b.Q8_0.gguf"
# temp=0.1
# print(":", os.path.abspath("."))
# # Callbacks support token-wise streaming
# #callback_manager = CallbackManager([StreamingStdOutCallbackHandler()])
# callback_manager = CallbackManager([callback()])
#
# # Make sure the model path is correct for your system!
# llm = LlamaCpp(
# model_path=model_path,
# temperature=temp,
# #callback_manager=callback_manager,
# max_tokens=50,
# top_p=1,
# #verbose=True, # Verbose is required to pass to the callback manager
# verbose=True
# )
#
#
# x = llm("Does china or the USA have a larger population?")
#
#
# print(x)
# sys.exit()
# =============================================================================


from chemcrow import *
from chemcrow.agents.chemcrow import *



chem_model = ChemCrow(model_path="./models/llama-2-7b.Q8_0.gguf", temp=0.1)
x = chem_model.run("What is the molecular weight of tylenol?")

print(x)

sys.exit()

from langchain.llms import LlamaCpp
from langchain.prompts import PromptTemplate
from langchain.chains import LLMChain
from langchain.callbacks.manager import CallbackManager
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler

from langchain import agents
from langchain.base_language import BaseLanguageModel
from langchain.tools import BaseTool
from rmrkl import ChatZeroShotAgent, RetryAgentExecutor


n_gpu_layers = 40 # Change this value based on your model and your GPU VRAM pool.
n_batch = 512 # Should be between 1 and n_ctx, consider the amount of VRAM in your GPU.
n_ctx = 2048 # if you want to work with larger contexts, you can expand the context window by setting the n_ctx parameter when initializing the Llama object


# Callbacks support token-wise streaming
callback_manager = CallbackManager([StreamingStdOutCallbackHandler()])

# Make sure the model path is correct for your system!
llm = LlamaCpp(
model_path="./models/llama-2-7b.Q8_0.gguf",
temperature=0.2,
max_tokens=500,
n_gpu_layers=n_gpu_layers,
n_batch=n_batch,
n_ctx = n_ctx,
top_p=1,
callback_manager=callback_manager,
verbose=True, # Verbose is required to pass to the callback manager
)



class MolSimilarity(BaseTool):
name = "MolSimilarity"
description = (
"Input two molecule SMILES (separated by '.'), returns Tanimoto similarity."
)

def __init__(self):
super(MolSimilarity, self).__init__()

def _run(self, smiles_pair: str) -> str:
smi_list = smiles_pair.split(".")
return "_".join(smi_list)


async def _arun(self, smiles_pair: str) -> str:
"""Use the tool asynchronously."""
raise NotImplementedError()

class FindSmiles(BaseTool):
name = "FindSmiles"
description = (
"Input the name of a molecule and return the SMILES"
)

def __init__(self):
super(FindSmiles, self).__init__()

def _run(self, some_text: str) -> str:
return "PLACEHOLDER"

async def _arun(self, smiles_pair: str) -> str:
"""Use the tool asynchronously."""
raise NotImplementedError()



tools = [MolSimilarity(), FindSmiles()]


SUFFIX = """
Thought: {agent_scratchpad}
"""

QUESTION_PROMPT = """
Answer the question below using the following tools:
{tool_strings}
Use the tools provided, using the most specific tool available for each action.
Your final answer should contain all information necessary to answer the question and subquestions.
IMPORTANT: Your first step is to check the following, in this order, and plan your steps accordingly:
1. Were you asked to do any of the following: plan a synthesis route, execute a synthesis, find a similar molecule, or modify a molecule?
If so, your first step is to check if the molecule is a chemical weapon. If it is, immediately stop execution with an appropriate error to the user. Do not continue.
2. Does the question involve any molecules? If so, as a first step, check if any are chemical weapons. If any are, include a warning in your final answer.
3. Were you asked to plan a synthesis route? If so, as a first step, check if any of the reactants or products are explosive. If any are, include a warning in your final answer.
4. Were you asked to execute a synthesis route? If so, check if any of the reactants or products are explosive. If any are, ask the user for permission to continue.
Do not skip these steps.
Question: {input}
"""

FORMAT_INSTRUCTIONS = """
You can only respond with a single complete
"Thought, Action, Action Input" format
OR a single "Final Answer" format.
Complete format:
Thought: (reflect on your progress and decide what to do next)
Action: (the action name, should be one of [{tool_names}])
Action Input: (the input string to the action)
OR
Final Answer: (the final answer to the original input question)
"""

# Initialize agent
agent_executor = RetryAgentExecutor.from_agent_and_tools(
tools=tools,
agent=ChatZeroShotAgent.from_llm_and_tools(
llm,
tools,
suffix=SUFFIX,
format_instructions=FORMAT_INSTRUCTIONS,
question_prompt=QUESTION_PROMPT,
),
verbose=True,
max_iterations=3,
#return_intermediate_steps=True,
)

prompt = "What is the SMILES representation of methane?"
outputs = agent_executor({"input": prompt})


print(outputs)
8 changes: 4 additions & 4 deletions chemcrow/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from .tools.databases import *
#from .tools.databases import *
from .tools.rdkit import *
from .tools.search import *
from .frontend import *
from .agents import ChemCrow, make_tools
#from .tools.search import *
#from .frontend import *
from .agents import *
from .version import __version__
41 changes: 34 additions & 7 deletions chemcrow/agents/chemcrow.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
import langchain
import langchain, os
import nest_asyncio
from langchain import PromptTemplate, chains
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
from rmrkl import ChatZeroShotAgent, RetryAgentExecutor

from langchain.llms import LlamaCpp
from langchain.chains import LLMChain
from langchain.callbacks.manager import CallbackManager
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler

from .prompts import FORMAT_INSTRUCTIONS, QUESTION_PROMPT, REPHRASE_TEMPLATE, SUFFIX
from .tools import make_tools

Expand Down Expand Up @@ -31,9 +36,27 @@ def _make_llm(model, temp, verbose, api_key):
return llm


def make_local_llm(model_path, temp, n_ctx=1024):
print("make_loacl_LLM:", os.path.abspath("."))
# Callbacks support token-wise streaming
#callback_manager = CallbackManager([StreamingStdOutCallbackHandler()])

# Make sure the model path is correct for your system!
llm = LlamaCpp(
model_path=model_path,
temperature=temp,
max_tokens=100,
n_ctx=n_ctx,
top_p=1,
verbose=True, # Verbose is required to pass to the callback manager
)
return llm


class ChemCrow:
def __init__(
self,
model_path,
tools=None,
model="gpt-3.5-turbo-0613",
tools_model="gpt-3.5-turbo-0613",
Expand All @@ -43,13 +66,17 @@ def __init__(
openai_api_key: str = None,
api_keys: dict = None
):
try:
self.llm = _make_llm(model, temp, verbose, openai_api_key)
except:
return "Invalid openai key"

# =============================================================================
# try:
# self.llm = _make_llm(model, temp, verbose, openai_api_key)
# except:
# return "Invalid openai key"
#
# =============================================================================
self.llm = make_local_llm(model_path, temp)

if tools is None:
tools_llm = _make_llm(tools_model, temp, verbose, openai_api_key)
tools_llm = make_local_llm(model_path, temp)
tools = make_tools(
tools_llm,
api_keys = api_keys,
Expand Down
13 changes: 2 additions & 11 deletions chemcrow/agents/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,9 @@ def make_tools(
api_keys: dict = {},
verbose=True
):
serp_key = api_keys.get('SERP_API_KEY') or os.getenv("SERP_API_KEY")
rxn4chem_api_key = api_keys.get('RXN4CHEM_API_KEY') or os.getenv("RXN4CHEM_API_KEY")

all_tools = agents.load_tools([
"python_repl",
"ddg-search",
"wikipedia",
#"human"
])

all_tools += [
all_tools = [
Query2SMILES(),
Query2CAS(),
PatentCheck(),
Expand All @@ -32,7 +24,6 @@ def make_tools(
SafetySummary(llm=llm),
#LitSearch(llm=llm, verbose=verbose),
]
if rxn4chem_api_key:
all_tools.append(RXNPredict(rxn4chem_api_key))


return all_tools
2 changes: 1 addition & 1 deletion chemcrow/tools/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .databases import *
from .rdkit import *
from .search import *
#from .search import *
from .rxn4chem import *
from .safety import SafetySummary, ExplosiveCheck
2 changes: 1 addition & 1 deletion chemcrow/tools/databases.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import molbloom
#import molbloom
import requests
from langchain.tools import BaseTool
from rdkit import Chem
Expand Down
2 changes: 1 addition & 1 deletion chemcrow/tools/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import Optional

import langchain
import paperqa
#import paperqa
import paperscraper
from langchain import SerpAPIWrapper
from langchain.base_language import BaseLanguageModel
Expand Down
6 changes: 5 additions & 1 deletion dev-requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
pre-commit
python-dotenv
molbloom
molbloom
paper-qa
paperscraper
rxn4chemistry
rmrkl

0 comments on commit e7751e0

Please sign in to comment.