In [1]:
!source ~/anaconda3/etc/profile.d/conda.sh

In [2]:
!conda create -n torch_env python=3.10 -y --quiet
# !conda run -n torch_env pip install torch --index-url https://download.pytorch.org/whl/cpu

!conda run -n torch_env pip install torch==2.3.0 torchtext==0.18.0 torchquantum

!conda run -n torch_env pip install amazon-braket-sdk amazon-braket-pennylane-plugin
!conda run -n torch_env pip install ipykernel
!conda run -n torch_env python -m ipykernel install --user --name torch_env --display-name "PyTorch Env"

Channels:
 - conda-forge
Platform: linux-64
Collecting package metadata (repodata.json): ...working... done
Solving environment: ...working... done

## Package Plan ##

  environment location: /home/ec2-user/anaconda3/envs/torch_env

  added / updated specs:
    - python=3.10


The following packages will be downloaded:

    package                    |            build
    ---------------------------|-----------------
    zstd-1.5.7                 |       h3691f8a_5         537 KB  conda-forge
    ------------------------------------------------------------
                                           Total:         537 KB

The following NEW packages will be INSTALLED:

  _libgcc_mutex      conda-forge/linux-64::_libgcc_mutex-0.1-conda_forge 
  _openmp_mutex      conda-forge/linux-64::_openmp_mutex-4.5-2_gnu 
  bzip2              conda-forge/linux-64::bzip2-1.0.8-hda65f42_8 
  ca-certificates    conda-forge/noarch::ca-certificates-2025.11.12-hbd8a1cb_0 
  ld_impl_linux-64   conda-forge

In [3]:
import sys
print(sys.executable)

/home/ec2-user/anaconda3/envs/torch_env/bin/python


In [4]:
## ! /home/ec2-user/anaconda3/envs/Braket/bin/python -m pip install torch

In [5]:
# quixer_braket_sdk.py
import itertools
import math
from typing import Any, Dict, Tuple

import numpy as np
import torch
from torch.types import Device

from braket.circuits import Circuit
from braket.devices import LocalSimulator


In [6]:
# -------------------------
# 1) Ansatz 14 -> Braket Circuit (numeric params)
# -------------------------
def ansatz_14_braket_circuit(n_qubits: int, params: torch.Tensor, layers: int = 1) -> Circuit:
    """
    Build a Braket Circuit implementing ansatz 14 using *numeric* params.
    params: 1D tensor length = 4 * n_qubits * layers
    """
    if params.numel() != 4 * n_qubits * layers:
        raise ValueError(f"params length must be {4*n_qubits*layers}, got {params.numel()}")

    circ = Circuit()
    p_iter = iter(params.tolist())

    for _ in range(layers):
        # First RY layer
        for q in range(n_qubits):
            angle = next(p_iter)
            circ.ry(q, angle)

        # First CRX layer (reverse order)
        # Decompose CRX(control,target,theta) as:
        # Rx(target, theta/2) - CNOT(control,target) - Rx(target, -theta/2) - CNOT(control,target)
        for i in range(n_qubits - 1, -1, -1):
            control = i
            target = (i + 1) % n_qubits
            theta = next(p_iter)
            circ.rx(target, theta / 2.0)
            circ.cnot(control, target)
            circ.rx(target, -theta / 2.0)
            circ.cnot(control, target)

        # Second RY layer
        for q in range(n_qubits):
            angle = next(p_iter)
            circ.ry(q, angle)

        # Second CRX layer (different neighbor direction)
        order = [n_qubits - 1] + list(range(n_qubits - 1))
        for i in order:
            control = i
            target = (i - 1) % n_qubits
            theta = next(p_iter)
            circ.rx(target, theta / 2.0)
            circ.cnot(control, target)
            circ.rx(target, -theta / 2.0)
            circ.cnot(control, target)

    return circ


In [7]:
# -------------------------
# 2) Simulator helpers: statevector and unitary construction with caching
# -------------------------
class BraketUnitaryCache:
    """
    Cache mapping (n_qubits, params_tuple, layers) -> torch.Tensor (dim x dim complex)
    """
    def __init__(self):
        self._cache: Dict[Tuple[int, Tuple[float, ...], int], torch.Tensor] = {}

    def get(self, n_qubits: int, params: torch.Tensor, layers: int):
        key = (n_qubits, tuple(float(x) for x in params.tolist()), layers)
        return self._cache.get(key, None)

    def set(self, n_qubits: int, params: torch.Tensor, layers: int, U: torch.Tensor):
        key = (n_qubits, tuple(float(x) for x in params.tolist()), layers)
        self._cache[key] = U


_unitary_cache = BraketUnitaryCache()


def circuit_statevector(circ: Circuit, n_qubits: int, device: LocalSimulator) -> np.ndarray:
    """
    Run a Braket circuit (ending with state_vector() result type) on the LocalSimulator
    and return the resulting statevector as a numpy complex array of shape (2**n_qubits,).
    """
    # Append the state_vector result request
    circ_with_sv = circ.state_vector()
    # shots=0 requests exact statevector from LocalSimulator
    task = device.run(circ_with_sv, shots=0)
    result = task.result()
    # Braket returns result_types list whose first entry was the StateVector with `.value`
    # This access pattern is consistent with SDK examples: result.result_types[0].value
    sv = np.array(result.result_types[0].value, dtype=np.complex128)
    return sv


def circuit_unitary(circ: Circuit, n_qubits: int, device: LocalSimulator, cache: BraketUnitaryCache = None) -> torch.Tensor:
    """
    Construct the numeric unitary for a numeric circuit by running it on each computational basis input.
    Returns a torch.complex64 tensor shape (dim, dim).
    Uses caching if provided.
    """
    dim = 2**n_qubits
    cols = []
    for k in range(dim):
        prep = Circuit()
        # prepare |k> by applying X on qubits with bit=1
        for q in range(n_qubits):
            if ((k >> q) & 1) == 1:
                prep.x(q)
        total = prep + circ
        sv = circuit_statevector(total, n_qubits, device)
        cols.append(sv)
    U = np.column_stack(cols)  # shape (dim, dim)
    return torch.tensor(U, dtype=torch.complex64)


def get_or_build_unitary(n_qubits: int, params: torch.Tensor, layers: int, device: LocalSimulator, cache: BraketUnitaryCache):
    """
    Return cached unitary or build and cache it.
    """
    if cache is not None:
        cached = cache.get(n_qubits, params, layers)
        if cached is not None:
            return cached

    circ = ansatz_14_braket_circuit(n_qubits, params, layers=layers)
    U = circuit_unitary(circ, n_qubits, device, cache=cache)

    if cache is not None:
        cache.set(n_qubits, params, layers, U)
    return U

