In [None]:
import numpy as np
import matplotlib.pyplot as plt
import tree as t
import re
import Bio as b

In [None]:
refname = "/Users/rfeld/Documents/Research/SPATIAL/spatial_24/ref/chr21.fa"
ref = []
with open(refname, 'r') as f:
    for line in f.readlines():
        ref.append(line.strip())
ref = ref[1:]

# convert to 960 blocks - blocks of 60, at initialize. 
num_blocks = len(ref) // 16

new_ref = []
for i in range(0, num_blocks, 16):
    block = ""
    for j in range(16):
        block += ref[i + j]
    new_ref.append(block)

if len(ref) % 16 > 0:
    last_block = ""
    blocks_left = len(ref) % 16
    for j in range(blocks_left):
        last_block += ref[-(blocks_left - j)]
    new_ref.append(last_block)

ref = new_ref
ref_idx = list(range(len(new_ref)))

In [None]:
min_cn_length = 1000 # in bp
cn_length_mean = 5000000
resolution = 960

In [None]:
# CURRENT
min_cn_length = 1000 # in bp
# cn_length_mean = 5000000
resolution = 960

def update_tag(tag, b_rel_start, b_length, copies, resolution=960):
    # Auto-convert from base pairs to block indices if needed
    if b_rel_start >= resolution:
        b_rel_start = (b_rel_start - 1) // resolution
    if b_length >= resolution:
        b_length = b_length // resolution

    region = copies * tag[b_rel_start : b_rel_start + b_length]
    return tag[:b_rel_start] + region + tag[b_rel_start + b_length :]

def mutate(b_rel_idx, b_rel_start, b_length, copies):
    region = copies * b_rel_idx[b_rel_start:b_rel_start + b_length]
    b_new_idx = b_rel_idx[:b_rel_start] + region + b_rel_idx[b_rel_start + b_length:]

    if copies == 0:
        # for tracking deletions, put one on each edge of deleted region
        tag = [0] * len(b_rel_idx)
        # #NB weird-ish behavior for edge cases, but should be interpretable
        if b_rel_start > 0:
            tag[b_rel_start - 1] = 1
        else:
            tag[0] = 1
        if b_rel_start + b_length < len(b_rel_idx):
            tag[b_rel_start + b_length] = 1
        else:
            tag[-1] = 1
        tag[b_rel_start - 1] = 1
        tag[b_rel_start + b_length] = 1
        tag = tag[:b_rel_start] + tag[b_rel_start + b_length:]
    else:
        tag = [0] * len(b_rel_idx[:b_rel_start]) + [1] * len(region) + [0] * \
            len(b_rel_idx[b_rel_start + b_length:])
    
    return b_new_idx, tag

def generate_random_mutation(b_rel_idx):
    num_regions = len(b_rel_idx)
    CN_size = max(min(round(np.random.exponential(cn_length_mean / resolution)), num_regions), 1)
    # FIXME argparse this
    CN_type = np.random.binomial(1, 0.5)
    #FIXME check behavior
    CN_copies = np.random.geometric(0.5) + 1 if CN_type == 1 else 0
    b_start_idx = np.random.randint(num_regions - CN_size + 1)
    b_end_idx = b_start_idx + CN_size

    mutation_dict = {
        "copies": CN_copies,
        "prev_start_idx": b_start_idx * resolution + 1,
        "prev_end_idx":b_end_idx * resolution,
        "region_length": CN_size * resolution, 
        "total_length": CN_copies * CN_size * resolution,
        "parent_idx": b_rel_idx.copy()}
    return mutation_dict

In [None]:
# CURRENT
def n_diploid_cell_random(b_rel_idx_allele0, b_rel_idx_allele1, n0, n1,
                          prev_mutations=None, prev_tags=None, resolution=960):
    """Generate mutations and update inherited tags for a diploid genome."""
    mutations = {}
    tags = {}
    final_blocks = {}

    if prev_mutations is None:
        prev_mutations = {0: [], 1: []}
    if prev_tags is None:
        prev_tags = {0: [], 1: []}

    for allele, initial_blocks, n_mut in [(0, b_rel_idx_allele0, n0), (1, b_rel_idx_allele1, n1)]:
        allele_mutations = prev_mutations[allele].copy()
        allele_tags = [tag.copy() for tag in prev_tags[allele]]
        mutation_blocks = initial_blocks.copy()

        for _ in range(n_mut):
            md = generate_random_mutation(mutation_blocks.copy())
            md["allele"] = allele

            b_start_idx = (md["prev_start_idx"] - 1) // resolution
            b_length = md["region_length"] // resolution

            mutation_blocks, new_tag = mutate(
                md["parent_idx"].copy(), b_start_idx, b_length, md["copies"]
            )

            # Update all prior tags to align with the new genome structure
            allele_tags = [
                update_tag(tag, b_start_idx, b_length, md["copies"], resolution)
                for tag in allele_tags
            ]
            allele_tags.append(new_tag)
            allele_mutations.append(md)

        # Assign tag-based location summaries
        for tag, md in zip(allele_tags[-n_mut:], allele_mutations[-n_mut:]):
            if sum(tag) == 0:
                md["tumor_tag_start_bp"] = None
                md["tumor_tag_end_bp"] = None
                md["ref_min_bp"] = None
                md["ref_max_bp"] = None
            else:
                first = tag.index(1)
                last = len(tag) - 1 - tag[::-1].index(1)
                md["tumor_tag_start_bp"] = first * resolution + 1
                md["tumor_tag_end_bp"] = (last + 1) * resolution

                if first > last or last >= len(mutation_blocks):
                    md["ref_min_bp"] = None
                    md["ref_max_bp"] = None
                else:
                    ref_blocks = mutation_blocks[first : last + 1]
                    md["ref_min_bp"] = min(ref_blocks) * resolution + 1
                    md["ref_max_bp"] = (max(ref_blocks) + 1) * resolution

        mutations[allele] = allele_mutations
        tags[allele] = allele_tags
        final_blocks[allele] = mutation_blocks

    return mutations, tags, final_blocks


In [None]:
#CURRENT 
def append_mutation_log(mutations_by_allele, filepath, human_readable=False):
    """Append mutations from both alleles to a plain text log file.

    Parameters
    ----------
    mutations_by_allele : dict[int, list[dict]]
        Dictionary mapping allele (0, 1) to list of mutation records.
    filepath : str
        Path to the log file.
    human_readable : bool
        If True, prints a labeled block per mutation.
        If False, prints a tab-separated row per mutation.
    """
    fields = [
        "allele",
        "copies",
        "prev_start_idx",
        "prev_end_idx",
        "region_length",
        "total_length",
        "tumor_tag_start_bp",
        "tumor_tag_end_bp",
        "ref_min_bp",
        "ref_max_bp"
    ]

    with open(filepath, "a") as f:
        for allele in sorted(mutations_by_allele):
            allele_mutations = mutations_by_allele[allele]

            if human_readable:
                f.write(f"\n# Allele {allele}\n")
                for i, m in enumerate(allele_mutations):
                    f.write(f"\n[Mutation {i+1}]\n")
                    for field in fields:
                        val = m.get(field, "")
                        f.write(f"{field}:\t{val}\n")
            else:
                f.write(f"# Allele {allele}\n")
                for m in allele_mutations:
                    row = [str(m.get(field, "")) for field in fields]
                    f.write("\t".join(row) + "\n")

In [None]:
cn_length_mean = 5 # lowering for toy examples

b_rel_idx0 = list(range(100))
b_rel_idx1 = list(range(100))
mutations, tags, b_new_idx = n_diploid_cell_random(b_rel_idx0, b_rel_idx1, 2, 0)

In [None]:
append_mutation_log(mutations, "log.txt", human_readable=True)

In [None]:
from copy import deepcopy

def apply_mutation_sequence(blocks, tags, mutations, resolution=960):
    """
    Sequentially apply a list of mutations to both the genome blocks and the tags.

    Parameters
    ----------
    blocks : list[int]
        Initial list of block indices (genome).
    tags : list[list[int]]
        List of tag vectors to update alongside the genome.
    mutations : list[dict]
        List of mutation dictionaries to apply sequentially.
    resolution : int
        Base pair resolution per block.

    Returns
    -------
    new_blocks : list[int]
        Updated genome after all mutations.
    new_tags : list[list[int]]
        Updated tags after all mutations.
    """
    current_blocks = deepcopy(blocks)
    current_tags = [deepcopy(tag) for tag in tags]

    for md in mutations:
        b_start_idx = (md["prev_start_idx"] - 1) // resolution
        b_length = md["region_length"] // resolution
        copies = md["copies"]

        # Apply mutation to genome
        current_blocks, new_tag = mutate(
            current_blocks, b_start_idx, b_length, copies
        )

        # Update all existing tags
        current_tags = [
            update_tag(tag, b_start_idx, b_length, copies, resolution)
            for tag in current_tags
        ]

        # Add the new mutation tag
        current_tags.append(new_tag)

    return current_blocks, current_tags

