In [1]:
!pip install pennylane pennylane-lightning pennylane-lightning[gpu] cotengra quimb equinox jaxtyping --upgrade



In [2]:
import matplotlib.pyplot as plt
import sys
import pandas as pd
import jax
import time

from jaxtyping import Array, Float, Int, PyTree
import jax.numpy as jnp
import numpy as np

import equinox as eqx

import optax  # optimization using jax

import torch  # https://pytorch.org
import torchvision  # https://pytorch.org

import pennylane as qml
import pennylane.numpy as pnp

import functools, itertools

from jax.lib import xla_bridge

#jax.config.update("jax_enable_x64", True)

def set_jax_platform():
    # Check if TPU is available
    try:
        tpu_backend = xla_bridge.get_backend('tpu')
        if tpu_backend and tpu_backend.device_count() > 0:
            # Set platform to TPU
            jax.config.update('jax_platform_name', 'tpu')
            print("Set platform to TPU")
            return
    except RuntimeError:
        pass  # No TPU found, move on to check for GPU

    # Check if GPU is available
    try:
      gpu_backend = xla_bridge.get_backend('gpu')
      if gpu_backend and gpu_backend.device_count() > 0:
          # Set platform to CUDA (GPU)
          jax.config.update('jax_platform_name', 'gpu')
          print("Set platform to GPU")
    except RuntimeError:
          # Set platform to CPU
          jax.config.update('jax_platform_name', 'cpu')
          print("Set platform to CPU")

# Call the function to set the platform
set_jax_platform()

seed = 1701
jrng_key = jax.random.PRNGKey(seed)
torch.manual_seed(seed)
print(jax.devices())

Set platform to GPU
[cuda(id=0)]


In [3]:
# some training configuration
DATA = "CIFAR10" # "MNIST", "FashionMNIST", "CIFAR10"
REPLACEMENT_LVL = 2 # 0, 1, or 2
BATCH_SIZE = 200
EPOCHS = 100
LEARNING_RATE = 3E-4
REPEATS = 5

In [4]:
# utilities
ket = {
    '0':jnp.array([1,0]),
    '1':jnp.array([0,1]),
    '+':(jnp.array([1,0]) + jnp.array([0,1]))/jnp.sqrt(2),
    '-':(jnp.array([1,0]) - jnp.array([0,1]))/jnp.sqrt(2)
}

pauli = {
    'I':jnp.array([[1,0],[0,1]]),
    'X':jnp.array([[0,1],[1,0]]),
    'Y':jnp.array([[0, -1j],[1j, 0]]),
    'Z':jnp.array([[1,0],[0,-1]])
}

def tensor_product(*args):
  input_list = [a for a in args]
  return functools.reduce(jnp.kron, input_list)

def multi_qubit_identity(n_qubits:int)->jnp.ndarray:
  assert n_qubits>0
  if n_qubits == 1:
    return pauli['I']
  else:
    return tensor_product(*[pauli['I'] for _ in range(n_qubits)])

def pauli_dict_func(key):
    return pauli[key]

def pauli_dict_func_multiple_keys(keys):
    return list(map(pauli_dict_func, keys))

def pauli_string_tensor_prod(pauli_string:str):
    paulis_char = list(pauli_string)
    paulis_mat = pauli_dict_func_multiple_keys(paulis_char)
    return tensor_product(*paulis_mat)

def generate_nqubit_pauli_strings(n_qubits:int):
    assert n_qubits>0
    pauli_labels = ['I', 'X', 'Y', 'Z']
    pauli_strings = []
    for labels in itertools.product(pauli_labels, repeat=n_qubits):
        pauli_str = "".join(labels)
        if pauli_str != 'I'*n_qubits:
            pauli_strings.append(pauli_str)
    return pauli_strings

def generate_pauli_tensor_list(pauli_strings:list):
    return list(map(pauli_string_tensor_prod, pauli_strings))

