In [None]:
import pickle
from spcount.Species import Species
from spcount.Query import Query
import sys
import os

In [None]:


output_prefix='/scratch/vickers_lab/projects/20220417_bacteria_genome/nonhost_genome/refseq_bacteria_table/result/RA_4893_2'
with open(output_prefix + '.pickle', 'rb') as handle:
  myobj = pickle.load(handle)

In [None]:
species_taxonomy_map=myobj["species_taxonomy_map"]
keys = [k for k in species_taxonomy_map.keys()]
species_taxonomy_map[keys[0]]


In [None]:
species_list=myobj["species_list"]
species_list[0].sample_query_count

In [None]:
#logger.info("merge identical")
for sv in species_list:
  sv.is_identical = False
  sv.identical_species = []

for i1 in range(0, len(species_list)-1):
  if species_list[i1].is_identical:
    continue
  if i1 % 100 == 0:
    print(f"checking identical: {i1+1} / {len(species_list)}")
  for i2 in range(i1+1, len(species_list)):
    if species_list[i2].is_identical:
      continue
    if species_list[i1].query_count != species_list[i2].query_count:
      break
    if species_list[i1].queries_set == species_list[i2].queries_set:
      species_list[i2].is_identical = True
      species_list[i1].identical_species.append(species_list[i2])

old_len = len(species_list)
new_len = len([sv for sv in species_list if not sv.is_identical])
print(f"{old_len - new_len} identical species were found")


In [None]:
for i1 in range(0, len(species_list)-1):
  if species_list[i1].is_subset or species_list[i1].is_identical:
    continue
  if i1 % 100 == 0:
    print(f"checking subset: {i1+1} / {len(species_list)}")
  for i2 in range(i1+1, len(species_list)):
    if species_list[i2].is_subset or species_list[i1].is_identical:
      continue
    if species_list[i1].contains(species_list[i2]):
      species_list[i2].is_subset = True

#remove subset species from query 
for species in species_list:
  if species.is_subset:
    all_names = [species.name] + species.identical_species
    for qlist in species.queries.values():
      for q in qlist:
        q.remove_species(all_names)

old_len = len(species_list)
species_list = [sv for sv in species_list if not sv.is_subset]
new_len = len(species_list)
print(f"{old_len - new_len} subset were removed")


In [None]:
samples=myobj['samples']
samples

In [None]:
with open(output_prefix + ".species.query.count", "wt") as fout:
  fout.write("Feature\t" + "\t".join(samples) + "\n")
  for species in species_list:
    if species.is_identical:
      continue
    species_name = species.name
    if len(species.identical_species) > 0:
      species_name = species_name + "," + ",".join([s.name for s in species.identical_species])
    countstr = "\t".join(str(species.sample_query_count[sample]) if sample in species.sample_query_count else "0" for sample in samples)
    fout.write(f"{species_name}\t{countstr}\n")


In [None]:
query_list=set(myobj['query_list'])
for species in species_list:
  for sample_queries in species.queries.values():
    for query in sample_queries:
      assert(query in query_list)


In [None]:
for query in query_list:
  query.estimate_count()


In [None]:
for species in species_list:
  species.sum_estimated_count()


In [None]:
with open(output_prefix + ".species.estimated.count", "wt") as fout:
  fout.write("Feature\t" + "\t".join(samples) + "\n")
  for species in species_list:
    if species.is_identical:
      continue

    species_name = species.name
    if len(species.identical_species) > 0:
      species_name = species_name + "," + ",".join([s.name for s in species.identical_species])
    countstr = "\t".join("{:.2f}".format(species.sample_estimated_count[sample]) if sample in species.sample_estimated_count else "0" for sample in samples)
    fout.write(f"{species_name}\t{countstr}\n")


In [None]:
levels = [ 'genus', 'family', 'order', 'class', 'phylum']
for level in levels:
  cat_map = {}
  for species in species_list:
    cat_name = species_taxonomy_map[species.name][level]
    #print(species.name + ": " + cat_name)
    if cat_name not in cat_map:
      cat_map[cat_name] = Species(cat_name)
    cat_map[cat_name].identical_species.append(species)
  
  cats = [cat for cat in cat_map.values()]
  for cat in cats:
    cat.sample_query_count = {}
    for sample in samples:
      squeries_set = set()
      for species in cat.identical_species:
        if sample in species.queries:
          for query in species.queries[sample]:
            squeries_set.add(query)
      scount = sum([query.count for query in squeries_set])
      cat.sample_query_count[sample] = scount
    cat.query_count = sum([v for v in cat.sample_query_count.values()])

    cat.estimated_count = sum([species.estimated_count for species in cat.identical_species])
    cat.sample_estimated_count = {sample: sum([species.sample_estimated_count[sample] if sample in species.sample_estimated_count else 0 for species in cat.identical_species]) for sample in samples}
    
  cats.sort(key=lambda x:x.query_count, reverse=True)

  with open(output_prefix + "." + level + ".query.count", "wt") as fout:
    fout.write("Feature\t" + "\t".join(samples) + "\n")
    for species in cats:
      species_name = species.name
      countstr = "\t".join(str(species.sample_query_count[sample]) if sample in species.sample_query_count else "0" for sample in samples)
      fout.write(f"{species_name}\t{countstr}\n")

  with open(output_prefix + "." + level + ".estimated.count", "wt") as fout:
    fout.write("Feature\t" + "\t".join(samples) + "\n")
    for species in cats:
      species_name = species.name
      countstr = "\t".join("{:.2f}".format(species.sample_estimated_count[sample]) if sample in species.sample_estimated_count else "0" for sample in samples)
      fout.write(f"{species_name}\t{countstr}\n")

