In [None]:
import http.client
import json
import pandas as pd
import xml.etree.ElementTree as ET
from bs4 import BeautifulSoup 
import time
from concurrent.futures import ThreadPoolExecutor
import threading
import multiprocessing
import glob
import traceback
import socket
from datetime import datetime
from sklearn.metrics import confusion_matrix, accuracy_score, precision_score, recall_score, f1_score
import numpy as np
from tqdm import tqdm
import time
import os
import shutil
import regex as re


In [None]:
ec_2017_2021_df = pd.read_csv('final_ec_combined_cdoc.csv')

oc_2017_2021_df = pd.read_csv('final_oc_combined_cdoc.csv')

rc_2017_2021_df = pd.read_csv('RC_2017_2021_CDOC_CLEANED.csv')

bc_2017_2021_df = pd.read_csv('final_bc_combined_cdoc.csv')   

In [None]:
# #Regular expression pattern for MMR

pattern = r""" 
\bMSI\b | 
\bMicrosatellite\b | 
\bMLH-1\b | 
\bMSH-2\b |
\bMSH-6\b |
\bPMS-2\b |
\bMismatch\s+repair\b 
"""

In [None]:
#for breast cancer, TNBC
pattern_tnbc = r""" 
\bER\b | 
\bPR\b | 
\bHER2\b | 
\bc-erb-b2\b |
\bcerbb2\b 
"""

In [None]:
ec_2017_2021_df['mmr_flag'] = ec_2017_2021_df['CDOC'].str.contains(pattern, flags=re.IGNORECASE | re.VERBOSE, case=False, regex=True, na=False)
oc_2017_2021_df['mmr_flag'] = oc_2017_2021_df['CDOC'].str.contains(pattern, flags=re.IGNORECASE | re.VERBOSE, case=False, regex=True, na=False)

In [None]:
#Re for TNBC
bc_2017_2021_df['mmr_flag'] = bc_2017_2021_df['CDOC'].str.contains(pattern_tnbc, flags=re.IGNORECASE | re.VERBOSE, case=False, regex=True, na=False)

In [None]:
pos_ec = ec_2017_2021_df[ec_2017_2021_df['mmr_flag'] == True]
pos_oc = oc_2017_2021_df[oc_2017_2021_df['mmr_flag'] == True]
#pos_rc = rc_2017_2021_df[rc_2017_2021_df['mmr_flag'] == True]

In [None]:
pos_bc = bc_2017_2021_df[bc_2017_2021_df['mmr_flag'] == True]

In [None]:
cnt = pos_bc['PATIENT_IDENTIFIER'].nunique()
cnt


In [None]:
for txt in ss['CDOC'].tolist():
    print(txt, '\n' + '-'*80)

In [None]:
#FOR PROCESS AMI OUTCOME

# llama url
llama3_url = "10.18.212.70:30012"

# prompt to use as default
default_sys_prompt = '''
You are a clinical-text classifier. You will be given clinical notes and your job is to follow the rules stated in the user prompt. 
DO NOT MAKE ASSUMPTIONS ABOUT PROTEINS THAT ARE NOT MENTIONED. DO not fabricate or generate results.
'''


# FOR MMR

In [None]:
def construct_dialog(CDOC, PATIENT_IDENTIFIER):
  
    dialog = [
        [{"role": "system", "content": default_sys_prompt},
         {"role": "user", "content": f''' 
         
         Task: Detect wheter patients have an abnormal expression of DNA mismatch reapir (MMR) proteins: MLH-1, 
         MSH-2, MSH-6 or PMS-2 from the clinical notes.
         
         Clinical Note: {CDOC}
         
         **KEY PRINCIPLE**
         1. LOOKOUT FOR THE FOLLOWING KEYWORDS:
             - MSI High
             - Mircostatellite unstable
             - MLH-1 = Negative
             - MSH-2 = Negative
             - MSH-6 = Negative
             - PMS-2 = Negative 
             - Negative expression of DNA mismatch repair proteins
        
         **CLASSIFICATION CRITERIA**
         1. If Keywords specified above are found do the following: 
             - If there is an abnormal or negative expression, indicate overall_result as "abnormal".
             - If the report indicates that the MMR expression is normal OR all the proteins are normal, indicate the overall result as "normal".
             - if there are only equivocal and/or normal proteins, indicate the overall result as "equivocal"
             - if the report has no mentions of keywords, DNA mismatch repairs or any of the proteins, return "null".
        
         2. IF KEYWORDS ARE NOT FOUND:
            - Return overall_result as null since it has no correlation to the proteins. 
            - Examples: Random alpha-numeric values "CECnnnnn" where n is a numeric value. 
            - Rows with names or time does not equate to abnormal cases.
            - Do not infer standalone words like "yes", "as above", Absent" or "Present", "False", "True" unless directly following one of the protein names.
            - If there is no mentions of DNA mismatch repairs or any of the proteins do not generate false results.
    
            Example 1:
            {{
                "PMS-2": "normal",
                "MLH-1": "abnormal",
                "overall_result": "abnormal"
            }}

            Example 2:
            {{
                "PMS-2": "normal",
                "MLH-1": "normal",
                "overall_result": "normal"
            }}

            Example 3:
            {{
                "MLH-1": "normal",
                "MSH-2": "normal", 
                "MSH-6": "equivocal",
                "overall_result": "equivocal"
            }}

             **Response Format: MANDATORY JSON**
             {{
                 "overall_result": abnormal|normal|equivocal|null,
                 "rationale": "short reasoning to justify overall_result",
                 "supporting_text": "the snippet you matched"
             }}

            Respond ONLY with the specific JSON format. Do not generate or fabricate results to justify the outcome. 
             
            If there are medical terms like "nursing", "multidisciplinary notes", "cancerline" or any generic medical terminologies without any mentions of DNA mismatch repairs or any of the proteins flag as "null".  

            DO NOT MAKE ASSUMPTIONS ABOUT PROTEINS THAT ARE NOT MENTIONED. 
            
            IF NO PROTEINS IN THE KEYWORDS ARE FOUND IN {CDOC} SIMPLY STATE AS "null". 
            
            IF alpha-numeric strings are found with no relation to the keywords in {CDOC} simply state as "null".
            

        '''}
        ],
    ]
    return dialog 


