In [None]:
import torch 
from deepchopper import  remove_intervals_and_keep_left, smooth_label_region, summary_predict, get_label_region

In [None]:
from deepchopper.utils import alignment_predict, highlight_target, highlight_targets
from pathlib import Path
from dataclasses import dataclass
import numpy as np 
import matplotlib.pyplot as plt
from collections import Counter

**TODO**: 

- [ ] summary chop  or not chop
- [ ] summary chop internal or terminal
- [ ] chop only has one interval
- [ ] summary chop interval size

In [None]:
# cnn_data_folder  = "/projects/b1171/ylk4626/project/DeepChopper/logs/eval/runs/2024-04-12_16-17-16/predicts/0/0.pt"
# hyena_data_folder = "/projects/b1171/ylk4626/project/DeepChopper/logs/eval/runs/2024-04-12_20-30-28/predicts/0/0.pt"

In [None]:
# def id2seq(ids: list[int]):
#     # A', 'C', 'G', 'T', 'N'
#     table = {7: "A", 8: "C", 9: "G", 10: "T", 11: "N"}
#     return "".join(table[c] for c in ids)


# def majority_voting(labels, window_size):
#     # Ensure window size is odd to have a central token
#     if window_size % 2 == 0:
#         window_size += 1

#     half_window = window_size // 2
#     smoothed_labels = []

#     for i in range(len(labels)):
#         # Extract the context window
#         start = max(0, i - half_window)
#         end = min(len(labels), i + half_window + 1)
#         window = labels[start:end]

#         # # 11100 will be 11111
#         if end == len(labels):
#             window += [1] * (i + half_window + 1 - end)

#         # Choose the most common label in the window
#         most_common = max(set(window), key=window.count)
#         smoothed_labels.append(most_common)

#     return smoothed_labels


# def ascii_values_to_string(ascii_values):
#     return "".join(chr(value) for value in ascii_values)


# def convert_id_str(ids):
#     return (ascii_values_to_string(i[2 : i[0] + 2]) for i in ids)


@dataclass
class FqRecord:
    id: str
    seq: str
    qual: str

    def to_str(self):
        return f"{self.id}\n{self.seq}\n+\n{self.qual}"


@dataclass
class Predict:
    prediction: list[int]
    seq: str
    id: str
    is_trucation: bool
    qual: str | None = None

    def is_terminal(self, *, threshold=10, smooth=False, window_size=None) -> bool:
        pass

    @property
    def is_polya(self) -> bool:
        pass

    @property
    def prediction_region(self):
        return get_label_region(self.prediction)

    def smooth_prediction_region(self, window_size):
        return get_label_region(self.smooth_label(window_size))

    def smooth_label(self, window_size):
        return majority_voting(self.prediction, window_size)

    @property
    def seq_len(self):
        return len(self.true_seq)

    def fetch_qual(self, fq_records):
        self.qual = fq_records[self.id].qual

    @property
    def qual_array(self):
        if self.qual is None:
            raise ValueError("no qual, please fetch qual first")
        return [ord(c) - 33 for c in list(self.qual)]

    def vis_qual_static(self, start: int | None = None, end: int | None = None, figure_size=(20, 1)):
        if self.qual is None:
            raise ValueError("no qual, please fetch qual first")

        start = 0 if start is None else start
        end = len(self.seq) if end is None else end

        qual = np.array([ord(c) - 33 for c in list(self.qual[start:end])]).reshape(1, -1)
        seq = list(self.seq[start:end])

        # Creating the heatmap
        fig, ax = plt.subplots(figsize=figure_size)  # Set a wide figure to accommodate the sequence
        cax = ax.imshow(qual, aspect="auto", cmap="viridis")
        cbar = plt.colorbar(cax, ax=ax, orientation="vertical")
        cbar.set_label("Value")
        # Setting up the sequence as x-axis labels
        ax.set_xticks(np.arange(len(seq)))
        ax.set_xticklabels(seq, rotation=90)  # Rotate labels for better readability
        # Remove y-axis labels as there's only one row
        ax.set_yticks([])
        ax.set_title(f"{self.id}: {start}-{end}")
        plt.close()

    def print_seq(self, *, smooth=False, smooth_window_size: int | None = None):
        regions = self.chop_intervals(smooth=smooth, smooth_window_size=smooth_window_size)

        print(f"id     : {self.id}")
        print(f"regions: {regions}")
        highlight_targets(self.seq, regions)

    def compare_smooth(self, smooth_window_size: int):
        regions = self.prediction_region

        window_size = smooth_window_size
        smooth_regions = self.smooth_prediction_region(window_size)

        print(f"id      : {self.id}")
        print(f"original: {regions}")
        print(f"smooth  : {smooth_regions}")
        highlight_targets(self.seq, regions)
        highlight_targets(self.seq, smooth_regions)

    def chop_intervals(self, *, smooth: bool, smooth_window_size: int | None) -> list[tuple[int, int]]:
        if smooth:
            if smooth_window_size is None:
                raise ValueError("please provide window size")
            window_size = smooth_window_size
            regions = self.smooth_prediction_region(window_size)
        else:
            regions = self.prediction_region
        return regions

    def to_fqs_record(self, intervals: list[tuple[int, int]]):
        if self.qual is None:
            raise ValueError("no qual, please fetch qual first")

        assert len(self.qual) == len(self.seq)

        seqs, saved_intervals = remove_intervals_and_keep_left(self.seq, intervals)
        quals, saved_intervals = remove_intervals_and_keep_left(self.qual, intervals)

        assert len(seqs) == len(quals)
        for ind, (seq, qual) in enumerate(zip(seqs, quals, strict=True)):
            record_id = f"@{self.id}|{saved_intervals[ind][0], saved_intervals[ind][1]}"
            yield FqRecord(id=record_id, seq=seq, qual=qual)

    def smooth_and_select_intervals(
        self,
        smooth_window_size: int,
        min_interval_length: int,
        approved_interval_nums: int = 1,
    ) -> list[tuple[int, int]]:
        chop_intervals = self.chop_intervals(smooth=True, smooth_window_size=smooth_window_size)

        results = []
        for interval in chop_intervals:
            if interval[1] - interval[0] > min_interval_length:
                results.append(interval)

        if len(results) > approved_interval_nums:
            return []

        return results


