Colab-ready: VQE codon+MFE co-optimization with GU (Turner) + RY-CZ-RY ansatz

Includes: dot-bracket generation & arc plot from VQE-optimized quartet variables

Install necessary packages (uncomment if running in a fresh Colab)

!pip -q install qiskit qiskit-aer scipy matplotlib

In [None]:
from qiskit_aer import AerSimulator
from qiskit_aer import Aer
from qiskit import QuantumCircuit, transpile

from typing import List, Dict, Tuple, Callable, Optional
import itertools, math, time
import numpy as np
import matplotlib.pyplot as plt

# Qiskit & SciPy
from qiskit import QuantumCircuit, transpile
from qiskit.circuit import Parameter
# from qiskit.providers.aer import AerSimulator
from scipy.optimize import minimize

In [None]:
# -------------------------
# Codon table & default CAI-like weights (toy)
# -------------------------
CODON_TABLE = {
    'A': ['GCC','GCU','GCA','GCG'],
    'R': ['CGG','AGA','AGG','CGC','CGA','CGU'],
    'N': ['AAC','AAU'],
    'D': ['GAC','GAU'],
    'C': ['UGC','UGU'],
    'Q': ['CAG','CAA'],
    'E': ['GAG','GAA'],
    'G': ['GGC','GGG','GGA','GGU'],
    'H': ['CAC','CAU'],
    'I': ['AUC','AUU','AUA'],
    'L': ['CUG','CUC','CUU','UUG','UUA','CUA'],
    'K': ['AAG','AAA'],
    'M': ['AUG'],
    'F': ['UUC','UUU'],
    'P': ['CCC','CCU','CCA','CCG'],
    'S': ['AGC','UCC','UCU','AGU','UCA','UCG'],
    'T': ['ACC','ACA','ACU','ACG'],
    'W': ['UGG'],
    'Y': ['UAC','UAU'],
    'V': ['GUG','GUC','GUU','GUA'],
}

DEFAULT_CODON_WEIGHTS: Dict[str,float] = {}
for aa, codons in CODON_TABLE.items():
    for i,c in enumerate(codons):
        DEFAULT_CODON_WEIGHTS[c] = 1.0 - 0.05*(i % 4)


In [None]:

# -------------------------
# Utilities
# -------------------------
def prefilter_codons_for_peptide(peptide_seq: str, codon_weights: Dict[str,float], top_k: int = 2) -> List[List[str]]:
    peptide_seq = peptide_seq.upper()
    codon_choices = []
    for aa in peptide_seq:
        if aa not in CODON_TABLE:
            raise ValueError(f"Unknown amino acid '{aa}' in peptide sequence.")
        choices = CODON_TABLE[aa]
        sorted_choices = sorted(choices, key=lambda c: codon_weights.get(c, 0.0), reverse=True)
        codon_choices.append(sorted_choices[:top_k])
    return codon_choices

def flatten_var_index_map(codon_choices: List[List[str]]) -> Tuple[Dict[Tuple[int,int],int], List[Tuple[int,int]]]:
    var_map = {}
    var_list = []
    idx = 0
    for pos, choices in enumerate(codon_choices):
        for cidx in range(len(choices)):
            var_map[(pos,cidx)] = idx
            var_list.append((pos,cidx))
            idx += 1
    return var_map, var_list

def join_codons_to_seq(codons: List[str]) -> str:
    return ''.join(codons)

def revcomp_rna(seq: str) -> str:
    comp = str.maketrans({'A':'U','U':'A','G':'C','C':'G'})
    return seq.translate(comp)[::-1]