# This function can now be used to replay all inherited mutations and tags on top of the parent's genome
# before new mutations are generated and appended.


In [None]:
# Parameters
resolution = 960
cn_length_mean = 5000000
geom_p = 0.5  # For number of mutations per cell
diploid_blocks = list(range(50000))  # e.g., 48Mb genome (50,000 blocks)


tree_path = "/Users/rfeld/Documents/Research/SPATIAL/output/experiments/exp40/sequences/tree.nwk"
tree = Phylo.read(tree_path, "newick")
tree.ladderize()

# Build parent map for upward tracking
parent_map = {}
for clade in tree.find_clades(order="level"):
    for child in clade.clades:
        parent_map[child] = clade

# Walk tree and generate random mutations per cell
cumulative_genomes = {}
mutation_records = {}

for node in tree.find_clades(order="level"):
    if not node.name:
        continue

    parent = parent_map.get(node)
    if parent and parent.name in cumulative_genomes:
        parent_genome = cumulative_genomes[parent.name]
    else:
        parent_genome = {0: diploid_blocks.copy(), 1: diploid_blocks.copy()}

    # Draw number of mutations
    n_total = np.random.geometric(p=geom_p)
    n0 = np.random.binomial(n_total, 0.5)
    n1 = n_total - n0

    muts, tags, final_blocks = n_diploid_cell_random(parent_genome[0], parent_genome[1], n0, n1)
    cumulative_genomes[node.name] = final_blocks
    mutation_records[node.name] = muts

In [None]:
# CURRENT
from Bio import Phylo

# Parameters
resolution = 960
cn_length_mean = 5000000
geom_p = 0.5
diploid_blocks = list(range(50000))  # reference genome blocks

# Load tree and build parent map
tree_path = "/Users/rfeld/Documents/Research/SPATIAL/output/experiments/exp40/sequences/tree.nwk"
tree = Phylo.read(tree_path, "newick")
tree.ladderize()

parent_map = {}
for clade in tree.find_clades(order="level"):
    for child in clade.clades:
        parent_map[child] = clade

# Initialize storage
cumulative_genomes = {}
mutation_records = {}
tags_by_cell = {}

for node in tree.find_clades(order="level"):
    if not node.name:
        continue

    parent = parent_map.get(node)
    if parent and parent.name in cumulative_genomes:
        parent_genome = cumulative_genomes[parent.name]
    else:
        parent_genome = {0: diploid_blocks.copy(), 1: diploid_blocks.copy()}

    # Handle founder: no mutations
    if parent is None:
        cumulative_genomes[node.name] = {
            0: parent_genome[0].copy(),
            1: parent_genome[1].copy()
        }
        mutation_records[node.name] = {0: [], 1: []}
        tags_by_cell[node.name] = {0: [], 1: []}
        continue

    # Draw number of mutations
    n_total = np.random.geometric(p=geom_p)
    n0 = np.random.binomial(n_total, 0.5)
    n1 = n_total - n0

    # Generate mutations
    muts, tags, final_blocks = n_diploid_cell_random(
        parent_genome[0], parent_genome[1], n0, n1,
        prev_mutations=mutation_records[parent.name],
        prev_tags=tags_by_cell[parent.name],
        resolution=resolution)

    # Accumulate mutations and tags
    mutation_records[node.name] = {
        allele: mutation_records[parent.name][allele] + muts[allele]
        for allele in [0, 1]
    }
    tags_by_cell[node.name] = {
        allele: tags_by_cell[parent.name][allele] + tags[allele]
        for allele in [0, 1]
    }

    cumulative_genomes[node.name] = final_blocks

# Validate the entire mutation pipeline
# validate_mutation_pipeline(
#     tree=tree,
#     parent_map=parent_map,
#     mutation_records=mutation_records,
#     tags=tags_by_cell,
#     cumulative_genomes=cumulative_genomes,
#     resolution=960
# )


In [None]:
# CURRENT
def validate_mutation_pipeline_verbose(tree, parent_map, mutation_records, tags, cumulative_genomes, resolution=960):
    """
    Verbose validation of mutation pipeline.
    Reports detailed mismatch info to help debug tag/genome alignment issues.
    """
    expected_keys = {
        "allele", "copies", "prev_start_idx", "prev_end_idx", "region_length", "total_length",
        "tumor_tag_start_bp", "tumor_tag_end_bp", "ref_min_bp", "ref_max_bp", "parent_idx"
    }

    for node in tree.find_clades(order="level"):
        if not node.name:
            continue

        for allele in [0, 1]:
            muts = mutation_records.get(node.name, {}).get(allele, [])
            block_list = cumulative_genomes.get(node.name, {}).get(allele, [])
            tag_list = tags.get(node.name, {}).get(allele, [])

            for i, md in enumerate(muts):
                if i >= len(tag_list):
                    print(f"[ERROR] {node.name} allele {allele} mutation {i}: tag is missing entirely (len(tags) = {len(tag_list)})")
                    raise IndexError(f"{node.name} allele {allele}: tag index {i} out of bounds")

                tag = tag_list[i]
                if len(tag) != len(block_list):
                    print(f"\n[ERROR] {node.name} allele {allele} mutation {i}: tag/genome length mismatch")
                    print(f"  ↪ Tag length    = {len(tag)}")
                    print(f"  ↪ Genome length = {len(block_list)}")
                    print(f"  ↪ Tag start     = {tag[:5]} ... {tag[-5:]}")
                    print(f"  ↪ Blocks start  = {block_list[:5]} ... {block_list[-5:]}")
                    print(f"  ↪ Mutation: copies={md['copies']}, region={md['region_length']} bp")
                    print(f"             prev_start={md['prev_start_idx']}, prev_end={md['prev_end_idx']}")

                if md["tumor_tag_start_bp"] is not None:
                    assert 1 <= md["tumor_tag_start_bp"] <= len(block_list) * resolution, \
                        f"{node.name} allele {allele} mutation {i}: tumor_tag_start_bp out of range"
                    assert md["tumor_tag_start_bp"] <= md["tumor_tag_end_bp"], \
                        f"{node.name} allele {allele} mutation {i}: tumor start > end"

                if md["ref_min_bp"] is not None:
                    assert md["ref_min_bp"] <= md["ref_max_bp"], \
                        f"{node.name} allele {allele} mutation {i}: ref_min > ref_max"

                missing_keys = expected_keys - md.keys()
                assert not missing_keys, f"{node.name} allele {allele} mutation {i}: missing keys {missing_keys}"

    return "All mutation pipeline checks passed."


In [None]:
def validate_mutation_pipeline(tree, parent_map, mutation_records, tags, cumulative_genomes, resolution=960):
    """
    Run consistency checks on simulated diploid mutation records, tags, and genome propagation.
    Raises AssertionError with descriptive message if any check fails.
    """
    expected_keys = {
        "allele", "copies", "prev_start_idx", "prev_end_idx", "region_length", "total_length",
        "tumor_tag_start_bp", "tumor_tag_end_bp", "ref_min_bp", "ref_max_bp", "parent_idx"
    }

    for node in tree.find_clades(order="level"):
        if not node.name:
            continue

        # Validate genome propagation from parent
        parent = parent_map.get(node)
        if parent and parent.name in cumulative_genomes:
            for allele in [0, 1]:
                parent_blocks = cumulative_genomes[parent.name][allele]
                this_blocks = cumulative_genomes[node.name][allele]
                assert isinstance(this_blocks, list), f"{node.name} allele {allele}: genome not list"
                assert all(isinstance(x, int) for x in this_blocks), f"{node.name} allele {allele}: genome has non-integer"
                assert len(this_blocks) >= 0, f"{node.name} allele {allele}: genome has negative length"

        if node.name not in mutation_records:
            continue

        for allele in [0, 1]:
            muts = mutation_records[node.name].get(allele, [])
            block_list = cumulative_genomes[node.name][allele]
            tag_list = tags_by_cell[node.name].get(allele, [])

            for i, md in enumerate(muts):
                tag = tag_list[i]
                assert len(tag) == len(block_list), f"{node.name} allele {allele} mutation {i}: tag/genome length mismatch"

                if md["tumor_tag_start_bp"] is not None:
                    assert 1 <= md["tumor_tag_start_bp"] <= len(block_list) * resolution, \
                        f"{node.name} allele {allele} mutation {i}: tumor_tag_start_bp out of range"
                    assert md["tumor_tag_start_bp"] <= md["tumor_tag_end_bp"], \
                        f"{node.name} allele {allele} mutation {i}: tumor start > end"

                if md["ref_min_bp"] is not None:
                    assert md["ref_min_bp"] <= md["ref_max_bp"], \
                        f"{node.name} allele {allele} mutation {i}: ref_min > ref_max"

                missing_keys = expected_keys - md.keys()
                assert not missing_keys, f"{node.name} allele {allele} mutation {i}: missing keys {missing_keys}"

    return "All mutation pipeline checks passed."