In [8]:
# -------------------------
# 3) LCU & QSVT implementations (classical simulation)
# -------------------------
# def apply_linear_combination_of_unitaries_braket(
#     initial_states: torch.Tensor,  # [batch, dim]
#     pqc_parameters: torch.Tensor,  # [batch, n_tokens, n_pqc_params]
#     n_qubits: int,
#     lcu_coefficients: torch.Tensor,  # [n_tokens] complex
#     device: LocalSimulator,
#     layers: int = 1,
#     cache: BraketUnitaryCache = None,
# ) -> torch.Tensor:
#     """
#     Apply LCU: sum_w b_w U(w) |psi>  where U(w) is unitary given by PQC params for token w.
#     Returns [batch, dim] complex tensor.
#     """

#     batch = initial_states.shape[0]
#     n_tokens = pqc_parameters.shape[1]
#     dim = 2**n_qubits

#     # Flatten parameter sets across batch x tokens -> (batch*n_tokens, n_params)
#     flat_params = pqc_parameters.view(-1, pqc_parameters.shape[-1])

#     # Build or fetch unitaries for each flattened param vector
#     unitaries = []
#     for params_vec in flat_params:
#         U = get_or_build_unitary(n_qubits, params_vec, layers, device, cache)
#         unitaries.append(U)
#     # stack and reshape to [batch, n_tokens, dim, dim]
#     unitaries = torch.stack(unitaries, dim=0).view(batch, n_tokens, dim, dim)

#     # Expand initial states to [batch, n_tokens, dim]
#     expanded = initial_states.unsqueeze(1).expand(-1, n_tokens, -1)
#     # apply unitaries: [batch, n_tokens, dim] = einsum btij,btj->bti
#     evolved = torch.einsum("btij,btj->bti", unitaries, expanded)

#     # ensure lcu_coefficients complex dtype and shape [n_tokens]
#     lcu_coeffs = lcu_coefficients.to(torch.complex64).view(1, n_tokens, 1)
#     weighted = evolved * lcu_coeffs
#     summed = weighted.sum(dim=1)  # [batch, dim]
#     return summed

def apply_linear_combination_of_unitaries_braket(
    initial_states: torch.Tensor,       # [batch, dim]
    pqc_parameters: torch.Tensor,       # [batch, n_tokens, n_pqc_params] or [batch, n_pqc_params]
    n_qubits: int,
    lcu_coefficients: torch.Tensor,     # either [n_tokens] or [batch] complex
    device: LocalSimulator,
    layers: int = 1,
    cache: BraketUnitaryCache = None,
) -> torch.Tensor:
    batch, dim = initial_states.shape

    # Ensure pqc_parameters has 3 dims: [batch, n_tokens, n_params]
    if pqc_parameters.ndim == 2:
        pqc_parameters = pqc_parameters.unsqueeze(1)  # [batch, 1, n_params]

    batch_check, n_tokens, n_params = pqc_parameters.shape
    assert batch_check == batch, "Batch size mismatch"

    # Flatten parameters for unitary generation
    flat_params = pqc_parameters.reshape(-1, n_params)  # [batch*n_tokens, n_params]

    # Generate unitaries
    unitaries = [get_or_build_unitary(n_qubits, params, layers, device, cache) 
                 for params in flat_params]
    unitaries = torch.stack(unitaries, dim=0).reshape(batch, n_tokens, 2**n_qubits, 2**n_qubits)

    # Expand initial states to [batch, n_tokens, dim]
    expanded_states = initial_states.unsqueeze(1).expand(batch, n_tokens, 2**n_qubits)

    # Apply unitaries
    evolved_states = torch.einsum("btij,btj->bti", unitaries, expanded_states)

    # Handle lcu_coefficients: either [n_tokens] or [batch]
    if lcu_coefficients.numel() == n_tokens:
        lcu_coeffs = lcu_coefficients.to(torch.complex64).view(1, n_tokens, 1)
    elif lcu_coefficients.numel() == batch:
        lcu_coeffs = lcu_coefficients.to(torch.complex64).view(batch, 1, 1)
    else:
        raise ValueError(f"Unexpected lcu_coefficients size {lcu_coefficients.shape}")

    weighted_states = evolved_states * lcu_coeffs
    summed_states = weighted_states.sum(dim=1)  # sum over tokens
    return summed_states


def apply_qsvt_and_lcu_braket(
    initial_states: torch.Tensor,
    pqc_parameters: torch.Tensor,
    n_qubits: int,
    lcu_coefficients: torch.Tensor,
    qsvt_polynomial_coefficients: torch.Tensor,
    device: LocalSimulator,
    layers: int = 1,
    cache: BraketUnitaryCache = None,
) -> torch.Tensor:
    """
    Apply polynomial via repeated LCU application (classical-simulated QSVT).
    """
    accumulated_state = qsvt_polynomial_coefficients[0] * initial_states
    monomial_state = initial_states
    for c in qsvt_polynomial_coefficients[1:]:
        monomial_state = apply_linear_combination_of_unitaries_braket(
            monomial_state,
            pqc_parameters,
            n_qubits,
            lcu_coefficients,
            device,
            layers=layers,
            cache=cache,
        )
        accumulated_state = accumulated_state + c * monomial_state

    norm = torch.linalg.vector_norm(qsvt_polynomial_coefficients, ord=1)
    return accumulated_state / norm


In [9]:
# -------------------------
# 4) Pauli expectation helpers (from statevector)
# -------------------------
# def single_qubit_reduced_density(statevec: torch.Tensor, n_qubits: int, target: int) -> torch.Tensor:
#     """
#     Compute reduced 2x2 density matrix for a single qubit `target` by reshaping and tracing out others.
#     statevec: [dim] complex torch tensor
#     """
#     dim = 2**n_qubits
#     rho = torch.outer(statevec, torch.conj(statevec))  # [dim, dim]
#     # reshape to (2,)*n_qubits x (2,)*n_qubits and permute target to front
#     shape = [2] * (2 * n_qubits)
#     rho_reshaped = rho.view(*shape)
#     axes = list(range(n_qubits))
#     perm = [target] + [i for i in axes if i != target]
#     perm_full = perm + [p + n_qubits for p in perm]
#     rho_perm = rho_reshaped.permute(*perm_full)
#     rho2 = rho_perm.contiguous().view(2, 2, -1, -1)
#     reduced = torch.einsum("aabb->ab", rho2)  # trace out the other systems
#     return reduced

def single_qubit_reduced_density(statevec: torch.Tensor, n_qubits: int, target: int) -> torch.Tensor:
    """
    Compute reduced 2x2 density matrix for a single qubit `target` by tracing out other qubits.
    """
    # Full density matrix
    rho = torch.outer(statevec, torch.conj(statevec))  # [2^n, 2^n]

    # Reshape to (2,2,...,2) for n_qubits input & output axes
    shape = [2] * (2 * n_qubits)
    rho_reshaped = rho.view(*shape)

    # Permute target qubit to the front for both input/output axes
    axes = list(range(n_qubits))
    perm = [target] + [i for i in axes if i != target]  # input axes
    perm_full = perm + [p + n_qubits for p in perm]     # output axes
    rho_perm = rho_reshaped.permute(*perm_full)

    # Compute dimensions of "rest" qubits
    rest_dim = 2**(n_qubits - 1)
    rho2 = rho_perm.contiguous().view(2, rest_dim, 2, rest_dim)

    # Trace out other qubits
    reduced = torch.einsum("abab->ab", rho2)
    return reduced