# -------------------------
# Turner-2004 GU stack table & helix-end penalty (exact entries from NNDB Turner-2004 GU page)
# (Compact subset used here; extend as needed)
# -------------------------
TURNER2004 = {
    "GU_stack_dG37": {
        # (top_dinuc, bottom_dinuc) -> ΔG37 (kcal/mol)
        ("AG","UU"): -0.55, ("UU","GA"): -0.55,
        ("AU","UG"): -1.36, ("GU","UA"): -1.36,
        ("CG","GU"): -1.41, ("UG","GC"): -2.11,
        ("GG","UC"): -1.53, ("UC","GG"): -1.53,
        ("GU","CG"): -2.51, ("GC","UG"): -1.27,
        ("GA","UU"): -1.27, ("UU","AG"): -1.27,
        ("GG","UU"): +0.47, ("UU","GG"): +0.47,
        # GU->GU special handled separately (prediction set to -0.5)
    },
    "helix_end_penalty_dG37": {
        "AU_or_GU": +0.45
    },
    "favorable_tri_stack_dG37": {
        "GU_UG_triplet_total": -4.12
    },
    # Representative Watson–Crick stacks (subset — extend if you want full WC table)
    "WC_stack_dG37": {
        ("GC","CG"):-3.42, ("CG","GC"):-2.36, ("AU","UA"):-0.93, ("UA","AU"):-1.10,
        ("GG","CC"):-2.35, ("CC","GG"):-2.35
    }
}

def bp_type(b1: str, b2: str) -> str:
    if (b1,b2)==('A','U') or (b1,b2)==('U','A'): return 'AU'
    if (b1,b2)==('G','C') or (b1,b2)==('C','G'): return 'GC'
    if (b1,b2)==('G','U') or (b1,b2)==('U','G'): return 'GU'
    return 'MM'


In [None]:

# -------------------------
# Surrogate MFE using GU stacks + terminal penalties (reverse-complement duplex proxy)
# -------------------------
def mfe_surrogate_turner_gu(seq: str) -> float:
    seq = seq.upper()
    rc = revcomp_rna(seq)
    dG = 0.0
    # baseline GC count contribution (legacy)
    dG += -0.5 * (seq.count('G') + seq.count('C'))
    # WC stacks (subset)
    for i in range(len(seq)-1):
        top = seq[i:i+2]
        bot = rc[i:i+2]
        # try exact WC mapping first
        val = TURNER2004["WC_stack_dG37"].get((top, bot))
        if val is not None:
            dG += val
    # GU stacks and special cases
    for i in range(len(seq)-1):
        top = seq[i:i+2]
        bot = rc[i:i+2]
        val = TURNER2004["GU_stack_dG37"].get((top, bot))
        # try symmetric equiv (swap/reverse)
        if val is None:
            val = TURNER2004["GU_stack_dG37"].get((bot[::-1], top[::-1]))
        # GU->GU override to -0.5 if exact GU/UG tandem is encountered
        if top=="GU" and bot=="UG":
            val = -0.5
        if val is not None:
            dG += val
    # trio special: approximate distribute powerful tri-stack total if motif found
    tri_total = TURNER2004["favorable_tri_stack_dG37"]["GU_UG_triplet_total"]
    for i in range(len(seq)-3):
        # pattern GC GU UG CG in paired context
        bp0 = bp_type(seq[i], rc[i])
        bp1 = bp_type(seq[i+1], rc[i+1])
        bp2 = bp_type(seq[i+2], rc[i+2])
        bp3 = bp_type(seq[i+3], rc[i+3])
        if bp0=='GC' and bp1=='GU' and bp2=='GU' and bp3=='GC':
            # compute existing sum and adjust
            s01 = TURNER2004["GU_stack_dG37"].get((seq[i:i+2], rc[i:i+2]), 0.0)
            s12 = TURNER2004["GU_stack_dG37"].get((seq[i+1:i+3], rc[i+1:i+3]), 0.0)
            if seq[i+1:i+3]=="GU" and rc[i+1:i+3]=="UG": s12 = -0.5
            s23 = TURNER2004["GU_stack_dG37"].get((seq[i+2:i+4], rc[i+2:i+4]), 0.0)
            already = (s01 or 0.0) + (s12 or 0.0) + (s23 or 0.0)
            adjust = tri_total - already
            dG += adjust
    # terminal ends
    if len(seq)>0:
        if bp_type(seq[0], rc[0]) in ('AU','GU'):
            dG += TURNER2004["helix_end_penalty_dG37"]["AU_or_GU"]
        if bp_type(seq[-1], rc[-1]) in ('AU','GU'):
            dG += TURNER2004["helix_end_penalty_dG37"]["AU_or_GU"]
    return float(dG)


