In [24]:
import torch 
from deepchopper import  remove_intervals_and_keep_left, smooth_label_region, summary_predict, get_label_region

In [3]:
from deepchopper.utils import alignment_predict, highlight_target, highlight_targets

In [4]:
from pathlib import Path 

In [178]:
# cnn_data_folder = "/projects/b1171/ylk4626/project/DeepChopper/logs/eval/runs/2024-04-09_13-56-19/predicts/0/0.pt"
# cnn check point /projects/b1171/ylk4626/project/DeepChopper/logs/train/runs/2024-04-07_12-01-37/checkpoints/epoch_036_f1_0.9914.ckpt 

# /projects/b1171/ylk4626/project/DeepChopper/logs/eval/runs/2024-04-12_15-53-06

# cnn_data_folder = "/projects/b1171/ylk4626/project/DeepChopper/logs/eval/runs/2024-04-12_12-58-22/predicts/0/0.pt"  # real data
cnn_data_folder  = "/projects/b1171/ylk4626/project/DeepChopper/logs/eval/runs/2024-04-12_15-53-06/predicts/0/0.pt"

# heyna check point  /projects/b1171/ylk4626/project/DeepChopper/logs/train/runs/2024-04-09_20-13-03/checkpoints/epoch_007_f1_0.9931.ckpt
heyna_data_folder = "/projects/b1171/ylk4626/project/DeepChopper/logs/eval/runs/2024-04-12_14-19-46/predicts/0/0.pt"

In [183]:
def test_smooth(data_folder):
    prediction = torch.load(data_folder)
    true_predcition, true_label = summary_predict(prediction['prediction'].argmax(-1).numpy(), prediction['target'].numpy(), -100)
    true_seq, true_label = summary_predict(prediction['seq'].numpy(), prediction['target'].numpy(), -100)
    return true_predcition, true_seq, true_label


def id2seq(ids: list[int]):
    # A', 'C', 'G', 'T', 'N'
    table =  {
        7: "A", 
        8: "C",
        9: "G",
        10: "T",
        11: "N"}
    return ''.join((table[c] for c in ids))


def majority_voting(labels, window_size):
    # Ensure window size is odd to have a central token
    if window_size % 2 == 0:
        window_size += 1

    half_window = window_size // 2
    smoothed_labels = []
    
    for i in range(len(labels)):
        # Extract the context window
        start = max(0, i - half_window)
        end = min(len(labels), i + half_window + 1)
        window = labels[start:end]     
        # Choose the most common label in the window
        most_common = max(set(window), key=window.count)
        smoothed_labels.append(most_common)

    return smoothed_labels

def ascii_values_to_string(ascii_values):
    return ''.join(chr(value) for value in ascii_values)

def convert_id_str(ids):  
    return (ascii_values_to_string(i[1: i[0]+1]) for i in ids)    

class BatchPredict:
    def __init__(self, batch_prediction, smooth_window_size=9):
        self.smooth_window_size = smooth_window_siz
        self.data  = torch.load(batch_prediction)
        self.batch_size = self.data['seq'].shape[0]
        self.true_predcition, self.true_label = summary_predict(self.data['prediction'].argmax(-1).numpy(), self.data['target'].numpy(), -100)
        self.true_seq, _true_label = summary_predict(self.data['seq'].numpy(), self.data['target'].numpy(), -100)
        self.true_id  = convert_id_str(self.data['id'])

    def __repr__(self):
        return f"{__class__.__name__}(batch_size={self.batch_size})"

    @staticmethod
    def prediction_region(predict):
        return get_label_region(predict)

    @staticmethod
    def smooth_region(predict, smooth_window_size: int):
        return majority_voting(predict, smooth_window_size)

    def print_all_seq(self, *, smooth=False, smooth_window_size: int| None =None):
        for idx, seq in enumerate(self.true_seq):
            if smooth:
                window_size = smooth_window_size if smooth_window_size is not None else self.smooth_window_size
                regions = self.prediction_region(self.smooth_region(self.true_predcition[idx], window_size)) 
            else:
                regions = self.prediction_region(self.true_predcition[idx])

            print(f"{regions}")
            # highlight_targets("".join((str(i) for i in self.true_predcition[idx])), regions)
            highlight_targets(id2seq(seq), regions)

    def compare_smooth(self,  smooth_window_size: int| None = None):
        for idx, seq_ids in enumerate(self.true_seq):
            regions = self.prediction_region(self.true_predcition[idx])

            window_size = smooth_window_size if smooth_window_size is not None else self.smooth_window_size
            smooth_regions = self.prediction_region(self.smooth_region(self.true_predcition[idx], window_size)) 

            print(f"original: {regions}")
            print(f"smooth  : {smooth_regions}")
            seq_bases = id2seq(seq_ids)
            highlight_targets(seq_bases, regions)
            highlight_targets(seq_bases, smooth_regions)

