In [1]:
import argparse
import datetime
import os
import re
import time
import yaml

import multiprocessing as mp

from Bio import GenBank
from Bio.Seq import Seq
from functools import wraps
from tqdm import tqdm

from rna_folding import find_top_hairpins

with open("config.yaml", "r") as file:
    config = yaml.safe_load(file)

# Constants
PATH_TO_OUTPUT = config["path_to_output"]
MIN_GAP_LENGTH = config["min_gap_length"]
MAX_GAP_LENGTH = config["max_gap_length"]
MIN_FLANKS_LENGTH = config["min_flanks_length"]
MAX_FLANKS_LENGTH = config["max_flanks_length"]
MFE_THRESHOLD_HPT = config["mfe_threshold_hpt"]
MFE_THRESHOLD_HPA = config["mfe_threshold_hpa"]

BAR_FORMAT = config["bar_format"]

NUM_PROCESSES = (
    config["num_processes"] or mp.cpu_count()
)  # take maximum available if not specified


class IntervaledFeature:
    def __init__(self, interval, feature):
        self.interval = interval
        self.feature_list = [feature]
        self.feature_num = len(self.feature_list)
        self.feature_lengths = [feature.length]

    def __repr__(self):
        return f"IntervaledFeature(interval={self.interval}, feature_list={self.feature_list}, feature_num={self.feature_num}, feature_lengths={self.feature_lengths})"


class IntervaledGap:
    def __init__(self, interval, features_left, features_right):
        self.interval = interval
        self.features_left = features_left
        self.features_right = features_right

    def __repr__(self):
        return f"IntervaledGap(interval={self.interval}, features_left={self.features_left}, features_right={self.features_right})"


def timeit(func):
    @wraps(func)
    def timeit_wrapper(*args, **kwargs):
        start_time = datetime.datetime.now()
        result = func(*args, **kwargs)
        end_time = datetime.datetime.now()
        total_time = str(end_time - start_time).split(".")[0]
        print(f"Total execution time {total_time}")
        return result

    return timeit_wrapper


def add_output_suffix(file_path):
    directory, base_name = os.path.split(file_path)
    file_name, file_extension = os.path.splitext(base_name)
    new_file_name = f"{file_name}_output{file_extension}"

    return os.path.join(directory, new_file_name)


def modify_first_line(file_path, output_path):
    with open(file_path, "r") as file:
        lines = file.readlines()

    if lines and "LOCUS" in lines[0]:
        first_line = lines[0]
        bp_index = first_line.find("bp")
        if bp_index != -1:
            modified_first_line = first_line[:bp_index] + "0 " + first_line[bp_index:]
            lines[0] = modified_first_line

    with open(output_path, "w") as file:
        file.writelines(lines)


def insert_before_origin(file_path, insert_strings):
    with open(file_path, "r") as file:
        content = file.read()

    origin_index = content.find("ORIGIN")
    if origin_index == -1:
        return

    insert_content = "\n".join(insert_strings) + "\n"

    modified_content = content[:origin_index] + insert_content + content[origin_index:]

    with open(file_path, "w") as file:
        file.write(modified_content)


def merge_intervals(intervaled_features):
    sorted_intervaled_features = sorted(
        intervaled_features, key=lambda x: x.interval[0]
    )

    merged_intervaled_features = []

    for intervaled_feature in tqdm(
        sorted_intervaled_features, desc="Merging intervals...", bar_format=BAR_FORMAT
    ):
        if (
            not merged_intervaled_features
            or merged_intervaled_features[-1].interval[1]
            < intervaled_feature.interval[0]
        ):
            merged_intervaled_features.append(intervaled_feature)
        else:
            merged_intervaled_features[-1].interval[1] = max(
                merged_intervaled_features[-1].interval[1],
                intervaled_feature.interval[1],
            )
            if (
                intervaled_feature.feature_list[0]
                not in merged_intervaled_features[-1].feature_list
            ):
                merged_intervaled_features[-1].feature_list.append(
                    intervaled_feature.feature_list[0]
                )

                merged_intervaled_features[-1].feature_num += 1
                merged_intervaled_features[-1].feature_lengths.append(
                    intervaled_feature.feature_lengths[0]
                )

    return merged_intervaled_features


