In [1]:
import scanpy as sc
import anndata as ad
import numpy as np
import pandas as pd
from sklearn.preprocessing import LabelEncoder
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import shap
import matplotlib.pyplot as plt
import seaborn as sns


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
SEED = 42
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.backends.cudnn.deterministic = True

In [3]:
adata = sc.read_h5ad("data/cell_atlas_of_the_human_lung_in_health_and_disease_full.h5ad")

In [4]:
subset_mask = (
    (adata.obs['ann_level_1'] == 'Epithelial') &
    (adata.obs['ann_level_2'].isin(['Airway epithelium', 'Alveolar epithelium', 'Submucosal Gland'])) &
    (adata.obs['ann_level_3'].isin(['Basal', 'Secretory', 'Submucosal Secretory'])) &
    (adata.obs['ann_level_4'].isin(['Basal resting', 'Club', 'Deuterosomal', 'Goblet', 
                                    'Hillock-like', 'SMG duct', 'SMG mucous', 
                                    'SMG serous', 'Suprabasal', 'Transitional Club-AT2'])) &
    (adata.obs['ann_level_5'].isin(['Club (non-nasal)', 'Goblet (bronchial)', 
                                    'Goblet (nasal)', 'Goblet (subsegmental)', 
                                    'SMG serous (bronchial)', 'SMG serous (nasal)', 
                                    'pre-TB secretory']))
)

adata = adata[subset_mask].copy()
print(adata)

AnnData object with n_obs × n_vars = 106931 × 56239
    obs: 'suspension_type', 'donor_id', 'is_primary_data', 'assay_ontology_term_id', 'cell_type_ontology_term_id', 'development_stage_ontology_term_id', 'disease_ontology_term_id', 'self_reported_ethnicity_ontology_term_id', 'tissue_ontology_term_id', 'organism_ontology_term_id', 'sex_ontology_term_id', "3'_or_5'", 'BMI', 'age_or_mean_of_age_range', 'age_range', 'anatomical_region_ccf_score', 'ann_coarse_for_GWAS_and_modeling', 'ann_finest_level', 'ann_level_1', 'ann_level_2', 'ann_level_3', 'ann_level_4', 'ann_level_5', 'cause_of_death', 'core_or_extension', 'dataset', 'fresh_or_frozen', 'log10_total_counts', 'lung_condition', 'mixed_ancestry', 'original_ann_level_1', 'original_ann_level_2', 'original_ann_level_3', 'original_ann_level_4', 'original_ann_level_5', 'original_ann_nonharmonized', 'reannotation_type', 'sample', 'scanvi_label', 'sequencing_platform', 'smoking_status', 'study', 'subject_type', 'tissue_coarse_unharmonized', '

In [5]:
desired_diseases = ["normal", "chronic obstructive pulmonary disease", "chronic rhinitis", "pulmonary fibrosis"]
adata = adata[adata.obs['disease'].isin(desired_diseases)].copy()
print(adata)


AnnData object with n_obs × n_vars = 98731 × 56239
    obs: 'suspension_type', 'donor_id', 'is_primary_data', 'assay_ontology_term_id', 'cell_type_ontology_term_id', 'development_stage_ontology_term_id', 'disease_ontology_term_id', 'self_reported_ethnicity_ontology_term_id', 'tissue_ontology_term_id', 'organism_ontology_term_id', 'sex_ontology_term_id', "3'_or_5'", 'BMI', 'age_or_mean_of_age_range', 'age_range', 'anatomical_region_ccf_score', 'ann_coarse_for_GWAS_and_modeling', 'ann_finest_level', 'ann_level_1', 'ann_level_2', 'ann_level_3', 'ann_level_4', 'ann_level_5', 'cause_of_death', 'core_or_extension', 'dataset', 'fresh_or_frozen', 'log10_total_counts', 'lung_condition', 'mixed_ancestry', 'original_ann_level_1', 'original_ann_level_2', 'original_ann_level_3', 'original_ann_level_4', 'original_ann_level_5', 'original_ann_nonharmonized', 'reannotation_type', 'sample', 'scanvi_label', 'sequencing_platform', 'smoking_status', 'study', 'subject_type', 'tissue_coarse_unharmonized', 't

