## **Project Information**
## **Authors: Shrey Patel (sp2675), Abhishek Jani (aj1121), Mustafa Adil (ma2398)**
- Note:
This code was developed and tested using Google Colab.For best results and to ensure all dependencies are handled correctly,it is recommended to upload the corresponding .ipynb file to Google Colab and run the cells sequentially.



# **Setup and Installs**

In [None]:
# !pip install --upgrade numpy==1.26.4 # Specific numpy version if required by dependencies
# !pip install --upgrade thinc==8.3.5 # Specific thinc version if required by dependencies
!pip install --upgrade pandas scikit-learn transformers datasets textattack requests tqdm

import pandas as pd
import numpy as np
import torch
import re
import json
import os
import time
import requests
from concurrent.futures import ThreadPoolExecutor, as_completed
from tqdm.auto import tqdm # for progress bars

# Import components from libraries
from datasets import Dataset
from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification,
    Trainer,
    TrainingArguments,
    AutoConfig
)
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score
from sklearn.model_selection import train_test_split
from textattack.models.wrappers import HuggingFaceModelWrapper # For TextAttack integration
tqdm.pandas()
print("libraries imported")

gpu_available = torch.cuda.is_available()
if gpu_available:
  print("gpu is available")
  DEVICE = torch.device("cuda")
else:
    print("No gpu, using cpu")
    DEVICE = torch.device("cpu")

# **Configuration**

```
Note: Seeding for reproducibility is often done here, but currently commented out.

Uncomment these lines if you need strict reproducibility across runs.

np.random.seed(CONFIG["random_seed"])
torch.manual_seed(CONFIG["random_seed"])
if torch.cuda.is_available():
  torch.cuda.manual_seed_all(CONFIG["random_seed"])
```

In [2]:
CONFIG = {
    #dataset
    "classification_data_path": "hf://datasets/ade-benchmark-corpus/ade_corpus_v2/Ade_corpus_v2_classification/train-00000-of-00001.parquet",
    "relation_data_path": "hf://datasets/ade-benchmark-corpus/ade_corpus_v2/Ade_corpus_v2_drug_ade_relation/train-00000-of-00001.parquet",
    "extracted_drugs_file": "extracted_drug_names.json",
    "brand_mapping_file": "full_drug_brand_mapping.json",
    "modified_data_file": "classification_data_with_brands.parquet",
    #model
    "base_model_name": "dmis-lab/biobert-base-cased-v1.1",
    "finetuned_model_dir": "./trained_model",
    "tokenizer_max_length": 128,
    "num_labels": 2,

    "training_output_dir": "./results_quick",
    "num_train_epochs": 10,
    "per_device_train_batch_size": 16,
    "per_device_eval_batch_size": 16,
    "save_strategy_finetune": "epoch",
    "logging_steps": 200,
    "report_to": "none",
    "dataloader_num_workers": 2,
    "fp16_training": torch.cuda.is_available(),

    "ade_label": 1,
    "test_split_size": 0.2,
    "random_seed": 42,
    #api
    "api_delay_seconds": 0.2,
    "api_timeout_seconds": 10,
    #ouput files
    "masking_flip_ade_file": "masking_flip_examples_ade_only.json",
    "masking_nonflip_ade_file": "masking_non_flip_examples_ade_only.json",
    "masking_flip_non_ade_file": "output/masking_non_ade_flips.json",
    "masking_nonflip_non_ade_file": "output/masking_non_ade_nonflips.json",
    "brand_flip_all_file": "brand_flip_examples.json",
    "brand_nonflip_all_file": "brand_nonflip_examples.json",
    "brand_flip_ade_file": "brand_flip_examples_ade_only.json",
    "brand_nonflip_ade_file": "brand_nonflip_examples_ade_only.json",
}

# np.random.seed(CONFIG["random_seed"])
# torch.manual_seed(CONFIG["random_seed"])
# if torch.cuda.is_available():
#     torch.cuda.manual_seed_all(CONFIG["random_seed"])