class BatchPredict:
    def __init__(self, batch_prediction, smooth_window_size=9):
        self.smooth_window_size = smooth_window_size
        self.data_path = batch_prediction

        self.data = torch.load(batch_prediction)
        self.batch_size = self.data["seq"].shape[0]
        self.batch_predicts = []

        predictions, _labels = summary_predict(
            self.data["prediction"].argmax(-1).numpy(), self.data["target"].numpy(), -100
        )
        seqs, _labels = summary_predict(self.data["seq"].numpy(), self.data["target"].numpy(), -100)

        for idx in range(len(predictions)):
            self.batch_predicts.append(
                Predict(
                    prediction=predictions[idx],
                   seq=id2seq(seqs[idx]),
                    id=ascii_values_to_string(self.data["id"][idx][2 : self.data["id"][idx][0] + 2]),
                    is_trucation=bool(self.data["id"][idx][1]),
                )
            )

    def __repr__(self):
        return f"{__class__.__name__}(batch_size={self.batch_size}, data={self.data_path})"

    def print_all_seq(self, *, smooth=False, smooth_window_size: int | None = None):
        for predict in self.batch_predicts:
            if smooth:
                window_size = smooth_window_size if smooth_window_size is not None else self.smooth_window_size
                regions = predict.smooth_prediction_region(window_size)
            else:
                regions = predict.prediction_region

            print(f"id     : {predict.id}")
            print(f"regions: {regions}")
            highlight_targets(predict.seq, regions)

    def compare_smooth(self, smooth_window_size: int | None = None):
        for predict in self.batch_predicts:
            regions = predict.prediction_region

            window_size = smooth_window_size if smooth_window_size is not None else self.smooth_window_size
            smooth_regions = predict.smooth_prediction_region(window_size)

            print(f"id      : {predict.id}")
            print(f"original: {regions}")
            print(f"smooth  : {smooth_regions}")
            highlight_targets(predict.seq, regions)
            highlight_targets(predict.seq, smooth_regions)

    def __len__(self):
        return self.batch_size

    def __getitem__(self, idx):
        return self.batch_predicts[idx]

    def __iter__(self):
        return iter(self.batch_predicts)

    def __next__(self):
        return next(self)

