In [18]:
import warnings
warnings.filterwarnings('ignore')
import bioframe as bf
from gpn.data import load_table, Genome
import numpy as np
import polars as pl
from scipy.special import softmax
from scipy.stats import entropy
from tqdm import tqdm

In [2]:
NUCLEOTIDES = list("ACGT")
chrom = "22"
model = "multiz100way/89/128/64/True/defined.phastCons.percentile-75_0.05_0.001/medium/0.1/42/30000/True/True/True"

In [3]:
anc_path = "../../results/ancestral/homo_sapiens_ancestor_GRCh38/homo_sapiens_ancestor_22.fa"

In [4]:
anc_seq = Genome(anc_path)._genome["ANCESTOR_for_chromosome:GRCh38:22:1:50818468:1"].upper()
anc_seq = np.frombuffer(anc_seq.encode("ascii"), dtype="S1")
anc_seq

array([b'.', b'.', b'.', ..., b'.', b'.', b'.'], dtype='|S1')

In [5]:
# https://ftp.ensembl.org/pub/release-109/fasta/ancestral_alleles/

In [6]:
annotation = load_table("../../results/annotation.gtf.gz")[["chrom", "start", "end", "feature"]]
annotation = annotation[annotation.chrom==chrom]
annotation

Unnamed: 0,chrom,start,end,feature
3261452,22,11827522,11910358,gene
3261453,22,11827522,11910358,transcript
3261454,22,11827522,11827658,exon
3261455,22,11832760,11832915,exon
3261456,22,11867114,11867145,exon
...,...,...,...,...
3332625,22,27750775,27750914,three_prime_utr
3332626,22,27750063,27750228,three_prime_utr
3332627,22,27750677,27791883,transcript
3332628,22,27791746,27791883,exon


In [7]:
annotation = annotation.replace({
    "five_prime_utr": "5' UTR",
    "three_prime_utr": "3' UTR",
})

In [8]:
annotation.feature.value_counts()

feature
exon              34347
CDS               18387
transcript         5351
3' UTR             4148
5' UTR             3538
start_codon        2098
stop_codon         1885
gene               1412
Selenocysteine       12
Name: count, dtype: int64

In [9]:
features = [
    "CDS",
    "5' UTR",
    "3' UTR",
]

feature_intervals = {
    feature: bf.merge(annotation[annotation.feature==feature]).drop(columns="n_intervals") 
    for feature in features
}

In [10]:
path = f"../../results/positions/{chrom}/processed_logits/{model}.parquet"
V = (
    pl.read_parquet(path)
)
V = V.with_columns(anc=anc_seq[V["pos"]-1])
V = V.with_columns(anc=pl.col("anc").cast(str))
V = V.with_columns(
    pl.DataFrame(softmax(V.select(NUCLEOTIDES), axis=1), schema=NUCLEOTIDES)
)
for c in ["ref", "anc"]:
    V = (
        V.with_columns(
            pl.when(pl.col(c) == "A").then(pl.col("A"))
            .when(pl.col(c) == "C").then(pl.col("C"))
            .when(pl.col(c) == "G").then(pl.col("G"))
            .when(pl.col(c) == "T").then(pl.col("T"))
            .alias(f"prob_{c}")
        )
        .with_columns(
            (1 / pl.col(f"prob_{c}")).alias(f"perplexity_{c}"),
            (pl.col(f"prob_{c}") == pl.max_horizontal(NUCLEOTIDES))
            .alias(f"accuracy_{c}").cast(float)
        )
    )
intervals = (
    V.with_columns(
        (pl.col("pos") - 1).alias("start"),
        pl.col("pos").alias("end")
    )
    .select(["chrom", "start", "end"])
    .to_pandas()
)
for f, x in tqdm(feature_intervals.items()):
    V = V.with_columns(
        pl.from_pandas(bf.count_overlaps(intervals, x)["count"] > 0)
        .alias(f)
    )
V

100%|██████████| 3/3 [02:34<00:00, 51.43s/it]