In [None]:
validate_mutation_pipeline(
    tree=tree,
    parent_map=parent_map,
    mutation_records=mutation_records,
    tags=tags,
    cumulative_genomes=cumulative_genomes,
    resolution=960
)

In [None]:
# Simulate a toy example to demonstrate why naive tag updates can cause length mismatches

# Initial genome and tag
initial_blocks = list(range(10))  # genome: [0, 1, 2, ..., 9]
initial_tag = [0] * 10            # tag: same length

# Define a mutation: delete 3 elements from the middle
def mutate_blocks(blocks, start, length, copies):
    # Simulate mutation on genome blocks
    region = copies * blocks[start:start+length]
    new_blocks = blocks[:start] + region + blocks[start+length:]
    return new_blocks

def update_tag_naive(tag, start, length, copies):
    # Naively mutate a tag like in user's code
    region = copies * tag[start:start+length]
    return tag[:start] + region + tag[start+length:]

# Apply two sequential mutations to genome and tag
# First mutation: delete 3 blocks starting from index 2 (i.e., remove 2,3,4)
mutated_blocks1 = mutate_blocks(initial_blocks, start=2, length=3, copies=0)
mutated_tag1 = update_tag_naive(initial_tag, start=2, length=3, copies=0)

# Second mutation: delete 2 blocks starting from new index 5
# BUT NOTE: mutated_tag1 length is now shorter than mutated_blocks1
mutated_blocks2 = mutate_blocks(mutated_blocks1, start=5, length=2, copies=0)
mutated_tag2 = update_tag_naive(mutated_tag1, start=5, length=2, copies=0)

# Report lengths to demonstrate mismatch
len_initial_blocks = len(initial_blocks)
len_mutated_blocks1 = len(mutated_blocks1)
len_mutated_tag1 = len(mutated_tag1)
len_mutated_blocks2 = len(mutated_blocks2)
len_mutated_tag2 = len(mutated_tag2)

(len_initial_blocks, len_mutated_blocks1, len_mutated_tag1,
 len_mutated_blocks2, len_mutated_tag2,
 mutated_blocks2, mutated_tag2)


In [None]:
mutated_blocks1
# mutated_tag1

In [None]:
b_rel_start

In [None]:
resolution = 960
copies = 0
prev_start = 38246401
prev_end = 40285440
region_length = prev_end - prev_start + 1  # = 2039040
b_rel_start = (prev_start - 1) // resolution  # should be 39840
b_length = region_length // resolution  # should be 2124

parent_blocks = list(range(50000))  # reference diploid blocks
parent_tag = [0] * 50000  # corresponding tag, aligned

In [None]:
# MUTATE genome
new_blocks, new_tag = mutate(parent_blocks.copy(), b_rel_start, b_length, copies)

# UPDATE tags — inherit + mutate
inherited_tags = [tag.copy() for tag in [parent_tag]]  # assume 1 inherited tag
updated_tags = [
    update_tag(tag, b_rel_start, b_length, copies, resolution)
    for tag in inherited_tags
]
updated_tags.append(new_tag)


In [None]:
for tag in updated_tags:
    print("Tag length:", len(tag))

print("Genome length:", len(new_blocks))

In [None]:
for node in tree.find_clades(order="level"):
    if not node.name:
        continue

    parent = parent_map.get(node)

    # Founder node: no mutations
    if parent is None:
        cumulative_genomes[node.name] = {
            0: diploid_blocks.copy(),
            1: diploid_blocks.copy()
        }
        mutation_records[node.name] = {0: [], 1: []}
        tags_by_cell[node.name] = {0: [], 1: []}
        continue

    # Reconstruct genome and tags by applying inherited mutations
    inherited_mutations = mutation_records[parent.name]
    inherited_blocks = {}
    inherited_tags = {}
    for allele in [0, 1]:
        inherited_blocks[allele], inherited_tags[allele] = apply_mutation_sequence(
            diploid_blocks, [], inherited_mutations[allele]
        )

    # Draw number of new mutations
    n_total = np.random.geometric(p=geom_p)
    n0 = np.random.binomial(n_total, 0.5)
    n1 = n_total - n0

    # Apply new mutations to reconstructed genome
    muts, tags, final_blocks = n_diploid_cell_random(
        inherited_blocks[0], inherited_blocks[1], n0, n1,
        prev_mutations=inherited_mutations,
        prev_tags=inherited_tags,
        resolution=resolution
    )

    # Save all state
    mutation_records[node.name] = muts
    tags_by_cell[node.name] = tags
    cumulative_genomes[node.name] = final_blocks


In [None]:
validate_mutation_pipeline_verbose(
    tree=tree,
    parent_map=parent_map,
    mutation_records=mutation_records,
    tags=tags,
    cumulative_genomes=cumulative_genomes,
    resolution=960
)

In [None]:
from Bio import Phylo
import numpy as np
from io import StringIO

# --- Parameters ---
resolution = 960
cn_length_mean = 5000000
geom_p = 0.5
diploid_blocks = list(range(50000))
nwk_string = "(((cell1:0.10730759618761926,(cell6:0.04420179819059644,cell7:0.04420179819059644)ancestor7:0.06310579799702282)ancestor3:0.15119724784519406,(cell2:0.17496885779309207,(cell8:0.1106682423281029,((cell11:0.01877210928914326,cell12:0.01877210928914326)ancestor11:0.06924918172854262,(cell13:0.02132101972393934,(cell14:0.00240299287976051,cell15:0.00240299287976051)ancestor13:0.01891802684417883)ancestor12:0.06670027129374655)ancestor10:0.02264695131041702)ancestor8:0.06430061546498916)ancestor4:0.08353598623972125)ancestor1:0.7160958934610564,((cell3:0.06903946747075308,(cell9:0.02475272583115519,cell10:0.02475272583115519)ancestor9:0.04428674163959789)ancestor5:0.587440203881981,(cell4:0.08567206940767744,cell5:0.08567206940767744)ancestor6:0.5708076019450566)ancestor2:0.31812106614113556)founder;"

def update_tag(tag, b_rel_start, b_length, copies, resolution=960):
    b_rel_start = b_rel_start // resolution
    b_length = b_length // resolution
    region = copies * tag[b_rel_start : b_rel_start + b_length]
    return tag[:b_rel_start] + region + tag[b_rel_start + b_length :]

def mutate(b_rel_idx, b_rel_start, b_length, copies):
    region = copies * b_rel_idx[b_rel_start:b_rel_start + b_length]
    b_new_idx = b_rel_idx[:b_rel_start] + region + b_rel_idx[b_rel_start + b_length:]

    if copies == 0:
        tag = [0] * len(b_rel_idx)
        tag[b_rel_start - 1 if b_rel_start > 0 else 0] = 1
        tag[min(b_rel_start + b_length, len(b_rel_idx) - 1)] = 1
        tag = tag[:b_rel_start] + tag[b_rel_start + b_length:]
    else:
        tag = [0] * len(b_rel_idx[:b_rel_start]) + [1] * len(region) + [0] * len(b_rel_idx[b_rel_start + b_length:])

    return b_new_idx, tag

def generate_random_mutation(b_rel_idx):
    num_regions = len(b_rel_idx)
    CN_size = max(min(round(np.random.exponential(cn_length_mean / resolution)), num_regions), 1)
    CN_type = np.random.binomial(1, 0.5)
    CN_copies = np.random.geometric(0.5) + 1 if CN_type == 1 else 0
    b_start_idx = np.random.randint(num_regions - CN_size + 1)
    b_end_idx = b_start_idx + CN_size

    return {
        "copies": CN_copies,
        "prev_start_idx": b_start_idx * resolution + 1,
        "prev_end_idx": b_end_idx * resolution,
        "region_length": CN_size * resolution,
        "total_length": CN_copies * CN_size * resolution,
        "parent_idx": b_rel_idx.copy()
    }

def apply_mutation_sequence(start_blocks, start_tags, mutation_list):
    blocks = start_blocks.copy()
    tags = [tag.copy() for tag in start_tags]
    for md in mutation_list:
        b_start_idx = (md["prev_start_idx"] - 1) // resolution
        b_length = md["region_length"] // resolution
        blocks, new_tag = mutate(md["parent_idx"].copy(), b_start_idx, b_length, md["copies"])
        tags = [update_tag(tag, b_start_idx * resolution, b_length * resolution, md["copies"]) for tag in tags]
        tags.append(new_tag)
    return blocks, tags

