# Code that generate the figures from Denis et al. 2025

In this notebook not all the figure were generated with the code bellow. Some figure were generated using Rscript and snakemake pipelines.

## Figure 3

### Figure 3A

In [None]:
import pandas as pd
import polars as pl
from pathlib import Path
import pytaxonkit
from Bio import SeqIO
from collections import defaultdict
import polars.selectors as cs
import seaborn.objects as so
from seaborn import axes_style
from seaborn import plotting_context

depth_df = []

folder_depth = Path("/mnt/archgen/microbiome_coprolite/aVirus/05-results/taxonomic_annotation/empirical_data_bwa/depth_breadth/")

sorted_files = sorted(folder_depth.glob("*depth_breadth.tsv"))

for file in sorted_files:
    depth_df.append(pl.read_csv(file, separator="\t"))
    depth_df[-1] = depth_df[-1].with_columns(
        pl.col('Breadth').cast(pl.Float32),
        pl.col('Depth avg').cast(pl.Float32),
        pl.col('Depth median').cast(pl.Float32),
        sample=pl.lit(file.stem.split('.')[0]),
        level=pl.lit(file.stem.split('.')[2]),
    )

depth_df = pl.concat(depth_df)

depth_df = depth_df.rename(
    {
        "Contig": "seqid",
    }
)

depth_df = depth_df.filter(~pl.col('level').eq('unclassified'))

depth_df = depth_df.with_columns(
    seqid_sample=pl.col('seqid') + pl.lit('_') + pl.col('sample'),
)

sample_genome_gt50 = depth_df.filter(
    pl.col('level').eq('total'),
    pl.col('Breadth').gt(0.5)
)['seqid_sample'].to_list()

depth_df = depth_df.filter(pl.col('seqid_sample').is_in(sample_genome_gt50))

# List of columns to ensure they exist
required_columns = ['ASM001', 'ASM002', 'ASM003', 'BSM001', 'BSM002', 
                    'HSM001', 'HSM002', 'HSM003', 'HSM004', 'ZSM005', 
                    'ZSM025', 'ZSM027', 'ZSM028', 'ZSM031', 'ZSM101', 
                    'ZSM102', 'ZSM103', 'ZSM214', 'ZSM216', 'ZSM219']
                    
depth_df.sort(['sample', 'Breadth'], descending=[True, False]).filter(pl.col('seqid').eq('LR597635')).filter(
    pl.col('sample').is_in(required_columns),
    # pl.col('level').is_in(['species', 'total']),
    pl.col('sample').eq('HSM001')
)

metadata = pl.read_csv(
    '/mnt/archgen/microbiome_coprolite/aVirus/03-data/refdbs/ICTV/ICTV_database/20240209/ICTV_metadata.tsv',
    separator="\t"
)

metadata = metadata.rename(
    {
        'Virus GENBANK accession': 'seqid',
        'Host source': 'host',
    }
)

seq2taxid = {}

with open ('/mnt/archgen/microbiome_coprolite/aVirus/03-data/refdbs/centrifuge_db/ICTV_decoy/seqid2taxid.map') as f:
    for line in f:
        seqid, taxid = line.strip().split('\t')
        seq2taxid[seqid] = taxid

metadata = metadata.with_columns(
    taxid=pl.col('seqid').replace_strict(
        seq2taxid,
        default='0'
    ),
)

metadata = metadata.rename(
    {
        'taxid': 'taxonomy_id',
    }
)

sample_genome_species_gt5 = depth_df.filter(
    pl.col('level').eq('species'),
    pl.col('Breadth').gt(0.05)
)['seqid_sample'].to_list()

depth_df_total = depth_df.filter(
    pl.col('level').eq('total'),
    pl.col('seqid_sample').is_in(sample_genome_species_gt5)
)

depth_df_total = depth_df_total.join(metadata[['seqid', 'host', 'Species', 'Genus', 'taxonomy_id']].unique('seqid'), on='seqid', how='left')

depth_df_total = depth_df_total.filter(~pl.col('taxonomy_id').is_null())

# Using the polars dataframe with the depth and breadth information and the species, genus, taxid, sample and host information, I will for each pair genus/sample keep one the seqid with the higest breadth value
# As I use bowtie with k=50, I could have reads that map more than one representative of the same genus, so I will keep the species with the highest breadth value
depth_df_total = depth_df_total.sort(['Genus', 'sample', 'Breadth'], descending=True).unique(['Genus', 'sample']).sort(['Species'])

depth_df_total = depth_df_total.pivot(
    index=['taxonomy_id'],
    on='sample',
    values='Breadth',
    aggregate_function='max',
).fill_null(0)

depth_df_total = depth_df_total.join(metadata[['seqid', 'host', 'Species', 'Genus', 'taxonomy_id']].unique('taxonomy_id'), on='taxonomy_id', how='left')

depth_df_total = depth_df_total.drop_nulls(subset=['host'])

# List of columns to ensure they exist
required_columns = ['ASM001', 'ASM002', 'ASM003', 'BSM001', 'BSM002', 
                    'HSM001', 'HSM002', 'HSM003', 'HSM004', 'ZSM005', 
                    'ZSM025', 'ZSM027', 'ZSM028', 'ZSM031', 'ZSM101', 
                    'ZSM102', 'ZSM103', 'ZSM214', 'ZSM216', 'ZSM219']

# Check and add missing columns
for col in required_columns:
    if col not in depth_df_total.columns:
        depth_df_total = depth_df_total.with_columns(pl.lit(0).alias(col))


# Now you can proceed with your operations
depth_df_total = depth_df_total.select(
    [
        'seqid', 'Species', 'Genus', 'taxonomy_id', 'host'
    ] + required_columns
)

depth_bool = depth_df_total.with_columns(
                pl.col("^[ABHZ]SM[0-9][0-9][0-9]$").gt(0.5) * 1
            ).sort('Species')



df_presence = depth_bool.rename(
    {
        'Species': 'name',
    }
)

df_presence_prokaryote = df_presence.filter(pl.col('host').is_in(['bacteria', 'archaea'])).select(['name','^[ABHZ]SM[0-9][0-9][0-9]$' ]).unpivot(
    on=cs.numeric(),
    index='name',
    value_name='presence',
    variable_name='site',
)

df_presence_prokaryote = df_presence_prokaryote.with_columns(
    pl.col('site').str.head(3)
)

df_presence_prokaryote_species = df_presence_prokaryote.group_by(['name']).agg(pl.col('presence').sum()).filter(pl.col('presence') > 0).sort('presence', descending=True)

df_presence_prokaryote_species = df_presence_prokaryote_species.rename(
    {
        'presence': 'total',
    }
)

df_presence_prokaryote = df_presence_prokaryote.group_by(['name', 'site']).agg(pl.col('presence').sum()).filter(pl.col('presence') > 0).sort('presence', descending=True)

df_presence_prokaryote = df_presence_prokaryote_species.top_k(by='total', k=50).join(
    df_presence_prokaryote, left_on='name', right_on='name', how='left'
)

df_presence_prokaryote = df_presence_prokaryote.with_columns(
    pl.col('site').replace_strict(
        {
            'ASM': 'Arid West Cave',
            'BSM': 'Boomerang Shelter',
            'HSM': 'Hallstatt',
            'ZSM': 'Zape',
        }
    )
).rename(
    {
        'site': 'Site',
        'presence': 'Number of samples the species is present in',
        'name': 'Species',
    }
).sort(['total','Site'], descending=[True, False])

df_presence_prokaryote = df_presence_prokaryote.filter(~pl.col('Species').eq('Sinsheimervirus phiX174'))


mapped_species_first_threshold = depth_df.filter(
    pl.col('level').eq('total'),
).join(
    metadata[['seqid', 'host', 'Species', 'Genus', 'taxonomy_id']].unique('seqid'), 
    on='seqid', 
    how='left'
).filter(~pl.col('taxonomy_id').is_null())['Species'].unique().to_list()

mapped_species_second_threshold = df_presence_prokaryote['Species'].unique().to_list()

In [None]:
import plotly.graph_objects as go
from collections import Counter