su4_generators = generate_pauli_tensor_list(
    generate_nqubit_pauli_strings(2)
)

su32_generators = generate_pauli_tensor_list(
    generate_nqubit_pauli_strings(5)
)

su8_generators = generate_pauli_tensor_list(
    generate_nqubit_pauli_strings(3)
)

su16_generators = generate_pauli_tensor_list(
    generate_nqubit_pauli_strings(4)
)

def su32_op(
    params:jnp.ndarray
):
    generator = jnp.einsum("i, ijk - >jk", params, jnp.asarray(su32_generators))
    return jax.scipy.linalg.expm(1j*generator)

def su4_op(
    params:jnp.ndarray
):
    generator = jnp.einsum("i, ijk - >jk", params, jnp.asarray(su4_generators))
    return jax.scipy.linalg.expm(1j*generator)

def su8_op(
    params:jnp.ndarray
):
    generator = jnp.einsum("i, ijk - >jk", params, jnp.asarray(su8_generators))
    return jax.scipy.linalg.expm(1j*generator)

def su16_op(
    params:jnp.ndarray
):
    generator = jnp.einsum("i, ijk - >jk", params, jnp.asarray(su16_generators))
    return jax.scipy.linalg.expm(1j*generator)

def measure_sv(
    state:jnp.ndarray,
    observable:jnp.ndarray
    ):
  """
  Measure a statevector with a Hermitian observable.
  Note: No checking Hermitianicity of the observable or whether the observable
  has all real eigenvalues or not
  """
  expectation_value = jnp.dot(jnp.conj(state.T), jnp.dot(observable, state))
  return jnp.real(expectation_value)

def measure_dm(
    rho:jnp.ndarray,
    observable:jnp.ndarray
):
  """
  Measure a density matrix with a Hermitian observable.
  Note: No checking Hermitianicity of the observable or whether the observable
  has all real eigenvalues or not.
  """
  product = jnp.dot(rho, observable)

  # Calculate the trace, which is the sum of diagonal elements
  trace = jnp.trace(product)

  # The expectation value should be real for physical observables
  return jnp.real(trace)

# assuming the input patch (hermitianized) has shape (c, h, w)
# assuming the input set statevectors has shape (c, 2**n)
# assuming we have a list of (state, observable) pairs
vmap_measure_sv_ob_pairs = jax.vmap(lambda pair: measure_sv(pair[0], pair[1]), in_axes=0, out_axes=0)
# assuming the input set desnity matrices has shape (c, 2**n, 2**n)
# assuming we have a list of (rho, observable) pairs
vmap_measure_dm_ob_pairs = jax.vmap(lambda pair: measure_dm(pair[0], pair[1]), in_axes=0, out_axes=0)

# vmap through different observables
vmap_measure_sv = jax.vmap(measure_sv, in_axes=(None, 0), out_axes=0)
vmap_measure_dm = jax.vmap(measure_dm, in_axes=(None, 0), out_axes=0)

def bitstring_to_state(bitstring:str):
  """
  Convert a bit string, like '0101001' or '+-+-101'
  to a statevector. Each character in the bitstring must be among
  0, 1, + and -
  """
  assert len(bitstring)>0
  for c in bitstring:
    assert c in ['0', '1', '+', '-']
  single_qubit_states = [ket[c] for c in bitstring]
  return tensor_product(*single_qubit_states)