chrom,pos,ref,A,C,G,T,anc,prob_ref,perplexity_ref,accuracy_ref,prob_anc,perplexity_anc,accuracy_anc,CDS,5' UTR,3' UTR
str,i64,str,f32,f32,f32,f32,str,f32,f32,f64,f32,f32,f64,bool,bool,bool
"""22""",10510065,"""A""",0.677927,0.101803,0.118962,0.101308,""".""",0.677927,1.475085,1.0,,,,false,false,false
"""22""",10510066,"""A""",0.678113,0.104195,0.11466,0.103032,""".""",0.678113,1.47468,1.0,,,,false,false,false
"""22""",10510067,"""T""",0.107364,0.133747,0.101849,0.657041,""".""",0.657041,1.521976,1.0,,,,false,false,false
"""22""",10510068,"""A""",0.680645,0.096985,0.123955,0.098416,""".""",0.680645,1.469195,1.0,,,,false,false,false
"""22""",10510069,"""T""",0.120914,0.108757,0.121446,0.648883,""".""",0.648883,1.54111,1.0,,,,false,false,false
"""22""",10510070,"""G""",0.541647,0.093849,0.229348,0.135156,""".""",0.229348,4.360189,0.0,,,,false,false,false
"""22""",10510071,"""T""",0.11857,0.059215,0.754159,0.068056,""".""",0.068056,14.693797,0.0,,,,false,false,false
"""22""",10510072,"""A""",0.721582,0.080565,0.113504,0.084349,""".""",0.721582,1.385843,1.0,,,,false,false,false
"""22""",10510073,"""T""",0.084662,0.097041,0.083532,0.734765,""".""",0.734765,1.36098,1.0,,,,false,false,false
"""22""",10510074,"""T""",0.069883,0.069918,0.068198,0.792001,""".""",0.792001,1.262625,1.0,,,,false,false,false


In [11]:
V = V.with_columns(
    pl.when(pl.col("CDS")).then(pl.lit("CDS"))
    .when(pl.col("5' UTR")).then(pl.lit("5' UTR"))
    .when(pl.col("3' UTR")).then(pl.lit("3' UTR"))
    .otherwise(pl.lit("Non-exonic"))
    .alias("Region")
)
V["Region"].value_counts()

Region,count
str,u32
"""Non-exonic""",37423927
"""3' UTR""",785066
"""5' UTR""",189036
"""CDS""",755604


In [12]:
df1 = V.group_by("Region").agg(
    #pl.mean("perplexity_ref").alias("mean perplexity"),
    pl.mean("accuracy_ref").alias("Accuracy"),
    pl.median("perplexity_ref").alias("Median perplexity"),
).sort("Accuracy", descending=True)
df1

Region,Accuracy,Median perplexity
str,f64,f32
"""CDS""",0.915296,1.032106
"""5' UTR""",0.836999,1.645802
"""3' UTR""",0.821391,1.687933
"""Non-exonic""",0.690292,1.977608


In [13]:
print(df1.to_pandas().to_latex(index=False, escape="latex", float_format="%.3f"))

\begin{tabular}{lrr}
\toprule
Region & Accuracy & Median perplexity \\
\midrule
CDS & 0.915 & 1.032 \\
5' UTR & 0.837 & 1.646 \\
3' UTR & 0.821 & 1.688 \\
Non-exonic & 0.690 & 1.978 \\
\bottomrule
\end{tabular}



In [14]:
V_diff = V.filter(
    pl.col("anc").is_in(NUCLEOTIDES),
    pl.col("ref") != pl.col("anc"),
)
len(V_diff)

245242

In [15]:
df2 = V_diff.group_by("Region").agg(
    pl.mean("accuracy_ref").alias("Accuracy predicting reference"),
    pl.mean("accuracy_anc").alias("Accuracy predicting ancestral"),
).sort("Accuracy predicting ancestral", descending=True)
df2

Region,Accuracy predicting reference,Accuracy predicting ancestral
str,f64,f64
"""CDS""",0.227335,0.679959
"""5' UTR""",0.177312,0.678742
"""3' UTR""",0.245011,0.60324
"""Non-exonic""",0.268406,0.515225


In [16]:
print(df2.to_pandas().to_latex(index=False, escape="latex", float_format="%.3f"))

\begin{tabular}{lrr}
\toprule
Region & Accuracy predicting reference & Accuracy predicting ancestral \\
\midrule
CDS & 0.227 & 0.680 \\
5' UTR & 0.177 & 0.679 \\
3' UTR & 0.245 & 0.603 \\
Non-exonic & 0.268 & 0.515 \\
\bottomrule
\end{tabular}



