In [14]:
import itertools
import sys
import random
import time
from collections import defaultdict
import heapq


In [15]:
###RADING FASTA###
def read_fasta(filename):
    """
    Read the first sequence from a FASTA file
    Returns the sequence as a string (uppercase)
    """
    with open(filename, 'r') as f:
        lines = f.readlines()

    seq_lines =[line.strip() for line in lines if not line.startswith('>')]
    genome_seq = "".join(seq_lines).upper()
    return genome_seq


In [16]:
###GENERATING READS###
def Generate_error_free_reads(genome_seq, N, l):
    """
    in this function, given genome sequence we are generating N error-free reads of length l.
    the reads are randomly sampled across the genome to achieve uniform coverage (on average)
    :param genome_seq: string, the full PhiX genome
    :param N: int, the number of reads
    :param l: int, the length of each read
    :return: list of read strings
    """
    Genome_len = len(genome_seq)
    circular_seq = genome_seq + genome_seq
    reads =[]

    for i in range(N):
        start_position = random.randint(0, Genome_len - 1)
        curr_read = circular_seq[start_position : start_position + l]
        reads.append(curr_read)
    return reads

def mismatch_base(base):
    """
    given a base return a different base
    """
    bases =['A', 'C', 'G', 'T']
    bases.remove(base)
    return random.choice(bases)

def generate_error_prone_reads(genome_seq, N, l, p):
    """
    in this function, given genome sequence we are generating N error-prone reads of length l.
    where each base is mutated with probability p.
    :param genome_seq: string, the full PhiX genome
    :param N: int, the number of reads
    :param l: int, the length of each read
    :param p: float, the probability of mutating each base
    :return: list of read strings
    """
    Genome_len = len(genome_seq)
    circular_seq = genome_seq + genome_seq
    reads = []
    for i in range(N):
        start_position = random.randint(0, Genome_len - 1)
        curr_read_list = list(circular_seq[start_position: start_position + l])

        for j in range(l):
            if random.random() < p:
                curr_read_list[j] = mismatch_base(curr_read_list[j])
        reads.append("".join(curr_read_list))
    return reads

def generate_reads(genome_seq, N, l, p=0.0):
    """
    If p=0, return error-free reads.
    Otherwise, return error-prone reads.
    """
    if p == 0.0:
        return Generate_error_free_reads(genome_seq, N, l)
    else:
        return generate_error_prone_reads(genome_seq, N, l, p)

In [17]:
###K-MER INDEXING###
def kmers_extraction(read,k):
    """
    in this function we are extracting all k-mers from a given read
    and returning a list of tuples (k-mer, starting_position)
    """
    kmers =[]
    for i in range(len(read)-k +1):
        kmer=read[i:i+k]
        kmers.append((kmer,i))
    return kmers

def kmer_index_build(reads,k):
    """
    this function creates a dictionary that maps k-mers
    to the reads that contains them, returns a dictionary
    where keys are k-mers and values are set of reads indices.
    """
    kmer_index={}
    for i, read in enumerate(reads):
        kmers = kmers_extraction(read,k)

        for kmer,_ in kmers:
            if kmer not in kmer_index:
                kmer_index[kmer] = set()
            kmer_index[kmer].add(i)
    return kmer_index



In [18]:
### one-pass EROOR CORRECTION###
def build_kmer_frequency(reads, k):
    """
    in this function we calculate how many times each k-mer appears in the entire set of
    reads. we'll return a dictionary where the keys are k-mers and the values are their total counts
    across all reads
    """
    count = defaultdict(int)
    for r in reads:
        for i in range(len(r)-k+1):
            kmer =r[i:i+k]
            count[kmer] += 1

    return count