# Extract genus and species
def get_genus_species(species_list):
    return_list = []

    for species in species_list:
        if species == "Sinsheimervirus phiX174":
            print(species)
            continue

        g, *s = species.split()
        s = ' '.join(s)

        if not g.endswith('virus'):
            *g, s = species.split()
            g = ' '.join(g)
        
        return_list.append((g, s))

    return return_list

taxonomy_species = df_presence_prokaryote['Species'].unique().to_list()

taxonomy_genus_species = get_genus_species(taxonomy_species)
mapped_first_threshold = get_genus_species(mapped_species_first_threshold)
mapped_second_threshold = get_genus_species(mapped_species_second_threshold)

lost_genus = set(g for g, s in taxonomy_genus_species) - set(g for g, s in mapped_first_threshold)

# Group species by genus
taxonomy_by_genus = defaultdict(list)
first_threshold_by_genus = defaultdict(list)
second_threshold_by_genus = defaultdict(list)

for genus, species in taxonomy_genus_species:
    if genus in lost_genus:
        taxonomy_by_genus["Other"].append(species)
    else:
        taxonomy_by_genus[genus].append(species)

for genus, species in mapped_first_threshold:
    first_threshold_by_genus[genus].append(species)

for genus, species in mapped_second_threshold:
    second_threshold_by_genus[genus].append(species)

# Prepare node and link information
genera = list(taxonomy_by_genus.keys())
nodes = (
    genera +
    ["Lost"] +
    genera +
    ["Mapped", "Lost"]
)

genus_start_index = {genus: i for i, genus in enumerate(genera)}
lost_middle_index = len(genera)
genus_middle_index = {genus: i + len(genera) + 1 for i, genus in enumerate(genera)}
mapped_index = len(genera) + len(genera) + 1
lost_final_index = len(genera) + len(genera) + 2

sources = []
targets = []
values = []

# Add flows for each genus
for genus, species_list in taxonomy_by_genus.items():
    genus_start_idx = genus_start_index[genus]
    genus_middle_idx = genus_middle_index[genus]

    # First threshold transitions
    mapped_first_count = len([s for s in species_list if s in first_threshold_by_genus.get(genus, [])])
    lost_first_count = len([s for s in species_list if s not in first_threshold_by_genus.get(genus, [])])
    
    if mapped_first_count > 0:
        sources.append(genus_start_idx)  # From genus start
        targets.append(genus_middle_idx)  # To genus in middle
        values.append(mapped_first_count)
    
    if lost_first_count > 0:
        sources.append(genus_start_idx)  # From genus start
        targets.append(lost_middle_index)  # To "Lost (Middle)"
        values.append(lost_first_count)

    # Second threshold transitions for species mapped at first threshold
    mapped_second_count = len([s for s in species_list if s in second_threshold_by_genus.get(genus, [])])
    not_mapped_second_count = len([s for s in species_list if s in first_threshold_by_genus.get(genus, []) and s not in second_threshold_by_genus.get(genus, [])])

    if mapped_second_count > 0:
        sources.append(genus_middle_idx)  # From genus in middle
        targets.append(mapped_index)     # To "Mapped (Final)"
        values.append(mapped_second_count)
    
    if not_mapped_second_count > 0:
        sources.append(genus_middle_idx)  # From genus in middle
        targets.append(lost_final_index)  # To "Lost (Final)"
        values.append(not_mapped_second_count)

# Add flows for species lost in the first threshold
sources.append(lost_middle_index)  # From "Lost (Middle)"
targets.append(lost_final_index)   # To "Lost (Final)"
values.append(sum(len(s) for s in taxonomy_by_genus.values()) - len(mapped_species_first_threshold))

# Create the Sankey diagram
fig = go.Figure(go.Sankey(
    node=dict(
        pad=15,
        thickness=20,
        line=dict(color="black", width=0.5),
        label=nodes,
        color=["rgb(184,146,96)"] * len(genera) + ["rgb(177,0,67)"] +
              ["rgb(99,149,0)"] * len(genera) + ["rgb(99,149,0)", "rgb(177,0,67)"]
    ),
    link=dict(
        source=sources,
        target=targets,
        value=values,
        color=[
            "rgba(99,149,0, 0.5)" if t in genus_middle_index.values() else
            "rgba(177,0,67, 0.5)" if t == lost_middle_index else
            "rgba(99,149,0, 0.5)" if t == mapped_index else
            "rgba(177,0,67, 0.5)" for t in targets
        ]
    )
))

fig.update_layout(
    title_text="Sankey Diagram with Intermediate Genus Retention and Thresholds",
    font_size=10,
    width=500,
    height=600
)

fig.show()
fig.write_image("/mnt/archgen/microbiome_coprolite/aVirus/05-results/taxonomic_annotation/empirical_data/plot/sankey_diagram.comparisonProfilevsMapping.nophiX.pdf")


### Figure 3B

In [None]:
p = (
    so.Plot(df_presence_prokaryote, y="Species", x="Number of samples the species is present in", color="Site")
    .add(so.Bar(), so.Stack())
    .theme(axes_style("whitegrid") | plotting_context("talk") | {"grid.linestyle": ":", "pdf.fonttype": 42})
    .scale(color=so.Nominal(values=["#930c02", "#5d8300","#008f93", "#755092"], #["#de1e10", "#7cae00","#00bfc4", "#9c6ac2"]
                            order=["Arid West Cave", "Boomerang Shelter", "Hallstatt", "Zape"]), #["ASM", "BSM", "HSM", "ZSM"]
            )
    .layout(size=(10, 9)) # width, height
)

### Figure 3C - Depth of the reads on the genome MG711460 (Mushuvirus mushu)

In [None]:
import pysam
import os
import polars as pl
from Bio import SeqIO
import seaborn as sns
import seaborn.objects as so
import matplotlib.pyplot as plt

bam = "/mnt/archgen/microbiome_coprolite/aVirus/05-results/taxonomic_annotation/empirical_data_bwa/sam2lca/ZSM103.viruses_decoy.sam2lca.total.sorted.bam"
reference = "MG711460"

sequence = SeqIO.index("/mnt/archgen/microbiome_coprolite/aVirus/05-results/taxonomic_annotation/empirical_data/top_viruses/viruses.fasta", "fasta")[reference].seq

bamfile = pysam.AlignmentFile(bam, "rb")

reference_size = bamfile.header.get_reference_length(reference)

coverage = bamfile.count_coverage(reference, start=0, end=reference_size)

coverage = [cov.tolist() for cov in coverage]

coverage = pl.DataFrame(coverage, schema=['A', 'C', 'G', 'T']).with_columns(
    total = pl.col('A') + pl.col('C') + pl.col('G') + pl.col('T'),
    pos = pl.lit(range(0, reference_size)),
)

coverage = coverage.with_columns(
    pl.concat_list(pl.col(['A', 'C', 'G', 'T']))
    .list.arg_max()
    .map_elements(lambda x: coverage.columns[x], return_dtype=pl.Utf8)
    .alias('max_column'),
    real_sequence = pl.Series(list(str(sequence))),
)

coverage = coverage.with_columns(
    pl.when(
        pl.col('total').eq(0),
    ).then(pl.col('real_sequence')).otherwise(pl.col('max_column')).alias('max_column')
)

diff_base = coverage.filter(
    ~pl.col('max_column').eq(pl.col('real_sequence')),
)

p = (
    so.Plot(data=coverage, x="pos", y="total")
    .add(so.Area(color=".5"))
    # .add(so.Bar(edgewidth=1, alpha=1), data=diff_base, x="pos", y="total", color="max_column")
    .theme({"axes.facecolor": "w", "axes.edgecolor": "slategray"})
    .layout(size=(15, 2))  # in inches * (dpi / 100)
)

p.save("/mnt/archgen/microbiome_coprolite/aVirus/05-results/taxonomic_annotation/empirical_data_bwa/depth_breadth/Mushu_coverage.pdf", bbox_inches='tight')

## Figure 4

### Figure 4A

In [None]:
import polars as pl
from pathlib import Path
import seaborn.objects as so
from seaborn import axes_style
from seaborn import plotting_context
import seaborn as sns
import matplotlib.pyplot as plt

depth_df = []

