In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data as Data
from torch.utils.data import DataLoader, TensorDataset, Dataset, random_split
import pandas as pd
import numpy as np
from scipy.stats import t, shapiro, kstest
from sklearn.preprocessing import MinMaxScaler, StandardScaler
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE

In [None]:
#data preparation
def prepare_data():
    #loading data
    abundance = pd.read_csv('~/icr/simko/data/simko2_data/passport_prots.csv', index_col=0)
    abundance.index = abundance.index.astype(str)
    #removing cell lines with over 4000 nans
    nans_per_cl = abundance.isna().sum(axis=0)
    abundance_cl_filtered = abundance.loc[:, nans_per_cl<4000]
    #getting rid of protein with over 80% NaN (from the dataset filtered by CLs)
    prot_nan_count = abundance_cl_filtered.isna().sum(axis=1)
    prot_nan_percent = (prot_nan_count/abundance_cl_filtered.shape[1])*100
    abundance_filtered = abundance_cl_filtered[prot_nan_percent<80]

    #imputing witht the lower quartile average for each protein
    #set the protein names as the index - ignores it while we find the lower quartile
    def average_lower_quartile(x):
        sorted_abundances = x.dropna().sort_values()
        lower_qt_values = sorted_abundances.iloc[:int(len(sorted_abundances) * 0.25)]
        return lower_qt_values.mean()


    lower_qt_averages = abundance_filtered.apply(average_lower_quartile, axis=1)
    abundance_filtered_no_nan = abundance_filtered.apply(lambda x: x.fillna(lower_qt_averages[x.name]), axis=1)

    #transposing
    abundance_imputed = abundance_filtered_no_nan.T
    #scaling the imputed data
    scaler = StandardScaler()
    scaled_data = pd.DataFrame(scaler.fit_transform(abundance_imputed), index=abundance_imputed.index, columns=abundance_imputed.columns)
    #scaled_data = scaled_data.T
    return scaled_data

scaled_data = prepare_data()
scaled_data

In [None]:
#getting pbrm1 values so we can make the continuous - put the values between 0 and 1
raw_pbrm1 = scaled_data["PBRM1"].values
raw_min, raw_max = raw_pbrm1.min(), raw_pbrm1.max()
#condition = 
c = (raw_pbrm1 - raw_min) / (raw_max - raw_min)   

In [None]:
# getting rid of pbrm1 from data set to use in training
X = scaled_data.drop(columns=["PBRM1"]).values.astype(np.float32)


In [None]:
#creating custom data set --> makes it easier to proccess it later on (data augmentation)
class ProteomeDataset(Dataset):
    def __init__(self, X, c):
        self.X = torch.from_numpy(X).float()
        self.c = torch.from_numpy(c).float().unsqueeze(-1)
    def __len__(self):
        return len(self.X)
    def __getitem__(self, i):
        return self.X[i], self.c[i]

full_ds = ProteomeDataset(X, c)
loader  = DataLoader(full_ds, batch_size=32, shuffle=True, drop_last=True)

In [None]:
#splitting into train/validation set to check model isnt over fitting 
# 20% validation
n_val = int(len(full_ds) * 0.2)
n_train = len(full_ds) - n_val
train_ds, val_ds = random_split(full_ds, [n_train, n_val])


train_loader = DataLoader(train_ds, batch_size=32, shuffle=True, drop_last=True)
val_loader = DataLoader(val_ds, batch_size=32, shuffle=False, drop_last=False)