In [None]:

# -------------------------
# Build MFE quadratic surrogate (baseline + single-subst deltas + pairwise)
# -------------------------
def build_mfe_quadratic_surrogate(codon_choices: List[List[str]],
                                  baseline_choice_index: Optional[int],
                                  mfe_func: Callable[[str], float],
                                  do_pairwise: bool = True,
                                  verbose: bool = False):
    cache: Dict[str,float] = {}
    def cached_mfe(seq: str) -> float:
        if seq in cache: return cache[seq]
        val = mfe_func(seq)
        cache[seq] = val
        return val

    n_pos = len(codon_choices)
    base_idx_per_pos = []
    for i, choices in enumerate(codon_choices):
        if baseline_choice_index is None:
            base_idx_per_pos.append(0)
        else:
            base_idx_per_pos.append(min(baseline_choice_index, len(choices)-1))

    baseline_codons = [codon_choices[i][base_idx_per_pos[i]] for i in range(n_pos)]
    baseline_seq = join_codons_to_seq(baseline_codons)
    if verbose:
        print("Baseline seq:", baseline_seq)
        print("Computing baseline MFE (surrogate with GU) ...")
    M0 = cached_mfe(baseline_seq)

    a: Dict[Tuple[int,int], float] = {}
    for i in range(n_pos):
        for ci, codon in enumerate(codon_choices[i]):
            if ci == base_idx_per_pos[i]:
                a[(i,ci)] = 0.0
                continue
            curr = baseline_codons.copy(); curr[i] = codon
            seq = join_codons_to_seq(curr)
            mfe = cached_mfe(seq)
            a[(i,ci)] = mfe - M0
            if verbose:
                print(f" pos {i} choice {codon}: mfe {mfe:.3f}  delta {a[(i,ci)]:.3f}")

    b: Dict[Tuple[int,int,int,int], float] = {}
    if do_pairwise:
        if verbose: print("Computing pairwise interactions (this can be slow)...")
        for i in range(n_pos):
            for j in range(i+1, n_pos):
                for ci, codon_i in enumerate(codon_choices[i]):
                    for cj, codon_j in enumerate(codon_choices[j]):
                        if ci == base_idx_per_pos[i] and cj == base_idx_per_pos[j]:
                            b[(i,ci,j,cj)] = 0.0
                            continue
                        curr = baseline_codons.copy()
                        curr[i] = codon_i
                        curr[j] = codon_j
                        mfe_pair = cached_mfe(join_codons_to_seq(curr))
                        curr_i = baseline_codons.copy(); curr_i[i] = codon_i
                        mfe_i = cached_mfe(join_codons_to_seq(curr_i)) if not (ci==base_idx_per_pos[i]) else M0
                        curr_j = baseline_codons.copy(); curr_j[j] = codon_j
                        mfe_j = cached_mfe(join_codons_to_seq(curr_j)) if not (cj==base_idx_per_pos[j]) else M0
                        interaction = mfe_pair - mfe_i - mfe_j + M0
                        b[(i,ci,j,cj)] = interaction
                        if verbose:
                            print(f" pair ({i},{codon_i})({j},{codon_j}): mfe_pair {mfe_pair:.3f}, interaction {interaction:.4f}")
    else:
        for i in range(n_pos):
            for j in range(i+1, n_pos):
                for ci in range(len(codon_choices[i])):
                    for cj in range(len(codon_choices[j])):
                        b[(i,ci,j,cj)] = 0.0

    return a, b, M0


In [None]:

# -------------------------
# Quartet generator (candidate adjacent base-pair stacks) & Turner embedded quartet QUBO builder
# -------------------------
def generate_quartets(seq: str, min_loop_len: int = 3) -> List[Tuple[int,int]]:
    allowed_pairs = set([('A','U'),('U','A'),('G','C'),('C','G'),('G','U'),('U','G')])
    Q = []
    n = len(seq)
    for i in range(n-1):
        for j in range(i+min_loop_len+1, n):
            if (seq[i], seq[j]) in allowed_pairs and (seq[i+1], seq[j-1]) in allowed_pairs:
                Q.append((i,j))
    return Q