folder_depth = Path("/mnt/archgen/microbiome_coprolite/aVirus/05-results/taxonomic_annotation/empirical_data_bwa/depth_breadth/")

sorted_files = sorted(folder_depth.glob("*depth_breadth.tsv"))

for file in sorted_files:
    depth_df.append(pl.read_csv(file, separator="\t"))
    depth_df[-1] = depth_df[-1].with_columns(
        pl.col('Breadth').cast(pl.Float32),
        pl.col('Depth avg').cast(pl.Float32),
        pl.col('Depth median').cast(pl.Float32),
        sample=pl.lit(file.stem.split('.')[0]),
        level=pl.lit(file.stem.split('.')[2]),
    )

depth_df = pl.concat(depth_df)

depth_df = depth_df.rename(
    {
        "Contig": "seqid",
        "Breadth": "Breadth_reads",
        "Depth avg": "Depth_avg_reads",
        "Depth median": "Depth_median_reads",
    }
)

depth_df = depth_df.filter(~pl.col('level').eq('unclassified'))

depth_df = depth_df.with_columns(
    seqid_sample=pl.col('seqid') + pl.lit('_') + pl.col('sample'),
)

depth_df_reads = depth_df.filter(
    pl.col('level').eq('total'),
    ~pl.col('seqid').eq('J02482')
)

depth_df = []

folder_depth = Path("/mnt/archgen/microbiome_coprolite/aVirus/05-results/taxonomic_annotation/empirical_data_bwa/vclust/depth_breadth/")

# sorted_files = sorted(folder_depth.glob("*vclust.ani.depth_breadth.tsv"))
sorted_files = sorted(folder_depth.glob("*_coverage.tsv"))

for file in sorted_files:
    depth_df.append(pl.read_csv(file, separator="\t"))
    depth_df[-1] = depth_df[-1].rename(
        {
            "coverage": "Breadth",
            "reference": "Contig",
        }
    ).with_columns(
        pl.col('Breadth').cast(pl.Float32) / 100,
        # pl.col('Depth avg').cast(pl.Float32),
        # pl.col('Depth median').cast(pl.Float32),
        sample=pl.lit(file.stem.split('.')[0]),
        # software=pl.lit(file.stem.split('.')[1]), # [1] for cmseq
        software=pl.lit(file.stem.split('.')[2]), # [2] for vclust direct
    )

depth_df = pl.concat(depth_df)

depth_df = depth_df.rename(
    {
        "Contig": "seqid",
        "Breadth": "Breadth_contigs",
    }
)

depth_df = depth_df.with_columns(
    seqid_sample=pl.col('seqid') + pl.lit('_') + pl.col('sample'),
)

depth_df_contigs = depth_df.filter(
    ~pl.col('seqid').eq('J02482')
)

depth_concat = depth_df_contigs.select(
        ['seqid_sample', 'Breadth_contigs', 'software']
    ).join(
        depth_df_reads.select(
            ['seqid_sample', 'Breadth_reads', 'Depth_avg_reads', 'Depth_median_reads']
        ), 
        on='seqid_sample', how='left'
    ).with_columns(
        pl.col('Breadth_reads').fill_null(0),
        pl.col('Breadth_contigs').fill_null(0),
    )

vclust_df = []

folder_vclust = Path("/mnt/archgen/microbiome_coprolite/aVirus/05-results/taxonomic_annotation/empirical_data_bwa/vclust/")

sorted_files = sorted(folder_vclust.glob("*parquet"))

for file in sorted_files:
    vclust_df.append(pl.read_parquet(file))
    vclust_df[-1] = vclust_df[-1].with_columns(
        sample=pl.lit(file.stem.split('.')[0]),
        software=pl.lit(file.stem.split('.')[2]),
    )

vclust_df = pl.concat(vclust_df)

vclust_df = vclust_df.with_columns(
    seqid_sample=pl.col('reference') + pl.lit('_') + pl.col('sample'),
    query_sample=pl.col('query') + pl.lit('_') + pl.col('sample'),
)

vclust_df = vclust_df.rename(
    {
        "reference": "seqid",
    }
)

vclust_df_contigs = vclust_df.filter(
    ~pl.col('seqid').eq('J02482'),
    pl.col('query').str.starts_with('NODE')
)

pydamage_df = []

folder_pydamage = Path("/mnt/archgen/microbiome_coprolite/palaeofaeces_ecoevol/04-analysis/pydamage/")

sorted_files = sorted(folder_pydamage.glob("[ABHZ]*.pydamage.tsv.gz"))

for file in sorted_files:
    pydamage_df.append(pl.read_csv(file, separator="\t"))
    pydamage_df[-1] = pydamage_df[-1].with_columns(
        pl.col('CtoT-0').cast(pl.Float32),
        pl.col('damage_model_pmax').cast(pl.Float32),
        sample=pl.lit(file.stem.split('.')[0]),
        software=pl.lit(file.stem.split('.')[1]),
    )

pydamage_df = pl.concat(pydamage_df)

pydamage_df = pydamage_df.rename(
    {
       "reference" : "query",
        "CtoT-0":"Damage",
        # "damage_model_pmax":"Damage",
    }
)

pydamage_df = pydamage_df.with_columns(
    query_sample=pl.col('query') + pl.lit('_') + pl.col('sample'),
).filter(
    pl.col('Damage').gt(0)
)

df_damage = vclust_df.join(pydamage_df, on='query_sample', how='left')

df_damage = df_damage.with_columns(
    seqid_sample = pl.col('seqid') + pl.lit('_') + pl.col('sample')
)

df_damage = df_damage.filter(
    ~pl.col('Damage').is_null(),
)

df_damage = df_damage.group_by(['seqid_sample']).agg(pl.col('Damage').mean())

depth_concat = depth_concat.join(df_damage, on='seqid_sample', how='left')

depth_concat = depth_concat.with_columns(
    sample=pl.col('seqid_sample').str.split('_').list.get(1),
)


To calculate the length per reference:

```bash
for bam in *viruses_decoy.cleaned.processed.sorted.bam ; do echo "--------------"$bam"-----------" ; samtools view -F4 $bam | awk 'BEGIN { print "Reference\tAverage_Read_Length" } $3!="*" { tot[$3] += length($10); count[$3]++ } END { for(ref in tot) { printf "%s\t%.2f\n", ref, tot[ref]/count[ref] } }' > $(basename $bam .bam).avg.reads_length.tsv ; done
```

In [None]:
reads_len_df = []

folder_depth = Path("/mnt/archgen/microbiome_coprolite/aVirus/05-results/taxonomic_annotation/empirical_data_bwa/BAM_files")

sorted_files = sorted(folder_depth.glob("*.viruses_decoy.cleaned.processed.sorted.avg.reads_length.tsv"))

for file in sorted_files:
    reads_len_df.append(pl.read_csv(file, separator="\t"))
    reads_len_df[-1] = reads_len_df[-1].with_columns(
        pl.col('Average_Read_Length').cast(pl.Float32),
        sample=pl.lit(file.stem.split('.')[0]),
    ).rename(
        {
            "Reference": "seqid",
        }
    )

reads_len_df = pl.concat(reads_len_df)

reads_len_df = reads_len_df.with_columns(
    seqid_sample=pl.col('seqid') + pl.lit('_') + pl.col('sample'),
)

reads_len_df_contigs = reads_len_df.filter(
    ~pl.col('seqid').eq('J02482')
)

depth_concat = depth_concat.join(reads_len_df_contigs.select(['seqid_sample', 'Average_Read_Length']), on='seqid_sample', how='left')

p = (
    so.Plot(depth_concat.filter(~pl.col("Damage").is_null()), x="Breadth_reads", y="Breadth_contigs", color="software")
    .add(so.Dot())
    .add(so.Line(color="grey"), data=pl.DataFrame({"x": [0, 1], "y": [0, 1]}), x="x", y="y")
    # .scale(color="crest")
    .theme(axes_style("whitegrid") | plotting_context("talk") | {"grid.linestyle": ":", "pdf.fonttype": 42})
    .layout(size=(10, 10)) # width, height
    .limit(x=(0, 1), y=(0, 1))
)

### Figure 4B

#### metSPAdes

