# CNN

In [1]:
from pathlib import Path
import torch
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import spearmanr
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score
from torch.utils.data import TensorDataset, DataLoader, random_split, Dataset
from torch.cuda.amp import GradScaler, autocast
import torch.nn as nn
import torch.optim as optim
import copy

%load_ext autoreload
%autoreload 2

In [2]:
from src.data.dataset import CnvDataset

Important paths:

In [3]:
git_root = Path('.')
data_root = git_root / 'data'
assert data_root.exists()

Defining paths for batch 1.

In [4]:
dataset_root_val = data_root / 'embeddings' / 'batch_1' / 'val'
dataset_root_train = data_root / 'embeddings' / 'batch_1' / 'train'
dataset_root_test = data_root / 'embeddings' / 'batch_1' / 'test'

In [5]:
b1_val_path = data_root / 'splits' / 'batch1_val_filtered.tsv'
b1_val_df = pd.read_csv(b1_val_path, sep='\t')
b1_val_df

b1_train_path = data_root / 'splits' / 'batch1_training_filtered.tsv'
b1_train_df = pd.read_csv(b1_train_path, sep='\t')
b1_train_df

b1_test_path = data_root / 'splits' / 'batch1_test_filtered.tsv'
b1_test_df = pd.read_csv(b1_test_path, sep='\t')
b1_test_df

Unnamed: 0,barcode,gene_id,expression_count,classification
0,AAACCGAAGGCGCATC-1,ENSG00000269113,0.825470,low
1,AAACCGAAGGCGCATC-1,ENSG00000231252,0.495597,low
2,AAACCGAAGGCGCATC-1,ENSG00000188641,0.495597,low
3,AAACCGAAGGCGCATC-1,ENSG00000265972,1.271419,high
4,AAACCGAAGGCGCATC-1,ENSG00000197956,0.495597,low
...,...,...,...,...
18630,TTTAGGATCGTTATCT-1,ENSG00000198938,1.444938,high
18631,TTTAGGATCGTTATCT-1,ENSG00000198840,1.768557,high
18632,TTTAGGATCGTTATCT-1,ENSG00000198886,0.963478,high
18633,TTTAGGATCGTTATCT-1,ENSG00000198786,0.963478,high


In [6]:
b1_val_dataset = CnvDataset(
    root=dataset_root_val,
    data_df=b1_val_df
)

b1_train_dataset = CnvDataset(
    root=dataset_root_train,
    data_df=b1_train_df
)

b1_test_dataset = CnvDataset(
    root=dataset_root_test,
    data_df=b1_test_df
)

Using 51 barcodes
Using 1093 genes
No embedding files for 988 data points in data/embeddings/batch_1/val/single_gene_barcode!
Using 356 barcodes
Using 1595 genes
No embedding files for 4335 data points in data/embeddings/batch_1/train/single_gene_barcode!
Using 102 barcodes
Using 1235 genes
No embedding files for 2149 data points in data/embeddings/batch_1/test/single_gene_barcode!