# utilities for the flipped quanvolution kernel
def extract_patches(image, patch_size, stride, padding=None):
    """
    Extracts patches from an image with multiple input channels and optional custom padding.

    Args:
        image (jnp.ndarray): Input image tensor of shape (in_channels, height, width).
        patch_size (int): Size of the square patches to extract.
        stride (int): Stride between patches.
        padding (tuple): Padding value(s) for each dimension.

    Returns:
        jnp.ndarray: Tensor of extracted patches of shape (num_patches, in_channels, patch_size, patch_size).
    """

    in_channels, height, width = image.shape[-3], image.shape[-2], image.shape[-1]

    pad_h, pad_w = padding if padding is not None else (0, 0)


    image = jnp.pad(image, [(0, 0), (pad_h, pad_h), (pad_w, pad_w)], mode='constant') if padding is not None else image


    _, height, width = image.shape


    num_patches_h = (height - patch_size) // stride + 1
    num_patches_w = (width - patch_size) // stride + 1

    patch_indices = [(i, j) for i in range(num_patches_h) for j in range(num_patches_w)]

    patches = jnp.stack([image[:, i*stride:i*stride+patch_size, j*stride:j*stride+patch_size]
                         for i, j in patch_indices])

    return patches


def generate_2q_param_state(theta):
  state = bitstring_to_state('00')
  state = jnp.dot(
      su4_op(theta),
      state
  )
  return state

vmap_generate_2q_param_state = jax.vmap(generate_2q_param_state, in_axes=0, out_axes = 0)

# FlippedQuanv3x3 kernel
def single_kernel_op(thetas, patch):
  # patch has shape (c_in, h, w)
  # thetas has shape (c_in, 4^2-1) for SU4 gates
  n_theta = thetas.shape[0]
  n_channel = patch.shape[0]
  assert n_theta == n_channel, "Thetas and patch must have the same number of channels."
  states = vmap_generate_2q_param_state(thetas)
  patch = jnp.pad(patch, [(0,0),(0,1),(0,1)], mode='constant')
  herm_patch = (jnp.einsum("ijk->ikj", patch)+patch)/2
  channel_out = vmap_measure_sv_ob_pairs([states, herm_patch])
  return jnp.sum(channel_out, axis = 0)/n_theta

vmap_single_kernel_op_through_extracted_patches = jax.vmap(single_kernel_op, in_axes=(None, 0), out_axes=0)

# For multiple channel output
# parameter has shape (c_out, c_in, 4**2-1) for SU4 gates
vmap_vmap_single_kernel_op_through_extracted_patches = jax.vmap(vmap_single_kernel_op_through_extracted_patches, in_axes=(0, None), out_axes=0)

# Quantum version of the linear layer
# Realized with data reuploading and Hamiltonian embedding
# for input dimension D
# the quantum linear layer is a n = ceil(log4(D+1))-qubit quantum circuit
# Both the data encoding and the parameterised cirucit are achieved via the SU(2^n) unitary

def data_encode_unitary(padded_data, t):
    #original_dim = 4**7 # fix to 7 qubits #padded_data.shape[-1]
    #new_dim = jnp.sqrt(original_dim).astype(jnp.int_)
    data = jnp.reshape(padded_data, (2**7, 2**7))
    generator = (data + jnp.einsum('...jk->...kj', data))/2
    return jax.scipy.linalg.expm(1.0j*t*generator)

def su_n(params, pauli_string_tensor_list):
    paulis = jnp.asarray(pauli_string_tensor_list)
    generator = jnp.einsum("i, ijk -> jk", params, paulis)
    return jax.scipy.linalg.expm(1.0j*generator)

def linear_layer_func(
        padded_data,
        params,
        pauli_string_tensor_list,
        observables,
        n_qubits
):
    n_rep = params.shape[0]
    state = bitstring_to_state("+" * n_qubits)
    data_unitary = data_encode_unitary(padded_data, 1.0/n_rep)
    for i in range(n_rep):
        state = jnp.dot(data_unitary, state)
        state = jnp.dot(su_n(params[i], pauli_string_tensor_list), state)
    return vmap_measure_sv(state, observables)

vmap_batch_linear_layer_func = jax.vmap(linear_layer_func, in_axes=(0, None, None, None, None), out_axes=0)