In [None]:
def gather_all_predicts(chunk_result_path:list[Path], smooth_window_size = int):
    for chunk_path in chunk_result_path:
        print(f"load chunk data from {chunk_path}")
        for batch in (chunk_path).glob("*.pt"):
            try:
                result = BatchPredict(batch, smooth_window_size=smooth_window_size)
            except Exception as  e:
                print(f"fail to load batch {batch}: {e}")
            else:
                yield from result

In [None]:
from needletail import parse_fastx_file, NeedletailError, reverse_complement, normalize_seq

def collect_fq_records(file: Path):
    result = {}
    try:
        for record in parse_fastx_file(file.as_posix()):
            result[record.id]  = record
    except NeedletailError:
        print("Invalid Fastq file")

    return result

In [None]:
import pysam
def collect_sam_records(file: Path):
    if not isinstance(file, Path):
        file = Path(file)
    
    result = {}
    samfile = pysam.AlignmentFile(file.as_posix(), "rb")

    for read in samfile.fetch():
        result[read.query_name]  = read 

    return result 

In [None]:
sam_records = collect_sam_records("/projects/b1171/ylk4626/project/DeepChopper/data/eval/real_data/dorado_without_trim_fqs/VCaP.bam")

In [None]:
test_chunks = [Path("/projects/b1171/ylk4626/project/DeepChopper/tests/data/eval/chunk0"),
               Path("/projects/b1171/ylk4626/project/DeepChopper/tests/data/eval/chunk1")]       

In [None]:
fq_records = collect_fq_records(Path("/projects/b1171/ylk4626/project/DeepChopper/data/eval/real_data/dorado_without_trim_fqs/VCaP.fastq"))

In [None]:
len(fq_records)

In [None]:
# cnn check point /projects/b1171/ylk4626/project/DeepChopper/logs/train/runs/2024-04-07_12-01-37/checkpoints/epoch_036_f1_0.9914.ckpt 
# heyna check point  /projects/b1171/ylk4626/project/DeepChopper/logs/train/runs/2024-04-09_20-13-03/checkpoints/epoch_007_f1_0.9931.ckpt


## K562

# data/eval/real_data/dorado_without_trim_fqs/K562.fastq_chunks/K562.fastq_0.parquet

# hyena result: chunk 0 1 2 3 4 5 6
# /projects/b1171/ylk4626/project/DeepChopper/logs/eval/runs/2024-04-12_20-50-48
# /projects/b1171/ylk4626/project/DeepChopper/logs/eval/runs/2024-04-12_20-30-28
# /projects/b1171/ylk4626/project/DeepChopper/logs/eval/runs/2024-04-12_22-01-16
# /projects/b1171/ylk4626/project/DeepChopper/logs/eval/runs/2024-04-12_22-17-21
# /projects/b1171/ylk4626/project/DeepChopper/logs/eval/runs/2024-04-12_22-28-45
# /projects/b1171/ylk4626/project/DeepChopper/logs/eval/runs/2024-04-12_22-39-48
# /projects/b1171/ylk4626/project/DeepChopper/logs/eval/runs/2024-04-12_22-51-10


# hyena_results = [
#     Path("/projects/b1171/ylk4626/project/DeepChopper/logs/eval/runs/2024-04-12_20-50-48/predicts/0"),
#     Path("/projects/b1171/ylk4626/project/DeepChopper/logs/eval/runs/2024-04-12_20-30-28/predicts/0"),
#     Path("/projects/b1171/ylk4626/project/DeepChopper/logs/eval/runs/2024-04-12_22-01-16/predicts/0"),
#     Path("/projects/b1171/ylk4626/project/DeepChopper/logs/eval/runs/2024-04-12_22-17-21/predicts/0"),
#     Path("/projects/b1171/ylk4626/project/DeepChopper/logs/eval/runs/2024-04-12_22-28-45/predicts/0"),
#     Path("/projects/b1171/ylk4626/project/DeepChopper/logs/eval/runs/2024-04-12_22-39-48/predicts/0"),
#     Path("/projects/b1171/ylk4626/project/DeepChopper/logs/eval/runs/2024-04-12_22-51-10/predicts/0"),
# ]
# # cnn result

