In [24]:
import torch 
from deepchopper import  remove_intervals_and_keep_left, smooth_label_region, summary_predict, get_label_region

In [3]:
from deepchopper.utils import alignment_predict, highlight_target, highlight_targets

In [4]:
from pathlib import Path 

In [204]:
# cnn check point /projects/b1171/ylk4626/project/DeepChopper/logs/train/runs/2024-04-07_12-01-37/checkpoints/epoch_036_f1_0.9914.ckpt 
# heyna check point  /projects/b1171/ylk4626/project/DeepChopper/logs/train/runs/2024-04-09_20-13-03/checkpoints/epoch_007_f1_0.9931.ckpt

# poe eval ckpt_path=/projects/b1171/ylk4626/project/DeepChopper/logs/train/runs/2024-04-08_23-19-20/checkpoints/epoch_005_f1_0.9933.ckpt  model=hyena +data.predict_data_path=data/eval/real_data/dorado_without_trim_fqs/K562.fastq_chunks/K562.fastq_0.parquet trainer=gpu

cnn_data_folder  = "/projects/b1171/ylk4626/project/DeepChopper/logs/eval/runs/2024-04-12_16-17-16/predicts/0/0.pt"
heyna_data_folder = "/projects/b1171/ylk4626/project/DeepChopper/logs/eval/runs/2024-04-12_14-19-46/predicts/0/0.pt"

In [240]:
def test_smooth(data_folder):
    prediction = torch.load(data_folder)
    true_predcition, true_label = summary_predict(prediction['prediction'].argmax(-1).numpy(), prediction['target'].numpy(), -100)
    true_seq, true_label = summary_predict(prediction['seq'].numpy(), prediction['target'].numpy(), -100)
    return true_predcition, true_seq, true_label


def id2seq(ids: list[int]):
    # A', 'C', 'G', 'T', 'N'
    table =  {
        7: "A", 
        8: "C",
        9: "G",
        10: "T",
        11: "N"}
    return ''.join((table[c] for c in ids))


def majority_voting(labels, window_size):
    # Ensure window size is odd to have a central token
    if window_size % 2 == 0:
        window_size += 1

    half_window = window_size // 2
    smoothed_labels = []
    
    for i in range(len(labels)):
        # Extract the context window
        start = max(0, i - half_window)
        end = min(len(labels), i + half_window + 1)
        window = labels[start:end]     
        # Choose the most common label in the window
        most_common = max(set(window), key=window.count)
        smoothed_labels.append(most_common)

    return smoothed_labels

def ascii_values_to_string(ascii_values):
    return ''.join(chr(value) for value in ascii_values)

def convert_id_str(ids):  
    return (ascii_values_to_string(i[1: i[0]+1]) for i in ids)    

class BatchPredict:
    def __init__(self, batch_prediction, smooth_window_size=9):
        self.smooth_window_size = smooth_window_size
        self.data  = torch.load(batch_prediction)
        self.batch_size = self.data['seq'].shape[0]
        self.true_predcition, self.true_label = summary_predict(self.data['prediction'].argmax(-1).numpy(), self.data['target'].numpy(), -100)
        self.true_seq, _true_label = summary_predict(self.data['seq'].numpy(), self.data['target'].numpy(), -100)
        self.true_id  = list(convert_id_str(self.data['id']))

    def __repr__(self):
        return f"{__class__.__name__}(batch_size={self.batch_size})"

    @staticmethod
    def prediction_region(predict):
        return get_label_region(predict)

    @staticmethod
    def smooth_region(predict, smooth_window_size: int):
        return majority_voting(predict, smooth_window_size)

    def print_all_seq(self, *, smooth=False, smooth_window_size: int| None =None):
        for idx, seq in enumerate(self.true_seq):
            if smooth:
                window_size = smooth_window_size if smooth_window_size is not None else self.smooth_window_size
                regions = self.prediction_region(self.smooth_region(self.true_predcition[idx], window_size)) 
            else:
                regions = self.prediction_region(self.true_predcition[idx])

            print(f"id:      {self.true_id[idx]}")
            print(f"regions: {regions}")
            # highlight_targets("".join((str(i) for i in self.true_predcition[idx])), regions)
            highlight_targets(id2seq(seq), regions)

    def compare_smooth(self,  smooth_window_size: int| None = None):
        for idx, seq_ids in enumerate(self.true_seq):
            regions = self.prediction_region(self.true_predcition[idx])

            window_size = smooth_window_size if smooth_window_size is not None else self.smooth_window_size
            smooth_regions = self.prediction_region(self.smooth_region(self.true_predcition[idx], window_size)) 

            print(f"id:       {self.true_id[idx]}")
            print(f"original: {regions}")
            print(f"smooth  : {smooth_regions}")
            seq_bases = id2seq(seq_ids)
            highlight_targets(seq_bases, regions)
            highlight_targets(seq_bases, smooth_regions)

class BatchPredicts:
    def __init__(self, batch_folder, smooth_window_size=9):
        batch_folder = Path(batch_folder)
        self.batch_predicts =  (BatchPredict(batch, smooth_window_size=smooth_window_size)  for batch in (batch_folder).glob("*.pt"))

    def __len__(self):
        return len(self.batch_predicts)

    def __getitem__(self, index):
         return self.batch_predicts[index]

    def __iter__(self):
        return self.batch_predicts

    def __next__(self): 
        return next(self.batch_predicts)