In [None]:
class CVAE(nn.Module):
    def __init__(self, n_proteins, latent_dim=50):
        super().__init__()
        self.enc = nn.Sequential(
            nn.Linear(n_proteins + 1, 1024),
            nn.BatchNorm1d(1024), nn.ReLU(),
            nn.Dropout(p=0.3),

            nn.Linear(1024, 512),
            nn.BatchNorm1d(512), nn.ReLU(),
            nn.Dropout(p=0.3),

            nn.Linear(512, 128),
            nn.BatchNorm1d(128), nn.ReLU(),
            nn.Dropout(p=0.3)

        )
        self.fc_mu     = nn.Linear(128, latent_dim)
        self.fc_logvar = nn.Linear(128, latent_dim)
        self.dec = nn.Sequential(
            nn.Linear(latent_dim + 1, 128),
            nn.BatchNorm1d(128), nn.ReLU(),
            nn.Dropout(p=0.3),

            nn.Linear(128, 512),
            nn.BatchNorm1d(512), nn.ReLU(),
            nn.Dropout(p=0.3),

            nn.Linear(512, 1024),
            nn.BatchNorm1d(1024), nn.ReLU(),
            nn.Dropout(p=0.3),

            nn.Linear(1024, n_proteins)
            #nn.BatchNorm1d(4096), nn.ReLU(), 
            #nn.Dropout(p=0.3),

            #nn.Linear(4096, n_proteins)
        )

    def encode(self, x, c):
        h = self.enc(torch.cat([x, c], dim=1))
        mu, logvar = self.fc_mu(h), self.fc_logvar(h)
        return mu, logvar

    def reparameterize(self, mu, logvar):
        std = (0.5 * logvar).exp()
        eps = torch.randn_like(std)
        return mu + eps * std

    def forward(self, x, c):
        mu, logvar = self.encode(x, c)
        z = self.reparameterize(mu, logvar)
        x_rec = self.dec(torch.cat([z, c], dim=1))
        return x_rec, mu, logvar

In [None]:
#loss function - we will do KL annealing in the training loop
def loss_function(recon_x, x_batch, mu, logvar, beta=1.0):
    """
    recon_x:   (B, P)  reconstructed batch
    x:         (B, P)  original batch
    mu:        (B, L)  latent means
    logvar:    (B, L)  latent log‐variances
    beta:      float   weight on the KL term
    """
    # Reconstruction loss (sum over batch & features)
    batch_size = x_batch.size(0)
    recon_loss = nn.MSELoss(reduction='sum')(recon_x, x_batch)
    # KL divergence term (sum over batch & latent dims)
    kld = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    
    #normlaising loss by batch size
    total_l = (recon_loss + beta * kld) / batch_size
    return total_l, recon_loss / batch_size, kld / batch_size

In [None]:
#trying huber loss instead
def loss_function_huber(recon_x, x_batch, mu, logvar, beta=1.0):
    # Reconstruction loss (sum over batch & features)
    batch_size = x_batch.size(0)
    recon_loss = nn.SmoothL1Loss(reduction='sum')(recon_x, x_batch)
    # KL divergence term (sum over batch & latent dims)
    kld = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    
    #normlaising loss by batch size
    total_l = (recon_loss + beta * kld) / batch_size
    return total_l, recon_loss / batch_size, kld / batch_size

In [None]:
#setting up model and optimiser
#dont need to include conditional dimension the model architecture already knows to expect the 'condtion vector' --> the plus 1 is the extra 'value' that is the pbrm1 condition
model  = CVAE(X.shape[1], latent_dim=100)

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.0005)

#KL annealing
def beta_schedule(epoch, warmup=50):
    return min(1.0, epoch / warmup)

In [None]:
n_epochs = 250

In [None]:
total_losses = []
recon_losses = []
kl_losses = []

