In [113]:
import warnings
warnings.filterwarnings('ignore')
import bioframe as bf
from gpn.data import load_table
import polars as pl
from scipy.special import softmax
from tqdm import tqdm

In [27]:
NUCLEOTIDES = list("ACGT")
chrom = "22"
model = "multiz100way/89/128/64/True/defined.phastCons.percentile-75_0.05_0.001/medium/0.1/42/30000/True/True/True"

In [31]:
annotation = load_table("../../results/annotation.gtf.gz")[["chrom", "start", "end", "feature"]]
annotation = annotation[annotation.chrom==chrom]
annotation

Unnamed: 0,chrom,start,end,feature
3261452,22,11827522,11910358,gene
3261453,22,11827522,11910358,transcript
3261454,22,11827522,11827658,exon
3261455,22,11832760,11832915,exon
3261456,22,11867114,11867145,exon
...,...,...,...,...
3332625,22,27750775,27750914,three_prime_utr
3332626,22,27750063,27750228,three_prime_utr
3332627,22,27750677,27791883,transcript
3332628,22,27791746,27791883,exon


In [108]:
annotation = annotation.replace({
    "five_prime_utr": "5' UTR",
    "three_prime_utr": "3' UTR",
})

In [109]:
annotation.feature.value_counts()

feature
exon              34347
CDS               18387
transcript         5351
3' UTR             4148
5' UTR             3538
start_codon        2098
stop_codon         1885
gene               1412
Selenocysteine       12
Name: count, dtype: int64

In [110]:
features = [
    "CDS",
    "5' UTR",
    "3' UTR",
]

feature_intervals = {
    feature: bf.merge(annotation[annotation.feature==feature]).drop(columns="n_intervals") 
    for feature in features
}

feature_intervals