def find_uncovered_intervals(interval, intervaled_features):

    start, end = interval
    intervaled_features.sort(key=lambda x: x.interval)  # Sort by intervals

    uncovered_intervals = []
    current_start = start
    prev_feature = None

    for feature in tqdm(
        intervaled_features,
        desc="Finding uncovered intervals...",
        bar_format=BAR_FORMAT,
    ):
        i_start, i_end = feature.interval
        if i_end <= start:
            prev_feature = feature
            continue
        if i_start >= end:
            break

        if i_start > current_start and (i_start - 1) != current_start:
            uncovered_interval = (current_start + 1, i_start - 1)
            features_left = (
                prev_feature
                if prev_feature and prev_feature.interval[1] <= uncovered_interval[0]
                else None
            )
            features_right = (
                feature if feature.interval[0] >= uncovered_interval[1] else None
            )
            uncovered_intervals.append(
                IntervaledGap(uncovered_interval, features_left, features_right)
            )

        current_start = max(current_start, i_end)
        prev_feature = feature

    if current_start < end:
        uncovered_interval = (current_start + 1, end - 1)
        features_left = (
            prev_feature
            if prev_feature and prev_feature.interval[1] <= uncovered_interval[0]
            else None
        )
        features_right = None
        uncovered_intervals.append(
            IntervaledGap(uncovered_interval, features_left, features_right)
        )

    return uncovered_intervals


def filter_intervals_by_length(intervals, min_length, max_length):
    filtered_intervals = []
    for intervaled_gap in tqdm(
        intervals, desc="Filtering intervals by length...", bar_format=BAR_FORMAT
    ):
        start, end = intervaled_gap.interval[0], intervaled_gap.interval[1]
        length = end - start

        if min_length <= length + 1 <= max_length:
            filtered_intervals.append(intervaled_gap)

    return filtered_intervals


def filter_intervals_by_flanking_legth(intervals, min_flanks_length, max_flanks_length):
    filtered_intervals = []
    for intervaled_gap in tqdm(
        intervals,
        desc="Filtering intervals by flanking genes...",
        bar_format=BAR_FORMAT,
    ):
        length_left = intervaled_gap.features_left.feature_lengths[-1]
        length_right = intervaled_gap.features_right.feature_lengths[0]

        if (
            min_flanks_length <= length_left + 1 <= max_flanks_length
            and min_flanks_length <= length_right + 1 <= max_flanks_length
        ):
            filtered_intervals.append(intervaled_gap)

    return filtered_intervals


def create_uncovered_intervals_str(intervals):

    return [
        f"     gap             {interval.interval[0]}..{interval.interval[1]}"
        for interval in intervals
    ]


def create_hairpins_str(hairpins, type_of_hairpin, complement=False):
    if complement:

        output_str_list = [
            f"     {type_of_hairpin}             complement({hairpin[0]}..{hairpin[1]})\n"
            + f"                     /sequence={hairpin[2]}\n"
            + f"                     /structure={hairpin[3]}\n"
            + f"                     /mfe={hairpin[4]}"
            for hairpin in hairpins
        ]

    else:
        output_str_list = [
            f"     {type_of_hairpin}             {hairpin[0]}..{hairpin[1]}\n"
            + f"                     /sequence={hairpin[2]}\n"
            + f"                     /structure={hairpin[3]}\n"
            + f"                     /mfe={hairpin[4]}"
            for hairpin in hairpins
        ]

    return output_str_list

In [13]:
import re
from Bio import GenBank

In [46]:
from insertion_site_finder_2 import (
    modify_first_line,
    IntervaledFeature,
    merge_intervals,
    find_uncovered_intervals,
    filter_intervals_by_length,
    create_uncovered_intervals_str,
    insert_before_origin,
)

In [30]:
record_list = []

