# Mutagenesis Protocol
## Created using ChatGPT

### The core of the script is designed to generate an efficient, step-by-step mutagenesis protocol that builds enzyme variants using minimal and non-redundant mutagenesis steps. A databank will be created to store all variants that you produce. Labels used in the databank can be printed out.
The script essentially builds a directed acyclic graph (DAG) of variant derivations, and for each target variant, it:

   - Starts from the deepest reachable node (max existing subset).

   - Extends it minimally by adding only the mutations that are missing.

   - Never rebuilds any variant it already made.

So the number of total mutation steps across all variants is minimized, and common intermediate variants are shared between targets.

## Connect to your Google Drive

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
!pip install reportlab
!pip install primer3-py

Collecting reportlab
  Downloading reportlab-4.4.3-py3-none-any.whl.metadata (1.7 kB)
Downloading reportlab-4.4.3-py3-none-any.whl (2.0 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.0/2.0 MB[0m [31m26.1 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: reportlab
Successfully installed reportlab-4.4.3
Collecting primer3-py
  Downloading primer3_py-2.2.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (4.0 kB)
Downloading primer3_py-2.2.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (3.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.1/3.1 MB[0m [31m35.0 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: primer3-py
Successfully installed primer3-py-2.2.0


# In the first step:
 - Paste the wildtype sequence (including 30 bases at flanking sites)
    - you will be notified when the 30 bases are not added
 - Define the variants
    - mutations have to be provided as string like 'S19V'
    - multiple mutations have to be separated by comma
    - each construct needs to be put into square brackets: ['S19V', 'S30V', 'Q424V', 'A431E'],
    - constructs have to be separated by comma
 - Give the number of mutations that you want to add simultaneously (1 to 2 mutations are possible in this script)
 - Provide a prefix for your variants, this will be used for the databank
 - If you want to print labels (189 cell sheet by Hermes) you can define the starting column (0-6) and row (0-26)

## Primer Design

In [7]:
import os
import re
import csv
import json
from typing import List, Tuple, Dict
from collections import defaultdict
import primer3


class PrimerGenerator:
    """Site-directed mutagenesis primer generator with improved 3'-extension handling."""

    def __init__(self, flank_size_bases=30):
        # Standard genetic code
        self.codon_table = {
            'TTT': 'F', 'TTC': 'F', 'TTA': 'L', 'TTG': 'L',
            'TCT': 'S', 'TCC': 'S', 'TCA': 'S', 'TCG': 'S',
            'TAT': 'Y', 'TAC': 'Y', 'TAA': '*', 'TAG': '*',
            'TGT': 'C', 'TGC': 'C', 'TGA': '*', 'TGG': 'W',
            'CTT': 'L', 'CTC': 'L', 'CTA': 'L', 'CTG': 'L',
            'CCT': 'P', 'CCC': 'P', 'CCA': 'P', 'CCG': 'P',
            'CAT': 'H', 'CAC': 'H', 'CAA': 'Q', 'CAG': 'Q',
            'CGT': 'R', 'CGC': 'R', 'CGA': 'R', 'CGG': 'R',
            'ATT': 'I', 'ATC': 'I', 'ATA': 'I', 'ATG': 'M',
            'ACT': 'T', 'ACC': 'T', 'ACA': 'T', 'ACG': 'T',
            'AAT': 'N', 'AAC': 'N', 'AAA': 'K', 'AAG': 'K',
            'AGT': 'S', 'AGC': 'S', 'AGA': 'R', 'AGG': 'R',
            'GTT': 'V', 'GTC': 'V', 'GTA': 'V', 'GTG': 'V',
            'GCT': 'A', 'GCC': 'A', 'GCA': 'A', 'GCG': 'A',
            'GAT': 'D', 'GAC': 'D', 'GAA': 'E', 'GAG': 'E',
            'GGT': 'G', 'GGC': 'G', 'GGA': 'G', 'GGG': 'G'
        }
        self.stop_codons = {"TAA", "TAG", "TGA"}

        # Reverse codon table
        self.aa_to_codons = defaultdict(list)
        for codon, aa in self.codon_table.items():
            self.aa_to_codons[aa].append(codon)

        # Constants
        self.flank_size_bases = flank_size_bases
        self.flank_size_codons = flank_size_bases // 3
        self.MIN_PRIMER_LEN = 30
        self.MIN_EXTENSION_LEN = 15  # Minimum 3'-extension length
        self.TARGET_TM = 68.0
        self.MIN_TM = 55.0
        self.MAX_TM_DIFF = 3.0
        self.MAX_OVERLAP_LEN = 15  # Hard cap on overlap length

    def check_sequence_validity(self, dna_sequence: str, flank_size: int = 30) -> bool:
        """Validate DNA sequence structure and reading frame."""
        if len(dna_sequence) % 3 != 0:
            print(f"Error: Sequence length {len(dna_sequence)} is not a multiple of 3.")
            return False

        if len(dna_sequence) < (flank_size * 2 + 6):
            print("Error: Sequence too short for flanking regions.")
            return False

        # Check start codon
        start_codon_pos = dna_sequence[flank_size:flank_size + 3]
        if start_codon_pos != "ATG":
            print(f"Error: No ATG start codon after first {flank_size} bases, found '{start_codon_pos}'.")
            return False

        # Check stop codon
        stop_codon_pos = dna_sequence[-(flank_size + 3):-flank_size]
        if stop_codon_pos not in self.stop_codons:
            print(f"Error: No valid stop codon before last {flank_size} bases, found '{stop_codon_pos}'.")
            return False

        print("Sequence check passed ✓: correct ATG start and stop codon positions.")
        return True

    @staticmethod
    def _parse_mutation(mut_str: str) -> Tuple[str, int, str]:
        """Parse mutation string like 'A123B' into (original_aa, position, new_aa)."""
        match = re.match(r'([A-Z])(\d+)([A-Z])', mut_str.strip())
        if not match:
            raise ValueError(f"❌ Invalid mutation format: {mut_str}")
        return match.group(1), int(match.group(2)), match.group(3)

    def validate_mutations_against_sequence(self, dna_sequence: str, variant_list: List[List[str]]) -> None:
        """Validate mutations against wildtype sequence."""
        for variant in variant_list:
            for mut in variant:
                original_aa, pos, _ = self._parse_mutation(mut)

                # Adjust position for flanking regions
                adj_pos_codons = pos + self.flank_size_codons
                codon_start = (adj_pos_codons - 1) * 3
                codon_seq = dna_sequence[codon_start:codon_start + 3]

                wt_aa = self.codon_table.get(codon_seq, 'X')
                if wt_aa != original_aa:
                    raise ValueError(
                        f"❌ Mutation {mut} does not match WT residue '{wt_aa}' "
                        f"at coding AA position {pos} (codon {codon_seq})"
                    )

    def reverse_complement(self, sequence: str) -> str:
        """Return reverse complement of DNA sequence."""
        complement = {'A': 'T', 'T': 'A', 'G': 'C', 'C': 'G'}
        return ''.join(complement[base] for base in sequence[::-1])

    def calculate_tm(self, sequence: str) -> float:
        """Calculate melting temperature using primer3."""
        if not sequence:
            return 0.0
        try:
            return round(primer3.calc_tm(sequence), 2)
        except Exception as e:
            print(f"Error calculating Tm for sequence '{sequence}': {e}")
            return 0.0

    def get_best_codon(self, amino_acid: str, original_codon: str) -> str:
        """Get optimal codon for amino acid, preferring original if available."""
        codons = self.aa_to_codons.get(amino_acid, ["NNN"])
        return original_codon if original_codon in codons else codons[0]

    @staticmethod
    def flatten_mutations(mutation_groups):
        flat = []
        for group in mutation_groups:
            if isinstance(group, (list, tuple)) and all(isinstance(i, tuple) and len(i) == 3 for i in group):
                flat.extend(group)
            else:
                flat.append(group)
        return flat

    def merge_close_mutations(self, mutations: List[Tuple[str, int, str]], max_distance_nt: int = 21) -> List[List[Tuple[str, int, str]]]:
        sorted_muts = sorted(mutations, key=lambda x: x[1])
        if not sorted_muts:
            return []

        merged = []
        current_group = [sorted_muts[0]]

        for mut in sorted_muts[1:]:
            last_position = current_group[-1][1]
            if (mut[1] - last_position) * 3 <= max_distance_nt:
                current_group.append(mut)
            else:
                merged.append(current_group)
                current_group = [mut]
        merged.append(current_group)
        return merged

    def _create_mutated_sequence(self, dna_sequence: str, mutations: List[Tuple[str, int, str]]) -> str:
        """Create mutated DNA sequence from list of mutations."""
        mutated_seq = list(dna_sequence)

        for aa, pos, new_aa in mutations:
            idx = ((pos - 1) + self.flank_size_codons) * 3
            original_codon = dna_sequence[idx:idx + 3]
            new_codon = self.get_best_codon(new_aa, original_codon)

            # Verify original amino acid
            if self.codon_table.get(original_codon, 'X') != aa:
                print(f"Warning: Original amino acid mismatch at position {pos}")

            mutated_seq[idx:idx + 3] = list(new_codon)

        return ''.join(mutated_seq)

    def _find_optimal_overlap(self, mutated_seq, codon_starts, codon_ends):
        min_inside_flank = 4
        crit_flank_limit = 2
        reverse_flank_weight = 2.0
        forward_flank_weight = 1.5
        wt_inside_weight = 0.5
        penalty_per_missing_flank_nt = 50
        upstream_bias_weight = 0.2

        mutation_positions = sorted(set(codon_starts))
        seq_len = len(mutated_seq)
        target_tm = self.TARGET_TM

        def search_len(curr_len):
            best_score = -1e9
            best_start, best_end, best_tm = None, None, None
            best_flanks = (None, None)
            for start_pos in range(0, seq_len - curr_len + 1):
                end_pos = start_pos + curr_len
                muts_inside = [m for m in mutation_positions if start_pos <= m < end_pos]
                if not muts_inside:
                    continue

                first_inside = min(muts_inside)
                last_inside = max(muts_inside)
                reverse_flank = first_inside - start_pos
                forward_flank = end_pos - (last_inside + 3)

                # reject critical short flanks
                if reverse_flank < crit_flank_limit or forward_flank < crit_flank_limit:
                    continue

                count = len(muts_inside)
                ext_span = self._extension_mutation_span(mutation_positions, start_pos, end_pos)
                overlap_seq = mutated_seq[start_pos:end_pos]
                overlap_tm = self.calculate_tm(overlap_seq)

                total_mut_nt = count * 3
                total_wt_nt = curr_len - total_mut_nt

                flank_penalty = 0
                if forward_flank < min_inside_flank:
                    flank_penalty += (min_inside_flank - forward_flank) * penalty_per_missing_flank_nt
                if reverse_flank < min_inside_flank:
                    flank_penalty += (min_inside_flank - reverse_flank) * penalty_per_missing_flank_nt

                score = (
                    count * 1000
                    + (forward_flank * forward_flank_weight)
                    + (reverse_flank * reverse_flank_weight)
                    + (total_wt_nt * wt_inside_weight)
                    - (ext_span * 10)
                    - flank_penalty
                    - abs(overlap_tm - target_tm)
                    - (start_pos * upstream_bias_weight)
                )

                if score > best_score:
                    best_score = score
                    best_start, best_end, best_tm = start_pos, end_pos, overlap_tm
                    best_flanks = (forward_flank, reverse_flank)

            return best_start, best_end, best_tm, best_flanks, best_score

        # --- Stage 1: try exact MAX_OVERLAP_LEN ---
        best_start, best_end, best_tm, (f_flank, r_flank), _ = search_len(self.MAX_OVERLAP_LEN)
        if best_start is not None and f_flank >= min_inside_flank and r_flank >= min_inside_flank:
            return best_start, best_end, best_tm

        # --- Stage 2: allow +1 and +2 bases, pick best scoring ---
        best_len_candidate = None
        best_len_score = -1e9
        for curr_len in [self.MAX_OVERLAP_LEN + 1, self.MAX_OVERLAP_LEN + 2]:
            cand_start, cand_end, cand_tm, _, cand_score = search_len(curr_len)
            if cand_start is not None and cand_score > best_len_score:
                best_len_candidate = (cand_start, cand_end, cand_tm)
                best_len_score = cand_score

        if best_len_candidate is not None:
            return best_len_candidate

        # --- Fallback: center between first/last mutation ---
        first_mut = min(mutation_positions)
        last_mut = max(mutation_positions)
        center = (first_mut + last_mut) // 2
        best_start = max(0, center - self.MAX_OVERLAP_LEN // 2)
        best_end = min(seq_len, best_start + self.MAX_OVERLAP_LEN)
        best_tm = self.calculate_tm(mutated_seq[best_start:best_end])
        return best_start, best_end, best_tm



    def _extension_mutation_span(self, mutation_positions, overlap_start, overlap_end):
        """
        Helper: compute span of mutations outside the overlap.
        Smaller = better.
        """
        outside = [pos for pos in mutation_positions if pos < overlap_start or pos >= overlap_end]
        if not outside:
            return 0
        return max(outside) - min(outside)

    def _build_primer(self, mutated_seq: str, overlap_seq: str, extension_start: int,
                      extension_direction: int, is_reverse: bool = False) -> Tuple[str, float]:
        """Build primer ensuring min 3'-extension length."""

        if is_reverse:
            extension_end = extension_start
            extension_start = max(0, extension_end - self.MIN_EXTENSION_LEN)
            extension_template = mutated_seq[extension_start:extension_end]
            extension = self.reverse_complement(extension_template)
            primer = self.reverse_complement(overlap_seq) + extension
        else:
            extension_end = min(len(mutated_seq), extension_start + self.MIN_EXTENSION_LEN)
            extension = mutated_seq[extension_start:extension_end]
            primer = overlap_seq + extension

        tm = self.calculate_tm(primer)
        return primer, tm

    def _extend_primer_one_base(self, mutated_seq: str, current_primer: str, overlap_seq: str,
                              extension_start: int, extension_end: int,
                              is_reverse: bool) -> Tuple[str, float, int, int]:

        MAX_PRIMER_LEN = 45  # define your max length here or use class attribute

        # Calculate current primer length
        current_len = len(current_primer)

        # Prevent extension if max length reached
        if current_len >= MAX_PRIMER_LEN:
            return current_primer, self.calculate_tm(current_primer), extension_start, extension_end

        if is_reverse:
            if extension_start == 0:
                # can't extend beyond start of sequence
                return current_primer, self.calculate_tm(current_primer), extension_start, extension_end

            extension_start = max(0, extension_start - 1)
            extension_template = mutated_seq[extension_start:extension_end]
            extension = self.reverse_complement(extension_template)
            primer = self.reverse_complement(overlap_seq) + extension
        else:
            if extension_end >= len(mutated_seq):
                # can't extend beyond end of sequence
                return current_primer, self.calculate_tm(current_primer), extension_start, extension_end

            extension_end = min(len(mutated_seq), extension_end + 1)
            extension = mutated_seq[extension_start:extension_end]
            primer = overlap_seq + extension

        tm = self.calculate_tm(primer)
        return primer, tm, extension_start, extension_end

    def count_nt_changes(self, wt_seq: str, mutations: List[Tuple[str, int, str]]) -> int:
        """Return the count of nucleotide changes vs wildtype after applying `mutations`."""
        mutated_seq = self._create_mutated_sequence(wt_seq, mutations)
        return sum(1 for a,b in zip(wt_seq, mutated_seq) if a != b)

    def split_mutations_by_max_nt_changes(
        self, wt_seq: str, mutation_group: List[Tuple[str, int, str]], max_nt_changes: int = 6
    ) -> List[List[Tuple[str, int, str]]]:
        """
        Split mutation group so that each group does not exceed max_nt_changes.
        Ensures codon-aware splits. Mutations are grouped as required to stay under the threshold.
        """
        sorted_mutations = sorted(mutation_group, key=lambda m: m[1])
        split_groups = []
        current_group = []
        for mut in sorted_mutations:
            test_list = current_group + [mut]
            if self.count_nt_changes(wt_seq, test_list) > max_nt_changes:
                if current_group:
                    split_groups.append(current_group)
                current_group = [mut]
            else:
                current_group.append(mut)
        if current_group:
            split_groups.append(current_group)
        return split_groups

    def generate_primers_for_construct(self, dna_sequence: str, mutations: List[Tuple[str, int, str]]) -> List[Dict]:
        mutations = self.flatten_mutations(mutations)
        primer_sets = []
        mutation_groups = self.merge_close_mutations(mutations)

        MAX_PRIMER_LEN = 45
        MAX_MUT_NT = 6  # Maximum nucleotide changes per primer pair

        for group in mutation_groups:
            # Split groups if nucleotide changes exceed max allowed per primer
            split_groups = self.split_mutations_by_max_nt_changes(dna_sequence, group, MAX_MUT_NT)
            previous_sub_group = []

            for i, sub_group in enumerate(split_groups):
                # Determine codon starts/ends for current mutations
                codon_starts_sub = [((m[1] - 1) + self.flank_size_codons) * 3 for m in sub_group]
                codon_ends_sub = [start + 3 for start in codon_starts_sub]

                # --- Accumulate previous mutations ---
                all_applied_muts = previous_sub_group + sub_group
                # Create mutated sequence including all previous mutations
                mutated_seq = self._create_mutated_sequence(dna_sequence, all_applied_muts)

                # Find optimal overlap for these mutations
                overlap_start, overlap_end, _ = self._find_optimal_overlap(mutated_seq, codon_starts_sub, codon_ends_sub)

                overlap_seq = mutated_seq[overlap_start:overlap_end]

                # Forward primer initial design
                forward_start = overlap_end
                forward_end = overlap_end + self.MIN_EXTENSION_LEN
                forward_primer = overlap_seq + mutated_seq[forward_start:forward_end]
                forward_tm = self.calculate_tm(forward_primer)

                # Reverse primer initial design
                reverse_end = overlap_start
                reverse_start = max(0, reverse_end - self.MIN_EXTENSION_LEN)
                reverse_template = mutated_seq[reverse_start:reverse_end]
                reverse_primer = self.reverse_complement(overlap_seq) + self.reverse_complement(reverse_template)
                reverse_tm = self.calculate_tm(reverse_primer)

                # Balance primer Tm with incremental extension
                tm_diff = abs(forward_tm - reverse_tm)
                attempts = 0
                max_attempts = 30
                while (
                    (tm_diff > self.MAX_TM_DIFF) or
                    (forward_tm < self.TARGET_TM) or
                    (reverse_tm < self.TARGET_TM)
                ) and attempts < max_attempts:
                    if forward_tm < reverse_tm and len(forward_primer) < MAX_PRIMER_LEN:
                        forward_primer, forward_tm, forward_start, forward_end = self._extend_primer_one_base(
                            mutated_seq, forward_primer, overlap_seq, forward_start, forward_end, is_reverse=False)
                    elif len(reverse_primer) < MAX_PRIMER_LEN:
                        reverse_primer, reverse_tm, reverse_start, reverse_end = self._extend_primer_one_base(
                            mutated_seq, reverse_primer, overlap_seq, reverse_start, reverse_end, is_reverse=True)
                    else:
                        break
                    tm_diff = abs(forward_tm - reverse_tm)
                    attempts += 1

                # Combine previous_sub_group + sub_group uniquely by mutation identity (aa, pos, new_aa)
                all_mut_set = { (m[0], m[1], m[2]) for m in previous_sub_group + sub_group }
                all_mut_list = sorted(all_mut_set, key=lambda x: x[1])

                # Create a unique ID string for this primer set from all covered mutations
                set_id = "_".join([f"{m[0]}{m[1]}{m[2]}" for m in all_mut_list])

                # Determine dependency: for the first primer set, no dependencies
                if len(primer_sets) == 0:
                    depends_on = []
                else:
                    # Depend on the immediate previous primer set
                    depends_on = [primer_sets[-1]['set_id']]

                primer_sets.append({
                    'set_id': set_id,
                    'depends_on': depends_on,
                    'mutations': [f"{m[0]}{m[1]}{m[2]}" for m in sub_group],  # current subgroup mutations
                    'all_covered_mutations': [f"{m[0]}{m[1]}{m[2]}" for m in all_mut_list],  # total mutations in primer sequence
                    'overlap_sequence': overlap_seq,
                    'overlap_tm': self.calculate_tm(overlap_seq),
                    'forward_primer': forward_primer,
                    'reverse_primer': reverse_primer,
                    'forward_tm': forward_tm,
                    'reverse_tm': reverse_tm,
                    'forward_length': len(forward_primer),
                    'reverse_length': len(reverse_primer),
                    'overlap_length': len(overlap_seq),
                    'tm_difference': tm_diff
                })

                previous_sub_group += sub_group

        return primer_sets



    def save_primer_list(self, primer_data: List[Dict], output_dir: str = None, filename: str = "primer_list.txt") -> None:
        """Save primer data to text and CSV files."""
        if output_dir is None:
            output_dir = os.path.expanduser("~/Mutagenesis")

        os.makedirs(output_dir, exist_ok=True)

        # Save text format
        filepath_txt = os.path.join(output_dir, filename)
        with open(filepath_txt, 'w') as f:
            f.write("Site-Directed Mutagenesis Primer List\n")
            f.write("=" * 120 + "\n\n")

            header = (f"{'Primer Name':<20}{'Primer Sequence':<60}{'Len':>6}"
                      f"{'Tm (°C)':>9}{'OverlapLen':>11}{'OverlapTm':>10}{'Notes':>30}\n")
            f.write(header)
            f.write("-" * 120 + "\n")

            for entry in primer_data:
                primer_base = "_".join(entry['all_covered_mutations'])
                notes = ",".join(entry['all_covered_mutations'])

                # Forward primer
                f.write(f"{primer_base + '_for':<20}")
                f.write(f"{entry['forward_primer']:<60}")
                f.write(f"{entry['forward_length']:>6}")
                f.write(f"{entry['forward_tm']:>9.1f}")
                f.write(f"{entry['overlap_length']:>11}")
                f.write(f"{entry['overlap_tm']:>10.1f}")
                f.write(f"{notes:>30}\n")

                # Reverse primer
                f.write(f"{primer_base + '_rev':<20}")
                f.write(f"{entry['reverse_primer']:<60}")
                f.write(f"{entry['reverse_length']:>6}")
                f.write(f"{entry['reverse_tm']:>9.1f}")
                f.write(f"{entry['overlap_length']:>11}")
                f.write(f"{entry['overlap_tm']:>10.1f}")
                f.write(f"{notes:>30}\n")
                f.write("-" * 120 + "\n")

            f.write(f"\nSummary:\nTotal primer pairs: {len(primer_data)}\n")

        # Save CSV format
        filepath_csv = os.path.join(output_dir, "primer_list.csv")
        with open(filepath_csv, 'w', newline='') as csvfile:
            writer = csv.writer(csvfile)
            writer.writerow([
                "Primer Name", "Primer Sequence", "Length", "Tm (°C)",
                "Overlap Length", "Overlap Tm", "Mutations"
            ])

            for entry in primer_data:
                all_muts = ",".join(entry['all_covered_mutations'])
                base_name = "_".join(entry['all_covered_mutations'])
                writer.writerow([
                    f"{base_name}_for", entry['forward_primer'], entry['forward_length'],
                    entry['forward_tm'], entry['overlap_length'], entry['overlap_tm'], all_muts
                ])
                writer.writerow([
                    f"{base_name}_rev", entry['reverse_primer'], entry['reverse_length'],
                    entry['reverse_tm'], entry['overlap_length'], entry['overlap_tm'], all_muts
                ])

        print(f"Primer files saved:\n  - {filepath_txt}\n  - {filepath_csv}")

    def save_primer_data_json(self, primer_data: List[Dict], output_dir: str = None, filename: str = "primer_list.json") -> None:
        """Save primer data as JSON."""
        if output_dir is None:
            output_dir = os.path.expanduser("~/Mutagenesis")

        os.makedirs(output_dir, exist_ok=True)
        filepath_json = os.path.join(output_dir, filename)

        with open(filepath_json, 'w') as f:
            json.dump(primer_data, f, indent=2)

        print(f"Primer JSON saved to {filepath_json}")

## Mutagenesis Protocol

In [18]:
import networkx as nx
import re
import json
from pathlib import Path
from collections import Counter, defaultdict
from reportlab.lib.pagesizes import A4
from reportlab.lib.styles import getSampleStyleSheet, ParagraphStyle
from reportlab.lib import colors
from reportlab.platypus import SimpleDocTemplate, Paragraph, Spacer, Table, TableStyle, PageBreak
from reportlab.pdfgen import canvas
from reportlab.lib.units import mm

class MutagenesisProtocol:
    def __init__(
        self,
        variant_input,
        max_mutations_per_step,
        variant_prefix,
        output_dir_path,
        start_col=2,
        start_row=16,
        list_existing_as_steps=False
    ):
        self.variant_input = variant_input
        self.max_mutations_per_step = max_mutations_per_step
        self.variant_prefix = variant_prefix
        self.start_col = start_col
        self.start_row = start_row
        self.mutagenesis_dir = Path(output_dir_path)
        self.primer_json_path = self.mutagenesis_dir / "primer_list.json"
        self.output_dir = self.mutagenesis_dir / "protocols"
        self.output_dir.mkdir(parents=True, exist_ok=True)
        self.pdf_path = self.output_dir / "mutagenesis_protocol.pdf"
        self.databank_file = self.output_dir / "variant_databank.json"
        self.list_existing_as_steps = list_existing_as_steps

    @staticmethod
    def mutation_pos(mutation):
        matches = re.findall(r'\d+', mutation)
        return int(matches[0]) if matches else 0

    @staticmethod
    def variant_key(mutations):
        return ','.join(sorted(mutations, key=MutagenesisProtocol.mutation_pos))

    @staticmethod
    def extract_variant_number(name):
        match = re.search(r'variant (\d+)', name)
        return int(match.group(1)) if match else -1

    @staticmethod
    def normalize_mut(m):
        return m.strip().upper()

    def get_base_variant(self, current, existing_variants):
        """Find the best existing variant to use as a starting point"""
        best = ()
        for var in existing_variants:
            if set(var).issubset(current) and len(var) > len(best):
                best = var
        return best

    def design_primer_pathways(self, primer_list, target_variants):
        """
        Design optimal mutagenesis pathways using available primers.
        Only selects primers whose 'all_covered_mutations' are a subset of target mutations.

        primer_list: list of primers loaded from primer_list.json
        target_variants: list of target mutation sets, e.g., [["S19V"], ["P423F", "Q424V"]]
        Returns: list of workflows for each target variant
        """
        all_workflows = []

        for variant_idx, target_mutations in enumerate(target_variants):
            print(f"\nProcessing variant {variant_idx + 1}: {', '.join(target_mutations)}")

            steps = []
            target_muts = set(self.normalize_mut(m) for m in target_mutations)
            primer_map = {p['set_id']: p for p in primer_list}
            used_primers = []
            remaining_muts = target_muts.copy()

            def add_primer_with_deps(primer_id):
                # Recursively add dependencies first
                for dep in primer_map[primer_id].get('depends_on', []):
                    if dep not in used_primers:
                        add_primer_with_deps(dep)
                if primer_id not in used_primers:
                    used_primers.append(primer_id)
                    primer_info = primer_map[primer_id]
                    step_muts = [self.normalize_mut(m) for m in primer_info['all_covered_mutations']]
                    mutations_added = [self.normalize_mut(m) for m in primer_info['mutations']]
                    steps.append({
                        'primer_id': primer_id,
                        'primer_info': primer_info,
                        'all_covered_mutations': step_muts,
                        'mutations_added': mutations_added
                    })
                    remaining_muts.difference_update(mutations_added)

            while remaining_muts:
                # Find primers that only introduce mutations we want (subset of target)
                # and cover the most remaining mutations
                candidates = []
                for p in primer_list:
                    primer_all_muts = set(self.normalize_mut(m) for m in p['all_covered_mutations'])

                    # CRITICAL: Only consider primers whose all_covered_mutations are subset of target
                    if not primer_all_muts.issubset(target_muts):
                        continue

                    new_coverage = primer_all_muts & remaining_muts
                    if new_coverage:
                        candidates.append((p['set_id'], len(new_coverage), primer_all_muts, len(primer_all_muts)))

                if not candidates:
                    print(f"❌ Cannot cover remaining mutations: {remaining_muts}")
                    print("❌ No primers available that only introduce target mutations")
                    raise RuntimeError("Cannot cover all mutations with available primers that don't introduce unwanted mutations.")

                # Sort by: coverage of remaining mutations (desc), then total mutations (desc)
                candidates.sort(key=lambda x: (x[1], x[3]), reverse=True)
                best_primer_id = candidates[0][0]

                print(f"  Selected primer {best_primer_id} (covers {candidates[0][1]} remaining mutations)")
                add_primer_with_deps(best_primer_id)

            # Print the workflow
            print(f"\nWorkflow for variant {variant_idx + 1}:")
            print("Step  |  Primer Set  |  Mutations Added  |  All Covered")
            print("-" * 70)
            for step_num, step in enumerate(steps, 1):
                added_str = ', '.join(step['mutations_added'])
                covered_str = ', '.join(step['all_covered_mutations'])
                print(f"{step_num:2d}    |  {step['primer_id']:15s}  |  {added_str:15s}  |  {covered_str}")

            all_workflows.append(steps)

        return all_workflows

    def generate_pdf(
        self,
        protocol_by_round,
        final_variants,
        filename,
        existing_input_variants,
        variant_to_label_map,
        variant_to_final_variant,
        used_variants
    ):
        doc = SimpleDocTemplate(filename, pagesize=A4)
        elements = []
        styles = getSampleStyleSheet()
        cell_style = ParagraphStyle('mut_cell', fontSize=9, leading=11)

        if existing_input_variants:
            elements.append(Paragraph("<b>Pre-existing Variants in Databank</b>", styles['Heading2']))
            elements.append(Spacer(1, 12))
            data = [['Input Label', 'Existing Variant Label']]
            for input_label, existing_label in sorted(existing_input_variants.items()):
                data.append([input_label, existing_label or '(unknown)'])
            table = Table(data, colWidths=[150, 250])
            table.setStyle(TableStyle([
                ('BACKGROUND', (0, 0), (-1, 0), colors.lightgrey),
                ('GRID', (0, 0), (-1, -1), 0.5, colors.grey)
            ]))
            elements.append(table)
            elements.append(PageBreak())

        if not protocol_by_round:
            elements.append(Paragraph("<b>No new variants generated.</b>", styles['Heading2']))
            elements.append(Spacer(1, 12))
            if not existing_input_variants:
                elements.append(Paragraph("No variants were generated because no input variants were provided.", styles['Normal']))
            else:
                elements.append(Paragraph("All input variants already exist in the databank.", styles['Normal']))
            doc.build(elements)
            return

        elements.append(Paragraph("<b>Final Variants Overview</b>", styles['Heading2']))
        final_table_data = [['Input Label', 'Final Variant', 'Mutations']]
        for label in sorted(variant_to_label_map, key=self.extract_variant_number):
            sorted_mutations = sorted(variant_to_label_map[label], key=self.mutation_pos)
            mut_paragraph = Paragraph(', '.join(sorted_mutations), cell_style)
            final_table_data.append([label, variant_to_final_variant[label], mut_paragraph])
        table = Table(final_table_data, colWidths=[100, 100, 260], repeatRows=1)
        table.setStyle(TableStyle([
            ('BACKGROUND', (0, 0), (-1, 0), colors.lightgrey),
            ('GRID', (0, 0), (-1, -1), 0.5, colors.grey)
        ]))
        elements.append(table)
        elements.append(Spacer(1, 12))

        elements.append(Paragraph("<b>Step-by-Step Mutagenesis Protocol (Grouped by Round)</b>", styles['Heading2']))
        for i, round_num in enumerate(sorted(protocol_by_round), start=1):
            steps = sorted(protocol_by_round[round_num], key=lambda r: self.extract_variant_number(r[1]))
            elements.append(Paragraph(f"<b>Step {i}</b>", styles['Heading3']))
            table_data = [['New Variant', 'Parent Variant', 'Mutations Added']]
            for row in steps:
                sorted_muts = ', '.join(sorted(row[2].split(', '), key=self.mutation_pos))
                table_data.append([row[0], row[1], Paragraph(sorted_muts, cell_style)])
            step_table = Table(table_data, colWidths=[100, 120, 160], repeatRows=1)
            step_table.setStyle(TableStyle([
                ('BACKGROUND', (0, 0), (-1, 0), colors.lightgrey),
                ('GRID', (0, 0), (-1, -1), 0.5, colors.grey)
            ]))
            elements.append(step_table)
            elements.append(Spacer(1, 10))

        elements.append(PageBreak())
        elements.append(Paragraph("<b>All Produced Variants</b>", styles['Heading2']))
        all_variants_data = [['Variant Name', 'Mutations']]
        for muts, name in used_variants.items():
            mut_text = ', '.join(muts) if muts else '(none)'
            mut_paragraph = Paragraph(mut_text, cell_style)
            all_variants_data.append([name, mut_paragraph])
        table = Table(all_variants_data, colWidths=[140, 320], repeatRows=1)
        table.setStyle(TableStyle([
            ('BACKGROUND', (0, 0), (-1, 0), colors.lightgrey),
            ('GRID', (0, 0), (-1, -1), 0.5, colors.grey),
            ('VALIGN', (0, 0), (-1, -1), 'TOP'),
            ('LEFTPADDING', (0, 0), (-1, -1), 5),
            ('RIGHTPADDING', (0, 0), (-1, -1), 5),
        ]))
        elements.append(table)

        doc.build(elements)

    def generate_label_pdf(self, variants, start_col, start_row, base_dir):
        base_dir = Path(base_dir)
        variant_json_file = base_dir / "protocols" / "variant_databank.json"
        pdf_output_file = base_dir / "protocols" / "herma_10900_labels.pdf"

        page_width, page_height = A4
        label_width = 25.3 * mm
        label_height = 9.9 * mm
        horizontal_spacing = 2.5 * mm
        top_margin = 13.5 * mm
        left_margin = 10.3 * mm
        columns = 7
        rows = 27
        font_size = 5.5

        def variant_to_key(mutations):
            return ','.join(sorted(mutations, key=self.mutation_pos))

        with open(variant_json_file, "r") as f:
            variant_data = json.load(f)

        def mutations_str_to_set(mutation_str):
            if mutation_str == '(none)':
                return set()
            return set(mutation_str.split(','))

        input_mutation_sets = [set(v) for v in variants]
        selected_variants = []
        for label, mutation_str in variant_data.items():
            variant_mut_set = mutations_str_to_set(mutation_str)
            if any(variant_mut_set.issubset(inp_set) for inp_set in input_mutation_sets):
                selected_variants.append(label)

        def variant_sort_key(label):
            match = re.search(r'\D+(\d+)$', label)
            return int(match.group(1)) if match else 0

        selected_variants = sorted(selected_variants, key=variant_sort_key)
        c = canvas.Canvas(str(pdf_output_file), pagesize=A4)
        c.setFont("Helvetica", font_size)

        label_index = 0
        total_labels = len(selected_variants)
        page_number = 1

        current_start_col = start_col
        current_start_row = start_row

        while label_index < total_labels:
            for row in range(rows):
                for col in range(columns):
                    if page_number == 1:
                        if row < current_start_row or (row == current_start_row and col < current_start_col):
                            continue
                    if label_index >= total_labels:
                        break

                    x = left_margin + col * (label_width + horizontal_spacing)
                    y = page_height - top_margin - (row + 1) * label_height

                    label = selected_variants[label_index]
                    text_x = x + label_width / 2
                    text_y = y + label_height / 2 - font_size / 2

                    c.drawCentredString(text_x, text_y, label)
                    label_index += 1

                if label_index >= total_labels:
                    break

            if label_index < total_labels:
                c.showPage()
                c.setFont("Helvetica", font_size)
                page_number += 1
                current_start_row = 0
                current_start_col = 0

        c.save()
        print(f"✅ PDF with labels generated: {pdf_output_file}")

    def run(self):
        # Check primer JSON exists
        if not self.primer_json_path.exists():
            raise FileNotFoundError(f"Primer JSON not found: {self.primer_json_path}")

        with open(self.primer_json_path, "r") as f:
            primer_data = json.load(f)

        mutation_to_groups = {}
        primer_groups = {}
        for idx, primer_set in enumerate(primer_data):
            group_muts = [self.normalize_mut(m) for m in primer_set["mutations"]]
            primer_groups[idx] = sorted(group_muts, key=self.mutation_pos)
            for mut in group_muts:
                mutation_to_groups.setdefault(mut, []).append(idx)

        if self.list_existing_as_steps:
            # SKIP DATABANK: Start fresh, only wildtype is "existing"
            variant_databank = {}
            existing_variants = {(): 'wildtype'}
        else:
            if self.databank_file.exists():
                try:
                    with open(self.databank_file, "r") as f:
                        content = f.read().strip()
                        variant_databank = json.loads(content) if content else {}
                except json.JSONDecodeError:
                    print("⚠️ Warning: Databank file invalid. Starting fresh.")
                    variant_databank = {}
            else:
                variant_databank = {}
            existing_variants = {(): 'wildtype'}
            for key, mutation_str in variant_databank.items():
                muts = tuple(mutation_str.split(',')) if mutation_str != '(none)' else ()
                existing_variants[muts] = key

        all_mutations = [mut for v in self.variant_input for mut in v]
        mutation_freq = Counter(all_mutations)
        variants = [sorted(v, key=lambda x: -mutation_freq[x]) for v in self.variant_input]

        existing_variants = {(): 'wildtype'}
        for key, mutation_str in variant_databank.items():
            muts = tuple(mutation_str.split(',')) if mutation_str != '(none)' else ()
            existing_variants[muts] = key

        final_variants = defaultdict(list)
        variant_to_label_map = {}
        variant_to_final_variant = {}
        protocol_by_round = defaultdict(list)
        used_variants = {}
        existing_input_variants = {}

        for idx, target_mutations in enumerate(variants):
            target = tuple(target_mutations)
            variant_key_str = self.variant_key(target)
            variant_label = f"variant {idx+1}"

            if variant_key_str in variant_databank.values():
                if not self.list_existing_as_steps:
                    # Just note it in the pre-existing list and skip real planning
                    for muts, name in existing_variants.items():
                        if variant_key_str == self.variant_key(muts):
                            existing_input_variants[variant_label] = name
                            break
                    continue
                else:
                    # Force regenerate a full protocol path as if it's new
                    # Remove the exact match from existing_variants so planning will run
                    for muts in list(existing_variants.keys()):
                        if self.variant_key(muts) == variant_key_str:
                            del existing_variants[muts]
                    # do NOT 'continue' here — fall through to full mutation batching logic

            base_variant = self.get_base_variant(target, existing_variants)
            current_mutations = list(base_variant)
            base_name = existing_variants[base_variant]

            target_group_ids = []
            for mut in target:
                mut_norm = self.normalize_mut(mut)
                possible_gids = mutation_to_groups.get(mut_norm, [])
                if possible_gids:
                    filtered_gids = [
                        gid for gid in possible_gids
                        if all(m in map(self.normalize_mut, target) for m in primer_groups[gid])
                        and any(m not in map(self.normalize_mut, current_mutations) for m in primer_groups[gid])
                    ]
                    if filtered_gids:
                        best_gid = max(
                            filtered_gids,
                            key=lambda gid: (
                                len([m for m in primer_groups[gid] if self.normalize_mut(m) in map(self.normalize_mut, target)]),
                                len(primer_groups[gid])
                            )
                        )
                        if best_gid not in target_group_ids:
                            target_group_ids.append(best_gid)

            needed_group_ids = []
            for gid in target_group_ids:
                group_muts = primer_groups[gid]
                if not all(m in map(self.normalize_mut, current_mutations) for m in group_muts):
                    needed_group_ids.append(gid)

            while needed_group_ids:
                batch_ids = needed_group_ids[:self.max_mutations_per_step]
                batch_muts = []
                for gid in batch_ids:
                    batch_muts.extend(primer_groups[gid])

                test_mutations = sorted(
                    set(map(self.normalize_mut, current_mutations)) | set(batch_muts),
                    key=self.mutation_pos
                )
                test_sorted = tuple(test_mutations)

                if test_sorted not in existing_variants:
                    mutation_count = len(test_sorted)
                    new_name = f"{variant_label}.{mutation_count}"
                    existing_variants[test_sorted] = new_name
                    used_variants[test_sorted] = new_name
                    protocol_by_round[mutation_count].append([new_name, base_name, ', '.join(batch_muts)])
                    base_name = new_name
                    current_mutations = test_mutations

                for gid in batch_ids:
                    needed_group_ids.remove(gid)

            final_variants[base_name].append(target)
            variant_to_label_map[variant_label] = target
            variant_to_final_variant[variant_label] = base_name
            used_variants[tuple(sorted(map(self.normalize_mut, target), key=self.mutation_pos))] = base_name

        self.generate_pdf(
            protocol_by_round, final_variants, str(self.pdf_path),
            existing_input_variants, variant_to_label_map,
            variant_to_final_variant, used_variants
        )

        existing_ids = [int(v_id[-2:]) for v_id in variant_databank if v_id.startswith(self.variant_prefix)]
        next_id_num = max(existing_ids, default=0) + 1
        for muts, name in existing_variants.items():
            mutation_str = ','.join(sorted(muts, key=self.mutation_pos))
            if mutation_str not in variant_databank.values():
                new_id = f"{self.variant_prefix}{next_id_num:02d}"
                variant_databank[new_id] = mutation_str
                next_id_num += 1

        databank_file = Path("/content/drive/My Drive/Mutagenesis/protocols/variant_databank.json")
        if databank_file.exists():
            with open(databank_file, 'r') as f:
                try:
                    current_state = json.load(f)
                except json.JSONDecodeError:
                    current_state = {}
            undo_stack.append(copy.deepcopy(current_state))
        else:
            undo_stack.append({})  # empty dict if file does not exist

        if not self.list_existing_as_steps:
            with open(self.databank_file, "w") as f:
                json.dump(variant_databank, f, indent=2)

        # Now save new databank
        with open(databank_file, 'w') as f:
            json.dump(variant_databank, f, indent=2)

        print(f"✅ PDF generated: {self.pdf_path}")
        print(f"📁 Databank updated: {self.databank_file}")

        self.generate_label_pdf(self.variant_input, self.start_col, self.start_row, self.mutagenesis_dir)

## User Interface

Open the cell if closed and fill out the widget below!

Click the **Button** below to start the script.

### Create widget

In [19]:
import ipywidgets as widgets
from IPython.display import display, clear_output
from pathlib import Path
import re
import copy

# === 1. INTERACTIVE FORM WIDGETS ===
sequence_input = widgets.Textarea(
    value='Enter wildtype DNA sequence here...',
    placeholder='Paste your DNA sequence...',
    description='Sequence:',
    layout=widgets.Layout(width='80%', height='100px')
)

variant_input_box = widgets.Textarea(
    value='S19V, S30V, Q424V, A431E\nK3Q, K36R, D76Q, L167A, Q424V\n',
    placeholder='Enter mutations per variant, one variant per line',
    description='Variants:',
    layout=widgets.Layout(width='80%', height='100px')
)

max_mut_dropdown = widgets.Dropdown(
    options=[1, 2],
    value=1,
    description='Max muts/step:'
)

variant_prefix_input = widgets.Text(
    value='KWE_TA_A',
    description='Var prefix:'
)

start_col_input = widgets.BoundedIntText(
    value=0,
    min=0, max=6, step=1,
    description='Start col:'
)

start_row_input = widgets.BoundedIntText(
    value=0,
    min=0, max=26, step=1,
    description='Start row:'
)

list_existing_checkbox = widgets.Checkbox(
    value=False,
    description='Skip variant databank',
    disabled=False
)

run_button = widgets.Button(description='Run Complete Workflow', button_style='success')
output = widgets.Output()

# Undo stack to store previous states of variant_databank.json
undo_stack = []

undo_button = widgets.Button(description='Undo Last Change', button_style='warning')
undo_output = widgets.Output()

def on_undo_clicked(b):
    with undo_output:
        clear_output()
        global undo_stack
        if undo_stack:
            last_state = undo_stack.pop()
            databank_file = Path("/content/drive/My Drive/Mutagenesis/protocols/variant_databank.json")
            with open(databank_file, 'w') as f:
                json.dump(last_state, f, indent=2)
            print("✅ Undid last change, variant databank restored.")
        else:
            print("No undo steps available.")

undo_button.on_click(on_undo_clicked)

# === 2. BUTTON FUNCTION ===
def on_run_button_clicked(b):
    with output:
        clear_output()

        # Parse inputs correctly
        seq = sequence_input.value.strip().replace('\n', '').upper()
        variants_raw = variant_input_box.value.strip().split('\n')
        variant_list = [list(map(str.strip, line.split(','))) for line in variants_raw if line]

        max_mutations = max_mut_dropdown.value
        vp = variant_prefix_input.value.strip()
        start_col = start_col_input.value
        start_row = start_row_input.value

        print("=== RUNNING MUTAGENESIS WORKFLOW ===")
        print(f"Sequence length: {len(seq)}")
        print(f"Variants: {variant_list}")
        print(f"Max mutations/step: {max_mutations}")
        print(f"Variant prefix: {vp}")
        print(f"Label START_COL: {start_col}, START_ROW: {start_row}\n")

        # ---- Primer design ----
        pg = PrimerGenerator()
        if not pg.check_sequence_validity(seq, flank_size=30):
            print("❌ Sequence failed validation.")
            return

        pg = PrimerGenerator(flank_size_bases=30)
        if not pg.check_sequence_validity(seq, flank_size=30):
            return
        try:
            pg.validate_mutations_against_sequence(seq, variant_list)
        except ValueError as e:
            print(str(e))
            return

        all_primer_data = []
        existing_pairs = set()
        for muts in variant_list:
            tuples = []
            for mut_str in muts:
                m = re.match(r'([A-Z])(\d+)([A-Z])', mut_str.strip())
                if m:
                    tuples.append((m.group(1), int(m.group(2)), m.group(3)))
            for primer_set in pg.generate_primers_for_construct(seq, tuples):
                key = (primer_set['forward_primer'], primer_set['reverse_primer'])
                if key not in existing_pairs:
                    all_primer_data.append(primer_set)
                    existing_pairs.add(key)

        output_dir = Path("/content/drive/My Drive/Mutagenesis")
        pg.save_primer_list(all_primer_data, output_dir=output_dir)
        pg.save_primer_data_json(all_primer_data, output_dir=output_dir)
        print("✅ Primer design complete.\n")

        # ---- Protocol generation ----
        protocol = MutagenesisProtocol(
            variant_list,
            max_mutations,
            vp,
            output_dir,
            start_col,
            start_row,
            list_existing_as_steps=list_existing_checkbox.value
        )
        protocol.run()

        print("\n🎉 Workflow finished! Check your Google Drive for outputs.")

### Load widgets

In [20]:
# Connect click event
run_button.on_click(on_run_button_clicked)

# Display UI
display(sequence_input, variant_input_box, max_mut_dropdown,
        variant_prefix_input, start_col_input, start_row_input,
        list_existing_checkbox, run_button, output,
        undo_button, undo_output)

Textarea(value='Enter wildtype DNA sequence here...', description='Sequence:', layout=Layout(height='100px', w…

Textarea(value='S19V, S30V, Q424V, A431E\nK3Q, K36R, D76Q, L167A, Q424V\n', description='Variants:', layout=La…

Dropdown(description='Max muts/step:', options=(1, 2), value=1)

Text(value='KWE_TA_A', description='Var prefix:')

BoundedIntText(value=0, description='Start col:', max=6)

BoundedIntText(value=0, description='Start row:', max=26)

Checkbox(value=False, description='Skip variant databank')

Button(button_style='success', description='Run Complete Workflow', style=ButtonStyle())

Output()



Output()

## Load variant_databank.json to delete entries

### Create widgets

In [23]:
import json
from pathlib import Path
import ipywidgets as widgets
from ipywidgets import Layout, HBox, Label, Checkbox

# Path to your variant databank JSON file
output_dir = Path("/content/drive/My Drive/Mutagenesis/protocols")
variant_databank_path = output_dir / "variant_databank.json"

def load_variant_databank():
    if variant_databank_path.exists():
        with open(variant_databank_path, "r") as f:
            return json.load(f)
    return {}

def save_variant_databank(variant_dict):
    with open(variant_databank_path, "w") as f:
        json.dump(variant_dict, f, indent=2)

# Load databank
variant_databank = load_variant_databank()

def create_variant_checkboxes():
    widget_rows = []
    for var_id, muts in variant_databank.items():
        label_text = f"{var_id}: {muts if muts != '(none)' else 'no mutations'}"
        cb = Checkbox(
            value=False,
            indent=False,
            layout=Layout(width='30px', margin='0 5px 0 0')
        )
        label = Label(
            value=label_text,
            layout=Layout(margin='0', width='auto', min_width='300px')
        )
        row = HBox(
            [cb, label],
            layout=Layout(align_items='center', justify_content='flex-start', gap='6px', margin='0')
        )
        widget_rows.append((cb, row))
    return widget_rows

# Create widgets ONCE
variant_widgets = create_variant_checkboxes()
checkboxes = [cb for cb, _ in variant_widgets]
checkbox_box = widgets.VBox([row for _, row in variant_widgets])
delete_button = widgets.Button(description="Delete Selected Variants", button_style="danger")
output = widgets.Output()

def on_delete_clicked(b):
    with output:
        output.clear_output()
        to_delete = []
        for i, cb in enumerate(checkboxes):
            if cb.value:
                variant_id = list(variant_databank.keys())[i]
                to_delete.append(variant_id)
        if not to_delete:
            print("No variants selected for deletion.")
            return
        for var_id in to_delete:
            if var_id in variant_databank:
                del variant_databank[var_id]
        save_variant_databank(variant_databank)
        print(f"Deleted variants: {', '.join(to_delete)}")
        # Refresh checkboxes
        new_variant_widgets = create_variant_checkboxes()
        new_checkboxes = [cb for cb, _ in new_variant_widgets]
        checkbox_box.children = [row for _, row in new_variant_widgets]
        checkboxes[:] = new_checkboxes

### Load widgets

In [24]:
# Assign event handler
delete_button.on_click(on_delete_clicked)

# Display widgets only - no code here except display calls
import IPython.display as display
display.display(widgets.Label("Select variants to delete from the databank:"))
display.display(checkbox_box)
display.display(delete_button)
display.display(output)

Label(value='Select variants to delete from the databank:')

VBox(children=(HBox(children=(Checkbox(value=False, indent=False, layout=Layout(margin='0 5px 0 0', width='30p…

Button(button_style='danger', description='Delete Selected Variants', style=ButtonStyle())

Output()