In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import numpy as np
import pandas as pd
import matplotlib as mpl
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
import numba
import numba.typed

from pathlib import Path
from tqdm import trange, tqdm
import sys
import seaborn as sns
import scipy.stats
import os
import pysam
import pprint
import pickle
import joblib
import polars as pl

# This is needed before pybedtools to make sure bedtools is imported on sanger JupyterHub 
os.environ["PATH"] += ":" + os.path.join(sys.prefix, "bin")
import pybedtools

pd.set_option('display.max_rows', 1000)


In [3]:
sys.path.append(str(Path(os.getcwd()).parent))
from src import liftover, annotate, diagnostics, inference

In [4]:
aut_chrom_names = [f"chr{i}" for i in list(range(1, 23))]
chrom_names = aut_chrom_names + ["chrX", "chrY"]

# Read the classified events

In [5]:
# Samples to do
sample_ids = [
    "PD50477f",
    # "PD50508bf", -- ignore; merged two sampling dates just for phasing, but should be analyzed separately
    "PD50519d",
    # "PD47269d", -- don't use, not there
    "PD50508f",
    # "PD50511e", -- don't use, likely mixture
    "PD50523b",
    # "PD48473b", -- don't use, not there
    "PD50521b",
    "PD50508b",
    # "PD50521be", -- ignore; merged two sampling dates just for phasing, but should be analyzed separately
    "PD46180c",
    # "PD50502f", -- don't use, likely mixture
    "PD50521e",
    # "PD50511e_SS",  --- don't use
    "PD50489e",
]

cls_df = pl.concat([
    pl.read_parquet(
        str(Path("/lustre/scratch126/casm/team154pc/sl17/03.sperm/02.results/01.read_alignment/01.ccs/04.hifiasm/02.hifiasm_0.19.5-r592/02.chromosome_length_scaffolds")
            / f"{focal_sample_id}" / "reads" / f"{chrom}_RagTag.certainty_0.95.classified_reads.parquet")
    ) \
    for focal_sample_id in sample_ids
    for chrom in aut_chrom_names
])

In [6]:
trusty_cls_df = (cls_df
    .filter(~pl.col("has_common_transition"))
    .filter(pl.col("min_coverage_between_transitions_hap1") >= 3)
    .filter(pl.col("min_coverage_between_transitions_hap2") >= 3)
)

# Utils

In [7]:
def hist_transitions(df, upto=2):
    return np.histogram(df["n_transitions"], bins=np.arange(upto+2), density=True)[0]

# Functions to get call set for a single sample

In [8]:
def generate_call_set(focal_sample_ids, take_every=1, bootstrap=False, cutoff=None, min_snps=0):
    certainty = "0.95"
    
    # Get all the other reads
    all_sampled_reads = []

    for focal_sample_id in focal_sample_ids:
        for t2t_chrom in aut_chrom_names:
            denovo_chrom = t2t_chrom + "_RagTag"

            patterns_filename = Path("/lustre/scratch126/casm/team154pc/sl17/03.sperm/02.results/01.read_alignment/01.ccs/04.hifiasm/02.hifiasm_0.19.5-r592/02.chromosome_length_scaffolds/") \
                / f"{focal_sample_id}" / "reads" / f"{denovo_chrom}.certainty_{certainty}.candidate_reads.patterns.parquet"

            pat_df = pl.scan_parquet(patterns_filename)
            pat_df = pat_df.gather_every(take_every)

            all_sampled_reads.append(pat_df)

    all_sampled_reads = pl.concat(all_sampled_reads).collect(streaming=True)
    
    # Create the joint dataframe
    both_df = (
        pl.concat([
            (trusty_cls_df
                .filter(pl.col("class") != "CNCO")
                .filter(pl.col("sample_id").is_in(focal_sample_ids))
                .select(["read_length", "snp_positions_on_read", "idx_transitions", "sample_id", "read_name"])
                .with_columns(weight = 1)
            ), 
            (all_sampled_reads
                .select(["read_length", "snp_positions_on_read", "idx_transitions", "sample_id", "read_name"])
                .with_columns(weight = take_every)
            )
        ])
    )   
    
    # Make sure we have minimal number of SNPs
    both_df = both_df.filter(pl.col("snp_positions_on_read").list.len() >= min_snps)
    
    # Cutoff if needed
    if cutoff:
        both_df = both_df.filter(
            pl.when(pl.col("idx_transitions").list.len() != 2).then(pl.lit(True)).otherwise(
                (pl.col("snp_positions_on_read").list.get(pl.col("idx_transitions").list.get(1)+1) - \
                 pl.col("snp_positions_on_read").list.get(pl.col("idx_transitions").list.get(0))) < cutoff
            )
        )
    
    # Bootstrap if needed
    if bootstrap:
        both_df = both_df.sample(n = len(both_df), with_replacement = True)
    
    # Add useful fields
    aug_both_df = (both_df
        .with_columns(
            prob_inside = 1 - (pl.col("snp_positions_on_read").list.get(0) + pl.col("read_length") - pl.col("snp_positions_on_read").list.get(-1)) / pl.col("read_length"),
            n_transitions = pl.col("idx_transitions").list.len(),
        )
        .with_columns(
            rounded_prob_inside = pl.col("prob_inside") // 0.1 * 0.1,
        )
    )
    
    return aug_both_df


