# BERT Model Train

In [1]:
import pandas as pd
import numpy as np
import torch
import random
import re
from transformers import BertTokenizer, BertForSequenceClassification, Trainer, TrainingArguments
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix, classification_report
from imblearn.over_sampling import SMOTE

# Function to set seeds for reproducibility
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

# Set the seed
SEED = 42
set_seed(SEED)

# Ensure determinism in PyTorch
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# Function to clean text by removing punctuations
def clean_text(text):
    text = str(text)
    text = re.sub(r'[^\w\s]', '', text)
    return text

# Function to tokenize text
def tokenize_function(text):
    return tokenizer(text, padding="max_length", truncation=True, max_length=512)

# Function to compute metrics
def compute_metrics(pred):
    labels = pred.label_ids
    preds = np.argmax(pred.predictions, axis=-1)
    accuracy = accuracy_score(labels, preds)
    f1 = f1_score(labels, preds, average='weighted')
    precision = precision_score(labels, preds, average='weighted')
    recall = recall_score(labels, preds, average='weighted')
    
    # Returning metrics as a dictionary
    return {
        'accuracy': accuracy, 
        'f1': f1, 
        'precision': precision, 
        'recall': recall
    }

# Initialize the BERT tokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

# Load the dataset
file_path = '/data/Bert_Model_Train.csv'
data = pd.read_csv(file_path, encoding='ISO-8859-1')

# Clean and tokenize the Context column
data['Context'] = data['Context'].apply(clean_text)
data['tokenized'] = data.apply(lambda row: tokenize_function(row['Name'] + " " + row['Context']), axis=1)

# Prepare data for oversampling
data['input_ids'] = data['tokenized'].apply(lambda x: x['input_ids'])
data['attention_mask'] = data['tokenized'].apply(lambda x: x['attention_mask'])

# Map labels
label_mapping = {'BCRP_substrate': 0, 'BCRP_inhibitor': 1}
data['Activity_encoded'] = data['Activity'].map(label_mapping)

# Define dataset class
class DrugDataset(torch.utils.data.Dataset):
    def __init__(self, input_ids, attention_mask, labels):
        self.input_ids = input_ids
        self.attention_mask = attention_mask
        self.labels = labels

    def __getitem__(self, idx):
        return {
            'input_ids': torch.tensor(self.input_ids[idx], dtype=torch.long),
            'attention_mask': torch.tensor(self.attention_mask[idx], dtype=torch.long),
            'labels': torch.tensor(self.labels[idx], dtype=torch.long)
        }

    def __len__(self):
        return len(self.labels)

# Load the BERT model
model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=2)

# Training arguments
training_args = TrainingArguments(
    output_dir='./results',
    num_train_epochs=3,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    warmup_steps=20,
    weight_decay=0.01,
    evaluation_strategy="epoch",
    logging_dir='./logs',
    learning_rate=2e-5,
    logging_steps=10,
)

