In [19]:
import pandas as pd
import itertools as it
import multiprocessing as mp
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np
from pathlib import Path

In [None]:
workflow_dir = Path("") #absolute path to the root of the cell matching workflow (cell_matching)
results_dir = Path("") #absolute path to the scores directory
n_cells = 760 #adjust
n_bootstrap = 250 #adjust

In [None]:
gt = {}
with open(workflow_dir / 'matched_barcodes.tsv','r') as ih: #adjust as needed
    _ = ih.readline()
    for line in ih:
        parts = line.strip().split("\t")
        gt[parts[1]] = parts[0]
len(gt)

736320

In [None]:
def resolve_replicate(rep)->list[dict[str,str|int|float]]:
    res = []

    top_ns = [1,5,10,25,50,100]
    metrics = ["cosine","pearson","spearman","kendall"]
    tools = ["rna","copykat","copyvae"]
    for tool in tools:
        for metric in metrics:
            df = pd.read_csv(results_dir/f'replicate_{rep}_{tool}_{metric}.tsv',sep='\t',index_col=0).iloc[:n_cells,n_cells:]
            recalls = {x:0 for x in top_ns}
            duplicates = []
            for c in df.columns:
                ranked_index = df[c].apply(abs).sort_values(ascending=False).index
                max_score = df[c].apply(abs).max()
                duplicates.append(len([x for x in df[c] if abs(x) >= max_score]))
                for n in top_ns:
                    recalls[n] += gt[c] in ranked_index[:n]
            for n in top_ns:
                res.append({
                    "metric":metric,
                    "tool" : tool,
                    "top_n" : n,
                    "recall" : recalls[n] / df.shape[1],
                    "min_duplicates" : min(duplicates),
                    "max_duplicates" : max(duplicates),
                    "median_duplicates" : np.median(duplicates),
                    "mean_duplicates" : np.mean(duplicates)
                })
        print(rep,tool)
    randoms = {x:0 for x in top_ns}
    cnt = 0
    for c in df.columns:
        random_calls = df[c].apply(abs).sample(frac=1,random_state=rep+cnt).index
        cnt += 1
        for n in top_ns:
            randoms[n] += gt[c] in random_calls[:n]
    for n in top_ns:
        res.append({
            "metric":"random",
            "tool" : "other",
            "top_n" : n,
            "recall" : randoms[n] / df.shape[1],
            "min_duplicates" : 1,
            "max_duplicates" : 1,
            "median_duplicates" : 1,
            "mean_duplicates" : 1
        })
    return res

if __name__ == "__main__":
    results = []
    p = mp.Pool()
    for r in p.map_async(resolve_replicate,range(1,n_bootstrap+1)).get():
        results += r    
    recall_df = pd.DataFrame(results)


11327 rna
21
 45rna   
5512  rnarna
rna20 rnarna3
 rna
39rna 176097rna   
rnarna
40
 rnarna


 rna
42 rna
63
3862 rna 
rna
19 rna 
57rna

30 rna58 rna
11 rna 
rna
528
4659 43  rna 
rnarna
49rna rna

32 rna
 rna

3724 rna
26  rna31
 rna
1451  rnarna

rna

35291018  rnarna 61rna  
rnarna

644  rnarna

4836  rnarna

64 rna
56
 rna
15 54rna
 rna
23 rna
412 rna
50 rna
33 rna
 rna
5 rna
28 rna
22 rna
34 rna
4 rna
47 rna
53 rna
25 rna
16 rna
27 copykat
43 copykat
26 copykat
64 copykat
2 copykat
3411  copykat
31 copykat
4522  copykat
603 copykat
copykat
copykat44
36 copykat
  copykatcopykat
8 copykat
23 copykat
33 copykat1618
 
 copykatcopykatcopykat

30 37
9 copykat
40 58copykat
 copykat
 5 62copykat 
copykat35
14  copykat3248copykat
copykat
 47
copykat12 copykat
452  copykatcopykat

3824 copykat  copykat



46 copykat copykat53 copykat
25 1copykat
 copykat
57 copykatcopykat
21 copykat

42
 copykat

59 copykat
55 copykat
50 copykat
28 copykat
10 copykat
7 copykat49 39copykat
 copykat15 17copy

In [22]:
recall_df

