In [None]:
import pandas as pd
import requests
import time
import openpyxl
import urllib
import selfies as sf
from selenium import webdriver
from selenium.webdriver.common.by import By
from selenium.webdriver.support.ui import WebDriverWait
from selenium.webdriver.support import expected_conditions as EC
import re
import json
import openai
import os
import tiktoken
import numpy as np
from collections import defaultdict


def auto_fill(file_path, sheet_name):
    """
    Automatically fills in the blanks for the specified columns in an Excel sheet.
    
    Parameters:
    - file_path: Path to the Excel file.
    - sheet_name: Name of the sheet to process.
    
    Returns:
    - Updated DataFrame with filled values.
    """
    import pandas as pd
    
    # Load the data from the specified sheet
    data = pd.read_excel(file_path, sheet_name=sheet_name)
    
    # Columns to fill
    columns_to_fill = [
        "Common names", "CAS number", "SMILES code (SciFinder)", 
        "Canonical SMILES (PubChem)", "Input IUPAC name (cactus)", "Input SMILES (cactus)"
    ]
    
    # Forward fill the specified columns
    data[columns_to_fill] = data[columns_to_fill].ffill()
    
    # Save the updated data back to the same Excel file
    with pd.ExcelWriter(file_path, engine='openpyxl', mode='a') as writer:
        data.to_excel(writer, sheet_name=sheet_name, index=False)
    
    return data


def fetch_pubchem_data(file_path):
    # Read the Excel file
    df = pd.read_excel(file_path, engine='openpyxl')
    
    # Create empty lists to store the results
    iupac_names = [""] * len(df)
    canonical_smiles_list = [""] * len(df)
    
    def fetch_data_for_smiles(index, smiles):
        # Directly assign "Invalid" if the input is "Invalid"
        if smiles == "Invalid":
            iupac_names[index] = "Invalid"
            canonical_smiles_list[index] = "Invalid"
            return
    
        def make_api_call(smiles_to_use):
            response = requests.get(base_url.format(smiles_to_use))
            data = response.json()
            iupac_name = data["PropertyTable"]["Properties"][0].get("IUPACName", "Unknown")
            canonical_smiles = data["PropertyTable"]["Properties"][0].get("CanonicalSMILES", "Unknown")
            return iupac_name, canonical_smiles
    
        try:
        # First attempt to make the API call
            iupac_name, canonical_smiles = make_api_call(smiles)
            iupac_names[index] = iupac_name
            canonical_smiles_list[index] = canonical_smiles
        except Exception as e:
            # If there's an error and the smiles string contains a "#"
            if "#" in smiles:
                # Replace "#" with "%23" and retry the API call
                corrected_smiles = smiles.replace("#", "%23")
                print(f"Retrying for row {index + 1} with corrected SMILES: {corrected_smiles}")
                try:

                    iupac_name, canonical_smiles = make_api_call(corrected_smiles)
                    iupac_names[index] = iupac_name
                    canonical_smiles_list[index] = canonical_smiles
                except Exception as e2:
                    print(f"Error processing row {index + 1} with corrected SMILES: {corrected_smiles}. Error: {e2}")
                    iupac_names[index] = "Error"
                    canonical_smiles_list[index] = "Error"
            else:
                print(f"Error processing row {index + 1} with SMILES: {smiles}. Error: {e}")
                iupac_names[index] = "Error"
                canonical_smiles_list[index] = "Error"

    
    # Base URL for PubChem API
    base_url = "https://pubchem.ncbi.nlm.nih.gov/rest/pug/compound/smiles/{}/property/IUPACName,CanonicalSMILES/JSON"
    
    # Process each row in the "Output cactus SMILES" column for the first time
    for index, smiles in enumerate(df["Output SMILES"]):
        print(f"Processing row {index + 1} with SMILES: {smiles}")  # Print progress
        fetch_data_for_smiles(index, smiles)
    
    # Retry for errors
    for _ in range(2):  # Two more attempts
        error_indices = [i for i, name in enumerate(iupac_names) if name == "Error"]
        if not error_indices:
            break  # No errors, break out of the loop

        prev_error_count = len(error_indices)
        time.sleep(30)
        for index in error_indices:
            smiles = df["Output SMILES"].iloc[index]
            fetch_data_for_smiles(index, smiles)
            time.sleep(2)
        # Check if the number of errors has decreased
        current_error_count = len([i for i, name in enumerate(iupac_names) if name == "Error"])
        if current_error_count >= prev_error_count:
            break  # No improvement, break out of the loop

    # Check for remaining errors
    final_error_indices = [i for i, name in enumerate(iupac_names) if name == "Error"]
    if final_error_indices:
        error_rows = ", ".join(str(i+1) for i in final_error_indices)
        print(f"\nCompleted for {file_path}, but rows {error_rows} have errors and need to be verified. Please correct the errors and copy and paste the columns to the main Sheet1 from sheet pubchem output before moving to the next step.")
    
    # Assign the results to the dataframe
    df["Output IUPAC name"] = iupac_names
    df["Output Canonical SMILES"] = canonical_smiles_list
    
    # Save the updated dataframe to a new sheet named "output" in the same Excel file
    with pd.ExcelWriter(file_path, engine='openpyxl', mode='a') as writer:
        df.to_excel(writer, sheet_name="pubchem output", index=False)

        