# Apply cross-validation
skf = StratifiedKFold(n_splits=2, shuffle=True, random_state=SEED)
for fold, (train_idx, val_idx) in enumerate(skf.split(data, data['Activity_encoded'])):
    train_data = data.iloc[train_idx]
    val_data = data.iloc[val_idx]

    # Oversampling with SMOTE
    smote = SMOTE(random_state=SEED)
    train_resampled, train_labels = smote.fit_resample(
        np.hstack((np.array(train_data['input_ids'].tolist()), np.array(train_data['attention_mask'].tolist()))),
        train_data['Activity_encoded']
    )

    val_resampled, val_labels = smote.fit_resample(
        np.hstack((np.array(val_data['input_ids'].tolist()), np.array(val_data['attention_mask'].tolist()))),
        val_data['Activity_encoded']
    )

    # Split input_ids and attention_mask after oversampling
    train_input_ids = train_resampled[:, :512]
    train_attention_mask = train_resampled[:, 512:]
    val_input_ids = val_resampled[:, :512]
    val_attention_mask = val_resampled[:, 512:]

    # Convert to DrugDataset format
    train_dataset = DrugDataset(train_input_ids, train_attention_mask, train_labels)
    val_dataset = DrugDataset(val_input_ids, val_attention_mask, val_labels)

    # Initialize the Trainer
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=val_dataset,
        compute_metrics=compute_metrics)

    # Train the model
    print(f"Training on fold {fold+1}")
    trainer.train()

    # Evaluate the model
    train_results = trainer.evaluate(train_dataset)
    val_results = trainer.evaluate(val_dataset)

    # Print training results
    print(f"Results for fold {fold+1}:")
    print("Training Results:")
    print(f"Accuracy: {train_results['eval_accuracy']}")
    print(f"F1: {train_results['eval_f1']}")
    print(f"Precision: {train_results['eval_precision']}")
    print(f"Recall: {train_results['eval_recall']}")

    train_predictions = trainer.predict(train_dataset)
    train_preds = np.argmax(train_predictions.predictions, axis=-1)
    train_labels = train_predictions.label_ids
    
    # Compute and print confusion matrices
    train_conf_matrix = confusion_matrix(train_labels, train_preds)
    print(f"Confusion Matrix for Training Data on Fold {fold+1}:")
    print(train_conf_matrix)

    # Print Validation results
    print("Validation Results:")
    print(f"Accuracy: {val_results['eval_accuracy']}")
    print(f"F1: {val_results['eval_f1']}")
    print(f"Precision: {val_results['eval_precision']}")
    print(f"Recall: {val_results['eval_recall']}")
    
    # Prediction on validation data
    val_predictions = trainer.predict(val_dataset)
    val_preds = np.argmax(val_predictions.predictions, axis=-1)
    val_labels = val_predictions.label_ids
    
    # Compute and print confusion matrices
    val_conf_matrix = confusion_matrix(val_labels, val_preds)
    
    print(f"Confusion Matrix for Validation Data on Fold {fold+1}:")
    print(val_conf_matrix)




Downloading tokenizer_config.json:   0%|          | 0.00/28.0 [00:00<?, ?B/s]

Downloading vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

Downloading tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

Downloading config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

Downloading model.safetensors:   0%|          | 0.00/440M [00:00<?, ?B/s]

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Detected kernel version 4.14.326, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


Training on fold 1


Epoch,Training Loss,Validation Loss,Accuracy,F1,Precision,Recall
1,No log,0.673767,0.5,0.333333,0.25,0.5
2,0.687200,0.603976,0.833333,0.833043,0.835664,0.833333
3,0.687200,0.511473,0.854167,0.853595,0.859788,0.854167


  _warn_prf(average, modifier, msg_start, len(result))


Results for fold 1:
Training Results:
Accuracy: 0.9166666666666666
F1: 0.9160839160839161
Precision: 0.9285714285714285
Recall: 0.9166666666666666
Confusion Matrix for Training Data on Fold 1:
[[20  4]
 [ 0 24]]
Validation Results:
Accuracy: 0.8541666666666666
F1: 0.85359477124183
Precision: 0.8597883597883597
Recall: 0.8541666666666666


Detected kernel version 4.14.326, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


Confusion Matrix for Validation Data on Fold 1:
[[19  5]
 [ 2 22]]
Training on fold 2


Epoch,Training Loss,Validation Loss,Accuracy,F1,Precision,Recall
1,No log,0.450366,0.9375,0.937255,0.944444,0.9375
2,0.516400,0.401578,0.916667,0.916084,0.928571,0.916667
3,0.516400,0.307943,0.979167,0.979158,0.98,0.979167


Results for fold 2:
Training Results:
Accuracy: 0.9583333333333334
F1: 0.9582608695652176
Precision: 0.9615384615384616
Recall: 0.9583333333333334
Confusion Matrix for Training Data on Fold 2:
[[24  0]
 [ 2 22]]
Validation Results:
Accuracy: 0.9791666666666666
F1: 0.9791576204950064
Precision: 0.98
Recall: 0.9791666666666666