# **Utility Functions**

In [3]:
def predict_label(text, model_wrapper):
    if not model_wrapper:
         raise ValueError("Model wrapper is not available.")
    try:
        # Ensuring that the text is a non-empty string
        if not isinstance(text, str): text = str(text)
        if not text or text.strip() == "":
            return 0 # default  for empty or invalid input
        outputs = model_wrapper([text])
        logits = outputs[0]
        if isinstance(logits, torch.Tensor):
            logits = logits.cpu().detach().numpy()
        # for handling NaN/inf in logits if model outputs them
        if not np.all(np.isfinite(logits)):
             print(f"Warning: Non-finite logits encountered for text: '{text[:50]}...'. Returning default (0).")
             return 0
        return np.argmax(logits)
    except Exception as e:
        print(f"Error during prediction for text: '{text[:50]}...': {e}")
        return 0 # on error return default

In [4]:
def save_json(data, filename):
    try:
        output_dir = os.path.dirname(filename)
        if output_dir and not os.path.exists(output_dir):
            os.makedirs(output_dir)
        with open(filename, "w", encoding='utf-8') as f:
            json.dump(data, f, indent=4, ensure_ascii=False)
    except Exception as e:
        print(f"Had an error saving data to {filename}: {e}")

In [5]:
def load_json(filename):
    if not os.path.exists(filename):
        print(f"file not found {filename}")
        return None
    try:
        with open(filename, 'r', encoding='utf-8') as f:
            data = json.load(f)
        return data
    except json.JSONDecodeError as e:
        print(f"error decoding {filename}: {e}")
        return None
    except Exception as e:
        print(f"erroe reading {filename}: {e}")
        return None

In [6]:
def check_and_create_dir(directory_path):# func checks if a dir exists, if it does not exist then creates it.
    if directory_path and not os.path.exists(directory_path):
        print(f"Creating directory {directory_path}")
        os.makedirs(directory_path)

# **Data Loading and Preprocessing Functions**

In [7]:
def load_parquet_data(path):
    try:
        df = pd.read_parquet(path)
        print(f"loaded {len(df)} rows from {path}")
        return df
    except Exception as e:
        print(f"Error loading parquet file from {path}: {e}")
        return None

In [8]:
def extract_drug_names(df_relation, drug_col='drug'):
    if df_relation is None or drug_col not in df_relation.columns:
        print(f"Error in drug data '{drug_col}'.")
        return []
    unique_drugs = set()
    for drugs_entry in df_relation[drug_col].dropna():
        for drug in str(drugs_entry).split(','):
            drug_clean = drug.strip().lower()
            if drug_clean and len(drug_clean) > 1: # Adding basic validation
                unique_drugs.add(drug_clean)
    sorted_drugs = sorted(list(unique_drugs))
    print(f"there are {len(sorted_drugs)} unique drug names")
    return sorted_drugs

In [9]:
def tokenize_function_hf(examples, tokenizer, max_length, text_col="text"):
    return tokenizer(
        examples[text_col],
        truncation=True,
        padding="max_length",
        max_length=max_length
    )

In [10]:
def prepare_hf_dataset(df, tokenizer, config, text_col="text", label_col="label"):
    if df is None or text_col not in df.columns or label_col not in df.columns:
         print("Error in dataset creation")
         return None, None # Return None for both training and test

    dataset = Dataset.from_pandas(df[[text_col, label_col]])

    print("Splitting dataset")
    split = dataset.train_test_split(test_size=config["test_split_size"], seed=config["random_seed"])
    train_dataset = split["train"]
    test_dataset = split["test"]

    #tokenizing
    train_dataset = train_dataset.map(
        lambda examples: tokenize_function_hf(examples, tokenizer, config["tokenizer_max_length"], text_col),
        batched=True
    )
    test_dataset = test_dataset.map(
        lambda examples: tokenize_function_hf(examples, tokenizer, config["tokenizer_max_length"], text_col),
        batched=True
    )

    try:
        train_dataset.set_format(type="torch", columns=["input_ids", "attention_mask", label_col])
        test_dataset.set_format(type="torch", columns=["input_ids", "attention_mask", label_col])
        train_dataset = train_dataset.rename_column(label_col, "labels")
        test_dataset = test_dataset.rename_column(label_col, "labels")
    except Exception as e:
         print(f"error{e}")
         return None, None

    return train_dataset, test_dataset

