In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
plt.rcParams['pdf.use14corefonts'] = True

import numba
import numba.typed

from pathlib import Path
import tqdm
import sys
import seaborn as sns
import scipy.stats
import os
import joblib
import io
import itertools
import warnings

from brokenaxes import brokenaxes


import polars as pl
pl.Config.set_tbl_rows(100);
pl.Config.set_fmt_str_lengths(50);


In [3]:
sys.path.append(str(Path(os.getcwd()).parent))
from src import liftover, annotate, diagnostics, inference
from src.IDs import *


# Create callset to infer from

## Blood

In [4]:
ceph_good_samples = "NA12878,NA12879,NA12881,NA12882,NA12886,NA12892,200084,200085,200102,200104".split(',')

In [5]:
%%time
blood_reads_df = pl.concat([
    pl.scan_parquet(
        f"/lustre/scratch122/tol/projects/sperm/results/CEPH_20241119/read_analysis/{sample_id}/{sample_id}/reads/{chrom}/all_reads_structure_annotated.parquet"
    ) 
    for sample_id in tqdm.tqdm(ceph_good_samples)
    for chrom in aut_chrom_names
])

100%|██████████| 10/10 [00:00<00:00, 390.77it/s]

CPU times: user 10.3 ms, sys: 11.7 ms, total: 22.1 ms
Wall time: 44.9 ms





In [6]:
%%time
blood_callset_df = inference.generate_call_set(blood_reads_df, ceph_good_samples, take_every=10, min_snps=3, sample_every=1)

CPU times: user 15min 37s, sys: 7min 53s, total: 23min 31s
Wall time: 16min 23s


In [7]:
blood_callset_df.shape

(4318196, 19)

## Sperm

In [8]:
%%time
sperm_reads_df = pl.concat([
    pl.scan_parquet(
        f"/lustre/scratch126/casm/team154pc/sl17/03.sperm/02.results/01.read_alignment/01.ccs/04.hifiasm/02.hifiasm_0.19.5-r592/02.chromosome_length_scaffolds/{focal_sample_id}/reads/{chrom}_RagTag.certainty_0.95.all_reads_structure_annotated.parquet"
    ) 
    for focal_sample_id in tqdm.tqdm(rahbari_sample_ids)
    for chrom in aut_chrom_names
])

100%|██████████| 9/9 [00:00<00:00, 3423.92it/s]

CPU times: user 5.81 ms, sys: 0 ns, total: 5.81 ms
Wall time: 5.25 ms





In [9]:
%%time
sperm_callset_df = inference.generate_call_set(sperm_reads_df, rahbari_sample_ids, take_every=10, min_snps=3, sample_every=1)

CPU times: user 7min 58s, sys: 4min 17s, total: 12min 16s
Wall time: 13min 2s


In [10]:
sperm_callset_df.shape

(1969931, 19)

# Run it

In [14]:
%%time
inference.maximum_likelihood_all_reads_joint(
    blood_callset_df["read_length"].to_numpy(),
    blood_callset_df["high_quality_snp_positions"].to_numpy(),
    blood_callset_df["high_quality_snps_idx_transitions"].to_numpy(),
    blood_callset_df["between_high_quality_snps_bp"].to_numpy() * 1e-8,
    numba.typed.List(np.repeat(5000 * 1e-8, len(blood_callset_df))),
    numba.typed.List(np.repeat(5000 * 1e-8, len(blood_callset_df))),
    blood_callset_df["weight"].to_numpy(),
    
    sperm_callset_df["read_length"].to_numpy(),
    sperm_callset_df["high_quality_snp_positions"].to_numpy(),
    sperm_callset_df["high_quality_snps_idx_transitions"].to_numpy(),
    sperm_callset_df["between_high_quality_snps_bp"].to_numpy() * 1e-8,
    numba.typed.List(np.repeat(5000 * 1e-8, len(sperm_callset_df))),
    numba.typed.List(np.repeat(5000 * 1e-8, len(sperm_callset_df))),
    sperm_callset_df["weight"].to_numpy(),
    
    q_range_sperm = (0.01, 0.5),   
    q_range_blood = (1e-10, 0.5),   
    
    m_range_sperm = (1e-10, 1-1e-10),
    m_range_blood = (1e-10, 1-1e-10),
    
    GC_tract_mean_range = (1, 1000),
    GC_tract_mean2_range = (100, 10000),
    
    prob_factor_range_sperm = (1, 1),
    prob_factor_range_blood = (1e-10, 1),
    
    read_margin_in_bp = 5000,
    
    # x0 = [
    #     0.1, 1e-4, 
    #     0.98, 0.01,     
    #     30, 1000, 
    #     1.0, 1e-2,
    # ],
)