val_total_losses = []
val_recon_losses = []
val_kl_losses = []
for epoch in range(1, n_epochs+1):
    total_loss = 0
    recon = 0
    KL = 0
    β = beta_schedule(epoch)     # e.g. linear ramp 0→1 over first 50 epochs
    #β = 1
    for x_batch, c_batch in train_loader:
        x_batch, c_batch = x_batch, c_batch

        # jitter c if you’re using augmentation
        noise   = torch.randn_like(c_batch) * 0.02
        c_noisy = torch.clamp(c_batch + noise, 0.0, 1.0)

        x_rec, mu, logvar = model(x_batch, c_batch)
        loss, recon_loss, kld = loss_function_huber(x_rec, x_batch, mu, logvar, beta=β)

        optimizer.zero_grad()
        loss.backward()
        total_loss += loss.item()
        recon += recon_loss.item()
        KL += kld.item()
        optimizer.step()

    total_losses.append(total_loss / len(train_loader))
    recon_losses.append(recon / len(train_loader))
    kl_losses.append(KL / len(train_loader))
        
    # VALIDATION LOOP — added here
    model.eval()  # Switch to eval mode for validation
    val_total_loss = 0
    val_recon = 0
    val_KL = 0

    with torch.no_grad():
        for x_val, c_val in val_loader:
            noise   = torch.randn_like(c_val) * 0.02
            c_noisy = torch.clamp(c_val + noise, 0.0, 1.0)

            x_rec, mu, logvar = model(x_val, c_val)
            loss, recon_loss, kld = loss_function_huber(x_rec, x_val, mu, logvar, beta=β)

            val_total_loss += loss.item()
            val_recon += recon_loss.item()
            val_KL += kld.item()

    val_total_losses.append(val_total_loss / len(val_loader))
    val_recon_losses.append(val_recon / len(val_loader))
    val_kl_losses.append(val_KL / len(val_loader))

    model.train()

    print(f"Epoch {epoch:03d} | β={β:.2f} | Totak_Loss={total_loss/len(train_loader):.2f} | Loss={recon/len(train_loader):.2f} | Loss_per_element={recon/(896*6892):.2f} | KL={KL/len(train_loader):.2f}")

In [None]:
#plotting loss curves (training and validation)
plt.figure(figsize=(10, 4))
plt.plot(total_losses, label="Total Loss")
plt.plot(recon_losses, label='Reconstruction Loss')
plt.plot(val_total_losses, label="Total Loss (Validation)")
plt.plot(val_recon_losses, label="Reconstruction Loss (Validation)")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.title("Total & Reconstruction Loss")
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.show()

plt.figure(figsize=(10, 4))
plt.plot(kl_losses, label="KL Divergence", color="orange")
plt.plot(val_kl_losses, label="KL Divergence (Validation)", color='blue')
plt.xlabel("Epoch")
plt.ylabel("KL Loss")
plt.title("KL Divergence Over Epochs")
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.show()


In [None]:
model.eval()
with torch.no_grad():
    for x_val, c_val in val_loader:
        x_val = x_val

        # --- encode with the original c_val, to capture each sample's background z ---
        mu, logvar = model.encode(x_val, c_val)
        z = model.reparameterize(mu, logvar)

        # --- force full knockout by setting c=0 ---
        batch_size = x_val.size(0)
        c_knock = torch.zeros(batch_size, 1)

        # --- decode at c=0 to get KO-simulated profiles ---
        x_ko = model.dec(torch.cat([z, c_knock], dim=1))

        # now x_ko has shape [batch_size, n_proteins] and is your simulated KO data
        # you can collect these or directly compare to x_val downstream

In [None]:
all_ko = []
all_orig = []
model.eval()
with torch.no_grad():
    for x_val, c_val in val_loader:
        x_val = x_val
        mu, logvar = model.encode(x_val, c_val)
        z = model.reparameterize(mu, logvar)
        c_knock = torch.zeros(x_val.size(0), 1)
        x_ko = model.dec(torch.cat([z, c_knock], dim=1))
        all_ko.append(x_ko.cpu())
        all_orig.append(x_val.cpu())

all_ko  = torch.cat(all_ko, dim=0).numpy()   # simulated KO
all_orig= torch.cat(all_orig, dim=0).numpy() # original profiles

In [None]:
scaled_data_no_pbrm1 = scaled_data.drop(columns=["PBRM1"])
protein_names = scaled_data_no_pbrm1.columns.tolist()
og_values = scaled_data_no_pbrm1.values

In [None]:
# all_orig, all_ko are both shape [N_val, n_proteins]
df_orig = pd.DataFrame(all_orig, columns=protein_names)
df_ko   = pd.DataFrame(all_ko,   columns=protein_names)

