# DNA-based Antimicrobial Resistance Prediction - Colab Version

This notebook provides a direct interface for AMR prediction in Google Colab.

In [None]:
# Install required packages
!pip install -q tensorflow biopython scikit-learn pandas numpy requests

# Clone the repository
!git clone https://github.com/tmone/amr-system.git
%cd amr-system

import sys
sys.path.append('./python-app')

In [None]:
# Import all required packages
import pandas as pd
import numpy as np
from Bio import SeqIO
import tensorflow as tf
from IPython.display import HTML, display
import ipywidgets as widgets
import matplotlib.pyplot as plt
from google.colab import files
import io
import requests
import tarfile
import gzip
from pathlib import Path
from tqdm.notebook import tqdm

print("All required packages imported successfully")

In [None]:
from model import AMRModel
import requests
import tarfile
import gzip
from pathlib import Path
from tqdm.notebook import tqdm
import pandas as pd

class DataDownloader:
    def __init__(self, base_url="https://card.mcmaster.ca/download/0/", 
                 filename="broadstreet-v3.2.4.tar.bz2",
                 output_dir="data"):
        self.base_url = base_url
        self.filename = filename
        self.output_dir = Path(output_dir)
        self.output_dir.mkdir(parents=True, exist_ok=True)

    def download_file(self):
        """Download data file with progress bar"""
        url = f"{self.base_url}{self.filename}"
        target_path = self.output_dir / self.filename
        
        if target_path.exists():
            print("File already downloaded")
            return target_path
            
        print(f"Downloading from {url}")
        response = requests.get(url, stream=True)
        total_size = int(response.headers.get('content-length', 0))
        
        with open(target_path, 'wb') as file:
            for data in tqdm(response.iter_content(1024), total=total_size//1024, unit='KB', desc=self.filename):
                file.write(data)
                
        print("Download completed")
        return target_path

    def extract_files(self, archive_path):
        """Extract downloaded tar.bz2 file"""
        print(f"Extracting {archive_path}")
        with tarfile.open(archive_path, 'r:bz2') as tar:
            tar.extractall(path=self.output_dir)
        print("Extraction completed")

    def process_fasta(self):
        """Process extracted FASTA files"""
        fasta_files = list(self.output_dir.glob('**/*.fasta'))
        sequences = []
        labels = []
        
        for fasta in fasta_files:
            if 'resistant' in fasta.stem.lower():
                label = 1  # Resistant
            else:
                label = 0  # Sensitive
                
            with open(fasta) as f:
                current_seq = ''
                for line in f:
                    if line.startswith('>'):
                        if current_seq:
                            sequences.append(current_seq)
                            labels.append(label)
                        current_seq = ''
                    else:
                        current_seq += line.strip()
                if current_seq:
                    sequences.append(current_seq)
                    labels.append(label)
        return sequences, labels

    def run(self):
        """Execute complete download and processing pipeline"""
        archive_path = self.download_file()
        self.extract_files(archive_path)
        return self.process_fasta()

# Download and process data
downloader = DataDownloader()
sequences, labels = downloader.run()
print(f"Loaded {len(sequences)} sequences")

In [None]:
# Initialize and train model
model = AMRModel()

# Train with progress tracking
from IPython.display import HTML, display
import ipywidgets as widgets

progress = widgets.FloatProgress(value=0, min=0, max=100, description='Training:')
display(progress)

def update_progress(epoch, logs):
    progress.value = (epoch + 1) * 100 / epochs

epochs = 10
history = model.train(sequences, labels, epochs=epochs)
progress.value = 100

In [None]:
# Download and process data
downloader = DataDownloader()
sequences, labels = downloader.run()
print(f"Loaded {len(sequences)} sequences")

In [None]:
# Initialize and train model
model = AMRModel()

# Train with progress tracking
from IPython.display import HTML, display
import ipywidgets as widgets

progress = widgets.FloatProgress(value=0, min=0, max=100, description='Training:')
display(progress)

def update_progress(epoch, logs):
    progress.value = (epoch + 1) * 100 / epochs

epochs = 10
history = model.train(sequences, labels, epochs=epochs)
progress.value = 100

In [None]:
# Plot training results
import matplotlib.pyplot as plt

def plot_training_history(history):
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
    
    ax1.plot(history['loss'], label='Training Loss')
    ax1.plot(history['val_loss'], label='Validation Loss')
    ax1.set_title('Model Loss')
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Loss')
    ax1.legend()
    
    ax2.plot(history['accuracy'], label='Training Accuracy')
    ax2.plot(history['val_accuracy'], label='Validation Accuracy')
    ax2.set_title('Model Accuracy')
    ax2.set_xlabel('Epoch')
    ax2.set_ylabel('Accuracy')
    ax2.legend()
    
    plt.tight_layout()
    plt.show()

plot_training_history(history)

In [None]:
# Interactive prediction interface
def predict_sequence(sequence):
    try:
        result = model.predict_sequence(sequence)
        return HTML(f"""
        <div style='background:#f0f0f0;padding:10px;border-radius:5px'>
            <h3>Analysis Results:</h3>
            <p><b>Sequence:</b> {sequence[:50]}...</p>
            <p><b>Prediction:</b> {'Resistant' if result['class'] == 1 else 'Sensitive'}</p>
            <p><b>Confidence:</b> {result['confidence']:.2f}%</p>
        </div>""")
    except Exception as e:
        return f"Error: {str(e)}"

# Create interactive widgets
sequence_input = widgets.Textarea(
    value='ATGCATGCATGC',
    description='DNA Sequence:',
    layout={'width': '100%', 'height': '100px'}
)

predict_button = widgets.Button(description='Predict')
output = widgets.Output()

def on_button_click(b):
    with output:
        output.clear_output()
        display(predict_sequence(sequence_input.value))

predict_button.on_click(on_button_click)

display(sequence_input, predict_button, output)

## Batch Prediction

Upload a FASTA file for batch prediction

In [None]:
from google.colab import files
import io

def process_uploaded_fasta():
    try:
        uploaded = files.upload()
        results = []
        
        for filename, content in uploaded.items():
            print(f"Processing {filename}...")
            sequences = []
            headers = []
            current_header = ''
            current_seq = ''
            
            for line in io.StringIO(content.decode('utf-8')):
                line = line.strip()
                if line.startswith('>'):
                    if current_seq:
                        sequences.append(current_seq)
                        headers.append(current_header)
                    current_header = line[1:]
                    current_seq = ''
                else:
                    current_seq += line
                    
            if current_seq:
                sequences.append(current_seq)
                headers.append(current_header)
                
            # Process sequences with progress bar
            for header, seq in tqdm(zip(headers, sequences), total=len(sequences), desc="Predicting"):
                try:
                    result = model.predict_sequence(seq)
                    results.append({
                        'header': header,
                        'sequence': seq[:50] + '...',
                        'prediction': 'Resistant' if result['class'] == 1 else 'Sensitive',
                        'confidence': f"{result['confidence']:.2f}%"
                    })
                except Exception as e:
                    results.append({
                        'header': header,
                        'sequence': seq[:50] + '...',
                        'prediction': 'Error',
                        'confidence': str(e)
                    })
        
        # Create and display results DataFrame
        if results:
            df = pd.DataFrame(results)
            display(HTML("<h3>Prediction Results:</h3>"))
            display(HTML(df.to_html(index=False)))
            
            # Save results to CSV
            output_file = 'prediction_results.csv'
            df.to_csv(output_file, index=False)
            print(f"\nResults saved to {output_file}")
        else:
            print("No sequences processed")
            
    except Exception as e:
        print(f"Error processing file: {str(e)}")

print("Upload a FASTA file for batch prediction:")
process_uploaded_fasta()