def measure_all_x_y_z_from_statevector(states: torch.Tensor, n_qubits: int, device: Device) -> torch.Tensor:
    """
    states: [batch, dim] complex
    returns: [batch, 3*n_qubits] real expectations (order: for each qubit X,Y,Z)
    """
    sx = torch.tensor([[0, 1], [1, 0]], dtype=torch.complex64, device=device)
    sy = torch.tensor([[0, -1j], [1j, 0]], dtype=torch.complex64, device=device)
    sz = torch.tensor([[1, 0], [0, -1]], dtype=torch.complex64, device=device)

    batch = states.shape[0]
    res = []
    for b in range(batch):
        state = states[b]
        evs = []
        for q in range(n_qubits):
            reduced = single_qubit_reduced_density(state, n_qubits, q)
            ev_x = torch.trace(reduced @ sx).real
            ev_y = torch.trace(reduced @ sy).real
            ev_z = torch.trace(reduced @ sz).real
            evs.extend([ev_x, ev_y, ev_z])
        res.append(torch.stack(evs))
    return torch.stack(res, dim=0)  # [batch, 3*n_qubits]


In [10]:
# -------------------------
# 5) QuixerBraket model (torch.nn.Module)
# -------------------------
class Quixer(torch.nn.Module):
    def __init__(
        self,
        n_qubits: int,
        n_tokens: int,
        qsvt_polynomial_degree: int,
        n_ansatz_layers: int,
        vocabulary_size: int,
        embedding_dimension: int,
        dropout: float,
        batch_size: int,
        device: Device,
        braket_device: LocalSimulator = None,
    ):
        super().__init__()
        self.n_tokens = n_tokens
        self.n_qubits = n_qubits
        if qsvt_polynomial_degree <= 0:
            raise ValueError("qsvt_polynomial_degree must be > 0")
        self.degree = qsvt_polynomial_degree
        self.device = device
        self.n_ansatz_layers = n_ansatz_layers

        if n_tokens == 0:
            raise ValueError("n_tokens must be non-zero")
        self.n_ctrl_qubits = int(math.log2(n_tokens))

        self.n_pqc_parameters = 4 * n_qubits * n_ansatz_layers
        self.embedding_dimension = embedding_dimension

        # Embedding & angle mapper
        self.embedding = torch.nn.Embedding(vocabulary_size, self.embedding_dimension)
        torch.nn.init.xavier_uniform_(self.embedding.weight)
        self.embedding_to_angles = torch.nn.Linear(self.embedding_dimension, self.n_pqc_parameters)
        self.dropout = torch.nn.Dropout(dropout)

        # Braket device (LocalSimulator by default)
        self.braket_device = braket_device if braket_device is not None else LocalSimulator()

        # Parameters (trainable)
        self.qsvt_polynomial_coefficients = torch.nn.Parameter(torch.rand(self.degree + 1))
        # complex LCU coefficients: store as two real tensors and combine when needed,
        # but for simplicity store as complex dtype parameter (PyTorch supports it).
        self.lcu_coefficients = torch.nn.Parameter(torch.rand(n_tokens, dtype=torch.complex64))

        # Final PQC parameters
        self.quantum_feedforward_parameters = torch.nn.Parameter(torch.rand(self.n_pqc_parameters))

        self.nr_of_measurements = 3 * n_qubits
        self.output_feedforward = torch.nn.Sequential(
            torch.nn.Linear(self.nr_of_measurements, self.embedding_dimension),
            torch.nn.ReLU(),
            torch.nn.Linear(self.embedding_dimension, vocabulary_size),
        )

        # unitary cache
        self.unitary_cache = _unitary_cache

        self.output_layer = torch.nn.Linear(2**self.n_qubits, vocabulary_size)

    # def forward(self, x: torch.Tensor):
    #     """
    #     x: [batch, n_tokens] integer token ids
    #     returns: (output_logits [batch, vocab_size], mean_postselection_prob scalar)
    #     """
    #     batch_size = x.shape[0]

    #     # LCU coefficients repeated per batch and normalized (1-norm)
    #     lcu_coefficients = self.lcu_coefficients.repeat(batch_size, 1)
    #     lcu_coefficients = torch.nn.functional.normalize(lcu_coefficients, p=1, dim=1)

    #     # Get embeddings & PQC angles: [batch, n_tokens, n_pqc_parameters]
    #     x_emb = self.embedding(x)
    #     pqc_angles = self.embedding_to_angles(self.dropout(x_emb))

    #     # Initial |0> states for each batch
    #     initial_states = torch.zeros(batch_size, 2**self.n_qubits, dtype=torch.complex64, device=self.device)
    #     initial_states[:, 0] = 1.0 + 0.0j

    #     # Apply QSVT + LCU (classical simulation using Braket)
    #     # Note: LCU coefficients in apply_qsvt_and_lcu_braket expect shape [n_tokens] complex;
    #     # in our implementation lcu_coefficients is [batch, n_tokens] so pick first row since identical per batch
        # qsvt_lcu_state = apply_qsvt_and_lcu_braket(
        #     initial_states,
        #     pqc_angles,
        #     self.n_qubits,
        #     lcu_coefficients[0],  # assume same across batch (like your original code)
        #     self.qsvt_polynomial_coefficients,
        #     self.braket_device,
        #     layers=self.n_ansatz_layers,
        #     cache=self.unitary_cache,
        # )
        
        # # Normalize and set as current state for final feedforward PQC application
        # norms = torch.linalg.vector_norm(qsvt_lcu_state, dim=-1, keepdim=True)
        # normalized_states = qsvt_lcu_state / (norms + 1e-12)

        # # Build unitary for final feedforward PQC (single param vector)
        # feed_U = get_or_build_unitary(self.n_qubits, self.quantum_feedforward_parameters, 1, self.braket_device, self.unitary_cache)
        # # Apply (matrix-vector): feed_U @ psi  (feed_U has shape [dim, dim])
        # # Note on orientation: our `circuit_unitary` built columns as U |k> so standard matrix.
        # final_states = torch.einsum("ij,bj->bi", feed_U, normalized_states)

        # # Measure expectations (X,Y,Z per qubit) from final_states
        # expectation_values = measure_all_x_y_z_from_statevector(final_states, self.n_qubits, self.device)

        # output_logits = self.output_feedforward(expectation_values)

        # # Postselection probabilities (norms before normalization)
        # final_probabilities = norms.squeeze(-1)  # [batch]
        # return output_logits, torch.mean(final_probabilities)

    # def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
    #     """
    #     Forward pass for Quixer model.
        
    #     Args:
    #         x: [batch, window_size] of token indices
    
    #     Returns:
    #         logits: [batch, vocab_size]
    #         mean_postselection_prob: scalar, average norm after LCU
    #     """
    #     # Ensure x is 2D: [batch, seq_len]
    #     if x.dim() == 1:
    #         x = x.unsqueeze(0)
    #     batch, seq_len = x.shape
    
    #     # 1. Embeddings
    #     x_emb = self.embedding(x)  # [batch, seq_len, embedding_dim]
    #     pqc_angles = self.embedding_to_angles(self.dropout(x_emb))  # [batch, seq_len, n_pqc_params]
    
    #     # 2. Prepare initial quantum states
    #     dim = 2 ** self.n_qubits
    #     initial_states = torch.zeros(batch, dim, dtype=torch.complex64, device=x.device)
    #     initial_states[:, 0] = 1.0 + 0.0j  # |0> for each batch

    #     # 3. LCU coefficients
    #     # Make 2D: [1, n_tokens] for normalization, then expand to batch
    #     lcu_coeffs = self.lcu_coefficients.unsqueeze(0)  # [1, n_tokens]
    #     lcu_coeffs = torch.nn.functional.normalize(lcu_coeffs, p=1, dim=1)  # normalize across tokens
    #     lcu_coeffs = lcu_coeffs.expand(batch, -1)  # [batch, n_tokens]
    
    #     # 4. Apply QSVT + LCU (classical simulation)
    #     qsvt_lcu_state = apply_qsvt_and_lcu_braket(
    #         initial_states,
    #         pqc_angles,
    #         self.n_qubits,
    #         lcu_coeffs,
    #         self.qsvt_polynomial_coefficients,
    #         self.braket_device,
    #         layers=self.n_ansatz_layers,
    #         cache=self.unitary_cache,
    #     )
    
    #     # 5. Normalize LCU-evolved states and compute average norm
    #     norms = torch.linalg.vector_norm(qsvt_lcu_state, dim=-1, keepdim=True)  # [batch, 1]
    #     normalized_states = qsvt_lcu_state / (norms + 1e-12)
    #     mean_postselection_prob = norms.mean()

    #     # 6. Apply final feedforward PQC
    #     feed_U = get_or_build_unitary(
    #         self.n_qubits,
    #         self.quantum_feedforward_parameters,
    #         1,
    #         self.braket_device,
    #         self.unitary_cache,
    #     )  # [dim, dim]
    
    #     final_states = torch.einsum("ij,bj->bi", feed_U, normalized_states)  # [batch, dim]
    
    #     # 7. Classical projection to vocab logits
    #     # logits = self.output_layer(final_states)  # [batch, vocab_size]
    #     logits = self.output_layer(final_states.real)
    
    #     return logits, mean_postselection_prob

    def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
        """
        Forward pass for Quixer model.
    
        Args:
            x: [batch, seq_len] of token indices
    
        Returns:
            logits: [batch, vocab_size]
            mean_postselection_prob: scalar tensor
        """
        batch = x.shape[0]
    
        # 1. Embedding lookup
        x_emb = self.embedding(x)  # [batch, seq_len, embedding_dim]
        pqc_angles = self.embedding_to_angles(self.dropout(x_emb))  # [batch, seq_len, n_pqc_params]
    
        # 2. Prepare initial quantum states |0>
        dim = 2**self.n_qubits
        initial_states = torch.zeros(batch, dim, dtype=torch.complex64, device=x.device)
        initial_states[:, 0] = 1.0 + 0.0j
    
        # 3. Compute LCU coefficients
        # If single set per model, expand to batch
        lcu_coeffs = torch.nn.functional.normalize(
            self.lcu_coefficients, p=1, dim=-1
        )  # [n_tokens] or [1, n_tokens]
        if lcu_coeffs.dim() == 1:
            lcu_coeffs = lcu_coeffs.unsqueeze(0).expand(batch, -1)  # [batch, n_tokens]
    
        # 4. Apply QSVT + LCU
        qsvt_lcu_state = apply_qsvt_and_lcu_braket(
            initial_states,
            pqc_angles,
            self.n_qubits,
            lcu_coeffs[0],  # assume same across batch
            self.qsvt_polynomial_coefficients,
            self.braket_device,
            layers=self.n_ansatz_layers,
            cache=self.unitary_cache,
        )
    
        # 5. Normalize state
        norms = torch.linalg.vector_norm(qsvt_lcu_state, dim=-1, keepdim=True)
        normalized_states = qsvt_lcu_state / (norms + 1e-12)
        mean_postselection_prob = norms.mean()
    
        # 6. Apply final feedforward PQC
        feed_U = get_or_build_unitary(
            self.n_qubits,
            self.quantum_feedforward_parameters,
            layers=self.n_ansatz_layers,
            device=self.braket_device,
            cache=self.unitary_cache,
        )  # [dim, dim]
    
        # Apply to normalized states
        final_states = torch.einsum("ij,bj->bi", feed_U, normalized_states)  # [batch, dim]
    
        # 7. Classical projection to vocab logits
        # Convert complex to real for nn.Linear
        logits = self.output_layer(final_states.real)  # [batch, vocab_size]
    
        return logits, mean_postselection_prob