In [15]:
import pandas as pd
import polars as pl
from pathlib import Path
import pytaxonkit
from Bio import SeqIO
from collections import defaultdict
import polars.selectors as cs
import seaborn.objects as so
from seaborn import axes_style
from seaborn import plotting_context

depth_df = []

folder_depth = Path("/mnt/archgen/microbiome_coprolite/aVirus/05-results/taxonomic_annotation/empirical_data_bwa/vclust/depth_breadth/")

# sorted_files = sorted(folder_depth.glob("*vclust.ani.depth_breadth.tsv"))
sorted_files = sorted(folder_depth.glob("*_coverage.tsv"))

for file in sorted_files:
    depth_df.append(pl.read_csv(file, separator="\t"))
    depth_df[-1] = depth_df[-1].rename(
        {
            "coverage": "Breadth",
            "reference": "Contig",
        }
    ).with_columns(
        pl.col('Breadth').cast(pl.Float32) / 100,
        # pl.col('Depth avg').cast(pl.Float32),
        # pl.col('Depth median').cast(pl.Float32),
        sample=pl.lit(file.stem.split('.')[0]),
        # software=pl.lit(file.stem.split('.')[1]), # [1] for cmseq
        software=pl.lit(file.stem.split('.')[2]), # [2] for vclust direct
    )

depth_df = pl.concat(depth_df)

depth_df = depth_df.rename(
    {
        "Contig": "seqid",
    }
)

depth_df = depth_df.filter(
    ~pl.col('seqid').eq('J02482')
)

depth_df = depth_df.with_columns(
    seqid_sample=pl.col('seqid') + pl.lit('_') + pl.col('sample'),
)

depth_df = depth_df.filter(
    pl.col('Breadth').gt(0.5),
    pl.col('software').eq('metaspades'),
)


# List of columns to ensure they exist
required_columns = ['ASM001', 'ASM002', 'ASM003', 'BSM001', 'BSM002', 
                    'HSM001', 'HSM002', 'HSM003', 'HSM004', 'ZSM005', 
                    'ZSM025', 'ZSM027', 'ZSM028', 'ZSM031', 'ZSM101', 
                    'ZSM102', 'ZSM103', 'ZSM214', 'ZSM216', 'ZSM219']
                    


metadata = pl.read_csv(
    '/mnt/archgen/microbiome_coprolite/aVirus/03-data/refdbs/ICTV/ICTV_database/20240209/ICTV_metadata.tsv',
    separator="\t"
)

metadata = metadata.rename(
    {
        'Virus GENBANK accession': 'seqid',
        'Host source': 'host',
    }
)

seq2taxid = {}

with open ('/mnt/archgen/microbiome_coprolite/aVirus/03-data/refdbs/centrifuge_db/ICTV_decoy/seqid2taxid.map') as f:
    for line in f:
        seqid, taxid = line.strip().split('\t')
        seq2taxid[seqid] = taxid

metadata = metadata.with_columns(
    taxid=pl.col('seqid').replace_strict(
        seq2taxid,
        default='0'
    ),
)

metadata = metadata.rename(
    {
        'taxid': 'taxonomy_id',
    }
)

depth_df_total = depth_df.join(metadata[['seqid', 'host', 'Species', 'Genus', 'taxonomy_id']].unique('seqid'), on='seqid', how='left')

depth_df_total = depth_df_total.filter(~pl.col('taxonomy_id').is_null())

# To attribute the coverage to only one species, I will keep the seqid with the higest breadth value among the seqids that belong to the same genus
depth_df_total = depth_df_total.sort(['Genus', 'sample', 'Breadth'], descending=True).unique(['Genus', 'sample']).sort(['Species'])

depth_df_total = depth_df_total.pivot(
    index=['taxonomy_id'],
    on='sample',
    values='Breadth',
    aggregate_function='max',
).fill_null(0)

depth_df_total = depth_df_total.join(metadata[[
    'seqid', 'host', 'Species', 'Genus', 'taxonomy_id'
    ]].unique('taxonomy_id'), on='taxonomy_id', how='left')

depth_df_total = depth_df_total.drop_nulls(subset=['host'])

# List of columns to ensure they exist
required_columns = ['ASM001', 'ASM002', 'ASM003', 'BSM001', 'BSM002', 
                    'HSM001', 'HSM002', 'HSM003', 'HSM004', 'ZSM005', 
                    'ZSM025', 'ZSM027', 'ZSM028', 'ZSM031', 'ZSM101', 
                    'ZSM102', 'ZSM103', 'ZSM214', 'ZSM216', 'ZSM219']

# Check and add missing columns
for col in required_columns:
    if col not in depth_df_total.columns:
        depth_df_total = depth_df_total.with_columns(pl.lit(0).alias(col))


# Now you can proceed with your operations
depth_df_total = depth_df_total.select(
    [
        'seqid', 'Species', 'Genus', 'taxonomy_id', 'host'
    ] + required_columns
)

depth_bool = depth_df_total.with_columns(
                pl.col("^[ABHZ]SM[0-9][0-9][0-9]$").gt(0.5) * 1
            ).sort('Species')


df_presence = depth_bool.rename(
    {
        'Species': 'name',
    }
)

df_presence_prokaryote = df_presence.filter(
    pl.col('host').is_in(['bacteria', 'archaea'])
    ).select(['name','^[ABHZ]SM[0-9][0-9][0-9]$' ]).unpivot(
    on=cs.numeric(),
    index='name',
    value_name='presence',
    variable_name='site',
)

df_presence_prokaryote = df_presence_prokaryote.with_columns(
    pl.col('site').str.head(3)
)

df_presence_prokaryote_species = df_presence_prokaryote.group_by(['name']).agg(pl.col('presence').sum()).filter(pl.col('presence') > 0).sort('presence', descending=True)

df_presence_prokaryote_species = df_presence_prokaryote_species.rename(
    {
        'presence': 'total',
    }
)

df_presence_prokaryote = df_presence_prokaryote.group_by(['name', 'site']).agg(pl.col('presence').sum()).filter(pl.col('presence') > 0).sort('presence', descending=True)

df_presence_prokaryote = df_presence_prokaryote_species.top_k(by='total', k=50).join(
    df_presence_prokaryote, left_on='name', right_on='name', how='left'
)

df_presence_prokaryote = df_presence_prokaryote.with_columns(
    pl.col('site').replace_strict(
        {
            'ASM': 'Arid West Cave',
            'BSM': 'Boomerang Shelter',
            'HSM': 'Hallstatt',
            'ZSM': 'Zape',
        }
    )
).rename(
    {
        'site': 'Site',
        'presence': 'Number of samples the species is present in',
        'name': 'Species',
    }
).sort(['total','Site'], descending=[True, False])

df_presence_prokaryote = df_presence_prokaryote.filter(~pl.col('Species').eq('Sinsheimervirus phiX174'))

p = (
    # so.Plot(df_presence_prokaryote, y="name", x="presence", color="site")
    so.Plot(df_presence_prokaryote, y="Species", x="Number of samples the species is present in", color="Site")
    .add(so.Bar(), so.Stack())
    .theme(axes_style("whitegrid") | plotting_context("talk") | {"grid.linestyle": ":", "pdf.fonttype": 42})
    .scale(color=so.Nominal(values=["#930c02", "#5d8300","#008f93", "#755092"], #["#de1e10", "#7cae00","#00bfc4", "#9c6ac2"]
                            order=["Arid West Cave", "Boomerang Shelter", "Hallstatt", "Zape"]), #["ASM", "BSM", "HSM", "ZSM"]
            )
    .layout(size=(10, 9)) # width, height
)


#### MEGAHIT

In [11]:
import pandas as pd
import polars as pl
from pathlib import Path
import pytaxonkit
from Bio import SeqIO
from collections import defaultdict
import polars.selectors as cs
import seaborn.objects as so
from seaborn import axes_style
from seaborn import plotting_context

depth_df = []

folder_depth = Path("/mnt/archgen/microbiome_coprolite/aVirus/05-results/taxonomic_annotation/empirical_data_bwa/vclust/depth_breadth/")

# sorted_files = sorted(folder_depth.glob("*vclust.ani.depth_breadth.tsv"))
sorted_files = sorted(folder_depth.glob("*_coverage.tsv"))