def n_diploid_cell_random(b_rel_idx_allele0, b_rel_idx_allele1, n0, n1,
                          prev_mutations=None, prev_tags=None, resolution=960):
    mutations = {}
    tags = {}
    final_blocks = {}

    if prev_mutations is None:
        prev_mutations = {0: [], 1: []}
    if prev_tags is None:
        prev_tags = {0: [], 1: []}

    for allele, initial_blocks, n_mut in [(0, b_rel_idx_allele0, n0), (1, b_rel_idx_allele1, n1)]:
        allele_mutations = prev_mutations[allele].copy()
        allele_tags = [tag.copy() for tag in prev_tags[allele]]
        mutation_blocks = initial_blocks.copy()

        for _ in range(n_mut):
            md = generate_random_mutation(mutation_blocks.copy())
            md["allele"] = allele
            b_start_idx = (md["prev_start_idx"] - 1) // resolution
            b_length = md["region_length"] // resolution
            mutation_blocks, new_tag = mutate(md["parent_idx"].copy(), b_start_idx, b_length, md["copies"])
            allele_tags = [update_tag(tag, b_start_idx * resolution, b_length * resolution, md["copies"]) for tag in allele_tags]
            allele_tags.append(new_tag)
            allele_mutations.append(md)

        for tag, md in zip(allele_tags[-n_mut:], allele_mutations[-n_mut:]):
            if sum(tag) == 0:
                md.update({k: None for k in ["tumor_tag_start_bp", "tumor_tag_end_bp", "ref_min_bp", "ref_max_bp"]})
            else:
                first = tag.index(1)
                last = len(tag) - 1 - tag[::-1].index(1)
                if first >= len(mutation_blocks):
                    md.update({k: None for k in ["tumor_tag_start_bp", "tumor_tag_end_bp", "ref_min_bp", "ref_max_bp"]})
                else:
                    md["tumor_tag_start_bp"] = max(1, min((first * resolution + 1), len(mutation_blocks) * resolution))
                    md["tumor_tag_end_bp"] = min((last + 1) * resolution, len(mutation_blocks) * resolution)
                    ref_blocks = mutation_blocks[first : last + 1]
                    md["ref_min_bp"] = min(ref_blocks) * resolution + 1
                    md["ref_max_bp"] = (max(ref_blocks) + 1) * resolution

        mutations[allele] = allele_mutations
        tags[allele] = allele_tags
        final_blocks[allele] = mutation_blocks

    return mutations, tags, final_blocks

# --- Main Tree Walk ---
tree = Phylo.read(StringIO(nwk_string), "newick")
tree.ladderize()

parent_map = {child: clade for clade in tree.find_clades(order="level") for child in clade.clades}

cumulative_genomes = {}
mutation_records = {}
tags_by_cell = {}

for node in tree.find_clades(order="level"):
    if not node.name:
        continue

    parent = parent_map.get(node)
    if parent is None:
        cumulative_genomes[node.name] = {0: diploid_blocks.copy(), 1: diploid_blocks.copy()}
        mutation_records[node.name] = {0: [], 1: []}
        tags_by_cell[node.name] = {0: [], 1: []}
        continue

    inherited_mutations = mutation_records[parent.name]
    inherited_blocks = {}
    inherited_tags = {}
    for allele in [0, 1]:
        inherited_blocks[allele], inherited_tags[allele] = apply_mutation_sequence(
            diploid_blocks, [], inherited_mutations[allele]
        )

    n_total = np.random.geometric(p=geom_p)
    n0 = np.random.binomial(n_total, 0.5)
    n1 = n_total - n0

    muts, tags, final_blocks = n_diploid_cell_random(
        inherited_blocks[0], inherited_blocks[1], n0, n1,
        prev_mutations=inherited_mutations,
        prev_tags=inherited_tags,
        resolution=resolution
    )

    mutation_records[node.name] = muts
    tags_by_cell[node.name] = tags
    cumulative_genomes[node.name] = final_blocks

def validate_mutation_pipeline_verbose(tree, parent_map, mutation_records, tags, cumulative_genomes, resolution=960):
    expected_keys = {
        "allele", "copies", "prev_start_idx", "prev_end_idx", "region_length", "total_length",
        "tumor_tag_start_bp", "tumor_tag_end_bp", "ref_min_bp", "ref_max_bp", "parent_idx"
    }

    for node in tree.find_clades(order="level"):
        if not node.name:
            continue

        for allele in [0, 1]:
            muts = mutation_records.get(node.name, {}).get(allele, [])
            block_list = cumulative_genomes.get(node.name, {}).get(allele, [])
            tag_list = tags.get(node.name, {}).get(allele, [])

            for i, md in enumerate(muts):
                if i >= len(tag_list):
                    raise IndexError(f"{node.name} allele {allele}: tag index {i} out of bounds")

                tag = tag_list[i]
                if len(tag) != len(block_list):
                    raise ValueError(f"{node.name} allele {allele} mutation {i}: tag/genome length mismatch")

                if md["tumor_tag_start_bp"] is not None:
                    print(f"  [DEBUG] {node.name} allele {allele} mutation {i}:")
                    print(f"    tumor_tag_start_bp: {md['tumor_tag_start_bp']}")
                    print(f"    len(block_list) * resolution: {len(block_list) * resolution}")
                    assert 1 <= md["tumor_tag_start_bp"] <= len(block_list) * resolution

                if md["ref_min_bp"] is not None:
                    assert md["ref_min_bp"] <= md["ref_max_bp"], \
                        f"{node.name} allele {allele} mutation {i}: ref_min > ref_max"

                missing_keys = expected_keys - md.keys()
                assert not missing_keys, f"{node.name} allele {allele} mutation {i}: missing keys {missing_keys}"

    return "All mutation pipeline checks passed."

In [None]:
mutation_records

In [None]:
from Bio import Phylo
import numpy as np
from io import StringIO

# --- Parameters ---
resolution = 960
cn_length_mean = 1000000  # lowered to make non-overlapping more feasible
geom_p = 0.5
diploid_blocks = list(range(50000))

nwk_string = "(((cell1:0.10730759618761926,(cell6:0.04420179819059644,cell7:0.04420179819059644)ancestor7:0.06310579799702282)ancestor3:0.15119724784519406,(cell2:0.17496885779309207,(cell8:0.1106682423281029,((cell11:0.01877210928914326,cell12:0.01877210928914326)ancestor11:0.06924918172854262,(cell13:0.02132101972393934,(cell14:0.00240299287976051,cell15:0.00240299287976051)ancestor13:0.01891802684417883)ancestor12:0.06670027129374655)ancestor10:0.02264695131041702)ancestor8:0.06430061546498916)ancestor4:0.08353598623972125)ancestor1:0.7160958934610564,((cell3:0.06903946747075308,(cell9:0.02475272583115519,cell10:0.02475272583115519)ancestor9:0.04428674163959789)ancestor5:0.587440203881981,(cell4:0.08567206940767744,cell5:0.08567206940767744)ancestor6:0.5708076019450566)ancestor2:0.31812106614113556)founder;"

def mutate(b_rel_idx, b_rel_start, b_length, copies):
    region = copies * b_rel_idx[b_rel_start:b_rel_start + b_length]
    b_new_idx = b_rel_idx[:b_rel_start] + region + b_rel_idx[b_rel_start + b_length:]
    return b_new_idx

def generate_random_mutation(num_regions, existing_intervals):
    max_size = max(1, num_regions - 1)
    for _ in range(100):  # attempt multiple times to avoid infinite loop
        CN_size = max(min(round(np.random.exponential(cn_length_mean / resolution)), max_size), 1)
        CN_type = np.random.binomial(1, 0.5)
        CN_copies = np.random.geometric(0.5) + 1 if CN_type == 1 else 0
        if num_regions - CN_size + 1 <= 0:
            b_start_idx = 0
        else:
            b_start_idx = np.random.randint(num_regions - CN_size + 1)
        b_end_idx = b_start_idx + CN_size

        overlaps = any(not (b_end_idx <= s or b_start_idx >= e) for (s, e) in existing_intervals)
        if not overlaps:
            return {
                "copies": CN_copies,
                "prev_start_idx": b_start_idx * resolution + 1,
                "prev_end_idx": b_end_idx * resolution,
                "region_length": CN_size * resolution,
                "total_length": CN_copies * CN_size * resolution,
                "interval": (b_start_idx, b_end_idx)
            }
    raise ValueError("Failed to generate non-overlapping mutation after 100 tries.")

