In [2]:
import torch 
from deepchopper import  remove_intervals_and_keep_left, smooth_label_region, summary_predict, get_label_region

Downloading builder script:   0%|          | 0.00/4.20k [00:00<?, ?B/s]

Downloading builder script:   0%|          | 0.00/6.77k [00:00<?, ?B/s]

Downloading builder script:   0%|          | 0.00/7.55k [00:00<?, ?B/s]

Downloading builder script:   0%|          | 0.00/7.36k [00:00<?, ?B/s]

In [3]:
from deepchopper.utils import alignment_predict, highlight_target, highlight_targets

In [4]:
from pathlib import Path 

In [5]:
# cnn check point /projects/b1171/ylk4626/project/DeepChopper/logs/train/runs/2024-04-07_12-01-37/checkpoints/epoch_036_f1_0.9914.ckpt 
# heyna check point  /projects/b1171/ylk4626/project/DeepChopper/logs/train/runs/2024-04-09_20-13-03/checkpoints/epoch_007_f1_0.9931.ckpt

# data/eval/real_data/dorado_without_trim_fqs/K562.fastq_chunks/K562.fastq_0.parquet

# heyna result: chunk 0 1 2 3 4 5 6
# /projects/b1171/ylk4626/project/DeepChopper/logs/eval/runs/2024-04-12_20-50-48
# /projects/b1171/ylk4626/project/DeepChopper/logs/eval/runs/2024-04-12_20-30-28
# 
#
#
#
#



cnn_data_folder  = "/projects/b1171/ylk4626/project/DeepChopper/logs/eval/runs/2024-04-12_16-17-16/predicts/0/0.pt"
heyna_data_folder = "/projects/b1171/ylk4626/project/DeepChopper/logs/eval/runs/2024-04-12_20-30-28/predicts/0/0.pt"

In [144]:
def test_smooth(data_folder):
    prediction = torch.load(data_folder)
    true_predcition, true_label = summary_predict(prediction['prediction'].argmax(-1).numpy(), prediction['target'].numpy(), -100)
    true_seq, true_label = summary_predict(prediction['seq'].numpy(), prediction['target'].numpy(), -100)
    return true_predcition, true_seq, true_label


def id2seq(ids: list[int]):
    # A', 'C', 'G', 'T', 'N'
    table =  {
        7: "A", 
        8: "C",
        9: "G",
        10: "T",
        11: "N"}
    return ''.join((table[c] for c in ids))


def majority_voting(labels, window_size):
    # Ensure window size is odd to have a central token
    if window_size % 2 == 0:
        window_size += 1

    half_window = window_size // 2
    smoothed_labels = []
    
    for i in range(len(labels)):
        # Extract the context window
        start = max(0, i - half_window)
        end = min(len(labels), i + half_window + 1)
        window = labels[start:end]     
        # Choose the most common label in the window
        most_common = max(set(window), key=window.count)
        # print(f"{i}: {labels[i]} {window=} {most_common=}")
        smoothed_labels.append(most_common)

    return smoothed_labels

def ascii_values_to_string(ascii_values):
    return ''.join(chr(value) for value in ascii_values)

def convert_id_str(ids):  
    return (ascii_values_to_string(i[2: i[0]+2]) for i in ids)    

from dataclasses import dataclass

@dataclass 
class Predict: 
    true_prediction: list[int]
    true_seq: str
    true_id: str
    is_trucation: bool
    
    def is_terminal(self, *, threshold=10, smooth=False, window_size=None) -> bool:
        pass
        
    @property
    def is_polya(self) -> bool:
        pass

    @property
    def prediction_region(self):
        return get_label_region(self.true_prediction)

    def smooth_prediction_region(self, window_size):
        return get_label_region(self.smooth_label(window_size))

    def smooth_label(self, window_size):
        return majority_voting(self.true_prediction, window_size)

    