for file in sorted_files:
    depth_df.append(pl.read_csv(file, separator="\t"))
    depth_df[-1] = depth_df[-1].rename(
        {
            "coverage": "Breadth",
            "reference": "Contig",
        }
    ).with_columns(
        pl.col('Breadth').cast(pl.Float32) / 100,
        # pl.col('Depth avg').cast(pl.Float32),
        # pl.col('Depth median').cast(pl.Float32),
        sample=pl.lit(file.stem.split('.')[0]),
        # software=pl.lit(file.stem.split('.')[1]), # [1] for cmseq
        software=pl.lit(file.stem.split('.')[2]), # [2] for vclust direct
    )

depth_df = pl.concat(depth_df)

depth_df = depth_df.rename(
    {
        "Contig": "seqid",
    }
)

depth_df_contigs = depth_df.filter(
    ~pl.col('seqid').eq('J02482')
)

depth_df = depth_df.with_columns(
    seqid_sample=pl.col('seqid') + pl.lit('_') + pl.col('sample'),
)

depth_df = depth_df.filter(
    pl.col('Breadth').gt(0.5),
    pl.col('software').eq('megahit'),
)


# List of columns to ensure they exist
required_columns = ['ASM001', 'ASM002', 'ASM003', 'BSM001', 'BSM002', 
                    'HSM001', 'HSM002', 'HSM003', 'HSM004', 'ZSM005', 
                    'ZSM025', 'ZSM027', 'ZSM028', 'ZSM031', 'ZSM101', 
                    'ZSM102', 'ZSM103', 'ZSM214', 'ZSM216', 'ZSM219']
                    


metadata = pl.read_csv(
    '/mnt/archgen/microbiome_coprolite/aVirus/03-data/refdbs/ICTV/ICTV_database/20240209/ICTV_metadata.tsv',
    separator="\t"
)

metadata = metadata.rename(
    {
        'Virus GENBANK accession': 'seqid',
        'Host source': 'host',
    }
)

seq2taxid = {}

with open ('/mnt/archgen/microbiome_coprolite/aVirus/03-data/refdbs/centrifuge_db/ICTV_decoy/seqid2taxid.map') as f:
    for line in f:
        seqid, taxid = line.strip().split('\t')
        seq2taxid[seqid] = taxid

metadata = metadata.with_columns(
    taxid=pl.col('seqid').replace_strict(
        seq2taxid,
        default='0'
    ),
)

metadata = metadata.rename(
    {
        'taxid': 'taxonomy_id',
    }
)

depth_df_total = depth_df.join(metadata[['seqid', 'host', 'Species', 'Genus', 'taxonomy_id']].unique('seqid'), on='seqid', how='left')

depth_df_total = depth_df_total.filter(~pl.col('taxonomy_id').is_null())

# To attribute the coverage to only one species, I will keep the seqid with the higest breadth value among the seqids that belong to the same genus
depth_df_total = depth_df_total.sort(['Genus', 'sample', 'Breadth'], descending=True).unique(['Genus', 'sample']).sort(['Species'])

depth_df_total = depth_df_total.pivot(
    index=['taxonomy_id'],
    on='sample',
    values='Breadth',
    aggregate_function='max',
).fill_null(0)

depth_df_total = depth_df_total.join(metadata[[
    'seqid', 'host', 'Species', 'Genus', 'taxonomy_id'
    ]].unique('taxonomy_id'), on='taxonomy_id', how='left')

depth_df_total = depth_df_total.drop_nulls(subset=['host'])

# List of columns to ensure they exist
required_columns = ['ASM001', 'ASM002', 'ASM003', 'BSM001', 'BSM002', 
                    'HSM001', 'HSM002', 'HSM003', 'HSM004', 'ZSM005', 
                    'ZSM025', 'ZSM027', 'ZSM028', 'ZSM031', 'ZSM101', 
                    'ZSM102', 'ZSM103', 'ZSM214', 'ZSM216', 'ZSM219']

# Check and add missing columns
for col in required_columns:
    if col not in depth_df_total.columns:
        depth_df_total = depth_df_total.with_columns(pl.lit(0).alias(col))


# Now you can proceed with your operations
depth_df_total = depth_df_total.select(
    [
        'seqid', 'Species', 'Genus', 'taxonomy_id', 'host'
    ] + required_columns
)

depth_bool = depth_df_total.with_columns(
                pl.col("^[ABHZ]SM[0-9][0-9][0-9]$").gt(0.5) * 1
            ).sort('Species')


df_presence = depth_bool.rename(
    {
        'Species': 'name',
    }
)

df_presence_prokaryote = df_presence.filter(
    pl.col('host').is_in(['bacteria', 'archaea'])
    ).select(['name','^[ABHZ]SM[0-9][0-9][0-9]$' ]).unpivot(
    on=cs.numeric(),
    index='name',
    value_name='presence',
    variable_name='site',
)

df_presence_prokaryote = df_presence_prokaryote.with_columns(
    pl.col('site').str.head(3)
)

df_presence_prokaryote_species = df_presence_prokaryote.group_by(['name']).agg(pl.col('presence').sum()).filter(pl.col('presence') > 0).sort('presence', descending=True)

df_presence_prokaryote_species = df_presence_prokaryote_species.rename(
    {
        'presence': 'total',
    }
)

df_presence_prokaryote = df_presence_prokaryote.group_by(['name', 'site']).agg(pl.col('presence').sum()).filter(pl.col('presence') > 0).sort('presence', descending=True)

df_presence_prokaryote = df_presence_prokaryote_species.top_k(by='total', k=50).join(
    df_presence_prokaryote, left_on='name', right_on='name', how='left'
)

df_presence_prokaryote = df_presence_prokaryote.with_columns(
    pl.col('site').replace_strict(
        {
            'ASM': 'Arid West Cave',
            'BSM': 'Boomerang Shelter',
            'HSM': 'Hallstatt',
            'ZSM': 'Zape',
        }
    )
).rename(
    {
        'site': 'Site',
        'presence': 'Number of samples the species is present in',
        'name': 'Species',
    }
).sort(['total','Site'], descending=[True, False])

df_presence_prokaryote = df_presence_prokaryote.filter(~pl.col('Species').eq('Sinsheimervirus phiX174'))

p = (
    # so.Plot(df_presence_prokaryote, y="name", x="presence", color="site")
    so.Plot(df_presence_prokaryote, y="Species", x="Number of samples the species is present in", color="Site")
    .add(so.Bar(), so.Stack())
    .theme(axes_style("whitegrid") | plotting_context("talk") | {"grid.linestyle": ":", "pdf.fonttype": 42})
    .scale(color=so.Nominal(values=["#930c02", "#5d8300","#008f93", "#755092"], #["#de1e10", "#7cae00","#00bfc4", "#9c6ac2"]
                            order=["Arid West Cave", "Boomerang Shelter", "Hallstatt", "Zape"]), #["ASM", "BSM", "HSM", "ZSM"]
            )
    .layout(size=(10, 9)) # width, height
)


### Figure 4C

Used the snakemake pipeline `PLOT_contigs_over_reference`

## Figure 5

### Figure 5A

In [None]:
import polars as pl
from pathlib import Path
import gzip
from Bio import SeqIO

custom_style = {"text.color": "#131516",
                "svg.fonttype": "none",
                # "font.family": "sans-serif",
                # "font.weight": "light",
                "axes.spines.right": False,
                "axes.spines.top": False,
                # "axes.spines.bottom": False,
                'xtick.bottom': False,
                "pdf.fonttype": 42
                }

import seaborn as sns ; sns.set_theme(style="ticks", rc=custom_style)  # for plot styling
import matplotlib.pyplot as plt
import seaborn.objects as so

### Genome that were determine either contamination of containing some bacterial DNA for some not well defined provirus