Confusion Matrix for Validation Data on Fold 2:
[[24  0]
 [ 1 23]]


# API and NLP Workflow

In [4]:
import requests
from Bio import Entrez
from bs4 import BeautifulSoup
import pandas as pd
import spacy
import warnings
from drug_named_entity_recognition import find_drugs
import collections
import string
import time

# Function to fetch Canonical SMILES and Chemical ID for a given drug name from PubChem
def get_canonical_smiles_and_chem_id(drug_name):
    base_url = "https://pubchem.ncbi.nlm.nih.gov/rest/pug/compound/name"
    # Fetch Canonical SMILES
    smiles_response = requests.get(f"{base_url}/{drug_name}/property/CanonicalSMILES/JSON")
    if smiles_response.status_code == 200:
        canonical_smiles = smiles_response.json()['PropertyTable']['Properties'][0]['CanonicalSMILES']
    else:        
        canonical_smiles = None
        pass
    time.sleep(0.2)
    # Fetch Chemical ID
    cid_response = requests.get(f"{base_url}/{drug_name}/cids/JSON")
    if cid_response.status_code == 200:
        chem_id = cid_response.json()['IdentifierList']['CID'][0]
    else:
        chem_id = None
        pass

    return canonical_smiles, chem_id

# Function to search PubMed Central with a given query
def search_pubmed_central(query, max_results=1000):
    Entrez.email = "vishesh.walia@outlook.com"  # Set your email here
    terms = query.split()
    and_query = " AND ".join([f'"{term}"[Title]' for term in terms])
    handle = Entrez.esearch(db="pmc", term=and_query, retmax=max_results)
    record = Entrez.read(handle)
    handle.close()
    pmc_ids = record['IdList']
    return pmc_ids

# Function to predict the class label using the model
def predict_with_model(model, tokenizer, drug_names, contexts):
    # Concatenate drug_name and context for each pair
    combined_texts = [drug_name + " " + context for drug_name, context in zip(drug_names, contexts)]

    # Tokenize the combined input
    encodings = tokenizer(combined_texts, padding=True, truncation=True, max_length=512, return_tensors="pt")
    
    # Predict with the model
    with torch.no_grad():
        outputs = model(**encodings)
    
    # Convert logits to probabilities and get predicted class labels
    probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1)
    predictions = torch.argmax(probabilities, dim=1)
    return predictions

# Function to fetch and process an article from PubMed Central
def fetch_and_process_article(pmc_id, cache, nlp):
    headers = {'User-Agent': 'Mozilla/5.0'}
    url = f"https://www.ncbi.nlm.nih.gov/pmc/articles/{pmc_id}/?report=printable"
    time.sleep(1/3)
    with requests.get(url, headers=headers) as response:
        if response.status_code == 200:
            html_content = response.text
            soup = BeautifulSoup(html_content, 'html.parser')
            text = soup.get_text()
            title_tag = soup.find('h1')
            article_title = title_tag.get_text().strip() if title_tag else "No Title Found"

            # Preprocess the context by removing punctuation and leading/trailing spaces
            doc = nlp(text)
            context = " ".join([token.text.strip(string.punctuation) for token in doc])

            # Identify drug names within the text
            drugs_info = find_drugs(context.split(), is_ignore_case=True)

            keywords = ["inhibitor", "transport inhibitor", "inhibits", "non transporter", 
                        "unable to transport", "bcrp", "non inhibitor", "substrate", 
                        "transport substrate", "transporter", 'substrates']
            
            unique_drug_context_pairs = set()
            article_drugs = []

            # Extract relevant drug-context pairs and fetch SMILES and ChemID
            for drug, start, end in drugs_info:
                drug_pos = context.lower().find(drug['name'].lower())
                if drug_pos != -1:
                    start_pos = max(drug_pos - 300, 0)
                    end_pos = min(drug_pos + 300, len(context))
                    drug_context = context[start_pos:end_pos].lower()

                    if any(keyword in drug_context for keyword in keywords):
                        drug_context_key = (drug['name'].lower(), drug_context)
                        if drug_context_key not in unique_drug_context_pairs:
                            unique_drug_context_pairs.add(drug_context_key)
                            canonical_smiles, chem_id = cache.get(drug['name'], get_canonical_smiles_and_chem_id(drug['name']))
                            cache[drug['name']] = (canonical_smiles, chem_id)
                            article_drugs.append({
                                'PMC ID': pmc_id, 
                                'Drug Name': drug['name'], 
                                'Context': drug_context, 
                                'Canonical SMILES': canonical_smiles, 
                                'ChemID': chem_id
                            })

            return article_title, article_drugs
        else:
            print(f"Failed to fetch article for PMC ID {pmc_id}. Status code: {response.status_code}")
            return "", []



