In [303]:
import pandas as pd
from scipy import stats
from sklearn.preprocessing import StandardScaler
from scipy import stats
from sklearn.decomposition import PCA
import matplotlib.pyplot as plt
from statsmodels.nonparametric.kernel_density import KDEMultivariate
# from sklearn.model_selection import GroupShuffleSplit
from sklearn.model_selection import StratifiedGroupKFold
import seaborn as sns
import numpy as np
import re
import os
import sys; print(sys.executable)

/Users/gc3045/miniconda3/bin/python


In [304]:

%load_ext autoreload
%autoreload 2

from helpers.utils import learn_kde
from helpers.utils import score_with_kdes
from helpers.utils import reshape_to_cmat, _site_slice
import helpers.utils 

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [305]:
codebook_fname = "/Users/gc3045/git/laml2-experiments/real_data/PEtracer/PEtracer raw data/edit_codebook.csv"
validation_fname = "/Users/gc3045/git/laml2-experiments/real_data/PEtracer/PEtracer raw data/preedited_merfish_invitro_alleles.csv"
out_dir = "/Users/gc3045/git/laml2-experiments/real_data/PEtracer/inputs/"


In [306]:
validation_df = pd.read_csv(validation_fname)
codebook = pd.read_csv(codebook_fname)

dfw = validation_df.copy().fillna("unedited")
codebook_df = codebook.copy().fillna("unedited")

In [307]:
codebook_df

Unnamed: 0,site,edit,bit
0,HEK3,GATAG,r25
1,HEK3,AATCG,r26
2,HEK3,GCAAG,r27
3,HEK3,GCGCC,r28
4,HEK3,CTTTG,r29
5,HEK3,ATCAA,r30
6,HEK3,CTCTC,r31
7,HEK3,ATTTA,r32
8,EMX1,GGACA,r33
9,EMX1,ACAAT,r34


In [308]:
id_cols    = ["cellBC", "intID", "clone"]
pieces = [
    _site_slice(dfw, "HEK3", id_cols, codebook_df),
    _site_slice(dfw, "EMX1", id_cols, codebook_df),
    _site_slice(dfw, "RNF2", id_cols, codebook_df),
]
long_df = pd.concat(pieces, axis=0, ignore_index=True)

In [309]:
long_df

Unnamed: 0,cellBC,intID,clone,target_site,pet_state,seq_state,brightest_state,pet_prob,feature_0,feature_1,feature_2,feature_3,feature_4,feature_5,feature_6,feature_7,feature_8
0,4T1_preedited-1057,intID1011,2,HEK3,CTCTC,CTCTC,CTCTC,0.999968,217.510,217.293,179.428,137.317,266.385,203.037,2130.065,343.644,247.875
1,4T1_preedited-1120,intID1011,2,HEK3,CTCTC,CTCTC,CTCTC,0.998348,381.193,1051.053,397.811,592.112,1058.296,211.300,3782.266,1129.117,798.350
2,4T1_preedited-1143,intID1011,2,HEK3,CTCTC,CTCTC,CTCTC,0.999994,447.105,514.259,181.789,366.912,617.272,203.037,5035.697,709.823,373.422
3,4T1_preedited-1153,intID1011,2,HEK3,CTCTC,CTCTC,CTCTC,0.999737,262.551,402.394,194.774,206.525,537.598,206.578,2346.477,414.466,422.514
4,4T1_preedited-1162,intID1011,2,HEK3,CTCTC,CTCTC,CTCTC,0.999808,1261.121,1508.173,310.458,1648.905,1928.272,246.714,9536.407,1950.001,1693.274
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
131800,4T1_preedited-8630,intID974,2,RNF2,TTCCT,TTCCT,TTCCT,0.999471,218.098,139.293,1132.592,273.628,125.128,141.711,148.886,197.135,181.789
131801,4T1_preedited-8638,intID974,2,RNF2,TTCCT,TTCCT,TTCCT,0.999980,634.978,184.150,4390.856,626.125,216.022,377.897,534.379,750.765,211.300
131802,4T1_preedited-8639,intID974,2,RNF2,TTCCT,TTCCT,TTCCT,0.991395,241.436,182.970,1024.936,212.464,147.556,175.766,169.810,514.676,174.706
131803,4T1_preedited-8644,intID974,2,RNF2,TTCCT,TTCCT,TTCCT,0.999988,293.748,168.804,2760.626,358.131,173.526,252.664,181.077,445.029,158.180


