## **Project Overview**

This project involves the use of a finetuned LLaMA model on antibiotic and bacteria information to assist doctors in analyzing antibiotic resistance based on mutation data from bacterial isolates. The process integrates both numerical (0/1) data for isolates and text-based insights from a language model. The pipeline combines these two types of information to improve the accuracy of antibiotic resistance prediction.




In [None]:
!pip install ipdb transformers torch pandas numpy accelerate

In [None]:
import torch
import torch.nn.functional as F
import torch.optim as optim
from transformers import AutoModelForCausalLM, AutoTokenizer
import pandas as pd
import numpy as np
from tqdm.auto import tqdm


**Pipeline Steps**

1. **Configuration Setup and Data Loading**
   - The project starts by loading configurations for the model, including hyperparameters and paths for data files. This configuration helps streamline adjustments during experimentation.

In [None]:
def load_config(config_file=None):
    if config_file:
        with open(config_file, 'r') as f:
            config = yaml.safe_load(f)
    else:
        config = {
            "model_name": "microsoft/biogpt",
            "learning_rate": 1e-5,
            "epochs": 2,
            "beta_kl": 0.1,
            "alpha": 1.0,
            "data_file": "/workspace/train.csv"
        }
    return config


2. **Training Stage I: Stylistic Token Generation**
   - The LLaMA model is initially trained to produce stylistic tokens, aiming to generate responses relevant to pharmacology. This training uses KL divergence to compare outputs between the model and a reference model, encouraging the model to learn specialized vocabulary and context for antibiotic resistance.




In [None]:
# Stage I: Train initial biogpt model to produce stylistic tokens using the KL Div Loss
def train_style(model, tokenizer, reference_model, data, epochs=2, lr=1e-5, beta_kl=0.1):
    optimizer = optim.AdamW(model.parameters(), lr=lr)
    model.train()
    reference_model.eval()

    for epoch in range(epochs):
        total_loss = 0.0
        for example in data:
            conversation = stage1_chat_format(example)
            conversation_text = " ".join([f"{msg['role']}: {msg['content']}" for msg in conversation])
            inputs = tokenizer(conversation_text, return_tensors="pt", padding=True, truncation=True)
            inputs = {k: v.to(model.device) for k, v in inputs.items()}
            outputs = model(**inputs)

            generated_text = tokenizer.decode(outputs.logits.argmax(dim=-1)[0], skip_special_tokens=True)


            instruction = "You are going to become an expert in pharamacy and adverse effects on medication. Learn to generate stylistic tokens for the pharmacology domain \n"
            # Calculate KL divergence between model output and reference model
            with torch.no_grad():
                ref_outputs = reference_model(**inputs)
                ref_log_probs = F.log_softmax(ref_outputs.logits, dim=-1)
            log_probs = F.log_softmax(outputs.logits, dim=-1)
            kl_loss = F.kl_div(log_probs, ref_log_probs, reduction='batchmean')


            labels = inputs['input_ids'].clone()
            # Set labels to -100 for the parts we don't want to compute loss for (user question, original answer, instruction)
            labels[:, :len(inputs['input_ids'][0]) + len(tokenizer(instruction, return_tensors="pt")['input_ids'][0])] = -100
            modified_outputs = model(**inputs, labels=modified_labels)
            cross_entropy_loss = outputs.loss

            total_loss_value = cross_entropy_loss + beta_kl * kl_loss
            optimizer.zero_grad()
            total_loss_value.backward()
            optimizer.step()

            total_loss += total_loss_value.item()
        print(f"Stage I - Epoch {epoch+1}, Loss: {total_loss:.4f}")

In [None]:
def stage1_chat_format(example):
    return [
        {"role": "user", "content": example['question']},
        {"role": "assistant", "content": example.get('original_answer', '')}
    ]

3. **Initial Answer Generation**
   - Once the model has been trained for style, it generates initial answers based on the questions provided in the dataset. These answers are used as a base to add context to the text data and are saved for future stages.