def apply_mutation_and_trace(blocks, mutation):
    b_start_idx = (mutation["prev_start_idx"] - 1) // resolution
    b_length = mutation["region_length"] // resolution

    new_blocks = mutate(blocks, b_start_idx, b_length, mutation["copies"])

    if mutation["copies"] == 0 or b_start_idx >= len(new_blocks):
        return new_blocks, (None, None), (None, None)

    tumor_start_bp = max(1, b_start_idx * resolution + 1)
    tumor_end_bp = min((b_start_idx + b_length * mutation["copies"]) * resolution, len(new_blocks) * resolution)

    ref_region = blocks[b_start_idx : b_start_idx + b_length]
    if not ref_region:
        ref_min_bp = ref_max_bp = None
    else:
        ref_min_bp = min(ref_region) * resolution + 1
        ref_max_bp = (max(ref_region) + 1) * resolution

    return new_blocks, (tumor_start_bp, tumor_end_bp), (ref_min_bp, ref_max_bp)

def simulate_cell_mutations(blocks, n_mut):
    mutations = []
    existing_intervals = []
    for _ in range(n_mut):
        md = generate_random_mutation(len(blocks), existing_intervals)
        existing_intervals.append(md.pop("interval"))
        blocks, tumor_range, ref_range = apply_mutation_and_trace(blocks, md)
        md["tumor_tag_start_bp"], md["tumor_tag_end_bp"] = tumor_range
        md["ref_min_bp"], md["ref_max_bp"] = ref_range
        mutations.append(md)
    return blocks, mutations

def simulate_diploid_cell(parent_blocks, n0, n1, inherited_mutations):
    all_mutations = {}
    final_blocks = {}
    for allele, n_mut in [(0, n0), (1, n1)]:
        blocks = parent_blocks[allele]
        inherited = inherited_mutations[allele].copy()
        for md in inherited:
            blocks, _, _ = apply_mutation_and_trace(diploid_blocks, md)
        blocks, new_muts = simulate_cell_mutations(blocks, n_mut)
        all_mutations[allele] = inherited + new_muts
        final_blocks[allele] = blocks
    return all_mutations, final_blocks

# --- Tree traversal ---
tree = Phylo.read(StringIO(nwk_string), "newick")
tree.ladderize()
parent_map = {child: clade for clade in tree.find_clades(order="level") for child in clade.clades}
cumulative_genomes = {}
mutation_records = {}
for node in tree.find_clades(order="level"):
    if not node.name:
        continue
    parent = parent_map.get(node)
    if parent is None:
        cumulative_genomes[node.name] = {0: diploid_blocks.copy(), 1: diploid_blocks.copy()}
        mutation_records[node.name] = {0: [], 1: []}
        continue
    inherited_mutations = mutation_records[parent.name]
    parent_blocks = {}
    for allele in [0, 1]:
        blocks = diploid_blocks.copy()
        for md in inherited_mutations[allele]:
            blocks, _, _ = apply_mutation_and_trace(blocks, md)
        parent_blocks[allele] = blocks
    n_total = np.random.geometric(p=geom_p)
    n0 = np.random.binomial(n_total, 0.5)
    n1 = n_total - n0
    muts, final_blocks = simulate_diploid_cell(parent_blocks, n0, n1, inherited_mutations)
    mutation_records[node.name] = muts
    cumulative_genomes[node.name] = final_blocks

def validate_pipeline(tree, mutation_records, cumulative_genomes, resolution):
    for node in tree.find_clades(order="level"):
        if not node.name:
            continue
        for allele in [0, 1]:
            muts = mutation_records[node.name][allele]
            genome = cumulative_genomes[node.name][allele]
            for i, md in enumerate(muts):
                start_bp = md.get("tumor_tag_start_bp")
                if start_bp is not None:
                    assert 1 <= start_bp <= len(genome) * resolution, f"{node.name} allele {allele} mutation {i}: tumor_tag_start_bp out of range"
                    assert start_bp <= md["tumor_tag_end_bp"], f"{node.name} allele {allele} mutation {i}: tumor start > end"
                    assert md["ref_min_bp"] <= md["ref_max_bp"], f"{node.name} allele {allele} mutation {i}: ref min > max"
    return "Validation passed."

validate_pipeline(tree, mutation_records, cumulative_genomes, resolution)


In [None]:
from Bio import Phylo
import numpy as np
from io import StringIO

# --- Parameters ---
resolution = 960
cn_length_mean = 1000000
geom_p = 0.5
diploid_blocks = list(range(50000))

nwk_string = "(cell1:0.1)founder;"  # simplified tree for testing
n_manual_mutations = 1  # manually specify mutations per allele

def mutate(b_rel_idx, b_rel_start, b_length, copies):
    region = copies * b_rel_idx[b_rel_start:b_rel_start + b_length]
    b_new_idx = b_rel_idx[:b_rel_start] + region + b_rel_idx[b_rel_start + b_length:]
    return b_new_idx

def generate_random_mutation(num_regions):
    max_size = max(1, num_regions - 1)
    CN_size = max(min(round(np.random.exponential(cn_length_mean / resolution)), max_size), 1)
    CN_type = np.random.binomial(1, 0.5)
    CN_copies = np.random.geometric(0.5) + 1 if CN_type == 1 else 0
    if num_regions - CN_size + 1 <= 0:
        b_start_idx = 0
    else:
        b_start_idx = np.random.randint(num_regions - CN_size + 1)
    b_end_idx = b_start_idx + CN_size
    return {
        "copies": CN_copies,
        "prev_start_idx": b_start_idx * resolution + 1,
        "prev_end_idx": b_end_idx * resolution,
        "region_length": CN_size * resolution,
        "total_length": CN_copies * CN_size * resolution
    }

def apply_mutation_and_trace(blocks, mutation):
    b_start_idx = (mutation["prev_start_idx"] - 1) // resolution
    b_length = mutation["region_length"] // resolution
    new_blocks = mutate(blocks, b_start_idx, b_length, mutation["copies"])
    if mutation["copies"] == 0 or b_start_idx >= len(new_blocks):
        return new_blocks, (None, None), (None, None)
    tumor_start_bp = max(1, b_start_idx * resolution + 1)
    tumor_end_bp = min((b_start_idx + b_length * mutation["copies"]) * resolution, len(new_blocks) * resolution)
    ref_region = blocks[b_start_idx : b_start_idx + b_length]
    if not ref_region:
        ref_min_bp = ref_max_bp = None
    else:
        ref_min_bp = min(ref_region) * resolution + 1
        ref_max_bp = (max(ref_region) + 1) * resolution
    return new_blocks, (tumor_start_bp, tumor_end_bp), (ref_min_bp, ref_max_bp)

def simulate_cell_mutations(blocks, n_mut):
    mutations = []
    for _ in range(n_mut):
        md = generate_random_mutation(len(blocks))
        blocks, tumor_range, ref_range = apply_mutation_and_trace(blocks, md)
        md["tumor_tag_start_bp"], md["tumor_tag_end_bp"] = tumor_range
        md["ref_min_bp"], md["ref_max_bp"] = ref_range
        mutations.append(md)
    return blocks, mutations

def simulate_diploid_cell(parent_blocks, n0, n1, inherited_mutations):
    all_mutations = {}
    final_blocks = {}
    for allele, n_mut in [(0, n0), (1, n1)]:
        blocks = parent_blocks[allele]
        inherited = inherited_mutations[allele].copy()
        for md in inherited:
            blocks, _, _ = apply_mutation_and_trace(diploid_blocks, md)
        blocks, new_muts = simulate_cell_mutations(blocks, n_mut)
        all_mutations[allele] = inherited + new_muts
        final_blocks[allele] = blocks
    return all_mutations, final_blocks

# --- Tree traversal ---
tree = Phylo.read(StringIO(nwk_string), "newick")
tree.ladderize()
parent_map = {child: clade for clade in tree.find_clades(order="level") for child in clade.clades}
cumulative_genomes = {}
mutation_records = {}
for node in tree.find_clades(order="level"):
    if not node.name:
        continue
    parent = parent_map.get(node)
    if parent is None:
        cumulative_genomes[node.name] = {0: diploid_blocks.copy(), 1: diploid_blocks.copy()}
        mutation_records[node.name] = {0: [], 1: []}
        continue
    inherited_mutations = mutation_records[parent.name]
    parent_blocks = {}
    for allele in [0, 1]:
        blocks = diploid_blocks.copy()
        for md in inherited_mutations[allele]:
            blocks, _, _ = apply_mutation_and_trace(blocks, md)
        parent_blocks[allele] = blocks
    muts, final_blocks = simulate_diploid_cell(parent_blocks, n_manual_mutations, n_manual_mutations, inherited_mutations)
    mutation_records[node.name] = muts
    cumulative_genomes[node.name] = final_blocks