In [310]:
# check that there is only one state for each (clone, intID, target_site)
ok = dfw.groupby(id_cols)["HEK3"].nunique().le(1).all()
print("One pet_state per (clone,intID,HEK3)?", ok)
ok = dfw.groupby(id_cols)["EMX1"].nunique().le(1).all()
print("One pet_state per (clone,intID,EMX1)?", ok)
ok = dfw.groupby(id_cols)["RNF2"].nunique().le(1).all()
print("One pet_state per (clone,intID,RNF2)?", ok)

One pet_state per (clone,intID,HEK3)? True
One pet_state per (clone,intID,EMX1)? True
One pet_state per (clone,intID,RNF2)? True


In [311]:
ok = long_df.groupby(id_cols + ["target_site"])['pet_state'].nunique().le(1).all()
print("One pet_state per (clone,intID,target_site)?", ok)
ok = long_df.groupby(id_cols + ["target_site"])['seq_state'].nunique().le(1).all()
print("One seq_state per (clone,intID,target_site)?", ok)

One pet_state per (clone,intID,target_site)? True
One seq_state per (clone,intID,target_site)? True


In [312]:
unique_cassettes = long_df['intID'].unique()
cassette_mapping = {id_: i for i, id_ in enumerate(unique_cassettes)}
long_df["cassette_idx"] = long_df["intID"].map(cassette_mapping)

site_mapping = {'RNF2': 0, 'HEK3': 1, 'EMX1': 2}
long_df['target_idx'] = long_df['target_site'].map(site_mapping) + len(site_mapping) * long_df["cassette_idx"]

In [313]:
print("Range of target_idx:", min(long_df['target_idx']), max(long_df['target_idx']))
print("Range of cassette_idx:", min(long_df['cassette_idx']), max(long_df['cassette_idx']))

Range of target_idx: 0 95
Range of cassette_idx: 0 31


In [314]:
site_states = (
    long_df[['target_site', 'seq_state']]
    .drop_duplicates()
    .groupby('target_site')['seq_state']
    .apply(list)
)

In [315]:
# save a codebook
rows = []
for site, labels in site_states.items():
    # ensure 'unedited' is 0, others get 1..8 sorted alphabetically
    labels_sorted = sorted([l for l in labels if l != "unedited"])
    ordered = ["unedited"] + labels_sorted
    for i, lbl in enumerate(ordered):
        rows.append({"site": site, "label": lbl, "genotype": i})

label_codebook = pd.DataFrame(rows).sort_values(["site", "genotype"], ignore_index=True)

In [316]:
label_codebook.to_csv(out_dir + "/label_codebook.csv", index=False) 

In [317]:
for col in ['seq', 'brightest', 'pet']:
    long_df = long_df.merge(
        label_codebook.rename(columns={'site': 'target_site', 'label': f'{col}_state', 'genotype': f'{col}_geno'}),
        on=['target_site', f'{col}_state'],
        how='left'
    )

In [318]:
long_df

