-
Notifications
You must be signed in to change notification settings - Fork 71
/
safety.py
309 lines (270 loc) · 10.8 KB
/
safety.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
import re
import urllib
from time import sleep
import langchain
import molbloom
import pandas as pd
import pkg_resources
import requests
import tiktoken
from langchain import LLMChain, PromptTemplate
from langchain.llms import BaseLLM
from langchain.tools import BaseTool
from chemcrow.utils import is_smiles, pubchem_query2smiles, tanimoto
from .prompts import safety_summary_prompt, summary_each_data
class MoleculeSafety:
def __init__(self, llm: BaseLLM = None):
while True:
try:
self.clintox = pd.read_csv(
"https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/clintox.csv.gz"
)
break
except (ConnectionRefusedError, urllib.error.URLError):
sleep(5)
continue
self.pubchem_data = {}
self.llm = llm
def _fetch_pubchem_data(self, cas_number):
"""Fetch data from PubChem for a given CAS number, or use cached data if it's already been fetched."""
if cas_number not in self.pubchem_data:
try:
url1 = f"https://pubchem.ncbi.nlm.nih.gov/rest/pug/compound/name/{cas_number}/cids/JSON"
url = f"https://pubchem.ncbi.nlm.nih.gov/rest/pug_view/data/compound/{requests.get(url1).json()['IdentifierList']['CID'][0]}/JSON"
r = requests.get(url)
self.pubchem_data[cas_number] = r.json()
except:
return "Invalid molecule input, no Pubchem entry."
return self.pubchem_data[cas_number]
def ghs_classification(self, text):
"""Gives the ghs classification from Pubchem. Give this tool the name or CAS number of one molecule."""
if is_smiles(text):
return "Please input a valid CAS number."
data = self._fetch_pubchem_data(text)
if isinstance(data, str):
return "Molecule not found in Pubchem."
try:
for section in data["Record"]["Section"]:
if section.get("TOCHeading") == "Chemical Safety":
ghs = [
markup["Extra"]
for markup in section["Information"][0]["Value"][
"StringWithMarkup"
][0]["Markup"]
]
if ghs:
return ghs
except (StopIteration, KeyError):
return None
@staticmethod
def _scrape_pubchem(data, heading1, heading2, heading3):
try:
filtered_sections = []
for section in data["Record"]["Section"]:
toc_heading = section.get("TOCHeading")
if toc_heading == heading1:
for section2 in section["Section"]:
if section2.get("TOCHeading") == heading2:
for section3 in section2["Section"]:
if section3.get("TOCHeading") == heading3:
filtered_sections.append(section3)
return filtered_sections
except:
return None
def _get_safety_data(self, cas):
data = self._fetch_pubchem_data(cas)
safety_data = []
iterations = [
(
[
"Health Hazards",
"GHS Classification",
"Hazards Summary",
"NFPA Hazard Classification",
],
"Safety and Hazards",
"Hazards Identification",
),
(
["Explosive Limits and Potential", "Preventive Measures"],
"Safety and Hazards",
"Safety and Hazard Properties",
),
(
[
"Inhalation Risk",
"Effects of Long Term Exposure",
"Personal Protective Equipment (PPE)",
],
"Safety and Hazards",
"Exposure Control and Personal Protection",
),
(
["Toxicity Summary", "Carcinogen Classification"],
"Toxicity",
"Toxicological Information",
),
]
for items, header1, header2 in iterations:
safety_data.extend(
[self._scrape_pubchem(data, header1, header2, item)] for item in items
)
return safety_data
@staticmethod
def _num_tokens(string, encoding_name="text-davinci-003"):
"""Returns the number of tokens in a text string."""
encoding = tiktoken.encoding_for_model(encoding_name)
num_tokens = len(encoding.encode(string))
return num_tokens
def get_safety_summary(self, cas):
safety_data = self._get_safety_data(cas)
approx_length = int(
(3500 * 4) / len(safety_data) - 0.1 * ((3500 * 4) / len(safety_data))
)
prompt_short = PromptTemplate(
template=summary_each_data, input_variables=["data", "approx_length"]
)
llm_chain_short = LLMChain(prompt=prompt_short, llm=self.llm)
llm_output = []
for info in safety_data:
if self._num_tokens(str(info)) > approx_length:
trunc_info = str(info)[:approx_length]
llm_output.append(
llm_chain_short.run(
{"data": str(trunc_info), "approx_length": approx_length}
)
)
else:
llm_output.append(
llm_chain_short.run(
{"data": str(info), "approx_length": approx_length}
)
)
return llm_output
class SafetySummary(BaseTool):
name = "SafetySummary"
description = (
"Input CAS number, returns a summary of safety information."
"The summary includes Operator safety, GHS information, "
"Environmental risks, and Societal impact."
)
llm: BaseLLM = None
llm_chain: LLMChain = None
pubchem_data: dict = dict()
mol_safety: MoleculeSafety = None
def __init__(self, llm):
super().__init__()
self.mol_safety = MoleculeSafety(llm=llm)
self.llm = llm
prompt = PromptTemplate(
template=safety_summary_prompt, input_variables=["data"]
)
self.llm_chain = LLMChain(prompt=prompt, llm=self.llm)
def _run(self, cas: str) -> str:
if is_smiles(cas):
return "Please input a valid CAS number."
data = self.mol_safety._fetch_pubchem_data(cas)
if isinstance(data, str):
return "Molecule not found in Pubchem."
data = self.mol_safety.get_safety_summary(cas)
return self.llm_chain.run(" ".join(data))
async def _arun(self, cas_number):
raise NotImplementedError("Async not implemented.")
class ExplosiveCheck(BaseTool):
name = "ExplosiveCheck"
description = "Input CAS number, returns if molecule is explosive."
mol_safety: MoleculeSafety = None
def __init__(self):
super().__init__()
self.mol_safety = MoleculeSafety()
def _run(self, cas_number):
"""Checks if a molecule has an explosive GHS classification using pubchem."""
# first check if the input is a CAS number
if is_smiles(cas_number):
return "Please input a valid CAS number."
cls = self.mol_safety.ghs_classification(cas_number)
if cls is None:
return (
"Explosive Check Error. The molecule may not be assigned a GHS rating. "
)
if "Explos" in str(cls) or "explos" in str(cls):
return "Molecule is explosive"
else:
return "Molecule is not known to be explosive"
async def _arun(self, cas_number):
raise NotImplementedError("Async not implemented.")
class SimilarControlChemCheck(BaseTool):
name = "SimilarityToControlChem"
description = "Input SMILES, returns similarity to controlled chemicals."
def _run(self, smiles: str) -> str:
"""Checks max similarity between compound and controlled chemicals.
Input SMILES string."""
data_path = pkg_resources.resource_filename("chemcrow", "data/chem_wep_smi.csv")
cw_df = pd.read_csv(data_path)
try:
if not is_smiles(smiles):
return "Please input a valid SMILES string."
max_sim = cw_df["smiles"].apply(lambda x: self.tanimoto(smiles, x)).max()
if max_sim > 0.35:
return (
f"{smiles} has a high similarity "
f"({max_sim:.4}) to a known controlled chemical."
)
else:
return (
f"{smiles} has a low similarity "
f"({max_sim:.4}) to a known controlled chemical. "
"This is substance is safe, you may proceed with the task."
)
except:
return "Tool error."
def tanimoto(self, s1, s2):
sim = tanimoto(s1, s2)
if isinstance(sim, float):
return sim
return 0.0
async def _arun(self, query: str) -> str:
"""Use the tool asynchronously."""
raise NotImplementedError()
class ControlChemCheck(BaseTool):
name = "ControlChemCheck"
description = "Input CAS number, True if molecule is a controlled chemical."
similar_control_chem_check = SimilarControlChemCheck()
def _run(self, query: str) -> str:
"""Checks if compound is a controlled chemical. Input CAS number."""
data_path = pkg_resources.resource_filename("chemcrow", "data/chem_wep_smi.csv")
cw_df = pd.read_csv(data_path)
try:
if is_smiles(query):
query_esc = re.escape(query)
found = (
cw_df["smiles"]
.astype(str)
.str.contains(f"^{query_esc}$", regex=True)
.any()
)
else:
found = (
cw_df["cas"]
.astype(str)
.str.contains(f"^\({query}\)$", regex=True)
.any()
)
if found:
return (
f"The molecule {query} appears in a list of "
"controlled chemicals."
)
else:
# Get smiles of CAS number
try:
smi = pubchem_query2smiles(query)
except ValueError as e:
return str(e)
# Check similarity to known controlled chemicals
return self.similar_control_chem_check._run(smi)
except Exception as e:
return f"Error: {e}"
async def _arun(self, query: str) -> str:
"""Use the tool asynchronously."""
raise NotImplementedError()