# **Model Functions**

In [11]:
def load_tokenizer_and_model(model_name_or_path, num_labels):
    try:
        print(f"loading tokenizer from {model_name_or_path}")
        tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)

        print(f"loading model from {model_name_or_path}")
        model_config = AutoConfig.from_pretrained(
            model_name_or_path,
            num_labels=num_labels
        )
        model = AutoModelForSequenceClassification.from_pretrained(
            model_name_or_path,
            config=model_config
        )
        model.to(DEVICE)
        print("loaded successfully")
        return tokenizer, model
    except Exception as e:
        print(f"Error loading{model_name_or_path}: {e}")
        return None, None

In [12]:
def create_model_wrapper(model, tokenizer):
    if model is None or tokenizer is None:
        print("Error, no model or tokenizer")
        return None
    try:
        model_wrapper = HuggingFaceModelWrapper(model, tokenizer)
        print("Model is wrapped successfully.")
        return model_wrapper
    except Exception as e:
        print(f"Error in wrapping model: {e}")
        return None

In [13]:
def compute_metrics_hf(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    acc = accuracy_score(labels, predictions)
    f1 = f1_score(labels, predictions, average='macro')
    precision = precision_score(labels, predictions, average='macro', zero_division=0)
    recall = recall_score(labels, predictions, average='macro', zero_division=0)
    return {"accuracy": acc, "f1": f1, "precision": precision, "recall": recall}

In [14]:
def train_model(model, tokenizer, train_dataset, eval_dataset, config):
    if model is None or tokenizer is None or train_dataset is None or eval_dataset is None:
        print("Error in training")
        return None

    training_args = TrainingArguments(
        output_dir=config["training_output_dir"],
        num_train_epochs=config["num_train_epochs"],
        per_device_train_batch_size=config["per_device_train_batch_size"],
        per_device_eval_batch_size=config["per_device_eval_batch_size"],
        save_strategy=config["save_strategy_finetune"],
        logging_steps=config["logging_steps"],
        report_to=config["report_to"],
        fp16=config["fp16_training"],
        dataloader_num_workers=config["dataloader_num_workers"],
        eval_strategy="epoch" if config["save_strategy_finetune"] != "no" else "no",
        logging_strategy="epoch",
        load_best_model_at_end=True if config["save_strategy_finetune"] != "no" else False,
        metric_for_best_model="f1" if config["save_strategy_finetune"] != "no" else None,
        seed=config["random_seed"]
    )

    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        tokenizer=tokenizer,
        compute_metrics=compute_metrics_hf
    )

    print("Starting model finetuning")
    train_result = trainer.train()
    print("finished model finetuning")

    print("test set evaluation")
    eval_results = trainer.evaluate()
    print("Evaluation metrics:", eval_results)

    save_directory = config["finetuned_model_dir"]
    check_and_create_dir(save_directory)
    print(f"Saving model and tokenizer to: {save_directory}")
    trainer.save_model(save_directory)

    print("Model training and saving complete.")
    return model

# **Analysis Functions**

In [15]:
def mask_drug_names(text, drug_names_list, mask_token='[MASK]'):
    if not isinstance(text, str) or not drug_names_list:
        return text, False
    masked_text = text
    found = False
    sorted_drug_names = sorted(drug_names_list, key=len, reverse=True)

    for drug in sorted_drug_names:
        if not drug: continue
        try:
            #using \b for word boundaries
            pattern = re.compile(r'\b' + re.escape(drug) + r'\b', flags=re.IGNORECASE)
            temp_masked_text = pattern.sub(mask_token, masked_text)
            if temp_masked_text != masked_text:
                found = True
                masked_text = temp_masked_text
        except re.error as e:
            continue
    return masked_text, found