In [7]:
BATCH_SIZE = 32
train_loader = DataLoader(b1_train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(b1_val_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_loader = DataLoader(b1_test_dataset, batch_size=BATCH_SIZE, shuffle=False)

In [8]:
import torch.nn.functional as F

class ChromosomeCNN(nn.Module):
    def __init__(self,  input_dim, seq_len, output_dim):
        super(ChromosomeCNN, self).__init__()
        
        self.input_dim = input_dim
        self.conv1 = nn.Conv1d(in_channels=input_dim, out_channels=32, kernel_size=5, padding=2)
        self.conv2 = nn.Conv1d(in_channels=32, out_channels=64, kernel_size=5, padding=2)

        self.fc1 = None 
        self.fc2 = nn.Linear(128, 1)
        
        self.seq_len = seq_len
        
    def initialize_fc1(self, x):

        if self.fc1 is None: 
            flattened_size = x.shape[1] * x.shape[2] 
            self.fc1 = nn.Linear(flattened_size, 128).to(x.device)     
    
    def forward(self, inputs_seq):
        
        #print(f"Shape of inputs before permute: {inputs_seq.shape}")
        
        x = inputs_seq#.permute(0, 2, 1)
        
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = torch.flatten(x, start_dim=1)
        
        if self.fc1 is None:

            fc1_input_size = x.shape[1]
            self.fc1 = nn.Linear(fc1_input_size, 128)
                
        x = self.fc1(x)
        x = self.fc2(x)
        
        return x

In [9]:
def create_ablated_dataloader(loader, channel_to_remove, channel_variable_counts):
    ablated_dataloader = []
    
    for batch in loader:
        ablated_inputs, targets = batch
        
        ablated_inputs = ablation_study(ablated_inputs, channel_to_remove, channel_variable_counts)
        
        ablated_dataloader.append((ablated_inputs, targets))
    
        
    return ablated_dataloader

In [10]:
def full_study(loader):
    full_dataloader = []
    
    for batch in loader:
        ablated_inputs, targets = batch
        
        full_dataloader.append((ablated_inputs, targets))
        
    return full_dataloader

In [11]:
def ablation_study(inputs, channel_to_remove, channel_variable_counts):

    start_idx = 0
    ablated_inputs = []
    
    for i, count in enumerate(channel_variable_counts):
        end_idx = start_idx + count
        
        if i != channel_to_remove:  
            ablated_inputs.append(inputs[:, start_idx:end_idx, :])
        
        start_idx = end_idx

    stacked_inputs = torch.cat(ablated_inputs, dim=1)
    
    return stacked_inputs

In [12]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

sequ_len = 10000 ##### add correct one

epochs = 20

def train_(model, train_loader, val_loader, epochs):
    
    best_val_loss = float('inf')
    train_losses_avg = []
    val_losses_avg = []
    
    criterion = nn.BCEWithLogitsLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)
    

    for epoch in range(epochs):

        model.train()
        
        total_loss = 0
        train_losses = []
        
        for stacked_inputs_batch, y_batch in train_loader:

            stacked_inputs_batch = stacked_inputs_batch.to(device)
            y_batch = y_batch.to(device, non_blocking=True)
            #stacked_inputs_batch = stacked_inputs_batch.unsqueeze(0)

            optimizer.zero_grad()

            with autocast():   
                outputs = model(stacked_inputs_batch)
                loss = criterion(outputs, y_batch)
                train_losses.append(loss.item())
            
            loss.backward()
            optimizer.step()
            scheduler.step()
            total_loss += loss.item()
        avg_train_loss = sum(train_losses) / len(train_losses)
        train_losses_avg.append(avg_train_loss)
    
        print(f"Epoch {epoch + 1}/{epochs}, Loss: {total_loss / len(train_loader):.4f}")
    
        model.eval()
        val_losses = []
        for stacked_inputs_batch, y_batch in val_loader:

            stacked_inputs_batch = stacked_inputs_batch.to(device)
            y_batch = y_batch.to(device, non_blocking=True)
            #stacked_inputs_batch = stacked_inputs_batch.unsqueeze(0)

            with torch.no_grad(), autocast():
                y_pred = model(stacked_inputs_batch)
                lossV = criterion(y_pred, y_batch)
                val_losses.append(lossV.item())

        avg_val_loss = sum(val_losses) / len(val_losses)
        val_losses_avg.append(avg_val_loss)
        print(f'Epoch {epoch+1}, Val loss: {avg_val_loss}')
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            best_model = copy.deepcopy(model.state_dict())

    plt.plot(train_losses_avg[1:], label='Train Loss')
    plt.plot(val_losses_avg[1:], label='Val Loss')
    plt.legend()
    plt.show()
            
    return avg_val_loss

In [13]:
def ablation_study_evaluation(train_loader, val_loader, test_loader, channel_variable_counts, seq_len, num_epochs):

    print("Training with all channels intact...")
    num_channels = 7
    
    full_train_loader = full_study(train_loader)
    full_val_loader = full_study(val_loader)
    full_test_loader = full_study(test_loader)
    
    model = ChromosomeCNN(input_dim = num_channels, seq_len = seq_len, output_dim = 1).to(device)
    baseline_loss = train_(model, full_train_loader, full_val_loader, num_epochs)
    
    torch.save({
        'model_state_dict': model.state_dict(),
    }, 'baseline_model.pth')
    
    baseline_test = test_model("baseline_model.pth", full_test_loader, num_channels, seq_len)

    for channel_idx in range(3):
        print(f"\nAblating channel {channel_idx}...")
        
        remaining_channels = [i for i in range(3) if i != channel_idx]
        remaining_variables = sum(channel_variable_counts[i] for i in remaining_channels)
        print("remaing variables", remaining_variables)
        
        model = ChromosomeCNN(input_dim = remaining_variables, seq_len = seq_len, output_dim = 1).to(device)
        
        ablated_train_loader = create_ablated_dataloader(train_loader, channel_idx, channel_variable_counts)
        ablated_val_loader = create_ablated_dataloader(val_loader, channel_idx, channel_variable_counts)
        ablated_test_loader = create_ablated_dataloader(test_loader, channel_idx, channel_variable_counts)

        
        model_ablated = ChromosomeCNN(input_dim=remaining_variables, seq_len=seq_len, output_dim=1).to(device)
        ablated_model_name = f"ablated_model_channel_{channel_idx}"
        
        ablated_loss = train_(model_ablated, ablated_train_loader, ablated_val_loader, epochs)#, ablated_model_name)
        
        ablated_model_filename = f'ablated_model_channel_{channel_idx}.pth'
        torch.save({
            'model_state_dict': model_ablated.state_dict(),
        }, ablated_model_filename)

        results = {}
        results[f"Ablated Channel {channel_idx}"] = test_model(
            f"{ablated_model_name}.pth", ablated_test_loader, remaining_variables, seq_len
        )
        
        
        print(f"Loss after ablating channel {channel_idx}: {ablated_loss:.4f}")
        print(f"Performance drop: {baseline_loss - ablated_loss:.4f}")

In [14]:
def test_model(model_path, test_loader, total_variables, seq_len):

    model = ChromosomeCNN(input_dim=total_variables, seq_len=seq_len, output_dim=1).to(device)
    checkpoint = torch.load(model_path)
    
    input_tensor = torch.zeros(1, model.input_dim, model.seq_len).to(device)
    model(input_tensor)
    
    model.load_state_dict(checkpoint['model_state_dict'])
    
    model.eval()

    criterion = nn.BCEWithLogitsLoss()
    test_losses = []
    all_predictions = []
    all_labels = []

    with torch.no_grad():
        for stacked_inputs_batch, y_batch in test_loader:
            stacked_inputs_batch = stacked_inputs_batch.to(device)
            y_batch = y_batch.to(device, non_blocking=True)
            #stacked_inputs_batch = stacked_inputs_batch.unsqueeze(0)

            with autocast():
                outputs = model(stacked_inputs_batch)
                loss = criterion(outputs, y_batch)
                test_losses.append(loss.item())

                all_predictions.append(outputs.cpu().numpy())
                all_labels.append(y_batch.cpu().numpy())

    avg_test_loss = sum(test_losses) / len(test_losses)
    print(f"Test MSE: {avg_test_loss:.4f}")

    all_predictions = np.concatenate(all_predictions, axis=0)
    all_labels = np.concatenate(all_labels, axis=0)

    probabilities = 1 / (1 + np.exp(-all_predictions))  # Sigmoid function
    predicted_classes = (probabilities >= 0.5).astype(int)  # Convert to 0 or 1 based on threshold

    # Compute accuracy and other metrics
    accuracy = accuracy_score(all_labels, predicted_classes)
    precision = precision_score(all_labels, predicted_classes)
    recall = recall_score(all_labels, predicted_classes)
    f1 = f1_score(all_labels, predicted_classes)
    auc = roc_auc_score(all_labels, probabilities)

    print(f'Accuracy: {accuracy:.4f}')
    print(f'Precision: {precision:.4f}')
    print(f'Recall: {recall:.4f}')
    print(f'F1 Score: {f1:.4f}')
    print(f'AUC: {auc:.4f}')
    
    return avg_test_loss

In [None]:
embedding_dim = 4  
cnv_dim = 2        
chromatin_dim = 1  
expression_dim = 1
seq_len=10_000

ablation_study_evaluation(train_loader, val_loader, test_loader, channel_variable_counts=[embedding_dim, cnv_dim, chromatin_dim], seq_len=seq_len, num_epochs=epochs)

Training with all channels intact...


In [None]:
#model.eval()
#correct = 0
#total = 0
#with torch.no_grad():
#    for X_batch, y_batch in test_loader:
#        X_batch, y_batch = X_batch.to(device).unsqueeze(1), y_batch.to(device).unsqueeze(1)
#        outputs = model(X_batch)
#        predictions = (outputs > 0.5).float()
#        correct += (predictions == y_batch).sum().item()
#        total += y_batch.size(0)

#accuracy = correct / total
#print(f"Test Accuracy: {accuracy * 100:.2f}%")

In [None]:
#model.load_state_dict(best_model)
#model.eval()
#test_losses = []
#y_preds = []
#y_actuals = []

#scaler = GradScaler()

#for X_batch, cnv_batch, y_batch in test_loader:

#    X_batch = X_batch.unsqueeze(1).to(device, non_blocking=True)
#    cnv_batch = cnv_batch.to(device)
#    y_batch = y_batch.to(device, non_blocking=True)
    
#    with torch.no_grad(), autocast():
#        y_pred = model(X_batch, cnv_batch)
#        lossV = criterion(y_pred, y_batch)
        
#        y_preds.extend(y_pred.cpu().numpy())
#        y_actuals.extend(y_batch.cpu().numpy())
#        test_losses.append(lossV.item())

#avg_test_loss = sum(test_losses) / len(test_losses)
#print(f'Test MSE: {avg_test_loss}')

In [None]:
def model_summary(model):
    print("Model Summary:")
    print("{:<50} {:<30} {:<15} {:<15}".format("Layer Name", "Shape", "Parameters", "Trainable"))
    print("-" * 110)
    total_params = 0
    total_trainable_params = 0
    lm_params = 0
    lm_trainable_params = 0
    lm_layers = 0
    for name, parameter in model.named_parameters():
        param = parameter.numel()
        total_params += param
        # Check if the parameter is trainable
        trainable = parameter.requires_grad
        trainable_param = param if trainable else 0
        total_trainable_params += trainable_param
        print("{:<50} {:<30} {:<15} {:<15}".format(name, str(parameter.size()), param, trainable_param))
    print("-" * 110)
    print(f"Total Parameters: {total_params}")
    print(f"Trainable Parameters: {total_trainable_params}")

#model_summary(model)