Unnamed: 0,metric,tool,top_n,recall,min_duplicates,max_duplicates,median_duplicates,mean_duplicates
0,cosine,rna,1,0.005263,1,1,1.0,1.0
1,cosine,rna,5,0.011842,1,1,1.0,1.0
2,cosine,rna,10,0.018421,1,1,1.0,1.0
3,cosine,rna,25,0.034211,1,1,1.0,1.0
4,cosine,rna,50,0.059211,1,1,1.0,1.0
...,...,...,...,...,...,...,...,...
19495,random,other,5,0.013158,1,1,1.0,1.0
19496,random,other,10,0.019737,1,1,1.0,1.0
19497,random,other,25,0.038158,1,1,1.0,1.0
19498,random,other,50,0.061842,1,1,1.0,1.0


In [None]:
recall_df.to_csv(workflow_dir/"cell2cell_recall.csv",index=False)

In [None]:
def get_accs(rep)->dict[str,float]:
    methods = ["cosine","pearson","spearman","kendall"]
    tools = ["rna","copyvae","copykat"]
    matches = []
    gen = np.random.default_rng(rep)
    for method in methods:
        for tool in tools:
            df = pd.read_csv(results_dir/f'replicate_{rep}_{tool}_{method}.tsv',sep='\t',index_col=0).iloc[:n_cells,n_cells:]
            best_score = df.apply(lambda x:max(abs(x)),axis=0)
            is_best_match = df.apply(lambda x : abs(x) == best_score,axis=1).T.replace(True,{col:col for col in df.T.columns}).replace(False,pd.NA).T
            cell_matches = is_best_match.apply(lambda x:[k.split("_")[0] for k in x if not pd.isna(k)][0] if x.count() > 0 else "NA",axis=0)
            matches.append(pd.Series(cell_matches,name=f"{method}_{tool}").sort_index())
    random = pd.Series([x.split("_")[0] for x in pd.Series(df.index).sample(n=len(df.columns),random_state=rep,replace=True)],index = df.columns,name="random_other").sort_index()
    matches.append(random)
    madna_df = pd.read_csv(clonealign_assignments/f"replicate_{rep}_macrodna.tsv",sep="\t",index_col=0)
    madna_df["macrodna_other"] = madna_df["predict_cell"].apply(lambda x: x.split("_")[0])
    matches.append(madna_df["macrodna_other"])
    matches_df = pd.concat(matches,axis=1)
    matches_df["GroundTruth"] = [gt[x] for x in matches_df.index]
    return {c : sum(matches_df[c]==matches_df["GroundTruth"]) / matches_df[c].count() for c in matches_df.columns if c != "CellLine"}
if __name__ == "__main__":
    p = mp.Pool()
    res = list(p.map_async(get_accs,range(1,n_bootstrap+1)).get())
    top_df = pd.DataFrame(res)
    

Process ForkPoolWorker-254:
Process ForkPoolWorker-253:
Process ForkPoolWorker-224:
Process ForkPoolWorker-239:
Process ForkPoolWorker-193:
Process ForkPoolWorker-198:
Process ForkPoolWorker-213:
Process ForkPoolWorker-234:
Process ForkPoolWorker-223:
Process ForkPoolWorker-246:
Process ForkPoolWorker-205:
Process ForkPoolWorker-195:
Process ForkPoolWorker-255:
Process ForkPoolWorker-214:
Process ForkPoolWorker-204:
Process ForkPoolWorker-241:
Process ForkPoolWorker-237:
Process ForkPoolWorker-209:
Process ForkPoolWorker-226:
Process ForkPoolWorker-200:
Process ForkPoolWorker-217:
Process ForkPoolWorker-208:
Process ForkPoolWorker-247:
Process ForkPoolWorker-215:
Process ForkPoolWorker-232:
Process ForkPoolWorker-229:
Process ForkPoolWorker-199:
Process ForkPoolWorker-203:
Process ForkPoolWorker-225:
Process ForkPoolWorker-201:
Process ForkPoolWorker-251:
Process ForkPoolWorker-230:
Process ForkPoolWorker-243:
Process ForkPoolWorker-197:
Process ForkPoolWorker-240:
Process ForkPoolWork

In [25]:
top_df

