In [1]:
import torch
import pandas as pd
import numpy as np
import torch

In [2]:
#with open('data/fragments2.tsv', 'r') as file:
#    for i in range(1000):
#        print(file.readline())

In [3]:
#df = pd.read_csv('data/fragments1.tsv', sep='\t', skiprows=51)
#df = pd.read_csv('data/fragments2.tsv', sep='\t', skiprows=51, header=None)

#df.columns = ['Chromosome', 'Start', 'End', 'Barcode', 'Count']

In [4]:
#df
#df_subset = df.head(100000)

In [5]:
# ADD EMBEDDINGS CREATION        

In [6]:
#embeddingsTot = embed()  ## dont run, first add embeddings creation

#embeddings = embeddingsTot[0]
#cnv = embeddingsTot[1]
#open_cromatin = embeddingsTot[2]

In [7]:
seq_len = 6000
num_samples = 100

embeddings = np.random.rand(num_samples, seq_len, 4)  # 100 samples, each of length 6000 and 4 features
cnv = np.random.rand(num_samples, seq_len, 2)  # 100 samples, each of length 6000 and 4 features
open_cromatin = np.random.rand(num_samples, seq_len, 1)  # 100 samples, each of length 6000 and 4 features
gene_expression = np.random.rand(num_samples, 1)

In [8]:
from torch.utils.data import TensorDataset, DataLoader, random_split, Dataset
from torch.cuda.amp import GradScaler, autocast
import torch.nn as nn
import torch.optim as optim
import copy

class ChromosomeDataset(Dataset):
    def __init__(self, embeddings, cnv_data, open_cromatin, gene_expression):
        self.embeddings = torch.tensor(embeddings, dtype=torch.float32)
        self.cnv_data = torch.tensor(cnv_data, dtype=torch.float32)
        self.open_cromatin = torch.tensor(open_cromatin, dtype=torch.float32)
        self.gene_expression = torch.tensor(gene_expression, dtype=torch.float32)

    def __len__(self):
        return len(self.embeddings)

    def __getitem__(self, idx):
        embeddings = self.embeddings[idx]
        cnv = self.cnv_data[idx]
        open_cromatin = self.open_cromatin[idx]
        gene_expression = self.gene_expression[idx]
        
        return embeddings, cnv, open_cromatin, gene_expression

dataset = ChromosomeDataset(embeddings, cnv, open_cromatin, gene_expression)

train_size = int(0.6 * len(dataset))
val_size = int(0.3 * len(dataset))
test_size = len(dataset) - train_size - val_size

train_dataset, val_dataset, test_dataset = random_split(dataset, [train_size, val_size,test_size])

BATCH_SIZE = 32
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

In [9]:
class StackedDataset(Dataset):
    def __init__(self, ablated_inputs, gene_expression):
        self.ablated_inputs = torch.tensor(ablated_inputs, dtype=torch.float32)
        self.gene_expression = torch.tensor(gene_expression, dtype=torch.float32)

    def __len__(self):
        return len(self.ablated_inputs)

    def __getitem__(self, idx):
        ablated_inputs = self.ablated_inputs[idx]
        gene_expression = self.gene_expression[idx]
        
        return ablated_inputs, gene_expression

In [10]:
import torch.nn.functional as F

class ChromosomeCNN(nn.Module):
    def __init__(self,  input_dim, seq_len, output_dim):
        super(ChromosomeCNN, self).__init__()
        
        self.input_dim = input_dim
        self.conv1 = nn.Conv1d(in_channels=input_dim, out_channels=32, kernel_size=5, padding=2)
        self.conv2 = nn.Conv1d(in_channels=32, out_channels=64, kernel_size=5, padding=2)

        self.fc1 = None 
        self.fc2 = nn.Linear(128, 1)
        
        self.seq_len = seq_len
        
    def initialize_fc1(self, x):

        if self.fc1 is None: 
            flattened_size = x.shape[1] * x.shape[2] 
            self.fc1 = nn.Linear(flattened_size, 128).to(x.device)     
    
    def forward(self, inputs_seq):
        
        x = inputs_seq.permute(0, 2, 1)
        
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = torch.flatten(x, start_dim=1)
        
        if self.fc1 is None:

            fc1_input_size = x.shape[1]
            self.fc1 = nn.Linear(fc1_input_size, 128)
                
        x = self.fc1(x)
        x = self.fc2(x)
        
        return x

In [11]:
def create_ablated_dataloader(loader, channel_to_remove, channel_variable_counts):
    ablated_dataloader = []
    
    for batch in loader:
                    
        inputs = batch[:-1]
        targets = batch[-1]
       
        ablated_inputs = ablation_study(inputs, channel_to_remove, channel_variable_counts)
        #ablated_data.append((ablated_inputs, targets))  # Keep targets intact
        ablated_dataloader = StackedDataset(ablated_inputs, targets)
        
    return ablated_dataloader