In [6]:
sc.pp.filter_cells(adata, min_counts=500)
sc.pp.filter_genes(adata, min_cells=10)


In [7]:
sc.pp.normalize_total(adata, target_sum=1e4)
sc.pp.log1p(adata)


In [8]:
sc.pp.highly_variable_genes(adata, n_top_genes=4000, flavor='seurat_v3')
adata = adata[:, adata.var['highly_variable']].copy()




In [9]:
le = LabelEncoder()
adata.obs['disease_encoded'] = le.fit_transform(adata.obs['disease'])
disease_classes = le.classes_  # This will store the class order
print(disease_classes)


['chronic obstructive pulmonary disease' 'chronic rhinitis' 'normal'
 'pulmonary fibrosis']


In [10]:
# Simple split: 80% train, 10% val, 10% test
all_indices = np.arange(adata.n_obs)
np.random.shuffle(all_indices)

train_size = int(0.8 * len(all_indices))
val_size = int(0.1 * len(all_indices))
test_size = len(all_indices) - train_size - val_size

train_idx = all_indices[:train_size]
val_idx = all_indices[train_size:train_size+val_size]
test_idx = all_indices[train_size+val_size:]

X = adata.X

# Convert to np.array if not already
X = X.toarray() if hasattr(X, 'toarray') else X

y = adata.obs['disease_encoded'].values


In [11]:
class MyDataset(Dataset):
    def __init__(self, X, y):
        self.X = torch.tensor(X, dtype=torch.float32)
        self.y = torch.tensor(y, dtype=torch.long)
        
    def __len__(self):
        return len(self.y)
    
    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]

train_dataset = MyDataset(X[train_idx], y[train_idx])
val_dataset = MyDataset(X[val_idx], y[val_idx])
test_dataset = MyDataset(X[test_idx], y[test_idx])

batch_size = 16
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)


In [12]:
input_dim = X.shape[1]
num_classes = len(disease_classes)