In [16]:
def prepare_brand_replacement(mapping_dict):
    if not mapping_dict:
        return {}, [], {}
    # Ensuring the keys are lowercase for consistent matching
    lower_mapping_dict = {k.lower(): v for k, v in mapping_dict.items()}
    compiled_patterns = {
        generic_lower: re.compile(r'\b' + re.escape(generic_lower) + r'\b', flags=re.IGNORECASE)
        for generic_lower in lower_mapping_dict.keys()
    }
    sorted_generic_keys = sorted(lower_mapping_dict.keys(), key=len, reverse=True)
    return lower_mapping_dict, compiled_patterns, sorted_generic_keys

In [17]:
def replace_generic_with_brand(text, lower_mapping_dict, compiled_patterns, sorted_keys):
    if not isinstance(text, str) or not lower_mapping_dict:
        return text, False

    current_text = text
    replaced = False
    for generic_lower in sorted_keys:
        pattern = compiled_patterns.get(generic_lower)
        brand_name = lower_mapping_dict.get(generic_lower)
        if pattern and brand_name:
            new_text = pattern.sub(brand_name, current_text)
            if new_text != current_text:
                replaced = True
                current_text = new_text
    return current_text, replaced

In [18]:
def calculate_flip_rate(df, text_col1, text_col2, model_wrapper, description="Processing"):
    if df is None or df.empty or model_wrapper is None:
        print(f"'{description}': invalid inputs.")
        return 0.0, [], []

    flip_count = 0
    total_processed = 0
    flip_examples = []
    non_flip_examples = []

    print(f"\nCalculating flip rate for: {description}")
    for index, row in tqdm(df.iterrows(), total=len(df), desc=description):
        text1 = row[text_col1]
        text2 = row[text_col2]

        #Basic validation
        if not isinstance(text1, str) or not isinstance(text2, str) or not text1 or not text2:
            continue

        try:
            pred1 = predict_label(text1, model_wrapper)
            pred2 = predict_label(text2, model_wrapper)
            total_processed += 1

            example_data = {
                'index': index,
                text_col1: text1,
                text_col2: text2,
                'prediction1': int(pred1),
                'prediction2': int(pred2)
            }
            if 'label' in row: example_data['original_label'] = int(row['label'])
            if pred1 != pred2:
                flip_count += 1
                flip_examples.append(example_data)
            else:
                non_flip_examples.append(example_data)
        except Exception as e:
            print(f"error in flip calculation for index {index} in {description}: {e}")

    flip_rate = flip_count / total_processed if total_processed > 0 else 0.0
    print(f"Results ({description}) ---")
    print(f"Total sentences processed: {total_processed}")
    print(f"sentences where prediction FLIPPED: {flip_count}")
    print(f"sentences where prediction DID NOT FLIP: {len(non_flip_examples)}")
    print(f"Flip Rate: {flip_rate:.4f}")

    return flip_rate, flip_examples, non_flip_examples

# **API Functions (RxNorm)**

In [19]:
def get_rxcui(generic_name, config):
    base_url = "https://rxnav.nlm.nih.gov/REST/rxcui.json"
    params = {'name': generic_name, 'search': 2} #Approx match
    try:
        response = requests.get(base_url, params=params, timeout=config["api_timeout_seconds"])
        response.raise_for_status()
        data = response.json()
        if 'idGroup' in data and 'rxnormId' in data['idGroup'] and data['idGroup']['rxnormId']:
            return data['idGroup']['rxnormId'][0] #Return first RxCUI
        return None
    except requests.exceptions.Timeout:
        print(f"Timeout fetching RxCUI for '{generic_name}'")
        return None
    except requests.exceptions.RequestException as e:
        print(f"API request failed for RxCUI lookup '{generic_name}': {e}")
        return None
    except Exception as e:
        print(f"Unexpected error in get_rxcui for '{generic_name}': {e}")
        return None