In [19]:
V = V.with_columns(
    entropy=entropy(V.select(NUCLEOTIDES), base=2, axis=1)
)
V

chrom,pos,ref,A,C,G,T,anc,prob_ref,perplexity_ref,accuracy_ref,prob_anc,perplexity_anc,accuracy_anc,CDS,5' UTR,3' UTR,Region,entropy
str,i64,str,f32,f32,f32,f32,str,f32,f32,f64,f32,f32,f64,bool,bool,bool,str,f32
"""22""",10510065,"""A""",0.677927,0.101803,0.118962,0.101308,""".""",0.677927,1.475085,1.0,,,,false,false,false,"""Non-exonic""",1.41576
"""22""",10510066,"""A""",0.678113,0.104195,0.11466,0.103032,""".""",0.678113,1.47468,1.0,,,,false,false,false,"""Non-exonic""",1.416055
"""22""",10510067,"""T""",0.107364,0.133747,0.101849,0.657041,""".""",0.657041,1.521976,1.0,,,,false,false,false,"""Non-exonic""",1.467612
"""22""",10510068,"""A""",0.680645,0.096985,0.123955,0.098416,""".""",0.680645,1.469195,1.0,,,,false,false,false,"""Non-exonic""",1.406799
"""22""",10510069,"""T""",0.120914,0.108757,0.121446,0.648883,""".""",0.648883,1.54111,1.0,,,,false,false,false,"""Non-exonic""",1.490927
"""22""",10510070,"""G""",0.541647,0.093849,0.229348,0.135156,""".""",0.229348,4.360189,0.0,,,,false,false,false,"""Non-exonic""",1.676943
"""22""",10510071,"""T""",0.11857,0.059215,0.754159,0.068056,""".""",0.068056,14.693797,0.0,,,,false,false,false,"""Non-exonic""",1.177065
"""22""",10510072,"""A""",0.721582,0.080565,0.113504,0.084349,""".""",0.721582,1.385843,1.0,,,,false,false,false,"""Non-exonic""",1.289668
"""22""",10510073,"""T""",0.084662,0.097041,0.083532,0.734765,""".""",0.734765,1.36098,1.0,,,,false,false,false,"""Non-exonic""",1.254029
"""22""",10510074,"""T""",0.069883,0.069918,0.068198,0.792001,""".""",0.792001,1.262625,1.0,,,,false,false,false,"""Non-exonic""",1.067291


In [28]:
(
    V.filter(pl.col("anc").is_in(NUCLEOTIDES))
    .with_columns((pl.col("ref") == pl.col("anc")).alias("Ref-Anc match"))
    .group_by("Ref-Anc match")
    .agg(
        pl.mean("entropy").alias("mean entropy"),
    )
)

Ref-Anc match,mean entropy
bool,f32
False,1.608392
True,1.498389


In [None]:
# ecdfplot might be better?

In [21]:
df3 = (
    V.filter(pl.col("anc").is_in(NUCLEOTIDES))
    .with_columns((pl.col("ref") == pl.col("anc")).alias("Ref-Anc match"))
    .select(["Ref-Anc match", "entropy"])
    .to_pandas()
)
df3

Unnamed: 0,Ref-Anc match,entropy
0,True,0.191883
1,True,0.301253
2,True,0.550957
3,True,0.344833
4,True,0.405797
...,...,...
32279465,True,1.975786
32279466,True,1.858189
32279467,True,1.591479
32279468,True,1.829296


In [22]:
#sns.ecdfplot(data=df3, x="perplexity", hue="Ref-Anc match")  # maybe subsample?


KeyboardInterrupt



In [27]:
df3.groupby("Ref-Anc match").entropy.mean()

Ref-Anc match
False    1.608394
True     1.500958
Name: entropy, dtype: float32

In [31]:
from scipy.stats import mannwhitneyu
_, p = mannwhitneyu(
    df3[df3["Ref-Anc match"]].entropy,
    df3[~df3["Ref-Anc match"]].entropy,
    alternative="less",
)
p

0.0

In [32]:
from scipy.stats import mannwhitneyu
_, p = mannwhitneyu(
    df3[df3["Ref-Anc match"]].entropy,
    df3[~df3["Ref-Anc match"]].entropy,
    alternative="greater",
)
p

1.0