class SimpleNN(nn.Module):
    def __init__(self, input_dim, hidden_dim=256, num_classes=4):
        super(SimpleNN, self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.bn1 = nn.BatchNorm1d(hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.bn2 = nn.BatchNorm1d(hidden_dim)
        self.fc3 = nn.Linear(hidden_dim, num_classes)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(0.2)
        
    def forward(self, x):
        x = self.fc1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.dropout(x)
        x = self.fc2(x)
        x = self.bn2(x)
        x = self.relu(x)
        x = self.dropout(x)
        x = self.fc3(x)
        return x

model = SimpleNN(input_dim=input_dim, num_classes=num_classes)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)


In [13]:
epochs = 10
for epoch in range(epochs):
    model.train()
    train_loss = 0
    correct = 0
    total = 0
    
    for batch_x, batch_y in train_loader:
        batch_x, batch_y = batch_x, batch_y
        
        optimizer.zero_grad()
        outputs = model(batch_x)
        loss = criterion(outputs, batch_y)
        loss.backward()
        optimizer.step()
        
        train_loss += loss.item() * batch_x.size(0)
        _, pred = outputs.max(1)
        correct += pred.eq(batch_y).sum().item()
        total += batch_y.size(0)
    
    train_loss /= total
    train_acc = correct / total

    # Validate
    model.eval()
    val_loss = 0
    val_correct = 0
    val_total = 0
    with torch.no_grad():
        for batch_x, batch_y in val_loader:
            batch_x, batch_y = batch_x, batch_y
            outputs = model(batch_x)
            loss = criterion(outputs, batch_y)
            val_loss += loss.item() * batch_x.size(0)
            _, pred = outputs.max(1)
            val_correct += pred.eq(batch_y).sum().item()
            val_total += batch_y.size(0)
    val_loss /= val_total
    val_acc = val_correct / val_total
    
    print(f"Epoch {epoch+1}/{epochs}, Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}, Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}")

Epoch 1/10, Train Loss: 0.0546, Train Acc: 0.9805, Val Loss: 0.0305, Val Acc: 0.9876
Epoch 2/10, Train Loss: 0.0290, Train Acc: 0.9888, Val Loss: 0.0333, Val Acc: 0.9873
Epoch 3/10, Train Loss: 0.0206, Train Acc: 0.9921, Val Loss: 0.0343, Val Acc: 0.9888
Epoch 4/10, Train Loss: 0.0165, Train Acc: 0.9939, Val Loss: 0.0295, Val Acc: 0.9896
Epoch 5/10, Train Loss: 0.0130, Train Acc: 0.9953, Val Loss: 0.0371, Val Acc: 0.9860
Epoch 6/10, Train Loss: 0.0100, Train Acc: 0.9964, Val Loss: 0.0357, Val Acc: 0.9909
Epoch 7/10, Train Loss: 0.0093, Train Acc: 0.9968, Val Loss: 0.0345, Val Acc: 0.9901
Epoch 8/10, Train Loss: 0.0072, Train Acc: 0.9974, Val Loss: 0.0372, Val Acc: 0.9892
Epoch 9/10, Train Loss: 0.0065, Train Acc: 0.9975, Val Loss: 0.0412, Val Acc: 0.9895
Epoch 10/10, Train Loss: 0.0054, Train Acc: 0.9982, Val Loss: 0.0407, Val Acc: 0.9897


In [31]:
model.eval()
test_loss = 0
test_correct = 0
test_total = 0
all_preds = []
all_targets = []

with torch.no_grad():
    for batch_x, batch_y in test_loader:
        batch_x, batch_y = batch_x, batch_y
        outputs = model(batch_x)
        loss = criterion(outputs, batch_y)
        test_loss += loss.item() * batch_x.size(0)
        _, pred = outputs.max(1)
        test_correct += pred.eq(batch_y).sum().item()
        test_total += batch_y.size(0)
        all_preds.append(pred.cpu().numpy())
        all_targets.append(batch_y.cpu().numpy())

test_loss /= test_total
test_acc = test_correct / test_total
train_loss /= total
train_acc = correct / total
print(f"Test Loss: {test_loss:.4f}, Test Acc: {test_acc:.4f}")
print(f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}")

Test Loss: 0.0392, Test Acc: 0.9918
Train Loss: 0.0000, Train Acc: 0.9982


In [26]:
# Take a small sample of the test set
sample_size = 2000
sample_idx = np.random.choice(test_idx, sample_size, replace=False)
X_sample = X[sample_idx]
X_sample = torch.tensor(X_sample, dtype=torch.float32)

model.eval()
background = X[train_idx][:1000]  # background for SHAP
background = torch.tensor(background, dtype=torch.float32)
background = background.detach()
explainer = shap.DeepExplainer(model, background)
# explainer.explainer.multi_input, explainer.explainer.multi_output = True, True
shap_values = explainer.shap_values(X_sample, check_additivity=False)


In [29]:
# Assuming shap_values is a list of arrays with shape [num_classes, sample_size, num_genes]
# Convert to a numpy array if needed
shap_array = np.array(shap_values)  # shape: (num_classes, sample_size, num_genes)

# Compute the average absolute SHAP value across classes and samples
shap_avg = np.mean(np.abs(shap_array), axis=(0, 1))  # shape: (num_genes,)

# Get gene names from adata.var_names
gene_names = adata.var_names

# Sort genes by importance (descending order)
important_genes_idx = np.argsort(shap_avg)[::-1]

# Extract top 20 genes
top_20_genes = gene_names[important_genes_idx[:20]]

print("Top 20 important genes identified by SHAP:")
for gene in top_20_genes:
    print(gene)

# If you want a SHAP summary plot for a single class:
# shap.summary_plot(shap_values[0], features=X_sample, feature_names=gene_names)


Top 20 important genes identified by SHAP:
ENSG00000175899
ENSG00000197953
ENSG00000184389
ENSG00000114771
