# Variant Calling with Deep Learning on 1000 Genomes Data

**Duration:** 60-90 minutes | **Platform:** Colab or Studio Lab | **Data:** ~1.5GB

Train a CNN to call genetic variants from sequencing reads using real 1000 Genomes Project data.

## Research Goal

Build a deep learning variant caller that identifies SNPs and indels from aligned BAM files, competing with traditional tools like GATK HaplotypeCaller.

## Setup

In [None]:
# Install dependencies (pre-installed in Studio Lab)
!pip install -q pysam biopython tensorflow scikit-learn

In [None]:
import urllib.request
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pysam
import tensorflow as tf
from sklearn.metrics import auc, precision_recall_fscore_support, roc_curve
from tensorflow import keras
from tqdm import tqdm

print(f"TensorFlow version: {tf.__version__}")
print(f"GPU available: {tf.config.list_physical_devices('GPU')}")

## 1. Download 1000 Genomes Data

Download a subset of chromosome 20 BAM file from individual NA12878 (~1.5GB, 15-20 minutes).

**Note:** Data downloads from AWS Open Data Registry (no credentials required).

In [None]:
# Create data directory
data_dir = Path("data")
data_dir.mkdir(exist_ok=True)

# AWS Open Data Registry URLs (public, no credentials)
BASE_URL = "https://s3.amazonaws.com/1000genomes/"

# Files to download
files_to_download = [
    "phase3/data/NA12878/alignment/NA12878.chrom20.ILLUMINA.bwa.CEU.low_coverage.20121211.bam",
    "phase3/data/NA12878/alignment/NA12878.chrom20.ILLUMINA.bwa.CEU.low_coverage.20121211.bam.bai",
]

# Reference genome (hg19) - subset for chr20
reference_url = "https://hgdownload.cse.ucsc.edu/goldenPath/hg19/chromosomes/chr20.fa.gz"


def download_file(url, destination):
    """Download file with progress bar."""
    if destination.exists():
        print(f"✓ Using cached file: {destination.name}")
        return

    print(f"Downloading {destination.name}...")
    urllib.request.urlretrieve(url, destination)
    print(f"✓ Downloaded {destination.name}")


# Download BAM files
for file_path in files_to_download:
    filename = Path(file_path).name
    url = BASE_URL + file_path
    download_file(url, data_dir / filename)

# Download reference genome
download_file(reference_url, data_dir / "chr20.fa.gz")

# Uncompress reference
if not (data_dir / "chr20.fa").exists():
    !gunzip -k {data_dir}/chr20.fa.gz
    print("✓ Uncompressed reference genome")

print("\n✓ Data download complete!")

## 2. Explore the BAM File

Load and inspect the aligned sequencing reads.

In [None]:
# Load BAM file
bam_file = data_dir / "NA12878.chrom20.ILLUMINA.bwa.CEU.low_coverage.20121211.bam"
bamfile = pysam.AlignmentFile(str(bam_file), "rb")

# Get basic statistics
print(f"BAM file: {bam_file.name}")
print(f"References: {bamfile.references[:5]}")  # Chromosome names
print(f"Number of mapped reads: {bamfile.mapped}")
print(f"Number of unmapped reads: {bamfile.unmapped}")

# Sample a few reads
print("\nSample reads:")
for i, read in enumerate(bamfile.fetch("20", 10000000, 10000100)):
    if i >= 3:
        break
    print(
        f"  Position: {read.reference_start:,}, Length: {read.query_length}, Quality: {read.mapping_quality}"
    )

## 3. Generate Pileup Tensors

Convert aligned reads to image-like tensors for CNN input.

**Pileup representation:**
- Each genomic position becomes a row
- Each overlapping read becomes a column
- Channels encode: base, quality, mapping quality, strand, etc.

In [None]:
# Encoding dictionaries
BASE_ENCODING = {"A": 0, "C": 1, "G": 2, "T": 3, "N": 4}
STRAND_ENCODING = {"+": 0, "-": 1}