def request_castus_SMILES(file_path, base_url="https://cactus.nci.nih.gov/chemical/structure/"):
    df = pd.read_excel(file_path, engine='openpyxl')
    
    total_rows = len(df)

    for idx, row in df.iterrows():
        # Skip if "Output cactus SMILES" is already filled
        if pd.notna(row["Output cactus SMILES"]):
            continue

        # If "Output Canonical SMILES" or "Output IUPAC name" or input is "Invalid", set "Output cactus SMILES" to "Invalid"
        if row["Output Canonical SMILES"] == "Invalid" or row["Output IUPAC name"] == "Invalid" or row['Output SMILES'] =="Invalid":
            df.at[idx, 'Output cactus SMILES'] = "Invalid"
            continue

        print(f"Processing row {idx + 1} of {total_rows}...")

        # First, try fetching with "Output Canonical SMILES"
        response = requests.get(base_url + urllib.parse.quote(row['Output SMILES']) + "/SMILES")

        # If that fails, try "Output IUPAC name"
       # if response.status_code == 404:
         #   response = requests.get(base_url + urllib.parse.quote(row['Output IUPAC name']) + "/SMILES")

        # If the response code is 200 (OK), update the "Output cactus SMILES" column
        if response.status_code == 200:
            df.at[idx, 'Output cactus SMILES'] = response.text
            #print(f"SMILES code {response.text} was found using {row['Output SMILES'] if response.text in row['Output SMILES'] else row['Output IUPAC name']}")
        else:
            df.at[idx, 'Output cactus SMILES'] = "Invalid"
            if not (row['Output SMILES'] == "Invalid"):
                print(f"SMILES code {urllib.parse.quote(row['Output SMILES'])} was not found.") 

    # Save the final DataFrame to a new Excel file
    with pd.ExcelWriter(file_path, engine='openpyxl', mode='a') as writer:
        df.to_excel(writer, sheet_name="castus output", index=False)
    print(f"Data saved to the 'castus output' sheet in {file_path} for excel {file_path}. Please check if there are any errors and copy and paste the columns to the main Sheet1 from sheet 'castus output' before moving to the next step.")
    
def smiles_to_selfies(file_path):
    # Read the Excel file
    df = pd.read_excel(file_path, engine='openpyxl')

    # Process "Input SMILES (cactus)" column
    for idx, row in df.iterrows():
        smiles_value = row["Input SMILES (cactus)"]
        if smiles_value == "Invalid":
            df.at[idx, "Input SELFIES"] = "Invalid"
        else:
            try:
                selfies_value = sf.encoder(smiles_value)
                df.at[idx, "Input SELFIES"] = selfies_value
            except sf.EncoderError:
                print(f"Error encoding input SMILES at row {idx + 1}: {smiles_value}")
                df.at[idx, "Input SELFIES"] = "Invalid"

    # Process "Output cactus SMILES" column
    for idx, row in df.iterrows():
        smiles_value = row["Output cactus SMILES"]
        if smiles_value == "Invalid":
            df.at[idx, "Output SELFIES"] = "Invalid"
        else:
            try:
                selfies_value = sf.encoder(smiles_value)
                df.at[idx, "Output SELFIES"] = selfies_value
            except sf.EncoderError:
                print(f"Error encoding output SMILES at row {idx + 1}: {smiles_value}")
                df.at[idx, "Output SELFIES"] = "Invalid"


    with pd.ExcelWriter(file_path, engine='openpyxl',mode='a') as writer:
        df.to_excel(writer, sheet_name="SELFIES output", index=False)

    print(f"Updated data saved to {file_path}. Please check if there are any errors and copy and paste the columns to the main Sheet1 from sheet 'SELFIES output' before moving to the next step. ")

            
def smiles_to_iupac_name(KeyWord):
    
    if KeyWord =="Invalid" or "":
        return "Invalid"
    
    # Initialize Chrome browser
    browser = webdriver.Chrome()
    
    # Navigate to the specified URL
    browser.get("https://app.syntelly.com/smiles2iupac")
    
    # Find the input element on the webpage
    input_1 = browser.find_element(By.CSS_SELECTOR, 'input[aria-invalid="false"]')
    
    # Send the keyword to the input element and press Enter
    input_1.send_keys(KeyWord)
    input_1.send_keys("\n")

    # Wait for the page to load the results
    try:
        wait = WebDriverWait(browser, 20)
        #div_element = wait.until(EC.presence_of_element_located((By.CSS_SELECTOR, 'div.sc-gQSkpc.eWIKTX')))
        div_element = wait.until(EC.presence_of_element_located((By.CSS_SELECTOR, 'div.sc-hybRYi.jJhBHv')))
    except:
        print("Time out, return Unknown")
        return "Unknown"

    # Get the HTML content of the page
    html = browser.page_source
    try:
        # Use regex to find the string that starts with "Results: " and ends with "</h4>"
        FindStr = re.compile(r'Results: (.*?)</h4>')
        result = re.findall(FindStr, html)[0]

        # If the result is a space, return "Invalid"
        if result == " " or result == "":
            return "Unknown"

        # Find the position of the first comma in the result
        comma_index = result.find(', ')
        if comma_index != -1:
            # If a comma exists, truncate the string up to the first comma
            truncated_result = result[:comma_index]
        else:
            # If no comma exists, return the entire result string
            truncated_result = result
        return truncated_result
    except:
        return "Unknown"