In [12]:
def full_study(loader):
    full_dataloader = []
    
    for batch in loader:
               
        inputs = batch[:-1]
        targets = batch[-1]
        
        embeddings, cnv, open_cromatin = inputs
       
        _inputs = (cnv, open_cromatin)
        stacked_inputs = torch.cat(_inputs, dim=-1)

        full_dataloader = StackedDataset(stacked_inputs, targets)
        
    return full_dataloader

In [13]:
def ablation_study(inputs, channel_to_remove, channel_variable_counts):

    embeddings, cnv, open_cromatin = inputs
    
    if channel_to_remove == 0:
        # Ablating embeddings: remove embeddings
        ablated_inputs = (cnv, open_cromatin)
    elif channel_to_remove == 1:
        # Ablating cnv: remove cnv
        ablated_inputs = (embeddings, open_cromatin)
    elif channel_to_remove == 2:
        # Ablating open_cromatin: remove open_cromatin
        ablated_inputs = (embeddings, cnv)
        
    stacked_inputs = torch.cat(ablated_inputs, dim=-1)

    return stacked_inputs

In [14]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

sequ_len = 6000 ##### add correct one

epochs = 2

def train_(model, train_loader, val_loader, epochs):
    
    best_val_loss = float('inf')
    train_losses_avg = []
    val_losses_avg = []
    
    criterion = nn.MSELoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)
    

    for epoch in range(epochs):

        model.train()
        
        total_loss = 0
        
        for stacked_inputs_batch, y_batch in train_loader:

            stacked_inputs_batch = stacked_inputs_batch.to(device)
            y_batch = y_batch.to(device, non_blocking=True)
            stacked_inputs_batch = stacked_inputs_batch.unsqueeze(0)

            optimizer.zero_grad()

            with autocast():   
                outputs = model(stacked_inputs_batch)
                loss = criterion(outputs, y_batch)
            
            loss.backward()
            optimizer.step()
            scheduler.step()
            total_loss += loss.item()
    
        print(f"Epoch {epoch + 1}/{epochs}, Loss: {total_loss / len(train_loader):.4f}")
    
        model.eval()
        val_losses = []
        for stacked_inputs_batch, y_batch in val_loader:

            stacked_inputs_batch = stacked_inputs_batch.to(device)
            y_batch = y_batch.to(device, non_blocking=True)
            stacked_inputs_batch = stacked_inputs_batch.unsqueeze(0)

            with torch.no_grad(), autocast():
                y_pred = model(stacked_inputs_batch)
                lossV = criterion(y_pred, y_batch)
                val_losses.append(lossV.item())

        avg_val_loss = sum(val_losses) / len(val_losses)
        val_losses_avg.append(avg_val_loss)
        print(f'Epoch {epoch+1}, Val loss: {avg_val_loss}')
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            best_model = copy.deepcopy(model.state_dict())
            
    return avg_val_loss

In [15]:
def ablation_study_evaluation(dataloader, channel_variable_counts, seq_len, num_epochs):

    print("Training with all channels intact...")
    num_channels = 3
    
    full_train_loader = full_study(train_loader)
    full_val_loader = full_study(val_loader)
    full_test_loader = full_study(test_loader)
    
    model = ChromosomeCNN(input_dim = num_channels, seq_len = seq_len, output_dim = 1).to(device)
    baseline_loss = train_(model, full_train_loader, full_val_loader, num_epochs)
    
    torch.save({
        'model_state_dict': model.state_dict(),
    }, 'baseline_model.pth')
    
    baseline_test = test_model("baseline_model.pth", full_test_loader, num_channels, seq_len)

    for channel_idx in range(3):
        print(f"\nAblating channel {channel_idx}...")
        
        remaining_channels = [i for i in range(3) if i != channel_idx]
        remaining_variables = sum(channel_variable_counts[i] for i in remaining_channels)
        print("remaing variables", remaining_variables)
        
        model = ChromosomeCNN(input_dim = remaining_variables, seq_len = seq_len, output_dim = 1).to(device)
        
        ablated_train_loader = create_ablated_dataloader(train_loader, channel_idx, channel_variable_counts)
        ablated_val_loader = create_ablated_dataloader(val_loader, channel_idx, channel_variable_counts)
        ablated_test_loader = create_ablated_dataloader(test_loader, channel_idx, channel_variable_counts)

        
        model_ablated = ChromosomeCNN(input_dim=remaining_variables, seq_len=seq_len, output_dim=1).to(device)
        ablated_model_name = f"ablated_model_channel_{channel_idx}"
        
        ablated_loss = train_(model_ablated, ablated_train_loader, ablated_val_loader, epochs)#, ablated_model_name)
        
        ablated_model_filename = f'ablated_model_channel_{channel_idx}.pth'
        torch.save({
            'model_state_dict': model_ablated.state_dict(),
        }, ablated_model_filename)

        results = {}
        results[f"Ablated Channel {channel_idx}"] = test_model(
            f"{ablated_model_name}.pth", ablated_test_loader, remaining_variables, seq_len
        )
        
        
        print(f"Loss after ablating channel {channel_idx}: {ablated_loss:.4f}")
        print(f"Performance drop: {baseline_loss - ablated_loss:.4f}")