In [20]:
def get_brand_names(rxcui, config):
    if not rxcui: return []
    base_url = f"https://rxnav.nlm.nih.gov/REST/rxcui/{rxcui}/related.json"
    params = {'tty': 'BN'}
    brands = []
    try:
        response = requests.get(base_url, params=params, timeout=config["api_timeout_seconds"])
        response.raise_for_status()
        data = response.json()
        if ('relatedGroup' in data and 'conceptGroup' in data['relatedGroup']):
            for group in data['relatedGroup']['conceptGroup']:
                if group.get('tty') == 'BN' and 'conceptProperties' in group:
                    for prop in group['conceptProperties']:
                        if 'name' in prop: brands.append(prop['name'])
    except requests.exceptions.HTTPError as http_err:
        if response.status_code != 404:
             print(f"HTTP error fetching brands for RxCUI {rxcui}: {http_err}")
    except requests.exceptions.Timeout:
        print(f"Timeout")
    except requests.exceptions.RequestException as e:
        print(f"API request failed: {e}")
    except Exception as e:
        print(f"error")
    return brands

In [21]:
def fetch_brand_mappings(generic_names_list, config):
    if not generic_names_list: return []
    all_mappings = []
    processed_count = 0
    found_count = 0
    print(f"API lookups for {len(generic_names_list)} names")

    for name in tqdm(generic_names_list, desc="Fetching Brand Mappings"):
        processed_count += 1
        if not isinstance(name, str) or not name.strip() or len(name.strip()) < 2:
             continue #Skiping over invalid names

        generic_name_cleaned = name.strip()
        current_rxcui = get_rxcui(generic_name_cleaned, config)
        time.sleep(config["api_delay_seconds"]) #Delay

        if current_rxcui:
            brands = get_brand_names(current_rxcui, config)
            time.sleep(config["api_delay_seconds"]) #Delay

            if brands:
                selected_brand = brands[0].strip().lower()
                if selected_brand:
                    all_mappings.append({'generic_name': name, 'brand_name': selected_brand})
                    found_count += 1

    print(f"API lookups completed and found {found_count}/{len(generic_names_list)} names")
    return all_mappings

# **Workflow - Step 1: Fine-tuning**

In [None]:
df_classification = load_parquet_data(CONFIG["classification_data_path"])
if df_classification is not None:
    tokenizer, model = load_tokenizer_and_model(CONFIG["base_model_name"], CONFIG["num_labels"])

    if tokenizer and model:
        train_ds, test_ds = prepare_hf_dataset(df_classification, tokenizer, CONFIG)

        if train_ds and test_ds:
            model = train_model(model, tokenizer, train_ds, test_ds, CONFIG)
            if model:
                print("Finetuning completed")
            else:
                print("Model training failed")
        else:
            print("Dataset preparation failed")
    else:
        print("Base model or tokenizer loading failed")
else:
    print("Classification data loading failed")

# **Workflow - Step 2: Drug Name Extraction and Saving**

In [None]:
df_relation = load_parquet_data(CONFIG["relation_data_path"])
extracted_drugs = []
if df_relation is not None:
    extracted_drugs = extract_drug_names(df_relation)
    if extracted_drugs:
        save_json(extracted_drugs, CONFIG["extracted_drugs_file"])
        print(f"Saved {len(extracted_drugs)} drug names to {CONFIG['extracted_drugs_file']}")
    else:
        print("No drug names were extracted")
else:
    print("Relation data loading failed")

# **Workflow - Step 3: Drug Masking Analysis (ADE Sentences)**

In [None]:
if 'df_classification' not in locals() or df_classification is None:
    print("reloading classification data for masking analysis")
    df_classification = load_parquet_data(CONFIG["classification_data_path"])