modify_first_line(
    "Winter wheat cons/Bacillus halotolerans_AOP29/out/AOP29.gbk",
    "Winter wheat cons/Bacillus halotolerans_AOP29/out/AOP29_modified.gbk",
)

with open(
    "Winter wheat cons/Bacillus halotolerans_AOP29/out/AOP29_modified.gbk"
) as handle:
    for record in GenBank.parse(handle):
        record_list.append(record)

In [31]:
feature_list = []


for feature in record_list[0].features:
    feature_list.append(feature)

In [32]:
intervaled_features_list = []

for feature in feature_list:
    location = feature.location.replace("complement(", "").replace(")", "").split("..")
    start, stop = int(location[0]), int(location[1])
    feature.length = stop - start

    intervaled_features_list.append(IntervaledFeature([start, stop], feature))

In [33]:
len(merge_intervals(intervaled_features_list[1:]))

3522

In [34]:
merged_intervals = merge_intervals(intervaled_features_list[1:])

In [35]:
[feature for feature in merged_intervals if feature.feature_num > 1]

[IntervaledFeature(interval=[16263, 16912], feature_list=[Feature(key='CDS', location='16263..16748'), Feature(key='CDS', location='16745..16912')], feature_num=2, feature_lengths=[485, 167]),
 IntervaledFeature(interval=[22048, 24287], feature_list=[Feature(key='CDS', location='complement(22048..23505)'), Feature(key='CDS', location='complement(23502..24287)')], feature_num=2, feature_lengths=[1457, 785]),
 IntervaledFeature(interval=[41961, 43605], feature_list=[Feature(key='CDS', location='41961..42221'), Feature(key='CDS', location='42218..43204'), Feature(key='CDS', location='complement(43201..43605)')], feature_num=3, feature_lengths=[260, 986, 404]),
 IntervaledFeature(interval=[61896, 64197], feature_list=[Feature(key='CDS', location='61896..63497'), Feature(key='CDS', location='63490..64197')], feature_num=2, feature_lengths=[1601, 707]),
 IntervaledFeature(interval=[65885, 68456], feature_list=[Feature(key='CDS', location='65885..67417'), Feature(key='CDS', location='67410..6

In [36]:
uncovered_intervals = find_uncovered_intervals(
    intervaled_features_list[0].interval, merged_intervals
)

In [37]:
len(uncovered_intervals)

3493

In [38]:
uncovered_intervals

[IntervaledGap(interval=(914, 1595), features_left=IntervaledFeature(interval=[815, 914], feature_list=[Feature(key='assembly_gap', location='815..914')], feature_num=1, feature_lengths=[99]), features_right=IntervaledFeature(interval=[1595, 2173], feature_list=[Feature(key='CDS', location='1595..2173')], feature_num=1, feature_lengths=[578])),
 IntervaledGap(interval=(2173, 2215), features_left=IntervaledFeature(interval=[1595, 2173], feature_list=[Feature(key='CDS', location='1595..2173')], feature_num=1, feature_lengths=[578]), features_right=IntervaledFeature(interval=[2215, 2739], feature_list=[Feature(key='CDS', location='complement(2215..2739)')], feature_num=1, feature_lengths=[524])),
 IntervaledGap(interval=(2739, 2766), features_left=IntervaledFeature(interval=[2215, 2739], feature_list=[Feature(key='CDS', location='complement(2215..2739)')], feature_num=1, feature_lengths=[524]), features_right=IntervaledFeature(interval=[2766, 4295], feature_list=[Feature(key='CDS', locati

In [39]:
filtered_intervals = filter_intervals_by_length(uncovered_intervals, 0, 150)

In [40]:
len(filtered_intervals)

2406

In [41]:
filtered_intervals[74][1] - filtered_intervals[74][0]

146

In [42]:
uncovered_intervals_str = create_uncovered_intervals_str(filtered_intervals)

In [43]:
uncovered_intervals_str

['     gap             2173..2215',
 '     gap             2739..2766',
 '     gap             4295..4314',
 '     gap             5494..5500',
 '     gap             6078..6168',
 '     gap             7376..7393',
 '     gap             11409..11416',
 '     gap             12753..12789',
 '     gap             13052..13160',
 '     gap             16160..16263',
 '     gap             16912..16952',
 '     gap             17024..17132',
 '     gap             17962..18054',
 '     gap             20387..20428',
 '     gap             21702..21718',
 '     gap             22032..22048',
 '     gap             24287..24345',
 '     gap             26414..26558',
 '     gap             28546..28659',
 '     gap             30644..30771',
 '     gap             35814..35827',
 '     gap             36081..36106',
 '     gap             37137..37152',
 '     gap             37550..37669',
 '     gap             40805..40849',
 '     gap             41526..41571',
 '     gap             4

In [49]:
insert_before_origin(
    "Winter wheat cons/Bacillus halotolerans_AOP29/out/AOP29_modified.gbk",
    uncovered_intervals_str,
)

In [190]:
uncovered_intervals

[(914, 1595),
 (2173, 2215),
 (2739, 2766),
 (4295, 4314),
 (4838, 5006),
 (5494, 5500),
 (6078, 6168),
 (7376, 7393),
 (8865, 9066),
 (9608, 9816),
 (10361, 10741),
 (11409, 11416),
 (12753, 12789),
 (13052, 13160),
 (14008, 14159),
 (15721, 15975),
 (16160, 16263),
 (16912, 16952),
 (17024, 17132),
 (17962, 18054),
 (19220, 19401),
 (20387, 20428),
 (21702, 21718),
 (22032, 22048),
 (24287, 24345),
 (26414, 26558),
 (28546, 28659),
 (30644, 30771),
 (32759, 32934),
 (34922, 35077),
 (35814, 35827),
 (36081, 36106),
 (37137, 37152),
 (37550, 37669),
 (39333, 39516),
 (40805, 40849),
 (41526, 41571),
 (41834, 41961),
 (43605, 43666),
 (44037, 44096),
 (45448, 45560),
 (46732, 46835),
 (47998, 48240),
 (48476, 48605),
 (49957, 50093),
 (50485, 50688),
 (51770, 51849),
 (52349, 52479),
 (53324, 53351),
 (53611, 53699),
 (54862, 54988),
 (56274, 56320),
 (56706, 56733),
 (57407, 57684),
 (58742, 58835),
 (60709, 60730),
 (61143, 61161),
 (61718, 61896),
 (64197, 64625),
 (65752, 65885),
 

In [88]:
wrong_covered = 0
for interval in uncovered_intervals:
    if interval[1] - 1 == interval[1]:
        wrong_covered += 1

In [89]:
wrong_covered

0

In [65]:
len(interval_list)

4040

In [69]:
merged_intervals

[[1, 814],
 [815, 914],
 [1595, 2173],
 [2215, 2739],
 [2766, 4295],
 [4314, 4838],
 [5006, 5494],
 [5500, 6078],
 [6168, 7376],
 [7393, 8865],
 [9066, 9608],
 [9816, 10361],
 [10741, 11409],
 [11416, 12753],
 [12789, 13052],
 [13160, 14008],
 [14159, 15721],
 [15975, 16160],
 [16263, 16912],
 [16952, 17024],
 [17132, 17962],
 [18054, 19220],
 [19401, 20387],
 [20428, 21702],
 [21718, 22032],
 [22048, 24287],
 [24345, 26414],
 [26558, 28546],
 [28659, 30644],
 [30771, 32759],
 [32934, 34922],
 [35077, 35814],
 [35827, 36081],
 [36106, 37137],
 [37152, 37550],
 [37669, 39333],
 [39516, 40805],
 [40849, 41526],
 [41571, 41834],
 [41961, 43605],
 [43666, 44037],
 [44096, 45448],
 [45560, 46732],
 [46835, 47998],
 [48240, 48476],
 [48605, 49957],
 [50093, 50485],
 [50688, 51770],
 [51849, 52349],
 [52479, 53324],
 [53351, 53611],
 [53699, 54862],
 [54988, 56274],
 [56320, 56706],
 [56733, 57407],
 [57684, 58742],
 [58835, 60709],
 [60730, 61143],
 [61161, 61718],
 [61896, 64197],
 [64625, 

In [72]:
interval_list[-1]

[4018183, 4018855]

In [58]:
passed_intervals_list = []
intersections = 0

for indx, interval in enumerate(interval_list):

    if indx == 0 or indx == 1:
        next

    start = interval[0]
    stop_prev = interval_list[indx - 1][1]

    if start > stop_prev:
        intersections += 1

In [59]:
intersections

3521

In [57]:
interval_list

[[1, 4018855],
 [1, 814],
 [815, 914],
 [1595, 2173],
 [2215, 2739],
 [2766, 4295],
 [4314, 4838],
 [5006, 5494],
 [5500, 6078],
 [6168, 7376],
 [7393, 8865],
 [9066, 9608],
 [9816, 10361],
 [10741, 11409],
 [11416, 12753],
 [12789, 13052],
 [13160, 14008],
 [14159, 15721],
 [15975, 16160],
 [16263, 16748],
 [16745, 16912],
 [16952, 17024],
 [17132, 17962],
 [18054, 19220],
 [19401, 20387],
 [20428, 21702],
 [21718, 22032],
 [22048, 23505],
 [23502, 24287],
 [24345, 26414],
 [26558, 28546],
 [28659, 30644],
 [30771, 32759],
 [32934, 34922],
 [35077, 35814],
 [35827, 36081],
 [36106, 37137],
 [37152, 37550],
 [37669, 39333],
 [39516, 40805],
 [40849, 41526],
 [41571, 41834],
 [41961, 42221],
 [42218, 43204],
 [43201, 43605],
 [43666, 44037],
 [44096, 45448],
 [45560, 46732],
 [46835, 47998],
 [48240, 48476],
 [48605, 49957],
 [50093, 50485],
 [50688, 51770],
 [51849, 52349],
 [52479, 53324],
 [53351, 53611],
 [53699, 54862],
 [54988, 56274],
 [56320, 56706],
 [56733, 57407],
 [57684, 58

In [48]:
sum(lengths_list)

2049

In [45]:
feature_list[3].location.split("..")[0] > feature_list[3].location.split("..")[1]

False

In [40]:
dir(feature_list[3])

['__class__',
 '__delattr__',
 '__dict__',
 '__dir__',
 '__doc__',
 '__eq__',
 '__format__',
 '__ge__',
 '__getattribute__',
 '__getstate__',
 '__gt__',
 '__hash__',
 '__init__',
 '__init_subclass__',
 '__le__',
 '__lt__',
 '__module__',
 '__ne__',
 '__new__',
 '__reduce__',
 '__reduce_ex__',
 '__repr__',
 '__setattr__',
 '__sizeof__',
 '__str__',
 '__subclasshook__',
 '__weakref__',
 'key',
 'location',
 'qualifiers']

In [46]:
sum([False, True, True])

2

In [51]:
"complement".replace("(", "")

'complement'

In [2]:
def split_sequence(sequence, window_size=10e6, overlap=10e3):
    step_size = int(window_size - overlap)
    subsequences = []
    for i in range(0, len(sequence), step_size):
        end_index = i + window_size
        if end_index > len(sequence):
            end_index = len(sequence)
        subsequences.append(sequence[i:end_index])
        if end_index == len(sequence):
            break
    return subsequences

In [4]:
split_sequence("AAAAATTTTTGGGGGGGGGGGGGGCCCCCCCC", 5, 3)

['AAAAA',
 'AAATT',
 'ATTTT',
 'TTTTG',
 'TTGGG',
 'GGGGG',
 'GGGGG',
 'GGGGG',
 'GGGGG',
 'GGGGG',
 'GGGGC',
 'GGCCC',
 'CCCCC',
 'CCCCC',
 'CCCC']

In [6]:
len("UCAGCUCCUCGUGCGCGUGAUGAGAAGUAAGCGGAGGAGCGGA")

43