<a href="https://colab.research.google.com/github/ynaowusu/protein-folding-quantum-algorithms/blob/main/proteinfolding.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [71]:
!pip install qiskit #this is just me intalling qiskit into our notebook
!pip install matplotlib plotly #since it says we need a 3d structure to simulate the lattice and any other 3d elements
!pip install numpy
!pip install qiskit-aer




In [75]:
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from qiskit import QuantumCircuit, QuantumRegister, ClassicalRegister
from qiskit.quantum_info import SparsePauliOp
from qiskit_aer import Aer  # Fixed import
from scipy.optimize import minimize
from typing import List, Dict, Tuple
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from qiskit import QuantumCircuit, QuantumRegister, ClassicalRegister
from qiskit.quantum_info import SparsePauliOp
from qiskit_aer import Aer  # Fixed import
from scipy.optimize import minimize
from typing import List, Dict, Tuple

"""
proteins fold in 3D space and here we're working with a simplified lattice model w tetrahedral directions. There are two kinds of alternating sites in the protein chain: 'A' and 'B'.
every turn corresponds to a direction in 3D space either x,y or z
The vectors  [1, 1, 1]) are unit steps in different directions based on the protein structure model.
"""

class SparseDiamondEncoder:
    """Encodes/decodes backbone turns on a diamond (tetrahedral) lattice."""

    # sublattice A basis (even sites)
    _A: Dict[int,np.ndarray] = {
        0: np.array([+1,+1,+1]),    # |00> This is for the first round of qubits represents x,y,z
        1: np.array([+1,-1,-1]),    # |10>
        2: np.array([-1,+1,-1]),    # |01>
        3: np.array([-1,-1,+1]),    # |11>
    }
    # sublattice B basis (odd sites)
    _B: Dict[int,np.ndarray] = {
        0: np.array([+1,+1,-1]),    # |00>
        1: np.array([+1,-1,+1]),    # |10>
        2: np.array([-1,+1,+1]),    # |01>
        3: np.array([-1,-1,-1]),    # |11>
    }

    @staticmethod
    def site_parity(i):
        """Return 'A' if turn index i is even (sublattice A), else 'B'."""
        return 'A' if i % 2 == 0 else 'B'

    @classmethod
    def block_to_vec(cls, block, i):
        """
        Decode one-hot 4-bit string (e.g. '1000') at turn index i into a 3D step vector.
        Raises ValueError if block is not one-hot of length 4.
        """
        if len(block) != 4 or block.count('1') != 1:
            raise ValueError(f"Invalid block {block}, must be one-hot length 4.")
        axis = ['1000','0100','0010','0001'].index(block)
        if cls.site_parity(i) == 'A':
            table = cls._A
        else:
            table = cls._B
        return table[axis]

    @classmethod
    def bitstring_to_turns(cls, bitstr, N):
        """
        Convert full measured bitstring (length=4*(N-3)) into list of N-1 turn vectors:
        - t1 = (+1,+1,+1)
        - t2 = (+1,-1,-1)
        - t3..t_{N-1} from blocks of 4 bits
        """
        expected = 4*(N-3)
        if len(bitstr) != expected:
            raise ValueError(f"Expected bitstring length {expected}, got {len(bitstr)}")
        turns: List[np.ndarray] = []
        # fixed first two turns
        turns.append(cls._A[0])  # t1
        turns.append(cls._A[1])  # t2
        # decode remaining
        for k, start in enumerate(range(0, expected, 4), start=3):
            block = bitstr[start:start+4]
            turns.append(cls.block_to_vec(block, k))
        return turns

    @classmethod
    def turns_to_valid_bitstring(cls, N):
        """Generate a valid one-hot encoded bitstring for testing."""
        n_variable_turns = N - 3
        bitstring = ""
        for i in range(n_variable_turns):
            # Create a one-hot block (randomly pick one position to be '1')
            block = ['0'] * 4
            block[np.random.randint(4)] = '1'
            bitstring += ''.join(block)
        return bitstring