In [184]:
heyna_prediction = BatchPredict(heyna_data_folder)
cnn_prediction = BatchPredict(cnn_data_folder)

In [185]:
cnn_prediction

BatchPredict(batch_size=6)

In [189]:
cnn_prediction.data['id'].cpu()


tensor([[102,  57, 100,  ...,   0,   0,   0],
        [ 48,  48,  98,  ...,   0,   0,   0],
        [ 48,  55,  51,  ...,   0,   0,   0],
        [ 48,  99, 101,  ...,   0,   0,   0],
        [ 49,  50,  56,  ...,   0,   0,   0],
        [ 50,  48,  48,  ...,   0,   0,   0]], dtype=torch.int8)

In [186]:
cnn_prediction.compare_smooth()

original: [(1356, 1357), (1571, 1573), (1574, 1576), (1912, 1978)]
smooth  : [(1912, 1978)]


original: [(1060, 1122)]
smooth  : [(1060, 1122)]


original: [(646, 709)]
smooth  : [(646, 709)]


original: [(896, 902), (910, 963)]
smooth  : [(896, 902), (910, 963)]


original: [(1251, 1337)]
smooth  : [(1251, 1337)]


original: [(778, 833)]
smooth  : [(778, 833)]


In [177]:
heyna_prediction.compare_smooth()

original: []
smooth  : []


original: [(935, 999)]
smooth  : [(935, 999)]


original: [(1029, 1030), (1031, 1090)]
smooth  : [(1030, 1090)]


original: [(141, 166), (167, 173), (425, 426)]
smooth  : [(141, 173)]


original: [(97, 125), (126, 134)]
smooth  : [(97, 136)]


original: []
smooth  : []


In [163]:
cnn_prediction.print_all_seq(smooth=False)

[]


[(329, 397)]


[(845, 849), (935, 999)]


[(458, 459), (460, 520)]


[(1019, 1090)]


[(926, 961), (962, 1012)]


[(8, 20), (23, 26), (27, 69), (70, 71), (74, 75), (76, 82), (85, 87), (100, 148), (151, 152), (153, 154), (160, 161), (163, 164), (169, 170), (171, 173), (218, 221), (226, 227), (233, 258), (260, 261), (262, 275), (283, 293), (296, 297), (299, 300), (302, 303), (304, 308), (309, 311), (316, 319), (321, 323), (325, 382), (383, 391), (392, 394), (402, 403), (404, 405), (406, 407), (410, 412), (417, 418), (419, 420), (430, 442)]


[(360, 415)]


[(77, 112), (113, 115), (116, 120), (123, 136)]


[(363, 365), (366, 371), (377, 389), (399, 427)]


[(215, 217), (305, 306), (313, 315), (317, 345), (346, 347), (348, 356)]


[(114, 120), (283, 284), (285, 286), (397, 398)]


In [157]:
cnn_prediction.print_all_seq(smooth=True)

[]


[(329, 397)]


[(935, 999), (935, 999)]


[(458, 459), (458, 459)]


[(1019, 1090)]


[(962, 1012), (962, 1012)]


[(70, 71), (70, 71), (70, 71), (70, 71), (70, 71), (70, 71), (76, 82), (163, 164), (163, 164), (163, 164), (163, 164), (163, 164), (163, 164), (163, 164), (163, 164), (163, 164), (233, 258), (299, 300), (299, 300), (299, 300), (299, 300), (299, 300), (299, 300), (299, 300), (299, 300), (299, 300), (321, 323), (321, 323), (321, 323), (321, 323), (321, 323), (402, 403), (402, 403), (402, 403), (402, 403), (406, 407), (406, 407)]


[(360, 415)]


[(77, 112), (77, 112), (77, 112), (77, 112)]


[(363, 365), (363, 365), (363, 365), (363, 365)]


[(346, 347), (346, 347), (346, 347), (346, 347), (346, 347), (346, 347)]


[(114, 120), (114, 120), (114, 120), (114, 120)]


In [27]:
alignment_predict(majority_voting(cp[0]), cl[0])

In [12]:
hp, hs, hl = test_smooth(heyna_data_foler)

In [44]:
alignment_predict(cp[0], hp[0])

ValueError: zip() argument 2 is shorter than argument 1

In [26]:
smooth_label_region(true_predcition[0], 1, 1, 1)

[(403, 404),
 (407, 408),
 (410, 411),
 (413, 414),
 (1749, 1750),
 (2130, 2146),
 (2148, 2211)]

In [28]:
smooth_label_region(true_predcition[0], 1, 1, 2)

[(2130, 2146), (2148, 2211)]