In [241]:
aa = BatchPredicts("/projects/b1171/ylk4626/project/DeepChopper/logs/eval/runs/2024-04-12_16-17-16/predicts/0")

In [247]:
ss =  next(aa)

In [248]:
ss.compare_smooth()

id:       5c26ac76-0eb3-4e08-86a9-ada9eda87631
original: [(2, 4)]
smooth  : []


id:       fca46a73-58ce-4ba2-a124-f63b7e25577e
original: [(204, 208), (209, 210), (225, 226), (227, 229), (230, 233), (478, 480), (768, 769), (770, 772), (831, 835), (837, 838), (844, 845), (846, 880), (882, 886), (887, 891), (892, 894)]
smooth  : [(205, 209), (227, 232), (833, 836), (845, 895)]


id:       fcc29fee-9643-408e-a908-1591b3af34ce
original: [(136, 137), (177, 191), (196, 198), (202, 234)]
smooth  : [(177, 191), (200, 234)]


id:       029d73a1-fb85-4346-ba68-8f0c31e136ae
original: [(975, 1039)]
smooth  : [(975, 1039)]


id:       21eeb936-c54b-4cc2-bd31-8e6532d5f87e
original: [(2019, 2082)]
smooth  : [(2019, 2082)]


id:       1b6a2ba5-36a9-4a62-8440-4798c5fdbed5
original: []
smooth  : []


In [221]:
# heyna_prediction = BatchPredict(heyna_data_folder)
cnn_prediction = BatchPredict(cnn_data_folder)

dict_keys(['prediction', 'target', 'seq', 'qual', 'id'])


In [222]:
cnn_prediction

BatchPredict(batch_size=6)

In [224]:
cnn_prediction.compare_smooth()

id:       648c05db-d8d4-4bba-83a5-0f75420ec680
original: [(350, 351), (355, 407)]
smooth  : [(354, 407)]


id:       358cb29d-0f95-4e14-a9bf-ecf3b71b0376
original: [(65, 69), (71, 78), (81, 84), (85, 89), (94, 96), (97, 98), (126, 149)]
smooth  : [(67, 88), (126, 149)]


id:       6d6d71ca-b490-42cf-9b88-e1ec11475d5d
original: [(3078, 3149)]
smooth  : [(3078, 3149)]


id:       6ea40c43-25ac-4226-850d-16e8ac787685
original: [(570, 649), (650, 666)]
smooth  : [(570, 666)]


id:       032203b6-dbb6-40ae-adf9-13aaa21c4920
original: []
smooth  : []


id:       e765cae0-6b1c-4abc-9e10-9f7458d4ac10
original: [(1160, 1232)]
smooth  : [(1160, 1232)]


In [177]:
heyna_prediction.compare_smooth()

original: []
smooth  : []


original: [(935, 999)]
smooth  : [(935, 999)]


original: [(1029, 1030), (1031, 1090)]
smooth  : [(1030, 1090)]


original: [(141, 166), (167, 173), (425, 426)]
smooth  : [(141, 173)]


original: [(97, 125), (126, 134)]
smooth  : [(97, 136)]


original: []
smooth  : []


In [218]:
cnn_prediction.print_all_seq(smooth=False)

id: 648c05db-d8d4-4bba-83a5-0f75420ec680
[(350, 351), (355, 407)]


id: 358cb29d-0f95-4e14-a9bf-ecf3b71b0376
[(65, 69), (71, 78), (81, 84), (85, 89), (94, 96), (97, 98), (126, 149)]


id: 6d6d71ca-b490-42cf-9b88-e1ec11475d5d
[(3078, 3149)]


id: 6ea40c43-25ac-4226-850d-16e8ac787685
[(570, 649), (650, 666)]


id: 032203b6-dbb6-40ae-adf9-13aaa21c4920
[]


id: e765cae0-6b1c-4abc-9e10-9f7458d4ac10
[(1160, 1232)]


In [219]:
cnn_prediction.print_all_seq(smooth=True)

id: 648c05db-d8d4-4bba-83a5-0f75420ec680
[(354, 407)]


id: 358cb29d-0f95-4e14-a9bf-ecf3b71b0376
[(67, 88), (126, 149)]


id: 6d6d71ca-b490-42cf-9b88-e1ec11475d5d
[(3078, 3149)]


id: 6ea40c43-25ac-4226-850d-16e8ac787685
[(570, 666)]


id: 032203b6-dbb6-40ae-adf9-13aaa21c4920
[]


id: e765cae0-6b1c-4abc-9e10-9f7458d4ac10
[(1160, 1232)]


In [27]:
alignment_predict(majority_voting(cp[0]), cl[0])

In [12]:
hp, hs, hl = test_smooth(heyna_data_foler)

In [44]:
alignment_predict(cp[0], hp[0])

ValueError: zip() argument 2 is shorter than argument 1

In [26]:
smooth_label_region(true_predcition[0], 1, 1, 1)

[(403, 404),
 (407, 408),
 (410, 411),
 (413, 414),
 (1749, 1750),
 (2130, 2146),
 (2148, 2211)]

In [28]:
smooth_label_region(true_predcition[0], 1, 1, 2)

[(2130, 2146), (2148, 2211)]