def create_pileup_tensor(bamfile, reference, chrom, start, end, max_depth=100):
    """
    Create pileup tensor for a genomic region.

    Args:
        bamfile: pysam.AlignmentFile object
        reference: Reference sequence (string)
        chrom: Chromosome name
        start, end: Genomic coordinates (0-based)
        max_depth: Maximum read depth to consider

    Returns:
        tensor: Shape (length, max_depth, n_channels)
    """
    length = end - start
    n_channels = 7  # base, base_qual, map_qual, strand, is_match, is_del, is_ins

    tensor = np.zeros((length, max_depth, n_channels), dtype=np.float32)

    # For each position, collect overlapping reads
    for pos_idx, pos in enumerate(range(start, end)):
        pileup_column = bamfile.pileup(chrom, pos, pos + 1, truncate=True, max_depth=max_depth)

        read_idx = 0
        for pileup_col in pileup_column:
            for pileup_read in pileup_col.pileups:
                if read_idx >= max_depth:
                    break

                read = pileup_read.alignment

                # Channel 0: Base encoding
                if not pileup_read.is_del and not pileup_read.is_refskip:
                    base = read.query_sequence[pileup_read.query_position]
                    tensor[pos_idx, read_idx, 0] = BASE_ENCODING.get(base, 4) / 4.0

                    # Channel 1: Base quality
                    base_qual = read.query_qualities[pileup_read.query_position]
                    tensor[pos_idx, read_idx, 1] = base_qual / 40.0  # Normalize by typical max

                # Channel 2: Mapping quality
                tensor[pos_idx, read_idx, 2] = read.mapping_quality / 60.0

                # Channel 3: Strand
                tensor[pos_idx, read_idx, 3] = 0 if read.is_reverse else 1

                # Channel 4: Is match to reference
                ref_base = reference[pos - start]
                if not pileup_read.is_del and not pileup_read.is_refskip:
                    read_base = read.query_sequence[pileup_read.query_position]
                    tensor[pos_idx, read_idx, 4] = 1.0 if read_base == ref_base else 0.0

                # Channel 5: Is deletion
                tensor[pos_idx, read_idx, 5] = 1.0 if pileup_read.is_del else 0.0

                # Channel 6: Is insertion (simplified)
                tensor[pos_idx, read_idx, 6] = 1.0 if pileup_read.indel > 0 else 0.0

                read_idx += 1

    return tensor


print("✓ Pileup tensor generation functions ready")

In [None]:
# Load reference sequence
reference_file = pysam.FastaFile(str(data_dir / "chr20.fa"))

# Generate a sample pileup tensor
test_region = (10000000, 10000221)  # 221bp window
ref_seq = reference_file.fetch("20", test_region[0], test_region[1])
sample_tensor = create_pileup_tensor(
    bamfile, ref_seq, "20", test_region[0], test_region[1], max_depth=100
)

print(f"Sample pileup tensor shape: {sample_tensor.shape}")
print(
    f"  Position x Read Depth x Channels: {sample_tensor.shape[0]} x {sample_tensor.shape[1]} x {sample_tensor.shape[2]}"
)

# Visualize pileup (base channel only)
plt.figure(figsize=(14, 6))
plt.imshow(sample_tensor[:, :, 0].T, aspect="auto", cmap="viridis", interpolation="none")
plt.colorbar(label="Base Encoding")
plt.xlabel("Genomic Position")
plt.ylabel("Read Depth")
plt.title("Pileup Visualization (Base Channel)")
plt.tight_layout()
plt.show()

## 4. Load Truth Set Labels

Download GIAB (Genome in a Bottle) high-confidence variant calls for NA12878.

**Note:** This is a simulated section. Real implementation would download GIAB VCF files.

In [None]:
# In a real implementation, download GIAB truth set:
# GIAB_URL = "https://ftp-trace.ncbi.nlm.nih.gov/ReferenceSamples/giab/release/NA12878_HG001/latest/GRCh37/"
# truth_vcf = "HG001_GRCh37_1_22_v4.2.1_benchmark.vcf.gz"