if 'extracted_drugs' not in locals() or not extracted_drugs:
     print(f"Loading extracted drug names from {CONFIG['extracted_drugs_file']}")
     extracted_drugs = load_json(CONFIG["extracted_drugs_file"])

if 'model' not in locals() or 'tokenizer' not in locals() or model is None or tokenizer is None:
    print(f"loading finetuned model from {CONFIG['finetuned_model_dir']}...")
    tokenizer, model = load_tokenizer_and_model(CONFIG["finetuned_model_dir"], CONFIG["num_labels"])

#Creating model wrapper or reusing
if 'model_wrapper' not in locals() or model_wrapper is None:
    model_wrapper = create_model_wrapper(model, tokenizer)


if df_classification is not None and extracted_drugs and model_wrapper and tokenizer:
    # Filter for ADE sentences
    ade_df = df_classification[df_classification['label'] == CONFIG['ade_label']].copy()
    print(f"Processing {len(ade_df)} ADE sentences for masking analysis")

    if not ade_df.empty:
        # Apply masking
        mask_results = ade_df['text'].progress_apply(
            lambda x: mask_drug_names(x, extracted_drugs, tokenizer.mask_token)
        )
        ade_df['masked_text'] = mask_results.apply(lambda x: x[0])
        ade_df['drug_found_flag'] = mask_results.apply(lambda x: x[1])

        # Filter df to rows where drug was actually found and masked
        ade_df_masked = ade_df[ade_df['drug_found_flag'] == True].copy()
        print(f"Found drugs and applied masking to {len(ade_df_masked)} ADE sentences.")

        if not ade_df_masked.empty:
             # Calculate flip rate
             _, flip_examples, non_flip_examples = calculate_flip_rate(
                  ade_df_masked, 'text', 'masked_text', model_wrapper, "Masking (ADE Only)"
             )
             # Save results
             save_json(flip_examples, CONFIG["masking_flip_ade_file"])
             save_json(non_flip_examples, CONFIG["masking_nonflip_ade_file"])
             print(f"Saved masking analysis examples for ADE sentences.")
        else:
             print("No ADE sentences had detectable drug names from the list. Skipping flip rate calculation.")
    else:
        print("No ADE sentences found in the dataset.")
else:
    print("Skipping masking analysis due to missing data, drug names, or model.")

# **Workflow - Step 4: Drug Masking Analysis (Non ADE Sentences)**

In [None]:
if 'df_classification' not in locals() or df_classification is None:
    print("Reloading classification data for masking analysis...")
    df_classification = load_parquet_data(CONFIG["classification_data_path"])

if 'extracted_drugs' not in locals() or not extracted_drugs:
     print(f"Loading extracted drug names from {CONFIG['extracted_drugs_file']}")
     extracted_drugs = load_json(CONFIG["extracted_drugs_file"])

if 'model' not in locals() or 'tokenizer' not in locals() or model is None or tokenizer is None:
    print(f"Loading finetuned model from {CONFIG['finetuned_model_dir']}")
    tokenizer, model = load_tokenizer_and_model(CONFIG["finetuned_model_dir"], CONFIG["num_labels"])

if 'model_wrapper' not in locals() or model_wrapper is None:
    model_wrapper = create_model_wrapper(model, tokenizer)

if "masking_flip_non_ade_file" not in CONFIG or "masking_nonflip_non_ade_file" not in CONFIG:
    raise KeyError("Config Error")