def generate_iupac_names(filename):
    # Load the data from the given filename
    df = pd.read_excel(filename)
    
    # Create a new column for IUPAC names if it doesn't exist
    if 'Output IUPAC name' not in df.columns:
        df['Output IUPAC name'] = ""
    
    # Counters
    total_updates = 0
    blanks_filled = 0
    unknown_or_invalid_updated = 0
    
    for i, smiles in enumerate(df['Output cactus SMILES']):
        # Check if the value is neither empty nor "Invalid"
        if pd.notna(smiles) and smiles != "Invalid":
            # Check if the 'Output IUPAC name' cell is empty, "Unknown", or "Error"
            current_iupac_name = df.loc[i, 'Output IUPAC name']
            if pd.isna(current_iupac_name) or current_iupac_name in ["Unknown", "Error"]:
                try:
                    # Generate the IUPAC name using the provided function
                    iupac_name = smiles_to_iupac_name(smiles)
                    # Assign the generated IUPAC name to the "Output IUPAC name" column in the same row
                    df.loc[i, 'Output IUPAC name'] = iupac_name
                    
                    # Update counters
                    total_updates += 1
                    if pd.isna(current_iupac_name):
                        blanks_filled += 1
                        print(f"Row {i+1}: Filled blank with {iupac_name}")
                    elif current_iupac_name in ["Unknown", "Error"]:
                        unknown_or_invalid_updated += 1
                        print(f"Row {i+1}: Updated '{current_iupac_name}' to {iupac_name}")
                    
                except Exception as e:
                    print(f"Exception at row {i}: {e}")

    # Save the modified dataframe to a new sheet called "IUPAC output" in the same Excel file
    with pd.ExcelWriter(filename, engine='openpyxl', mode='a') as writer:
        df.to_excel(writer, sheet_name='IUPAC output', index=False)

    # Print summary
    print(f"\nSummary:")
    print(f"Total updates made: {total_updates}")
    print(f"Number of blanks filled: {blanks_filled}")
    print(f"Number of 'Unknown' or 'Invalid' entries updated: {unknown_or_invalid_updated}")
    print(f"Data saved to the 'IUPAC output' sheet in {filename} for {file_path}. Please check if there are any errors and copy and paste the columns to the main Sheet1 from sheet 'IUPAC output' before moving to the next step.")
    

def double_check(file_path):
    # Read the Excel file into a pandas DataFrame using the provided file path.
    df = pd.read_excel(file_path, engine='openpyxl')
    
    # Define the columns to check, the corresponding base URLs, endpoints, and any values to skip.
    columns_to_check = [
        ("Input SMILES (cactus)", "https://cactus.nci.nih.gov/chemical/structure/", "/SMILES", []),
        ("Input IUPAC name (cactus)", "https://cactus.nci.nih.gov/chemical/structure/", "/iupac_name", []),
        ("Output IUPAC name", "https://cactus.nci.nih.gov/chemical/structure/", "/iupac_name", ["Unknown", "Invalid"]),
        ("Output cactus SMILES", "https://cactus.nci.nih.gov/chemical/structure/", "/SMILES", ["Unknown", "Invalid"])
    ]
    
    # Initialize a counter for the number of values checked.
    checked_count = 0
    
    # Iterate over the columns to check.
    for column_name, base_url, endpoint, skip_values in columns_to_check:
        previous_value = None  # Keep track of the previous value to identify duplicates.
        
        # For each column, iterate over each row in the DataFrame.
        for idx, row in df.iterrows():
            input_value = row[column_name]
            
            # If the input value is the same as the previous row, skip checking.
            if input_value == previous_value:
                continue
            
            # Update the previous value tracker.
            previous_value = input_value
            
            # Additional check for "Output IUPAC name" based on "Output Canonical SMILES" column.
            if column_name == "Output IUPAC name" and (row["Output Canonical SMILES"] in ["Unknown", "Invalid"]):
                continue
            
            # Check if the value in the current row and column is NaN or if it's in the list of values to be skipped.
            if pd.isna(input_value) or (skip_values and input_value in skip_values):
                continue

            # Replace any '#' characters with '%23' for URL encoding.
            url_value = input_value.replace('#', '%23')

            # Make an HTTP request to the corresponding URL to validate the url_value.
            try: 
                response = requests.get(base_url + url_value + endpoint)
                # If the response is more than 100 words, truncate it.
                truncated_response = ' '.join(response.text.split()[:100])
                # Check conditions for printing error messages
                if "IUPAC" in column_name and response.status_code == 404:
                    # Skip printing for IUPAC names with 404 response
                    pass
                elif response.status_code != 200 or truncated_response.lower() != input_value.lower():
                    print(f"Error in row {idx + 2} with CAS number {row['CAS number']}. Input: {input_value}, Response: {truncated_response}")
            except:
                
                print(f"ConnectionError happens in row {idx + 2} for {url_value}")

            # Increment the counter since a value was checked.
            checked_count += 1
            # Check if the current count is a multiple of 100
            if checked_count % 100 == 0:
                print(f"Checked {checked_count} values so far.")

    # Print the final number of values checked.
    print(f"{checked_count} values have been checked for Sheet1 in {file_path} and done.")
    return