def naive_error_correction(reads, k=5, threshold=2):
    """
    this is a single pass naive error correction:
    first we build a global k-mer frequency across all reads,the for each read
    for each k-mer region: if the frequency <threshold, single base changes is applied to find a higher
    frequency k-mer, if a beneficial substitution found apply it
    :return corrected reads
    """
    frequency_map = build_kmer_frequency(reads, k)
    corrected_reads = []

    # iterate over each read
    for org_read in reads:
        read_list = list(org_read)
        read_length = len(org_read)

        #slide a window of length k across the read
        for start_position in range(read_length - k + 1):
            curr_kmer =org_read[start_position:start_position + k]
            curr_freq = frequency_map[curr_kmer]

            #consider substitutions only if current k-mer is below threshold
            if curr_freq < threshold:
                best_improvement = curr_freq
                best_subtitution = None

                #attempt single base substitutions whithin this k-mer
                for offset in range(k):
                    old_base = read_list[start_position + offset]
                    for alternate_base in 'ACGT':
                        if alternate_base == old_base:
                            continue
                        #temporarily substitute
                        read_list[start_position + offset] = alternate_base

                        #form a new k-mer and check frequency
                        new_kmer = "".join(read_list[start_position:start_position+k])
                        new_freq = frequency_map[new_kmer]
                        if new_freq > best_improvement:
                            best_improvement = new_freq
                            best_subtitution = (start_position + offset, alternate_base)

                    # Revert the single base substitution before moving on
                    read_list[start_position + offset] = old_base

                if best_subtitution is not None:
                    index_in_read, new_base = best_subtitution
                    read_list[index_in_read] = new_base
        #re-assemble the corrected read
        corrected_reads.append("".join(read_list))

    return corrected_reads

In [19]:
###overlap###
def overlap(read1, read2, min_overlap=1):
    """
    given two reads, Return the length of the maximum overlap between
    the suffix of read1 and the prefix of read2. If no overlap of at least min_overlap return 0
    """
    maximum_len = min(len(read1), len(read2))
    for length in range(maximum_len, min_overlap -1, -1):
        if read1.endswith(read2[:length]):
            return length
    return 0

def candidate_overlaps(reads,kmer_index,k):
    """
    this function identifies pairs of reads that are likely to overlap
    by checking for shared k-mers. so instead of comparing all reads pairwise
    it quickly filters candidates.
    :param reads: list of read strings
    :param kmer_index: dictionary that maps k-mers to their index
    :param k: k-mers length
    :return: Dictionary where keys are read indices and values are
             set of candidate overlapping reads
    """
    candidates ={}

    for i,read in enumerate(reads):
        suffix =read[-k:]
        prefix =read[:k]

        for j in kmer_index.get(suffix,[]):
            if i!=j:
                if i not in candidates:
                    candidates[i]=set()
                candidates[i].add(j)

        for j in kmer_index.get(prefix,[]):
            if i!=j:
                if j not in candidates:
                    candidates[j]=set()
                candidates[j].add(i)

    return candidates

def build_overlap_edges(reads, candidates, min_ovl=5):
    """
    For each candidate pair, compute overlap length.
    Return list of edges as [(-ovl_len, i, j)] for a max-heap.
    """
    edges = []
    for i, candidate_set in candidates.items():
        for j in candidate_set:
            overlap_len = overlap(reads[i], reads[j], min_ovl)
            if overlap_len > 0:
                # store as negative so a min heap can pop the largest overlap first
                edges.append((-overlap_len, i, j))
    return edges



In [20]:
def global_greedy_assemble(reads, edges, min_ovl=5):
    """
    Use a max-heap of all edges, pick largest overlap, merge reads, update edges, repeat.
    """

    active_indices = set(range(len(reads)))

    heapq.heapify(edges)

    while edges:
        negative_ovl, i, j = heapq.heappop(edges)
        overlap_len = -negative_ovl

        if overlap_len < min_ovl:
            break

        #if either reads is no linger active,skip
        if i not in active_indices or j not in active_indices:
            continue

        # Merge the two reads
        merged_read = reads[i] + reads[j][overlap_len:]

        #remove the old reads from active status
        active_indices.remove(i)
        active_indices.remove(j)

        new_idx = len(reads)
        reads.append(merged_read)
        active_indices.add(new_idx)

        # Now compute the new overlaps between this new read and all others
        new_edges = []
        for a in list(active_indices):
            if a == new_idx:
                continue

            ov1 = overlap(reads[new_idx], reads[a], min_ovl)
            if ov1 > 0:
                new_edges.append((-ov1, new_idx, a))

            ov2 = overlap(reads[a], reads[new_idx], min_ovl)
            if ov2 > 0:
                new_edges.append((-ov2, a, new_idx))

        for edge in new_edges:
            heapq.heappush(edges, edge)

    final_contigs = [reads[idx] for idx in active_indices]
    return final_contigs