Current:	[   0.255    0.25     0.5      0.5    500.5   5050.       1.       0.5  ]	59389.51855748413
Current:	[   0.268    0.25     0.5      0.5    500.5   5050.       1.       0.5  ]	58767.48130098777
Current:	[   0.255    0.263    0.5      0.5    500.5   5050.       1.       0.5  ]	59361.036058592894
Current:	[   0.255    0.25     0.525    0.5    500.5   5050.       1.       0.5  ]	59099.303263054826
Current:	[   0.255    0.25     0.5      0.525  500.5   5050.       1.       0.5  ]	59372.99082426613
Current:	[   0.255    0.25     0.5      0.5    525.525 5050.       1.       0.5  ]	59490.357789591435
Current:	[   0.255    0.25     0.5      0.5    500.5   5302.5      1.       0.5  ]	59482.722159094366
Current:	[   0.255    0.25     0.5      0.5    500.5   5050.       1.       0.5  ]	59389.51855748413
Current:	[   0.255    0.25     0.5      0.5    500.5   5050.       1.       0.525]	59417.07715617629
Current:	[   0.258    0.253    0.506    0.506  475.475 5113.125    1.       0.506]	5907

       message: Maximum number of function evaluations has been exceeded.
       success: False
        status: 1
           fun: 49941.714078042685
             x: [ 3.311e-01  2.207e-01  8.508e-01  8.251e-01  5.175e+01
                  2.478e+02  1.000e+00  9.868e-01]
           nit: 1102
          nfev: 1600
 final_simplex: (array([[ 3.311e-01,  2.207e-01, ...,  1.000e+00,
                         9.868e-01],
                       [ 3.290e-01,  2.162e-01, ...,  1.000e+00,
                         9.892e-01],
                       ...,
                       [ 3.251e-01,  2.240e-01, ...,  1.000e+00,
                         1.000e+00],
                       [ 3.282e-01,  2.200e-01, ...,  1.000e+00,
                         9.904e-01]]), array([ 4.994e+04,  4.994e+04,  4.994e+04,  4.994e+04,
                        4.994e+04,  4.994e+04,  4.994e+04,  4.994e+04,
                        4.994e+04]))

# Open files

In [4]:
import glob, pickle

res = pickle.load(open("/lustre/scratch126/casm/team154pc/sl17/03.sperm/02.results/08.tract_length/joint_inference.sample_every=1.bootstrap=0.rep=0.pcl", "rb"))

In [5]:
res

       message: Optimization terminated successfully.
       success: True
        status: 0
           fun: 42617.66938108271
             x: [ 8.958e-02  8.744e-03  8.830e-01  1.117e-02  1.000e+01
                  1.427e+02  1.000e+00  1.393e-02]
           nit: 845
          nfev: 1225
 final_simplex: (array([[ 8.958e-02,  8.744e-03, ...,  1.000e+00,
                         1.393e-02],
                       [ 8.957e-02,  8.747e-03, ...,  1.000e+00,
                         1.393e-02],
                       ...,
                       [ 8.957e-02,  8.745e-03, ...,  1.000e+00,
                         1.393e-02],
                       [ 8.958e-02,  8.744e-03, ...,  1.000e+00,
                         1.393e-02]]), array([ 4.262e+04,  4.262e+04,  4.262e+04,  4.262e+04,
                        4.262e+04,  4.262e+04,  4.262e+04,  4.262e+04,
                        4.262e+04]))

In [6]:
labels = ["q_range_sperm", "q_range_blood", "m_range_sperm", "m_range_blood", "GC_tract_mean_range", "GC_tract_mean2_range", "prob_factor_range_sperm", "prob_factor_range_blood"] 
for a, x in zip(labels, res.x): 
    print(a, x)

q_range_sperm 0.08957949375402405
q_range_blood 0.008744323474881258
m_range_sperm 0.8829616052536022
m_range_blood 0.01117324608328251
GC_tract_mean_range 10.0
GC_tract_mean2_range 142.69203547322655
prob_factor_range_sperm 1.0
prob_factor_range_blood 0.01393081588315273
