# Prep sequences and metadata

In [None]:
from datetime import datetime
from dateutil.relativedelta import relativedelta

import Bio.SeqIO

import pandas as pd

def date_to_decimal_year(date_str):
    """Convert YYYY-MM-DD to decimal year"""
    dt = datetime.strptime(date_str, "%Y-%m-%d")
    year = dt.year
    start_of_year = datetime(year, 1, 1)
    end_of_year = datetime(year + 1, 1, 1)
    
    decimal_year = year + (dt - start_of_year).days / (end_of_year - start_of_year).days
    return decimal_year

In [None]:
# get variables from `snakemake`
sequence_set = snakemake.wildcards.sequence_set

input_fastas = snakemake.input.fastas
input_metadata_csv = snakemake.input.metadata

unaligned_seqs_fasta = snakemake.output.unaligned_seqs
output_metadata_csv = snakemake.output.metadata

metadata_req_cols = snakemake.params.metadata_req_cols
length_range = snakemake.params.length_range
drop_accessions = snakemake.params.drop_accessions

print(f"Analyzing {sequence_set=}")

In [None]:
# get the metadata
metadata = pd.read_csv(input_metadata_csv)
print(f"Initial {len(metadata)=}")

assert set(metadata_req_cols).issubset(metadata.columns), set(metadata_req_cols) - set(metadata.columns)

assert "date" in metadata_req_cols, metadata_req_cols

metadata = metadata[metadata_req_cols].assign(
    accession=lambda x: x["accession"].str.split(".").str[0],
    num_date=lambda x: x["date"].map(date_to_decimal_year)
)

assert metadata["accession"].nunique() == len(metadata)

print(f"{drop_accessions=}")
assert not any("." in acc for acc in drop_accessions)

metadata = metadata.query("accession not in @drop_accessions")
print(f"Final {len(metadata)=}")

print(f"Writing to {output_metadata_csv=}")
metadata.to_csv(output_metadata_csv, index=False, float_format="%.7g")

In [None]:
# get the sequences
seqs = []
seq_accessions = set()
metadata_accessions = set(metadata["accession"])
for fasta in input_fastas:
    print(f"Reading sequences from {fasta}")
    for seq in Bio.SeqIO.parse(fasta, "fasta"):
        accession = seq.id.split(".")[0]
        if accession in drop_accessions:
            continue
        if accession not in metadata_accessions:
            raise ValueError(f"{accession=} not in {metadata_accessions=}")
        if accession in seq_accessions:
            raise ValueError(f"Duplicate {accession=}")
        if not (length_range[0] <= len(seq) <= length_range[1]):
            raise ValueError(f"{accession=} has {len(seq)=}")
        seq_accessions.add(accession)
        seqs.append(f">{accession}\n{str(seq.seq)}\n")

if seq_accessions != metadata_accessions:
    raise ValueError(
        f"{(seq_accessions - metadata_accessions)=}\n"
        f"{(metadata_accessions - seq_accessions)=}"
    )

print(f"Writing {len(seqs)=} sequences to {unaligned_seqs_fasta=}")
with open(unaligned_seqs_fasta, "w") as f:
    f.write("\n".join(seqs))