In [None]:
def generate_initial_answers(model, tokenizer, data, output_file="/workspace/train.csv", batch_size=16):
    model.eval()
    generated_data = []
    for i in tqdm(range(0, len(data), batch_size), desc="Generating initial answers"):
        batch = data[i:i+batch_size]
        questions = [example['question'] for example in batch]
        inputs = tokenizer(questions, return_tensors="pt", padding=True, truncation=True).to(model.device)

        with torch.no_grad():
            outputs = model.generate(**inputs, max_new_tokens=50, eos_token_id=tokenizer.eos_token_id)

        for idx, example in enumerate(batch):
            generated_answer = tokenizer.decode(outputs[idx], skip_special_tokens=True)
            example['original_answer'] = generated_answer[len(example['question']):].strip()
            example["correct_answer"] = example["correct_answer"]
            generated_data.append(example)

    df = pd.DataFrame(generated_data)
    df.to_csv(output_file, index=False)
    print(f"Initial answers generated and saved to {output_file}")

4. **Data Preparation**
   - A dataset class is created to handle antibiotic resistance data. It includes two sets of features:
     - A numerical vector (0/1) indicating the presence or absence of mutations.
     - Encoded text features that add contextual information about antibiotic resistance, extracted from the LLaMA model output.

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import numpy as np

class AntibioticResistanceDataset(Dataset):
    def __init__(self, isolate_features, text_features, labels):
        self.isolate_features = isolate_features
        self.text_features = text_features
        self.labels = labels

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        isolate_feature = self.isolate_features[idx]
        text_feature = self.text_features[idx]
        label = self.labels[idx]
        return isolate_feature, text_feature, label


5. **Co-Attention Mechanism**
   - A Co-Attention module is introduced to combine numerical isolate features with text-based features, helping the model to attend jointly to both sets of information. This fusion aims to leverage the strengths of both numeric and language data for more accurate predictions.

In [None]:
# Co-Attention Module
class CoAttention(nn.Module):
    def __init__(self, isolate_dim, text_dim, hidden_dim):
        super(CoAttention, self).__init__()
        self.isolate_fc = nn.Linear(isolate_dim, hidden_dim)
        self.text_fc = nn.Linear(text_dim, hidden_dim)
        self.attention = nn.Linear(hidden_dim, hidden_dim)
        self.softmax = nn.Softmax(dim=1)

    def forward(self, isolate_features, text_features):
        # Project both features to a common hidden space
        isolate_proj = torch.relu(self.isolate_fc(isolate_features))
        text_proj = torch.relu(self.text_fc(text_features))
        attention_scores = torch.relu(self.attention(isolate_proj + text_proj))
        attention_weights = self.softmax(attention_scores)
        co_attended_features = attention_weights * (isolate_proj + text_proj)
        return co_attended_features


6. **Antibiotic Resistance Prediction Model**
   - A neural network model is built, which includes the co-attention mechanism followed by fully connected layers. This model takes the co-attended features and outputs a probability that predicts antibiotic resistance.

In [None]:
class AntibioticResistancePredictor(nn.Module):
    def __init__(self, isolate_dim, text_dim, hidden_dim):
        super(AntibioticResistancePredictor, self).__init__()
        self.co_attention = CoAttention(isolate_dim, text_dim, hidden_dim)
        self.fc1 = nn.Linear(hidden_dim, hidden_dim)
        self.bn1 = nn.BatchNorm1d(hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, 1)
        self.dropout = nn.Dropout(0.3)

    def forward(self, isolate_features, text_features):
        co_attended_features = self.co_attention(isolate_features, text_features)
        x = torch.relu(self.fc1(co_attended_features))
        x = self.bn1(x)
        x = self.dropout(x)
        x = torch.sigmoid(self.fc2(x))
        return x

7. **Model Training**
   - The model is trained on the combined data using binary cross-entropy loss. This helps in minimizing the difference between predicted and true resistance labels.


In [None]:
# Training the Model
def train_co_attn(model, dataloader, criterion, optimizer, num_epochs=20):
    model.train()
    for epoch in range(num_epochs):
        running_loss = 0.0
        for isolate_features, text_features, labels in dataloader:
            isolate_features = isolate_features.float()
            text_features = text_features.float()
            labels = labels.float().view(-1, 1)

            optimizer.zero_grad()
            outputs = model(isolate_features, text_features)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item() * isolate_features.size(0)
        epoch_loss = running_loss / len(dataloader.dataset)
        print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss:.4f}")


    model.save_pretrained("./trained_for_amr")
    tokenizer.save_pretrained("./trained_for_amr")