def quartet_energy_from_turner_embedded(seq: str, i: int, j: int, ua_terminal_penalty: float = 0.0) -> float:
    b1 = seq[i] + seq[j]
    b2 = seq[i+1] + seq[j-1]
    key1 = f"{b1}/{b2}"
    key2 = f"{b2}/{b1}"
    # lookup WC first
    val = TURNER2004["WC_stack_dG37"].get((b1,b2))
    if val is None:
        val = TURNER2004["WC_stack_dG37"].get((b2,b1))
    if val is None:
        # GU lookup approximate via GU table
        val = TURNER2004["GU_stack_dG37"].get((b1,b2))
        if val is None:
            val = TURNER2004["GU_stack_dG37"].get((b2[::-1], b1[::-1]))
    if val is None:
        val = 0.0
    if (b1 in ('AU','UA') or b2 in ('AU','UA')):
        val += ua_terminal_penalty
    return float(val)

def build_qubo_from_turner_embedded(seq: str, reward_stack: float = -0.5, ua_penalty: float = 0.0, cross_penalty: float = 20.0, min_loop_len: int = 3, verbose: bool = False):
    quartets = generate_quartets(seq, min_loop_len=min_loop_len)
    var_map = {q:i for i,q in enumerate(quartets)}
    Q_lin = {i:0.0 for i in range(len(quartets))}
    Q_quad = {}
    constant = 0.0
    for (i,j), vid in var_map.items():
        eq = quartet_energy_from_turner_embedded(seq, i, j, ua_terminal_penalty=0.0)
        Q_lin[vid] += float(eq)
        if verbose:
            print(f"quartet {(i,j)} var {vid} eq {eq:.3f}")
    for (i,j), vid in var_map.items():
        neighbor = (i+1, j-1)
        if neighbor in var_map:
            vj = var_map[neighbor]
            key = (min(vid, vj), max(vid, vj))
            Q_quad[key] = Q_quad.get(key, 0.0) + float(reward_stack)
            if ua_penalty != 0.0:
                b_close = seq[i+1] + seq[j-1]
                if b_close.upper() in ('AU','UA'):
                    Q_lin[vid] += float(ua_penalty)
                    Q_quad[key] = Q_quad.get(key, 0.0) - float(ua_penalty)
    for (a,b), vid in var_map.items():
        for (c,d), vj in var_map.items():
            if vid >= vj: continue
            if (a < c < b < d) or (c < a < d < b):
                key = (vid, vj)
                Q_quad[key] = Q_quad.get(key, 0.0) + float(cross_penalty)
                if verbose:
                    print(f"cross penalty between {(a,b)} var {vid} and {(c,d)} var {vj}: +{cross_penalty}")
    return Q_lin, Q_quad, constant, var_map, quartets


In [None]:

# -------------------------
# Build combined QUBO: codon-choice vars followed by quartet vars (if any)
# Returns combined_var_map mapping global_var_idx -> ('codon',(pos,cidx)) or ('quart',(i,j))
# -------------------------
def build_qubo_combined(codon_choices: List[List[str]], codon_weights: Dict[str,float], w_codon: float, w_mfe: float, gamma: float, baseline_choice_index: Optional[int], mfe_func: Callable[[str], float], do_pairwise_surrogate: bool, use_turner_pdf: bool = True, turner_kwargs: dict = None, verbose: bool = False):
    var_map, var_list = flatten_var_index_map(codon_choices)
    n_codons = len(var_list)
    Q_lin = {i: 0.0 for i in range(n_codons)}
    Q_quad = {}
    constant = 0.0
    # CAI linear
    for (pos,ci), vid in var_map.items():
        codon = codon_choices[pos][ci]
        w = max(1e-12, codon_weights.get(codon, 1e-12))
        Q_lin[vid] += w_codon * (-math.log(w))
    # MFE surrogate on codon-choice vars
    a_dict, b_dict, M0 = build_mfe_quadratic_surrogate(codon_choices, baseline_choice_index, mfe_func, do_pairwise=do_pairwise_surrogate, verbose=verbose)
    constant += w_mfe * M0
    for (i,ci), a in a_dict.items():
        vid = var_map[(i,ci)]; Q_lin[vid] += w_mfe * a
    for (i,ci,j,cj), b in b_dict.items():
        vi = var_map[(i,ci)]; vj = var_map[(j,cj)]
        if vi == vj:
            Q_lin[vi] += w_mfe * b
        else:
            key = (min(vi,vj), max(vi,vj)); Q_quad[key] = Q_quad.get(key, 0.0) + w_mfe * b
    # one-hot constraints for codon choices
    for p, choices in enumerate(codon_choices):
        for cidx in range(len(choices)):
            vid = var_map[(p,cidx)]; Q_lin[vid] += gamma * (-1.0)
        for c1, c2 in itertools.combinations(range(len(choices)), 2):
            v1 = var_map[(p,c1)]; v2 = var_map[(p,c2)]; key = (min(v1,v2), max(v1,v2)); Q_quad[key] = Q_quad.get(key, 0.0) + gamma * 2.0
        constant += gamma * 1.0

    combined_var_map = {i: ('codon', var_list[i]) for i in range(n_codons)}
    quartets_list = []
    if use_turner_pdf:
        # baseline seq to generate quartets appended after codon bits
        seq_ref = ''.join([choices[0] for choices in codon_choices])
        turner_kwargs = turner_kwargs or {}
        Q_lin_t, Q_quad_t, const_t, var_map_t, quartets = build_qubo_from_turner_embedded(seq_ref, **turner_kwargs)
        offset = n_codons
        # append linear
        for qidx, val in Q_lin_t.items():
            Q_lin[offset + qidx] = Q_lin.get(offset + qidx, 0.0) + w_mfe * val
            combined_var_map[offset + qidx] = ('quartet', quartets[qidx])
        # append quad
        for (i,j), val in Q_quad_t.items():
            key = (offset + i, offset + j)
            Q_quad[key] = Q_quad.get(key, 0.0) + w_mfe * val
        constant += w_mfe * const_t
        quartets_list = quartets

    return Q_lin, Q_quad, constant, combined_var_map, var_map


In [None]:

# -------------------------
# Cost table enumeration
# -------------------------
def qubo_to_cost_table(Q_lin: Dict[int,float], Q_quad: Dict[Tuple[int,int],float], constant: float):
    max_lin = max(Q_lin.keys()) if Q_lin else -1
    max_q = max((j for (_, j) in Q_quad.keys()), default=-1) if Q_quad else -1
    n_vars = max(max_lin, max_q) + 1
    if n_vars <= 0:
        raise RuntimeError("No variables")
    if n_vars > 22:
        raise RuntimeError(f"Too many variables to enumerate exactly (n_vars={n_vars}). Reduce top_k.")
    cost_table: Dict[str,float] = {}
    N = 2**n_vars
    for idx in range(N):
        b = format(idx, f'0{n_vars}b')[::-1]
        total = constant
        for i, ch in enumerate(b):
            if ch == '1':
                total += Q_lin.get(i, 0.0)
        for (i1,i2), coeff in Q_quad.items():
            if b[i1] == '1' and b[i2] == '1':
                total += coeff
        cost_table[b] = total
    return cost_table, n_vars

# -------------------------
# Two-local RY-CZ-RY ansatz
# -------------------------
def two_local_y_cz_ansatz(num_qubits=3, reps=2):
    qc = QuantumCircuit(num_qubits)
    thetas = [Parameter(f"θ_{i}") for i in range(num_qubits * 2 * reps)]
    k = 0
    for _ in range(reps):
        for q in range(num_qubits):
            qc.ry(2 * thetas[k], q); k += 1
        for q in range(num_qubits - 1):
            qc.cz(q, q + 1)
        for q in range(num_qubits):
            qc.ry(2 * thetas[k], q); k += 1
    return qc


In [None]:

# -------------------------
# Statevector eval + CVaR
# -------------------------
def evaluate_params_statevector(ansatz, param_values: np.ndarray, cost_table: Dict[str,float], alpha: float = 0.2, backend=None, verbose: bool = False):
    try:
        bound = ansatz.assign_parameters(param_values)
    except Exception:
        bound = ansatz.bind_parameters(list(param_values))
    if backend is None:
        backend = AerSimulator(method='statevector')

    circ = bound.copy()
    try:
        circ.save_statevector()
    except Exception:
        try:
            from qiskit.providers.aer.library import save_statevector
            circ.append(save_statevector(), [])
        except Exception:
            pass
    tcirc = transpile(circ, backend)
    res = backend.run(tcirc).result()
    data0 = res.data(0)
    sv = data0.get('statevector') or data0.get('state_vector') or data0.get('density_matrix')
    if sv is None:
        raise RuntimeError("statevector not found in result data")
    sv = np.array(sv, dtype=complex)
    N = len(sv); n_qubits = int(math.log2(N))
    probs: Dict[str,float] = {}
    for idx, amp in enumerate(sv):
        p = float((abs(amp))**2)
        if p == 0.0: continue
        b = format(idx, f'0{n_qubits}b')[::-1]
        probs[b] = probs.get(b, 0.0) + p
    s = sum(probs.values())
    if s>0 and abs(s-1.0)>1e-9:
        for k in list(probs.keys()):
            probs[k] /= s
    items = [(cost_table[b], p, b) for b, p in probs.items() if b in cost_table]
    items.sort(key=lambda x: x[0])
    cum = 0.0; weighted = 0.0
    for cost, p, b in items:
        take = min(p, max(0.0, alpha - cum))
        weighted += cost * take
        cum += take
        if cum >= alpha - 1e-12: break
    if cum == 0.0:
        expectation = sum(cost * p for cost, p, _ in items)
        return expectation, items
    return weighted / cum, items


In [None]:

# -------------------------
# Manual VQE COBYLA optimizing CVaR
# -------------------------
def manual_vqe_optimize(ansatz, cost_table: Dict[str,float], alpha: float = 0.2, maxiter: int = 80, verbose: bool = True):
    n_params = len(list(ansatz.parameters))
    if verbose: print("Ansatz has", n_params, "parameters; starting COBYLA.")
    history = []
    def obj(x):
        val, _ = evaluate_params_statevector(ansatz, x, cost_table, alpha=alpha, backend=None, verbose=False)
        history.append(float(val))
        if verbose:
            print("obj -> CVaR", float(val))
        return float(val)
    x0 = np.random.uniform(0, 2*np.pi, size=n_params)
    res = minimize(obj, x0, method='COBYLA', options={'maxiter': maxiter})
    final_cvar, items = evaluate_params_statevector(ansatz, res.x, cost_table, alpha=alpha, backend=None, verbose=False)
    if len(items) == 0:
        raise RuntimeError("No observed states in final statevector.")
    best_cost, prob, best_b = items[0]
    if verbose:
        print("Optimization finished. Best cost:", best_cost, "best_bitstring:", best_b)
    return best_b, best_cost, res, history


In [None]:

# -------------------------
# Dot-bracket generator & plot
# -------------------------
def basepairs_from_quartets(quartet_list: List[Tuple[int,int]], bits_set: List[int]) -> List[Tuple[int,int]]:
    """Given quartet_list and list of indices set to 1 among quartet variables, return (i,j) base pairs for each quartet as two pairs: (i,j) and (i+1,j-1)."""
    bp = []
    for idx in bits_set:
        if idx < 0 or idx >= len(quartet_list): continue
        i,j = quartet_list[idx]
        bp.append((i,j))
        bp.append((i+1,j-1))
    # remove duplicates and invalid
    bp2 = []
    seen = set()
    for a,b in bp:
        if a<0 or b<0 or a>=b: continue
        if (a,b) in seen: continue
        seen.add((a,b)); bp2.append((a,b))
    return sorted(bp2)

def basepairs_to_dotbracket(seq: str, base_pairs: List[Tuple[int,int]]) -> str:
    n = len(seq)
    db = ['.' for _ in range(n)]
    for i,j in base_pairs:
        if 0 <= i < n and 0 <= j < n:
            db[i] = '('
            db[j] = ')'
    return ''.join(db)

