In [None]:
import os
import subprocess
from collections import defaultdict
import pickle
import pandas as pd
from Bio import SeqIO

In [None]:
def build_mappings(mapping_f="/scratch0/Advait/nucl_gb.accession2taxid"):
    mapping_dict = defaultdict(list)
    count = 0
    with open(mapping_f, "r") as mapping:
        next(mapping)
        for line in mapping:
            _, sequence_id, taxid, _ = line.split("\t")

            if sequence_id in sequence_ids:
                mapping_dict[taxid].append(sequence_id)

            if count % 10000000 == 0:
                print(f"{count}/304977224")
            count += 1
            
    with open('dbs/taxid2seqid.pickle', 'wb') as handle:
        pickle.dump(mapping_dict, handle)

In [None]:
def build_record_dict(sequences_db_f):
    sequence_ids = set()
    record_dict = dict()
    count = 0
    for record in SeqIO.parse(sequences_db_f, "fasta"):
        count += 1
        if record.id not in sequence_ids:
            sequence_ids.add(record.id)
            record_dict[record.id] = record
    print(f"WARNING: {count-len(record_dict)} duplicated record found among {count} sequences.")
    
    return record_dict

In [None]:
def local_taxid_fetch(taxid, output_directory, taxid2seqid_dict, record_dict):
    reference_genome_path=os.path.join(output_directory, 'reference_genomes')
    if not os.path.exists(reference_genome_path):
        os.mkdir(reference_genome_path)
    
    records = []
    for sequence_id in taxid2seqid_dict[str(taxid)]:
        records.append(record_dict[sequence_id])
        
    if len(records) > 0:
        with open(os.path.join(reference_genome_path, f"taxid_{taxid}.fasta"), "w") as output_handle:
            SeqIO.write(records, output_handle, "fasta")
        return taxid, f"taxid_{taxid}", "N/A", "N/A", "N/A", "N/A", True
    else:
        return taxid, "N/A", "N/A", "N/A", "N/A", "N/A", False

In [None]:
def prepare_reference_genomes_offline(taxid_queries, output_directory, sequences_db_f, mapping_f, ncbi_taxa_db):
    with open(mapping_f, 'rb') as handle:
        taxid2seqid_dict = pickle.load(handle)
        
    record_dict = build_record_dict(sequences_db_f)
    
    download_result = []
    for taxid in taxid_queries:
        download_result.append(local_taxid_fetch(taxid, output_directory, taxid2seqid_dict, record_dict))
        
    reference_metadata = pd.DataFrame(download_result,
                                      columns=['Taxonomy ID', 
                                               'Assembly Accession ID', 
                                               'Source Database', 
                                               'Is Representative', 
                                               'Assembly Level', 
                                               'Organism of Assembly', 
                                               'Downloaded'])
    
    taxonomy_name = []
    for taxid in reference_metadata['Taxonomy ID']:
        taxonomy_name.append(ncbi_taxa_db.get_taxid_translator([taxid])[taxid])
    reference_metadata['Species'] = taxonomy_name
    reference_metadata.to_csv(os.path.join(output_directory, 'reference_metadata.csv'), index=False)
    
    return reference_metadata

## Testing

In [None]:
sequences_db_f = "/home/dbs/SeqScreenDB_21.4/bowtie2/blacklist.seqs.nt.fna"
mapping_f = "/home/Users/yl181/seqscreen_nano/reference_finder/dbs/taxid2seqid.pickle"

In [None]:
from ete3 import NCBITaxa
ete3db = "/home/Users/yl181/seqscreen_nano/ete3_ncbi_taxonomy_db/taxa.sqlite"
ncbi_taxa_db = NCBITaxa(dbfile=ete3db)

In [None]:
output_directory = "test"

In [None]:
taxid_queries = [1280,
 1613,
 1351,
 1423,
 287,
 5207,
 1639,
 1642,
 562,
 1638,
 28901,
 1392,
 623,
 573,
 4932,
 96241,
 1590,
 1352,
 294,
 1643,
 176275]

In [None]:
prepare_reference_genomes_offline(taxid_queries, output_directory, sequences_db_f, mapping_f, ncbi_taxa_db)