In [None]:
import os
import subprocess
from collections import defaultdict
import math
import pandas as pd
from Bio import SeqIO

In [None]:
def parse_reference_fasta(assembly_id):
    reference_fasta = os.path.join(output_directory, 'reference_genomes', f'{assembly_id}.fasta')
    
    with open(reference_fasta, "r") as handle:
        genome_length = 0
        genome_ids = []
        
        for record in SeqIO.parse(handle, "fasta"):      
            genome_length += len(record.seq)
            genome_ids.append(record.id)
    
    return genome_length, genome_ids

In [None]:
def get_mapping_stats(assembly_id, genome_id=None):
    bam_file = os.path.join(output_directory, 'bam_files', f"{assembly_id}.sorted.bam")
    
    if genome_id is not None:
        stats_res = subprocess.Popen(['samtools', 'stats', bam_file, genome_id],
                               stdout=subprocess.PIPE)
    else:
        stats_res = subprocess.Popen(['samtools', 'stats', bam_file],
                               stdout=subprocess.PIPE)

    grep_res = subprocess.Popen(['grep', '^SN'],
                              stdin=stats_res.stdout,
                              stdout=subprocess.PIPE)

    mapping_res = subprocess.run(['cut', '-f', '2-'],
                                 check=True,
                                 universal_newlines=True,
                                 stdin = grep_res.stdout,
                                 stdout=subprocess.PIPE)

    mapping_stats = dict()
    for line in mapping_res.stdout.strip().split('\n'):
        attribute, value = line.split(':\t')
        mapping_stats[attribute] = float(value.split('\t')[0])
        
    return mapping_stats

In [None]:
def get_expected_coverage(genome_length, reads_mapped, genome_totol_count):
    mean_mapping_length = genome_totol_count/reads_mapped
    
    N = genome_length/mean_mapping_length
    x = reads_mapped
    
    expected_M = N*(1-((1-1/N)**x))
    variance = N*((1-1/N)**x) + (N**2)*(1-1/N)*((1-2/N)**x)-(N**2)*((1-1/N)**(2*x))
    std = math.sqrt(variance)
    expected_coverage = expected_M/N
    
    return expected_coverage, std

In [None]:
def calculate_depth(assembly_id, min_depth=1):
    depth_file = os.path.join(output_directory, 'depth_files', f"{assembly_id}.depth")
    
    genome_length, genome_ids = parse_reference_fasta(assembly_id)
    mapping_stats = get_mapping_stats(assembly_id)
    reads_mapped = mapping_stats['reads mapped']

    genome_pos_count = 0
    genome_totol_count = 0
    with open(depth_file, "r") as depth:
        for line in depth.readlines():
            
            genome_id = line.split("\t")[0]
            pos = int(line.split("\t")[1])
            depth = int(line.strip().split("\t")[2])

            if depth >= min_depth and genome_id in genome_ids:
                genome_pos_count += 1
                genome_totol_count += depth

    
    if genome_totol_count == 0:
        breadth_coverage = None
        depth_coverage = None
        expected_breadth_coverage = None
    else:
        breadth_coverage = genome_pos_count/genome_length
        depth_coverage = genome_totol_count/genome_pos_count
        expected_breadth_coverage, std = get_expected_coverage(genome_length, reads_mapped, genome_totol_count)
    
    return breadth_coverage, depth_coverage, expected_breadth_coverage

In [None]:
def calculate_depth_merged(assembly_ids, min_depth=1):
    depth_file = os.path.join(output_directory, 'depth_files', f"merged.depth")
    
    genome_id_pos_count = defaultdict(int)
    genome_id_totol_count = defaultdict(int)
    
    with open(depth_file, "r") as depth:
        for line in depth.readlines():            
            genome_id = line.split("\t")[0]
            pos = int(line.split("\t")[1])
            depth = int(line.strip().split("\t")[2])

            if depth >= min_depth:
                genome_id_pos_count[genome_id] += 1
                genome_id_totol_count[genome_id] += depth
    
    breadth_coverage_dict = defaultdict(float)
    depth_coverage_dict = defaultdict(float)
    reads_mapped_dict = defaultdict(int)
    expected_breadth_coverage_dict = defaultdict(float)
    
    for assembly_id in assembly_ids:
        genome_length, genome_ids = parse_reference_fasta(assembly_id)
        
        genome_pos_count = 0
        genome_totol_count = 0
        for genome_id in genome_ids:
            genome_pos_count += genome_id_pos_count[genome_id]
            genome_totol_count += genome_id_totol_count[genome_id]
            if genome_id_pos_count[genome_id] > 0:
                mapping_stats = get_mapping_stats('merged', genome_id)
                reads_mapped_dict[assembly_id] += mapping_stats['reads mapped']
            
        if genome_totol_count > 0:
            breadth_coverage_dict[assembly_id] = genome_pos_count/genome_length
            depth_coverage_dict[assembly_id] = genome_totol_count/genome_pos_count
            expected_breadth_coverage, std = get_expected_coverage(genome_length, reads_mapped_dict[assembly_id], genome_totol_count)
            expected_breadth_coverage_dict[assembly_id] = expected_breadth_coverage
            
    return breadth_coverage_dict, depth_coverage_dict, expected_breadth_coverage_dict