def validate_pipeline(tree, mutation_records, cumulative_genomes, resolution):
    for node in tree.find_clades(order="level"):
        if not node.name:
            continue
        for allele in [0, 1]:
            muts = mutation_records[node.name][allele]
            genome = cumulative_genomes[node.name][allele]
            for i, md in enumerate(muts):
                start_bp = md.get("tumor_tag_start_bp")
                if start_bp is not None:
                    assert 1 <= start_bp <= len(genome) * resolution, f"{node.name} allele {allele} mutation {i}: tumor_tag_start_bp out of range"
                    assert start_bp <= md["tumor_tag_end_bp"], f"{node.name} allele {allele} mutation {i}: tumor start > end"
                    assert md["ref_min_bp"] <= md["ref_max_bp"], f"{node.name} allele {allele} mutation {i}: ref min > max"
    return "Validation passed."

validate_pipeline(tree, mutation_records, cumulative_genomes, resolution)


In [None]:
from Bio import Phylo
import numpy as np
from io import StringIO

# --- Parameters ---
resolution = 960
cn_length_mean = 10000
geom_p = 0.5
diploid_blocks = list(range(50000))

# nwk_string = "(cell1:0.1)founder;"  # simplified tree for testing
nwk_string = "(((cell1:0.10730759618761926,(cell6:0.04420179819059644,cell7:0.04420179819059644)ancestor7:0.06310579799702282)ancestor3:0.15119724784519406,(cell2:0.17496885779309207,(cell8:0.1106682423281029,((cell11:0.01877210928914326,cell12:0.01877210928914326)ancestor11:0.06924918172854262,(cell13:0.02132101972393934,(cell14:0.00240299287976051,cell15:0.00240299287976051)ancestor13:0.01891802684417883)ancestor12:0.06670027129374655)ancestor10:0.02264695131041702)ancestor8:0.06430061546498916)ancestor4:0.08353598623972125)ancestor1:0.7160958934610564,((cell3:0.06903946747075308,(cell9:0.02475272583115519,cell10:0.02475272583115519)ancestor9:0.04428674163959789)ancestor5:0.587440203881981,(cell4:0.08567206940767744,cell5:0.08567206940767744)ancestor6:0.5708076019450566)ancestor2:0.31812106614113556)founder;"

n_manual_mutations = 1  # manually specify mutations per allele

def mutate(b_rel_idx, b_rel_start, b_length, copies):
    region = copies * b_rel_idx[b_rel_start:b_rel_start + b_length]
    b_new_idx = b_rel_idx[:b_rel_start] + region + b_rel_idx[b_rel_start + b_length:]
    return b_new_idx

def generate_random_mutation(num_regions):
    max_size = max(1, num_regions - 1)
    CN_size = max(min(round(np.random.exponential(cn_length_mean / resolution)), max_size), 1)
    CN_type = np.random.binomial(1, 0.5)
    CN_copies = np.random.geometric(0.5) + 1 if CN_type == 1 else 0
    if num_regions - CN_size + 1 <= 0:
        b_start_idx = 0
    else:
        b_start_idx = np.random.randint(num_regions - CN_size + 1)
    b_end_idx = b_start_idx + CN_size
    return {
        "copies": CN_copies,
        "prev_start_idx": b_start_idx * resolution + 1,
        "prev_end_idx": b_end_idx * resolution,
        "region_length": CN_size * resolution,
        "total_length": CN_copies * CN_size * resolution
    }

def apply_mutation_and_trace(blocks, mutation):
    b_start_idx = (mutation["prev_start_idx"] - 1) // resolution
    b_length = mutation["region_length"] // resolution
    new_blocks = mutate(blocks, b_start_idx, b_length, mutation["copies"])
    if mutation["copies"] == 0 or b_start_idx >= len(new_blocks):
        return new_blocks, (None, None), (None, None)
    tumor_start_bp = max(1, b_start_idx * resolution + 1)
    tumor_end_bp = min((b_start_idx + b_length * mutation["copies"]) * resolution, len(new_blocks) * resolution)
    ref_region = blocks[b_start_idx : b_start_idx + b_length]
    if not ref_region:
        ref_min_bp = ref_max_bp = None
    else:
        ref_min_bp = min(ref_region) * resolution + 1
        ref_max_bp = (max(ref_region) + 1) * resolution
    return new_blocks, (tumor_start_bp, tumor_end_bp), (ref_min_bp, ref_max_bp)

def simulate_cell_mutations(blocks, n_mut):
    mutations = []
    for _ in range(n_mut):
        md = generate_random_mutation(len(blocks))
        blocks, tumor_range, ref_range = apply_mutation_and_trace(blocks, md)
        md["tumor_tag_start_bp"], md["tumor_tag_end_bp"] = tumor_range
        md["ref_min_bp"], md["ref_max_bp"] = ref_range
        mutations.append(md)
    return blocks, mutations

def simulate_diploid_cell(parent_blocks, n0, n1, inherited_mutations):
    all_mutations = {}
    final_blocks = {}
    for allele, n_mut in [(0, n0), (1, n1)]:
        blocks = parent_blocks[allele]
        inherited = inherited_mutations[allele].copy()
        for md in inherited:
            blocks, _, _ = apply_mutation_and_trace(diploid_blocks, md)
        blocks, new_muts = simulate_cell_mutations(blocks, n_mut)
        all_mutations[allele] = inherited + new_muts
        final_blocks[allele] = blocks
    return all_mutations, final_blocks

def replay_mutations_on_diploid(diploid_blocks, mutation_list, resolution):
    blocks = diploid_blocks.copy()
    for md in mutation_list:
        b_start_idx = (md["prev_start_idx"] - 1) // resolution
        b_length = md["region_length"] // resolution
        blocks = mutate(blocks, b_start_idx, b_length, md["copies"])

    return blocks

def print_mutation_log(node_name, mutation_records, resolution):
    print(f"\n--- Mutation Log for {node_name} ---")
    for allele in [0, 1]:
        print(f"\nAllele {allele}:")
        for i, md in enumerate(mutation_records[node_name][allele]):
            print(f"  Mutation {i+1}:")
            for k, v in md.items():
                print(f"    {k}: {v}")

def log_genome_lengths(cumulative_genomes):
    print("\n--- Genome Lengths ---")
    for name, genome in cumulative_genomes.items():
        for allele in [0, 1]:
            print(f"{name} allele {allele}: {len(genome[allele])} blocks")

def check_mutation_feasibility(mutation_list, genome_len, resolution):
    for i, md in enumerate(mutation_list):
        region_len = md["region_length"] // resolution
        start = (md["prev_start_idx"] - 1) // resolution
        if start + region_len > genome_len:
            print(f"Warning: Mutation {i} extends beyond genome length!")

def validate_reconstruction(tree, mutation_records, cumulative_genomes, resolution):
    for node in tree.find_clades(order="level"):
        if not node.name:
            continue
        for allele in [0, 1]:
            muts = mutation_records[node.name][allele]
            true_blocks = cumulative_genomes[node.name][allele]
            reconstructed = replay_mutations_on_diploid(diploid_blocks, muts, resolution)
            assert reconstructed == true_blocks, f"Mismatch in {node.name} allele {allele}"
    return "Replay validation passed."

tree = Phylo.read(StringIO(nwk_string), "newick")
tree.ladderize()
parent_map = {child: clade for clade in tree.find_clades(order="level") for child in clade.clades}
cumulative_genomes = {}
mutation_records = {}
for node in tree.find_clades(order="level"):
    if not node.name:
        continue
    parent = parent_map.get(node)
    if parent is None:
        cumulative_genomes[node.name] = {0: diploid_blocks.copy(), 1: diploid_blocks.copy()}
        mutation_records[node.name] = {0: [], 1: []}
        continue
    inherited_mutations = mutation_records[parent.name]
    parent_blocks = {}
    for allele in [0, 1]:
        blocks = diploid_blocks.copy()
        for md in inherited_mutations[allele]:
            blocks, _, _ = apply_mutation_and_trace(blocks, md)
        parent_blocks[allele] = blocks
    muts, final_blocks = simulate_diploid_cell(parent_blocks, n_manual_mutations, n_manual_mutations, inherited_mutations)
    mutation_records[node.name] = muts
    cumulative_genomes[node.name] = final_blocks

validate_reconstruction(tree, mutation_records, cumulative_genomes, resolution)
log_genome_lengths(cumulative_genomes)
for name in mutation_records:
    print_mutation_log(name, mutation_records, resolution)
    for allele in [0, 1]:
        check_mutation_feasibility(mutation_records[name][allele], len(diploid_blocks), resolution)


In [None]:
from Bio import Phylo
import numpy as np
from io import StringIO

# --- Parameters ---
resolution = 960
cn_length_mean = 1000000
geom_p = 0.5
diploid_blocks = list(range(50000))