all_bacteria_ICTV = [
    "AE006468",
    "AF049230",
    "BD143114",
    "BD269513",
    "BX897699",
    "CP000031",
    "CP000830",
    "CP001312",
    "CP001357",
    "CP006891",
    "CP014526",
    "CP015418",
    "CP019275",
    "CP023680",
    "CP023686",
    "CP038625",
    "CP052639",
    "JAEILC010000038",
    "KK213166",
    "M18706",
    "QUVN01000024",
    "U68072",
    "U96748",
]
provirus_not_well_defined = [
    "AY319521",
    "EF462197",
    "EF462198",
    "EF710638",
    "EF710642",
    "FJ184280",
    "FJ188381",
    "HG424323",
    "HM208303",
    "J02013",
    "JQ347801",
    "K02712",
    "KF147927",
    "KF183314",
    "KF183315",
    "KP972568",
    "KX232515",
    "KX452695",
    "MK075003",
    "V01201",
]
unverified = [
    "DQ188954",
    "HM543472",
    "JQ407224",
    "KC008572",
    "KC626021",
    "KF302037",
    "KF360047",
    "KF938901",
    "KJ641726",
    "KM233624",
    "KM389459",
    "KM982402",
    "KP843857",
    "KR862307",
    "KU343148",
    "KU343149",
    "KU343150",
    "KU343151",
    "KU343152",
    "KU343153",
    "KU343154",
    "KU343155",
    "KU343156",
    "KU343160",
    "KU343161",
    "KU343162",
    "KU343163",
    "KU343164",
    "KU343165",
    "KU343169",
    "KU343170",
    "KU343171",
    "KU672593",
    "KU752557",
    "KX098515",
    "KX228196",
    "KX228197",
    "KX228198",
    "KX363561",
    "KX452696",
    "KX452698",
    "KX656670",
    "KX656671",
    "KX989546",
    "KY450753",
    "KY487839",
    "KY608967",
    "KY742649",
    "MG459218",
    "MG551742",
    "MG599035",
    "MH791395",
    "MH791402",
    "MH791405",
    "MH791410",
    "MH791412",
    "MH918795",
    "MH925094",
    "MH992121",
    "MK033136",
    "MK050014",
    "MK415316",
    "MK415317",
    "MK474470",
    "MK780203",
    "MN545971",
    "MN871450",
    "MN871491",
    "MN871495",
    "MN871498",
    "MN928506",
    "MT360681",
    "MT360682",
    "MW325771",
    "MW685514",
    "MW685515",
]

all_unwanted = all_bacteria_ICTV + provirus_not_well_defined + unverified

jeager_df = pl.read_csv(
    "/mnt/archgen/microbiome_coprolite/aVirus/05-results/taxonomic_annotation/empirical_data_bwa/simulation_genomad/jeager/*/*/*_default_phages_jaeger.tsv",
    separator="\t"
)

jeager_df = jeager_df.with_columns(
    species_id = pl.col('contig_id').str.split('_').list.get(0)
).filter(
    ~pl.col('species_id').is_in(all_unwanted)
)

path_genomad = Path("/mnt/archgen/microbiome_coprolite/aVirus/05-results/taxonomic_annotation/empirical_data_bwa/simulation_genomad/genomad")

results = path_genomad.glob("*/*_summary/*_virus_summary.tsv")

genomad_df = []

for file in results:
    df = pl.read_csv(file, separator="\t")
    df = df.with_columns(
        species_id = pl.col('seq_name').str.split('_').list.get(0)
    ).rename(
        {
            "seq_name": "contig_id",
        }
    )

    genomad_df.append(df)

genomad_df = pl.concat(genomad_df)

genomad_df = genomad_df.filter(
    ~pl.col('species_id').is_in(all_unwanted)
)

path_genomad_default = Path("/mnt/archgen/microbiome_coprolite/aVirus/05-results/taxonomic_annotation/empirical_data_bwa/simulation_genomad/genomad_default")

results = path_genomad_default.glob("*/*_summary/*_virus_summary.tsv")

genomad_default_df = []

for file in results:
    df = pl.read_csv(file, separator="\t")
    df = df.with_columns(
        pl.col('length').cast(pl.Int64),
        species_id = pl.col('seq_name').str.split('_').list.get(0)
    ).rename(
        {
            "seq_name": "contig_id",
        }
    )

    genomad_default_df.append(df)

genomad_default_df = pl.concat(genomad_default_df).filter(
    ~pl.col('species_id').is_in(all_unwanted)
)

path2contigs = Path("/mnt/archgen/microbiome_coprolite/aVirus/03-data/simulated_viral_contigs")

list_contigs_num = []

for file in path2contigs.glob("*.fna.gz"):
    with gzip.open(file, "rb") as f:
        contigs = f.read()
        count = contigs.count(b">")
        list_contigs_num.append((file.name.split(".")[0], count))

df_contigs_num = pl.DataFrame(list_contigs_num, schema=["species_id", "length"], orient="row")


df_contigs_name = []

for file in path2contigs.glob("*.fna.gz"):
    with gzip.open(file, "rt") as f:
        parsed = SeqIO.parse(f, "fasta")
        for record in parsed:
            contig_name = record.id.replace("+", "F").replace(":-", "_R").replace(":", "_").replace("-", "_")
            df_contigs_name.append((file.name.split(".")[0], contig_name, len(record.seq)))

df_contigs_name = pl.DataFrame(df_contigs_name, schema=["species_id", "contig_id", "length"], orient="row")

df_contigs_num = df_contigs_num.filter(
    ~pl.col("species_id").is_in(all_unwanted)
)

jeager_species_count = jeager_df.group_by("species_id").agg(
    pl.count("species_id").alias("count_jeager")
)

genomad_species_count = genomad_df.group_by("species_id").agg(
    pl.count("species_id").alias("count_genomad")
)

df_contigs_num = df_contigs_num.join(
    jeager_species_count,
    on="species_id",
    how="left"
).join(
    genomad_species_count,
    on="species_id",
    how="left"
).fill_null(0)

df_contigs_num = df_contigs_num.with_columns(
    pl.col("count_jeager").cast(pl.Int64),
    pl.col("count_genomad").cast(pl.Int64)
)

df_contigs_name = df_contigs_name.with_columns(
    genomad = pl.col("contig_id").is_in(genomad_df["contig_id"]),
    jeager = pl.col("contig_id").is_in(jeager_df["contig_id"]),
    genomad_default = pl.col("contig_id").is_in(genomad_default_df["contig_id"]),
)

# Assume df is your polars dataframe with columns:
# "species_id", "contig_id", "length", "genomad", "jeager"

# Melt the dataframe to long format so that we have a column "method" 
# (with values "genomad" or "jeager") and a column "viral" (True/False).
df_long = df_contigs_name.unpivot(
    index=["species_id", "contig_id", "length"],
    on=["genomad", "genomad_default", "jeager"],
    variable_name="method",
    value_name="viral"
)

from seaborn import axes_style, plotting_context

# Create the plot:
# - x-axis: contig length
# - color: detection method (genomad vs jeager)
# - facet columns: viral status (True for detected, False for not detected)
p = (
    so.Plot(df_long, x="length", color="method")
    .add(so.Bars(), so.Hist(bins=30))
    .facet(col="viral")
    .theme(axes_style("whitegrid") | plotting_context("talk") | {"grid.linestyle": ":", "pdf.fonttype": 42})
    .layout(size=(10, 6))
)




### Figure 5B

In [None]:
import polars as pl
import matplotlib.pyplot as plt
from typing import List

def plot_contig_matches(blast_tsv: str, outfig: str, list_viral: List[str], reference_id: str = "MG711460"):
    """
    Plot contig matches against a reference sequence
    
    Parameters:
    -----------
    blast_tsv : str
        Path to TSV file containing blast results with columns query, reference, rstart, rend
    outfig : str
        Path to save the output figure
    list_viral : List[str]
        Name of the contigs detected as viral
    reference_id : str
        ID of the reference sequence to plot matches for (default: MG711460)
    """
    
    reference_len = 36636

    # Read blast results
    df = pl.read_csv(blast_tsv, separator="\t")
    
    # Filter for reference sequence
    df = df.filter(pl.col("reference").eq(reference_id))
    
    # Create plot
    # Create the plot
    p = plt.figure(figsize=(15, 5))

    # Add reference line at y=0
    plt.hlines(y=-0.9, xmin=0, xmax=reference_len, color='grey', linestyle='-', linewidth=5)

    # Create y-position mapping for unique query names
    unique_queries = sorted(df['query'].unique())
    y_positions = {query: i for i, query in enumerate(unique_queries)}

    # Plot rectangles for each match
    for row in df.to_dicts():
        y_pos = y_positions[row['query']]
        if row['query'] in list_viral:
            plt.hlines(y_pos, row['rstart'], row['rend'], linewidth=5, color='blue')
        else:
            plt.hlines(y_pos, row['rstart'], row['rend'], linewidth=5, color='red')
        
    # Customize plot
    plt.grid(True, linestyle=':', alpha=0.3)
    plt.ylabel('Contigs')
    plt.xlabel('Position on reference')
    plt.title(f'Contig matches on {reference_id}')
    plt.yticks(range(len(unique_queries)), unique_queries)
    plt.savefig(outfig, bbox_inches='tight')
    return 