In [11]:
# -------------------------
# 6) Small example usage
# -------------------------
if __name__ == "__main__":
    # small example: 2 qubits, 2 tokens, tiny vocab
    n_qubits = 2
    n_tokens = 2
    qsvt_degree = 2
    n_layers = 1
    vocab = 10
    emb_dim = 8
    dropout = 0.0
    batch_size = 1
    device = torch.device("cpu")

    model = Quixer(
        n_qubits=n_qubits,
        n_tokens=n_tokens,
        qsvt_polynomial_degree=qsvt_degree,
        n_ansatz_layers=n_layers,
        vocabulary_size=vocab,
        embedding_dimension=emb_dim,
        dropout=dropout,
        batch_size=batch_size,
        device=device,
    )

    # dummy input: batch of 1, sequence length n_tokens
    x = torch.randint(0, vocab, (batch_size, n_tokens))
    logits, mean_prob = model(x)
    print("logits:", logits)
    print("mean postselection prob:", mean_prob)

logits: tensor([[-0.0835, -0.4767, -0.4800, -0.4976, -0.5488,  0.7021,  0.6041,  0.0558,
         -0.1709, -0.2785]], grad_fn=<AddmmBackward0>)
mean postselection prob: tensor(0.7684, grad_fn=<MeanBackward0>)


In [12]:
import sys
print(sys.executable)

/home/ec2-user/anaconda3/envs/torch_env/bin/python


In [13]:
# !pip cache purge

In [14]:
# !{sys.executable} -m pip uninstall -y torchtext
# !pip uninstall -y torch

In [15]:
import sys
# !{sys.executable} -m pip install "torchtext==0.17.1"
# !{sys.executable} -m pip install --force-reinstall "torch==2.9.1" "torchtext==0.17.1" --index-url https://download.pytorch.org/whl/cpu

