In [61]:
import torch
import json
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path

In [62]:
def load_data(out_dir, split, category):
    features = torch.load(f"{out_dir}/{split}_{category}_features.pt")
    labels = torch.load(f"{out_dir}/{split}_{category}_labels.pt")
    with open(f"{out_dir}/{split}_{category}_contigs.json") as f:
        contigs = json.load(f)
    return features, labels, contigs

In [63]:
def analyze_feature_shape(features: torch.Tensor):
    print(f"[INFO] Features shape: {features.shape}")

In [64]:
def analyze_label_distribution(labels: torch.Tensor):
    unique, counts = labels.unique(return_counts=True)
    print("[INFO] Label distribution:")
    for label, count in zip(unique.tolist(), counts.tolist()):
        print(f"  - Label {label}: {count} contigs")

In [65]:
def summarize_metadata_stats(features: torch.Tensor):
    metadata = features[:, :4]
    names = ["GC content", "Mean genome length", "Log contig length", "Ambiguity rate"]
    print("[INFO] Summary statistics for metadata features:")
    for i, name in enumerate(names):
        col = metadata[:, i].numpy()
        print(f"  - {name}: min={col.min():.4f}, max={col.max():.4f}, mean={col.mean():.4f}, std={col.std():.4f}")

In [66]:
def plot_gc_distribution(features: torch.Tensor, figure_dir: Path):
    gc_values = features[:, 0].numpy()
    plt.hist(gc_values, bins=50)
    plt.title("GC Content Distribution")
    plt.xlabel("Normalized GC Content")
    plt.ylabel("Frequency")
    plt.savefig(figure_dir / "gc_content_distribution.pdf")
    plt.clf()
    plt.close()

In [67]:
def plot_contig_length_distribution(features: torch.Tensor, figure_dir: Path):
    log_lengths = features[:, 2].numpy() * 20  # Reverting log1p/20 normalization
    lengths = np.expm1(log_lengths)
    plt.hist(lengths, bins=100, log=True)
    plt.title("Contig Length Distribution")
    plt.xlabel("Length (bp)")
    plt.ylabel("Frequency (log scale)")
    plt.savefig(figure_dir / "contig_length_distribution.pdf")
    plt.clf()
    plt.close()

In [68]:
def validate_kmer_sums(features: torch.Tensor, tolerance: float = 0.05):
    kmer_vectors = features[:, 4:]
    kmer_sums = kmer_vectors.sum(dim=1).numpy()
    mean_sum = np.mean(kmer_sums)
    std_sum = np.std(kmer_sums)
    outliers = np.sum((kmer_sums < 1 - tolerance) | (kmer_sums > 1 + tolerance))
    print(f"[INFO] K-mer frequency vector sums:")
    print(f"  - Mean: {mean_sum:.4f}, Std: {std_sum:.4f}")
    print(f"  - Outliers (sum < {1 - tolerance:.2f} or > {1 + tolerance:.2f}): {outliers} out of {len(kmer_sums)}")

In [69]:
def analyze_features(out_dir, split, category, figure_dir):
    print(f"[INFO] Analyzing features for {split} split of {category}")
    features, labels, contigs = load_data(out_dir, split, category)
    analyze_feature_shape(features)
    analyze_label_distribution(labels)
    summarize_metadata_stats(features)
    plot_gc_distribution(features, figure_dir)
    plot_contig_length_distribution(features, figure_dir)
    validate_kmer_sums(features)

In [70]:
out_dir = Path('../../results/outputs')
figures_dir = Path('../../figures/features')

In [71]:
analyze_features(out_dir, split='test', category='archaea', figure_dir=figures_dir)

[INFO] Analyzing features for test split of archaea
[INFO] Features shape: torch.Size([367499, 260])
[INFO] Label distribution:
  - Label 1: 367499 contigs
[INFO] Summary statistics for metadata features:
  - GC content: min=0.2205, max=0.7332, mean=0.4780, std=0.0928
  - Mean genome length: min=0.0015, max=6.0353, mean=0.0085, std=0.0463
  - Log contig length: min=0.3108, max=0.7807, mean=0.4097, std=0.0509
  - Ambiguity rate: min=0.0000, max=0.0995, mean=0.0006, std=0.0033
[INFO] K-mer frequency vector sums:
  - Mean: 1.0000, Std: 0.0000
  - Outliers (sum < 0.95 or > 1.05): 0 out of 367499


In [72]:
analyze_features(out_dir, split='val', category='archaea', figure_dir=figures_dir)

[INFO] Analyzing features for val split of archaea
[INFO] Features shape: torch.Size([93075, 260])
[INFO] Label distribution:
  - Label 1: 93075 contigs
[INFO] Summary statistics for metadata features:
  - GC content: min=0.2464, max=0.7280, mean=0.4620, std=0.0967
  - Mean genome length: min=0.0022, max=4.3615, mean=0.0083, std=0.0260
  - Log contig length: min=0.3108, max=0.7644, mean=0.4290, std=0.0388
  - Ambiguity rate: min=0.0000, max=0.0998, mean=0.0004, std=0.0029
[INFO] K-mer frequency vector sums:
  - Mean: 0.8987, Std: 0.3017
  - Outliers (sum < 0.95 or > 1.05): 9460 out of 93075


In [73]:
analyze_features(out_dir, split='test', category='plasmid', figure_dir=figures_dir)

[INFO] Analyzing features for test split of plasmid
[INFO] Features shape: torch.Size([3821, 260])
[INFO] Label distribution:
  - Label 3: 3821 contigs
[INFO] Summary statistics for metadata features:
  - GC content: min=0.1635, max=0.7444, mean=0.4740, std=0.0987
  - Mean genome length: min=0.0010, max=2.3808, mean=0.0666, std=0.1392
  - Log contig length: min=0.3455, max=0.7341, mean=0.4908, std=0.0847
  - Ambiguity rate: min=0.0000, max=0.0957, mean=0.0002, std=0.0024
[INFO] K-mer frequency vector sums:
  - Mean: 0.5113, Std: 0.4993
  - Outliers (sum < 0.95 or > 1.05): 1873 out of 3821