def T(t_i: np.ndarray, t_j: np.ndarray) -> int:
    """
    Indicator: 1 if t_i and t_j lie on same lattice axis or opposite, else 0.
    """
    return int(np.array_equal(t_i, t_j) or np.array_equal(t_i, -t_j))


"""
Hamiltonian defines the cost Hamiltonian terms for a protein fold on a diamond lattice:

1. Growth Constraint (H_gc): penalises consecutive backtracking bonds.
2. Overlap Penalty (H_ov): penalises any pair of beads occupying the same site.
3. Contact Reward (H_ct): rewards non-bonded nearest neighbours.
4. Chirality Constraint (H_ch): enforces correct handedness per sublattice.
5. Interaction Stub (H_in): placeholder for shell-based interaction qubits.

Each method returns a classical float energy for the decoded turn list.
"""

class Hamiltonian:
    def __init__(
        self,
        N: int,  # Number of amino acids
        back_penalty_weight: float = 30.0,
        overlap_weight: float = 50.0,
        contact_reward: float = 10.0,
        chirality_weight: float = 40.0,
        interaction_weights=None,
    ):
        """
        Parameters:
        -----------
        N : Number of amino acids in the protein
        turns : List of N-1 3D step vectors (np.ndarray).
        back_penalty_weight : λ_back for growth constraint.
        overlap_weight : λ_ov for overlap penalty.
        contact_reward : ε for contact reward (>0).
        chirality_weight : λ_chir for chirality penalty.
        """
        self.N = N
        self.total_qubits = 4 * (N - 3)  # Fixed: added total_qubits attribute
        self.λ_back = back_penalty_weight
        self.λ_ov = overlap_weight
        self.ε = contact_reward
        self.λ_chir = chirality_weight
        self.interaction_weights = interaction_weights or {}

    def calculate_energy(self, turns):
        """Calculate total energy for given turns."""
        return (
            self.growth_constraint(turns)
            + self.overlap_penalty(turns)
            + self.contact_reward(turns)  # Fixed method name
            + self.chirality_constraint(turns)
            + self.interaction_constraint(turns)
        )

    def total_energy(self, bitstring):
        """
        Calculate energy from bitstring by first decoding to turns.
        This is what CVARVQE calls.
        """
        try:
            turns = SparseDiamondEncoder.bitstring_to_turns(bitstring, self.N)
            return self.calculate_energy(turns)
        except ValueError:
            # If bitstring is invalid, return high penalty
            return 1000.0

    def growth_constraint(self, turns):
        """
        H_gc = λ_back * Σ_{i=3..N-1} T(t_i, t_{i+1}).
        """
        H_gc = 0.0
        for i in range(2, len(turns) - 1):
            H_gc += self.λ_back * T(turns[i], turns[i + 1])
        return H_gc

    def overlap_penalty(self, turns):
        """
        H_ov = λ_ov * Σ overlap penalty for beads at same position
        """
        pos = [(0, 0, 0)]
        for t in turns:
            prev = np.array(pos[-1])
            pos.append(tuple(prev + t))

        H_ov = 0.0
        # Check for overlaps
        for i in range(len(pos)):
            for j in range(i + 1, len(pos)):
                if pos[i] == pos[j]:
                    H_ov += self.λ_ov
        return H_ov

    def contact_reward(self, turns):  # Fixed method name
        """
        H_ct = -ε * Σ contact rewards for non-bonded nearest neighbors
        """
        pos = [(0, 0, 0)]
        for t in turns:
            prev = np.array(pos[-1])
            pos.append(tuple(prev + t))

        H_ct = 0.0
        # threshold for nearest neighbor on diamond lattice
        nn_dist = np.sqrt(3)
        tol = 1e-6
        for i in range(len(pos)):
            # |i-j|>2 for non-bonded neighbors
            for j in range(i + 3, len(pos)):
                d = np.linalg.norm(np.array(pos[i]) - np.array(pos[j]))
                if abs(d - nn_dist) < tol:
                    H_ct += -self.ε
        return H_ct

    def chirality_constraint(self, turns):
        """
        H_ch = λ_chir * Σ_{i=2..N-2} indicator of wrong handedness;
        uses scalar triple product and sublattice parity.
        """
        H_ch = 0.0
        for i in range(2, len(turns) - 1):
            if i >= len(turns) - 1:
                break
            v_prev = turns[i - 1]
            v_curr = turns[i]
            v_side = turns[i + 1] if i + 1 < len(turns) else turns[i]
            triple = np.dot(np.cross(v_prev, v_curr), v_side)
            parity = (1 - (-1) ** i) // 2
            # expect triple>0 on A (parity=0), <0 on B (parity=1)
            if not ((parity == 0 and triple > 0) or (parity == 1 and triple < 0)):
                H_ch += self.λ_chir
        return H_ch

    def interaction_constraint(self, turns):
        """Interaction constraint based on distance shells."""
        H_in = 0.0
        # build bead positions
        pos = [(0, 0, 0)]
        for t in turns:
            prev = np.array(pos[-1])
            pos.append(tuple(prev + t))
        tol = 1e-6
        for shell, weight in self.interaction_weights.items():
            for i in range(len(pos)):
                for j in range(i + 3, len(pos)):
                    d = np.linalg.norm(np.array(pos[i]) - np.array(pos[j]))
                    if abs(d - shell) < tol:
                        H_in += weight
        return H_in