# !conda run -n torch_env pip uninstall -y torch torchtext
# !conda run -n torch_env pip uninstall -y torch torchtext
# !conda run -n torch_env pip install torch==2.3.0 torchtext==0.18.0 amazon-braket-sdk torchquantum

In [16]:
!conda run -n torch_env pip install datasets

Collecting datasets
  Using cached datasets-4.4.1-py3-none-any.whl.metadata (19 kB)
Collecting pyarrow>=21.0.0 (from datasets)
  Using cached pyarrow-22.0.0-cp310-cp310-manylinux_2_28_x86_64.whl.metadata (3.2 kB)
Collecting pandas (from datasets)
  Using cached pandas-2.3.3-cp310-cp310-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl.metadata (91 kB)
Collecting httpx<1.0.0 (from datasets)
  Using cached httpx-0.28.1-py3-none-any.whl.metadata (7.1 kB)
Collecting xxhash (from datasets)
  Using cached xxhash-3.6.0-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl.metadata (13 kB)
Collecting huggingface-hub<2.0,>=0.25.0 (from datasets)
  Downloading huggingface_hub-1.1.7-py3-none-any.whl.metadata (13 kB)
Collecting aiohttp!=4.0.0a0,!=4.0.0a1 (from fsspec[http]<=2025.10.0,>=2023.1.0->datasets)
  Using cached aiohttp-3.13.2-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl.metadata (8.1 kB)