# cnn_results = [
#     Path("/projects/b1171/ylk4626/project/DeepChopper/logs/eval/runs/2024-04-14_11-54-02/predicts/0"),
#     Path("/projects/b1171/ylk4626/project/DeepChopper/logs/eval/runs/2024-04-14_11-59-59/predicts/0"),
#     Path("/projects/b1171/ylk4626/project/DeepChopper/logs/eval/runs/2024-04-14_12-07-39/predicts/0"),
#     Path("/projects/b1171/ylk4626/project/DeepChopper/logs/eval/runs/2024-04-14_12-15-06/predicts/0"),
#     Path("/projects/b1171/ylk4626/project/DeepChopper/logs/eval/runs/2024-04-14_12-20-36/predicts/0"),
#     Path("/projects/b1171/ylk4626/project/DeepChopper/logs/eval/runs/2024-04-14_12-26-11/predicts/0"),
#     Path("/projects/b1171/ylk4626/project/DeepChopper/logs/eval/runs/2024-04-14_12-31-16/predicts/0"),
# ]

In [None]:
## VCaP

# hyena 
# /projects/b1171/ylk4626/project/DeepChopper/logs/eval/runs/2024-04-15_15-59-13
# /projects/b1171/ylk4626/project/DeepChopper/logs/eval/runs/2024-04-15_16-32-10
# /projects/b1171/ylk4626/project/DeepChopper/logs/eval/runs/2024-04-15_22-42-00
# /projects/b1171/ylk4626/project/DeepChopper/logs/eval/runs/2024-04-15_23-23-31
# /projects/b1171/ylk4626/project/DeepChopper/logs/eval/runs/2024-04-16_00-26-11
#

hyena_results = [
Path("/projects/b1171/ylk4626/project/DeepChopper/logs/eval/runs/2024-04-15_15-59-13/predicts/0"),
Path("/projects/b1171/ylk4626/project/DeepChopper/logs/eval/runs/2024-04-15_16-32-10/predicts/0"),
Path("/projects/b1171/ylk4626/project/DeepChopper/logs/eval/runs/2024-04-15_22-42-00/predicts/0"),
Path("/projects/b1171/ylk4626/project/DeepChopper/logs/eval/runs/2024-04-15_23-23-31/predicts/0"),
Path('/projects/b1171/ylk4626/project/DeepChopper/logs/eval/runs/2024-04-16_00-26-11/predicts/0'),
Path("/projects/b1171/ylk4626/project/DeepChopper/logs/eval/runs/2024-04-16_10-38-12/predicts/0"),

# /projects/b1171/ylk4626/project/DeepChopper/logs/eval/runs/2024-04-16_11-18-42
# /projects/b1171/ylk4626/project/DeepChopper/logs/eval/runs/2024-04-16_12-10-35
    
]

In [None]:
data_points = 20000
all_predicts = []
idx = 0
for p in gather_all_predicts(hyena_results[:1], smooth_window_size=11):
    idx+=1
    if idx <= data_points:
        all_predicts.append(p)
    else:
        break

In [None]:
len(all_predicts)

In [None]:
for p in all_predicts:
    p.fetch_qual(fq_records)

In [None]:
all_predicts_with_chop  = []
all_predicts_smooth_with_chop = []
smooth_intervals = {}

for p in all_predicts:
    if len(p.prediction_region) > 0:
        all_predicts_with_chop.append(p)
        
    smooth_regions = p.smooth_and_select_intervals(smooth_window_size=11, min_interval_length=5, approved_interval_nums=999)

    if len(smooth_regions) > 0:
        all_predicts_smooth_with_chop.append(p)

    smooth_intervals[p.id] = smooth_regions

In [None]:
len(all_predicts_with_chop)

In [None]:
len(all_predicts_smooth_with_chop)

In [None]:
internal_chop_predicts = [] 

for p in all_predicts_smooth_with_chop:        
    reg = smooth_intervals[p.id]
    for r in reg:
        if r[1] / len(p.seq) < 0.7:
            internal_chop_predicts.append(p)
            highlight_targets(p.seq, reg)
            break
    # p.compare_smooth(11)

In [None]:
len(internal_chop_predicts)

In [None]:
idx  = 0 
for p in all_predicts_smooth_with_chop:
    idx +=1
    if idx > 100:
        break
    reg = smooth_intervals[p.id]
        
    oreg = p.prediction_region 

    print(f"orignal:{oreg}")
    print(f"smooth: {reg}")

    highlight_targets(p.seq, oreg)
    highlight_targets(p.seq, reg)