if df_classification is not None and extracted_drugs and model_wrapper and tokenizer:
    # Filter for Non-ADE sentences
    non_ade_df = df_classification[df_classification['label'] != CONFIG['ade_label']].copy()
    print(f"Processing {len(non_ade_df)}")

    if not non_ade_df.empty:
        mask_results = non_ade_df['text'].progress_apply(
            lambda x: mask_drug_names(x, extracted_drugs, tokenizer.mask_token)
        )
        non_ade_df['masked_text'] = mask_results.apply(lambda x: x[0])
        non_ade_df['drug_found_flag'] = mask_results.apply(lambda x: x[1])
        non_ade_df_masked = non_ade_df[non_ade_df['drug_found_flag'] == True].copy()
        print(f"Found drugs and applied masking to {len(non_ade_df_masked)} Non-ADE sentences")

        if not non_ade_df_masked.empty:
             # Calculate flip rate

             _, flip_examples, non_flip_examples = calculate_flip_rate(
                  non_ade_df_masked, 'text', 'masked_text', model_wrapper, "Masking (Non-ADE Only)"
             )
             # Save results
             save_json(flip_examples, CONFIG["masking_flip_non_ade_file"])
             save_json(non_flip_examples, CONFIG["masking_nonflip_non_ade_file"])
             print(f"Saved masking analysis examples for Non-ADE sentences.")
        else:
             print("No drug names")
    else:
        print("No Non-ADE sentences found in the dataset")
else:
    print("error in masking analysis due to missing data")

# **Workflow - Step 5: Fetch and Save Brand Name Mappings**

In [None]:
if 'extracted_drugs' not in locals() or not extracted_drugs:
     print(f"Loading extracted drug names from {CONFIG['extracted_drugs_file']}")
     extracted_drugs = load_json(CONFIG["extracted_drugs_file"])

brand_mappings_list = []
if extracted_drugs:
    brand_mappings_list = fetch_brand_mappings(extracted_drugs, CONFIG)
    if brand_mappings_list:
        save_json(brand_mappings_list, CONFIG["brand_mapping_file"])
        print(f"Saved {len(brand_mappings_list)} brand mappings to {CONFIG['brand_mapping_file']}")
    else:
        print("No brand mappings were etched")
else:
    print("extracted drug list is missing")

# **Workflow - Step 6: Apply Brand Name Replacement**

In [None]:
if 'df_classification' not in locals() or df_classification is None or 'text_brand_replaced' in df_classification.columns:
    print(f"Loading classification data from {CONFIG['classification_data_path']}")
    df_classification = load_parquet_data(CONFIG["classification_data_path"])

if 'brand_mappings_list' not in locals() or not brand_mappings_list:
     print(f"loading brand mappings from {CONFIG['brand_mapping_file']}")
     brand_mappings_list = load_json(CONFIG["brand_mapping_file"])

if df_classification is not None and brand_mappings_list:
    brand_mapping_dict = {item['generic_name'].lower(): item['brand_name']
                          for item in brand_mappings_list if 'generic_name' in item and 'brand_name' in item}

    if brand_mapping_dict:
        #Preparing for replacement
        lower_map_dict, compiled_patterns, sorted_keys = prepare_brand_replacement(brand_mapping_dict)

        print("Applying generic-to-brand replacement")
        #Applying replacement function
        replacement_results = df_classification['text'].progress_apply(
            lambda x: replace_generic_with_brand(x, lower_map_dict, compiled_patterns, sorted_keys)
        )
        df_classification['text_brand_replaced'] = replacement_results.apply(lambda x: x[0])
        df_classification['brand_replaced_flag'] = replacement_results.apply(lambda x: x[1])
        replaced_count = df_classification['brand_replaced_flag'].sum()
        print(f"{replaced_count} sentences had replacements")
        try:
            check_and_create_dir(os.path.dirname(CONFIG["modified_data_file"])) # Ensure dir exists
            df_classification.to_parquet(CONFIG["modified_data_file"], index=False)
            print(f"saved modified datato {CONFIG['modified_data_file']}")
        except Exception as e:
            print(f"error saving modified data: {e}")
    else:
        print("brand drug names dictionary is empty")
else:
    print("missing data")

# **Workflow - Step 7: Brand Replacement Flip Rate (All Sentences)**