Collecting anyio (from httpx<1.0.0->datas

In [17]:
# !pip install torchtext
!conda run -n torch_env pip install torchdata

Collecting torchdata
  Using cached torchdata-0.11.0-py3-none-any.whl.metadata (6.3 kB)
Using cached torchdata-0.11.0-py3-none-any.whl (61 kB)
Installing collected packages: torchdata
Successfully installed torchdata-0.11.0



In [18]:
########### Quixer/quixer/setuptraining.py
import random
import os
import time
import math
from tqdm import tqdm
from typing import Any, Optional, Tuple, Callable
from pathlib import Path

import numpy as np
import random

import torch
from torch.types import Device
from torch.nn.modules.loss import _Loss
import torchtext

# from quixer.quixer_model import Quixer
# from quixer.baseline_models import Transformer, LSTM, FNet
from torchtext.datasets import PennTreebank
from collections import Counter

from datasets import load_dataset
from torchtext.vocab import build_vocab_from_iterator
from torchtext.data.utils import get_tokenizer


def epoch_time(start_time: float, end_time: float) -> Tuple[float, float]:
    """
    Computes time elapsed in minutes and seconds when given two UNIX timestamps
    with the starting time and ending time.

    Args:
      start_time: Starting time as a UNIX timestamp.
      end_time: End time as a UNIX timestamp.
    """
    elapsed_time = end_time - start_time
    elapsed_mins = int(elapsed_time / 60)
    elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
    return elapsed_mins, elapsed_secs


def batchify_s2s(
    data: torch.Tensor,
    batch_size: int,
    window_size: int,
    pad_token_id: int,
    device: Device,
) -> torch.Tensor:
    """
    Takes in a sequence of token IDs as a torch tensor `data` and returns a torch tensor containing
    the training data with shape `[number of batches + window_size, batch_size]`.

    Each batch is represented by `window_size` contiguous rows in the returned tensor and
    can be extracted using the `get_batch_s2s` function.

    A sequence of pad tokens of length `window_size-1` is prepended to the data so as to
    provide a context window for the first token.

    Args:
      data: A 1D torch tensor containing a sequence of token IDs.
      batch_size: The number of sequences each batch should have.
      window_size: How many tokens are considered in each context window (each of which is a sequence in the batch).
      pad_token_id: The ID of the pad token, as supplied by the tokenizer.
      device: Torch device the returned tensor is to be created on.

    Returns:
      Tensor containing data for each batch prepared for a next token prediction language
      modelling task.
    """
    batch_nr_of_elements = batch_size * window_size
    nr_of_batches = (data.size(0) - 1) // batch_nr_of_elements

    # Discard tokens at the end of the data that do not fill a whole batch
    batched_data = (
        data[: nr_of_batches * batch_nr_of_elements]
        .view(batch_nr_of_elements, nr_of_batches)
        .T
    )

    # Data for the first batch
    window_data = torch.cat(
        (
            # Adds a sequence of pad tokens of length `window_size-1`
            # to provide a context window for the first token.
            torch.full((window_size, 1), pad_token_id, device=device),
            # Context for the first row of tokens in `batched_data`
            batched_data[-window_size:, :-1],
        ),
        dim=1,
    )

    return torch.cat((window_data, batched_data))


def get_batch_s2s(
    source: torch.Tensor, i: int, window_size: int
) -> tuple[torch.Tensor, torch.Tensor]:
    """
    Returns the `i`th batch; expects one of the tensors returned by `setup_dataset`.

    Args:
      source: Tensor containing data.
      i: Index of the batch.
      window_size: Context window size.
    Returns:
      The `i`th batch.
    """
    return source[i : i + window_size].T, source[i + window_size]


def initialise_weights(model: torch.nn.Module) -> None:
    """
    Initialises model weights.
    """

    def _init_weights(m):
        if type(m) == torch.nn.Linear:
            torch.nn.init.xavier_uniform_(m.weight)
            if m.bias is not None:
                torch.nn.init.zeros_(m.bias)
        elif isinstance(m, torch.nn.Embedding):
            torch.nn.init.normal_(m.weight, mean=0.0, std=0.02)

    model.apply(_init_weights)


# def setup_dataset(
#     device: Device, batch_size: int, window_size: int
# ) -> Tuple[torchtext.vocab.Vocab, Tuple[torch.Tensor, torch.Tensor, torch.Tensor], int]:
#     """
#     Downloads and tokenizes the Penn TreeBank dataset, and then sets it up for a
#     next-word prediction task.

#     Args:
#       device: Device to store dataset on.
#       batch_size: Size of the batches.
#       window_size: Size of the context window.

#     Returns:
#       Vocabulary represented by a torchtext.vocab.Vocab instance along with
#       three torch tensors containing the training, validation and test data.
#     """

#     # Download dataset from the Hugging Face Hub / load dataset
#     # raw_dset = load_dataset("ptb_text_only")
#     raw_dset = load_dataset("ptb_text_only", "penn_treebank")

#     # Get training data in PyArrow format
#     train_iter = raw_dset["train"].data[0]
#     # Convert from arrow array to native Python list
#     train_iter = [s.as_py() for s in train_iter]

#     # Get torchtext tokenizer
#     tokenizer = get_tokenizer("basic_english")

#     vocab = build_vocab_from_iterator(
#         map(tokenizer, train_iter), specials=["<pad>", "<unk>", "<eos>"]
#     )
#     # Define unknown word as the default index to use
#     vocab.set_default_index(vocab["<unk>"])

#     def data_process(raw_text_iter) -> torch.Tensor:
#         """
#         Converts raw text into a flat Tensor of token indices.
#         """
#         data = [
#             torch.tensor(vocab(tokenizer(item)) + [vocab["eos"]], dtype=torch.long)
#             for item in raw_text_iter
#         ]
#         return torch.cat(tuple(filter(lambda t: t.numel() > 1, data))).to(device)

#     # Convert from arrow arrays to native Python lists
#     train_sents = [s.as_py() for s in raw_dset["train"].data[0]]
#     val_sents = [s.as_py() for s in raw_dset["validation"].data[0]]
#     test_sents = [s.as_py() for s in raw_dset["test"].data[0]]

#     # Flatten datasets into one long tokenised string each
#     train_flat = data_process(train_sents)
#     val_flat = data_process(val_sents)
#     test_flat = data_process(test_sents)

#     # Get padding token
#     PAD_TOKEN = vocab["<pad>"]

#     # Prepare data for a next-token prediction language modelling task
#     train_iter = batchify_s2s(train_flat, batch_size, window_size, PAD_TOKEN, device)
#     val_iter = batchify_s2s(val_flat, batch_size, window_size, PAD_TOKEN, device)
#     test_iter = batchify_s2s(test_flat, batch_size, window_size, PAD_TOKEN, device)

#     return vocab, (train_iter, val_iter, test_iter), PAD_TOKEN

# def setup_dataset(device, batch_size, window_size):
#     # Load modern HuggingFace PTB
#     # raw = load_dataset("ptb")
#     # raw = load_dataset("ptb_text_only", "penn_treebank")
#     raw = load_dataset("penn_treebank")

#     # Extract text
#     train_texts = [x["sentence"] for x in raw["train"]]
#     val_texts   = [x["sentence"] for x in raw["validation"]]
#     test_texts  = [x["sentence"] for x in raw["test"]]

#     # Tokenize simply by splitting
#     train_tokens = [t.split() for t in train_texts]
#     val_tokens   = [t.split() for t in val_texts]
#     test_tokens  = [t.split() for t in test_texts]

#     # Build vocab
#     counter = Counter()
#     for sent in train_tokens:
#         counter.update(sent)

#     vocab = {"<pad>": 0, "<unk>": 1}
#     for tok, _ in counter.most_common():
#         vocab[tok] = len(vocab)

#     PAD_TOK = vocab["<pad>"]

#     # Numericalize
#     def numericalize(tokens):
#         return torch.tensor([vocab.get(tok, 1) for tok in tokens], dtype=torch.long)

#     train_tensors = [numericalize(t) for t in train_tokens]
#     val_tensors   = [numericalize(t) for t in val_tokens]
#     test_tensors  = [numericalize(t) for t in test_tokens]

#     # Simple batching (same as legacy PTB examples)
#     def make_batches(seqs):
#         batches = []
#         for seq in seqs:
#             # chop into windows
#             for i in range(0, len(seq) - window_size):
#                 X = seq[i:i+window_size]
#                 y = seq[i+1:i+window_size+1]
#                 batches.append((X, y))
#         return batches

#     train_iter = make_batches(train_tensors)
#     val_iter   = make_batches(val_tensors)
#     test_iter  = make_batches(test_tensors)

#     return vocab, (train_iter, val_iter, test_iter), PAD_TOK

def setup_dataset(device, batch_size, window_size, ptb_dir="/home/ec2-user/SageMaker/quixer"):

    # Load raw PTB text from local files
    def load_file(name):
        path = os.path.join(ptb_dir, name)
        with open(path, "r") as f:
            return [line.strip() for line in f if line.strip()]

    train_texts = load_file("ptb.train.txt")
    val_texts = load_file("ptb.valid.txt")
    test_texts = load_file("ptb.test.txt")

    # Tokenize simply by splitting
    train_tokens = [t.split() for t in train_texts]
    val_tokens   = [t.split() for t in val_texts]
    test_tokens  = [t.split() for t in test_texts]

    # Build vocab
    counter = Counter()
    for sent in train_tokens:
        counter.update(sent)

    vocab = {"<pad>": 0, "<unk>": 1}
    for tok, _ in counter.most_common():
        vocab[tok] = len(vocab)

    PAD_TOK = vocab["<pad>"]

    # Numericalize
    def numericalize(tokens):
        return torch.tensor([vocab.get(tok, 1) for tok in tokens], dtype=torch.long)

    train_tensors = [numericalize(t) for t in train_tokens]
    val_tensors   = [numericalize(t) for t in val_tokens]
    test_tensors  = [numericalize(t) for t in test_tokens]

    # Simple batching (same as legacy PTB examples)
    def make_batches(seqs):
        batches = []
        for seq in seqs:
            # chop into windows
            for i in range(0, len(seq) - window_size):
                X = seq[i:i+window_size]
                y = seq[i+1:i+window_size+1]
                batches.append((X, y))
        return batches

    train_iter = make_batches(train_tensors)
    val_iter   = make_batches(val_tensors)
    test_iter  = make_batches(test_tensors)

    return vocab, (train_iter, val_iter, test_iter), PAD_TOK


def create_model(
    hyperparams: dict[str, Any], device: Device, vocabulary_size: int
) -> torch.nn.Module:
    """
    Selects and creates model based on hyperparameters passed.

    Args:
      hyperparams: Model hyperparameters.
      device: Device the model will be run on.
      vocabulary_size: Size of the vocabulary.
    Returns:
      An instance of a torch model based on the hyperparameters passed.
    """
    model_str = hyperparams["model"]
    model: torch.nn.Module
    if model_str == "Quixer":
        model = Quixer(
            n_qubits=hyperparams["qubits"],
            n_tokens=hyperparams["window"],
            qsvt_polynomial_degree=hyperparams["layers"],
            n_ansatz_layers=hyperparams["ansatz_layers"],
            vocabulary_size=vocabulary_size,
            embedding_dimension=hyperparams["dimension"],
            dropout=hyperparams["dropout"],
            batch_size=hyperparams["batch_size"],
            device=device,
        )
    else:
        raise ValueError(f"Unrecognized model: {model_str}")

    return model


# def train_epoch(
#     model: torch.nn.Module,
#     iterator: torch.Tensor,
#     optimizer: torch.optim.Optimizer,
#     loss_function: _Loss,
#     clip: float,
#     scheduler: Optional[torch.optim.lr_scheduler.LRScheduler],
#     window_size: int,
# ):
#     """
#     Runs training loop for one epoch.
#     """
#     model.train()

#     epoch_loss = 0

#     # n_batches = iterator.shape[0] - window_size
#     n_batches = len(iterator) - window_size

#     idxs = list(range(n_batches))
#     random.shuffle(idxs)

#     for ctr, batch_idx in tqdm(enumerate(idxs), total=n_batches):
#         x, y = get_batch_s2s(iterator, batch_idx, window_size)
#         optimizer.zero_grad()

#         yhat, norm_avg = model(x)

#         loss = loss_function(yhat, y)
#         loss.backward()

#         if clip:
#             torch.nn.utils.clip_grad_norm_(model.parameters(), clip)

#         optimizer.step()
#         if scheduler:
#             scheduler.step()

#         epoch_loss += loss.item()

#     return epoch_loss / n_batches

# def train_epoch(
#     model: torch.nn.Module,
#     iterator,
#     optimizer: torch.optim.Optimizer,
#     loss_function: _Loss,
#     clip: float,
#     scheduler: Optional[torch.optim.lr_scheduler.LRScheduler],
#     window_size: int,
# ):
#     """
#     Runs training loop for one epoch.
#     Supports either:
#         - iterator as a 1D torch.Tensor (legacy sliding window)
#         - iterator as a list of (X, y) tuples (batch-prepared)
#     """
#     model.train()
#     epoch_loss = 0

#     # Determine if iterator is a tensor or list of tuples
#     if isinstance(iterator, torch.Tensor):
#         n_batches = len(iterator) - window_size
#         idxs = list(range(n_batches))
#         random.shuffle(idxs)

#         for ctr, batch_idx in tqdm(enumerate(idxs), total=n_batches):
#             x, y = get_batch_s2s(iterator, batch_idx, window_size)
#             optimizer.zero_grad()

#             yhat, norm_avg = model(x)
#             loss = loss_function(yhat, y)
#             loss.backward()

#             if clip:
#                 torch.nn.utils.clip_grad_norm_(model.parameters(), clip)

#             optimizer.step()
#             if scheduler:
#                 scheduler.step()

#             epoch_loss += loss.item()

#     elif isinstance(iterator, list):
#         n_batches = len(iterator)
#         idxs = list(range(n_batches))
#         random.shuffle(idxs)

#         for ctr, batch_idx in tqdm(enumerate(idxs), total=n_batches):
#             # batch is already (X, y)
#             X, Y = iterator[batch_idx]
#             x = torch.tensor(X, dtype=torch.float32).T  # or torch.long if token IDs
#             y = torch.tensor(Y, dtype=torch.float32)    # or torch.long if classification

#             optimizer.zero_grad()
#             yhat, norm_avg = model(x)
#             loss = loss_function(yhat, y)
#             loss.backward()

#             if clip:
#                 torch.nn.utils.clip_grad_norm_(model.parameters(), clip)

#             optimizer.step()
#             if scheduler:
#                 scheduler.step()

#             epoch_loss += loss.item()

#     else:
#         raise TypeError(f"Unsupported iterator type: {type(iterator)}")

#     return epoch_loss / n_batches

def train_epoch(
    model: torch.nn.Module,
    iterator: list[tuple[list[int], list[int]]],  # list of (X, Y) batches
    optimizer: torch.optim.Optimizer,
    loss_function: _Loss,
    clip: float,
    scheduler: Optional[torch.optim.lr_scheduler.LRScheduler],
    window_size: int,  # not used directly here since batches are pre-made
):
    """
    Runs training loop for one epoch using pre-batched sequences.

    Args:
        model: PyTorch model.
        iterator: List of (X, Y) batches.
        optimizer: Optimizer.
        loss_function: Loss function.
        clip: Gradient clipping value.
        scheduler: Learning rate scheduler (optional).
        window_size: sequence length (for legacy compatibility).
    """
    model.train()
    epoch_loss = 0

    n_batches = len(iterator)
    idxs = list(range(n_batches))
    random.shuffle(idxs)

    device = next(model.parameters()).device

    for batch_idx in tqdm(idxs, total=n_batches):
        X, Y = iterator[batch_idx]

        # Convert to tensors with correct dtype
        x = torch.tensor(X, dtype=torch.long).to(device)  # for nn.Embedding
        y = torch.tensor(Y, dtype=torch.long).to(device)  # for CrossEntropyLoss

        optimizer.zero_grad()
        yhat, norm_avg = model(x)

        loss = loss_function(yhat, y)
        loss.backward()

        if clip:
            torch.nn.utils.clip_grad_norm_(model.parameters(), clip)

        optimizer.step()
        if scheduler:
            scheduler.step()

        epoch_loss += loss.item()

    return epoch_loss / n_batches

def evaluate(
    model: torch.nn.Module,
    data: torch.Tensor,
    loss_function: _Loss,
    window_size: int,
) -> float:
    """
    Evaluates model on the supplied data.
    """

    model.eval()

    epoch_loss = 0

    n_batches = data.shape[0] - window_size

    with torch.no_grad():
        for batch_idx in tqdm(range(n_batches)):
            x, y = get_batch_s2s(data, batch_idx, window_size)

            yhat, _ = model(x)

            loss = loss_function(yhat, y)

            epoch_loss += loss.item()

    return epoch_loss / n_batches


def train_cycle(
    model: torch.nn.Module,
    hyperparams: dict[str, Any],
    train_iter: torch.Tensor,
    val_iter: torch.Tensor,
    test_iter: torch.Tensor,
) -> float:
    """
    Run a training cycle.

    Args:
      model: The model to train.
      hyperparams: The model hyperparameters.
      train_iter: Tensor containing training data returned by `setup_dataset` function.
      val_iter: Tensor containing validation data returned by `setup_dataset` function.
      test_iter: Tensor containing test data returned by `setup_dataset` function.
    """

    folder_path = Path("./trained_models")
    folder_path.mkdir(exist_ok=True, parents=True)
    checkpoint_fpath = (
        folder_path
        / f"q_transformer_lm_{hyperparams['model']}_{hyperparams['seed']}_{int(time.time())}.pt"
    )

    # Set up optimizer
    optimizer = torch.optim.Adam(
        model.parameters(),
        lr=hyperparams["lr"],
        weight_decay=hyperparams["wd"],
        eps=hyperparams["eps"],
    )

    # Set up learning rate scheduler
    scheduler = None
    if hyperparams["lr_sched"] == "cos":
        scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
            optimizer, T_0=hyperparams["restart_epochs"]
        )

    loss_function = torch.nn.CrossEntropyLoss()

    def _evaluate(iter: torch.Tensor):
        return evaluate(model, iter, loss_function, hyperparams["window"])

    best_valid_loss = float("inf")
    for epoch in range(hyperparams["epochs"]):
        start_time = time.time()

        train_loss = train_epoch(
            model,
            train_iter,
            optimizer,
            loss_function,
            hyperparams["max_grad_norm"],
            scheduler,
            hyperparams["window"],
        )

        valid_loss = _evaluate(val_iter)

        end_time = time.time()

        epoch_mins, epoch_secs = epoch_time(start_time, end_time)

        if valid_loss < best_valid_loss:
            best_valid_loss = valid_loss
            torch.save(model.state_dict(), checkpoint_fpath)

        print(f"Epoch: {epoch + 1:02} | Time: {epoch_mins}m {epoch_secs}s")
        print(f"\tTrain Loss: {train_loss:.3f} | Train ppl: {math.exp(train_loss)}")
        print(f"\t Val. Loss: {valid_loss:.3f} |  Val. ppl: {math.exp(valid_loss)}")

    model.load_state_dict(torch.load(checkpoint_fpath))

    valid_loss = _evaluate(val_iter)
    test_loss = _evaluate(test_iter)

    print("FINAL TRAINED MODEL STATS:")
    print(f"\t Val. Loss: {valid_loss:.3f} |  Val. ppl: {math.exp(valid_loss)}")
    print(f"\t Test Loss: {test_loss:.3f} |  Test ppl: {math.exp(test_loss)}")

    return test_loss


