In [None]:
import re

In [1]:
import numpy as np
import pandas as pd
import os
import time
import sys
import matplotlib.pyplot as plt
from google.colab import auth, drive


In [None]:
auth.authenticate_user()

# Mount Google Drive
drive.mount('/content/drive')

In [None]:
import vertexai
from vertexai.generative_models import (
    GenerationConfig,
    GenerativeModel,
    HarmBlockThreshold,
    HarmCategory,
    Image,
    Part,
    SafetySetting,
)
from vertexai.batch_prediction import BatchPredictionJob
import json
from google.cloud import storage

# replace with project ID from Google Cloud Platform
PROJECT_ID = "mit-mlhc-v2"  # @param {type: "string", placeholder: "[your-project-id]", isTemplate: true}
if not PROJECT_ID or PROJECT_ID == "[your-project-id]":
    PROJECT_ID = str(os.environ.get("GOOGLE_CLOUD_PROJECT"))

LOCATION = os.environ.get("GOOGLE_CLOUD_REGION", "us-central1")

vertexai.init(project=PROJECT_ID, location=LOCATION)

In [None]:
def extract_text(response):
  """Extracts text from the response dictionary, handling potential KeyError."""
  try:
    return response['candidates'][0]['content']['parts'][0]['text']
  except (KeyError, IndexError, TypeError):
    # Handle cases where 'parts' key is missing or empty
    print(response)
   # Or any other appropriate default value
    return ''

def concatenate_notes(group):
  return '\n\n'.join(f"{row['note_type_long']}\n{row['masked_regex']}" for _, row in group.iterrows())

def concatenate_notes2(group):
  return '\n\n'.join(f"{row['note_type_long']}\n{row['gemini']}" for _, row in group.iterrows())

# Data investigation

In [None]:
fn = '/content/drive/MyDrive/HSPH/Courses/MIT6.7930/AI Bias for AD/AI Bias AD/data/data_v3-rematched-USE_THIS/11_final_cohort_alldata.csv.gz'
df = pd.read_csv(fn)
df = df.sort_values(by='subject_id')
# df

In [None]:
temp = df.groupby(['subject_id', 'case_status']).size().unstack(fill_value=0)
cases = temp[temp[1] > 0]
ctrls = temp[temp[0] > 0]
print(len(cases))
print(len(ctrls))

In [None]:
cases[1].describe()

In [None]:
ctrls[0].describe()

In [None]:
# print(df.loc[16, 'text'])

# Mask AD mentions in clinical notes

## Regex replace

In [None]:
def remove_phrases(string):
  pattern = re.compile(r'(?:(?<=^)|(?<=,)|(?<=\.))\s*[^,.]*?\b(?:Alzheimer|dementia|Donepezil|Memantine|Rivastigmine|Galantamine)\b[^,.]*?(?=[,.]|$)', flags=re.I)
  stripped = re.sub(pattern, '', string).strip(' ,.')
  cleaned = re.sub(r',+', ',', stripped).strip(', ').strip()
  cleaned = re.sub(r'\.+', '.', cleaned)
  cleaned = re.sub(r'\,\.', '', cleaned)
  return cleaned

In [None]:
df['masked_regex'] = df['text'].apply(remove_phrases)
print(df.shape)
# df

In [None]:
df['id'] = df['subject_id'].astype(str) + '_' + df['hadm_id'].astype(int).astype(str)
df_notes = df.copy()
df_notes = df_notes.sort_values(by=['subject_id', 'charttime'])
df_notes = df_notes[['note_id', 'id', 'note_type', 'masked_regex']]
df_notes['note_type_long'] = ['Discharge Summary:' if x == 'DS' else 'Radiology Note:' for x in df_notes['note_type']]
df_concat = df_notes.groupby('id').apply(concatenate_notes).reset_index(name='concatenated_notes')
df_concat

In [None]:
df_clean.columns

In [None]:
df_clean = df.copy()
df_clean = df_clean[['id', 'admityear', 'admitmonth', 'admitday', 'gender', 'age', 'admission_type',
                     'marital_status', 'race', 'insurance_group', 'language_group', 'race_group1', 'race_group2', 'race_group3', 'race_group4',
                     'adrd', 'ad', 'case_status', 'adrd_status',
                     'Stroke_History', 'Myocardial_Infarction', 'Peripheral_Vascular_Disease', 'Cerebrovascular_Disease', 'Diabetes_Mellitus', 'Cancer']]
df_clean = df_clean.drop_duplicates()
print(df_clean.shape)

In [None]:
df_regex = pd.merge(df_clean, df_concat, on='id', how='inner')
print(df_regex.shape)
fn = '/content/drive/MyDrive/HSPH/Courses/MIT6.7930/AI Bias for AD/AI Bias AD/Gemini Prediction Model/12_regex_masked_concat.csv'
df_regex.to_csv(fn, index=False, header=True)
df_regex['case_status'].value_counts()

## Gemini Masking