In [None]:
genomad_megahit_BSM001 = pl.read_csv(
    "/mnt/archgen/microbiome_coprolite/aVirus/05-results/taxonomic_annotation/empirical_data_bwa/genomad/BSM001-megahit/BSM001-megahit_summary/BSM001-megahit_virus_summary.tsv",
    separator="\t"
)

plot_contig_matches(
    "/mnt/archgen/microbiome_coprolite/aVirus/05-results/taxonomic_annotation/empirical_data_bwa/vclust/BSM001.contigs.megahit.aln.tsv",
    outfig = "/mnt/archgen/microbiome_coprolite/aVirus/05-results/taxonomic_annotation/empirical_data_bwa/genomad/BSM001.contigs.megahit.genomad.pdf",
    list_viral = genomad_megahit_BSM001['seq_name'].unique().to_list()
)

### Figure 5C

In [None]:
import polars as pl
from pathlib import Path
import seaborn.objects as so
from seaborn import axes_style, plotting_context
import matplotlib.pyplot as plt
import os

custom_style = {"text.color": "#131516",
                "svg.fonttype": "none",
                # "font.family": "sans-serif",
                # "font.weight": "light",
                "axes.spines.right": False,
                "axes.spines.top": False,
                # "axes.spines.bottom": False,
                'xtick.bottom': False,
                "pdf.fonttype": 42
                }

import seaborn as sns ; sns.set_theme(style="ticks", rc=custom_style)  # for plot styling

def plot_checkv_quality(files, genomad_folder, decoy):
    # Initialize empty list to store dataframes
    dfs = []
    dfs_vclust = []
    genomad_summaries = []
    
    # Process each file
    for file in files:
        # Get sample and software from parent folder name
        sample, software = file.parent.name.split('-')
        
        # Read file and add sample/software columns
        pydamage_csv = os.path.join(genomad_folder, f"pydamage/{sample}-{software}/pydamage_filtered_results.csv")
        vclust_csv = os.path.join(genomad_folder, f"{sample}-{software}/vclust/{sample}.contigs.{software}.aln.tsv")
        genomad_summary = os.path.join(genomad_folder, f"{sample}-{software}/{sample}-{software}_summary/{sample}-{software}_virus_summary.tsv")
        
        # remove df with no ancient DNA
        if not os.path.exists(pydamage_csv):
            continue

        list_ancient = pl.read_csv(pydamage_csv)['reference'].to_list()

        # Genomad df
        df = pl.read_csv(file, separator="\t", null_values=["NA", "NaN", "nan", "N/A", "n/a", "N/A"]).fill_null(0)
        df = df.with_columns([
            pl.lit(sample).alias('sample'),
            pl.lit(software).alias('software')
        ]).filter(
            pl.col('contig_id').is_in(list_ancient),
        ).with_columns(
            pl.col('proviral_length').cast(pl.Int64),
        )
        
        dfs.append(df)
    
        # Vclust df
        df_vclust = pl.read_csv(vclust_csv, separator="\t")
        df_vclust = df_vclust.with_columns([
            pl.lit(sample).alias('sample'),
            pl.lit(software).alias('software')
        ]).filter(
            pl.col('query').is_in(list_ancient),
            pl.col('pident').gt(90),
            pl.col("query").str.starts_with('NODE'),
            ~pl.col('reference').str.starts_with('NODE'),
        ).unique(
            subset=['query'], maintain_order=True
        ).join(
            df, left_on='query', right_on='contig_id', how='left'
        )
        
        dfs_vclust.append(df_vclust)

        # Genomad summary df
        df_genomad = pl.read_csv(genomad_summary, separator="\t", null_values=["NA", "NaN", "nan", "N/A", "n/a", "N/A"]).fill_null(0)
        df_genomad = df_genomad.with_columns([
            pl.lit(sample).alias('sample'),
            pl.lit(software).alias('software')
        ]).filter(
            pl.col('seq_name').is_in(list_ancient),
        ).with_columns(
            pl.col('length').cast(pl.Int64),
        ).rename(
            {
                "seq_name": "contig_id",
            }
        )

        genomad_summaries.append(df_genomad)


    # Combine all dataframes
    combined_df = pl.concat(dfs)

    combined_vclust_df = pl.concat(dfs_vclust)

    combined_genomad_summary = pl.concat(genomad_summaries)
    
    contamination = combined_vclust_df.filter(
        pl.col('reference').is_in(decoy),
    )

    nodes_contamination = contamination['query'].unique().to_list()

    # Here we try to replace in checkv quality to bacterial contamination if contigs in nodes_contamination
    combined_vclust_df = combined_vclust_df.with_columns(
        checkv_quality_contamination=pl.col('query').replace_strict(
            old=nodes_contamination,
            new="Bacterial contamination",
            default=pl.col('checkv_quality')
        )
    ).filter(
        ~pl.col("reference").eq("J02482"),
    )

    # print(combined_df.head())
    combined_df = combined_df.with_columns(
        checkv_quality_contamination=pl.col('contig_id').replace_strict(
            old=nodes_contamination,
            new="Bacterial contamination",
            default=pl.col('checkv_quality')
        )
    )

    combined_df = combined_df.join(
        combined_genomad_summary,
        on=['contig_id', 'sample', 'software'], 
        how='left'
    ).filter(
        ~pl.col("taxonomy").str.contains("Sinsheimervirus"),
    )

    # Plot without the objects interface
    fig, ax = plt.subplots()

    p = sns.barplot(
        combined_df.group_by(['software', 'checkv_quality_contamination']).len().sort(['software', 'checkv_quality_contamination']), 
        x='len', 
        y='checkv_quality_contamination', 
        hue='software', 
        dodge=True,
        order=['Complete', 'High-quality', 'Medium-quality', 'Low-quality', 'Not-determined', 'Bacterial contamination'],
        alpha=0.5,
    )

    p = sns.barplot(
        combined_vclust_df.group_by(['software', 'checkv_quality_contamination']).len().sort(['software', 'checkv_quality_contamination']), 
        x='len', 
        y='checkv_quality_contamination', 
        hue='software', 
        dodge=True,
        order=['Complete', 'High-quality', 'Medium-quality', 'Low-quality', 'Not-determined', 'Bacterial contamination'],
        alpha=1,
    )

    ax.set_xscale("log")
    ax.set_xlabel("Number of contigs")
    ax.set_ylabel("CheckV quality")
    plt.xlim(0, 27000)
    ax.bar_label(p.containers[0], fontsize=10, fmt='%.0f')
    ax.bar_label(p.containers[1], fontsize=10, fmt='%.0f')
    ax.bar_label(p.containers[2], fontsize=10, fmt='%.0f')
    ax.bar_label(p.containers[3], fontsize=10, fmt='%.0f')

    plt.savefig(os.path.join(
        genomad_folder, "checkv_quality.pdf"
    ), bbox_inches='tight')

    return p, combined_vclust_df, combined_df, nodes_contamination




In [None]:
from Bio import SeqIO

viral_decoy = SeqIO.index("/mnt/archgen/microbiome_coprolite/aVirus/05-results/taxonomic_annotation/empirical_data/top_viruses/viruses_decoy.fasta", "fasta")
viral = SeqIO.index("/mnt/archgen/microbiome_coprolite/aVirus/05-results/taxonomic_annotation/empirical_data/top_viruses/viruses.fasta", "fasta")