In [None]:
all_predicts[0].vis_qual_static(all_predicts[0].smooth_prediction_region(11)[0][0] - 10 , all_predicts[0].smooth_prediction_region(11)[0][1] + 10)

In [None]:
import seaborn as sns
import numpy as np

def get_data_hist_for_num_of_intervals(all_predicts, * ,smooth:bool = False,
    smooth_window_size:int|None= None, 
    min_interval_length: int| None= None, 
    approved_interval_nums: int|None = None):

    number = [ ]
    for predict in all_predicts:
        if smooth:
            number.append(len(predict.smooth_and_select_intervals(smooth_window_size=smooth_window_size, 
                                            min_interval_length=min_interval_length, 
                                            approved_interval_nums=approved_interval_nums)))
        else:
            number.append(len(predict.prediction_region))

    return number

In [None]:
plot_data =  [] 
smooth_plot_data = []
for p in all_predicts_with_chop:
    plot_data.append(len(p.prediction_region))
    smooth_plot_data.append(len(smooth_intervals[p.id]))
    

In [None]:
# data  = get_data_hist_for_num_of_intervals(cnn_all_predicts)
# smooth_data= get_data_hist_for_num_of_intervals(cnn_all_predicts, smooth=True, smooth_window_size=9, min_interval_length=1, approved_interval_nums=999)

In [None]:
def vis_hist_for_num_of_intervals(data, figsize=(10,6), title=None):
    # Create histogram with a kernel density estimate
    plt.figure(figsize=figsize)
    sns.histplot(data, kde=True, color='green', line_kws={'linewidth': 2}, discrete=True)
    plt.title(title)
    plt.xlabel('Value')
    plt.ylabel('Frequency')

In [None]:
vis_hist_for_num_of_intervals(plot_data, title="Original Intervals")

In [None]:
vis_hist_for_num_of_intervals(smooth_plot_data, title="Smooth Intervals")

In [None]:
vis_hist_for_num_of_intervals(np.array(plot_data) - np.array(smooth_plot_data), title="Diff with/withou Smooth")

In [None]:
idx = 0
ploas = 0
only_one = 0
for p in all_predicts_with_chop:
    smooth_region = smooth_intervals[p.id]
    
    if len(smooth_region) == 1:
        only_one +=1 
        ploa_counter = Counter(p.seq[smooth_region[0][0] -10: smooth_region[0][0]])
        if ploa_counter.get("A", 0 ) >= 3:
            ploas +=1
        else:
            print(p.id)
            highlight_targets(p.seq, smooth_region)
        if smooth_region[0][1] / len(p.seq) < 0.7:
            idx +=1
            # print(p.id)
            # highlight_targets(p.true_seq, smooth_region)
                

In [None]:
idx

In [None]:
ploas 

In [None]:
only_one

In [None]:
plot_region_size_data = []
for p in all_predicts_with_chop:
    smooth_region = smooth_intervals[p.id]
    
    if len(smooth_region) == 1:
        size = smooth_region[0][1] - smooth_region[0][0]
        plot_region_size_data.append(size)

In [None]:
vis_hist_for_num_of_intervals(plot_region_size_data, title="Chop Size of clean data (smooth)")

In [None]:
plot_region_size_data = []
for p in all_predicts_with_chop:
    smooth_region = smooth_intervals[p.id]
    
    if len(smooth_region) == 1:
        size = smooth_region[0][1] - smooth_region[0][0]
        plot_region_size_data.append(size)

In [None]:
more_than_one_smooth_preditions = []

for p in all_predicts_with_chop:
    smooth_region = smooth_intervals[p.id]
    
    if len(smooth_region) > 1:
        more_than_one_smooth_preditions.append(p)

In [None]:
len(more_than_one_smooth_preditions)

In [None]:
import re 
from textwrap import wrap 

from collections import defaultdict 