Unnamed: 0,cellBC,intID,clone,target_site,pet_state,seq_state,brightest_state,pet_prob,feature_0,feature_1,...,feature_4,feature_5,feature_6,feature_7,feature_8,cassette_idx,target_idx,seq_geno,brightest_geno,pet_geno
0,4T1_preedited-1057,intID1011,2,HEK3,CTCTC,CTCTC,CTCTC,0.999968,217.510,217.293,...,266.385,203.037,2130.065,343.644,247.875,0,1,4,4,4
1,4T1_preedited-1120,intID1011,2,HEK3,CTCTC,CTCTC,CTCTC,0.998348,381.193,1051.053,...,1058.296,211.300,3782.266,1129.117,798.350,0,1,4,4,4
2,4T1_preedited-1143,intID1011,2,HEK3,CTCTC,CTCTC,CTCTC,0.999994,447.105,514.259,...,617.272,203.037,5035.697,709.823,373.422,0,1,4,4,4
3,4T1_preedited-1153,intID1011,2,HEK3,CTCTC,CTCTC,CTCTC,0.999737,262.551,402.394,...,537.598,206.578,2346.477,414.466,422.514,0,1,4,4,4
4,4T1_preedited-1162,intID1011,2,HEK3,CTCTC,CTCTC,CTCTC,0.999808,1261.121,1508.173,...,1928.272,246.714,9536.407,1950.001,1693.274,0,1,4,4,4
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
131800,4T1_preedited-8630,intID974,2,RNF2,TTCCT,TTCCT,TTCCT,0.999471,218.098,139.293,...,125.128,141.711,148.886,197.135,181.789,31,93,8,8,8
131801,4T1_preedited-8638,intID974,2,RNF2,TTCCT,TTCCT,TTCCT,0.999980,634.978,184.150,...,216.022,377.897,534.379,750.765,211.300,31,93,8,8,8
131802,4T1_preedited-8639,intID974,2,RNF2,TTCCT,TTCCT,TTCCT,0.991395,241.436,182.970,...,147.556,175.766,169.810,514.676,174.706,31,93,8,8,8
131803,4T1_preedited-8644,intID974,2,RNF2,TTCCT,TTCCT,TTCCT,0.999988,293.748,168.804,...,173.526,252.664,181.077,445.029,158.180,31,93,8,8,8


In [319]:
ok = long_df.groupby(["target_site"])["seq_geno"].nunique()
print("Num seq_geno per (clone,intID,target_site)?", ok)

Num seq_geno per (clone,intID,target_site)? target_site
EMX1    9
HEK3    9
RNF2    9
Name: seq_geno, dtype: int64


In [320]:
# check that there is only one state for each (clone, intID, target_site)
ok = long_df.groupby(["clone", "target_site"])["seq_geno"].nunique()
print("Num seq_geno per (clone,intID,target_site)?", ok)

Num seq_geno per (clone,intID,target_site)? clone  target_site
1      EMX1           6
       HEK3           5
       RNF2           6
2      EMX1           8
       HEK3           7
       RNF2           8
3      EMX1           1
       HEK3           1
       RNF2           1
4      EMX1           1
       HEK3           1
       RNF2           1
Name: seq_geno, dtype: int64


In [321]:
long_df['seq_geno'].nunique()

9

#### Score

In [322]:
df = long_df.copy()

In [323]:
y = df["seq_geno"]
groups = df["cellBC"]
sgkf = StratifiedGroupKFold(n_splits=5, shuffle=True, random_state=0)
train_idx, test_idx = next(sgkf.split(X=df, y=y, groups=groups))
df_train = df.iloc[train_idx].copy()
df_test  = df.iloc[test_idx].copy()

In [324]:
# check that there is only one state for each (clone, intID, target_site)
ok = df_train.groupby(["target_site"])["seq_geno"].nunique()
print("Num seq_geno per (clone,intID,target_site)?", ok)

Num seq_geno per (clone,intID,target_site)? target_site
EMX1    9
HEK3    9
RNF2    9
Name: seq_geno, dtype: int64


In [325]:
# check that there is only one state for each (clone, intID, target_site)
ok = df_test.groupby(["target_site"])["seq_geno"].nunique()
print("Num seq_geno per (clone,intID,target_site)?", ok)

Num seq_geno per (clone,intID,target_site)? target_site
EMX1    9
HEK3    9
RNF2    9
Name: seq_geno, dtype: int64


In [326]:
feature_cols = [col for col in df_train.columns if col.startswith('feature')]

In [328]:
all_kde_dict_X = {}
all_scales = {}
all_pca = {}
preds_X_full = pd.DataFrame(index=df_test.index, columns=["pred_label","argmax_logpdf"])
scores_X_tables = []