Unnamed: 0,cosine_rna,cosine_copyvae,cosine_copykat,pearson_rna,pearson_copyvae,pearson_copykat,spearman_rna,spearman_copyvae,spearman_copykat,kendall_rna,kendall_copyvae,kendall_copykat,random_other,macrodna_other,GroundTruth
0,0.005263,0.000000,0.013158,0.006579,0.002632,0.015789,0.006579,0.001316,0.011842,0.003947,0.001316,0.013158,0.000000,0.009211,1.0
1,0.001316,0.000000,0.010526,0.003947,0.002632,0.014474,0.003947,0.001316,0.014474,0.001316,0.001316,0.014474,0.000000,0.007895,1.0
2,0.005263,0.000000,0.010526,0.006579,0.002632,0.017105,0.007895,0.001316,0.015789,0.003947,0.001316,0.015789,0.001316,0.007895,1.0
3,0.005263,0.001316,0.010526,0.003947,0.002632,0.018421,0.003947,0.001316,0.013158,0.002632,0.001316,0.018421,0.002632,0.006579,1.0
4,0.005263,0.000000,0.011842,0.007895,0.003947,0.015789,0.005263,0.001316,0.013158,0.002632,0.001316,0.011842,0.001316,0.010526,1.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
245,0.002632,0.000000,0.013158,0.003947,0.003947,0.014474,0.005263,0.001316,0.013158,0.005263,0.001316,0.014474,0.003947,0.007895,1.0
246,0.005263,0.000000,0.009211,0.006579,0.001316,0.014474,0.007895,0.001316,0.014474,0.003947,0.001316,0.014474,0.000000,0.009211,1.0
247,0.003947,0.000000,0.013158,0.005263,0.002632,0.017105,0.003947,0.001316,0.013158,0.003947,0.001316,0.014474,0.000000,0.010526,1.0
248,0.003947,0.000000,0.010526,0.006579,0.003947,0.014474,0.003947,0.001316,0.014474,0.002632,0.001316,0.015789,0.000000,0.006579,1.0


In [26]:
long_df = pd.melt(top_df.drop(columns="GroundTruth"),var_name="method",value_name="accuracy")
long_df["tool"] = long_df["method"].apply(lambda x:x.split("_")[1])
long_df["method"] = long_df["method"].apply(lambda x:x.split("_")[0])
long_df

Unnamed: 0,method,accuracy,tool
0,cosine,0.005263,rna
1,cosine,0.001316,rna
2,cosine,0.005263,rna
3,cosine,0.005263,rna
4,cosine,0.005263,rna
...,...,...,...
3495,macrodna,0.007895,other
3496,macrodna,0.009211,other
3497,macrodna,0.010526,other
3498,macrodna,0.006579,other


In [None]:
clonealign_df = pd.read_csv(results_dir / "clonealign_assignments",index_col = 0)
clonealign_df["gt"] = clonealign_df["RNA"].apply(lambda x:gt[x])
clonealign_df["correct"] = clonealign_df["DNA"] == clonealign_df["gt"]
sum(clonealign_df["correct"])

1

In [30]:
long_df = pd.concat(
    [
        long_df,
        pd.DataFrame({"method":["clonealign"],"accuracy":[sum(clonealign_df["correct"])/len(clonealign_df["correct"])],"tool":["other"]})
    ]
)
long_df

Unnamed: 0,method,accuracy,tool
0,cosine,0.005263,rna
1,cosine,0.001316,rna
2,cosine,0.005263,rna
3,cosine,0.005263,rna
4,cosine,0.005263,rna
...,...,...,...
3496,macrodna,0.009211,other
3497,macrodna,0.010526,other
3498,macrodna,0.006579,other
3499,macrodna,0.010526,other


In [31]:
long_df[long_df["method"]=="random"]

Unnamed: 0,method,accuracy,tool
3000,random,0.000000,other
3001,random,0.000000,other
3002,random,0.001316,other
3003,random,0.002632,other
3004,random,0.001316,other
...,...,...,...
3245,random,0.003947,other
3246,random,0.000000,other
3247,random,0.000000,other
3248,random,0.000000,other


In [None]:
long_df.to_csv(workflow_dir / "cell2cell_accuracy.csv",index=False)

In [None]:
fig,ax = plt.subplots(figsize=(16,9))
sns.boxplot(stats_df,x="tool",y="accuracy",hue="method",ax=ax,fill=False)
plt.show()