# FOR TNBC

In [None]:
def construct_dialog(CDOC, PATIENT_PATIENT_IDENTIFIERKEY):
  
    dialog = [
        [{"role": "system", "content": default_sys_prompt},
         {"role": "user", "content": f''' 
         
         Task: Detect wheter patients have negative expression of breast cancer markers ER, PR, and HER2 from the clinical notes.
         HER2 may also be referred to as "c-erb-b2" or "cerbb2".
         
         Clinical Note: {CDOC}
         
         **KEY PRINCIPLE**
         1. LOOKOUT FOR THE FOLLOWING KEYWORDS:
             - ER = Negative
             - PR = Negative
             - HER2 = Negative
             - c-erb-b2 = Negative
             - cerbb2 = Negative
        
         **CLASSIFICATION CRITERIA**
         1. If all 3 markers (ER, PR, HER2/c-erb-b2/cerbbb2) are reported as **negaive** in the same note:
             - Indicate overall_result as "triple_negative"
             
         2. If one or more of the 3 markers are **not found**, or are reported as **positive**, or **unknown**, and the rest are negative:
             - Indicate overall_result as "not_triple_negative"
             
         3. If there are no mentions of ER, PR, HER2, c-erb-b2, or cerbb2:
             - Indicate overall_result as "null"

             **Response Format: MANDATORY JSON**
             {{
                 "overall_result": triple_negative| not_triple_negative | null,
                 "rationale": "short reasoning to justify overall_result",
                 "supporting_text": "the snippet you matched"
             }}

            Respond ONLY with the specific JSON format. Do not generate or fabricate results to justify the outcome. 
             

            DO NOT MAKE ASSUMPTIONS ABOUT RECEPTORS THAT ARE NOT MENTIONED. 
            
            IF NO RECEPTORS ARE FOUND IN {CDOC} SIMPLY STATE AS "null". 
            
            IF alpha-numeric strings are found with no relation to the keywords in {CDOC} simply state as "null".
            

        '''}
        ],
    ]
    return dialog 


In [None]:
def get_llama3_response(dialogs, llama3_url):
    try:
        connection = http.client.HTTPConnection(llama3_url, timeout = 60)
        headers = {"Content-Type": "application/json"}

        payload = { 
            "dialogs": dialogs,
            "top_p": 0.9, # default: 0.9
            "temperature": 0.01, # default: 0.1
            "max_seq_len": 8000, # default: 8192
            "max_gen_len": 2000 # default: None, model stops when the stop symbol is generated
        }
        try:
            connection.request("POST", "/batch-dialogs", json.dumps(payload), headers)
            response = connection.getresponse()
            data = response.read().decode("utf-8")
            print(f"Response sttus: {response.status}")
            return data
        except Exception as e:
            print(f"Request failed: {str(e)}")
            return None
        finally:
            connection.close()
    except Exception as e:
        print(f"Error in API call: {e}")
        return None
    
def parse_response(data):
    if data is None:
        print("No data to parse")
        return None
    try: 
        parsed_data = json.loads(data)
        
        if isinstance(parsed_data, list) and len(parsed_data) > 0:
            first_item = parsed_data[0]
            print(f"First item: {first_item}")
            
            if isinstance(first_item, dict) and 'generation' in first_item:
                content = first_item['generation'].get('content', '')
                content_dict = {
                    'overall_result': "",
                    'rationale': "",
                    'supporting_text': ""
                }
                '''
                if 'content' in gen:
                    return {
                        'overall_result': "",
                        'rationale': gen['content'],
                        'supporting_text': ""
                    }
                '''
                lines = content.split('\n')
                for line in lines:
                    try:
                        line = line.strip()
                        if '"overall_result":' in line:
                            try:
                                content_dict['overall_result'] = line.split(':', 1)[1].strip().strip('"').rstrip(',')
                            except:
                                content_dict['overall_result'] = '' 
                        elif '"rationale":' in line:
                            try:
                                content_dict['rationale'] = line.split(':', 1)[1].strip().strip('"').rstrip(',')
                            except:
                                content_dict['rationale'] = ''
                        elif '"supporting_text":' in line:
                            try:
                                content_dict['supporting_text'] = line.split(':', 1)[1].strip().strip('"').rstrip(',')
                            except:
                                content_dict['supporting_text'] = ''
                    except Exception as line_error:
                        print(f"Error parsing response: {line_error}")
                        continue
                print(f"Parsed content: {content_dict}")
                return content_dict
        print(f"Unhandled response structure: {parsed_data}")
        return None
    
    except Exception as e:
        print(f"Error in parse_response: {str(e)}")
        return None
    