In [21]:
###circular trimming###
def trim_circular(contig, expected_length, min_search=50):
    """
    Attempt to trim any duplicated region in a circular genome.
    If 'contig' is longer than 'expected_length', we suspect
    there's a repeated wrap-around region.
    We search for the largest prefix that matches the suffix.
    min_search is a minimum overlap size to look for.
    Return a trimmed contig if duplication is found, else return original.
    """
    #only trim contigs longer then the expected_length (~5386 bp)
    if len(contig) <= expected_length:
        return contig

    excess_len = len(contig) - expected_length
    max_check = min(len(contig)//2, excess_len + 200)

    for check in range(max_check, min_search-1, -1):
        prefix_sub = contig[:check]
        suffix_sub = contig[-check:]
        if prefix_sub == suffix_sub:
            return contig[:-check]
    return contig

In [22]:
def compute_coverage(N, l, genome_len):
    """
    Compute average coverage: (N * l) / genome_length
    """
    return (N * l) / float(genome_len)


In [23]:
def basic_performance_metrics(contigs, reference):
    num_contigs = len(contigs)
    lengths = [len(c) for c in contigs]
    longest = max(lengths) if lengths else 0
    total_assembled_length = sum(lengths)
    return {
        "num_contigs": num_contigs,
        "longest_contig": longest,
        "total_assembled_length": total_assembled_length,
        "reference_length": len(reference),
    }

In [24]:
def main(fasta_file,
         N_values = [250, 500, 1000],
         l_values = [50, 100, 150],
         p_values = [0.0, 0.01],
         min_ovl =5):
    reference_genome = read_fasta(fasta_file)
    G = len(reference_genome)
    print(f"Reference genome length: {G}")

    print("\n=== K-mer Filter,Error Correction, Global Greedy Merge + Circular Trim ===")
    print(f"{'N':>6} {'l':>6} {'p':>6} {'Coverage':>8} {'#Contigs':>8} {'Longest':>8} "
      f"{'TotalAsm':>9} {'TimeBuild':>6} {'TimeAssemble':>8}")

    for (N, l, p) in itertools.product(N_values, l_values, p_values):
        coverage = compute_coverage(N, l, G)

        # 1) Generate reads
        t0 = time.time()
        reads = generate_reads(reference_genome, N, l, p)

        # 2) Optional error correction if p>0
        if p > 0.0:
            reads = naive_error_correction(reads, k=5, threshold=2)
        gen_time = time.time() - t0

        # 3) Build k-mer index & find candidate pairs
        t0 = time.time()
        k_size = max(5, l // 3)
        k_idx = kmer_index_build(reads, k_size)
        cands = candidate_overlaps(reads, k_idx, k_size)
        edges = build_overlap_edges(reads, cands, min_ovl)
        build_time = time.time() - t0

        # 4) Global Greedy assembly
        t0 = time.time()
        contigs = global_greedy_assemble(reads, edges, min_ovl)
        asm_time = time.time() - t0

        # 5) Circular trimming
        trimmed_contigs = []
        for c in contigs:
            trimmed_c = trim_circular(c, G, min_search=50)
            trimmed_contigs.append(trimmed_c)

        # 6) Performance
        metrics = basic_performance_metrics(trimmed_contigs, reference_genome)
        num_contigs = metrics["num_contigs"]
        longest_contig = metrics["longest_contig"]
        total_asm_len = metrics["total_assembled_length"]

        print(f"{N:>6} {l:>6} {p:>6} {coverage:>8.2f} {num_contigs:>8} {longest_contig:>8} "
          f"{total_asm_len:>9} {build_time:>6.2f} {asm_time:>8.2f}")


In [25]:
main("sequence.fasta")

Reference genome length: 5386

=== K-mer Filter,Error Correction, Global Greedy Merge + Circular Trim ===
     N      l      p Coverage #Contigs  Longest  TotalAsm TimeBuild TimeAssemble
   250     50    0.0     2.32       25      501      5073   0.03     2.82
   250     50   0.01     2.32      101      326      8419   0.04     1.68
   250    100    0.0     4.64        5     2871      5268   0.10     2.67
   250    100   0.01     4.64       94      580     16818   0.03     2.34
   250    150    0.0     6.96        1     5402      5402   0.03     6.95
   250    150   0.01     6.96      118      925     29196   0.11     5.24
   500     50    0.0     4.64        4     2068      5346   0.06     9.23
   500     50   0.01     4.64      124      390     12045   0.13     6.83
   500    100    0.0     9.28        1     5398      5398   1.03    13.76
   500    100   0.01     9.28      168      819     31021   0.07     9.63
   500    150    0.0    13.92        1     5386      5386   0.19    18.04