In [1]:
import sys

In [2]:
sys.path.append('/home/brendenpelkie/Code/LLM_organic_synthesis/')

In [3]:
import json

import ord_schema
from ord_schema import message_helpers, validations
from ord_schema.proto import dataset_pb2

from google.protobuf.json_format import MessageToDict

from ord_diff.io import ModelOutput
from ord_diff.utils import json_load
from ord_diff.evaluation import get_compounds

from google.protobuf import json_format
from ord_schema.proto import reaction_pb2

## Load LLM extractd output and define parsing logic

In [5]:
with open('../models_llama/adapter/finetune_20230731/infer-14.json', 'rt') as f:
    data = json.load(f)

In [10]:
def get_llm_compounds(record, prompt_template, prompt_header, response_header):
    """
    Parse ORD compounds from syntactically correct ORD reaction messages 

    :param record: individual record from LLM output file 
    :type record: dict
    :param prompt template: Full prompt template used
    :type prompt_template: str
    :param prompt_header: prompt template header
    :type prompt_header: str
    :param response_header: prompt respponse header
    :type response_header: str
    :return compounds_from_inputs: list of compounds identified as reactants/inputs
    :rtype compounds_from_inputs: list of ORD compounds, None if parsing error
    :return compounds_from_products: list of compounds identified from reaction products
    :rtype compounds_from_products: list of ORD compounds, None if parsing error
    """
    ref_string = record['output']
    raw = record['response']

    try:
        model_output = ModelOutput.from_raw_alpaca(
            raw=raw,
            ref=ref_string,
            identifier=record['reaction_id'],
            prompt_template=prompt_template,
            prompt_header=prompt_header,
            response_header=response_header,
            instruction=record['instruction'],
        )
    except:
        print(f'Parse error for reaction {record["reaction_id"]}')
        return None, None


    reference_message = json_format.Parse(model_output.ref, reaction_pb2.Reaction())
    # this would error out if inferred string (completion) is syntactically incorrect
    try:
        inferred_message = json_format.Parse(model_output.response, reaction_pb2.Reaction())
    except:
        #print('inferred message error')
        #print(record['reaction_id'])
        return None, None

    compounds_from_inputs = get_compounds(reference_message, extracted_from="inputs")
    compounds_from_outcomes = get_compounds(reference_message, extracted_from="outcomes")

    return compounds_from_inputs, compounds_from_outcomes

In [11]:
def parse_compound_name(compound):
    """
    Parse the string name from a single compound

    :param compound: chemical compound in ORD format
    :type compound: ORD compound
    :return name: name of compound
    :rtype name: str, None if parse error
    """
    comp_dict = MessageToDict(compound)
    try:
        return comp_dict['identifiers'][0]['value']
    except:
        return None

In [12]:
def get_llm_species(reaction, prompt_template, prompt_header, response_header):
    """
    Parse the chemical species extractd by LLM model from LLM output for a reaction

    :param reaction: individual record from LLM output file 
    :type reaction: dict
    :param prompt template: Full prompt template used
    :type prompt_template: str
    :param prompt_header: prompt template header
    :type prompt_header: str
    :param response_header: prompt respponse header
    :type response_header: str
    :return species: list of identified chemicals
    :rtype species: list of strings, 'FAILED' if parse error
    """
    compounds_in, compounds_out = get_llm_compounds(reaction, prompt_template, prompt_header, response_header)

    species = []
    if compounds_in is not None:
        for comp in compounds_in:
            species.append(parse_compound_name(comp))
    if compounds_out is not None:
        for comp in compounds_out:
            species.append(parse_compound_name(comp))

    if compounds_in is None:
        return 'FAILED'

    return species

## Load chem data extractor output and define parsing logic

This was generated sepearatly using script 'chemdataextract.py' b/c I couldn't get an environment set up with both CDE and jupyter. Follow instructions in readme.md to generate this file. 

In [13]:
chemdata = []
with open('chemdataextractor_output.jsonl', 'rt') as f:
    for line in f:
        chemdata.append(json.loads(line))


chemdata_dict = {reaction['reaction_id']:reaction for reaction in chemdata}

In [15]:
def find_CDE_compounds(CDE_output):
    """find compounds identified by chem data extractor
    
    :param CDE_output: direct output from ChemDataExtractor
    :type CDE_output: dict 
    :return identified_species: list of identified chemicals
    :rtype identified_species: list of str
    """
    identified_species = []
    for item in CDE_output:
        if 'names' in item.keys():
            identified_species.append(item['names'][0])
    return identified_species

## Parse identified species from both parsing methods

In [16]:
extracted_compounds = {reaction['reaction_id']:{} for reaction in data}

prompt_template = json_load("../../LLM_organic_synthesis/models_llama/adapter/USPTO-t900/params.json")['prompt_template']
prompt_header = prompt_template.split("\n")[0]
response_header = prompt_template.split("\n")[-2]

fail_count = 0

for reaction in data:

    rid = reaction['reaction_id']

    cde_reaction = chemdata_dict[rid]

    llm_species = get_llm_species(reaction, prompt_template, prompt_header, response_header)
    if llm_species =='FAILED':
        fail_count +=1
        continue
    cde_species = find_CDE_compounds(cde_reaction['chemdataextractor_output'])

    extracted_compounds[rid]['llm_species']= llm_species
    extracted_compounds[rid]['cde_species'] = cde_species



    



In [18]:
print(f'Number of non-parsable LLM outputs: {fail_count}')

Number of non-parsable LLM outputs: 46


## Compare extracted outputs and calculate metrics