In [5]:
# models
class FlippedQuanv3x3(eqx.Module):
  weight: jax.Array
  bias: jax.Array
  stride: int
  pad: tuple|None
  pad_h: int
  pad_w: int

  def __init__(self, in_channels, out_channels, stride, padding, key):
    wkey, bkey = jax.random.split(key,2)
    self.weight = jax.random.normal(shape=[out_channels, in_channels, 15], key=wkey)
    self.bias = jax.random.normal(shape=[out_channels, 1], key=bkey)
    self.stride = stride
    self.pad = padding
    self.pad_h, self.pad_w = padding if padding is not None else (0,0)

  def __call__(self, x):
    # x has shape ( ,c_in, h, w)
    # weight has shape (c_out, c_in, 15)
    # bias has shape (c_out, 1)
    c_in, h_in, w_in = x.shape[-3], x.shape[-2], x.shape[-1]
    patches = extract_patches(x, patch_size=3, stride=self.stride, padding=self.pad)
    h_out = (h_in-3+2*self.pad_h)//self.stride +1
    w_out = (w_in-3+2*self.pad_w)//self.stride +1
    out = vmap_vmap_single_kernel_op_through_extracted_patches(self.weight, patches)
    out = out + self.bias
    return out.reshape((-1, h_out, w_out))

class DataReUploadingLinear(eqx.Module):
    weight: jax.Array
    bias: jax.Array
    n_reps:int
    in_dim:int
    out_dim:int
    n_qubits:int

    def __init__(self, n_reps, key):
        #assert 2 ** n_qubits >= out_dim
        #assert 4 ** n_qubits >= in_dim
        wkey, bkey = jax.random.split(key, 2)
        self.in_dim = 16 * 28 * 28
        self.out_dim = 10
        self.n_qubits = 7
        # in_dim is 16 * 28 * 28
        # out_dim is 10
        self.n_reps = n_reps
        param_dim = 4 ** self.n_qubits - 1
        self.weight = jax.random.normal(shape=[self.n_reps, param_dim], key=wkey)
        self.bias = jax.random.normal(shape=[self.out_dim], key=bkey)

    def generate_observables(self):
        observables = []
        for i in range(self.out_dim):
            temp_bitstring = '{0:b}'.format(i).zfill(self.n_qubits)
            ob = jnp.outer(bitstring_to_state(temp_bitstring), bitstring_to_state(temp_bitstring))
            observables.append(ob)
        return jnp.asarray(observables)

    def get_pauli_string_tensor_list(self):
        return generate_pauli_tensor_list(generate_nqubit_pauli_strings(self.n_qubits))

    #def get_pad_size(self):
    #    return 4 ** self.n_qubits - self.in_dim

    def __call__(self, x):
        # x has size (batchsize, in_dim)
        # pad x
        x = jnp.pad(x, (0, 4 ** self.n_qubits - self.in_dim))

        out = linear_layer_func(
            padded_data=x,
            params=self.weight,
            pauli_string_tensor_list=self.get_pauli_string_tensor_list(),
            observables=self.generate_observables(),
            n_qubits=self.n_qubits
        )
        out = out + self.bias
        return out

class HybridNet(eqx.Module):
    #in_channels:int
    replacement_lvl:int
    layers: list
    def __init__(self, replacement_lvl, key):
        assert replacement_lvl in [0,1,2]
        #self.in_channels = in_channels
        self.replacement_lvl = replacement_lvl
        key1, key2, key3 = jax.random.split(key, 3)
        if self.replacement_lvl == 0:
            self.layers = [
                eqx.nn.Conv2d(3, 32, kernel_size=3, padding=0, key=key1),
                eqx.nn.Conv2d(32, 16, kernel_size=3, padding=0, key=key2),
                jnp.ravel,
                eqx.nn.Linear(16 * 28 * 28, 10, key=key3),
            ]
        elif self.replacement_lvl == 1:
            self.layers = [
                FlippedQuanv3x3(3, 32, stride=1, padding=(0, 0), key=key1),
                FlippedQuanv3x3(32, 16, stride=1, padding=(0, 0), key=key2),
                jnp.ravel,
                eqx.nn.Linear(16 * 28 * 28, 10, key=key3),
            ]
        elif self.replacement_lvl == 2:
            self.layers = [
                FlippedQuanv3x3(3, 32, stride=1, padding=(0, 0), key=key1),
                FlippedQuanv3x3(32, 16, stride=1, padding=(0, 0), key=key2),
                jnp.ravel,
                DataReUploadingLinear(n_reps=1, key=key3),
            ]
        else:
            raise ValueError("replacement_lvl should be 0, 1, or 2")

    def __call__(self, x:Float[Array, "3 h w"])->Float[Array, "10"]:
        for layer in self.layers:
            x = layer(x)
        return x