def seed(SEED: int) -> None:
    """
    Sets the seed for Python's random module, numpy's RNG and torch's RNG.

    Args:
      SEED: integer specifying the seed
    """
    random.seed(SEED)
    np.random.seed(SEED)
    torch.manual_seed(SEED)
    torch.cuda.manual_seed(SEED)


def get_train_evaluate(device: Device) -> Callable:
    """
    Returns a function that runs the training cycle on a specified torch device.

    Args:
      device: Torch device

    Returns:
      Callable taking in a set of parameters as a dict and returning the value of the validation loss
      at the end of the training cycle.
    """

    def train_evaluate(parameterization: dict[str, Any]) -> float:
        """
        Train the model and return the test loss.
        """

        if "seed" not in parameterization:
            parameterization["seed"] = int.from_bytes(os.urandom(4), "big")

        # seed(parameterization["seed"])

        torch.manual_seed(parameterization["seed"])
        random.seed(parameterization["seed"])
        np.random.seed(parameterization["seed"])

        vocab, (train_iter, val_iter, test_iter), PAD_TOK = setup_dataset(
            device, parameterization["batch_size"], parameterization["window"]
        )

        model = create_model(parameterization, device, len(vocab))

        initialise_weights(model)

        model = model.to(device)

        valid_loss = train_cycle(
            model, parameterization, train_iter, val_iter, test_iter
        )

        return valid_loss

    return train_evaluate

  from .autonotebook import tqdm as notebook_tqdm