In [None]:
if 'df_classification' not in locals() or 'brand_replaced_flag' not in df_classification.columns:
     print(f"Loading modified data from {CONFIG['modified_data_file']}")
     # Check if file exists before loading
     if os.path.exists(CONFIG['modified_data_file']):
         df_classification = load_parquet_data(CONFIG["modified_data_file"])
     else:
         print(f"data not found {CONFIG['modified_data_file']}")
         df_classification = None

if 'model_wrapper' not in locals() or model_wrapper is None:
    print(f"Reloading fine-tuned model from {CONFIG['finetuned_model_dir']}...")
    tokenizer, model = load_tokenizer_and_model(CONFIG["finetuned_model_dir"], CONFIG["num_labels"])
    model_wrapper = create_model_wrapper(model, tokenizer)


if df_classification is not None and 'brand_replaced_flag' in df_classification.columns and model_wrapper:
    # Filter for rows where replacement actually happened
    df_replaced_all = df_classification[df_classification['brand_replaced_flag'] == True].copy()
    print(f"Analyzing {len(df_replaced_all)} sentences where brand replacement occurred.")

    if not df_replaced_all.empty:
         # Calculate flip rate
         _, flip_examples, non_flip_examples = calculate_flip_rate(
              df_replaced_all, 'text', 'text_brand_replaced', model_wrapper, "Brand Replace (All)"
         )
         # Save results
         save_json(flip_examples, CONFIG["brand_flip_all_file"])
         save_json(non_flip_examples, CONFIG["brand_nonflip_all_file"])
         print(f"Saved brand replacement flip analysis")
    else:
         print("No brand replacements")
else:
    print("missing data")

# **Workflow - Step 8: Brand Replacement Flip Rate (ADE Sentences Only)**

In [None]:
required_cols = ['label', 'brand_replaced_flag', 'text', 'text_brand_replaced']
if 'df_classification' not in locals() or not all(col in df_classification.columns for col in required_cols):
     print(f"Loading modified data from {CONFIG['modified_data_file']}")
     if os.path.exists(CONFIG['modified_data_file']):
        df_classification = load_parquet_data(CONFIG["modified_data_file"])
     else:
         print(f"data not found {CONFIG['modified_data_file']}")
         df_classification = None

if 'model_wrapper' not in locals() or model_wrapper is None:
    print(f"reloading finetuned model from {CONFIG['finetuned_model_dir']}")
    tokenizer, model = load_tokenizer_and_model(CONFIG["finetuned_model_dir"], CONFIG["num_labels"])
    model_wrapper = create_model_wrapper(model, tokenizer)

if df_classification is not None and all(col in df_classification.columns for col in required_cols) and model_wrapper:
    df_replaced_ade = df_classification[
        (df_classification['brand_replaced_flag'] == True) &
        (df_classification['label'] == CONFIG['ade_label'])
    ].copy()
    print(f"Analyzing {len(df_replaced_ade)} ADE sentences where brand replacement occurred.")

    if not df_replaced_ade.empty:
        #calculating flip rate
        _, flip_examples, non_flip_examples = calculate_flip_rate(
            df_replaced_ade, 'text', 'text_brand_replaced', model_wrapper, "Brand Replace (ADE Only)"
        )
        save_json(flip_examples, CONFIG["brand_flip_ade_file"])
        save_json(non_flip_examples, CONFIG["brand_nonflip_ade_file"])
        print(f"Saved brand replacement flip analysis examples (ADE sentences only).")
    else:
        print("No ADE sentences found where brand replacement occurred. Skipping analysis.")
else:
    if df_classification is None:
         print("Skipping brand replacement flip analysis (ADE Only) because DataFrame is missing.")
    elif not all(col in df_classification.columns for col in required_cols):
         print(f"Skipping brand replacement flip analysis (ADE Only) because DataFrame is missing required columns: {required_cols}")
    elif model_wrapper is None:
         print("Skipping brand replacement flip analysis (ADE Only) because Model Wrapper is missing.")
    else:
         print("Skipping brand replacement flip analysis (ADE Only) due to missing data or model.")