def verify_result_with_sam_records(predicts, smooth_intervals, sam_records, interval_threshold: float =0.7, overlap_threshold: int= 20):
    pat_left_s = re.compile(r"^(\d+)S")
    pat_right_s = re.compile(r"(\d+)S$")

    
    predict_read = sam_records.get(predict.id, None)
    if predict_read is None:
        print(f"the read is not map")
        return

    
    left_mat = pat_left_s.search(predict_read.cigarstring)
    right_mat = pat_right_s.search(predict_read.cigarstring)

    ls_len = int(left_mat.group(1)) if left_mat else 0
    rs_len = int(right_mat.group(1)) if right_mat else 0

    intervals = smooth_intervals[predict.id]

    # define results 
    overlap_results = defaultdict(int)

    for interval in intervals:
        quals = predict.qual_array[interval[0]:interval[1]]
        average_qual  = sum(quals)/len(quals)

        if interval[1] / interval_threshold <= interval_threshold:
            # internal
            
            if ls_len != 0:
                # if has ls
                if abs(interval[1]-ls_len) < overlap_threshold:
                    overlap_results["internal"] +=1 

            if rs_len !=-
                    

            
        






def wrap_str(ostr, width):
    return "\n".join(wrap(ostr,width))
    
def show_sam_record(predict, smooth_intervals, sam_records):
    pat_left_s = re.compile(r"^(\d+)S")
    pat_right_s = re.compile(r"(\d+)S$")

    seq_len = len(predict.seq)
    txt_width = 120
    
    print(f"\nread id {predict.id} seq len: {seq_len}")

    for interval in smooth_intervals[predict.id]:
        quals = predict.qual_array[interval[0]:interval[1]]
        average_qual  = sum(quals)/len(quals)
        print(f"smooth interval : {interval} len: {interval[1] - interval[0]}     {average_qual=}")

    highlight_targets(predict.seq, smooth_intervals[predict.id])    
    
    predict_read = sam_records.get(predict.id, None)
    if predict_read is None:
        print(f"the read is not map")
        return

    print(f"{predict_read.reference_id=} {predict_read.mapping_quality=}")
    print(f"{predict_read.reference_start=} {predict_read.reference_end=}")


    print(f"cigar: {wrap_str(predict_read.cigarstring, txt_width)}")

    
    left_mat = pat_left_s.search(predict_read.cigarstring)
    right_mat = pat_right_s.search(predict_read.cigarstring)

    ls_len = int(left_mat.group(1)) if left_mat else None
    rs_len = int(right_mat.group(1)) if right_mat else None

    if ls_len is not None:
        print(f"ls: 0-{ls_len}  \n {wrap_str(predict.seq[:ls_len], txt_width)}")
        
    if rs_len is not None:
        print(f"rs: {seq_len-rs_len}-{seq_len} \n {wrap_str(predict.seq[-rs_len:], txt_width)}")
    
    if predict_read.has_tag("SA"):
        print(f"has sa")
        chimeric_alns = predict_read.get_tag("SA")[:-1].split(";")

        for _aln in chimeric_alns:
            (
                chr_sa,
                pos_sa,
                strand_sa,
                cigar_sa,
                mapq_sa,
                nm_sa,
            ) = _aln.split(",")
    
            left_mat = pat_left_s.search(cigar_sa)
            right_mat = pat_right_s.search(cigar_sa)
    
            l_s_len = left_mat.group(1) if left_mat else ""
            r_s_len = right_mat.group(1) if right_mat else ""
    
            tgt_key = f"{predict_read.qname}\t{l_s_len=}\t{r_s_len=}"
            
            print(f"chimeric : {tgt_key}")
            

In [None]:
for p in internal_chop_predicts[:50]:
    show_sam_record(p, smooth_intervals, sam_records)

In [None]:
for p in more_than_one_smooth_preditions:
    show_sam_record(p, smooth_intervals, sam_records)

In [None]:
for p in all_predicts_with_chop:
    show_sam_record(p, smooth_intervals, sam_records)

In [None]:
read.get_tag("SA")

In [None]:
idx = 0
for p in all_predicts_with_chop[200:]:
    if idx > 10:
        break
    smooth_region = smooth_intervals[p.id]
    if len(smooth_region) == 1:
        idx +=1
        print(p.id)
        highlight_targets(p.seq, smooth_region)
        p.vis_qual_static(smooth_region[0][0] -10 , smooth_region[0][1] +10  )

In [None]:
smooth_label_region(true_predcition[0], 1, 1, 1)

In [None]:
smooth_label_region(true_predcition[0], 1, 1, 2)