def generate_training_data(file_path):


    if "Method " in file_path and file_path.endswith(".xlsx"):
        method_type = file_path.split("Method ")[1].split(".")[0]
        if len(method_type) == 1:  # Check if the method_type is a single character
            print(method_type)
        else:
            print("Error: Invalid method type format!")
    else:
        print("Error: Invalid file_path format!")

    method_type = file_path.split(" ")[1][0]
    # Define the action maps for each method type
    action_maps = {
        "S": {
            1: "Introduce or remove a methyl group from the ring.",
            2: "Introduce or remove a hydroxyl group from the ring.",
            3: "Introduce or remove an amino group from the ring.",
            4: "Introduce or remove a nitro group from the ring.",
            5: "Introduce or remove a fluoro group to the ring."
        },
        "I": {
            1: "Insert or remove an unsubstituted phenyl ring at the connection where the carboxylate group is directly attached to either a ring, C=C, C#C, or N=N, ensuring para-positioning.",
            2: "Insert or remove two carbons along with a triple bond at the connection where the carboxylate group is directly attached to either a ring, C=C, C#C, or N=N.",
            3: "Insert or remove two carbons along with a double bond at the connection where the carboxylate group is directly attached to either a ring, C=C, C#C, or N=N.",
            4: "Insert or remove an azo group (-N=N-) at the connection where the carboxylate group is directly attached to either a ring, C=C, C#C, or N=N."
        },
        "R": {
            1: "Replace a carbon atom in the ring with nitrogen, or vice versa.",
            2: "Replace a carbon atom in the ring with oxygen, or vice versa.",
            3: "Replace a carbon atom in the ring with sulfur, or vice versa."
        },
        "P": {
            1: "Shift the position of a COOH group within any ring type to another position on the same ring.",
            2: "Relocate the position of N donor, excluding NH, within any ring type to another position on the same ring."
        }
    }

    # Select the action map based on the method type
    action_map = action_maps[method_type]

    # Construct the mutation actions section of the prompt
    mutation_actions = "\n".join([f"({key}) {value}" for key, value in action_map.items()])

    # Define the objectives and mutation issues for each method type
    objectives = {
        "S": "introduce new functional groups or alter existing ones to the linker, then provide the correct molecular representation for the modified linker",
        "I": "insert or delete a linker expansion spacer like phenyl ring, double bond, triple bond, or azo group specifically at the location where a carboxylate group is directly connected to either a ring, a C=C double bond, a C#C triple bond, or an N=N azo group within the linker",
        "R": "swap out atoms in the linker with different heteroatoms (e.g., replace a carbon atom with a nitrogen or sulfur atom), while adhering to general chemical rules and bonding constraints, such as ensuring ring stability and proper valence for atoms, then provide the correct correct molecular representation for the modified linker",
        "P": "change the position of coordination sites, such as COOH or N, within aromatic or non-aromatic rings including 5-membered, 6-membered, 7-membered, and fused rings, then provide the correct correct molecular representation for the modified linker"
    }

    mutation_issues = {
        "S": "(e.g., it lacks a ring or a suitable substitution site)",
        "I": "(e.g., it lacks a ring or a suitable insertion site between carboxylate and qualified qualified structural groups mentioned above)",
        "R": "(e.g., it lacks a ring or a suitable substitution site)",
        "P": "(e.g., it lacks a ring or a suitable position for the coordination site shift)"
    }

    # Define the common prompt with placeholders
    common_prompt = f"""You are an AI assistant with expertise in organic chemistry. Your task is to make theoretical modifications to a given {{desc}} of a MOF linker. {{additional_info}} Your objective is {objectives[method_type]}. You should never remove or modify the carboxylate groups, as they are essential to MOF linkers. The user can choose from {len(action_map)} mutation actions:

    {mutation_actions}

    The user will first specify the desired mutation action, followed by 'Action: '. In the next line, the user will provide the {{type}} of the MOF linker to be mutated, starting with 'Compound: '.

    Your response should begin with 'New Compound: ', followed by the updated {{type}}. If the requested mutation isn't chemically feasible, due to bonding constraints or if the given structure isn't compatible with the mutation {mutation_issues[method_type]}, you should respond with 'New Compound: Invalid'."""

    # Specific parts for each type of data
    prompts = {
        'smiles': common_prompt.format(
            desc='SMILES code', 
            additional_info='', 
            type='SMILES code'
        ),
        'selfies': common_prompt.format(
            desc='SELFIES string', 
            additional_info='Here SELFIES (SELF-referencIng Embedded Strings) is a string-based representation of molecules. Every SELFIES string corresponds to a valid molecule, similar to the way Canonical SMILES representations work. ',
            type='SELFIES string'
        ),
        'iupac': common_prompt.format(
            desc='IUPAC name', 
            additional_info='', 
            type='IUPAC name'
        )
    }


    # Function to construct the user message
    def construct_user_message(row, data_type):

    
        input_columns_map = {
            'smiles': 'Input SMILES (cactus)',
            'canonical smiles': 'Input Canonical SMILES (PubChem)',
            'selfies': 'Input SELFIES',
            'iupac': 'Input IUPAC name (cactus)'
        }
    
        action = action_map.get(row["Action"], "Invalid Action")
        return f"Action: {action}\nCompound: {row[input_columns_map[data_type]]}"

    def construct_assistant_message(row, data_type):
        output_columns_map = {
            'smiles': 'Output cactus SMILES',
            'canonical smiles': 'Output Canonical SMILES',
            'selfies': 'Output SELFIES',
            'iupac': 'Output IUPAC name'
        }
    
        return f"New Compound: {row[output_columns_map[data_type]]}"

    # 1. Load the data from the provided Excel file
    data = pd.read_excel(file_path)

    # For model using SMILES code
    data["Modified User Message SMILES"] = data.apply(lambda row: construct_user_message(row, 'smiles'), axis=1)
    data["Modified Assistant Message SMILES"] = data.apply(lambda row: construct_assistant_message(row, 'smiles'), axis=1)

    # For model using SMILES code (the canonical smiles choice is optional)
    #data["Modified User Message Canonical SMILES"] = data.apply(lambda row: construct_user_message(row, 'canonical smiles'), axis=1)
    #data["Modified Assistant Message Canonical SMILES"] = data.apply(lambda row: construct_assistant_message(row, 'canonical smiles'), axis=1)

    # For model using SELFIES string
    data["Modified User Message SELFIES"] = data.apply(lambda row: construct_user_message(row, 'selfies'), axis=1)
    data["Modified Assistant Message SELFIES"] = data.apply(lambda row: construct_assistant_message(row, 'selfies'), axis=1)

    # For model using IUPAC name
    data["Modified User Message IUPAC"] = data.apply(lambda row: construct_user_message(row, 'iupac'), axis=1)
    data["Modified Assistant Message IUPAC"] = data.apply(lambda row: construct_assistant_message(row, 'iupac'), axis=1)



    # 2. Model 1R - Using SMILES code
    model_1R = pd.DataFrame({
        "system": [prompts['smiles']] * len(data),
        "user": data["Modified User Message SMILES"],
        "assistant": data["Modified Assistant Message SMILES"]
    })
    model_1R.to_excel("Model 1"+method_type+".xlsx", index=False)


    # 3. Model 2R - Using SELFIES string
    model_2R = pd.DataFrame({
        "system": [prompts['selfies']] * len(data),
        "user": data["Modified User Message SELFIES"],
        "assistant": data["Modified Assistant Message SELFIES"]
    })
    model_2R.to_excel("Model 2"+method_type+".xlsx", index=False)



    # 4. Model 3R - Using IUPAC name with filter
    filtered_data_3R = data[data["Output IUPAC name"] != "Unknown"]
    model_3R = pd.DataFrame({
        "system": [prompts['iupac']] * len(filtered_data_3R),
        "user": filtered_data_3R["Modified User Message IUPAC"],
        "assistant": filtered_data_3R["Modified Assistant Message IUPAC"]
    })
    model_3R.to_excel("Model 3"+method_type+".xlsx", index=False)

    
    print ("Training data for Model 1 2 3 for Method " + method_type + " generated.")
    
    return