# Initialize Spacy model outside of the function for efficiency
nlp = spacy.blank("en")


def main(search_queries):
    max_results = 1000
    all_drugs = []
    all_references = []
    processed_pmc_ids = set()  # Set to track processed PMC IDs
    cache = {}  # Cache for storing already fetched SMILES and ChemIDs
    nlp = spacy.blank("en")  # Load Spacy language model
    warnings.filterwarnings('ignore')  # Optional to suppress warnings

    for query in search_queries:
        pmc_ids = search_pubmed_central(query, max_results)
        for pmc_id in pmc_ids:
            if pmc_id not in processed_pmc_ids:
                article_title, article_drugs = fetch_and_process_article(pmc_id, cache, nlp)
                all_drugs.extend(article_drugs)  # Add unique drug entries to all_drugs list
                all_references.append({'Index': len(all_references) + 1, 'PMC ID': pmc_id, 'Article Title': article_title})
                processed_pmc_ids.add(pmc_id)  # Mark the PMC ID as processed

    df = pd.DataFrame(all_drugs)
    references_df = pd.DataFrame(all_references)
    references_dict = references_df.set_index('PMC ID')['Index'].to_dict()

    #Assuming the existence of a model and tokenizer for prediction
    predicted_labels = predict_with_model(model, tokenizer, df['Drug Name'].tolist(), df['Context'].tolist())
    reverse_label_mapping = {0: 'BCRP_substrate', 1: 'BCRP_inhibitor'}
    df['Predicted Label'] = [reverse_label_mapping[label.item()] for label in predicted_labels]
    
    #Group by Drug Name and determine the most frequent label
    df_grouped = df.groupby(['Drug Name', 'Predicted Label']).size().reset_index(name='Count')
    df_max_label = df_grouped.loc[df_grouped.groupby('Drug Name')['Count'].idxmax()]

    #Aggregate PMC IDs, Canonical SMILES, and Chemical ID
    df_agg = df.groupby('Drug Name').agg({
        'PMC ID': lambda x: ', '.join(sorted(set(x))),
        'Canonical SMILES': 'first',
        'ChemID': 'first'
    }).reset_index()

    df_combined = df_max_label.merge(df_agg, on='Drug Name')

    def pmc_to_ref_index(pmc_ids):
        pmc_ids = pmc_ids.split(', ')
        return [references_dict.get(pmc_id) for pmc_id in pmc_ids]

    df_combined['Ref Indexes'] = df_combined['PMC ID'].apply(pmc_to_ref_index)
    df_combined['Ref Indexes'] = df_combined['Ref Indexes'].apply(lambda x: [i for i in x if i is not None])

    ref_index_columns = df_combined['Ref Indexes'].apply(pd.Series)
    ref_index_columns = ref_index_columns.rename(columns=lambda x: 'Ref' + str(x + 1))

    df_final = pd.concat([df_combined.drop(['Ref Indexes', 'PMC ID'], axis=1), ref_index_columns], axis=1)
    df_final = df_final.drop(columns=['Count'])
    ref_index_columns = [col for col in df_final.columns if col.startswith('Ref')]
    df_final[ref_index_columns] = df_final[ref_index_columns].fillna(0).applymap(lambda x: int(x) if not pd.isna(x) else pd.NA)

    return df_final, references_df