In [19]:
import torch
import torchtext

print(torch.__version__)
print(torchtext.__version__)

2.3.0+cu121
0.18.0+cpu


In [20]:
!conda run -n torch_env pip install argparse

Collecting argparse
  Using cached argparse-1.4.0-py2.py3-none-any.whl.metadata (2.8 kB)
Using cached argparse-1.4.0-py2.py3-none-any.whl (23 kB)
Installing collected packages: argparse
Successfully installed argparse-1.4.0



In [None]:
import argparse
import torch

# from quixer.setup_training import get_train_evaluate


##################################################
# Default hyperparameters for each of the models #
##################################################

quixer_hparams = {
    "qubits": 6,
    "layers": 3,
    "ansatz_layers": 4,
    "window": 32,
    "epochs": 30,
    "restart_epochs": 30000,
    "dropout": 0.10,
    "lr": 0.002,
    "lr_sched": "cos",
    "wd": 0.0001,
    "eps": 1e-10,
    "batch_size": 32,
    "max_grad_norm": 5.0,
    "model": "Quixer",
    "print_iter": 50,
}


lstm_hparams = {
    "layers": 2,
    "window": 32,
    "residuals": False,
    "epochs": 30,
    "restart_epochs": 30000,
    "dropout": 0.30,
    "lr": 0.002,
    "lr_sched": "cos",
    "wd": 0.0001,
    "eps": 1e-10,
    "batch_size": 32,
    "max_grad_norm": 5.0,
    "print_iter": 50,
}


fnet_hparams = {
    "layers": 2,
    "window": 32,
    "epochs": 30,
    "restart_epochs": 30000,
    "dropout": 0.10,
    "lr": 0.002,
    "lr_sched": "cos",
    "wd": 0.0001,
    "eps": 1e-10,
    "batch_size": 32,
    "max_grad_norm": 5.0,
    "model": "FNet",
    "print_iter": 50,
}


transformer_hparams = {
    "layers": 1,
    "heads": 1,
    "window": 32,
    "epochs": 30,
    "restart_epochs": 30000,
    "dropout": 0.10,
    "lr": 0.001,
    "lr_sched": "cos",
    "wd": 0.0001,
    "eps": 1e-10,
    "batch_size": 32,
    "max_grad_norm": 5.0,
    "model": "Transformer",
    "print_iter": 50,
}

##################################################


# Embedding dimensions
classical_embedding_dimensions = [96, 128]
quantum_embedding_dimensions = [512]

# Dictionary defining available models along with associated hyperparameters
model_map = {
    "Quixer": (quixer_hparams, quantum_embedding_dimensions),
    "Transformer": (transformer_hparams, classical_embedding_dimensions),
    "LSTM": (lstm_hparams, classical_embedding_dimensions),
    "FNet": (fnet_hparams, classical_embedding_dimensions),
}
available_models = list(model_map.keys())

# Parse command line arguments
args = argparse.ArgumentParser(
    prog="Quixer", description="Runs the Quixer model and/or classical baselines"
)
args.add_argument(
    "-m",
    "--model",
    default="Quixer",
    choices=available_models,
    nargs="*",
    help="Model(s) to run.",
)
args.add_argument("-d", "--device", default="cpu", help="Device to run training on.")
# parsed = args.parse_args()
parsed = args.parse_args([])

device_name = parsed.device
models_to_run = parsed.model if type(parsed.model) is list else [parsed.model]

# Make algorithms deterministic for reproducibility
torch.backends.cudnn.deterministic = True


device = torch.device(device_name)
print(f"Running on device: {device}")

train_evaluate = get_train_evaluate(device)


# for model_name in models_to_run:
#     hyperparameters, embedding_dimensions = model_map[model_name]
#     for embedding_dimension in embedding_dimensions:
#         for seed in torch.randint(high=1000000, size=(10,)).tolist():
#             hyperparameters["model"] = model_name
#             hyperparameters["dimension"] = embedding_dimension
#             hyperparameters["seed"] = seed

#             train_evaluate(hyperparameters)

for model_name in models_to_run:
    hyperparameters, embedding_dimensions = model_map[model_name]
    for embedding_dimension in embedding_dimensions:
        # DON'T SHADOW seed() — rename the loop variable
        for run_seed in torch.randint(high=1000000, size=(10,)).tolist():
            hyperparameters["model"] = model_name
            hyperparameters["dimension"] = embedding_dimension
            hyperparameters["seed"] = run_seed

            train_evaluate(hyperparameters)

Running on device: cpu


  x = torch.tensor(X, dtype=torch.long).to(device)  # for nn.Embedding
  y = torch.tensor(Y, dtype=torch.long).to(device)  # for CrossEntropyLoss
