# Mutagenesis Protocol
## Created using ChatGPT

### The core of the script is designed to generate an efficient, step-by-step mutagenesis protocol that builds enzyme variants using minimal and non-redundant mutagenesis steps. A databank will be created to store all variants that you produce. Labels used in the databank can be printed out.
The script essentially builds a directed acyclic graph (DAG) of variant derivations, and for each target variant, it:

   - Starts from the deepest reachable node (max existing subset).

   - Extends it minimally by adding only the mutations that are missing.

   - Never rebuilds any variant it already made.

So the number of total mutation steps across all variants is minimized, and common intermediate variants are shared between targets.

## Connect to your Google Drive

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
!pip install reportlab
!pip install primer3-py

Collecting reportlab
  Downloading reportlab-4.4.3-py3-none-any.whl.metadata (1.7 kB)
Downloading reportlab-4.4.3-py3-none-any.whl (2.0 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.0/2.0 MB[0m [31m40.1 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: reportlab
Successfully installed reportlab-4.4.3
Collecting primer3-py
  Downloading primer3_py-2.2.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (4.0 kB)
Downloading primer3_py-2.2.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (3.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.1/3.1 MB[0m [31m37.4 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: primer3-py
Successfully installed primer3-py-2.2.0


# In the first step:
 - Paste the wildtype sequence (including 30 bases at flanking sites)
    - you will be notified when the 30 bases are not added
 - Define the variants
    - mutations have to be provided as string like 'S19V'
    - multiple mutations have to be separated by comma
    - each construct needs to be put into square brackets: ['S19V', 'S30V', 'Q424V', 'A431E'],
    - constructs have to be separated by comma
 - Give the number of mutations that you want to add simultaneously (1 to 2 mutations are possible in this script)
 - Provide a prefix for your variants, this will be used for the databank
 - If you want to print labels (189 cell sheet by Hermes) you can define the starting column (0-6) and row (0-26)

## Primer Design

In [17]:
import os
import re
import csv
import json
from typing import List, Tuple, Dict
from pathlib import Path
from collections import defaultdict, Counter

import primer3
from reportlab.lib.pagesizes import A4
from reportlab.lib import colors
from reportlab.platypus import Table, TableStyle, SimpleDocTemplate, Paragraph, Spacer, PageBreak
from reportlab.lib.styles import getSampleStyleSheet, ParagraphStyle
from reportlab.lib.units import mm
from reportlab.pdfgen import canvas

# === PrimerGenerator class (primer design) ===
class PrimerGenerator:
    def __init__(self, flank_size_bases=30):
        # Standard genetic code
        self.codon_table = {
            'TTT': 'F', 'TTC': 'F', 'TTA': 'L', 'TTG': 'L',
            'TCT': 'S', 'TCC': 'S', 'TCA': 'S', 'TCG': 'S',
            'TAT': 'Y', 'TAC': 'Y', 'TAA': '*', 'TAG': '*',
            'TGT': 'C', 'TGC': 'C', 'TGA': '*', 'TGG': 'W',
            'CTT': 'L', 'CTC': 'L', 'CTA': 'L', 'CTG': 'L',
            'CCT': 'P', 'CCC': 'P', 'CCA': 'P', 'CCG': 'P',
            'CAT': 'H', 'CAC': 'H', 'CAA': 'Q', 'CAG': 'Q',
            'CGT': 'R', 'CGC': 'R', 'CGA': 'R', 'CGG': 'R',
            'ATT': 'I', 'ATC': 'I', 'ATA': 'I', 'ATG': 'M',
            'ACT': 'T', 'ACC': 'T', 'ACA': 'T', 'ACG': 'T',
            'AAT': 'N', 'AAC': 'N', 'AAA': 'K', 'AAG': 'K',
            'AGT': 'S', 'AGC': 'S', 'AGA': 'R', 'AGG': 'R',
            'GTT': 'V', 'GTC': 'V', 'GTA': 'V', 'GTG': 'V',
            'GCT': 'A', 'GCC': 'A', 'GCA': 'A', 'GCG': 'A',
            'GAT': 'D', 'GAC': 'D', 'GAA': 'E', 'GAG': 'E',
            'GGT': 'G', 'GGC': 'G', 'GGA': 'G', 'GGG': 'G'
        }

        self.stop_codons = {"TAA", "TAG", "TGA"}

        # Reverse codon table for amino acid to codon conversion
        self.aa_to_codons = {}
        for codon, aa in self.codon_table.items():
            if aa not in self.aa_to_codons:
                self.aa_to_codons[aa] = []
            self.aa_to_codons[aa].append(codon)

        # Constants for flanking sizes
        self.flank_size_bases = flank_size_bases
        self.flank_size_codons = flank_size_bases // 3

        # Primer parameters
        self.MIN_PRIMER_LEN = 18
        self.TARGET_TM = 68.0
        self.MIN_TM = 55.0  # minimum acceptable melting temperature

    def check_sequence_validity(self, dna_sequence: str, flank_size: int = 30) -> bool:
        if len(dna_sequence) % 3 != 0:
            print(f"Error: Sequence length {len(dna_sequence)} is not a multiple of 3.")
            return False
        if len(dna_sequence) < (flank_size * 2 + 6):
            print("Error: Sequence too short for flanking regions.")
            return False

        start_codon_pos = dna_sequence[flank_size:flank_size + 3]
        if start_codon_pos != "ATG":
            print(f"Error: No ATG start codon after first {flank_size} bases, found '{start_codon_pos}'.")
            return False

        stop_codon_pos = dna_sequence[-(flank_size + 3):-flank_size]
        if stop_codon_pos not in self.stop_codons:
            print(f"Error: No valid stop codon before last {flank_size} bases, found '{stop_codon_pos}'.")
            return False

        print("Sequence check passed ✓: correct ATG start and stop codon positions.")
        return True

    @staticmethod
    def _parse_mutation(mut_str):
        m = re.match(r'([A-Z])(\d+)([A-Z])', mut_str.strip())
        if not m:
            raise ValueError(f"❌ Invalid mutation format: {mut_str}")
        return m.group(1), int(m.group(2)), m.group(3)

    def validate_mutations_against_sequence(self, dna_sequence: str, variant_list):
        """
        Validate that each mutation matches the wildtype amino acid
        at the specified position (1-based AA index within coding region).
        Accounts for flanking bases on both sides.
        """
        for variant in variant_list:
            for mut in variant:
                original_aa, pos, _ = self._parse_mutation(mut)

                # Adjust amino acid position by flanking codons
                adj_pos_codons = pos + self.flank_size_codons

                # Get codon from DNA (0-based indexing)
                codon_start = (adj_pos_codons - 1) * 3
                codon_seq = dna_sequence[codon_start:codon_start + 3]

                # Translate codon to AA
                wt_aa = self.codon_table.get(codon_seq)
                if wt_aa != original_aa:
                    raise ValueError(
                        f"❌ Mutation {mut} does not match WT residue '{wt_aa}' "
                        f"at coding AA position {pos} (codon {codon_seq}, adjusted index {adj_pos_codons})"
                    )

    def reverse_complement(self, sequence: str) -> str:
        complement = {'A': 'T', 'T': 'A', 'G': 'C', 'C': 'G'}
        return ''.join(complement[base] for base in sequence[::-1])

    def calculate_tm(self, sequence: str) -> float:
        if not sequence:
            return 0.0
        try:
            return round(primer3.calc_tm(sequence), 2)
        except Exception as e:
            print(f"Error calculating Tm for sequence '{sequence}': {e}")
            return 0.0

    def get_best_codon(self, amino_acid: str, original_codon: str) -> str:
        if amino_acid in self.aa_to_codons:
            codons = self.aa_to_codons[amino_acid]
            if original_codon in codons:
                return original_codon
            return codons[0]
        return "NNN"

    def merge_close_mutations(self, mutations: List[Tuple[str, int, str]], max_distance_nt: int = 21) -> List[List[Tuple[str, int, str]]]:
        sorted_muts = sorted(mutations, key=lambda x: x[1])
        merged = []
        current_group = [sorted_muts[0]]
        for mut in sorted_muts[1:]:
            last_position = current_group[-1][1]
            if (mut[1] - last_position) * 3 <= max_distance_nt:
                current_group.append(mut)
            else:
                merged.append(current_group)
                current_group = [mut]
        merged.append(current_group)
        return merged

    def generate_primers_for_construct(self, dna_sequence: str, mutations: List[Tuple[str, int, str]]) -> List[Dict]:
        primer_sets = []
        mutation_groups = self.merge_close_mutations(mutations)

        for group in mutation_groups:
            codon_starts = [((m[1] - 1) + self.flank_size_codons) * 3 for m in group]
            codon_ends = [start + 3 for start in codon_starts]
            min_start = max(0, min(codon_starts) - 6)
            max_end = min(len(dna_sequence), max(codon_ends) + 6)

            mutated_seq = list(dna_sequence)
            for m in group:
                aa, pos, new_aa = m
                idx = ((pos - 1) + self.flank_size_codons) * 3
                original_codon = dna_sequence[idx:idx + 3]
                new_codon = self.get_best_codon(new_aa, original_codon)
                if self.codon_table.get(original_codon, 'X') != aa:
                    print(f"Warning: Original amino acid mismatch at position {pos}")
                mutated_seq[idx:idx + 3] = list(new_codon)
            mutated_seq = ''.join(mutated_seq)

            best_overlap_start = min_start
            best_overlap_end = max_end
            best_overlap_tm = float('inf')
            overlap_found = False

            for start_offset in range(-3, 4):
                for length in range(14, 17):
                    test_start = max(0, min_start + start_offset)
                    test_end = min(len(dna_sequence), test_start + length)
                    if test_start <= min(codon_starts) and test_end >= max(codon_ends):
                        test_seq = mutated_seq[test_start:test_end]
                        test_tm = self.calculate_tm(test_seq)
                        if test_tm < 60 and test_tm < best_overlap_tm:
                            best_overlap_start = test_start
                            best_overlap_end = test_end
                            best_overlap_tm = test_tm
                            overlap_found = True

            if not overlap_found:
                best_overlap_start = min_start
                best_overlap_end = min(min_start + 15, len(dna_sequence))
                best_overlap_tm = self.calculate_tm(mutated_seq[best_overlap_start:best_overlap_end])

            overlap_seq = mutated_seq[best_overlap_start:best_overlap_end]

            forward_overlap = overlap_seq
            extension_start = best_overlap_end
            extension_end = min(len(dna_sequence), extension_start + 15)
            forward_extension = mutated_seq[extension_start:extension_end]
            forward_primer = forward_overlap + forward_extension

            while len(forward_primer) < self.MIN_PRIMER_LEN and extension_end < len(dna_sequence):
                extension_end += 1
                forward_extension = mutated_seq[extension_start:extension_end]
                forward_primer = forward_overlap + forward_extension

            forward_tm = self.calculate_tm(forward_primer)

            while forward_tm < self.TARGET_TM and extension_end < min(len(dna_sequence), extension_start + 35):
                extension_end += 1
                forward_extension = mutated_seq[extension_start:extension_end]
                forward_primer = forward_overlap + forward_extension
                forward_tm = self.calculate_tm(forward_primer)

            while len(forward_primer) > 45:
                extension_end -= 1
                forward_extension = mutated_seq[extension_start:extension_end]
                forward_primer = forward_overlap + forward_extension
                forward_tm = self.calculate_tm(forward_primer)

            if forward_tm < self.MIN_TM:
                print(f"Warning: Forward primer Tm for mutations {group} is low ({forward_tm}°C). Consider manual optimization or PCR adjustments.")

            reverse_overlap = self.reverse_complement(overlap_seq)
            reverse_extension_end = best_overlap_start
            reverse_extension_start = max(0, reverse_extension_end - 15)
            reverse_extension_template = mutated_seq[reverse_extension_start:reverse_extension_end]
            reverse_extension = self.reverse_complement(reverse_extension_template)
            reverse_primer = reverse_overlap + reverse_extension

            while len(reverse_primer) < self.MIN_PRIMER_LEN and reverse_extension_start > 0:
                reverse_extension_start -= 1
                reverse_extension_template = mutated_seq[reverse_extension_start:reverse_extension_end]
                reverse_extension = self.reverse_complement(reverse_extension_template)
                reverse_primer = reverse_overlap + reverse_extension

            reverse_tm = self.calculate_tm(reverse_primer)

            while reverse_tm < self.TARGET_TM and reverse_extension_start > max(0, reverse_extension_end - 35):
                reverse_extension_start -= 1
                reverse_extension_template = mutated_seq[reverse_extension_start:reverse_extension_end]
                reverse_extension = self.reverse_complement(reverse_extension_template)
                reverse_primer = reverse_overlap + reverse_extension
                reverse_tm = self.calculate_tm(reverse_primer)

            while len(reverse_primer) > 45:
                reverse_extension_start += 1
                reverse_extension_template = mutated_seq[reverse_extension_start:reverse_extension_end]
                reverse_extension = self.reverse_complement(reverse_extension_template)
                reverse_primer = reverse_overlap + reverse_extension
                reverse_tm = self.calculate_tm(reverse_primer)

            if reverse_tm < self.MIN_TM:
                print(f"Warning: Reverse primer Tm for mutations {group} is low ({reverse_tm}°C). Consider manual optimization or PCR adjustments.")

            tm_diff = abs(forward_tm - reverse_tm)
            max_tm_diff = 3.0

            if tm_diff > max_tm_diff:
                if forward_tm < reverse_tm:
                    while tm_diff > max_tm_diff and extension_end < min(len(dna_sequence), extension_start + 35):
                        extension_end += 1
                        forward_extension = mutated_seq[extension_start:extension_end]
                        forward_primer = forward_overlap + forward_extension
                        forward_tm = self.calculate_tm(forward_primer)
                        tm_diff = abs(forward_tm - reverse_tm)
                else:
                    while tm_diff > max_tm_diff and reverse_extension_start > max(0, reverse_extension_end - 35):
                        reverse_extension_start -= 1
                        reverse_extension_template = mutated_seq[reverse_extension_start:reverse_extension_end]
                        reverse_extension = self.reverse_complement(reverse_extension_template)
                        reverse_primer = reverse_overlap + reverse_extension
                        reverse_tm = self.calculate_tm(reverse_primer)
                        tm_diff = abs(forward_tm - reverse_tm)

            primer_sets.append({
                'mutations': [f"{m[0]}{m[1]}{m[2]}" for m in group],
                'overlap_sequence': overlap_seq,
                'overlap_tm': best_overlap_tm,
                'forward_primer': forward_primer,
                'reverse_primer': reverse_primer,
                'forward_tm': forward_tm,
                'reverse_tm': reverse_tm,
                'forward_length': len(forward_primer),
                'reverse_length': len(reverse_primer),
                'overlap_length': len(overlap_seq),
                'tm_difference': abs(forward_tm - reverse_tm)
            })

        return primer_sets

    def save_primer_list(self, primer_data: List[Dict], output_dir=None, filename="primer_list.txt"):
        if output_dir is None:
            output_dir = os.path.expanduser("~/Mutagenesis")
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)

        filepath_txt = os.path.join(output_dir, filename)
        with open(filepath_txt, 'w') as f:
            f.write("Site-Directed Mutagenesis Primer List\n")
            f.write("=" * 120 + "\n\n")
            header = ("Primer Name".ljust(20) +
                      "Primer Sequence".ljust(60) +
                      "Len".rjust(6) +
                      "Tm (°C)".rjust(9) +
                      "OverlapLen".rjust(11) +
                      "OverlapTm".rjust(10) +
                      "Notes".rjust(30) + "\n")
            f.write(header)
            f.write("-" * 120 + "\n")

            total_primers = 0

            for entry in primer_data:
                primer_name_base = "_".join(entry['mutations'])
                notes = ",".join(entry['mutations'])
                f.write(f"{primer_name_base + '_for':<20}")
                f.write(f"{entry['forward_primer']:<60}")
                f.write(f"{len(entry['forward_primer']):>6}")
                f.write(f"{entry['forward_tm']:>9.1f}")
                f.write(f"{entry['overlap_length']:>11}")
                f.write(f"{entry['overlap_tm']:>10.1f}")
                f.write(f"{notes:>30}\n")

                f.write(f"{primer_name_base + '_rev':<20}")
                f.write(f"{entry['reverse_primer']:<60}")
                f.write(f"{len(entry['reverse_primer']):>6}")
                f.write(f"{entry['reverse_tm']:>9.1f}")
                f.write(f"{entry['overlap_length']:>11}")
                f.write(f"{entry['overlap_tm']:>10.1f}")
                f.write(f"{notes:>30}\n")

                f.write("-" * 120 + "\n")

                total_primers += 2

            f.write("\nSummary:\n")
            f.write("=" * 50 + "\n")
            f.write(f"Total primer entries: {total_primers // 2}\n")

        print(f"Primer list saved to {filepath_txt}")

        filepath_csv = os.path.join(output_dir, "primer_list.csv")
        with open(filepath_csv, 'w', newline='') as csvfile:
            csvwriter = csv.writer(csvfile)
            csvwriter.writerow([
                "Primer Name", "Primer Sequence", "Length", "Tm (°C)",
                "Overlap Length", "Overlap Tm", "Mutations"
            ])
            for entry in primer_data:
                notes = ",".join(entry['mutations'])
                csvwriter.writerow([
                    "_".join(entry['mutations']) + "_for",
                    entry['forward_primer'],
                    len(entry['forward_primer']),
                    entry['forward_tm'],
                    entry['overlap_length'],
                    entry['overlap_tm'],
                    notes
                ])
                csvwriter.writerow([
                    "_".join(entry['mutations']) + "_rev",
                    entry['reverse_primer'],
                    len(entry['reverse_primer']),
                    entry['reverse_tm'],
                    entry['overlap_length'],
                    entry['overlap_tm'],
                    notes
                ])

        print(f"Primer CSV saved to {filepath_csv}")

    def save_primer_data_json(self, primer_data: List[Dict], output_dir=None, filename="primer_list.json"):
        if output_dir is None:
            output_dir = os.path.expanduser("~/Mutagenesis")
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)

        filepath_json = os.path.join(output_dir, filename)
        with open(filepath_json, 'w') as f:
            json.dump(primer_data, f, indent=2)

        print(f"Primer JSON saved to {filepath_json}")

## Mutagenesis Protocol

In [19]:
import re
import json
from pathlib import Path
from collections import Counter, defaultdict
from reportlab.lib.pagesizes import A4
from reportlab.lib.styles import getSampleStyleSheet, ParagraphStyle
from reportlab.lib import colors
from reportlab.platypus import SimpleDocTemplate, Paragraph, Spacer, Table, TableStyle, PageBreak
from reportlab.pdfgen import canvas
from reportlab.lib.units import mm

class MutagenesisProtocol:
    def __init__(
        self,
        variant_input,
        max_mutations_per_step,
        variant_prefix,
        output_dir_path,
        start_col=2,
        start_row=16,
        list_existing_as_steps=False
    ):
        self.variant_input = variant_input
        self.max_mutations_per_step = max_mutations_per_step
        self.variant_prefix = variant_prefix
        self.start_col = start_col
        self.start_row = start_row
        self.mutagenesis_dir = Path(output_dir_path)
        self.primer_json_path = self.mutagenesis_dir / "primer_list.json"
        self.output_dir = self.mutagenesis_dir / "protocols"
        self.output_dir.mkdir(parents=True, exist_ok=True)
        self.pdf_path = self.output_dir / "mutagenesis_protocol.pdf"
        self.databank_file = self.output_dir / "variant_databank.json"
        self.list_existing_as_steps = list_existing_as_steps

    @staticmethod
    def mutation_pos(mutation):
        matches = re.findall(r'\d+', mutation)
        return int(matches[0]) if matches else 0

    @staticmethod
    def variant_key(mutations):
        return ','.join(sorted(mutations, key=MutagenesisProtocol.mutation_pos))

    @staticmethod
    def extract_variant_number(name):
        match = re.search(r'variant (\d+)', name)
        return int(match.group(1)) if match else -1

    @staticmethod
    def normalize_mut(m):
        return m.strip().upper()

    def get_base_variant(self, current, existing_variants):
        best = ()
        for var in existing_variants:
            if set(var).issubset(current) and len(var) > len(best):
                best = var
        return best

    def generate_pdf(
        self,
        protocol_by_round,
        final_variants,
        filename,
        existing_input_variants,
        variant_to_label_map,
        variant_to_final_variant,
        used_variants
    ):
        doc = SimpleDocTemplate(filename, pagesize=A4)
        elements = []
        styles = getSampleStyleSheet()
        cell_style = ParagraphStyle('mut_cell', fontSize=9, leading=11)

        if existing_input_variants:
            elements.append(Paragraph("<b>Pre-existing Variants in Databank</b>", styles['Heading2']))
            elements.append(Spacer(1, 12))
            data = [['Input Label', 'Existing Variant Label']]
            for input_label, existing_label in sorted(existing_input_variants.items()):
                data.append([input_label, existing_label or '(unknown)'])
            table = Table(data, colWidths=[150, 250])
            table.setStyle(TableStyle([
                ('BACKGROUND', (0, 0), (-1, 0), colors.lightgrey),
                ('GRID', (0, 0), (-1, -1), 0.5, colors.grey)
            ]))
            elements.append(table)
            elements.append(PageBreak())

        if not protocol_by_round:
            elements.append(Paragraph("<b>No new variants generated.</b>", styles['Heading2']))
            elements.append(Spacer(1, 12))
            if not existing_input_variants:
                elements.append(Paragraph("No variants were generated because no input variants were provided.", styles['Normal']))
            else:
                elements.append(Paragraph("All input variants already exist in the databank.", styles['Normal']))
            doc.build(elements)
            return

        elements.append(Paragraph("<b>Final Variants Overview</b>", styles['Heading2']))
        final_table_data = [['Input Label', 'Final Variant', 'Mutations']]
        for label in sorted(variant_to_label_map, key=self.extract_variant_number):
            sorted_mutations = sorted(variant_to_label_map[label], key=self.mutation_pos)
            mut_paragraph = Paragraph(', '.join(sorted_mutations), cell_style)
            final_table_data.append([label, variant_to_final_variant[label], mut_paragraph])
        table = Table(final_table_data, colWidths=[100, 100, 260], repeatRows=1)
        table.setStyle(TableStyle([
            ('BACKGROUND', (0, 0), (-1, 0), colors.lightgrey),
            ('GRID', (0, 0), (-1, -1), 0.5, colors.grey)
        ]))
        elements.append(table)
        elements.append(Spacer(1, 12))

        elements.append(Paragraph("<b>Step-by-Step Mutagenesis Protocol (Grouped by Round)</b>", styles['Heading2']))
        for i, round_num in enumerate(sorted(protocol_by_round), start=1):
            steps = sorted(protocol_by_round[round_num], key=lambda r: self.extract_variant_number(r[1]))
            elements.append(Paragraph(f"<b>Step {i}</b>", styles['Heading3']))
            table_data = [['New Variant', 'Parent Variant', 'Mutations Added']]
            for row in steps:
                sorted_muts = ', '.join(sorted(row[2].split(', '), key=self.mutation_pos))
                table_data.append([row[0], row[1], Paragraph(sorted_muts, cell_style)])
            step_table = Table(table_data, colWidths=[100, 120, 160], repeatRows=1)
            step_table.setStyle(TableStyle([
                ('BACKGROUND', (0, 0), (-1, 0), colors.lightgrey),
                ('GRID', (0, 0), (-1, -1), 0.5, colors.grey)
            ]))
            elements.append(step_table)
            elements.append(Spacer(1, 10))

        elements.append(PageBreak())
        elements.append(Paragraph("<b>All Produced Variants</b>", styles['Heading2']))
        all_variants_data = [['Variant Name', 'Mutations']]
        for muts, name in used_variants.items():
            mut_text = ', '.join(muts) if muts else '(none)'
            mut_paragraph = Paragraph(mut_text, cell_style)
            all_variants_data.append([name, mut_paragraph])
        table = Table(all_variants_data, colWidths=[140, 320], repeatRows=1)
        table.setStyle(TableStyle([
            ('BACKGROUND', (0, 0), (-1, 0), colors.lightgrey),
            ('GRID', (0, 0), (-1, -1), 0.5, colors.grey),
            ('VALIGN', (0, 0), (-1, -1), 'TOP'),
            ('LEFTPADDING', (0, 0), (-1, -1), 5),
            ('RIGHTPADDING', (0, 0), (-1, -1), 5),
        ]))
        elements.append(table)

        doc.build(elements)

    def generate_label_pdf(self, variants, start_col, start_row, base_dir):
        base_dir = Path(base_dir)
        variant_json_file = base_dir / "protocols" / "variant_databank.json"
        pdf_output_file = base_dir / "protocols" / "herma_10900_labels.pdf"

        page_width, page_height = A4
        label_width = 25.3 * mm
        label_height = 9.9 * mm
        horizontal_spacing = 2.5 * mm
        top_margin = 13.5 * mm
        left_margin = 10.3 * mm
        columns = 7
        rows = 27
        font_size = 5.5

        def variant_to_key(mutations):
            return ','.join(sorted(mutations, key=self.mutation_pos))

        with open(variant_json_file, "r") as f:
            variant_data = json.load(f)

        def mutations_str_to_set(mutation_str):
            if mutation_str == '(none)':
                return set()
            return set(mutation_str.split(','))

        input_mutation_sets = [set(v) for v in variants]
        selected_variants = []
        for label, mutation_str in variant_data.items():
            variant_mut_set = mutations_str_to_set(mutation_str)
            if any(variant_mut_set.issubset(inp_set) for inp_set in input_mutation_sets):
                selected_variants.append(label)

        def variant_sort_key(label):
            match = re.search(r'\D+(\d+)$', label)
            return int(match.group(1)) if match else 0

        selected_variants = sorted(selected_variants, key=variant_sort_key)
        c = canvas.Canvas(str(pdf_output_file), pagesize=A4)
        c.setFont("Helvetica", font_size)

        label_index = 0
        total_labels = len(selected_variants)
        page_number = 1

        current_start_col = start_col
        current_start_row = start_row

        while label_index < total_labels:
            for row in range(rows):
                for col in range(columns):
                    if page_number == 1:
                        if row < current_start_row or (row == current_start_row and col < current_start_col):
                            continue
                    if label_index >= total_labels:
                        break

                    x = left_margin + col * (label_width + horizontal_spacing)
                    y = page_height - top_margin - (row + 1) * label_height

                    label = selected_variants[label_index]
                    text_x = x + label_width / 2
                    text_y = y + label_height / 2 - font_size / 2

                    c.drawCentredString(text_x, text_y, label)
                    label_index += 1

                if label_index >= total_labels:
                    break

            if label_index < total_labels:
                c.showPage()
                c.setFont("Helvetica", font_size)
                page_number += 1
                current_start_row = 0
                current_start_col = 0

        c.save()
        print(f"✅ PDF with labels generated: {pdf_output_file}")

    def run(self):
        # Check primer JSON exists
        if not self.primer_json_path.exists():
            raise FileNotFoundError(f"Primer JSON not found: {self.primer_json_path}")

        with open(self.primer_json_path, "r") as f:
            primer_data = json.load(f)

        mutation_to_groups = {}
        primer_groups = {}
        for idx, primer_set in enumerate(primer_data):
            group_muts = [self.normalize_mut(m) for m in primer_set["mutations"]]
            primer_groups[idx] = sorted(group_muts, key=self.mutation_pos)
            for mut in group_muts:
                mutation_to_groups.setdefault(mut, []).append(idx)

        if self.databank_file.exists():
            try:
                with open(self.databank_file, "r") as f:
                    content = f.read().strip()
                    variant_databank = json.loads(content) if content else {}
            except json.JSONDecodeError:
                print("⚠️ Warning: Databank file invalid. Starting fresh.")
                variant_databank = {}
        else:
            variant_databank = {}

        all_mutations = [mut for v in self.variant_input for mut in v]
        mutation_freq = Counter(all_mutations)
        variants = [sorted(v, key=lambda x: -mutation_freq[x]) for v in self.variant_input]

        existing_variants = {(): 'wildtype'}
        for key, mutation_str in variant_databank.items():
            muts = tuple(mutation_str.split(',')) if mutation_str != '(none)' else ()
            existing_variants[muts] = key

        final_variants = defaultdict(list)
        variant_to_label_map = {}
        variant_to_final_variant = {}
        protocol_by_round = defaultdict(list)
        used_variants = {}
        existing_input_variants = {}

        for idx, target_mutations in enumerate(variants):
            target = tuple(target_mutations)
            variant_key_str = self.variant_key(target)
            variant_label = f"variant {idx+1}"

            if variant_key_str in variant_databank.values():
                if not self.list_existing_as_steps:
                    # Just note it in the pre-existing list and skip real planning
                    for muts, name in existing_variants.items():
                        if variant_key_str == self.variant_key(muts):
                            existing_input_variants[variant_label] = name
                            break
                    continue
                else:
                    # Force regenerate a full protocol path as if it's new
                    # Remove the exact match from existing_variants so planning will run
                    for muts in list(existing_variants.keys()):
                        if self.variant_key(muts) == variant_key_str:
                            del existing_variants[muts]
                    # do NOT 'continue' here — fall through to full mutation batching logic

            base_variant = self.get_base_variant(target, existing_variants)
            current_mutations = list(base_variant)
            base_name = existing_variants[base_variant]

            target_group_ids = []
            for mut in target:
                mut_norm = self.normalize_mut(mut)
                possible_gids = mutation_to_groups.get(mut_norm, [])
                if possible_gids:
                    filtered_gids = [
                        gid for gid in possible_gids
                        if all(m in map(self.normalize_mut, target) for m in primer_groups[gid])
                        and any(m not in map(self.normalize_mut, current_mutations) for m in primer_groups[gid])
                    ]
                    if filtered_gids:
                        best_gid = max(
                            filtered_gids,
                            key=lambda gid: (
                                len([m for m in primer_groups[gid] if self.normalize_mut(m) in map(self.normalize_mut, target)]),
                                len(primer_groups[gid])
                            )
                        )
                        if best_gid not in target_group_ids:
                            target_group_ids.append(best_gid)

            needed_group_ids = []
            for gid in target_group_ids:
                group_muts = primer_groups[gid]
                if not all(m in map(self.normalize_mut, current_mutations) for m in group_muts):
                    needed_group_ids.append(gid)

            while needed_group_ids:
                batch_ids = needed_group_ids[:self.max_mutations_per_step]
                batch_muts = []
                for gid in batch_ids:
                    batch_muts.extend(primer_groups[gid])

                test_mutations = sorted(
                    set(map(self.normalize_mut, current_mutations)) | set(batch_muts),
                    key=self.mutation_pos
                )
                test_sorted = tuple(test_mutations)

                if test_sorted not in existing_variants:
                    mutation_count = len(test_sorted)
                    new_name = f"{variant_label}.{mutation_count}"
                    existing_variants[test_sorted] = new_name
                    used_variants[test_sorted] = new_name
                    protocol_by_round[mutation_count].append([new_name, base_name, ', '.join(batch_muts)])
                    base_name = new_name
                    current_mutations = test_mutations

                for gid in batch_ids:
                    needed_group_ids.remove(gid)

            final_variants[base_name].append(target)
            variant_to_label_map[variant_label] = target
            variant_to_final_variant[variant_label] = base_name
            used_variants[tuple(sorted(map(self.normalize_mut, target), key=self.mutation_pos))] = base_name

        self.generate_pdf(
            protocol_by_round, final_variants, str(self.pdf_path),
            existing_input_variants, variant_to_label_map,
            variant_to_final_variant, used_variants
        )

        existing_ids = [int(v_id[-2:]) for v_id in variant_databank if v_id.startswith(self.variant_prefix)]
        next_id_num = max(existing_ids, default=0) + 1
        for muts, name in existing_variants.items():
            mutation_str = ','.join(sorted(muts, key=self.mutation_pos))
            if mutation_str not in variant_databank.values():
                new_id = f"{self.variant_prefix}{next_id_num:02d}"
                variant_databank[new_id] = mutation_str
                next_id_num += 1

        with open(self.databank_file, "w") as f:
            json.dump(variant_databank, f, indent=2)

        print(f"✅ PDF generated: {self.pdf_path}")
        print(f"📁 Databank updated: {self.databank_file}")

        self.generate_label_pdf(self.variant_input, self.start_col, self.start_row, self.mutagenesis_dir)


## Widget

Open the cell if closed and fill out the widget below!

Click the **Button** below to start the script.

In [21]:
import ipywidgets as widgets
from IPython.display import display, clear_output
from pathlib import Path
import re

# === 1. INTERACTIVE FORM WIDGETS ===
sequence_input = widgets.Textarea(
    value='Enter wildtype DNA sequence here...',
    placeholder='Paste your DNA sequence...',
    description='Sequence:',
    layout=widgets.Layout(width='80%', height='100px')
)

variant_input_box = widgets.Textarea(
    value='S19V, S30V, Q424V, A431E\nK3Q, K36R, D76Q, L167A, Q424V\n',
    placeholder='Enter mutations per variant, one variant per line',
    description='Variants:',
    layout=widgets.Layout(width='80%', height='100px')
)

max_mut_dropdown = widgets.Dropdown(
    options=[1, 2],
    value=1,
    description='Max muts/step:'
)

variant_prefix_input = widgets.Text(
    value='KWE_TA_A',
    description='Var prefix:'
)

start_col_input = widgets.BoundedIntText(
    value=0,
    min=0, max=6, step=1,
    description='Start col:'
)

start_row_input = widgets.BoundedIntText(
    value=0,
    min=0, max=26, step=1,
    description='Start row:'
)

list_existing_checkbox = widgets.Checkbox(
    value=False,
    description='Skip variant databank',
    disabled=False
)

run_button = widgets.Button(description='Run Complete Workflow', button_style='success')
output = widgets.Output()

# === 2. BUTTON FUNCTION ===
def on_run_button_clicked(b):
    with output:
        clear_output()

        # Parse inputs correctly
        seq = sequence_input.value.strip().replace('\n', '').upper()
        variants_raw = variant_input_box.value.strip().split('\n')
        variant_list = [list(map(str.strip, line.split(','))) for line in variants_raw if line]

        max_mutations = max_mut_dropdown.value
        vp = variant_prefix_input.value.strip()
        start_col = start_col_input.value
        start_row = start_row_input.value

        print("=== RUNNING MUTAGENESIS WORKFLOW ===")
        print(f"Sequence length: {len(seq)}")
        print(f"Variants: {variant_list}")
        print(f"Max mutations/step: {max_mutations}")
        print(f"Variant prefix: {vp}")
        print(f"Label START_COL: {start_col}, START_ROW: {start_row}\n")

        # ---- Primer design ----
        pg = PrimerGenerator()
        if not pg.check_sequence_validity(seq, flank_size=30):
            print("❌ Sequence failed validation.")
            return

        pg = PrimerGenerator(flank_size_bases=30)
        if not pg.check_sequence_validity(seq, flank_size=30):
            return
        try:
            pg.validate_mutations_against_sequence(seq, variant_list)
        except ValueError as e:
            print(str(e))
            return

        all_primer_data = []
        existing_pairs = set()
        for muts in variant_list:
            tuples = []
            for mut_str in muts:
                m = re.match(r'([A-Z])(\d+)([A-Z])', mut_str.strip())
                if m:
                    tuples.append((m.group(1), int(m.group(2)), m.group(3)))
            for primer_set in pg.generate_primers_for_construct(seq, tuples):
                key = (primer_set['forward_primer'], primer_set['reverse_primer'])
                if key not in existing_pairs:
                    all_primer_data.append(primer_set)
                    existing_pairs.add(key)

        output_dir = Path("/content/drive/My Drive/Mutagenesis")
        pg.save_primer_list(all_primer_data, output_dir=output_dir)
        pg.save_primer_data_json(all_primer_data, output_dir=output_dir)
        print("✅ Primer design complete.\n")

        # ---- Protocol generation ----
        protocol = MutagenesisProtocol(
            variant_list,
            max_mutations,
            vp,
            output_dir,
            start_col,
            start_row,
            list_existing_as_steps=list_existing_checkbox.value
        )
        protocol.run()

        print("\n🎉 Workflow finished! Check your Google Drive for outputs.")

# Connect click event
run_button.on_click(on_run_button_clicked)

# Display UI
display(sequence_input, variant_input_box, max_mut_dropdown,
        variant_prefix_input, start_col_input, start_row_input, list_existing_checkbox,
        run_button, output)


Textarea(value='Enter wildtype DNA sequence here...', description='Sequence:', layout=Layout(height='100px', w…

Textarea(value='S19V, S30V, Q424V, A431E\nK3Q, K36R, D76Q, L167A, Q424V\n', description='Variants:', layout=La…

Dropdown(description='Max muts/step:', options=(1, 2), value=1)

Text(value='KWE_TA_A', description='Var prefix:')

BoundedIntText(value=0, description='Start col:', max=6)

BoundedIntText(value=0, description='Start row:', max=26)

Checkbox(value=False, description='Skip variant databank')

Button(button_style='success', description='Run Complete Workflow', style=ButtonStyle())

Output()