if __name__ == "__main__":
    search_queries = ["BCRP DRUG TRANSPORT", "ABCG2 DRUG TRANSPORT", "Drug Response ABCG2", "Drug Response BCRP", "BCRP Efflux Transport", 'ABCG2 Efflux',
                      "Multidrug resistance ATP-binding cassette transporters", 'Multidrug BCRP','Multidrug ABCG2', 'BCRP Efflux', 'BCRP Resistance']
    df_final, references_df = main(search_queries)
    references_df.to_csv('/results/references.tsv', sep = '\t')

    df_final_dedup = df_final.drop_duplicates(subset='Drug Name')

    # Calculations for metrics_summary
    team_name = "Model Kombat"
    team_contact = ["vxw220000@utdallas.edu", "axc220027@utdallas.edu"]
    num_samples = df_final_dedup['Drug Name'].nunique()
    num_references = references_df['PMC ID'].nunique()
    num_inhibitor = df_final_dedup[df_final_dedup['Predicted Label'] == 'BCRP_inhibitor'].shape[0]
    num_substrate = df_final_dedup[df_final_dedup['Predicted Label'] == 'BCRP_substrate'].shape[0]
    percent_minority = min(num_inhibitor, num_substrate) / num_samples if num_samples > 0 else 0
    num_smiles = df_final_dedup[df_final_dedup['Canonical SMILES'].notna() & (df_final_dedup['Canonical SMILES'] != 'missing') & (df_final_dedup['Canonical SMILES'] != 'not available')].shape[0]
    percent_smiles = num_smiles / num_samples if num_samples > 0 else 0

    # Creating metrics_summary dataframe
    metrics_summary = pd.DataFrame({
        "team_name": [team_name],
        "team_contact": [team_contact],
        "num_samples": [num_samples],
        "num_references": [num_references],
        "num_inhibitor": [num_inhibitor],
        "num_substrate": [num_substrate],
        "percent_minority": [percent_minority],
        "num_smiles": [num_smiles],
        "percent_smiles": [percent_smiles]
    })

    metrics_summary.to_csv('/results/metrics_summary.tsv', sep = '\t')

    # Renaming and reordering columns for training_data
    df_final.rename(columns={
        'Drug Name': 'Name',
        'Canonical SMILES': 'SMILES',
        'Predicted Label': 'Activity',
        'ChemID': 'PubChem_CID'
    }, inplace=True)
    
    # Assuming 'Ref Indexes' is a list of references for each drug
    # Expand this list into separate columns (Ref1, Ref2, etc.)
    df_final.to_csv('/results/training_data.tsv', sep = '\t')



# Metrics Summary

In [None]:
metrics_summary

# Validation

In [7]:


# Extract relevant columns from training data for comparison
file_path = '/data/Validation_Data.tsv'
training_data = pd.read_csv(file_path, sep = '\t')
training_data_relevant = training_data[['Name', 'Activity']].rename(columns={'Activity': 'Activity_actual'})

# Merge the two dataframes on the 'Name' column to find matching entries
validation_df = pd.merge(df_final, training_data_relevant, on='Name', how='inner')

# Calculate the number of matching predictions
matching_predictions = validation_df[validation_df['Activity'] == validation_df['Activity_actual']].shape[0]

# Total number of drug names in df_final that are also in the training data
total_matching_drug_names = validation_df.shape[0]

print(f'Total values in validation data: {training_data_relevant.shape[0]}')
print(f'Matching drug names in extracted data & validation data: {total_matching_drug_names}')
print(f'Matching substrates/inhibitors in extracted data & validation data: {matching_predictions}')
print('Accuracy: ' + str(round(matching_predictions/total_matching_drug_names,2)*100) + '%' + ' on validation set')


Total values in validation data: 51
Matching drug names in extracted data & validation data: 35
Matching substrates/inhibitors in extracted data & validation data: 28
Accuracy: 80.0% on validation set