for site in df_train['target_site'].unique():
    print(site)
    site_df_train = df_train.loc[df_train['target_site'] == site]
    site_df_test = df_test.loc[df_test['target_site'] == site]

    all_kde_dict_X[site] = {}
    all_scales[site] = {}
    all_pca[site] = {}
    site_labels = site_df_train["seq_geno"].unique()
    
    for lab in site_labels:
        m_train = site_df_train["seq_geno"] == lab
        print("label", lab, sum(m_train))

        X_block = site_df_train[feature_cols].loc[m_train]        
        scaler = StandardScaler()
        X_block = scaler.fit_transform(X_block.values.copy())
        pca = PCA(n_components=9, random_state=0, svd_solver='full', whiten=False)
        X_block = pca.fit_transform(X_block.copy())
    
        all_scales[site][lab] = scaler
        all_pca[site][lab] = pca
        all_kde_dict_X[site][lab] = learn_kde(X_block)

    # test
    pred_cols = ["pred_label", "argmax_logpdf"]
    collect_logpdf = []
    for lab in site_labels:
        scaler = all_scales[site][lab]
        pca = all_pca[site][lab]
        kde = all_kde_dict_X[site][lab]
        X_test = scaler.transform(site_df_test[feature_cols])
        X_test = pca.transform(X_test)
        jacobian = np.sum(np.log(np.maximum(scaler.scale_, 1e-12)))
        
        logpdf = np.log(np.maximum(kde.pdf(X_test), 1e-300)) - jacobian
        collect_logpdf.append(logpdf)

    L = np.column_stack(collect_logpdf)
    site_labels_arr = np.array(site_labels, dtype=object)
    idx_max = np.argmax(L, axis=1)
    pred_labels = site_labels_arr[idx_max]
    argmax_logpdf = L[np.arange(L.shape[0]), idx_max]
    
    preds_X = pd.DataFrame({
        "pred_label": pred_labels,
        "argmax_logpdf": argmax_logpdf,
    }, index=site_df_test.index)

    scores_X = pd.DataFrame(L, index=site_df_test.index, columns=[f'state{lab}_prob' for lab in site_labels])

    # write predictions into the preallocated frame using the boolean mask
    preds_X_full.loc[df_test['target_site'] == site, pred_cols] = preds_X[pred_cols].to_numpy()
    
    scores_X.index = site_df_test.index
    scores_X = scores_X.add_prefix(f"{site}::")
    scores_X_tables.append(scores_X)
    print("Scored X...")

HEK3
label 4 4437
label 5 8804
label 2 4671
label 0 6914
label 6 4464
label 1 1069
label 7 3377
label 3 514
label 8 1103




Scored X...
EMX1
label 2 4377
label 1 6368
label 7 5164
label 0 6914
label 8 4901
label 5 3286
label 3 3315
label 6 472
label 4 556




Scored X...
RNF2
label 2 3370
label 6 4922
label 7 2858
label 0 6914
label 5 2759
label 8 7239
label 4 1004
label 1 4142
label 3 2145




Scored X...


In [329]:
out = (
    df_test[["intID", "clone", "target_site","seq_geno","brightest_geno", "pet_geno"]]
    .join(preds_X_full.add_prefix("X_"))
)

In [330]:
out

Unnamed: 0,intID,clone,target_site,seq_geno,brightest_geno,pet_geno,X_pred_label,X_argmax_logpdf
0,intID1011,2,HEK3,4,4,4,4,-52.349765
15,intID1011,2,HEK3,4,4,4,4,-51.767551
20,intID1011,2,HEK3,4,4,4,4,-55.647344
35,intID1011,2,HEK3,4,4,4,4,-56.321716
45,intID1011,2,HEK3,4,4,4,4,-52.794195
...,...,...,...,...,...,...,...,...
131790,intID974,2,RNF2,8,8,8,8,-57.856428
131791,intID974,2,RNF2,8,8,8,8,-73.49166
131794,intID974,2,RNF2,8,8,8,8,-64.095896
131795,intID974,2,RNF2,8,8,8,8,-57.005516


In [331]:
out['X_argmax_logpdf'].astype('float').describe()

count    25746.000000
mean       -56.805482
std         13.028140
min       -738.195918
25%        -57.363105
50%        -54.745537
75%        -53.355046
max        -49.935402
Name: X_argmax_logpdf, dtype: float64

In [359]:
methods = {
    "brightest_geno": "brightest_state",
    "pet_geno": "pet_state",
    "X_pred_label": "X_pred_label",
}