In [None]:
delta_means = df_ko.mean(axis=0) - df_orig.mean(axis=0)
# Sort by largest drop (or increase)
delta_means = delta_means.sort_values()
print(delta_means.head(50))   # top 10 most negative shifts
#print(delta_means.tail(10))   # top 10 increases

In [None]:
# 1. Compute per-protein means
mean_orig = df_orig.mean(axis=0)
mean_ko   = df_ko.mean(axis=0)

# 2. Build the summary DataFrame
df_summary = pd.DataFrame({
    'mean_orig': mean_orig,
    'mean_ko':   mean_ko,
})
df_summary['diff'] = df_summary['mean_ko'] - df_summary['mean_orig']

# 3. Sort by the diff column
df_summary_sorted = df_summary.sort_values(by='diff')

# 4. (Optional) Reset index so protein names become a column
df_summary_sorted = df_summary_sorted.reset_index().rename(columns={'index':'protein'})

# View the top 10 proteins most down-regulated by KO:
df_summary_sorted.head(50)


In [None]:
#t test ti see if there signigicance difference between mean in orig and mean in  ko
from scipy.stats import ttest_ind
p_values = {}
for protein in df_ko.columns:
    t_stat, p_val = ttest_ind(df_ko[protein], df_orig[protein], equal_var=False)
    p_values[protein] = p_val

# Convert results to a DataFrame
p_values_df = pd.DataFrame.from_dict(p_values, orient='index', columns=['p_value'])

# Optionally: sort by significance
p_values_df = p_values_df.sort_values(by='p_value')


In [None]:
#merging p values with the mean diff for each protein
p_values_df = p_values_df.reset_index()
p_values_df = p_values_df.rename(columns={'index': 'protein'})
df_summary_sf = pd.merge(df_summary_sorted, p_values_df, on='protein')
df_summary_sf


In [None]:
#saving the data
#df_summary_sorted.to_csv("protein_shift_summary.csv", index=False)

In [None]:
#saving the df_ko and df_orig 'raw' data and the protein shift summaries along with significance
df_summary_sf.to_csv("CVAE_PBRM1_ko_results(basic).csv")
#df_orig.to_csv("orig_proteindata_testsample.csv")
#df_ko.to_csv("ko_proteindata_testsample.csv")

In [None]:
#plotting protein abundance changes for top 15 proteins
import seaborn as sns
import matplotlib.patches as mpatches


In [None]:
top_15_prots = df_summary_sorted.head(20)

#pbaf proteins
PBAF = ('ARID2', 'PHF10', 'BRD7', 'PBRM1', 'SMARCC1', 'SMARCC2', 'SMARCE1', 'SMARCB1',
        'SMARCD1', 'SMARCD2', 'SMARCD3', 'SMARCA2', 'SMARCA4', 'BCL7A',
        'BCL7B', 'BCL7C', 'ACTB', 'ACTL6A')

# assume top_15_prots is your DataFrame
top_15_prots['is_pbaf'] = top_15_prots['protein'].isin(PBAF)

# define colours for True/False
palette = {True: "steelblue",   # e.g. red for PBAF
           False: "darkgrey"}  # grey for non-PBAF

plt.figure(figsize=(10, 6))
sns.barplot(
    data=top_15_prots,
    x="protein", y="diff",
    hue="is_pbaf",
    dodge=False,            # so bars are not side-by-side
    palette=palette,
    edgecolor='black'
)
pbaf_patch   = mpatches.Patch(color='steelblue', label='PBAF proteins')
other_patch  = mpatches.Patch(color='darkgrey', label='Other proteins')
plt.legend(handles=[pbaf_patch, other_patch],
           title="", loc='lower right', frameon=False)
plt.xlabel("Protein")
plt.ylabel("Abundance Change")
plt.title("Top 15 Downregulated Proteins After Simulated PBRM1 KO")
plt.xticks(rotation=45, ha="right")
plt.show()

In [None]:
val_ds