In [None]:
def merge_reference_fasta(assembly_ids):
    merged_fasta = os.path.join(output_directory, 'reference_genomes', f'merged.fasta')
    
    seq_records = []
    for assembly_id in assembly_ids:
        reference_fasta = os.path.join(output_directory, 'reference_genomes', f'{assembly_id}.fasta')

        with open(reference_fasta, "r") as handle:
            for record in SeqIO.parse(handle, "fasta"):   
                seq_records.append(record)
                    
    with open(merged_fasta, "w") as output_handle:
        SeqIO.write(seq_records, output_handle, "fasta")

In [None]:
def cal_ani(assembly_id):
    reference_fasta = os.path.join(output_directory, 'reference_genomes', f'{assembly_id}.fasta')

    with open(reference_fasta, "r") as handle:
        total_count = 0
        matched_count = 0

        for record in SeqIO.parse(handle, "fasta"):
            if record.id in consensus_record_dict:
                for idx, base in enumerate(record.seq):
                    if consensus_record_dict[record.id][idx] != 'N':
                        total_count += 1
                        if consensus_record_dict[record.id][idx] == base:
                            matched_count += 1
            else:
                print(assembly_id, record.id, record.description)
                    
    return matched_count/total_count

In [None]:
output_directory = "/home/Users/yl181/seqscreen_nano/ZymoBIOMICS.STD.Even.ont.minimap2.breadth50"

In [None]:
reference_metadata = pd.read_csv(os.path.join(output_directory, 'reference_metadata.csv'))
# reference_genome_path = os.path.join(output_directory, 'reference_genomes')
# depth_file_path = os.path.join(output_directory, 'depth_files')
# bam_file_path = os.path.join(output_directory, 'bam_files')

In [None]:
min_depth = 1

In [None]:
downloaded_assemblies = reference_metadata[reference_metadata['Downloaded']]

breadth_coverage_list = []
depth_coverage_list = []
expected_breadth_coverage_list = []
coverage_score = []
for assembly_id in downloaded_assemblies['Assembly Accession ID']:
    breadth_coverage, depth_coverage, expected_breadth_coverage = calculate_depth(assembly_id, min_depth=min_depth)
    breadth_coverage_list.append(breadth_coverage)
    depth_coverage_list.append(depth_coverage)
    expected_breadth_coverage_list.append(expected_breadth_coverage)
    coverage_score.append(breadth_coverage/expected_breadth_coverage)

In [None]:
downloaded_assemblies['Breadth Coverage'] = breadth_coverage_list
downloaded_assemblies['Expected Coverage'] = expected_breadth_coverage_list
downloaded_assemblies['Coverage Score'] = coverage_score
downloaded_assemblies['Depth Coverage'] = depth_coverage_list

In [None]:
downloaded_assemblies.to_csv(os.path.join(output_directory, 'coverage.csv'), index=False)

In [None]:
zymo_theoretical_abundance = dict()
with open('/home/Users/yl181/seqscreen_nano/ZymoBIOMICS.STD.refseq.v2/theoretical_composition.txt', 'r') as ground_truth_f:
    for line in ground_truth_f.readlines():
        tax_name = line.strip().split(" - ")[0]
        abundance = float(line.strip().split(" - ")[1])/100
        zymo_theoretical_abundance[tax_name] = abundance

In [None]:
zymo_theoretical_abundance

In [None]:
theoretical_abundance = []
labels = []
for taxname in downloaded_assemblies['Species']:
    try:
        theoretical_abundance.append(zymo_theoretical_abundance[taxname])
        labels.append("TP")
    except KeyError:
        theoretical_abundance.append(0)
        labels.append("FP")

