# AMR Genome Dataset: Comprehensive Exploration and Analysis

![AMR Dataset](https://img.shields.io/badge/Dataset-AMR%20Genome-blue)
![Python](https://img.shields.io/badge/Python-3.8+-green)
![License](https://img.shields.io/badge/License-MIT-yellow)

This notebook provides a comprehensive exploration of the Antimicrobial Resistance (AMR) Genome Dataset, demonstrating data analysis, visualization, and machine learning applications.

## 📋 Table of Contents

1. [Dataset Overview](#dataset-overview)
2. [Data Loading and Initial Exploration](#data-loading)
3. [Genome Characteristics Analysis](#genome-analysis)
4. [AMR Gene Distribution](#amr-analysis)
5. [Resistance Phenotype Analysis](#resistance-analysis)
6. [Metadata Exploration](#metadata-analysis)
7. [Machine Learning Applications](#ml-applications)
8. [Advanced Visualizations](#visualizations)
9. [Research Applications](#research-applications)

## 🎯 Learning Objectives

By the end of this notebook, you will:
- Understand the structure and content of the AMR dataset
- Perform exploratory data analysis on genomic and AMR data
- Apply machine learning techniques for resistance prediction
- Create publication-quality visualizations
- Learn best practices for AMR data analysis

## 1. Dataset Overview {#dataset-overview}

### What is this dataset?

The AMR Genome Dataset contains **50 *Escherichia coli* isolates** with comprehensive antimicrobial resistance annotations. Each isolate includes:

- **Complete genome sequence** with quality metrics
- **79 AMR features** (42 genes + 37 resistance classes)
- **Rich metadata** including epidemiological data and publication information
- **Engineered features** for advanced analysis

### Key Features

| Category | Count | Description |
|----------|-------|-------------|
| **Isolates** | 50 | E. coli strains from various sources |
| **Genome Size** | 4.7-5.5 Mbp | Complete genome assemblies |
| **AMR Genes** | 42 | Binary presence/absence features |
| **Resistance Classes** | 37 | Antibiotic resistance phenotypes |
| **Metadata Fields** | 14 | Epidemiological and publication data |
| **Total Features** | 112 | Complete feature set |

### Data Sources

- **Genomic Data**: NCBI GenBank complete genome assemblies
- **AMR Annotations**: ABRicate tool against CARD database
- **Metadata**: NCBI BioProject/BioSample and publication data
- **Processing**: Custom Python pipeline with feature engineering

In [None]:
# Import required libraries
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import classification_report, confusion_matrix
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA
import warnings
warnings.filterwarnings('ignore')

# Set plotting style
plt.style.use('seaborn-v0_8')
sns.set_palette("husl")
plt.rcParams['figure.figsize'] = (12, 8)
plt.rcParams['font.size'] = 12

print("Libraries imported successfully!")

## 2. Data Loading and Initial Exploration {#data-loading}

In [None]:
# Load the main dataset
df = pd.read_csv('data/processed/Kaggle_AMR_Dataset_v1.0_final.csv')

print("Dataset loaded successfully!")
print(f"Dataset shape: {df.shape}")
print(f"\nFirst 5 rows:")
df.head()

In [None]:
# Dataset information
print("Dataset Information:")
print("-" * 50)
print(f"Number of isolates: {len(df)}")
print(f"Total features: {len(df.columns)}")
print(f"Memory usage: {df.memory_usage(deep=True).sum() / 1024 / 1024:.2f} MB")

print("\nData types:")
print(df.dtypes.value_counts())

print("\nMissing values summary:")
missing = df.isnull().sum()
missing = missing[missing > 0]
if len(missing) > 0:
    print(missing)
else:
    print("No missing values found!")

## 3. Genome Characteristics Analysis {#genome-analysis}

In [None]:
# Genome characteristics analysis
fig, axes = plt.subplots(2, 2, figsize=(15, 12))
fig.suptitle('Genome Characteristics Distribution', fontsize=16, fontweight='bold')

# Genome length distribution
axes[0,0].hist(df['Genome_Length_BP'], bins=20, alpha=0.7, color='skyblue', edgecolor='black')
axes[0,0].set_title('Genome Length Distribution')
axes[0,0].set_xlabel('Genome Length (bp)')
axes[0,0].set_ylabel('Frequency')
axes[0,0].axvline(df['Genome_Length_BP'].mean(), color='red', linestyle='--', label=f'Mean: {df["Genome_Length_BP"].mean():.0f}')
axes[0,0].legend()

# GC content distribution
axes[0,1].hist(df['GC_Content_Percent'], bins=15, alpha=0.7, color='lightgreen', edgecolor='black')
axes[0,1].set_title('GC Content Distribution')
axes[0,1].set_xlabel('GC Content (%)')
axes[0,1].set_ylabel('Frequency')
axes[0,1].axvline(df['GC_Content_Percent'].mean(), color='red', linestyle='--', label=f'Mean: {df["GC_Content_Percent"].mean():.2f}%')
axes[0,1].legend()

# Genome length vs GC content
axes[1,0].scatter(df['Genome_Length_BP'], df['GC_Content_Percent'], alpha=0.6, color='purple')
axes[1,0].set_title('Genome Length vs GC Content')
axes[1,0].set_xlabel('Genome Length (bp)')
axes[1,0].set_ylabel('GC Content (%)')

# Summary statistics
axes[1,1].axis('off')
stats_text = f"""
Genome Statistics:
• Mean length: {df['Genome_Length_BP'].mean():.0f} bp
• Std length: {df['Genome_Length_BP'].std():.0f} bp
• Min length: {df['Genome_Length_BP'].min():.0f} bp
• Max length: {df['Genome_Length_BP'].max():.0f} bp
• Mean GC: {df['GC_Content_Percent'].mean():.2f}%
• Std GC: {df['GC_Content_Percent'].std():.2f}%"""
axes[1,1].text(0.1, 0.8, stats_text, fontsize=12, verticalalignment='top',
               bbox=dict(boxstyle="round,pad=0.3", facecolor="lightblue"))

plt.tight_layout()
plt.show()

## 4. AMR Gene Distribution {#amr-analysis}

In [None]:
# AMR gene analysis
gene_cols = [col for col in df.columns if col.startswith('gene_')]
gene_frequencies = df[gene_cols].sum().sort_values(ascending=False)

print(f"Total AMR genes analyzed: {len(gene_cols)}")
print(f"\nTop 10 most frequent AMR genes:")
print(gene_frequencies.head(10))

# Visualize gene frequencies
plt.figure(figsize=(15, 8))
gene_frequencies.head(20).plot(kind='bar', color='coral', alpha=0.8)
plt.title('Top 20 Most Frequent AMR Genes', fontsize=16, fontweight='bold')
plt.xlabel('AMR Gene')
plt.ylabel('Frequency (Number of Isolates)')
plt.xticks(rotation=45, ha='right')
plt.grid(axis='y', alpha=0.3)
plt.tight_layout()
plt.show()

In [None]:
# AMR gene prevalence analysis
gene_prevalence = (df[gene_cols].sum() / len(df) * 100).sort_values(ascending=False)

fig, axes = plt.subplots(1, 2, figsize=(20, 8))

# Gene prevalence distribution
axes[0].hist(gene_prevalence, bins=20, alpha=0.7, color='orange', edgecolor='black')
axes[0].set_title('AMR Gene Prevalence Distribution', fontsize=14)
axes[0].set_xlabel('Prevalence (%)')
axes[0].set_ylabel('Number of Genes')
axes[0].axvline(50, color='red', linestyle='--', label='50% threshold')
axes[0].legend()

# Gene categories by prevalence
rare_genes = len(gene_prevalence[gene_prevalence < 10])
common_genes = len(gene_prevalence[gene_prevalence >= 50])
moderate_genes = len(gene_prevalence) - rare_genes - common_genes

categories = ['Rare (<10%)', 'Moderate (10-50%)', 'Common (≥50%)']
counts = [rare_genes, moderate_genes, common_genes]
colors = ['lightcoral', 'gold', 'lightgreen']

axes[1].bar(categories, counts, color=colors, alpha=0.8)
axes[1].set_title('AMR Genes by Prevalence Category', fontsize=14)
axes[1].set_ylabel('Number of Genes')
for i, v in enumerate(counts):
    axes[1].text(i, v + 0.5, str(v), ha='center', fontweight='bold')

plt.tight_layout()
plt.show()

print(f"Gene prevalence categories:")
print(f"• Rare genes (<10%): {rare_genes}")
print(f"• Moderate genes (10-50%): {moderate_genes}")
print(f"• Common genes (≥50%): {common_genes}")

## 5. Resistance Phenotype Analysis {#resistance-analysis}

In [None]:
# Resistance phenotype analysis
class_cols = [col for col in df.columns if col.startswith('class_')]
class_frequencies = df[class_cols].sum().sort_values(ascending=False)

print(f"Total resistance classes analyzed: {len(class_cols)}")
print(f"\nResistance phenotype frequencies:")
print(class_frequencies)

# Visualize resistance classes
plt.figure(figsize=(15, 8))
class_frequencies.plot(kind='bar', color='darkblue', alpha=0.8)
plt.title('Antibiotic Resistance Classes Distribution', fontsize=16, fontweight='bold')
plt.xlabel('Resistance Class')
plt.ylabel('Frequency (Number of Isolates)')
plt.xticks(rotation=45, ha='right')
plt.grid(axis='y', alpha=0.3)
plt.tight_layout()
plt.show()

In [None]:
# Multi-drug resistance analysis
df['total_amr_genes'] = df[gene_cols].sum(axis=1)
df['total_resistance_classes'] = df[class_cols].sum(axis=1)

fig, axes = plt.subplots(2, 2, figsize=(15, 12))
fig.suptitle('Multi-Drug Resistance Analysis', fontsize=16, fontweight='bold')

# Distribution of total AMR genes
axes[0,0].hist(df['total_amr_genes'], bins=15, alpha=0.7, color='purple', edgecolor='black')
axes[0,0].set_title('Distribution of Total AMR Genes per Isolate')
axes[0,0].set_xlabel('Number of AMR Genes')
axes[0,0].set_ylabel('Number of Isolates')
axes[0,0].axvline(df['total_amr_genes'].mean(), color='red', linestyle='--', 
                 label=f'Mean: {df["total_amr_genes"].mean():.1f}')
axes[0,0].legend()

# Distribution of resistance classes
axes[0,1].hist(df['total_resistance_classes'], bins=10, alpha=0.7, color='green', edgecolor='black')
axes[0,1].set_title('Distribution of Resistance Classes per Isolate')
axes[0,1].set_xlabel('Number of Resistance Classes')
axes[0,1].set_ylabel('Number of Isolates')
axes[0,1].axvline(df['total_resistance_classes'].mean(), color='red', linestyle='--',
                 label=f'Mean: {df["total_resistance_classes"].mean():.1f}')
axes[0,1].legend()

# Scatter plot: genes vs classes
axes[1,0].scatter(df['total_amr_genes'], df['total_resistance_classes'], alpha=0.6, color='orange')
axes[1,0].set_title('AMR Genes vs Resistance Classes')
axes[1,0].set_xlabel('Number of AMR Genes')
axes[1,0].set_ylabel('Number of Resistance Classes')

# MDR classification
axes[1,1].axis('off')
mdr_text = f"""
Multi-Drug Resistance (MDR):
• Resistant to ≥3 classes: {len(df[df['total_resistance_classes'] >= 3])}
• Resistant to ≥5 classes: {len(df[df['total_resistance_classes'] >= 5])}
• Max resistance classes: {df['total_resistance_classes'].max()}
• Max AMR genes: {df['total_amr_genes'].max()}
• Correlation (genes vs classes): {df['total_amr_genes'].corr(df['total_resistance_classes']):.3f}"""
axes[1,1].text(0.1, 0.8, mdr_text, fontsize=12, verticalalignment='top',
               bbox=dict(boxstyle="round,pad=0.3", facecolor="lightyellow"))

plt.tight_layout()
plt.show()

## 6. Metadata Exploration {#metadata-analysis}

In [None]:
# Metadata analysis
metadata_cols = ['collection_year', 'collection_month', 'collection_season', 
                'host_standardized', 'isolation_source_standardized', 'country']

print("Metadata Summary:")
print("-" * 40)

# Collection year distribution
if 'collection_year' in df.columns:
    plt.figure(figsize=(12, 8))
    
    # Year distribution
    plt.subplot(2, 2, 1)
    year_counts = df['collection_year'].value_counts().sort_index()
    year_counts.plot(kind='bar', color='lightblue', alpha=0.8)
    plt.title('Isolates by Collection Year')
    plt.xlabel('Year')
    plt.ylabel('Number of Isolates')
    plt.xticks(rotation=45)
    
    # Season distribution
    plt.subplot(2, 2, 2)
    season_counts = df['collection_season'].value_counts()
    season_counts.plot(kind='pie', autopct='%1.1f%%', colors=sns.color_palette('pastel'))
    plt.title('Isolates by Collection Season')
    plt.ylabel('')
    
    # Host distribution
    plt.subplot(2, 2, 3)
    host_counts = df['host_standardized'].value_counts()
    host_counts.plot(kind='barh', color='lightgreen', alpha=0.8)
    plt.title('Isolates by Host Type')
    plt.xlabel('Number of Isolates')
    
    # Source distribution
    plt.subplot(2, 2, 4)
    source_counts = df['isolation_source_standardized'].value_counts()
    source_counts.plot(kind='barh', color='coral', alpha=0.8)
    plt.title('Isolates by Isolation Source')
    plt.xlabel('Number of Isolates')
    
    plt.tight_layout()
    plt.show()
    
    print(f"Collection years: {sorted(df['collection_year'].dropna().unique())}")
    print(f"Seasons: {df['collection_season'].value_counts().to_dict()}")
    print(f"Host types: {df['host_standardized'].value_counts().to_dict()}")
    print(f"Isolation sources: {df['isolation_source_standardized'].value_counts().to_dict()}")
else:
    print("Metadata columns not found in dataset")

## 7. Machine Learning Applications {#ml-applications}

In [None]:
# Machine Learning Example: Predicting multi-drug resistance
print("Machine Learning Example: Predicting Multi-Drug Resistance")
print("=" * 60)

# Create target variable: Multi-drug resistant (≥3 resistance classes)
df['is_mdr'] = (df['total_resistance_classes'] >= 3).astype(int)

# Prepare features (AMR genes only for this example)
feature_cols = gene_cols
X = df[feature_cols]
y = df['is_mdr']

print(f"Dataset: {X.shape[0]} samples, {X.shape[1]} features")
print(f"MDR isolates: {y.sum()} ({y.sum()/len(y)*100:.1f}%)")
print(f"Non-MDR isolates: {len(y) - y.sum()} ({(len(y) - y.sum())/len(y)*100:.1f}%)")

# Split data
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42, stratify=y)

# Train Random Forest model
rf_model = RandomForestClassifier(n_estimators=100, random_state=42, class_weight='balanced')
rf_model.fit(X_train, y_train)

# Make predictions
y_pred = rf_model.predict(X_test)

print("\nModel Performance:")
print("-" * 30)
print(classification_report(y_test, y_pred, target_names=['Non-MDR', 'MDR']))

# Feature importance
feature_importance = pd.DataFrame({
    'feature': feature_cols,
    'importance': rf_model.feature_importances_
}).sort_values('importance', ascending=False)

print("\nTop 10 Most Important AMR Genes for MDR Prediction:")
print(feature_importance.head(10))

In [None]:
# Visualize feature importance
plt.figure(figsize=(12, 8))
top_features = feature_importance.head(15)
plt.barh(range(len(top_features)), top_features['importance'])
plt.yticks(range(len(top_features)), [f.split('_')[1] for f in top_features['feature']])
plt.xlabel('Feature Importance')
plt.title('Top 15 AMR Genes for MDR Prediction', fontsize=14, fontweight='bold')
plt.gca().invert_yaxis()
plt.grid(axis='x', alpha=0.3)
plt.tight_layout()
plt.show()

## 8. Advanced Visualizations {#visualizations}

In [None]:
# Advanced visualization: AMR gene co-occurrence heatmap
print("Creating AMR Gene Co-occurrence Heatmap...")

# Select top 20 most frequent genes for visualization
top_genes = gene_frequencies.head(20).index
gene_corr = df[top_genes].corr()

plt.figure(figsize=(14, 12))
sns.heatmap(gene_corr, cmap='RdYlBu_r', center=0, 
            square=True, linewidths=0.5, cbar_kws={'shrink': 0.8})
plt.title('AMR Gene Co-occurrence Matrix (Top 20 Genes)', fontsize=16, fontweight='bold')
plt.xticks(rotation=45, ha='right')
plt.yticks(rotation=0)
plt.tight_layout()
plt.show()

print("\nHeatmap interpretation:")
print("• Red squares: Genes frequently co-occur (positive correlation)")
print("• Blue squares: Genes rarely co-occur (negative correlation)")
print("• White squares: No correlation between genes")

In [None]:
# Dimensionality reduction visualization
print("Creating PCA Visualization of AMR Profiles...")

# Prepare data for PCA
X_pca = df[gene_cols]
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X_pca)

# Apply PCA
pca = PCA(n_components=2)
X_pca_result = pca.fit_transform(X_scaled)

# Create visualization
plt.figure(figsize=(12, 8))

# Color by MDR status
colors = ['red' if x >= 3 else 'blue' for x in df['total_resistance_classes']]
scatter = plt.scatter(X_pca_result[:, 0], X_pca_result[:, 1], 
                     c=colors, alpha=0.7, s=100, edgecolors='black')

# Add labels for extreme points
for i, txt in enumerate(df['Isolate_ID']):
    if abs(X_pca_result[i, 0]) > 2 or abs(X_pca_result[i, 1]) > 2:
        plt.annotate(txt, (X_pca_result[i, 0], X_pca_result[i, 1]), 
                    xytext=(5, 5), textcoords='offset points', fontsize=8)

plt.xlabel(f'PC1 ({pca.explained_variance_ratio_[0]*100:.1f}% variance)')
plt.ylabel(f'PC2 ({pca.explained_variance_ratio_[1]*100:.1f}% variance)')
plt.title('PCA of AMR Gene Profiles', fontsize=16, fontweight='bold')
plt.grid(alpha=0.3)

# Legend
legend_elements = [plt.Line2D([0], [0], marker='o', color='w', markerfacecolor='red', markersize=10, label='MDR (≥3 classes)'),
                  plt.Line2D([0], [0], marker='o', color='w', markerfacecolor='blue', markersize=10, label='Non-MDR (<3 classes)')]
plt.legend(handles=legend_elements, loc='upper right')

plt.tight_layout()
plt.show()

print(f"PCA explains {pca.explained_variance_ratio_[:2].sum()*100:.1f}% of variance in 2 dimensions")

## 9. Research Applications {#research-applications}

In [None]:
# Research applications demonstration
print("Research Applications Demonstration")
print("=" * 50)

# Application 1: Temporal trends in resistance
if 'collection_year' in df.columns:
    yearly_resistance = df.groupby('collection_year')['total_resistance_classes'].mean()
    
    plt.figure(figsize=(10, 6))
    yearly_resistance.plot(kind='line', marker='o', linewidth=2, color='darkred')
    plt.title('Average Resistance Classes by Collection Year', fontsize=14, fontweight='bold')
    plt.xlabel('Year')
    plt.ylabel('Average Number of Resistance Classes')
    plt.grid(alpha=0.3)
    plt.show()

# Application 2: Host-specific resistance patterns
host_resistance = df.groupby('host_standardized')['total_resistance_classes'].mean().sort_values(ascending=False)

plt.figure(figsize=(10, 6))
host_resistance.plot(kind='bar', color='teal', alpha=0.8)
plt.title('Average Resistance by Host Type', fontsize=14, fontweight='bold')
plt.xlabel('Host Type')
plt.ylabel('Average Resistance Classes')
plt.xticks(rotation=45)
plt.grid(axis='y', alpha=0.3)
plt.show()

# Application 3: Resistance gene diversity
gene_diversity = df[gene_cols].sum(axis=0)
gene_diversity_stats = {
    'Total unique genes': len(gene_diversity),
    'Genes present in ≥50% isolates': len(gene_diversity[gene_diversity >= len(df) * 0.5]),
    'Genes present in <10% isolates': len(gene_diversity[gene_diversity < len(df) * 0.1]),
    'Most common gene frequency': gene_diversity.max(),
    'Most common gene': gene_diversity.idxmax().replace('gene_', '')
}

print("\nResistance Gene Diversity Statistics:")
for key, value in gene_diversity_stats.items():
    print(f"• {key}: {value}")

# Application 4: Multi-drug resistance prevalence
mdr_stats = {
    'Pan-resistant (≥10 classes)': len(df[df['total_resistance_classes'] >= 10]),
    'Extensively resistant (≥7 classes)': len(df[df['total_resistance_classes'] >= 7]),
    'Multi-drug resistant (≥3 classes)': len(df[df['total_resistance_classes'] >= 3]),
    'Single class resistant': len(df[df['total_resistance_classes'] == 1]),
    'Susceptible': len(df[df['total_resistance_classes'] == 0])
}

print("\nMulti-Drug Resistance Classification:")
for category, count in mdr_stats.items():
    percentage = count / len(df) * 100
    print(f"• {category}: {count} isolates ({percentage:.1f}%)")

## 📚 Key Findings and Insights

### Dataset Characteristics
- **50 E. coli isolates** with complete genome sequences
- **Genome sizes**: 4.7-5.5 Mbp (mean: ~5.0 Mbp)
- **GC content**: 50-51% (typical for E. coli)
- **AMR genes per isolate**: 43-51 (mean: ~46)
- **Resistance classes per isolate**: 21-25 (mean: ~23)

### AMR Patterns
- **Most common genes**: CRP, acrB, acrA (housekeeping genes)
- **Variable genes**: CTX-M-15, tet(A), mphB (resistance markers)
- **All isolates are multidrug-resistant** (≥3 resistance classes)
- **High genetic diversity** in resistance gene profiles

### Research Implications
- **Surveillance**: Track resistance emergence patterns
- **Prediction**: ML models can predict resistance from genomic data
- **Epidemiology**: Link resistance to host, source, and geography
- **Evolution**: Study co-occurrence and acquisition patterns

## 🔬 Next Steps for Research

1. **Expand Dataset**: Include more bacterial species and isolates
2. **Temporal Analysis**: Track resistance trends over time
3. **Comparative Genomics**: Compare with susceptible isolates
4. **Functional Studies**: Validate gene function predictions
5. **Clinical Correlations**: Link genomic predictions to clinical outcomes

## 📖 References and Further Reading

- **CARD Database**: Comprehensive Antibiotic Resistance Database
- **ABRicate**: AMR gene detection tool
- **BioPython**: Python tools for computational biology
- **Scikit-learn**: Machine learning in Python

---

**Dataset Citation**:
```
@dataset{kulkarni_amr_dataset_2024,
  title={AMR Genome Dataset: Antimicrobial Resistance Prediction Dataset},
  author={Kulkarni, Vihaan},
  year={2024},
  publisher={GitHub},
  url={https://github.com/vihaankulkarni29/amr-dataset}
}
```

**Contact**: For questions about this dataset or analysis, please open a GitHub issue.