def create_configuration_circuit(N):
    """
    Allocate 4*(N-3) config qubits in a register, apply Hadamard to each,
    and attach classical bits for measurement.
    Returns (qc, cfg, meas).
    """
    n_cfg = 4*(N-3)
    cfg = QuantumRegister(n_cfg, 'cfg')
    meas = ClassicalRegister(n_cfg, 'meas')
    qc = QuantumCircuit(cfg, meas)
    qc.h(cfg)
    return qc, cfg, meas


def decode_measurement(counts, N):
    """
    From measurement counts, pick the most frequent bitstring and decode into turn vectors.
    """
    bitstr = max(counts, key=counts.get)
    return SparseDiamondEncoder.bitstring_to_turns(bitstr, N)


class CVARVQE:
    def __init__(self, hamiltonian, alpha=0.1):
        # self is just the initializer
        # the hamiltonian calculates the amount of energy it takes to form any given protein configuration.
        # this is important because the optimizer needs a way to evaluate how "good" or "bad" a quantum circuit's result is.
        # alpha = 0.1 means CVaR will only average the best 10% of outputs.
        self.hamiltonian = hamiltonian  # Save the energy calculator
        self.alpha = alpha              # Only use the best 10% of folds
        self.n_qubits = hamiltonian.total_qubits  # Know how many qubits to use in the circuit
        self.backend = Aer.get_backend('qasm_simulator')  # Initialize the simulator backend

    def create_ansatz(self, params):
        qc = QuantumCircuit(self.n_qubits)
        # This creates a quantum circuit with the number of qubits based on the protein

        # Step 1: Put each qubit into superposition using Hadamard gates
        for i in range(self.n_qubits):
            qc.h(i)  # Hadamard gate turns each qubit into a mix of 0 and 1

        # Step 2: Add a rotation gate to each qubit (this is how we "teach" the circuit how to fold the protein)
        param_idx = 0  # Keeps track of which parameter we're using from the list

        for i in range(self.n_qubits):
            qc.ry(params[param_idx], i)  # apply RY (Y-axis) rotation to qubit i using params[param_idx]
            param_idx += 1  # Move to the next angle in the list

        # Step 3: Add entangling gates
        for i in range(self.n_qubits - 1):  # what this is is that it's making a CNOT gate to make an entanglement between two qubits
            qc.cx(i, i + 1)  # this entanglement lets qubits share information which is important because in proteins, one fold affects nearby folds.

        # Step 4: More rotation gates
        for i in range(self.n_qubits):
            qc.ry(params[param_idx], i)  # this is a repeat of the other line of code that just rotates the qubit again after the cnot entangled gate is applied
            param_idx += 1  # the reason it's rotated again is so that the protein can fold in more complex ways.

        return qc  # Return the final circuit

    def evaluate_energy(self, params, n_shots=1000):
        # this function will be based around evaluating the energy of the circuit.
        # it takes in the list of angles from a circuit, n_shots is for how many times the circuit should run and it returns an energy score.
        qc = self.create_ansatz(params)  # this is used to build a quantum circuit with the angles from the previous function
        qc.add_register(ClassicalRegister(self.n_qubits))  # this classical register is a storage unit designed to hold classical bits, which represent binary states or bitstring of (0 and 1s)
        qc.measure_all()  # this tells the quantum circuit to measure every qubit so a bitstring is outputted like 0101 or 10101
        # basically what this says is measure all the incoming qubits then place them in the classical register

        # Use the initialized backend to run the circuit
        job = self.backend.run(qc, shots=n_shots)
        result = job.result()  # this gets the data from the run

        counts = result.get_counts()  # this gives you the output in dictionary form.
        # as an output it could like this [{'000': 50, '001': 100, '010': 250, '011': 600}] That means Bitstring 000 came up 50 times, 001 100 times etc.
        # overall it basically means get the outcome probabilities from the first circuit run, so I can loop through each bitstring and score it based on how well it folded the protein.

        energies = []  # This is a list to store energy values for each output

        for bitstring, count in counts.items():  # this loops through each result and its probability
            # binary_str = format(bitstring, f'0{self.n_qubits}b') #what this is basically doing is taking the integer bitstring and turn it into a nice, full-length binary string (like '0101') that matches the number of qubits we're using.

            # here's an example of what happens If self.n_qubits = 4, and bitstring = 3, then: format(3, '04b') → '0011'

            energy = self.hamiltonian.total_energy(bitstring)  # so with this line, what it's doing is calling the hamiltonian to calculate the energy the bitstring takes. so for example if we have a bitstring of like 0011, the energy used to make the fold is 2.5. it takes 2.5 units to hold that shape.

            energies.extend([energy] * count)  # so this is adding energy values to a list multiple times, depending on how often that folding showed up in the simulation.
            # for example Let's say this fold showed up 300 times (count = 300), You add [energy] * 300 → so this energy gets added 300 times and so now the list reflects how often each folding happened

        energies = np.array(energies)  # what this is doing is that it converts the list of energies into a numPy array so we can use numPy functions like sort() and mean().
        energies.sort()  # it sorts all of the energy values from ascending order from low to high

        cutoff_idx = int(len(energies) * self.alpha)  #
        if cutoff_idx == 0:
            cutoff_idx = 1  # Always keep at least 1

        # Average the best energies to get our CVaR value
        cvar_energy = np.mean(energies[:cutoff_idx])

        return cvar_energy  # Lower is better

    def optimize(self, initial_params=None, maxiter=50):
        """
        Try different parameter values to find the ones that give the lowest CVaR energy.

        initial_params: Starting guess. If not given, we pick random angles.
        maxiter: How many times to try new parameters.

        Returns the best parameters and their corresponding energy.
        """
        # If we don't have a starting guess, create random parameters
        if initial_params is None:
            initial_params = np.random.uniform(0, 2 * np.pi, 2 * self.n_qubits)

        # Use a classical algorithm (COBYLA) to minimize the energy
        result = minimize(
            self.evaluate_energy,         # What we're trying to minimize
            initial_params,               # Starting guess
            method='COBYLA',              # Optimization algorithm
            options={'maxiter': maxiter}  # How many steps to take
        )

        return result.x, result.fun  # Return best params and best energy