nwk_string = "(((cell1:0.10730759618761926,(cell6:0.04420179819059644,cell7:0.04420179819059644)ancestor7:0.06310579799702282)ancestor3:0.15119724784519406,(cell2:0.17496885779309207,(cell8:0.1106682423281029,((cell11:0.01877210928914326,cell12:0.01877210928914326)ancestor11:0.06924918172854262,(cell13:0.02132101972393934,(cell14:0.00240299287976051,cell15:0.00240299287976051)ancestor13:0.01891802684417883)ancestor12:0.06670027129374655)ancestor10:0.02264695131041702)ancestor8:0.06430061546498916)ancestor4:0.08353598623972125)ancestor1:0.7160958934610564,((cell3:0.06903946747075308,(cell9:0.02475272583115519,cell10:0.02475272583115519)ancestor9:0.04428674163959789)ancestor5:0.587440203881981,(cell4:0.08567206940767744,cell5:0.08567206940767744)ancestor6:0.5708076019450566)ancestor2:0.31812106614113556)founder;"
n_manual_mutations = 1

def mutate(b_rel_idx, b_rel_start, b_length, copies):
    region = copies * b_rel_idx[b_rel_start:b_rel_start + b_length]
    b_new_idx = b_rel_idx[:b_rel_start] + region + b_rel_idx[b_rel_start + b_length:]
    return b_new_idx

def generate_random_mutation(num_regions):
    max_size = max(1, num_regions - 1)
    CN_size = max(min(round(np.random.exponential(cn_length_mean / resolution)), max_size), 1)
    CN_type = np.random.binomial(1, 0.5)
    CN_copies = np.random.geometric(0.5) + 1 if CN_type == 1 else 0
    if num_regions - CN_size + 1 <= 0:
        b_start_idx = 0
    else:
        b_start_idx = np.random.randint(num_regions - CN_size + 1)
    b_end_idx = b_start_idx + CN_size
    return {
        "copies": CN_copies,
        "prev_start_idx": b_start_idx * resolution + 1,
        "prev_end_idx": b_end_idx * resolution,
        "region_length": CN_size * resolution,
        "total_length": CN_copies * CN_size * resolution
    }

def apply_mutation_and_trace(blocks, mutation):
    b_start_idx = (mutation["prev_start_idx"] - 1) // resolution
    b_length = mutation["region_length"] // resolution
    new_blocks = mutate(blocks, b_start_idx, b_length, mutation["copies"])
    return new_blocks

def simulate_cell_mutations(blocks, n_mut):
    mutations = []
    for _ in range(n_mut):
        md = generate_random_mutation(len(blocks))
        blocks = apply_mutation_and_trace(blocks, md)
        mutations.append(md)
    return blocks, mutations

def simulate_diploid_cell(parent_blocks, n0, n1, inherited_mutations, logf, name):
    all_mutations = {}
    final_blocks = {}
    for allele, n_mut in [(0, n0), (1, n1)]:
        blocks = parent_blocks[allele]
        inherited = inherited_mutations[allele].copy()
        logf.write(f"\n{name} allele {allele}: Starting with {len(blocks)} blocks\n")
        for md in inherited:
            blocks = apply_mutation_and_trace(blocks, md)
        blocks, new_muts = simulate_cell_mutations(blocks, n_mut)
        all_mutations[allele] = inherited + new_muts
        final_blocks[allele] = blocks
        logf.write(f"{name} allele {allele}: {n_mut} new mutations, final length: {len(blocks)}\n")
        for i, md in enumerate(new_muts):
            logf.write(f"  Mutation {i+1}: start={md['prev_start_idx']}, len={md['region_length']}, copies={md['copies']}\n")
    return all_mutations, final_blocks

tree = Phylo.read(StringIO(nwk_string), "newick")
tree.ladderize()
parent_map = {child: clade for clade in tree.find_clades(order="level") for child in clade.clades}
cumulative_genomes = {}
mutation_records = {}

with open("log.txt", "w") as logf:
    for node in tree.find_clades(order="level"):
        if not node.name:
            continue
        logf.write(f"\nProcessing {node.name}...\n")
        parent = parent_map.get(node)
        if parent is None:
            cumulative_genomes[node.name] = {0: diploid_blocks.copy(), 1: diploid_blocks.copy()}
            mutation_records[node.name] = {0: [], 1: []}
            continue
        inherited_mutations = mutation_records[parent.name]
        parent_blocks = {}
        for allele in [0, 1]:
            blocks = diploid_blocks.copy()
            for md in inherited_mutations[allele]:
                blocks = apply_mutation_and_trace(blocks, md)
            parent_blocks[allele] = blocks
        muts, final_blocks = simulate_diploid_cell(parent_blocks, n_manual_mutations, n_manual_mutations, inherited_mutations, logf, node.name)
        mutation_records[node.name] = muts
        cumulative_genomes[node.name] = final_blocks


In [None]:
from Bio import Phylo
import numpy as np
from io import StringIO

# --- Parameters ---
resolution = 960
cn_length_mean = 1000000
geom_p = 0.5
diploid_blocks = list(range(50000))

# Simple tree for testing FASTA correctness
nwk_string = "(cell1:0.1)founder;"
n_manual_mutations = 1

def mutate(b_rel_idx, b_rel_start, b_length, copies):
    region = copies * b_rel_idx[b_rel_start:b_rel_start + b_length]
    b_new_idx = b_rel_idx[:b_rel_start] + region + b_rel_idx[b_rel_start + b_length:]
    return b_new_idx

def generate_random_mutation(num_regions):
    max_size = max(1, num_regions - 1)
    CN_size = max(min(round(np.random.exponential(cn_length_mean / resolution)), max_size), 1)
    CN_type = np.random.binomial(1, 0.5)
    CN_copies = np.random.geometric(0.5) + 1 if CN_type == 1 else 0
    if num_regions - CN_size + 1 <= 0:
        b_start_idx = 0
    else:
        b_start_idx = np.random.randint(num_regions - CN_size + 1)
    b_end_idx = b_start_idx + CN_size
    return {
        "copies": CN_copies,
        "prev_start_idx": b_start_idx * resolution + 1,
        "prev_end_idx": b_end_idx * resolution,
        "region_length": CN_size * resolution,
        "total_length": CN_copies * CN_size * resolution,
        "start_idx": b_start_idx,
        "length_blocks": CN_size
    }

def apply_mutation_and_trace(blocks, mutation):
    return mutate(blocks, mutation["start_idx"], mutation["length_blocks"], mutation["copies"])

def simulate_cell_mutations(blocks, n_mut):
    mutations = []
    for _ in range(n_mut):
        md = generate_random_mutation(len(blocks))
        blocks = apply_mutation_and_trace(blocks, md)
        mutations.append(md)
    return blocks, mutations

def simulate_diploid_cell(parent_blocks, n0, n1, inherited_mutations, logf, name):
    all_mutations = {}
    final_blocks = {}
    for allele, n_mut in [(0, n0), (1, n1)]:
        blocks = parent_blocks[allele]
        inherited = inherited_mutations[allele].copy()
        logf.write(f"\n{name} allele {allele}: Starting with {len(blocks)} blocks\n")
        for md in inherited:
            blocks = apply_mutation_and_trace(blocks, md)
        blocks, new_muts = simulate_cell_mutations(blocks, n_mut)
        all_mutations[allele] = inherited + new_muts
        final_blocks[allele] = blocks
        logf.write(f"{name} allele {allele}: {n_mut} new mutations, final length: {len(blocks)}\n")
        for i, md in enumerate(new_muts):
            logf.write(f"  Mutation {i+1}: start_idx={md['start_idx']}, "
                       f"length_blocks={md['length_blocks']}, "
                       f"copies={md['copies']}\n")
    return all_mutations, final_blocks

tree = Phylo.read(StringIO(nwk_string), "newick")
tree.ladderize()
parent_map = {child: clade for clade in tree.find_clades(order="level") for child in clade.clades}
cumulative_genomes = {}
mutation_records = {}

with open("log.txt", "w") as logf:
    for node in tree.find_clades(order="level"):
        if not node.name:
            continue
        logf.write(f"\nProcessing {node.name}...\n")
        parent = parent_map.get(node)
        if parent is None:
            cumulative_genomes[node.name] = {0: diploid_blocks.copy(), 1: diploid_blocks.copy()}
            mutation_records[node.name] = {0: [], 1: []}
            continue
        inherited_mutations = mutation_records[parent.name]
        parent_blocks = {}
        for allele in [0, 1]:
            blocks = diploid_blocks.copy()
            for md in inherited_mutations[allele]:
                blocks = apply_mutation_and_trace(blocks, md)
            parent_blocks[allele] = blocks
        muts, final_blocks = simulate_diploid_cell(
            parent_blocks,
            n_manual_mutations,
            0,
            inherited_mutations,
            logf,
            node.name
        )
        mutation_records[node.name] = muts
        cumulative_genomes[node.name] = final_blocks

# Output the mutation data for cell1
mutation_data = mutation_records["cell1"]
mutation_data


In [None]:
len(cumulative_genomes["cell1"][0])

In [None]:
from Bio import Phylo
from io import StringIO