In [20]:
def compare_outputs(LLM_species, cde_species):
    """
    Count overlap between LLM and chemdataextractor outputs
    """
    LLM_species = set(LLM_species)
    cde_species = set(cde_species)
    
    overlap_count = len(LLM_species.intersection(cde_species))
    
    llm_extras = LLM_species.difference(cde_species)
    cde_extras = cde_species.difference(LLM_species)
    
    llm_extras_count = len(llm_extras)
    cde_extras_count = len(cde_extras)
    
    return overlap_count, llm_extras_count, cde_extras_count, llm_extras, cde_extras 

In [25]:
total_overlap_count = 0
llm_extra_count = 0
cde_extra_count = 0

llm_total_count = 0
cde_total_count = 0

llm_exclusive_finds = []
cde_exclusive_finds = []

for reaction_id, reaction in extracted_compounds.items():

    if len(reaction) == 0:
        continue
    cde_species = reaction['cde_species']
    LLM_species = reaction['llm_species']
    
    #cde_species = find_CDE_compounds(cde_entry['chemdataextractor_output'])
    #identified_species = []
    #LLM_species = find_ORD_compounds(json.loads(llm_entry['output']), identified_species)
    
    
    overlap_count, llm_count, cde_count, llm_extras, cde_extras = compare_outputs(LLM_species, cde_species)
    
    total_overlap_count += overlap_count
    llm_extra_count += llm_count
    cde_extra_count += cde_count

    llm_total_count += len(LLM_species)
    cde_total_count += len(cde_species)
    
    llm_exclusive_finds.extend(list(llm_extras))
    cde_exclusive_finds.extend(list(cde_extras))
    

    

In [28]:
print(f'Total number of species IDd by both methods: {total_overlap_count}')
print(f'Species IDd exclusively by LLM method: {llm_extra_count}')
print(f'Species IDd exclusively by CDE method: {cde_extra_count}')
print(f'Species IDd in total by LLM: {llm_total_count}')
print(f'Species IDd in total by CDE: {cde_total_count}')


Total number of species IDd by both methods: 5888
Species IDd exclusively by LLM method: 2137
Species IDd exclusively by CDE method: 2917
Species IDd in total by LLM: 8122
Species IDd in total by CDE: 8805


In [31]:
len(set(llm_exclusive_finds))

795

In [32]:
len(set(cde_exclusive_finds))

846

In [33]:
llm_exclusive_finds

['solid',
 'title compound',
 '4-(4-{[(2S)-2,3-Dihydroxypropyl]oxy}phenyl)-2-methoxy-6-sulphanylpyridine-3,5-dicarbonitrile',
 'N-(4-Methoxybenzylidene)-4-carboxybenzenesulfonamide',
 'crude product',
 'ethylmagnesium bromide diethyl ether',
 'oil',
 'title compound',
 'desired material',
 'DCM',
 'title compound',
 'acetic acid',
 'title compound',
 'methylphenyl phosphinic acid, phenethyl ester',
 'white solid',
 'Pd/C',
 'title compound',
 'Ethylene Urea Ethylene Glycol Glyoxal',
 'crude product',
 'DCM',
 'title compound',
 'title compound',
 '4-(3-hydroxy-3-(5,6,7,8-tetrahydro-5,5,8,8-tetramethyl-2-naphthyl)-1-propynyl)phenylsulfinylmethane',
 'product',
 'Pd/C',
 'title compound',
 'title compound',
 '4-hydroxy-3-methoxy-2-nitrobenzonitrile',
 'palladium on charcoal',
 'title compound',
 'orange coloured oil',
 '1,4-Dihydro-2,6-dimethyl-3-methoxycarbonyl-4-(2-nitrophenyl)-5-(6-aminohexyloxy)carbonyl-pyridine',
 'product',
 'powder',
 'HCl MeOH',
 'palladium on carbon',
 'title co

In [34]:
cde_exclusive_finds

['SULF-8',
 'Nujol',
 'DMSO-d6',
 'nitrogen',
 '1H',
 'bromide-diethyl ether',
 'ethylmagnesium',
 '2H',
 '1H',
 'glacial acetic acid',
 '1H',
 'phenethyl ester',
 'dichloromethane',
 'Pd / C',
 '2H, ( α -- CH2 )',
 'CH2',
 'Celite',
 'COOH',
 'hydrogen',
 '1H',
 'd6-DMSO',
 '1H',
 'C26H32FN4O4S',
 'silica gel',
 'CDCl3',
 '1H',
 'C17H21FN6O',
 'DMSO-d6',
 'C27H26ClF3N4O',
 '2H',
 'CD3OD',
 '1H',
 'C13H20SO2',
 'H',
 'Pd',
 '1H',
 'C21H21ClNO3',
 'palladium',
 'Kieselguhr',
 'AP+',
 'charcoal',
 'C11H22N2O3',
 'CCl4',
 'C9H10F8O4C',
 'H',
 '5-bromo-2-formylfuran',
 'hexylboronic acid',
 'EtOAc',
 '1H',
 'Hex',
 "5'-OH",
 'N',
 'DMSO-d6',
 'CH2',
 'NH',
 'OCH2',
 "3'-OH",
 'C19H29N5O6',
 'CH3',
 'H',
 'OCH3',
 'argon',
 'H2',
 'HCl',
 'D2O',
 '1H',
 'carbon',
 'palladium',
 'CH2Cl2',
 'EtOAc',
 'MeOH',
 'Pd / C',
 '1H',
 '1 N hydrochloric acid',
 'DMSO-d6',
 'H',
 '1H',
 'DMSO-d6',
 'DMSO-d6',
 '2H',
 '1H',
 '— NH',
 'methanol',
 'hydrogen chloride',
 '3 / 2oxalate',
 '2H',
 'CDCl3',
 '