def plot_secondary_structure(seq: str, dotbr: str):
    n = len(seq)
    # find pairs from dotbr
    stack = []
    pairs = []
    for i, c in enumerate(dotbr):
        if c == '(':
            stack.append(i)
        elif c == ')':
            if stack:
                j = stack.pop()
                pairs.append((j, i))
    # plot
    fig, ax = plt.subplots(figsize=(max(6, n*0.25), 3))
    xs = np.arange(n)
    ax.scatter(xs, np.zeros(n), s=8, c='k')
    for i, base in enumerate(seq):
        ax.text(i, 0.05, base, ha='center', va='bottom', fontsize=9)
    for (i,j) in pairs:
        mid = (i + j) / 2
        width = j - i
        # semi-circle arc
        theta = np.linspace(0, np.pi, 100)
        x = i + (width/2) * (1 - np.cos(theta))
        y = (width/2) * np.sin(theta) / (width/2 + 1)  # normalized height
        ax.plot(x, y, color='C0')
    ax.set_ylim(-0.5, max(1.0, n*0.05))
    ax.axis('off')
    plt.show()


In [None]:
# -------------------------
# High-level pipeline: cooptimize and display VQE-optimized structure
# -------------------------
def cooptimize_peptide_10mer(peptide_seq: str,
                             top_k: int = 2,
                             codon_weights: Dict[str,float] = None,
                             w_codon: float = 1.0,
                             w_mfe: float = 1.0,
                             gamma: float = 40.0,
                             baseline_choice_index: Optional[int] = None,
                             do_pairwise_surrogate: bool = True,
                             use_turner_pdf: bool = True,
                             turner_kwargs: dict = None,
                             alpha: float = 0.2,
                             maxiter: int = 50,
                             reps: int = 3,
                             verbose: bool = True):
    peptide_seq = peptide_seq.strip().upper()
    if len(peptide_seq) != 5:
        raise ValueError("Please supply a peptide sequence of length 5.")
    codon_weights = codon_weights or DEFAULT_CODON_WEIGHTS

    # 1) prefilter codons
    codon_choices = prefilter_codons_for_peptide(peptide_seq, codon_weights, top_k=top_k)
    if verbose:
        print("Prefiltered codon choices per position (top_k=%d):" % top_k)
        for i, choices in enumerate(codon_choices):
            print(i, choices)

    # 2) build combined QUBO (codon vars + quartet vars)
    Q_lin, Q_quad, const, combined_var_map, codon_var_map = build_qubo_combined(
        codon_choices, codon_weights, w_codon, w_mfe, gamma, baseline_choice_index,
        mfe_surrogate_turner_gu, do_pairwise_surrogate, use_turner_pdf=use_turner_pdf,
        turner_kwargs=turner_kwargs or {}, verbose=verbose
    )

    # gather quartets list if present
    quartet_list = []
    for idx, info in combined_var_map.items():
        if info[0]=='quartet':
            quartet_list.append((idx, info[1]))  # (global_var_idx, (i,j))
    quartet_list_sorted = sorted(quartet_list, key=lambda x: x[0])
    quartet_globals = [q for g,q in quartet_list_sorted]
    quartet_global_idxs = [g for g,q in quartet_list_sorted]

    # 3) enumerate cost table
    cost_table, n_vars = qubo_to_cost_table(Q_lin, Q_quad, const)
    if verbose:
        print("Enumerated cost_table: n_vars =", n_vars, "states =", len(cost_table))

    # 4) ansatz & VQE
    ansatz = two_local_y_cz_ansatz(num_qubits=n_vars, reps=reps)
    best_b, best_cost, opt_res, history = manual_vqe_optimize(ansatz, cost_table, alpha=alpha, maxiter=maxiter, verbose=verbose)

    # 5) decode codons from first n_codons entries of combined_var_map
    # build ordered var_list for codon decoding (positions in order)
    codon_var_list = [v for k,v in sorted(codon_var_map.items(), key=lambda x:x[1])]
    # But we have codon_vars mapping (pos,cidx) => vid earlier; easier decode by iterating combined_var_map
    n_total = max(combined_var_map.keys())+1
    chosen_codons = []
    codon_positions = []
    # Extract codon mapping entries in combined_var_map by scanning keys with info[0]=='codon'
    codon_entries = [(vid, info[1]) for vid, info in combined_var_map.items() if info[0]=='codon']
    # sort by vid
    codon_entries = sorted(codon_entries, key=lambda x:x[0])
    # group by position
    pos_to_choice = {}
    for vid, (pos, cidx) in codon_entries:
        if best_b[vid] == '1':
            pos_to_choice[pos] = cidx
    # fill defaults
    for p in range(len(codon_choices)):
        cidx = pos_to_choice.get(p, 0)
        chosen_codons.append(codon_choices[p][cidx])

    seq = join_codons_to_seq(chosen_codons)
    true_surrogate_mfe = mfe_surrogate_turner_gu(seq)

    if verbose:
        print("\nRESULTS")
        print("-------")
        print("chosen codons:", chosen_codons)
        print("mRNA seq:", seq)
        print("Surrogate ΔG37 (with GU + ends):", true_surrogate_mfe)
        print("QUBO best cost:", best_cost)

    # 6) Extract quartet bits and convert to base pairs -> dot-bracket & plot
    # quartet_global_idxs list holds global var indices for quartet variables in increasing order
    quartet_bits_set_indices = []
    for g_idx in quartet_global_idxs:
        # ensure we don't exceed bitstring length
        if g_idx < len(best_b) and best_b[g_idx] == '1':
            # find index within quartet_list (0..len-1)
            pos_in_list = quartet_global_idxs.index(g_idx)
            quartet_bits_set_indices.append(pos_in_list)

    # map quartet_list indices back to quartet (i,j)
    quartet_selected = [quartet_globals[i] for i in quartet_bits_set_indices]  # list of (i,j)
    base_pairs = basepairs_from_quartets(quartet_selected, list(range(len(quartet_selected)))) if quartet_selected else []

    # simpler: compute base pairs directly from quartet_selected
    bp_list = []
    for (i,j) in quartet_selected:
        bp_list.append((i,j))
        bp_list.append((i+1,j-1))
    # remove duplicates/sort
    bp_list = sorted({(a,b) for (a,b) in bp_list if 0<=a<b<len(seq)})

    dotbr = basepairs_to_dotbracket(seq, bp_list)
    print("\nVQE-optimized (quartet-derived) secondary structure (dot-bracket):")
    print(seq)
    print(dotbr)

    # plot structure
    try:
        plot_secondary_structure(seq, dotbr)
    except Exception as e:
        print("Plot failed:", e)

    # plot optimization history
    if len(history)>0:
        plt.figure(figsize=(6,3)); plt.plot(history, marker='o'); plt.title("VQE CVaR history"); plt.xlabel("eval"); plt.ylabel("CVaR"); plt.grid(True); plt.show()

    return chosen_codons, seq, best_cost, true_surrogate_mfe, opt_res, dotbr

# -------------------------
# Demo (runs when script executed)
# -------------------------
if __name__ == "__main__":
    peptide_seq = "DRNKF"
    chosen_codons, seq, best_cost, dG_sur, opt_res, dotbr = cooptimize_peptide_10mer(
        peptide_seq=peptide_seq,
        top_k=1,
        codon_weights=DEFAULT_CODON_WEIGHTS,
        w_codon=1.0,
        w_mfe=1.0,
        gamma=40.0,
        baseline_choice_index=None,
        do_pairwise_surrogate=True,
        use_turner_pdf=True,
        turner_kwargs=dict(reward_stack=-0.5, ua_penalty=0.0, cross_penalty=20.0, min_loop_len=3, verbose=False),
        alpha=0.2,
        maxiter=40,
        reps=2,
        verbose=True
    )
    print("\nDone. Dot-bracket (VQE):", dotbr)


In [None]:
# Calculate and display the MFE energy score
mfe_score = mfe_surrogate_turner_gu(seq)
print(f"MFE energy score (surrogate with GU + ends): {mfe_score}")