8. **Textual Context Generation and Encoding**
   - Context for each antibiotic and isolate combination is generated using the LLaMA model in a specific textual format. The output is encoded using a MedCPT encoder to extract text features, which are then used in the combined model.

In [None]:
# Generate Context for textual encoding
def generate_context(model, tokenizer, data_initial):
    results = []
    for item in data_initial:
        input_text = f"bacteria: {item['question']}, antibiotic: {item['correct_answer']}, affected_isolates: _______\ngive a comprehensive analysis of the situation? \n"
        inputs = tokenizer(input_text, return_tensors="pt").to(model.device)
        output = model.generate(**inputs, max_length=256)
        result = tokenizer.decode(output[0], skip_special_tokens=True)
        results.append(result)
    pd.DataFrame(results, columns=['generated_analysis']).to_csv("train.csv", index=False)

# Encode Text Features using MedCPT Encoder
def encode_text_features_mecpt(model, tokenizer, data):
    encoded_features = []
    for item in data:
        inputs = tokenizer(item['generated_analysis'], return_tensors="pt").to(model.device)
        with torch.no_grad():
            outputs = model(**inputs)
            encoded_features.append(outputs.last_hidden_state.mean(dim=1).cpu().numpy())
    return np.vstack(encoded_features)

9. **Inference**
   - The trained model is used to predict resistance labels for new data, using both isolate features and encoded text features.

In [None]:
def inference(model, tokenizer, isolate_features, text_features):
    model.eval()
    with torch.no_grad():
        isolate_features = torch.tensor(isolate_features).float().to(model.device)
        text_features = torch.tensor(text_features).float().to(model.device)
        outputs = model(isolate_features, text_features)
        predictions = (outputs > 0.5).cpu().numpy().astype(int)  # Convert probabilities to binary predictions
    return predictions

In [1]:
def main(config_file=None):
    config = load_config(config_file)
    model_name = config["model_name"]
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    tokenizer.pad_token = tokenizer.eos_token

    model = AutoModelForCausalLM.from_pretrained(model_name).to(torch.device("cuda"))
    model.config.pad_token = tokenizer.pad_token
    data_file_path = config["data_file"]
    df = pd.read_csv(data_file_path)

    # Prepare the data for generating initial answers
    data_initial = df[["question", "correct_answer"]].to_dict(orient="records")

    generate_initial_answers(model, tokenizer, data_initial)
    df_generated = pd.read_csv("train.csv")
    data = df_generated[["question", "original_answer", "correct_answer"]].to_dict(orient="records")

    # Style training
    reference_model = AutoModelForCausalLM.from_pretrained(model_name).to(model.device)  # Load reference model for KL
    train_style(
        model, tokenizer, reference_model, data,
        epochs=config["epochs"],
        lr=config["learning_rate"],
        beta_kl=config["beta_kl"]
)

    model.save_pretrained("./trained_self_correcting_model")
    tokenizer.save_pretrained("./trained_self_correcting_model")


    data = generate_context(model, tokenizer, data_initial)
    text_features = encode_text_features_mecpt(model, tokenizer, data)

    isolate_features_df = pd.read_csv("amr_ast_clindamycin_CJ.csv")
    isolate_features = isolate_features_df.values
    labels = isolate_features_df['label'].values

    dataset = AntibioticResistanceDataset(isolate_features, text_features, labels)
    dataloader = DataLoader(dataset, batch_size=8, shuffle=True)

    isolate_dim = 50
    text_dim = 768
    hidden_dim = 128

    model = AntibioticResistancePredictor(isolate_dim, text_dim, hidden_dim)
    criterion = nn.BCELoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    train_co_attn(model, dataloader, criterion, optimizer, num_epochs=20)

    test_isolate_features = isolate_features[:5]
    test_text_features = text_features[:5]
    predictions = inference(model, tokenizer, test_isolate_features, test_text_features)
    print("Predictions:", predictions)


In [None]:
if __name__ == "__main__":
    main()


**Summary**

The project integrates deep learning techniques and language models to analyze antibiotic resistance. It combines numerical mutation data with text-based context generated by a finetuned LLaMA model, enhancing the prediction of resistance using a co-attention mechanism. The system aims to provide doctors with comprehensive analysis by merging structured data with contextual understanding.