def generate_json_from_excel(model):
    #e.g. model= "2R"
    # Read the .xlsx file, ensuring 'N/A' is treated as a string
    data = pd.read_excel("Model "+model+".xlsx", na_values=[], keep_default_na=False)

    # Define the path for the output .jsonl file
    output_path = "Model "+model+"_json.jsonl"

    # Open the file in write mode
    with open(output_path, 'w') as file:
        # Iterate over each row in the dataframe
        for index, row in data.iterrows():
            # Create the JSON object for each row
            json_obj = {
                "messages": [
                    {"role": "system", "content": row["system"]},
                    {"role": "user", "content": row["user"]},
                    {"role": "assistant", "content": row["assistant"]}
                ]
            }
            # Write the JSON object to the file
            file.write(json.dumps(json_obj) + '\n')

    print(f"Data of Model {model} has been successfully converted and saved to {output_path}")
    return output_path

    
def check_json(model):
    #function provide by OpenAI
    #we specify the data path and open the JSONL file

    data_path = "Model "+model+"_json.jsonl"

    # Load dataset
    with open(data_path) as f:
        dataset = [json.loads(line) for line in f]

    # We can inspect the data quickly by checking the number of examples and the first item

    # Initial dataset stats
    print("Num examples:", len(dataset))
    print("First example:")
    for message in dataset[0]["messages"]:
        print(message)

    # Now that we have a sense of the data, we need to go through all the different examples and check to make sure the formatting is correct and matches the Chat completions message structure

    # Format error checks
    format_errors = defaultdict(int)

    for ex in dataset:
        if not isinstance(ex, dict):
            format_errors["data_type"] += 1
            continue

        messages = ex.get("messages", None)
        if not messages:
            format_errors["missing_messages_list"] += 1
            continue

        for message in messages:
            if "role" not in message or "content" not in message:
                format_errors["message_missing_key"] += 1

            if any(k not in ("role", "content", "name") for k in message):
                format_errors["message_unrecognized_key"] += 1

            if message.get("role", None) not in ("system", "user", "assistant"):
                format_errors["unrecognized_role"] += 1

            content = message.get("content", None)
            if not content or not isinstance(content, str):
                format_errors["missing_content"] += 1

        if not any(message.get("role", None) == "assistant" for message in messages):
            format_errors["example_missing_assistant_message"] += 1

    if format_errors:
        print("Found errors:")
        for k, v in format_errors.items():
            print(f"{k}: {v}")
    else:
        num_error=0
        print("No errors found")

    # Beyond the structure of the message, we also need to ensure that the length does not exceed the 4096 token limit.

    # Token counting functions
    encoding = tiktoken.get_encoding("cl100k_base")

    # not exact!
    # simplified from https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
    def num_tokens_from_messages(messages, tokens_per_message=3, tokens_per_name=1):
        num_tokens = 0
        for message in messages:
            num_tokens += tokens_per_message
            for key, value in message.items():
                if not isinstance(value, str):
                    print(f"Error in message: {message}")
                    print(f"Invalid value for key '{key}': {value} (type: {type(value)})")
                    continue
                num_tokens += len(encoding.encode(value))
                if key == "name":
                    num_tokens += tokens_per_name
        num_tokens += 3
        return num_tokens

    def num_assistant_tokens_from_messages(messages):
        num_tokens = 0
        for message in messages:
            if message["role"] == "assistant":
                num_tokens += len(encoding.encode(message["content"]))
        return num_tokens

    def print_distribution(values, name):
        print(f"\n#### Distribution of {name}:")
        print(f"min / max: {min(values)}, {max(values)}")
        print(f"mean / median: {np.mean(values)}, {np.median(values)}")
        print(f"p5 / p95: {np.quantile(values, 0.1)}, {np.quantile(values, 0.9)}")

    # Last, we can look at the results of the different formatting operations before proceeding with creating a fine-tuning job:

    # Warnings and tokens counts
    n_missing_system = 0
    n_missing_user = 0
    n_messages = []
    convo_lens = []
    assistant_message_lens = []

    for ex in dataset:
        messages = ex["messages"]
        if not any(message["role"] == "system" for message in messages):
            n_missing_system += 1
        if not any(message["role"] == "user" for message in messages):
            n_missing_user += 1
        n_messages.append(len(messages))
        convo_lens.append(num_tokens_from_messages(messages))
        assistant_message_lens.append(num_assistant_tokens_from_messages(messages))

    print("Num examples missing system message:", n_missing_system)
    print("Num examples missing user message:", n_missing_user)
    print_distribution(n_messages, "num_messages_per_example")
    print_distribution(convo_lens, "num_total_tokens_per_example")
    print_distribution(assistant_message_lens, "num_assistant_tokens_per_example")
    n_too_long = sum(l > 4096 for l in convo_lens)
    print(f"\n{n_too_long} examples may be over the 4096 token limit, they will be truncated during fine-tuning")

    # Pricing and default n_epochs estimate
    MAX_TOKENS_PER_EXAMPLE = 4096

    MIN_TARGET_EXAMPLES = 100
    MAX_TARGET_EXAMPLES = 25000
    TARGET_EPOCHS = 3
    MIN_EPOCHS = 1
    MAX_EPOCHS = 25

    n_epochs = TARGET_EPOCHS
    n_train_examples = len(dataset)
    if n_train_examples * TARGET_EPOCHS < MIN_TARGET_EXAMPLES:
        n_epochs = min(MAX_EPOCHS, MIN_TARGET_EXAMPLES // n_train_examples)
    elif n_train_examples * TARGET_EPOCHS > MAX_TARGET_EXAMPLES:
        n_epochs = max(MIN_EPOCHS, MAX_TARGET_EXAMPLES // n_train_examples)

    n_billing_tokens_in_dataset = sum(min(MAX_TOKENS_PER_EXAMPLE, length) for length in convo_lens)
    print(f"Dataset has ~{n_billing_tokens_in_dataset} tokens that will be charged for during training")
    print(f"By default, you'll train for {n_epochs} epochs on this dataset")
    print(f"By default, you'll be charged for ~{n_epochs * n_billing_tokens_in_dataset} tokens")    
    
    return num_error 
    
def start_ft(model):
    new_upload = openai.File.create(
      file=open( "Model "+model+"_json.jsonl", "rb"),
      purpose='fine-tune'
    )
    print("\n")
    print(f"Data of Model {model} has been uploaded. Please wait for 2 minutes to start the job.")
    print(new_upload)
    time.sleep(120)


    for _ in range(100):  # Max attempts
        print(f"Try to create fine-tuning job for Model {model}. The job ID is {new_upload.id}")
        try:
            ft_model = openai.FineTuningJob.create(training_file=new_upload.id, model="gpt-3.5-turbo")
            print("Fine-tuning job created successfully! Please check the email for the model ID.")
            print(ft_model)
            print("\n")
            break
        except openai.error.InvalidRequestError as e:
            if "still being processed" in str(e):
                print("File is still being processed. Retrying in 30 seconds...")
                time.sleep(30)  # Wait for 30 seconds
            else:
                raise e
        except openai.error.RateLimitError as e:
            if "rate-limited" in str(e):
                if "per day"in str(e):
                    print("12 task per day limit reached.")
                    time.sleep(10000)  # Wait for 90 minutes
                print("Rate limit reached. Retrying in 10 minutes...")
                time.sleep(600)  # Wait for 10 minutes
            else:
                raise e
      
    else:
        print("Max attempts reached. Fine-tuning job could not be created.")
    return new_upload.id    





def get_ft_responses(model,  ft_model_id, user_messages_sample,action_number="all"):
    """
    Function to get AI responses based on the provided model, action number, and a list of user messages.

    Parameters:
    - model (str): Model name used to load the corresponding Excel file.
    - user_messages_sample (list): List of user messages for testing.
    - ft_model_id (str): Fine-tuned model ID to be used for AI completions.
    - action_number (str): The specific action number or "all". Default is "all".

    Returns:
    - List of AI's responses.
    """

    # 1. LOAD THE DATA
    # Read the .xlsx file, ensuring 'N/A' is treated as a string
    data = pd.read_excel(f"Model {model}.xlsx", na_values=[], keep_default_na=False)

    # Extract the system message from the first non-header row of the 'system' column
    system_message = data['system'].iloc[0]

    
    # 2. EXTRACT ACTION CHOICES
    # Use regular expressions to find all action patterns (like (1), (2), etc.) from the system message
    action_patterns = re.findall(r'\(\d\)', system_message)

    # Count the number of actions
    number_of_actions = len(action_patterns)

    # Extract descriptions for each action for later use in user messages
    action_descriptions = [re.search(r'\({}\) (.+?)\.\n'.format(i+1), system_message).group(1) for i in range(number_of_actions)]

    # 3. ADJUST USER MESSAGES BASED ON ACTION_NUMBER
    adjusted_user_messages = []
    if action_number == "all":
        for compound in user_messages_sample:
            for i, action_description in enumerate(action_descriptions):
                # Combine action and compound to create a user message
                message = "Action: ({}) {}\nCompound: {}".format(i + 1, action_description, compound)
                adjusted_user_messages.append(message)
    else:
        # Validate the action number
        if 1 <= int(action_number) <= number_of_actions:
            for compound in user_messages_sample:
                action_description = action_descriptions[int(action_number) - 1]
                message = "Action: ({}) {}\nCompound: {}".format(action_number, action_description, compound)
                adjusted_user_messages.append(message)
        else:
            raise ValueError(f"Invalid action_number. It should be between 1 and {number_of_actions} or 'all'.")

    # 4. INTERACT WITH THE AI MODEL
    # Loop through the adjusted user messages and send them to the AI model
    responses = []
    for user_message in adjusted_user_messages:
  
        completion = openai.ChatCompletion.create(
            model=ft_model_id,
            messages=[
                {"role": "system", "content": system_message},
                {"role": "user", "content": user_message}
            ]
        )

        #print(completion.choices[0].message)
        # Append the AI's response to the responses list
        responses.append(completion.choices[0].message)
    
    return responses


def evaluation(model_type):
    output_column_list = ["GPT-3.5 (SMILES)",
                      "GPT-3.5 (IUPAC)",
                      "GPT-4 (SMILES)",
                      "GPT-4 (IUPAC)",
                      "FT Model 1"+model_type+" (SMILES)",
                      "FT Model 2"+model_type+" (SELFIES)",
                      "FT Model 3"+model_type+" (IUPAC)"]
                      #"FT Model 4"+model_type+" (SMILES)",
                     # "FT Model 5"+model_type+" (SELFIES)",
                      # "FT Model 6"+model_type+" (IUPAC)"]  Can include more models if needed

    model_name_list = [ "1"+model_type, #using same prompt as 1R (SMILES)
                   "3"+model_type,  #using same prompt as 3R (IUPAC)
                   "1"+model_type,
                   "3"+model_type,
                   "1"+model_type,
                   "2"+model_type,
                   "3"+model_type]
                   #"4"+model_type,
                  # "5"+model_type,
                   #"6"+model_type]


    file_path ="Evaluation Medtod "+model_type+".xlsx"

    # Read the Excel file
    df = pd.read_excel(file_path)

    # Initialize an empty DataFrame for the output
    output_df = pd.DataFrame(index=df.index)

    # Define a dictionary to map output column description to input column names
    input_column_map = {
        "(SMILES)": "Input SMILES (cactus)",
        "(IUPAC)": "Input IUPAC name (cactus)",
        "(SELFIES)": "Input SELFIES"
    }

    # Iterate over each output column
    for idx, (output_column, model_name, model_id) in enumerate(zip(output_column_list, model_name_list, model_id_list)):
    
        print(f"Processing column {idx+1}/{len(output_column_list)}: {output_column} using system prompt model: {model_name} for model {model_id}")
    
        # Determine the input column based on the description in the output column name
        for desc, input_col in input_column_map.items():
            if desc in output_column:
                input_column = input_col
                break
    
        # Adjust user messages based on the action_number column and the input column
        user_messages_sample = df[input_column].tolist()
        action_numbers = df["Action"].tolist()
    
        # Collect AI's responses
        ai_responses = []
        for uidx, (user_message, action_number) in enumerate(zip(user_messages_sample, action_numbers)):
            print(f"\tRequesting response for message {uidx+1}/{len(user_messages_sample)}")
        
            attempts = 0
            max_attempts = 5
            waiting_time = 10
            success = False
            while attempts < max_attempts and not success:
                try:
                    responses = get_ft_responses(model_name, model_id, [user_message], str(action_number))
                    # Process the JSON-like responses
                    for response_json in responses:
                        # Extract the part after "New Compound:"
                        compound = response_json["content"].split("New Compound: ")[-1].strip()
                        ai_responses.append(compound)
                    success = True
                except Exception as e:
                    
                    attempts += 1
                    print(f"\tError on attempt {attempts}. Retrying in {waiting_time} seconds...")
                    time.sleep(waiting_time)
                    waiting_time = waiting_time*2
                    if attempts == max_attempts:
                        print(f"\tFailed after {max_attempts} attempts. Moving on to the next message.")
                        ai_responses.append("Error")

        # Store the AI's responses in the output_df
        output_df[output_column] = ai_responses

    print("Processing complete!")

    # Update the original Excel file with the values from output_df
    with pd.ExcelWriter(file_path, engine='openpyxl', mode='a') as writer:
        output_df.to_excel(writer, sheet_name='Eva Output', index=False)
    
    #add two more columns to decode the SELFIES    
    # Open the 'Eva Output' sheet from the Excel file
    with pd.ExcelFile(file_path) as xls:
        eva_output_df = pd.read_excel(xls, 'Eva Output')

    # Function to decode the values or return "Invalid"
    def decode_or_invalid(value):
        if value == "Invalid":
            return "Invalid"
        try:
            return sf.decoder(value)
        except:
            return "Invalid"


    # Decode the values in "FT Model 2R (SELFIES)" and "FT Model 5R (SELFIES)"
    # and insert them right after the respective columns
    eva_output_df.insert(
        eva_output_df.columns.get_loc("FT Model 2"+model_type+" (SELFIES)") + 1,
        "FT Model 2"+model_type+" (SELFIES_Decode)",
        eva_output_df["FT Model 2"+model_type+" (SELFIES)"].apply(decode_or_invalid)
    )

    #uncomment code below if there are additional models need SELFIES conversion
   # eva_output_df.insert(
    #    eva_output_df.columns.get_loc("FT Model 5"+model_type+" (SELFIES)") + 1,
   #     "FT Model 5"+model_type+" (SELFIES_Decode)",
   #     eva_output_df["FT Model 5"+model_type+" (SELFIES)"].apply(decode_or_invalid)
   # )
    

    # Save the updated DataFrame to a new sheet named 'Eva Output Decode'
    with pd.ExcelWriter(file_path, engine='openpyxl', mode='a') as writer:
        eva_output_df.to_excel(writer, sheet_name='Eva Output Decode', index=False)
    
    return





# Fetch Chemical Representations

In [None]:
# Please double check the file path and method chosen!!
method_type = "X"  # This can be changed to "S", "I", "P" as needed
file_path = "Method " + method_type + ".xlsx"

In [None]:
fetch_pubchem_data(file_path)  #this function will check and retrieve the PubChem IUPAC names and Canomical smiles for given compound
                               #If the result is "Unknown", this compound is likely to be purely hyperthetical and have not been reported 

In [None]:
request_castus_SMILES(file_path) #this function will get a valid smiles code in a systematic way
                                #for all structures after mutation, regardless of they have been reported or not.

In [None]:
smiles_to_selfies(file_path)   #this function convert all castus smiles code to SELFIES strings and add to the excel table 

In [None]:
generate_iupac_names(file_path)  # this function will go through all compounds that do not have IUPAC names on PubChem
                                    #and retrieve their standard IUPAC name using syntelly

In [None]:
double_check(file_path)     # this function check all IUPAC names are valid and consistent

In [None]:
generate_training_data(file_path)  #this function generate a "system-user-assistant" prompt dialogue file for every mutation for training

# Training GPT Models

In [None]:
openai.api_key = "Replace with your real OpenAI API Key"     #e.g.  "ab-cdxyzABCkkkQQ2SsAb123DDeeFFggHh"
                                                            #For more information: https://openai.com/blog/openai-api

In [None]:
for i in range(1, 4):  # Loop from Model 1X to Model 4X, where X = S, I, R, P
    model = str(i) + method_type
    
    print("\n++++++++++++++++++++++++++++++++++++++++++++")
    print(f"Start to process Model {model}.\n")
    start_time = time.time()  # Start timing
    
    generate_json_from_excel(model) #first convert the excel file to json 
    check_json(model)  #using the code provided by OpenAI to check for json before training
    start_ft(model)   #upload the json to OpenAI and start fine tuning
    
    end_time = time.time()  # End timing
    elapsed_time = end_time - start_time  # Calculate elapsed time
    
    print(f"Finish Model {model}.\n")
    print(f"Model {model} took {elapsed_time:.2f} seconds to process.")
    print("\n++++++++++++++++++++++++++++++++++++++++++++")
    
# Optional: List 20 fine-tuning jobs
#openai.FineTuningJob.list(limit=20)    

# Evaluate FT Models

In [None]:
model_id_list = [ "gpt-3.5-turbo-0613" , #3.5 SMILES
                  "gpt-3.5-turbo-0613" , #3.5 IUPAC
                    "gpt-4-0613" ,#4 SMILES
                  "gpt-4-0613" ,#4 IUPAC
                 "ft:gpt-3.5-turbo-0613:ADD_Model_ID_Here",#replace with your model ID
               "ft:gpt-3.5-turbo-0613:ADD_Model_ID_Here",  #replace with your model ID
               "ft:gpt-3.5-turbo-0613:ADD_Model_ID_HereG",  #replace with your model ID]

evaluation(model_type)

In [None]:
# Below is a list of models trained on Methods S, I, R, and P that are ready for use.
# You can copy and paste the model ID into the cell above to run the evaluation.
# Ensure you use the exact same system prompt and user input format to guarantee the model functions as expected.
# For system prompt and user input samples, please refer to the function "generate_training_data" for more details.
# Alternatively, you can also obtain the system prompt by reading the "Model XY.xlsx" file, where X is a number from 1-3 and Y is either S, I, R, or P.
# Depending on the chemical representations (SMILES, SELFIES, IUPAC) and the mutation methods (insertion, replacement, etc.), the system prompt and user message may vary.


# model_type ="R"
model_id_list = [
                 "ft:gpt-3.5-turbo-0613:uc-berkeley::7vd4eEZu",#model 1R  SMILES
               "ft:gpt-3.5-turbo-0613:uc-berkeley::7veHJ0eR",  #model 2R  SELFIES
               "ft:gpt-3.5-turbo-0613:uc-berkeley::7vyL332G",  #model 3R  IUPAC
]




#model_type ="S"
model_id_list = [ 
                 "ft:gpt-3.5-turbo-0613:uc-berkeley::7wF4Wvdr",#model 1S  SMILES
               "ft:gpt-3.5-turbo-0613:uc-berkeley::7wGGcyfU",  #model 2S  SELFIES
               "ft:gpt-3.5-turbo-0613:uc-berkeley::7wHSe0sw",  #model 3S  IUPAC

]


#model_type ="I"
model_id_list = [ 
                 "ft:gpt-3.5-turbo-0613:uc-berkeley::7xJmyNlq",#model 1I  SMILES
               "ft:gpt-3.5-turbo-0613:uc-berkeley::7xKePzT5",  #model 2I  SELFIES
               "ft:gpt-3.5-turbo-0613:uc-berkeley::7xM2Vcbv",  #model 3I  IUPAC

]



#model_type ="P"
model_id_list = [
                 "ft:gpt-3.5-turbo-0613:uc-berkeley::7xiQHz21",#model 1P
               "ft:gpt-3.5-turbo-0613:uc-berkeley::7xjKObLF",  #model 2P
               "ft:gpt-3.5-turbo-0613:uc-berkeley::7xkDldW9",  #model 3P
]

