In [None]:
import os
import subprocess
import sys

from collections import defaultdict
from Bio import Entrez
import time

import pandas as pd
import json
from ete3 import NCBITaxa

In [None]:
sys.path.insert(0, '../utils')
from reference_finder import download_reference_genome, unpack, cat_reference_genome
from alignment import run_minimap2, sort_samfile, calculate_depth

In [None]:
import warnings
warnings.filterwarnings("ignore")

In [None]:
def get_species_taxid(taxid, ncbi_taxa_db, valid_kingdom={2, 4751, 2157, 10239}):
    lineage = ncbi_taxa_db.get_lineage(taxid)
    if bool(set(lineage) & valid_kingdom):
        taxid2rank_dict = ncbi_taxa_db.get_rank(lineage)
        for lineage_taxid in taxid2rank_dict:
            if taxid2rank_dict[lineage_taxid] == 'species':
                return lineage_taxid
    return None

## Load ETE3 NCBITaxa

In [None]:
ncbi_taxa_db = NCBITaxa()

In [None]:
# ncbi.update_taxonomy_database()

In [None]:
valid_kingdom = [2, 4751, 2157, 10239] # bacteria, archaea, viruses, and fungi

## Filtering Seqscreen Taxonomic Assignment

In [None]:
seqscreen_output = "/home/Users/yl181/seqscreen_nano/output_datasets/ZymoBIOMICS.STD.Even.ont.seqscreen"

In [None]:
classification_result_df = pd.read_csv(os.path.join(seqscreen_output, 'taxonomic_identification', 'taxonomic_assignment', 'taxonomic_results.txt'), sep='\t')

In [None]:
total_read_count, _ = classification_result_df.shape

In [None]:
taxid_count_dict = defaultdict(int)
taxid_species_lookup = dict()
error_count = 0
for taxid in classification_result_df['taxid']:
    
    try:
        taxid = int(taxid)
        try:
            species_taxid = taxid_species_lookup[taxid]
        except KeyError:
            species_taxid = get_species_taxid(taxid, ncbi_taxa_db)
            taxid_species_lookup[taxid] = species_taxid
            
        if species_taxid is not None:
            taxid_count_dict[species_taxid] += 1
    except ValueError:
        error_count += 1

In [None]:
min_frac = 0.002

taxid_queries = []
for taxid in taxid_count_dict:
    if taxid_count_dict[taxid] >= min_frac * total_read_count:
        taxid_queries.append(taxid)

In [None]:
min_frac * total_read_count

In [None]:
len(taxid_queries)

## Fetch Reference Genomes

In [None]:
output_directory = "/home/Users/yl181/seqscreen_nano/ZymoBIOMICS.STD.Even.ont.minimap2"

In [None]:
working_dir = os.path.join(output_directory, 'ncbi_downloads') 

In [None]:
if not os.path.exists(working_dir):
    os.mkdir(working_dir)

In [None]:
download_result = []

for taxid in taxid_queries:
    download_result.append(download_reference_genome(taxid, working_dir))

In [None]:
unpack(working_dir, output_directory)

In [None]:
reference_metadata = pd.DataFrame(download_result, columns=['Taxonomy ID', 'Assembly Accession ID', 'Source Database', 'Is Representative', 'Assembly Level', 'Organism of Assembly', 'Downloaded'])

In [None]:
taxonomy_name = []

for taxid in reference_metadata['Taxonomy ID']:
    taxonomy_name.append(ncbi_taxa_db.get_taxid_translator([taxid])[taxid])
    
reference_metadata['Species'] = taxonomy_name

In [None]:
reference_metadata = reference_metadata[['Taxonomy ID', 'Species', 'Assembly Accession ID', 'Source Database', 'Is Representative', 'Assembly Level', 'Organism of Assembly', 'Downloaded']]

In [None]:
reference_metadata.to_csv(os.path.join(output_directory, 'reference_metadata.csv'), index=False)
reference_metadata

In [None]:
cat_reference_genome(reference_metadata, output_directory, reference_genome_path=os.path.join(output_directory, 'reference_genomes'))

In [None]:
# reference_genome_path = os.path.join(output_directory, 'reference_genomes')

In [None]:
# input_fastq = '/home/Users/yl181/seqscreen_nano/ZymoBIOMICS.STD.Even.ont.raw_sequences/ERR3152364.downsampled.fastq'

In [None]:
# downloaded_assemblies = reference_metadata[reference_metadata['Downloaded']]

# num_cores = 20

# for assembly_id in downloaded_assemblies['Assembly Accession ID']:
#     reference_fasta = os.path.join(reference_genome_path, f'{assembly_id}.fasta')
#     run_minimap2(input_fastq, reference_fasta, assembly_id, output_directory, threads=num_cores)
#     sort_samfile(assembly_id, output_directory, num_cores)
#     calculate_depth(assembly_id, output_directory)