def compare_metrics(y_true: pd.Series, y_pred: pd.Series, missing_token="missing"):
    # valid (non-missing) mask
    mask = y_true.notna() & y_pred.notna()
    if missing_token is not None:
        mask &= (y_true != missing_token) & (y_pred != missing_token)
    n = int(mask.sum())
    if n == 0:
        return {"true_agree": float("nan"), "true_disagree": float("nan"), "n_valid": 0}
    agree_mask = (y_true[mask] == y_pred[mask])
    true_agree = agree_mask.mean()
    true_disagree = (~agree_mask).mean()   # 1 - true_agree
    return {"true_agree": float(true_agree), "true_disagree": float(true_disagree), "n_valid": n}

# Per-site metrics
def site_metrics(g):
    out = {}
    for col, base in methods.items():
        m = compare_metrics(g["seq_geno"], g[col])
        out[f"{base}_true_agree"] = m["true_agree"]
        out[f"{base}_true_disagree"] = m["true_disagree"]
        out[f"{base}_n_valid"] = m["n_valid"]
    return pd.Series(out)

acc_by_site = (
    out.groupby("target_site", dropna=False)
       .apply(site_metrics)
       .reset_index()
)

# Overall (pooled across sites)
overall = {}
for col, base in methods.items():
    m = compare_metrics(out["seq_geno"], out[col])
    overall[f"{base}_true_agree"] = m["true_agree"]
    overall[f"{base}_true_disagree"] = m["true_disagree"]
    overall[f"{base}_n_valid"] = m["n_valid"]

overall_row = pd.DataFrame([{"target_site": "Total", **overall}])
acc_by_site_with_total = pd.concat([acc_by_site, overall_row], ignore_index=True)

acc_by_site_with_total


  .apply(site_metrics)


Unnamed: 0,target_site,brightest_state_true_agree,brightest_state_true_disagree,brightest_state_n_valid,pet_state_true_agree,pet_state_true_disagree,pet_state_n_valid,X_pred_label_true_agree,X_pred_label_true_disagree,X_pred_label_n_valid
0,EMX1,0.973666,0.026334,8582.0,0.977162,0.022838,8582.0,0.958867,0.041133,8582.0
1,HEK3,0.971801,0.028199,8582.0,0.980075,0.019925,8582.0,0.977278,0.022722,8582.0
2,RNF2,0.986949,0.013051,8582.0,0.987765,0.012235,8582.0,0.979026,0.020974,8582.0
3,Total,0.977472,0.022528,25746.0,0.981667,0.018333,25746.0,0.971724,0.028276,25746.0


In [334]:
acc_by_site

Unnamed: 0,target_site,brightest_state_acc,pet_state_acc,X_pred_label_acc
0,EMX1,0.973666,0.977162,0.958867
1,HEK3,0.971801,0.980075,0.977278
2,RNF2,0.986949,0.987765,0.979026


In [335]:
acc_by_clone = (
    out.groupby("clone", dropna=False)
      .apply(lambda g: pd.Series({
          new_name: accuracy(g["seq_geno"], g[col])
          for col, new_name in methods.items()
      }))
      .reset_index())