# For this demo, create synthetic labels
def generate_training_data(bamfile, reference_file, chrom, regions, window_size=221):
    """
    Generate training examples from BAM file.

    In real implementation, would use GIAB truth set for labels.
    """
    X_train = []
    y_train = []

    for start, end in tqdm(regions, desc="Generating training data"):
        for pos in range(start, end - window_size, window_size // 2):  # 50% overlap
            window_end = pos + window_size

            # Get reference sequence
            ref_seq = reference_file.fetch(chrom, pos, window_end)

            # Create pileup tensor
            tensor = create_pileup_tensor(bamfile, ref_seq, chrom, pos, window_end)

            # Generate labels (simplified - real version uses GIAB VCF)
            # Label: 1 if variant, 0 if reference
            # In practice, this would come from truth VCF
            label = np.random.binomial(1, 0.001, window_size)  # ~0.1% variant rate

            X_train.append(tensor)
            y_train.append(label)

    return np.array(X_train), np.array(y_train)


# Define training and validation regions
train_regions = [
    (10000000, 12000000),  # 2Mb for training
]

val_regions = [
    (12000000, 13000000),  # 1Mb for validation
]

print("⚠️  NOTE: This demo uses simulated labels.")
print("    Real implementation would use GIAB high-confidence variant calls.")
print("\nGenerating training data (this will take 10-15 minutes)...")

In [None]:
# Generate training data
X_train, y_train = generate_training_data(
    bamfile, reference_file, "20", train_regions, window_size=221
)

X_val, y_val = generate_training_data(bamfile, reference_file, "20", val_regions, window_size=221)

print(f"\nTraining data shape: {X_train.shape}")
print(f"Training labels shape: {y_train.shape}")
print(f"Validation data shape: {X_val.shape}")
print(f"Validation labels shape: {y_val.shape}")
print(f"\nVariant rate in training data: {y_train.mean():.4f}")

## 5. Build CNN Variant Caller

Create a convolutional neural network architecture for variant calling.

In [None]:
def build_variant_caller_cnn(input_shape=(221, 100, 7)):
    """
    Build CNN variant caller model.

    Architecture inspired by DeepVariant (Poplin et al. 2018).
    """
    model = keras.Sequential(
        [
            # Input layer
            keras.layers.Input(shape=input_shape),
            # Conv block 1
            keras.layers.Conv2D(32, (3, 3), activation="relu", padding="same"),
            keras.layers.BatchNormalization(),
            keras.layers.MaxPooling2D((2, 2)),
            # Conv block 2
            keras.layers.Conv2D(64, (3, 3), activation="relu", padding="same"),
            keras.layers.BatchNormalization(),
            keras.layers.MaxPooling2D((2, 2)),
            # Conv block 3
            keras.layers.Conv2D(128, (3, 3), activation="relu", padding="same"),
            keras.layers.BatchNormalization(),
            keras.layers.MaxPooling2D((2, 2)),
            # Global pooling
            keras.layers.GlobalAveragePooling2D(),
            # Dense layers
            keras.layers.Dense(256, activation="relu"),
            keras.layers.Dropout(0.5),
            keras.layers.Dense(128, activation="relu"),
            keras.layers.Dropout(0.3),
            # Output layer (per-position variant probability)
            keras.layers.Dense(221, activation="sigmoid"),  # 221 positions
        ]
    )

    return model


# Build model
model = build_variant_caller_cnn()

# Compile
model.compile(
    optimizer=keras.optimizers.Adam(learning_rate=0.001),
    loss="binary_crossentropy",
    metrics=["accuracy", keras.metrics.Precision(), keras.metrics.Recall()],
)

print("Model architecture:")
model.summary()

## 6. Train the Model

Train the CNN variant caller (60-75 minutes on GPU, 4-6 hours on CPU).

In [None]:
# Setup callbacks
callbacks = [
    keras.callbacks.EarlyStopping(monitor="val_loss", patience=5, restore_best_weights=True),
    keras.callbacks.ReduceLROnPlateau(monitor="val_loss", factor=0.5, patience=3),
    keras.callbacks.ModelCheckpoint(
        "variant_caller_best.h5", monitor="val_loss", save_best_only=True
    ),
]

# Train
print("Training CNN variant caller...")
print("This will take 60-75 minutes on GPU (4-6 hours on CPU)")
print("\n" + "=" * 60)

history = model.fit(
    X_train,
    y_train,
    validation_data=(X_val, y_val),
    epochs=50,
    batch_size=32,
    callbacks=callbacks,
    verbose=1,
)

print("\n" + "=" * 60)
print("✓ Training complete!")

In [None]:
# Plot training history
fig, axes = plt.subplots(2, 2, figsize=(14, 10))

# Loss
axes[0, 0].plot(history.history["loss"], label="Train")
axes[0, 0].plot(history.history["val_loss"], label="Validation")
axes[0, 0].set_xlabel("Epoch")
axes[0, 0].set_ylabel("Loss")
axes[0, 0].set_title("Training Loss")
axes[0, 0].legend()
axes[0, 0].grid(True, alpha=0.3)

# Accuracy
axes[0, 1].plot(history.history["accuracy"], label="Train")
axes[0, 1].plot(history.history["val_accuracy"], label="Validation")
axes[0, 1].set_xlabel("Epoch")
axes[0, 1].set_ylabel("Accuracy")
axes[0, 1].set_title("Accuracy")
axes[0, 1].legend()
axes[0, 1].grid(True, alpha=0.3)

# Precision
axes[1, 0].plot(history.history["precision"], label="Train")
axes[1, 0].plot(history.history["val_precision"], label="Validation")
axes[1, 0].set_xlabel("Epoch")
axes[1, 0].set_ylabel("Precision")
axes[1, 0].set_title("Precision")
axes[1, 0].legend()
axes[1, 0].grid(True, alpha=0.3)

# Recall
axes[1, 1].plot(history.history["recall"], label="Train")
axes[1, 1].plot(history.history["val_recall"], label="Validation")
axes[1, 1].set_xlabel("Epoch")
axes[1, 1].set_ylabel("Recall")
axes[1, 1].set_title("Recall")
axes[1, 1].legend()
axes[1, 1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

## 7. Evaluate Performance

Assess variant calling accuracy on held-out validation set.

In [None]:
# Make predictions on validation set
y_pred_prob = model.predict(X_val)
y_pred = (y_pred_prob > 0.5).astype(int)

# Flatten for metrics calculation
y_val_flat = y_val.flatten()
y_pred_flat = y_pred.flatten()
y_pred_prob_flat = y_pred_prob.flatten()

# Calculate metrics
precision, recall, f1, _ = precision_recall_fscore_support(
    y_val_flat, y_pred_flat, average="binary"
)

print("Validation Set Performance:")
print("=" * 40)
print(f"Precision: {precision:.4f}")
print(f"Recall:    {recall:.4f}")
print(f"F1 Score:  {f1:.4f}")
print("\nNote: These metrics are based on simulated labels.")
print("      Real evaluation would use GIAB truth set.")

In [None]:
# ROC Curve
fpr, tpr, thresholds = roc_curve(y_val_flat, y_pred_prob_flat)
roc_auc = auc(fpr, tpr)

plt.figure(figsize=(8, 6))
plt.plot(fpr, tpr, color="darkorange", lw=2, label=f"ROC curve (AUC = {roc_auc:.3f})")
plt.plot([0, 1], [0, 1], color="navy", lw=2, linestyle="--", label="Random")
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel("False Positive Rate")
plt.ylabel("True Positive Rate")
plt.title("ROC Curve - Variant Calling Performance")
plt.legend(loc="lower right")
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

## 8. Call Variants on Test Region

Apply trained model to call variants in a held-out genomic region.

In [None]:
# Define test region
test_region = (15000000, 15010000)  # 10kb test region

# Generate test data
X_test, _ = generate_training_data(bamfile, reference_file, "20", [test_region], window_size=221)

# Call variants
print(f"Calling variants in region chr20:{test_region[0]:,}-{test_region[1]:,}")
variant_probs = model.predict(X_test)

# Identify high-confidence variants (prob > 0.9)
high_conf_threshold = 0.9
variants = []

for window_idx, window_probs in enumerate(variant_probs):
    window_start = test_region[0] + window_idx * (221 // 2)  # 50% overlap

    for pos_offset, prob in enumerate(window_probs):
        if prob > high_conf_threshold:
            pos = window_start + pos_offset
            variants.append({"chrom": "20", "pos": pos, "prob": prob})

# Convert to DataFrame
variants_df = pd.DataFrame(variants)

if len(variants_df) > 0:
    # Remove duplicates (from overlapping windows)
    variants_df = variants_df.drop_duplicates(subset=["pos"]).sort_values("pos")

    print(f"\n✓ Called {len(variants_df)} high-confidence variants")
    print("\nTop 10 variants:")
    print(variants_df.head(10))
else:
    print("\nNo high-confidence variants found in test region.")
    print("(This is expected with simulated labels)")

## 9. Summary and Next Steps

### What We Built

- Downloaded 1.5GB of 1000 Genomes data
- Generated pileup tensors from BAM files
- Trained a CNN variant caller (60-75 min)
- Evaluated performance on validation set
- Called variants on held-out region

### Limitations in Colab

You may have noticed:
- **20-minute download** at every session start
- **Training time close to timeout** (75 min)
- **Can't save model** between sessions
- **Limited to small regions** (~11GB RAM limit)

### Tier 1: Studio Lab (Multi-Cohort Analysis)

Upgrade to [Tier 1](../tier-1/) for:
- **8-12GB cached data** (download once, persist forever)
- **Multi-sample ensemble callers** (5-6 hour training)
- **Model checkpointing** (resume training)
- **Population analysis** (multiple individuals)
- **No session timeouts**

### Real Research Applications

This approach (deep learning for variant calling) is used in:
- **Clinical diagnostics:** Rare disease variant identification
- **Cancer genomics:** Somatic mutation detection
- **Population genetics:** Large-scale GWAS studies
- **Agricultural genomics:** Crop improvement

### References

- Poplin et al. (2018) "A universal SNP and small-indel variant caller using deep neural networks" *Nature Biotechnology* 36:983-987
- 1000 Genomes Project Consortium (2015) *Nature* 526:68-74
- Genome in a Bottle Consortium: https://www.nist.gov/programs-projects/genome-bottle

---

## Clean Up

Close file handles to free resources.

In [None]:
bamfile.close()
reference_file.close()
print("✓ Resources cleaned up")