In [76]:



# Test everything
if __name__ == "__main__":
    print("this is just a test psssss")

    # Test 1: SparseDiamondEncoder with valid bitstring
    print("\n1. Testing SparseDiamondEncoder...")
    N = 6
    valid_bitstring = SparseDiamondEncoder.turns_to_valid_bitstring(N)
    print(f"Valid test bitstring: {valid_bitstring}")
    try:
        turns = SparseDiamondEncoder.bitstring_to_turns(valid_bitstring, N)
        print(f"Decoded turns: {[tuple(t) for t in turns]}")
        print("✓ SparseDiamondEncoder working!")
    except Exception as e:
        print(f"✗ SparseDiamondEncoder failed: {e}")

    # Test 2: Hamiltonian
    print("\n2. Testing Hamiltonian...")
    try:
        ham = Hamiltonian(N)
        energy = ham.total_energy(valid_bitstring)
        print(f"Total energy for test configuration: {energy}")
        print("✓ Hamiltonian working!")
    except Exception as e:
        print(f"✗ Hamiltonian failed: {e}")

    # Test 3: CVARVQE
    print("\n3. Testing CVARVQE...")
    try:
        ham = Hamiltonian(N)
        vqe = CVARVQE(ham)

        # Test circuit creation
        test_params = np.random.uniform(0, 2*np.pi, 2 * ham.total_qubits)
        circuit = vqe.create_ansatz(test_params)
        print(f"Created VQE circuit with {circuit.num_qubits} qubits")

        # Test energy evaluation
        energy = vqe.evaluate_energy(test_params, n_shots=100)
        print(f"Test energy evaluation: {energy}")
        print("✓ CVARVQE working!")

    except Exception as e:
        print(f"✗ CVARVQE failed: {e}")

    # Test 4: Quick optimization
    print("\n4. Running quick optimization...")
    try:
        ham = Hamiltonian(N=5, contact_reward=5.0)  # Smaller protein
        vqe = CVARVQE(ham, alpha=0.2)

        print("Starting optimization (this may take a moment)...")
        best_params, best_energy = vqe.optimize(maxiter=10)  # Quick test

        print(f"Best energy found: {best_energy:.2f}")
        print(f"Best parameters: {best_params[:3]}...")  # Show first 3

        # Get the best configuration
        qc = vqe.create_ansatz(best_params)
        qc.add_register(ClassicalRegister(vqe.n_qubits))
        qc.measure_all()

        job = vqe.backend.run(qc, shots=1000)
        result = job.result()
        counts = result.get_counts()
        best_bitstring = max(counts, key=counts.get)

        print(f"Best configuration bitstring: {best_bitstring}")

        # Decode and visualize
        best_turns = SparseDiamondEncoder.bitstring_to_turns(best_bitstring, ham.N)
        print(f"Best folding turns: {[tuple(t) for t in best_turns]}")

        print("✓ Optimization complete!")

    except Exception as e:
        print(f"✗ Optimization failed: {e}")


this is just a test psssss

1. Testing SparseDiamondEncoder...
Valid test bitstring: 000101001000
Decoded turns: [(np.int64(1), np.int64(1), np.int64(1)), (np.int64(1), np.int64(-1), np.int64(-1)), (np.int64(-1), np.int64(-1), np.int64(-1)), (np.int64(1), np.int64(-1), np.int64(-1)), (np.int64(1), np.int64(1), np.int64(-1))]
✓ SparseDiamondEncoder working!

2. Testing Hamiltonian...
Total energy for test configuration: 30.0
✓ Hamiltonian working!

3. Testing CVARVQE...
Created VQE circuit with 12 qubits
Test energy evaluation: 1000.0
✓ CVARVQE working!

4. Running quick optimization...
Starting optimization (this may take a moment)...
Best energy found: 1000.00
Best parameters: [3.91026778 2.81393351 1.6452246 ]...
Best configuration bitstring: 01101001 00000000
✗ Optimization failed: Expected bitstring length 8, got 17