acc_by_clone

  .apply(lambda g: pd.Series({


Unnamed: 0,clone,brightest_state_acc,pet_state_acc,X_pred_label_acc
0,1,0.973676,0.98053,0.974533
1,2,0.975917,0.98399,0.971812
2,3,0.993711,0.987945,0.977987
3,4,0.985366,0.977778,0.958537


### Save clone 2 to file now

In [336]:
chosen_clone = 2

In [337]:
df_clean = df_test.copy()

In [338]:
use_scores = scores_X_tables
use_preds = preds_X_full

In [339]:
df_clean['kde_geno'] = use_preds['pred_label'].astype('Int64')

In [340]:
for i in range(9):
    col = f"state{i}_prob"
    if col not in df_clean.columns:
        df_clean[col] = pd.NA  # initialize

In [341]:
for site, df_scores in zip(df_clean['target_site'].unique(), use_scores):
    idx = df_scores.index  

    for col in df_scores.columns:
        new_col_name = col.split('::')[1]
        df_clean.loc[idx, new_col_name] = df_scores[col].values
        

In [342]:
df_clean

Unnamed: 0,cellBC,intID,clone,target_site,pet_state,seq_state,brightest_state,pet_prob,feature_0,feature_1,...,kde_geno,state0_prob,state1_prob,state2_prob,state3_prob,state4_prob,state5_prob,state6_prob,state7_prob,state8_prob
0,4T1_preedited-1057,intID1011,2,HEK3,CTCTC,CTCTC,CTCTC,0.999968,217.510,217.293,...,4,-64.033193,-147.289725,-204.084212,-387.693111,-52.349765,-95.375683,-75.685569,-95.364985,-164.620164
15,4T1_preedited-1272,intID1011,2,HEK3,CTCTC,CTCTC,CTCTC,0.999851,225.200,152.910,...,4,-94.883994,-99.610095,-69.928517,-141.008022,-51.767551,-80.762566,-122.351472,-74.303792,-78.839048
20,4T1_preedited-1290,intID1011,2,HEK3,CTCTC,CTCTC,CTCTC,0.907947,172.470,309.039,...,4,-63.948311,-65.232135,-66.200589,-67.607856,-55.647344,-66.220694,-71.652909,-64.186983,-59.47787
35,4T1_preedited-1632,intID1011,2,HEK3,CTCTC,CTCTC,CTCTC,0.982948,365.813,288.114,...,4,-94.357771,-84.979908,-83.868771,-99.952759,-56.321716,-70.109624,-121.903806,-85.975444,-77.39516
45,4T1_preedited-1684,intID1011,2,HEK3,CTCTC,CTCTC,CTCTC,0.987427,177.963,97.379,...,4,-64.258332,-63.167009,-63.646967,-75.101319,-52.794195,-61.500904,-67.676778,-64.387139,-60.426465
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
131790,4T1_preedited-8536,intID974,2,RNF2,TTCCT,TTCCT,TTCCT,0.995711,569.790,233.729,...,8,-78.044388,-66.945181,-74.242861,-70.015387,-93.599128,-84.447135,-87.023088,-184.767366,-57.856428
131791,4T1_preedited-8539,intID974,2,RNF2,TTCCT,TTCCT,TTCCT,0.999992,1022.081,689.382,...,8,-223.108919,-743.844916,-738.195918,-744.416538,-743.16738,-742.672101,-741.708892,-741.749935,-73.49166
131794,4T1_preedited-8573,intID974,2,RNF2,TTCCT,TTCCT,TTCCT,1.000000,465.167,185.330,...,8,-134.61144,-743.844916,-738.195918,-744.416538,-743.16738,-742.672101,-741.708892,-741.749935,-64.095896
131795,4T1_preedited-8598,intID974,2,RNF2,TTCCT,TTCCT,TTCCT,0.999999,272.823,134.571,...,8,-86.186164,-309.670206,-738.195918,-240.217808,-743.16738,-260.490184,-722.655914,-741.749935,-57.005516


In [343]:
chosen_clone_mask = df_clean['clone'] == chosen_clone
df_clean = df_clean.loc[chosen_clone_mask]

In [344]:
df_clean.columns

Index(['cellBC', 'intID', 'clone', 'target_site', 'pet_state', 'seq_state',
       'brightest_state', 'pet_prob', 'feature_0', 'feature_1', 'feature_2',
       'feature_3', 'feature_4', 'feature_5', 'feature_6', 'feature_7',
       'feature_8', 'cassette_idx', 'target_idx', 'seq_geno', 'brightest_geno',
       'pet_geno', 'kde_geno', 'state0_prob', 'state1_prob', 'state2_prob',
       'state3_prob', 'state4_prob', 'state5_prob', 'state6_prob',
       'state7_prob', 'state8_prob'],
      dtype='object')

In [345]:
disagree_mask = df_clean['seq_geno'] != df_clean['kde_geno']

In [346]:
df_clean.loc[disagree_mask, ['cellBC', 'intID', 'clone', 'target_site', 'pet_state', 'seq_state',
       'brightest_state', 'pet_prob',  'seq_geno', 'brightest_geno',
       'pet_geno', 'kde_geno', 'state0_prob', 'state1_prob', 'state2_prob',
       'state3_prob', 'state4_prob', 'state5_prob', 'state6_prob',
       'state7_prob', 'state8_prob']]

Unnamed: 0,cellBC,intID,clone,target_site,pet_state,seq_state,brightest_state,pet_prob,seq_geno,brightest_geno,...,kde_geno,state0_prob,state1_prob,state2_prob,state3_prob,state4_prob,state5_prob,state6_prob,state7_prob,state8_prob
47,4T1_preedited-1696,intID1011,2,HEK3,CTCTC,CTCTC,CTCTC,0.835430,4,4,...,2,-394.865416,-339.499631,-226.740742,-619.75197,-239.183932,-645.020428,-586.146937,-298.019859,-232.48528
657,4T1_preedited-8539,intID1011,2,HEK3,CTTTG,CTCTC,CTTTG,0.650073,4,5,...,5,-64.258876,-63.286289,-68.921873,-67.51927,-59.854852,-59.699805,-68.673667,-63.835181,-63.891229
11146,4T1_preedited-2212,intID1250,2,HEK3,GCAAG,AATCG,CTTTG,0.329595,1,5,...,4,-53.172149,-53.52497,-53.424067,-53.461725,-52.292727,-52.38288,-53.032484,-52.400216,-54.085839
11625,4T1_preedited-775,intID1250,2,HEK3,AATCG,AATCG,AATCG,0.691293,1,1,...,2,-61.521226,-60.507644,-57.333873,-63.672173,-67.103383,-64.513976,-63.971373,-60.545546,-59.095082
15287,4T1_preedited-1843,intID1294,2,HEK3,GCAAG,CTTTG,GATAG,0.273701,5,6,...,7,-56.051681,-55.799302,-56.422821,-57.529179,-55.485917,-55.650219,-55.868894,-55.347352,-57.055127
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
122803,4T1_preedited-8573,intID493,2,RNF2,TGCCA,TGCCA,TGCCA,1.000000,7,7,...,0,-62.510425,-743.844916,-738.195918,-744.416538,-743.16738,-742.672101,-741.708892,-62.998366,-743.192258
126816,4T1_preedited-7119,intID642,2,RNF2,TCCAA,ACTTA,TCCAA,0.991890,3,6,...,0,-230.449295,-472.249326,-451.492232,-400.05287,-743.16738,-616.038376,-375.303851,-474.247459,-305.840482
131498,4T1_preedited-5700,intID974,2,RNF2,TTCCT,TTCCT,ACTCC,0.402281,8,2,...,2,-57.490239,-57.264034,-54.980241,-56.795551,-57.59447,-57.37333,-58.6506,-58.757238,-55.254772
131553,4T1_preedited-6189,intID974,2,RNF2,TCCAA,TTCCT,TCCAA,0.993320,8,6,...,6,-139.64633,-743.844916,-290.724444,-744.416538,-743.16738,-259.434701,-71.530462,-741.749935,-118.973909


In [347]:
df_clean['pet_geno'].nunique()

9

In [348]:
df_clean['seq_geno'].unique()

array([4, 2, 1, 5, 7, 3, 8, 6])

In [349]:
# check that the cassette_idx and target_idx are redone appropriately
# need 3 outputs: 1) petracer geno_df 2) kde_argmax geno 3) kde_scores
df_clean['kde_geno'].nunique(), df_clean['seq_geno'].nunique()

(9, 8)

In [350]:
petracer_id_cols = ["cellBC", "intID", "clone", "x", "y", "z", "target_site", "target_idx"]
petracer_call_cols = ["pet_state", "seq_state", "brightest_state", "pet_prob"]

In [351]:
df_clean['target_idx'].nunique(), min(df_clean['target_idx']), max(df_clean['target_idx'])

(60, 0, 95)

In [352]:
unique_cassettes = df_clean['intID'].unique()
cassette_mapping = {id_: i for i, id_ in enumerate(unique_cassettes)}
df_clean["cassette_idx"] = df_clean["intID"].map(cassette_mapping)

site_mapping = {'RNF2': 0, 'HEK3': 1, 'EMX1': 2}
df_clean['target_idx'] = df_clean['target_site'].map(site_mapping) + len(site_mapping) * df_clean["cassette_idx"]

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df_clean["cassette_idx"] = df_clean["intID"].map(cassette_mapping)
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df_clean['target_idx'] = df_clean['target_site'].map(site_mapping) + len(site_mapping) * df_clean["cassette_idx"]


In [353]:
df_clean['target_idx'].nunique(), min(df_clean['target_idx']), max(df_clean['target_idx'])

(60, 0, 59)

In [354]:
df_clean

Unnamed: 0,cellBC,intID,clone,target_site,pet_state,seq_state,brightest_state,pet_prob,feature_0,feature_1,...,kde_geno,state0_prob,state1_prob,state2_prob,state3_prob,state4_prob,state5_prob,state6_prob,state7_prob,state8_prob
0,4T1_preedited-1057,intID1011,2,HEK3,CTCTC,CTCTC,CTCTC,0.999968,217.510,217.293,...,4,-64.033193,-147.289725,-204.084212,-387.693111,-52.349765,-95.375683,-75.685569,-95.364985,-164.620164
15,4T1_preedited-1272,intID1011,2,HEK3,CTCTC,CTCTC,CTCTC,0.999851,225.200,152.910,...,4,-94.883994,-99.610095,-69.928517,-141.008022,-51.767551,-80.762566,-122.351472,-74.303792,-78.839048
20,4T1_preedited-1290,intID1011,2,HEK3,CTCTC,CTCTC,CTCTC,0.907947,172.470,309.039,...,4,-63.948311,-65.232135,-66.200589,-67.607856,-55.647344,-66.220694,-71.652909,-64.186983,-59.47787
35,4T1_preedited-1632,intID1011,2,HEK3,CTCTC,CTCTC,CTCTC,0.982948,365.813,288.114,...,4,-94.357771,-84.979908,-83.868771,-99.952759,-56.321716,-70.109624,-121.903806,-85.975444,-77.39516
45,4T1_preedited-1684,intID1011,2,HEK3,CTCTC,CTCTC,CTCTC,0.987427,177.963,97.379,...,4,-64.258332,-63.167009,-63.646967,-75.101319,-52.794195,-61.500904,-67.676778,-64.387139,-60.426465
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
131790,4T1_preedited-8536,intID974,2,RNF2,TTCCT,TTCCT,TTCCT,0.995711,569.790,233.729,...,8,-78.044388,-66.945181,-74.242861,-70.015387,-93.599128,-84.447135,-87.023088,-184.767366,-57.856428
131791,4T1_preedited-8539,intID974,2,RNF2,TTCCT,TTCCT,TTCCT,0.999992,1022.081,689.382,...,8,-223.108919,-743.844916,-738.195918,-744.416538,-743.16738,-742.672101,-741.708892,-741.749935,-73.49166
131794,4T1_preedited-8573,intID974,2,RNF2,TTCCT,TTCCT,TTCCT,1.000000,465.167,185.330,...,8,-134.61144,-743.844916,-738.195918,-744.416538,-743.16738,-742.672101,-741.708892,-741.749935,-64.095896
131795,4T1_preedited-8598,intID974,2,RNF2,TTCCT,TTCCT,TTCCT,0.999999,272.823,134.571,...,8,-86.186164,-309.670206,-738.195918,-240.217808,-743.16738,-260.490184,-722.655914,-741.749935,-57.005516


##### Save character matrix

In [355]:
tbl = df_clean[['target_idx', 'cellBC', 'kde_geno']]
tbl.columns = ['target_site', 'cell_name', 'pred_label']
kde_inputs_argmax = reshape_to_cmat(tbl).T

In [356]:
kde_inputs_argmax.to_csv(out_dir + "/petracer_clone2_kde_character_matrix.csv")

##### Save PETracer probabilities now

In [357]:
petracer_df = df_clean.copy()
petracer_path = os.path.join(out_dir, f"petracer_clone{chosen_clone}_petracer_genotypes.csv")
petracer_df.to_csv(petracer_path, index=False)

##### Save the kde scores

In [358]:
state_cols = [c for c in df_clean.columns if re.match(r"state\d+_prob", c)]
obs_matrix = df_clean[["cellBC", "cassette_idx", "target_idx"] + state_cols].copy().rename(columns={"cellBC": "cell_name", "target_idx": "target_site"})

obs_matrix["cassette_idx"] = 0
obs_matrix = obs_matrix.sort_values(by="target_site", ascending=True)
petracer_path = os.path.join(out_dir, f"petracer_clone{chosen_clone}_kde_scores.csv")
obs_matrix.to_csv(petracer_path, index=False)