{'CDS':      chrom     start       end
 0       22  15528191  15529136
 1       22  15690077  15690709
 2       22  15695370  15695485
 3       22  15695644  15695818
 4       22  15697373  15697529
 ...    ...       ...       ...
 4522    22  50777951  50777981
 4523    22  50779054  50779121
 4524    22  50780180  50780226
 4525    22  50780708  50780718
 4526    22  50782187  50782294
 
 [4527 rows x 3 columns],
 "5' UTR":      chrom     start       end
 0       22  15690025  15690077
 1       22  16592550  16592810
 2       22  16808073  16808083
 3       22  16821606  16821699
 4       22  16825290  16825411
 ...    ...       ...       ...
 1264    22  50729571  50729660
 1265    22  50738203  50738235
 1266    22  50782294  50782351
 1267    22  50782768  50783286
 1268    22  50783390  50783667
 
 [1269 rows x 3 columns],
 "3' UTR":      chrom     start       end
 0       22  15697532  15697629
 1       22  15698661  15698768
 2       22  15700077  15700215
 3       22  15702685

In [111]:
path = f"../../results/positions/{chrom}/processed_logits/{model}.parquet"
V = pl.read_parquet(path)
V = (
    V.with_columns(
        pl.DataFrame(softmax(V.select(NUCLEOTIDES), axis=1), schema=NUCLEOTIDES)
    )
    .with_columns(
        pl.when(pl.col("ref") == "A").then(pl.col("A"))
        .when(pl.col("ref") == "C").then(pl.col("C"))
        .when(pl.col("ref") == "G").then(pl.col("G"))
        .when(pl.col("ref") == "T").then(pl.col("T"))
        .alias("ref_prob")
    )
    .with_columns(
        (1 / pl.col("ref_prob")).alias("perplexity"),
        (pl.col("ref_prob") == pl.max_horizontal(NUCLEOTIDES))
        .alias("accuracy").cast(float)
    )
)
V

chrom,pos,ref,A,C,G,T,ref_prob,perplexity,accuracy
str,i64,str,f32,f32,f32,f32,f32,f32,f64
"""22""",10510065,"""A""",0.677927,0.101803,0.118962,0.101308,0.677927,1.475085,1.0
"""22""",10510066,"""A""",0.678113,0.104195,0.11466,0.103032,0.678113,1.47468,1.0
"""22""",10510067,"""T""",0.107364,0.133747,0.101849,0.657041,0.657041,1.521976,1.0
"""22""",10510068,"""A""",0.680645,0.096985,0.123955,0.098416,0.680645,1.469195,1.0
"""22""",10510069,"""T""",0.120914,0.108757,0.121446,0.648883,0.648883,1.54111,1.0
"""22""",10510070,"""G""",0.541647,0.093849,0.229348,0.135156,0.229348,4.360189,0.0
"""22""",10510071,"""T""",0.11857,0.059215,0.754159,0.068056,0.068056,14.693797,0.0
"""22""",10510072,"""A""",0.721582,0.080565,0.113504,0.084349,0.721582,1.385843,1.0
"""22""",10510073,"""T""",0.084662,0.097041,0.083532,0.734765,0.734765,1.36098,1.0
"""22""",10510074,"""T""",0.069883,0.069918,0.068198,0.792001,0.792001,1.262625,1.0


In [112]:
intervals = (
    V.with_columns(
        (pl.col("pos") - 1).alias("start"),
        pl.col("pos").alias("end")
    )
    .select(["chrom", "start", "end"])
    .to_pandas()
)
intervals

Unnamed: 0,chrom,start,end
0,22,10510064,10510065
1,22,10510065,10510066
2,22,10510066,10510067
3,22,10510067,10510068
4,22,10510068,10510069
...,...,...,...
39153628,22,50808399,50808400
39153629,22,50808400,50808401
39153630,22,50808401,50808402
39153631,22,50808402,50808403


In [114]:
for f, x in tqdm(feature_intervals.items()):
    V = V.with_columns(
        pl.from_pandas(bf.count_overlaps(intervals, x)["count"] > 0)
        .alias(f)
    )
V

100%|██████████| 3/3 [03:30<00:00, 70.15s/it] 


chrom,pos,ref,A,C,G,T,ref_prob,perplexity,accuracy,CDS,5' UTR,3' UTR
str,i64,str,f32,f32,f32,f32,f32,f32,f64,bool,bool,bool
"""22""",10510065,"""A""",0.677927,0.101803,0.118962,0.101308,0.677927,1.475085,1.0,false,false,false
"""22""",10510066,"""A""",0.678113,0.104195,0.11466,0.103032,0.678113,1.47468,1.0,false,false,false
"""22""",10510067,"""T""",0.107364,0.133747,0.101849,0.657041,0.657041,1.521976,1.0,false,false,false
"""22""",10510068,"""A""",0.680645,0.096985,0.123955,0.098416,0.680645,1.469195,1.0,false,false,false
"""22""",10510069,"""T""",0.120914,0.108757,0.121446,0.648883,0.648883,1.54111,1.0,false,false,false
"""22""",10510070,"""G""",0.541647,0.093849,0.229348,0.135156,0.229348,4.360189,0.0,false,false,false
"""22""",10510071,"""T""",0.11857,0.059215,0.754159,0.068056,0.068056,14.693797,0.0,false,false,false
"""22""",10510072,"""A""",0.721582,0.080565,0.113504,0.084349,0.721582,1.385843,1.0,false,false,false
"""22""",10510073,"""T""",0.084662,0.097041,0.083532,0.734765,0.734765,1.36098,1.0,false,false,false
"""22""",10510074,"""T""",0.069883,0.069918,0.068198,0.792001,0.792001,1.262625,1.0,false,false,false


In [117]:
V = V.with_columns(
    pl.when(pl.col("CDS")).then(pl.lit("CDS"))
    .when(pl.col("5' UTR")).then(pl.lit("5' UTR"))
    .when(pl.col("3' UTR")).then(pl.lit("3' UTR"))
    .otherwise(pl.lit("Non-exonic"))
    .alias("Region")
)
V["Region"].value_counts()

Region,count
str,u32
"""3' UTR""",785066
"""Non-exonic""",37423927
"""5' UTR""",189036
"""CDS""",755604


In [119]:
V.group_by("Region").agg(
    pl.mean("perplexity").alias("mean perplexity"),
    pl.median("perplexity").alias("median perplexity"),
    pl.mean("accuracy").alias("mean accuracy"),
).sort("mean accuracy", descending=True)

Region,mean perplexity,median perplexity,mean accuracy
str,f32,f32,f64
"""CDS""",4.877965,1.032106,0.915296
"""5' UTR""",3.203826,1.645802,0.836999
"""3' UTR""",3.116771,1.687933,0.821391
"""Non-exonic""",2.528364,1.977608,0.690292