class BatchPredict:
    def __init__(self, batch_prediction, smooth_window_size=9):
        self.smooth_window_size = smooth_window_size
        self.data  = torch.load(batch_prediction)
        self.batch_size = self.data['seq'].shape[0]  
        self.batch_predicts = []
        true_predictions, _true_labels = summary_predict(self.data['prediction'].argmax(-1).numpy(), self.data['target'].numpy(), -100)
        true_seqs, _true_label = summary_predict(self.data['seq'].numpy(), self.data['target'].numpy(), -100)
        for idx in range(len(true_predictions)):
            self.batch_predicts.append(Predict(
                true_prediction=true_predictions[idx],
                true_seq= id2seq(true_seqs[idx]),
                true_id=ascii_values_to_string(self.data['id'][idx][2: self.data['id'][idx][0]+2]),
                is_trucation=bool(self.data['id'][idx][1])))
                                       
    def __repr__(self):
        return f"{__class__.__name__}(batch_size={self.batch_size})"

    def print_all_seq(self, *, smooth=False, smooth_window_size: int| None =None):
        for predict in self.batch_predicts:
            if smooth:
                window_size = smooth_window_size if smooth_window_size is not None else self.smooth_window_size
                regions =  predict.smooth_prediction_region(window_size) 
            else:
                regions = predict.prediction_region

            print(f"id     : {predict.true_id}")
            print(f"regions: {regions}")
            # highlight_targets("".join((str(i) for i in self.true_predcition[idx])), regions)
            highlight_targets(predict.true_seq, regions)

    def compare_smooth(self,  smooth_window_size: int| None = None):
        for predict in self.batch_predicts:
            regions = predict.prediction_region
            
            window_size = smooth_window_size if smooth_window_size is not None else self.smooth_window_size
            smooth_regions = predict.smooth_prediction_region(window_size) 

            print(f"id      : {predict.true_id}")
            print(f"original: {regions}")
            print(f"smooth  : {smooth_regions}")
            highlight_targets(predict.true_seq, regions)
            highlight_targets(predict.true_seq, smooth_regions)

    def __len__(self):
        return batch_size 

    def __getitem__(self, idx):
        return self.batch_predicts[idx]
         

class BatchPredicts:
    def __init__(self, batch_folder, smooth_window_size=9):
        batch_folder = Path(batch_folder)
        self.batch_predicts =  (BatchPredict(batch, smooth_window_size=smooth_window_size)  for batch in (batch_folder).glob("*.pt"))

    def __iter__(self):
        return self.batch_predicts

    def __next__(self): 
        return next(self.batch_predicts)


In [152]:
"".join([str(i) for i in majority_voting(ss.true_predcition[7], 17)])

In [145]:
data_folder = Path("/projects/b1171/ylk4626/project/DeepChopper/logs/eval/runs/2024-04-12_20-30-28/predicts/0")
aa = BatchPredicts(data_folder, smooth_window_size=17)

In [153]:
# ss.print_all_seq(smooth=True)

In [160]:
ss = next(aa)
ss.compare_smooth()

id      : 66eb10f4-a13a-4f6f-a22f-299da3c91658
original: [(283, 347)]
smooth  : [(283, 347)]


id      : 307e9731-6b6e-4e85-a6ae-542926cf642f
original: []
smooth  : []


id      : f202f535-19d4-4b6f-9bff-98e064fe7631
original: [(428, 429), (430, 431), (668, 732)]
smooth  : [(668, 732)]


id      : 2a45f5a0-0ace-4d24-bd72-1b03a3eb0e44
original: []
smooth  : []


id      : 0312cf81-e852-4d86-81a0-7f0a54731d5e
original: [(640, 697)]
smooth  : [(640, 697)]


id      : 034da6c4-10d1-44bf-9c7e-bef0fd4132fe
original: [(1214, 1290)]
smooth  : [(1214, 1290)]


id      : 039f550a-1e06-468c-a0e0-ce2a90801bfb
original: []
smooth  : []


id      : 0969fbb6-ade1-45df-a1b1-601815b24afd
original: [(847, 920)]
smooth  : [(847, 920)]


id      : d5cbcdc5-787f-419f-a86a-b0ee8d829c46
original: [(784, 855)]
smooth  : [(784, 855)]


id      : d6007987-549c-4941-ad5e-bd9496442be0
original: [(845, 847), (848, 849), (854, 888)]
smooth  : [(851, 888)]


id      : 3b938543-54ad-4b76-84ac-a70fc1e8e456
original: []
smooth  : []


id      : 3af31cd3-98fc-41d1-9c2f-cd40b44f4a6d
original: [(519, 597), (604, 605)]
smooth  : [(519, 599)]


In [51]:
# heyna_prediction = BatchPredict(heyna_data_folder)
cnn_prediction = BatchPredict(cnn_data_folder)

In [222]:
cnn_prediction

BatchPredict(batch_size=6)

In [149]:
# cnn_prediction.compare_smooth()

In [150]:
# cnn_prediction.print_all_seq(smooth=False)

In [27]:
alignment_predict(majority_voting(cp[0]), cl[0])

In [12]:
hp, hs, hl = test_smooth(heyna_data_foler)

In [44]:
alignment_predict(cp[0], hp[0])

ValueError: zip() argument 2 is shorter than argument 1

In [26]:
smooth_label_region(true_predcition[0], 1, 1, 1)

[(403, 404),
 (407, 408),
 (410, 411),
 (413, 414),
 (1749, 1750),
 (2130, 2146),
 (2148, 2211)]

In [28]:
smooth_label_region(true_predcition[0], 1, 1, 2)

[(2130, 2146), (2148, 2211)]