In [None]:
downloaded_assemblies['Theoretical Abundance'] = theoretical_abundance
downloaded_assemblies['Labels'] = labels

In [None]:
result_df = downloaded_assemblies[['Taxonomy ID', 'Species', 'Breadth Coverage', 'Depth Coverage', 'Theoretical Abundance', 'Labels']]

In [None]:
for taxname in zymo_theoretical_abundance:
    if taxname not in set(downloaded_assemblies['Species']):
        print(taxname)
    

In [None]:
downloaded_assemblies.sort_values(['Coverage Score'], ascending=False).drop(['Source Database', 'Is Representative', 'Downloaded'], axis=1)

In [None]:
downloaded_assemblies.sort_values(['Coverage Score'], ascending=False).to_csv(os.path.join(output_directory, 'coverage.csv'), index=False)

## Filtering by Coverage Score

In [None]:
filtered_assemblies = list(downloaded_assemblies[downloaded_assemblies['Coverage Score'] >= 0.7]['Assembly Accession ID'])
output_directory = "/home/Users/yl181/seqscreen_nano/ZymoBIOMICS.STD.Even.ont.minimap2.score70"

## Filtering by Breadth Coverage

In [None]:
# filtered_assemblies = list(downloaded_assemblies[downloaded_assemblies['Breadth Coverage'] >= 0.5]['Assembly Accession ID'])
# output_directory = "/home/Users/yl181/seqscreen_nano/ZymoBIOMICS.STD.Even.ont.minimap2.breadth50"

## Re-alignment

In [None]:
merge_reference_fasta(filtered_assemblies)

### Run the following command:

In [None]:
os.system(f'python utils/re-alignment.py -i ../ZymoBIOMICS.STD.Even.ont.raw_sequences/ERR3152364.downsampled.fastq -o {output_directory} -t 40')

In [None]:
breadth_coverage_dict, depth_coverage_dict, expected_breadth_coverage_dict = calculate_depth_merged(list(downloaded_assemblies['Assembly Accession ID']), min_depth=1)

In [None]:
breadth_coverage_list = []
depth_coverage_list = []
expected_breadth_coverage_list = []
coverage_score = []
for assembly_id in downloaded_assemblies['Assembly Accession ID']:
    breadth_coverage_list.append(breadth_coverage_dict[assembly_id])
    depth_coverage_list.append(depth_coverage_dict[assembly_id])
    expected_breadth_coverage_list.append(expected_breadth_coverage_dict[assembly_id])
    
    if expected_breadth_coverage_dict[assembly_id] != 0:
        coverage_score.append(breadth_coverage_dict[assembly_id]/expected_breadth_coverage_dict[assembly_id])
    else:
        coverage_score.append(0)

In [None]:
downloaded_assemblies['BC2'] = breadth_coverage_list
downloaded_assemblies['EC2'] = expected_breadth_coverage_list
downloaded_assemblies['CS2'] = coverage_score
downloaded_assemblies['DC2'] = depth_coverage_list

In [None]:
downloaded_assemblies.sort_values(['CS2'], ascending=False).drop(['Source Database', 'Is Representative', 'Downloaded', 'Assembly Level', 'Organism of Assembly'], axis=1)

In [None]:
downloaded_assemblies.sort_values(['Coverage Score'], ascending=False).to_csv(os.path.join(output_directory, 're-alignment.csv'), index=False)

In [None]:
threads=20
merged_bam = os.path.join(output_directory, 'bam_files', 'merged.sorted.bam')
subprocess.run(['samtools', 'consensus', 
                '--show-ins', 'no', 
                '--show-del', 'yes', 
                '-a',
                '--threads', str(threads),
                merged_bam, 
                '-o', os.path.join(output_directory, 'merged_consensus.fasta')],
              check=True)

In [None]:
consensus_record_dict = SeqIO.to_dict(SeqIO.parse(os.path.join(output_directory, 'merged_consensus.fasta'), "fasta"))

In [None]:
ani_list = []
for idx, row in downloaded_assemblies.iterrows():
    if row['CS2'] != 0:
        assembly_id = row['Assembly Accession ID']
        ani_list.append(cal_ani(assembly_id))
    else:
        ani_list.append(0)

In [None]:
downloaded_assemblies['Consensus ANI'] = ani_list

In [None]:
downloaded_assemblies.sort_values(['Coverage Score'], ascending=False).to_csv(os.path.join(output_directory, 're-alignment.csv'), index=False)

In [None]:
downloaded_assemblies