In [16]:
def test_model(model_path, test_loader, total_variables, seq_len):

    model = ChromosomeCNN(input_dim=total_variables, seq_len=seq_len, output_dim=1).to(device)
    checkpoint = torch.load(model_path)
    
    input_tensor = torch.zeros(1, model.seq_len, model.input_dim).to(device)
    model(input_tensor)
    
    model.load_state_dict(checkpoint['model_state_dict'])
    
    model.eval()

    criterion = nn.MSELoss()
    test_losses = []

    with torch.no_grad():
        for stacked_inputs_batch, y_batch in test_loader:
            stacked_inputs_batch = stacked_inputs_batch.to(device)
            y_batch = y_batch.to(device, non_blocking=True)
            stacked_inputs_batch = stacked_inputs_batch.unsqueeze(0)

            with autocast():
                outputs = model(stacked_inputs_batch)
                loss = criterion(outputs, y_batch)
                test_losses.append(loss.item())

    avg_test_loss = sum(test_losses) / len(test_losses)
    print(f"Test MSE: {avg_test_loss:.4f}")
    return avg_test_loss

In [17]:
embedding_dim = 4  
cnv_dim = 2        
chromatin_dim = 1  
expression_dim = 1

ablation_study_evaluation(dataset, channel_variable_counts=[embedding_dim, cnv_dim, chromatin_dim, expression_dim], seq_len=seq_len, num_epochs=epochs)


Training with all channels intact...


  This is separate from the ipykernel package so we can avoid doing imports until
  after removing the cwd from sys.path.
  return F.mse_loss(input, target, reduction=self.reduction)


Epoch 1/2, Loss: 0.4120
Epoch 1, Val loss: 0.2943000201679145
Epoch 2/2, Loss: 0.4016
Epoch 2, Val loss: 0.2942292742198333
Test MSE: 0.2305

Ablating channel 0...
remaing variables 3
Epoch 1/2, Loss: 0.3617
Epoch 1, Val loss: 0.22877146471674525
Epoch 2/2, Loss: 0.3530
Epoch 2, Val loss: 0.22871025035417838
Test MSE: 0.1708
Loss after ablating channel 0: 0.2287
Performance drop: 0.0655

Ablating channel 1...
remaing variables 5
Epoch 1/2, Loss: 0.3815
Epoch 1, Val loss: 0.2554075235806522
Epoch 2/2, Loss: 0.3722
Epoch 2, Val loss: 0.255347382388815
Test MSE: 0.1988
Loss after ablating channel 1: 0.2553
Performance drop: 0.0389

Ablating channel 2...
remaing variables 6
Epoch 1/2, Loss: 0.4337
Epoch 1, Val loss: 0.3329188159356515
Epoch 2/2, Loss: 0.4240
Epoch 2, Val loss: 0.33284032370429484
Test MSE: 0.2713
Loss after ablating channel 2: 0.3328
Performance drop: -0.0386


In [18]:
#model.eval()
#correct = 0
#total = 0
#with torch.no_grad():
#    for X_batch, y_batch in test_loader:
#        X_batch, y_batch = X_batch.to(device).unsqueeze(1), y_batch.to(device).unsqueeze(1)
#        outputs = model(X_batch)
#        predictions = (outputs > 0.5).float()
#        correct += (predictions == y_batch).sum().item()
#        total += y_batch.size(0)

#accuracy = correct / total
#print(f"Test Accuracy: {accuracy * 100:.2f}%")

In [19]:
#model.load_state_dict(best_model)
#model.eval()
#test_losses = []
#y_preds = []
#y_actuals = []

#scaler = GradScaler()

#for X_batch, cnv_batch, y_batch in test_loader:

#    X_batch = X_batch.unsqueeze(1).to(device, non_blocking=True)
#    cnv_batch = cnv_batch.to(device)
#    y_batch = y_batch.to(device, non_blocking=True)
    
#    with torch.no_grad(), autocast():
#        y_pred = model(X_batch, cnv_batch)
#        lossV = criterion(y_pred, y_batch)
        
#        y_preds.extend(y_pred.cpu().numpy())
#        y_actuals.extend(y_batch.cpu().numpy())
#        test_losses.append(lossV.item())

#avg_test_loss = sum(test_losses) / len(test_losses)
#print(f'Test MSE: {avg_test_loss}')

In [21]:
def model_summary(model):
    print("Model Summary:")
    print("{:<50} {:<30} {:<15} {:<15}".format("Layer Name", "Shape", "Parameters", "Trainable"))
    print("-" * 110)
    total_params = 0
    total_trainable_params = 0
    lm_params = 0
    lm_trainable_params = 0
    lm_layers = 0
    for name, parameter in model.named_parameters():
        param = parameter.numel()
        total_params += param
        # Check if the parameter is trainable
        trainable = parameter.requires_grad
        trainable_param = param if trainable else 0
        total_trainable_params += trainable_param
        print("{:<50} {:<30} {:<15} {:<15}".format(name, str(parameter.size()), param, trainable_param))
    print("-" * 110)
    print(f"Total Parameters: {total_params}")
    print(f"Trainable Parameters: {total_trainable_params}")

#model_summary(model)