def blocks_to_fasta(blocks, ref_blocks, output_path, label):
    with open(output_path, "w") as f:
        f.write(f">{label}\n")
        for idx in blocks:
            f.write(ref_blocks[idx])
        f.write("\n")

def get_genome_blocks_for_node(node_name, mutation_records, diploid_blocks, resolution=960):
    genome = diploid_blocks.copy()
    for md in mutation_records[node_name]:
        b_start = (md["prev_start_idx"] - 1) // resolution
        b_len = md["region_length"] // resolution
        copies = md["copies"]
        region = copies * genome[b_start:b_start + b_len]
        genome = genome[:b_start] + region + genome[b_start + b_len:]
    return genome

def extract_path_mutations(tree, mutation_records, target):
    # Get all mutation dicts along the path to the leaf
    clade = tree.find_any(name=target)
    path = tree.get_path(clade)
    node_names = [n.name for n in path if n.name]
    node_names.append(target)
    all_mutations = []
    for name in node_names:
        all_mutations.extend(mutation_records.get(name, []))
    return all_mutations

def validate_fasta_from_reference(fasta_path, expected_blocks, ref_blocks):
    with open(fasta_path, "r") as f:
        lines = f.readlines()
    actual_seq = ''.join([line.strip() for line in lines[1:]])
    expected_seq = ''.join([ref_blocks[i] for i in expected_blocks])
    assert actual_seq == expected_seq, "FASTA does not match expected genome!"
    print("✅ FASTA validation passed.")


In [None]:
len(ref_blocks)

In [None]:
# Setup
refname = "/Users/rfeld/Documents/Research/SPATIAL/spatial_24/ref/chr21.fa"
ref = []
with open(refname, 'r') as f:
    for line in f.readlines():
        ref.append(line.strip())
ref = ref[1:]

# convert to 960 blocks - blocks of 60, at initialize. 
num_blocks = len(ref) // 16

new_ref = []
for i in range(0, num_blocks, 16):
    block = ""
    for j in range(16):
        block += ref[i + j]
    new_ref.append(block)

if len(ref) % 16 > 0:
    last_block = ""
    blocks_left = len(ref) % 16
    for j in range(blocks_left):
        last_block += ref[-(blocks_left - j)]
    new_ref.append(last_block)

ref = new_ref
ref_idx = list(range(len(new_ref)))

ref_blocks = new_ref
tree = Phylo.read(StringIO(nwk_string), "newick")

# Generate genome blocks
target_cell = "cell1"
allele = 0
mutation_list = mutation_records[target_cell][allele]
genome_blocks = get_genome_blocks_for_node(target_cell, {target_cell: mutation_list}, diploid_blocks)

# Write FASTA
output_fasta = f"{target_cell}_allele{allele}.fa"
blocks_to_fasta(genome_blocks, ref_blocks, output_fasta, f"{target_cell}_allele{allele}")

# Validate
validate_fasta_from_reference(output_fasta, genome_blocks, ref_blocks)


In [None]:
from Bio import SeqIO, Phylo
import numpy as np
from io import StringIO

# Load the reference genome from the given path and chunk into 960bp blocks
refname = "/Users/rfeld/Documents/Research/SPATIAL/spatial_24/ref/chr21.fa"
ref = []
with open(refname, 'r') as f:
    for line in f.readlines():
        if not line.startswith(">"):
            ref.append(line.strip())

ref = list("".join(ref))
block_size = 960
ref_blocks = ["".join(ref[i:i+block_size]) for i in range(0, len(ref), block_size)]
diploid_blocks = list(range(len(ref_blocks)))

# Setup replay logic
def mutate(b_rel_idx, b_rel_start, b_length, copies):
    region = copies * b_rel_idx[b_rel_start:b_rel_start + b_length]
    return b_rel_idx[:b_rel_start] + region + b_rel_idx[b_rel_start + b_length:]

def replay_mutations_on_diploid(diploid_blocks, mutation_list, resolution):
    blocks = diploid_blocks.copy()
    for md in mutation_list:
        b_start_idx = (md["prev_start_idx"] - 1) // resolution
        b_length = md["region_length"] // resolution
        blocks = mutate(blocks, b_start_idx, b_length, md["copies"])
    return blocks

# FASTA writing
def blocks_to_fasta(block_indices, ref_blocks, fasta_path, header="genome"):
    with open(fasta_path, "w") as f:
        f.write(f">{header}\n")
        sequence = "".join([ref_blocks[i] for i in block_indices])
        for i in range(0, len(sequence), 60):
            f.write(sequence[i:i+60] + "\n")

# FASTA validation
def validate_fasta_from_reference(fasta_path, genome_block_indices, ref_blocks):
    records = list(SeqIO.parse(fasta_path, "fasta"))
    assert len(records) == 1, "Expected one sequence"
    fasta_seq = str(records[0].seq)
    expected_seq = "".join(ref_blocks[i] for i in genome_block_indices)

    if fasta_seq == expected_seq:
        print("✅ FASTA matches expected mutated genome sequence.")
    else:
        for i in range(0, len(expected_seq), 60):
            if fasta_seq[i:i+60] != expected_seq[i:i+60]:
                print(f"❌ Mismatch at position {i}")
                print("Expected:", expected_seq[i:i+60])
                print("Found:   ", fasta_seq[i:i+60])
                break

# Choose a cell and allele to test
target_cell = "cell1"
allele = 0
resolution = 960
output_fasta = f"{target_cell}_allele{allele}.fa"

# Load previous mutation_records (assumed loaded in session)
# replay and write fasta
mutation_list = mutation_records[target_cell][allele]
genome_blocks = replay_mutations_on_diploid(diploid_blocks, mutation_list, resolution)
blocks_to_fasta(genome_blocks, ref_blocks, output_fasta, f"{target_cell}_allele{allele}")

# validate
validate_fasta_from_reference(output_fasta, genome_blocks, ref_blocks)


In [None]:
def print_fasta_slice(fasta_path, block_start, num_blocks):
    """Print a chunk of the FASTA file, block by block."""
    print(f"\n[FASTA {fasta_path} | Blocks {block_start} to {block_start + num_blocks - 1}]")
    with open(fasta_path, 'r') as f:
        lines = [line.strip() for line in f if not line.startswith('>')]
        for i in range(block_start, block_start + num_blocks):
            if i < len(lines):
                print(f"Block {i}: {lines[i][:60]}... ({len(lines[i])} bp)")
            else:
                print(f"Block {i}: [Out of range]")

def print_expected_duplicated_region(mutation_log, resolution):
    """Print info about the mutation: what region was duplicated and where."""
    for i, m in enumerate(mutation_log):
        copies = m['copies']
        if copies > 0:
            start_idx = (m['prev_start_idx'] - 1) // resolution
            num_blocks = m['region_length'] // resolution
            print(f"\n[Mutation {i+1}]")
            print(f"→ Duplicated {num_blocks} blocks starting at block index {start_idx}")
            print(f"→ Duplicated {copies} times")
            print(f"→ Expect original region: blocks {start_idx} to {start_idx + num_blocks - 1}")
            print(f"→ Expect duplicated region(s) starting at: block {start_idx + num_blocks}")


In [None]:
mutation_records

In [None]:
# Parameters
target_cell = "cell1"
allele = 0
resolution = 960

# Paths
fasta_path = f"{target_cell}_allele{allele}.fa"

# Show mutation log
mutation_log = mutation_records[target_cell][allele]
print_expected_duplicated_region(mutation_log, resolution)

# Choose where to inspect the FASTA (start near original or duplicated region)
print_fasta_slice(fasta_path, block_start=3710, num_blocks=10)
print_fasta_slice(fasta_path, block_start=3810, num_blocks=10)

In [None]:
import random

# Parameters
block_length = 60
num_ref_blocks = 100

# Generate random 60bp A/G/T/C strings for a dummy reference
def random_block(length=60):
    return ''.join(random.choices("AGTC", k=length))

dummy_reference = [random_block(block_length) for _ in range(num_ref_blocks)]

# Example mutation:
# Duplicate blocks 10 to 14 (inclusive) two additional times (total of 3 copies)
start_idx = 10
length_blocks = 5
copies = 2  # total = 3 copies

# Create mutated genome
original_region = dummy_reference[start_idx : start_idx + length_blocks]
mutated_genome = (
    dummy_reference[:start_idx] +
    original_region * (copies + 1) +
    dummy_reference[start_idx + length_blocks:]
)

# Now scan the genome to validate the pattern exists as expected
found_matches = []
expected_sequence = original_region * (copies + 1)

for i in range(len(mutated_genome) - length_blocks * (copies + 1) + 1):
    window = mutated_genome[i : i + length_blocks * (copies + 1)]
    if window == expected_sequence:
        found_matches.append(i)

found_matches, original_region[:2], mutated_genome[start_idx:start_idx + 2 + length_blocks * copies]