# Do full inference?

In [15]:
%%time
all_df = generate_call_set(sample_ids, 100)

CPU times: user 40.3 s, sys: 4.5 s, total: 44.8 s
Wall time: 12.5 s


In [25]:
res = inference.maximum_likelihood_all_reads(
    all_df["read_length"].to_numpy(),
    all_df["snp_positions_on_read"].to_numpy(),
    all_df["idx_transitions"].to_numpy(),
    all_df["weight"].to_numpy(),
    [0, 0.5],
    [0.8, 1],    
    [1, 400],
    [1, 8000],
    [1e-7, 1e-7],
)

Current:	[   0.25     0.945  200.5   4000.5      0.   ]
Current:	[   0.25     0.945  200.5   4000.5      0.   ]
Current:	[   0.25     0.945  200.5   4000.5      0.   ]
Current:	[   0.245    0.974  188.711 4041.785    0.   ]
Current:	[   0.243    1.     181.637 4066.556    0.   ]
Current:	[   0.243    1.     181.637 4066.556    0.   ]
Current:	[   0.243    1.     181.637 4066.556    0.   ]
Current:	[   0.243    1.     181.637 4066.556    0.   ]
Current:	[   0.228    1.     162.147 4016.303    0.   ]
Current:	[   0.228    1.     162.147 4016.303    0.   ]
Current:	[   0.228    1.     162.147 4016.303    0.   ]
Current:	[   0.23     1.     138.434 4269.247    0.   ]
Current:	[   0.212    0.999  126.082 4098.992    0.   ]
Current:	[   0.212    0.999  126.082 4098.992    0.   ]
Current:	[   0.194    0.997   84.511 4499.954    0.   ]
Current:	[   0.194    0.997   84.511 4499.954    0.   ]
Current:	[   0.174    0.996   12.178 4714.51     0.   ]
Current:	[   0.167    0.996   35.378 4406.936   

In [22]:
res.x

array([1.20272546e-01, 9.93070706e-01, 2.39709390e+01, 1.10317996e+03,
       1.00000000e-07])

In [23]:
res

       message: Optimization terminated successfully.
       success: True
        status: 0
           fun: 35907.18314695208
             x: [ 1.203e-01  9.931e-01  2.397e+01  1.103e+03  1.000e-07]
           nit: 281
          nfev: 484
 final_simplex: (array([[ 1.203e-01,  9.931e-01, ...,  1.103e+03,
                         1.000e-07],
                       [ 1.203e-01,  9.931e-01, ...,  1.103e+03,
                         1.000e-07],
                       ...,
                       [ 1.203e-01,  9.931e-01, ...,  1.103e+03,
                         1.000e-07],
                       [ 1.203e-01,  9.931e-01, ...,  1.103e+03,
                         1.000e-07]]), array([ 3.591e+04,  3.591e+04,  3.591e+04,  3.591e+04,
                        3.591e+04,  3.591e+04]))