def process_in_batches(df, llama3_url, batch_size = 100, checkpoint_interval = 5000, start_row = 0):
    
    print("Starting process in batches function")
    if 'overall_result' not in df.columns:
        df['overall_result'] = pd.Series(dtype = 'string')
        df['rationale'] = pd.Series(dtype = 'string')
        df['supporting_text'] = pd.Series(dtype = 'string')
        
    total_rows = len(df)
    print(f"Starting processing from row {start_row} out of {total_rows} total rows")
    
    df = df.reset_index(drop = True)
    
    print("starting main loop")
    
    for start_idx in tqdm(range(start_row, total_rows, batch_size)):
        print(f"\nStarting batch at index {start_idx}")
        end_idx = min(start_idx + batch_size, total_rows)
        batch = df.iloc[start_idx:end_idx]
        
        for idx, row in batch.iterrows():
            print(f"\n Processing row {idx} (Patient key: {row['PATIENT_IDENTIFIER']})")
            try:
                CDOC = row['CDOC'],
                PATIENT_IDENTIFIER = row['PATIENT_IDENTIFIER'],
                
                dialog = construct_dialog(
                    CDOC,
                    PATIENT_IDENTIFIER
                )
                print(f"Sending request for row {idx} to LLM... ")
                response_data = get_llama3_response(dialog, llama3_url)
                result = parse_response(response_data)
                
                if result:
                    print(f"Successfully processed row {idx}")
                    df.at[idx, 'overall_result'] = str(result['overall_result'])
                    df.at[idx, 'rationale'] = str(result['rationale'])
                    df.at[idx, 'supporting_text'] = str(result['supporting_text'])
                    
                
                else:
                    print(f"No Valid result for row {idx}")
                    df.at[idx, 'overall_result'] = ''
                    df.at[idx, 'rationale'] = f'error: No Valid result'
                    df.at[idx, 'supporting_text'] = ''
                    
            except Exception as processing_error:
                error_message = str(processing_error)
                print(f"Error processing row {idx}: {error_message}")
                df.at[idx, 'overall_result'] = ''
                df.at[idx, 'rationale'] = f'error: {str(error_message)}'
                df.at[idx, 'supporting_text'] = ''
            time.sleep(0.1)
                
        print(f"Debug - Strt idx: {start_idx}")
        print(f"Debug - checkpoint interval: {checkpoint_interval}")
        print(f"Debug - modulo result: {start_idx % checkpoint_interval}")
        
        if start_idx > 0 and start_idx % checkpoint_interval == 0:
            print(f"Saving checkpoint at index {start_idx}")
                
            checkpoint_df = df.copy()
            try:
                save_path = f'oc_final_count_v2_{start_idx}.csv'
                print(f"Debug - Attempting to save to: {save_path}")
                checkpoint_df.to_csv(save_path, index = False)
                print(f"Successfully saved checkpoint to {save_path}")
                print(f"Current Data Types:")
                print(checkpoint_df[['overall_result', 'rationale', 'supporting_text']].dtypes)
            except Exception as e:
                print(f"Error saving checkpoints: {str(e)}")
                
    print(f"Saving final results at row {total_rows}")
    
    try:
        final_save_path = f'oc_final_count_v2_{total_rows}.csv'
        print(f"Saving to {final_save_path}")
        df.to_csv(final_save_path, index = False)
        print(f"Sucessfully saved final results")
    except Exception as e:
        print(f"Error saving final results: {e}")
        
    return df

    
def resume_from_checkpoint(checkpoint_path, new_df):
    try:
        checkpoint = pd.read_csv(checkpoint_path, dtype = {
            'overall_result': 'string',
            'rationale': 'string',
            'supporting_text': 'string'
        })
        
        last_processed = checkpoint['overall_result'].last_valid_index()
        if last_processed is None:
            return new_df, 0
        
        print(f"Resuming from row {last_processed + 1}")
        return checkpoint, last_processed + 1
    except Exception as e:
        print(f"Error loading checkpoint: {error}")
        return new_df, 0 
    

In [None]:
result_df = process_in_batches(pos_oc, llama3_url, batch_size = 100, checkpoint_interval = 40000)

In [None]:
df, start_row = resume_from_checkpoint('tnbc_final_count_40000.csv', pos_bc)

result_df = process_in_batches(df, llama3_url, batch_size = 100, checkpoint_interval = 20000, start_row = start_row)