list_decoy = set(viral_decoy.keys()) - set(viral.keys())

# some of the contigs in the decoy might be virus and or prophages
putative_phage_in_decoy = pl.read_csv(
    "/mnt/archgen/microbiome_coprolite/aVirus/05-results/taxonomic_annotation/empirical_data_bwa/decoy/conservative/all_50_bacterial_decoy_summary/all_50_bacterial_decoy_virus_summary.tsv",
    separator="\t"
)['seq_name'].unique().to_list()

putative_phage_in_decoy = list(i.split('|')[0] for i in putative_phage_in_decoy)

list_decoy = list_decoy - set(putative_phage_in_decoy)

In [None]:
analysis_foler = Path("/mnt/archgen/microbiome_coprolite/aVirus/05-results/taxonomic_annotation/empirical_data_bwa/checkv_genomad/")


p, df_vclust, df_genomad, contam = plot_checkv_quality(
    analysis_foler.glob("*/quality_summary.tsv"),
    genomad_folder = '/mnt/archgen/microbiome_coprolite/aVirus/05-results/taxonomic_annotation/empirical_data_bwa/genomad',
    decoy = list_decoy
)

In [None]:
analysis_foler = Path("/mnt/archgen/microbiome_coprolite/aVirus/05-results/taxonomic_annotation/empirical_data_bwa/checkv_genomad_default/")

p, df_vclust_default, df_genomad_default, contam_default = plot_checkv_quality(
    analysis_foler.glob("*/quality_summary.tsv"),
    genomad_folder = '/mnt/archgen/microbiome_coprolite/aVirus/05-results/taxonomic_annotation/empirical_data_bwa/genomad_default',
    decoy = list_decoy
)

### Figure SX

It reused some dataframe from the figure 5C. To have the genomad with relaxed parameters just changed the name of the dataframe to df_genomad

In [None]:
# Split taxonomy into different levels
df_genomad_default = df_genomad_default.with_columns(
    pl.col('taxonomy').replace_strict(
        old="Unclassified",
        new=";"*8,
        default=pl.col('taxonomy')
    )
).with_columns([
    pl.col('taxonomy').str.split(';').list.get(0).alias('root'),
    pl.col('taxonomy').str.split(';').list.get(1).alias('realm'),
    pl.col('taxonomy').str.split(';').list.get(2).alias('kingdom'),
    pl.col('taxonomy').str.split(';').list.get(3).alias('phylum'),
    pl.col('taxonomy').str.split(';').list.get(4).alias('class'),
    pl.col('taxonomy').str.split(';').list.get(5).alias('order'),
    pl.col('taxonomy').str.split(';').list.get(6).alias('family'),
    pl.col('taxonomy').str.split(';').list.get(7).alias('genus'),
    pl.col('taxonomy').str.split(';').list.get(8).alias('species')
])

In [None]:
# First show a row with all taxonomy fields to verify the logic
taxonomy_fields = ['root', 'realm', 'kingdom', 'phylum', 'class', 'order', 'family', 'genus', 'species']

# Get the last non-empty taxonomic level for each row
df_genomad_default = df_genomad_default.with_columns(
    last_taxonomy=pl.concat_list(taxonomy_fields)
    .list.eval(pl.element().filter(pl.element() != ""))
    .list.last()
).filter(
    ~pl.col("contig_id").is_in(df_vclust_default['query']),
)

# Create a frequency table first
taxonomy_counts_default = df_genomad_default.filter(
    pl.col("checkv_quality").is_in(
        ["Complete", "High-quality"]),
    ).group_by(["software", "last_taxonomy"]).agg(pl.len().alias("count")).sort("count", descending=True)

# Create the plot
p = (
    so.Plot(taxonomy_counts_default, y="last_taxonomy", x="count", color="software")
    .add(so.Bar(), so.Dodge())
    .theme(axes_style("whitegrid") | plotting_context("talk") | {"grid.linestyle": ":", "pdf.fonttype": 42})
    .layout(size=(10, 8))
)

### Figure 5D

In [None]:
import polars as pl
import matplotlib.pyplot as plt
from typing import List
import matplotlib.colors as mcolors

def plot_contig_bins(blast_tsv: str, outfig: str, bin_df: pl.DataFrame, reference_id: str = "MG711460", reference_len: int = 36636):
    """
    Plot contig bins against a reference sequence
    
    Parameters:
    -----------
    blast_tsv : str
        Path to TSV file containing blast results with columns query, reference, rstart, rend
    outfig : str
        Path to save the output figure
    list_viral : List[str]
        Name of the contigs detected as viral
    reference_id : str
        ID of the reference sequence to plot matches for (default: MG711460)
    """

    # Read blast results
    df = pl.read_csv(blast_tsv, separator="\t")
    
    # Filter for reference sequence
    df = df.filter(pl.col("reference").eq(reference_id))
    
    bin_df = bin_df.filter(
        pl.col("bin").is_in(
            bin_df.group_by("bin")
            .agg(pl.len().alias("count"))
            .filter(pl.col("count") > 1)
            .select("bin")
        )
    )

    # Read binning results
    bin_dict = bin_df.filter(
        pl.col("contig").is_in(df['query'].unique())
    ).rows_by_key(key=["contig"], unique=True)

    # Get unique clusters in sorted order (for reproducibility)
    unique_clusters = sorted(set(bin_dict.values()))
    num_clusters = len(unique_clusters)

    # Create a colormap (e.g., 'tab10', 'viridis', etc.)
    if num_clusters <= 12:
        cmap = plt.cm.get_cmap("Paired", num_clusters)
    else:
        cmap = plt.cm.get_cmap("tab20", num_clusters)

    # Map each unique cluster to a hex color using the colormap
    color_map = {cluster: mcolors.to_hex(cmap(i)) for i, cluster in enumerate(unique_clusters)}

    # Map each contig name to its corresponding color
    contig_colors = {contig: color_map[cluster] for contig, cluster in bin_dict.items()}

    # Create plot
    # Create the plot
    p = plt.figure(figsize=(15, 5))

    # Add reference line at y=0
    plt.hlines(y=-0.9, xmin=0, xmax=reference_len, color='grey', linestyle='-', linewidth=5)

    # Create y-position mapping for unique query names
    unique_queries = sorted(df['query'].unique())
    y_positions = {query: i for i, query in enumerate(unique_queries)}

    # Plot rectangles for each match
    for row in df.to_dicts():
        y_pos = y_positions[row['query']]
        color = contig_colors.get(row['query'], 'grey')
        plt.hlines(y_pos, row['rstart'], row['rend'], linewidth=5, color=color)

        name_cluster = f"Cluster {bin_dict.get(row['query'], 'N/A')}"
        plt.text(row['rstart'] + (row['rend'] - row['rstart'])/2, y_pos + 0.5, name_cluster, fontsize=8, ha='center')
        
    # Customize plot
    plt.grid(True, linestyle=':', alpha=0.3)
    plt.ylabel('Contigs')
    plt.xlabel('Position on reference')
    plt.title(f'Contig matches on {reference_id}')
    plt.yticks(range(len(unique_queries)), unique_queries)
    plt.savefig(outfig, bbox_inches='tight')
    return 

Here an example of how to plot for a specific case

In [None]:
method = "assembled"
software = "metaspades"
sample = "BSM001"
binning = "semibin"
bin_file = "contig_bins.tsv"
site = sample[:3]

bin_df = pl.read_csv(
    f"/mnt/archgen/microbiome_coprolite/aVirus/05-results/taxonomic_annotation/empirical_data_bwa/binning/{method}/{software}/{binning}/{sample}-{software}/{bin_file}", 
    separator="\t")

plot_contig_bins(
    f"/mnt/archgen/microbiome_coprolite/aVirus/05-results/taxonomic_annotation/empirical_data_bwa/vclust/{sample}.contigs.{software}.aln.tsv",
    outfig = f"/mnt/archgen/microbiome_coprolite/aVirus/05-results/taxonomic_annotation/empirical_data_bwa/genomad/{sample}.contigs.{software}.binning.{method}.{binning}.pdf",
    bin_df = bin_df
)