In [None]:
def df_to_jsonl_gcs(df, bucket_name, blob_name):
    """Converts a DataFrame to JSONL and uploads to Google Cloud Storage.

    Args:
        df: Pandas DataFrame with a 'text' column.
        bucket_name: Name of your Google Cloud Storage bucket.
        blob_name: Desired name for the JSONL file on GCS.
    """

    # Initialize a GCS client
    storage_client = storage.Client(project=PROJECT_ID)
    bucket = storage_client.bucket(bucket_name)
    if not bucket.exists():
        bucket.create(location='US')
        print(f'Bucket {bucket_name} created.')
    else:
        print(f'Bucket {bucket_name} already exists.')

    blob = bucket.blob(blob_name)

    # Write JSONL data to a string buffer
    jsonl_data = ""
    for index, row in df.iterrows():
        text = row['text'][:50000]

        ### TO DO: Provide Gemini a prompt ##############################
        # Edit the prompt to tell Gemini how to handle your input text
        prompt = f"""You are a clinical documentation specialist assisting in a medical research project aiming to predict Alzheimer's Disease from clinical notes. Please carefully review the following clinical notes and remove phrases with explicit diagnosis or direct mention Alzheimer's Disease (AD), dementia (only if directly diagnosed, general symptoms should be preserved), and medications specifically prescribed for Alzheimer's Disease (e.g., Donepezil, Memantine, Rivastigmine, Galantamine).

        A phrase is defined as any series of words contained within two commas (,), periods (.), parentheses (() or []), or single line (indicated by \n).

        Do not remove symptoms that could indicate cognitive decline, memory loss, confusion, disorientation, difficulty with language, executive dysfunction, or behavior changes.
        Do not remove general clinical observations that might hing at early-stage Alzheimer's or other neurological impairments.
        Clinical notes: {text}"""
        #################################################################

        json_data = {
            "id": row['note_id'],
            "request": {
                "contents": [
                    {
                        "role": "user",
                        "parts": [{"text": prompt}]
                    }
                ],
                "generationConfig": {"temperature": 0.4, "maxOutputTokens": 4096},

            }
        }
        jsonl_data += json.dumps(json_data) + '\n'

    # Upload the JSONL data to GCS
    blob.upload_from_string(jsonl_data, content_type='application/jsonl')
    print(f"JSONL file uploaded to gs://{bucket_name}/{blob_name}")

    return f"gs://{bucket_name}/{blob_name}"

### TO DO: Change bucket name ##############################

BUCKET_NAME = 'project_masking'
input_uri = df_to_jsonl_gcs(df, BUCKET_NAME, 'gemini_batch_requests.jsonl')

#################################################################

output_uri = f"gs://{BUCKET_NAME}/batch-prediction/"

# Submit a batch prediction job with Gemini model
batch_prediction_job = BatchPredictionJob.submit(
    source_model="gemini-1.5-flash-001",
    input_dataset=input_uri,
    output_uri_prefix=output_uri,
)

# Check job status
print(f"Job resource name: {batch_prediction_job.resource_name}")
print(f"Model resource name with the job: {batch_prediction_job.model_name}")
print(f"Job state: {batch_prediction_job.state.name}")

# Refresh the job until complete
while not batch_prediction_job.has_ended:
    time.sleep(5)
    batch_prediction_job.refresh()

# Check if the job succeeds
if batch_prediction_job.has_succeeded:
    print("Job succeeded!")
else:
    print(f"Job failed: {batch_prediction_job.error}")

# Check the location of the output
print(f"Job output location: {batch_prediction_job.output_location}")

# Example response:
#  Job output location: gs://your-bucket/gen-ai-batch-prediction/prediction-model-year-month-day-hour:minute:second.12345

In [None]:
# Load the JSONL file into a DataFrame
# once you've made your predictions, they should be
# stored at your google cloud storage bucket specified by the path
# and you should be able to download it from this path
path = 'gs://project_masking/batch-prediction/prediction-model-2025-05-01T15:12:34.600818Z'
output_path = path + '/predictions.jsonl'

masked_df = pd.read_json(output_path, lines=True)
masked_df = masked_df.join(pd.json_normalize(masked_df["response"], "candidates"))
print(masked_df.shape)

# Note some inputs may not generate predictions due to SAFETY constraints
masked_df['summary'] = masked_df['response'].apply(extract_text)
masked_df = masked_df[masked_df['summary'] != '']
masked_df = masked_df[['id', 'summary']]
masked_df.columns = ['note_id', 'gemini']
masked_df.shape

In [None]:
df_notes = df.copy()
df_notes = df_notes[['note_id', 'id', 'note_type', 'subject_id', 'charttime']]
df_notes = pd.merge(df_notes, masked_df, on='note_id', how='inner')
df_notes = df_notes.sort_values(by=['subject_id', 'charttime'])
df_notes['note_type_long'] = ['Discharge Summary:' if x == 'DS' else 'Radiology Note:' for x in df_notes['note_type']]
df_concat = df_notes.groupby('id').apply(concatenate_notes2).reset_index(name='concatenated_notes')
df_concat

In [None]:
df_gemini = pd.merge(df_clean, df_concat, on='id', how='inner')
print(df_gemini.shape)
fn = '/content/drive/MyDrive/HSPH/Courses/MIT6.7930/AI Bias for AD/AI Bias AD/Gemini Prediction Model/12_gemini_masked_concat.csv'
df_gemini.to_csv(fn, index=False, header=True)
df_gemini['case_status'].value_counts()