In [6]:
test_img = jnp.stack([jnp.arange(32*32*1, dtype=jnp.float_).reshape((32, 32))]*3*BATCH_SIZE, axis = 0).reshape((BATCH_SIZE,3,32,32))
print(test_img.shape)
for rpl_lvl in [0,1,2]:
  print(f"Replacement lvl={rpl_lvl}")
  start = time.time()
  model = HybridNet(replacement_lvl=rpl_lvl, key=jrng_key)
  print(model)
  test_out = jax.vmap(model)(test_img)
  print(test_out.shape)
  end = time.time()
  print(f"Time taken: {end-start}")

  test_img = jnp.stack([jnp.arange(32*32*1, dtype=jnp.float_).reshape((32, 32))]*3*BATCH_SIZE, axis = 0).reshape((BATCH_SIZE,3,32,32))


(200, 3, 32, 32)
Replacement lvl=0
HybridNet(
  replacement_lvl=0,
  layers=[
    Conv2d(
      num_spatial_dims=2,
      weight=f32[32,3,3,3],
      bias=f32[32,1,1],
      in_channels=3,
      out_channels=32,
      kernel_size=(3, 3),
      stride=(1, 1),
      padding=((0, 0), (0, 0)),
      dilation=(1, 1),
      groups=1,
      use_bias=True,
      padding_mode='ZEROS'
    ),
    Conv2d(
      num_spatial_dims=2,
      weight=f32[16,32,3,3],
      bias=f32[16,1,1],
      in_channels=32,
      out_channels=16,
      kernel_size=(3, 3),
      stride=(1, 1),
      padding=((0, 0), (0, 0)),
      dilation=(1, 1),
      groups=1,
      use_bias=True,
      padding_mode='ZEROS'
    ),
    <wrapped function ravel>,
    Linear(
      weight=f32[10,12544],
      bias=f32[10],
      in_features=12544,
      out_features=10,
      use_bias=True
    )
  ]
)
(200, 10)
Time taken: 1.7253785133361816
Replacement lvl=1
HybridNet(
  replacement_lvl=1,
  layers=[
    FlippedQuanv3x3(
      weight=

In [7]:
# training utilities
start_compile = time.time()
@eqx.filter_jit
def compute_out(
    model:HybridNet,
    x:Float[Array, "batch 3 32 32"]
) -> Float[Array, "batch 10"]:
    return jax.vmap(model)(x)

model = HybridNet(replacement_lvl=2, key=jrng_key)
test_out = compute_out(model, test_img)
end = time.time()
print(f"Time taken: {end-start_compile}")
print(test_out.shape)

test_img2 = jnp.stack([jnp.arange(32*32*1, dtype=jnp.float_).reshape((32, 32))]*3*BATCH_SIZE, axis = 0).reshape((BATCH_SIZE,3,32,32))
print(test_img2.shape)
start = time.time()
test_out2 = compute_out(model, test_img2)
end = time.time()
print(f"Time taken: {end-start}")

Time taken: 696.3979589939117
(200, 10)
(200, 3, 32, 32)
Time taken: 0.12333202362060547


  test_img2 = jnp.stack([jnp.arange(32*32*1, dtype=jnp.float_).reshape((32, 32))]*3*BATCH_SIZE, axis = 0).reshape((BATCH_SIZE,3,32,32))


In [8]:
def loss_fn(
    pred_y:Float[Array, "batch 10"],
    y:Int[Array, "batch"]
):
  return jnp.mean(optax.softmax_cross_entropy_with_integer_labels(pred_y, y))

def accuracy_fn(
    pred_y:Float[Array, "batch 10"],
    y:Int[Array, "batch"]
):
  pred = jnp.argmax(pred_y, axis=1)
  return jnp.sum(jnp.array(pred == y).astype(int)) / len(pred_y)

@eqx.filter_jit
def compute_loss(
    model: HybridNet, x: Float[Array, "batch 3 32 32"], y: Int[Array, " batch"]
):
  pred_y = compute_out(model, x)
  return loss_fn(pred_y, y)

@eqx.filter_jit
def compute_accuracy(
    model: HybridNet, x: Float[Array, "batch 3 32 32"], y: Int[Array, " batch"]
):
  pred_y = compute_out(model, x)
  return accuracy_fn(pred_y, y)

def evaluate(
    model: HybridNet,
    test_loader: torch.utils.data.DataLoader,
):
  avg_loss = 0
  avg_acc = 0
  for x, y in test_loader:
    x = x.numpy()
    y = y.numpy()
    avg_loss += compute_loss(model, x, y)
    avg_acc += compute_accuracy(model, x, y)
  return avg_loss / len(test_loader), avg_acc / len(test_loader)

In [9]:
# training utilities
def get_train_test_data(name = "MNIST"):
  assert name in ["MNIST", "CIFAR10", "FashionMNIST"]
  if name  == "CIFAR10":
    preprocess = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize((0.5,), (0.5,))
    ])
  else:
    preprocess =  torchvision.transforms.Compose([
    torchvision.transforms.Pad(2),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize((0.5,), (0.5,))
    ])
  if name == "MNIST":
    train = torchvision.datasets.MNIST(
        root="MNIST",
        train=True,
        transform=preprocess,
        download=True)
    test = torchvision.datasets.MNIST(
        root="MNIST",
        train=False,
        transform=preprocess,
        download=True)
  elif name == "CIFAR10":
    train = torchvision.datasets.CIFAR10(
        root="CIFAR10",
        train=True,
        transform=preprocess,
        download=True)
    test = torchvision.datasets.CIFAR10(
        root="CIFAR10",
        train=False,
        transform=preprocess,
        download=True)
  elif name == "FashionMNIST":
    train = torchvision.datasets.FashionMNIST(
        root="FashionMNIST",
        train=True,
        transform=preprocess,
        download=True)
    test = torchvision.datasets.FashionMNIST(
        root="FashionMNIST",
        train=False,
        transform=preprocess,
        download=True)
  else:
    raise ValueError("name should be MNIST, CIFAR10, or FashionMNIST")
  return train, test

def train(
      model: HybridNet,
      data_name: str,
      batchsize: int,
      epochs: int,
      optim: optax.GradientTransformation,
  ):

  train_loss = []
  train_acc = []
  test_loss = []
  test_acc = []

  opt_state = optim.init(eqx.filter(model, eqx.is_array))
  train_data, test_data = get_train_test_data(data_name)
  trainloader = torch.utils.data.DataLoader(
      train_data, batch_size=batchsize, shuffle=True)
  test_loader = torch.utils.data.DataLoader(
      test_data, batch_size=batchsize, shuffle=False)
  if data_name == "CIFAR10":
    @eqx.filter_jit
    def make_step(
      model: HybridNet,
      opt_state: PyTree,
      x: Float[Array, "batch 3 32 32"],
      y: Int[Array, "batch"],
    ):
      loss_value, grads = eqx.filter_value_and_grad(compute_loss)(model, x, y)
      updates, opt_state = optim.update(grads, opt_state, model)
      model = eqx.apply_updates(model, updates)
      acc = compute_accuracy(model, x, y)
      return model, opt_state, loss_value, acc

    @eqx.filter_jit
    def make_test_step(
        model:HybridNet,
        x: Float[Array, "batch 3 32 32"],
        y: Int[Array, "batch"],
    ):
      out = compute_out(model, x)
      loss_value = loss_fn(out, y)
      acc = accuracy_fn(out, y)
      return loss_value, acc
  else:
    @eqx.filter_jit
    def make_step(
      model: HybridNet,
      opt_state: PyTree,
      x: Float[Array, "batch 1 32 32"],
      y: Int[Array, "batch"],
    ):
      loss_value, grads = eqx.filter_value_and_grad(compute_loss)(model, x, y)
      updates, opt_state = optim.update(grads, opt_state, model)
      model = eqx.apply_updates(model, updates)
      acc = compute_accuracy(model, x, y)
      return model, opt_state, loss_value, acc

    @eqx.filter_jit
    def make_test_step(
        model:HybridNet,
        x: Float[Array, "batch 1 32 32"],
        y: Int[Array, "batch"],
    ):
      out = compute_out(model, x)
      loss_value = loss_fn(out, y)
      acc = accuracy_fn(out, y)
      return loss_value, acc


  for step in range(epochs):
    step_start = time.time()
    batch_train_loss = []
    batch_train_acc = []
    batch_test_loss = []
    batch_test_acc = []
    for batchidx, (x, y) in enumerate(trainloader):
      x = x.numpy()
      y = y.numpy()
      model, opt_state, loss_value_batch, acc_value_batch = make_step(model, opt_state, x, y)
      batch_train_loss.append(loss_value_batch)
      batch_train_acc.append(acc_value_batch)
    loss_value = np.mean(batch_train_loss)
    acc_value = np.mean(batch_train_acc)
    train_loss.append(loss_value)
    train_acc.append(acc_value)
    train_time = time.time() - step_start
    print(f"Train step {step}, loss = {loss_value:.4f}, acc = {acc_value:.4f}; train time = {train_time:.4f} seconds")

    for batchidx, (x, y) in enumerate(test_loader):
      x = x.numpy()
      y = y.numpy()
      loss_value_batch, acc_value_batch = make_test_step(model, x, y)
      batch_test_loss.append(loss_value_batch)
      batch_test_acc.append(acc_value_batch)

    test_loss_value, test_acc_value = np.mean(batch_test_loss), np.mean(batch_test_acc)
    test_loss.append(test_loss_value)
    test_acc.append(test_acc_value)
    test_time = time.time() - step_start - train_time
    print(f"Test loss = {test_loss_value:.4f}, acc = {test_acc_value:.4f}; test time = {test_time:.4f} seconds")
    print(f"Total time = {time.time() - step_start:.4f} seconds")

  return model, train_loss, train_acc, test_loss, test_acc

In [10]:
model = HybridNet(replacement_lvl=2, key = jrng_key)
optim = optax.sgd(LEARNING_RATE)
model, train_loss, train_acc, test_loss, test_acc = train(model, "CIFAR10", BATCH_SIZE, 10, optim)

Files already downloaded and verified
Files already downloaded and verified


  x_bar = _convert_element_type(x_bar, x.aval.dtype, x.aval.weak_type)
  x_bar = _convert_element_type(x_bar, x.aval.dtype, x.aval.weak_type)
  x_bar = _convert_element_type(x_bar, x.aval.dtype, x.aval.weak_type)
  x_bar = _convert_element_type(x_bar, x.aval.dtype, x.aval.weak_type)


Train step 0, loss = 2.6655, acc = 0.1000; train time = 1081.4233 seconds


ValueError: Incompatible shapes for broadcasting: shapes=[(200, 32, 32), (200,)]