In [None]:
from Bio import SeqIO
from pathlib import Path
import pandas as pd
import json
import os

In [None]:
def get_top_sequences(faa_path, top_n=10, min_len=100, max_len=2000):
    records = list(SeqIO.parse(faa_path, "fasta"))
    
    # Filter out sequences that are too short or too long
    filtered = [r for r in records if min_len <= len(r.seq) <= max_len]
    
    # Skip the file if fewer than top_n valid sequences are available
    if len(filtered) < top_n:
        return None

    # Select the top_n longest sequences
    sorted_records = sorted(filtered, key=lambda r: len(r.seq), reverse=True)
    return sorted_records[:top_n]

def process_all_faa(input_dir, output_dir, top_n=10, min_len=100, max_len=2000):
    input_dir = Path(input_dir)
    output_dir = Path(output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)

    faa_files = list(input_dir.glob("*.faa"))
    print(f"Found {len(faa_files)} .faa files")

    for faa_file in faa_files:
        top_seqs = get_top_sequences(faa_file, top_n=top_n, min_len=min_len, max_len=max_len)
        if top_seqs is None:
            print(f"Skipped {faa_file.name}: fewer than {top_n} valid sequences")
            continue

        output_file = output_dir / faa_file.name
        SeqIO.write(top_seqs, output_file, "fasta")
        print(f"Saved {len(top_seqs)} sequences from {faa_file.name}")

# Example usage
process_all_faa("../data/filtered_output", "../data/test_protein_sequences/selected_seqs_final", top_n=10, min_len=100, max_len=2000)

In [None]:
# Path settings
label_file = "../data/temperature_data.tsv"  # Contains taxid and temperature labels
faa_dir = Path("../data/test_protein_sequences/selected_seqs_final")  # Directory containing .faa files

# Read label file
df = pd.read_csv(label_file, sep="\t")
df["taxid"] = df["taxid"].astype(str).str.strip()

# Build a dict mapping taxid -> temperature
temperature_dict = dict(zip(df["taxid"], df["temperature"]))

strain_level_data = {}

# Iterate over all .faa files
for faa_file in faa_dir.glob("*.faa"):
    filename = faa_file.name
    taxid = filename.split("_")[0].strip()  # Extract taxid from the filename prefix

    if taxid not in temperature_dict:
        print(f"⚠️ taxid {taxid} not found in labels, skipping {filename}")
        continue

    sequences = []
    for record in SeqIO.parse(faa_file, "fasta"):
        seq = str(record.seq).strip()
        if seq:
            sequences.append(seq)

    if not sequences:
        print(f"⚠️ No valid sequences in {filename}, skipping")
        continue

    strain_level_data[taxid] = {
        "sequences": sequences,
        "label": float(temperature_dict[taxid])
    }

# Save as JSON
with open("strain_level_data.json", "w") as f:
    json.dump(strain_level_data, f, indent=2)

print(f"✅ Finished building dataset with {len(strain_level_data)} strains.")


In [None]:
# Load strain_level_data.json
with open("strain_level_data.json", "r") as f:
    strain_data_final = json.load(f)

# Flatten into a list while keeping strain_id
train_data_final = []

for strain_id, entry in strain_data_final.items():
    label = entry["label"]
    for seq in entry["sequences"]:
        train_data_final.append({
            "sequence": seq,
            "label": label,
            "strain_id": strain_id  
        })

# Save as JSON usable by Trainer
with open("train_data_final.json", "w") as f:
    json.dump(train_data_final, f, indent=2)

print(f" Finished building — total {len(train_data_final)} training samples")
