In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [None]:
#!conda install -c conda-forge pillow -y
#!conda install -c conda-forge pydicom -y
#!conda install -c conda-forge gdcm -y
#!pip install pylibjpeg pylibjpeg-libjpeg
#!pip install pydicom

In [None]:
!pip install --upgrade numpy==1.20.0

In [None]:
%%bash
conda install -c conda-forge gdcm -y

In [None]:
!pip install pylibjpeg pylibjpeg-libjpeg

In [None]:
import cv2
import io
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import os
import time
import argparse
import pandas as pd
from PIL import Image
import pydicom
import torch.nn as nn
import seaborn as sns
import warnings

In [None]:
TRAIN_DIR = '../input/siim-covid19-detection/train'
TEST_DIR = '../input/siim-covid19-detection/test'

In [None]:
train = pd.read_csv('../input/siim-covid19-detection/train_image_level.csv')
train_study = pd.read_csv('../input/siim-covid19-detection/train_study_level.csv')
#train_study.head()

In [None]:
train_study['id'] = train_study['id'].apply(lambda i: i.split('_')[0])
train_study.rename(columns={'Negative for Pneumonia': '0','Typical Appearance': '1',"Indeterminate Appearance": '2',
                   "Atypical Appearance": "3"}, inplace=True)
#train_study.head()

In [None]:
labels = []
def get_label(row):
    for c in train_study.columns:
        if row[c] == 1:
            labels.append(int(c))
            
train_study.apply(get_label, axis=1)
train_study.drop(columns=['0', '1','2', '3'], inplace=True)
train_study['label'] = labels
#train_study.head()

In [None]:
from os import listdir, walk
from skimage import exposure
import torch
from torch.utils.data import Dataset, DataLoader

In [None]:
class CovLungDataset(Dataset):
    def __init__(self, dir_path, labels_data, transforms=None, new_size=(512, 512)):
        self.dir_path = dir_path
        self.labels_data = labels_data
        self.new_size = new_size
        self.transforms = transforms

    def __len__(self):
        return len(self.labels_data)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        
        image_dir = self.labels_data.iloc[idx]['id']
        img_label = self.labels_data.iloc[idx]['label']
        
        path_to_img = os.path.join(self.dir_path, image_dir)
        # get first image path only
        path_to_img = os.path.join(path_to_img, listdir(path_to_img)[0])
        path_to_img = os.path.join(path_to_img, next(walk(path_to_img))[2][0])
        
        # read image
        data = pydicom.dcmread(path_to_img)
        image = data.pixel_array
        image = exposure.equalize_hist(image)
        
        good_height, good_width = self.new_size
        image = cv2.resize(image, (good_width, good_height), interpolation=Image.LANCZOS)
        
        # data augmentation
        if self.transforms:
            # doesn't work on floats
            image = (image * 255).astype(np.uint8)
            image = self.transforms(image=image)['image']
        
        sample = {'image': image, 'label': img_label}
        return sample


In [None]:
import albumentations as A
from albumentations.pytorch import ToTensorV2

transform = A.Compose([A.RandomBrightnessContrast(brightness_limit=[-0.2, 0.2], 
                                                  contrast_limit=[-0.2, 0.2], 
                                                  p=0.5),
                       A.ShiftScaleRotate(scale_limit=[-0.1, 0.3], 
                                          shift_limit=0.1, 
                                          rotate_limit=20, 
                                          border_mode=cv2.BORDER_CONSTANT,
                                          p=0.5),
                       # reshape image of size (k, n, 1) into (1, k, n)
                       ToTensorV2(p=1.0)
                      ])


In [None]:


transformed_train_dataset = CovLungDataset(dir_path=TRAIN_DIR,
                                     labels_data=train_study[['id', 'label']],
                                     transforms=transform,
                                     new_size=(512, 512))


train_dataloader = DataLoader(transformed_train_dataset, batch_size=8, shuffle=False, num_workers=2)



In [None]:

transformed_test_dataset = CovLungDataset(dir_path=TEST_DIR,
                                     labels_data=train_study[['id', 'label']],
                                     transforms=transform,
                                     new_size=(512, 512))


test_dataloader = DataLoader(transformed_test_dataset, batch_size=8, shuffle=False, num_workers=2)


In [None]:
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import ImageGrid
import numpy as np


dataiter = iter(train_dataloader)
sample = dataiter.next()
#sample = next(iter(transformed_train_dataset))
images = sample['image'] 
labels = sample['label']

fig = plt.figure(figsize=(15., 25.))
grid = ImageGrid(fig, 111, 
                 nrows_ncols=(2, 4),
                 axes_pad=.4,
                 )

labels_dict = {0: 'Negative for Pneumonia',  1: 'Typical Appearance',  2: 'Indeterminate Appearance',  3: 'Atypical Appearance'}

j = 0
for ax, im in zip(grid, images):
    im = im.numpy()
    im = np.transpose(im, (1,2,0))
    ax.imshow(im, cmap='gray')
    ax.set_title(labels_dict[labels[j].item()], fontsize=12)
    j += 1

plt.show()


In [None]:
parser = argparse.ArgumentParser()
parser.add_argument('--num_epochs', type=int, default=10, help='Number of training epochs')
parser.add_argument('--batch_size', type=int, default=64, help='Batch size')
parser.add_argument('--lr', type=float, default=5e-4, help='Learning rate')
parser.add_argument('--l2_reg', type=float, default=0, help='L2 regularisation')
parser.add_argument('--aug', action='store_true', default=False, help='Use data augmentation')
parser.add_argument('--data_path', type=str, default='/input/siim-covid19-detection',help='Path to data.')
parser.add_argument('--bond_dim', type=int, default=5, help='MPS Bond dimension')
parser.add_argument('--nChannel', type=int, default=1, help='Number of input channels')
parser.add_argument('--dense_net', action='store_true', default=False, help='Using Dense Net model')

In [None]:
args = parser.parse_args([])

In [None]:
# Miscellaneous initialization
torch.manual_seed(0)
start_time = time.time()

# MPS parameters
bond_dim = 20
adaptive_mode = False
periodic_bc = False

# Training parameters
#num_train = 2000
#num_test = 1000
batch_size = 100
num_epochs = args.num_epochs
learn_rate = args.lr
l2_reg = args.l2_reg


batch_size = args.batch_size
bond_dim = args.bond_dim


# LoTeNet parameters
adaptive_mode = False 
periodic_bc   = False

kernel = 2 # Stride along spatial dimensions
output_dim = 4 # output dimension
 
feature_dim = 2

#logFile = time.strftime("%Y%m%d_%H_%M")+'.txt'
#makeLogFile(logFile)

normTensor = 0.5*torch.ones(args.nChannel)
### Data processing and loading....

In [None]:
def svd_flex(tensor, svd_string, max_D=None, cutoff=1e-10, sv_right=True, sv_vec=None):
    """
    Split an input tensor into two pieces using a SVD across some partition

    Args:
        tensor (Tensor):    Pytorch tensor with at least two indices

        svd_string (str):   String of the form 'init_str->left_str,right_str',
                            where init_str describes the indices of tensor, and
                            left_str/right_str describe those of the left and
                            right output tensors. The characters of left_str
                            and right_str form a partition of the characters in
                            init_str, but each contain one additional character
                            representing the new bond which comes from the SVD

                            Reversing the terms in svd_string to the left and
                            right of '->' gives an ein_string which can be used
                            to multiply both output tensors to give a (low rank
                            approximation) of the input tensor

        cutoff (float):     A truncation threshold which eliminates any
                            singular values which are strictly less than cutoff

        max_D (int):        A maximum allowed value for the new bond. If max_D
                            is specified, the returned tensors

        sv_right (bool):    The SVD gives two orthogonal matrices and a matrix
                            of singular values. sv_right=True merges the SV
                            matrix with the right output, while sv_right=False
                            merges it with the left output

        sv_vec (Tensor):    Pytorch vector with length max_D, which is modified
                            in place to return the vector of singular values

    Returns:
        left_tensor (Tensor),
        right_tensor (Tensor):  Tensors whose indices are described by the
                                left_str and right_str parts of svd_string

        bond_dim:               The dimension of the new bond appearing from
                                the cutoff in our SVD. Note that this generally
                                won't match the dimension of left_/right_tensor
                                at this mode, which is padded with zeros
                                whenever max_D is specified
    """

    def prod(int_list):
        output = 1
        for num in int_list:
            output *= num
        return output

    with torch.no_grad():
        # Parse svd_string into init_str, left_str, and right_str
        svd_string = svd_string.replace(" ", "")
        init_str, post_str = svd_string.split("->")
        left_str, right_str = post_str.split(",")

        # Check formatting of init_str, left_str, and right_str
        assert all([c.islower() for c in init_str + left_str + right_str])
        assert len(set(init_str + left_str + right_str)) == len(init_str) + 1
        assert len(set(init_str)) + len(set(left_str)) + len(set(right_str)) == len(
            init_str
        ) + len(left_str) + len(right_str)

        # Get the special character representing our SVD-truncated bond
        bond_char = set(left_str).intersection(set(right_str)).pop()
        left_part = left_str.replace(bond_char, "")
        right_part = right_str.replace(bond_char, "")

        # Permute our tensor into something that can be viewed as a matrix
        ein_str = f"{init_str}->{left_part+right_part}"
        tensor = torch.einsum(ein_str, [tensor]).contiguous()

        left_shape = list(tensor.shape[: len(left_part)])
        right_shape = list(tensor.shape[len(left_part) :])
        left_dim, right_dim = prod(left_shape), prod(right_shape)

        tensor = tensor.view([left_dim, right_dim])

        # Get SVD and format so that left_mat * diag(svs) * right_mat = tensor
        left_mat, svs, right_mat = torch.svd(tensor)
        svs, _ = torch.sort(svs, descending=True)
        right_mat = torch.t(right_mat)

        # Decrease or increase our tensor sizes in the presence of max_D
        if max_D and len(svs) > max_D:
            svs = svs[:max_D]
            left_mat = left_mat[:, :max_D]
            right_mat = right_mat[:max_D]
        elif max_D and len(svs) < max_D:
            copy_svs = torch.zeros([max_D])
            copy_svs[: len(svs)] = svs
            copy_left = torch.zeros([left_mat.size(0), max_D])
            copy_left[:, : left_mat.size(1)] = left_mat
            copy_right = torch.zeros([max_D, right_mat.size(1)])
            copy_right[: right_mat.size(0)] = right_mat
            svs, left_mat, right_mat = copy_svs, copy_left, copy_right

        # If given as input, copy singular values into sv_vec
        if sv_vec is not None and svs.shape == sv_vec.shape:
            sv_vec[:] = svs
        elif sv_vec is not None and svs.shape != sv_vec.shape:
            raise TypeError(
                f"sv_vec.shape must be {list(svs.shape)}, but is "
                f"currently {list(sv_vec.shape)}"
            )

        # Find the truncation point relative to our singular value cutoff
        truncation = 0
        for s in svs:
            if s < cutoff:
                break
            truncation += 1
        if truncation == 0:
            raise RuntimeError(
                "SVD cutoff too large, attempted to truncate "
                "tensor to bond dimension 0"
            )

        # Perform the actual truncation
        if max_D:
            svs[truncation:] = 0
            left_mat[:, truncation:] = 0
            right_mat[truncation:] = 0
        else:
            # If max_D wasn't given, set it to the truncation index
            max_D = truncation
            svs = svs[:truncation]
            left_mat = left_mat[:, :truncation]
            right_mat = right_mat[:truncation]

        # Merge the singular values into the appropriate matrix
        if sv_right:
            right_mat = torch.einsum("l,lr->lr", [svs, right_mat])
        else:
            left_mat = torch.einsum("lr,r->lr", [left_mat, svs])

        # Reshape the matrices to make them proper tensors
        left_tensor = left_mat.view(left_shape + [max_D])
        right_tensor = right_mat.view([max_D] + right_shape)

        # Finally, permute the indices into the desired order
        if left_str != left_part + bond_char:
            left_tensor = torch.einsum(
                f"{left_part+bond_char}->{left_str}", [left_tensor]
            )
        if right_str != bond_char + right_part:
            right_tensor = torch.einsum(
                f"{bond_char+right_part}->{right_str}", [right_tensor]
            )

        return left_tensor, right_tensor, truncation


def init_tensor(shape, bond_str, init_method):
    """
    Initialize a tensor with a given shape

    Args:
        shape:       The shape of our output parameter tensor.

        bond_str:    The bond string describing our output parameter tensor,
                     which is used in 'random_eye' initialization method.
                     The characters 'l' and 'r' are used to refer to the
                     left or right virtual indices of our tensor, and are
                     both required to be present for the random_eye and
                     min_random_eye initialization methods.

        init_method: The method used to initialize the entries of our tensor.
                     This can be either a string, or else a tuple whose first
                     entry is an initialization method and whose remaining
                     entries are specific to that method. In each case, std
                     will always refer to a standard deviation for a random
                     normal random component of each entry of the tensor.

                     Allowed options are:
                        * ('random_eye', std): Initialize each tensor input
                            slice close to the identity
                        * ('random_zero', std): Initialize each tensor input
                            slice close to the zero matrix
                        * ('min_random_eye', std, init_dim): Initialize each
                            tensor input slice close to a truncated identity
                            matrix, whose truncation leaves init_dim unit
                            entries on the diagonal. If init_dim is larger
                            than either of the bond dimensions, then init_dim
                            is capped at the smaller bond dimension.
    """
    # Unpack init_method if it is a tuple
    if not isinstance(init_method, str):
        init_str = init_method[0]
        std = init_method[1]
        if init_str == "min_random_eye":
            init_dim = init_method[2]

        init_method = init_str
    else:
        std = 1e-9

    # Check that bond_str is properly sized and doesn't have repeat indices
    assert len(shape) == len(bond_str)
    assert len(set(bond_str)) == len(bond_str)

    if init_method not in ["random_eye", "min_random_eye", "random_zero"]:
        raise ValueError(f"Unknown initialization method: {init_method}")

    if init_method in ["random_eye", "min_random_eye"]:
        bond_chars = ["l", "r"]
        assert all([c in bond_str for c in bond_chars])

        # Initialize our tensor slices as identity matrices which each fill
        # some or all of the initially allocated bond space
        if init_method == "min_random_eye":

            # The dimensions for our initial identity matrix. These will each
            # be init_dim, unless init_dim exceeds one of the bond dimensions
            bond_dims = [shape[bond_str.index(c)] for c in bond_chars]
            if all([init_dim <= full_dim for full_dim in bond_dims]):
                bond_dims = [init_dim, init_dim]
            else:
                init_dim = min(bond_dims)

            eye_shape = [init_dim if c in bond_chars else 1 for c in bond_str]
            expand_shape = [
                init_dim if c in bond_chars else shape[i]
                for i, c in enumerate(bond_str)
            ]

        elif init_method == "random_eye":
            eye_shape = [
                shape[i] if c in bond_chars else 1 for i, c in enumerate(bond_str)
            ]
            expand_shape = shape
            bond_dims = [shape[bond_str.index(c)] for c in bond_chars]

        eye_tensor = torch.eye(bond_dims[0], bond_dims[1]).view(eye_shape)
        eye_tensor = eye_tensor.expand(expand_shape)

        tensor = torch.zeros(shape)
        tensor[[slice(dim) for dim in expand_shape]] = eye_tensor

        # Add on a bit of random noise
        tensor += std * torch.randn(shape)

    elif init_method == "random_zero":
        tensor = std * torch.randn(shape)

    return tensor


### OLDER MISCELLANEOUS FUNCTIONS ###   # noqa: E266


def onehot(labels, max_value):
    """
    Convert a batch of labels from the set {0, 1,..., num_value-1} into their
    onehot encoded counterparts
    """
    label_vecs = torch.zeros([len(labels), max_value])

    for i, label in enumerate(labels):
        label_vecs[i, label] = 1.0

    return label_vecs


def joint_shuffle(input_data, input_labels):
    """
    Shuffle input data and labels in a joint manner, so each label points to
    its corresponding datum. Works for both regular and CUDA tensors
    """
    assert input_data.is_cuda == input_labels.is_cuda
    use_gpu = input_data.is_cuda
    if use_gpu:
        input_data, input_labels = input_data.cpu(), input_labels.cpu()

    data, labels = input_data.numpy(), input_labels.numpy()

    # Shuffle relative to the same seed
    np.random.seed(0)
    np.random.shuffle(data)
    np.random.seed(0)
    np.random.shuffle(labels)

    data, labels = torch.from_numpy(data), torch.from_numpy(labels)
    if use_gpu:
        data, labels = data.cuda(), labels.cuda()

    return data, labels


def load_HV_data(length):
    """
    Output a toy "horizontal/vertical" data set of black and white
    images with size length x length. Each image contains a single
    horizontal or vertical stripe, set against a background
    of the opposite color. The labels associated with these images
    are either 0 (horizontal stripe) or 1 (vertical stripe).

    In its current version, this returns two data sets, a training
    set with 75% of the images and a test set with 25% of the
    images.
    """
    num_images = 4 * (2 ** (length - 1) - 1)
    num_patterns = num_images // 2
    split = num_images // 4

    if length > 14:
        print(
            "load_HV_data will generate {} images, "
            "this could take a while...".format(num_images)
        )

    images = np.empty([num_images, length, length], dtype=np.float32)
    labels = np.empty(num_images, dtype=np.int)

    # Used to generate the stripe pattern from integer i below
    template = "{:0" + str(length) + "b}"

    for i in range(1, num_patterns + 1):
        pattern = template.format(i)
        pattern = [int(s) for s in pattern]

        for j, val in enumerate(pattern):
            # Horizontal stripe pattern
            images[2 * i - 2, j, :] = val
            # Vertical stripe pattern
            images[2 * i - 1, :, j] = val

        labels[2 * i - 2] = 0
        labels[2 * i - 1] = 1

    # Shuffle and partition into training and test sets
    np.random.seed(0)
    np.random.shuffle(images)
    np.random.seed(0)
    np.random.shuffle(labels)

    train_images, train_labels = images[split:], labels[split:]
    test_images, test_labels = images[:split], labels[:split]

    return (
        torch.from_numpy(train_images),
        torch.from_numpy(train_labels),
        torch.from_numpy(test_images),
        torch.from_numpy(test_labels),
    )

In [None]:
class Contractable:
    """
    Container for tensors with labeled indices and a global batch size

    The labels for our indices give some high-level knowledge of the tensor
    layout, and permit the contraction of pairs of indices in a more
    systematic manner. However, much of the actual heavy lifting is done
    through specific contraction routines in different subclasses

    Attributes:
        tensor (Tensor):    A Pytorch tensor whose first index is a batch
                            index. Sub-classes of Contractable may put other
                            restrictions on tensor
        bond_str (str):     A string whose letters each label a separate mode
                            of our tensor, and whose length equals the order
                            (number of modes) of our tensor
        global_bs (int):    The batch size associated with all Contractables.
                            This is shared between all Contractable instances
                            and allows for automatic expanding of tensors
    """

    # The global batch size
    global_bs = None

    def __init__(self, tensor, bond_str):
        shape = list(tensor.shape)
        num_dim = len(shape)
        str_len = len(bond_str)

        global_bs = Contractable.global_bs
        batch_dim = tensor.size(0)

        # Expand along a new batch dimension if needed
        if ("b" not in bond_str and str_len == num_dim) or (
            "b" == bond_str[0] and str_len == num_dim + 1
        ):
            if global_bs is not None:
                tensor = tensor.unsqueeze(0).expand([global_bs] + shape)
            else:
                raise RuntimeError(
                    "No batch size given and no previous " "batch size set"
                )
            if bond_str[0] != "b":
                bond_str = "b" + bond_str

        # Check for correct formatting in bond_str
        elif bond_str[0] != "b" or str_len != num_dim:
            raise ValueError(
                "Length of bond string '{bond_str}' "
                f"({len(bond_str)}) must match order of "
                f"tensor ({len(shape)})"
            )

        # Set the global batch size if it is unset or needs to be updated
        elif global_bs is None or global_bs != batch_dim:
            Contractable.global_bs = batch_dim

        # Check that global batch size agrees with input tensor's first dim
        elif global_bs != batch_dim:
            raise RuntimeError(
                f"Batch size previously set to {global_bs}"
                ", but input tensor has batch size "
                f"{batch_dim}"
            )

        # Set the defining attributes of our Contractable
        self.tensor = tensor
        self.bond_str = bond_str

    def __mul__(self, contractable, rmul=False):
        """
        Multiply with another contractable along a linear index

        The default behavior is to multiply the 'r' index of this instance
        with the 'l' index of contractable, matching the batch ('b')
        index of both, and take the outer product of other indices.
        If rmul is True, contractable is instead multiplied on the right.
        """
        # This method works for general Core subclasses besides Scalar (no 'l'
        # and 'r' indices), composite contractables (no tensor attribute), and
        # MatRegion (multiplication isn't just simple index contraction)
        if (
            isinstance(contractable, Scalar)
            or not hasattr(contractable, "tensor")
            or type(contractable) is MatRegion
        ):
            return NotImplemented

        tensors = [self.tensor, contractable.tensor]
        bond_strs = [list(self.bond_str), list(contractable.bond_str)]
        lowercases = [chr(c) for c in range(ord("a"), ord("z") + 1)]

        # Reverse the order of tensors if needed
        if rmul:
            tensors = tensors[::-1]
            bond_strs = bond_strs[::-1]

        # Check that bond strings are in proper format
        for i, bs in enumerate(bond_strs):
            assert bs[0] == "b"
            assert len(set(bs)) == len(bs)
            assert all([c in lowercases for c in bs])
            assert (i == 0 and "r" in bs) or (i == 1 and "l" in bs)

        # Get used and free characters
        used_chars = set(bond_strs[0]).union(bond_strs[1])
        free_chars = [c for c in lowercases if c not in used_chars]

        # Rename overlapping indices in the bond strings (except 'b', 'l', 'r')
        specials = ["b", "l", "r"]
        for i, c in enumerate(bond_strs[1]):
            if c in bond_strs[0] and c not in specials:
                bond_strs[1][i] = free_chars.pop()

        # Combine right bond of left tensor and left bond of right tensor
        sum_char = free_chars.pop()
        bond_strs[0][bond_strs[0].index("r")] = sum_char
        bond_strs[1][bond_strs[1].index("l")] = sum_char
        specials.append(sum_char)

        # Build bond string of ouput tensor
        out_str = ["b"]
        for bs in bond_strs:
            out_str.extend([c for c in bs if c not in specials])
        out_str.append("l" if "l" in bond_strs[0] else "")
        out_str.append("r" if "r" in bond_strs[1] else "")

        # Build the einsum string for this operation
        bond_strs = ["".join(bs) for bs in bond_strs]
        out_str = "".join(out_str)
        ein_str = f"{bond_strs[0]},{bond_strs[1]}->{out_str}"

        # Contract along the linear dimension to get an output tensor
        out_tensor = torch.einsum(ein_str, [tensors[0], tensors[1]])

        # Return our output tensor wrapped in an appropriate class
        if out_str == "br":
            return EdgeVec(out_tensor, is_left_vec=True)
        elif out_str == "bl":
            return EdgeVec(out_tensor, is_left_vec=False)
        elif out_str == "blr":
            return SingleMat(out_tensor)
        elif out_str == "bolr":
            return OutputCore(out_tensor)
        else:
            return Contractable(out_tensor, out_str)

    def __rmul__(self, contractable):
        """
        Multiply with another contractable along a linear index
        """
        return self.__mul__(contractable, rmul=True)

    def reduce(self):
        """
        Return the contractable without any modification

        reduce() can be any method which returns a contractable. This is
        trivially possible for any contractable by returning itself
        """
        return self


class ContractableList(Contractable):
    """
    A list of contractables which can all be multiplied together in order

    Calling reduce on a ContractableList instance will first reduce every item
    to a linear contractable, and then contract everything together
    """

    def __init__(self, contractable_list):
        # Check that input list is nonempty and has contractables as entries
        if not isinstance(contractable_list, list) or contractable_list is []:
            raise ValueError("Input to ContractableList must be nonempty list")
        for i, item in enumerate(contractable_list):
            if not isinstance(item, Contractable):
                raise ValueError(
                    "Input items to ContractableList must be "
                    f"Contractable instances, but item {i} is not"
                )

        self.contractable_list = contractable_list

    def __mul__(self, contractable, rmul=False):
        """
        Multiply a contractable by everything in ContractableList in order
        """
        # The input cannot be a composite contractable
        assert hasattr(contractable, "tensor")
        output = contractable.tensor

        # Multiply by everything in ContractableList, in the correct order
        if rmul:
            for item in self.contractable_list:
                output = item * output
        else:
            for item in self.contractable_list[::-1]:
                output = output * item

        return output

    def __rmul__(self, contractable):
        """
        Multiply another contractable by everything in ContractableList
        """
        return self.__mul__(contractable, rmul=True)

    def reduce(self, parallel_eval=False):
        """
        Reduce all the contractables in list before multiplying them together
        """
        c_list = self.contractable_list
        # For parallel_eval, reduce all contractables in c_list
        if parallel_eval:
            c_list = [item.reduce() for item in c_list]

        # Multiply together all the contractables. This multiplies in right to
        # left order, but certain inefficient contractions are unsupported.
        # If we encounter an unsupported operation, then try multiplying from
        # the left end of the list instead
        while len(c_list) > 1:
            try:
                c_list[-2] = c_list[-2] * c_list[-1]
                del c_list[-1]
            except TypeError:
                c_list[1] = c_list[0] * c_list[1]
                del c_list[0]

        return c_list[0]


class MatRegion(Contractable):
    """
    A contiguous collection of matrices which are multiplied together

    The input tensor defining our MatRegion must have shape
    [batch_size, num_mats, D, D], or [num_mats, D, D] when the global batch
    size is already known
    """

    def __init__(self, mats):
        shape = list(mats.shape)
        if len(shape) not in [3, 4] or shape[-2] != shape[-1]:
            raise ValueError(
                "MatRegion tensors must have shape "
                "[batch_size, num_mats, D, D], or [num_mats,"
                " D, D] if batch size has already been set"
            )

        super().__init__(mats, bond_str="bslr")

    def __mul__(self, edge_vec, rmul=False):
        """
        Iteratively multiply an input vector with all matrices in MatRegion
        """
        # The input must be an instance of EdgeVec
        if not isinstance(edge_vec, EdgeVec):
            return NotImplemented

        mats = self.tensor
        num_mats = mats.size(1)

        # Load our vector and matrix batches
        dummy_ind = 1 if rmul else 2
        vec = edge_vec.tensor.unsqueeze(dummy_ind)
        mat_list = [mat.squeeze(1) for mat in torch.chunk(mats, num_mats, 1)]

        # Do the repeated matrix-vector multiplications in the proper order
        for i, mat in enumerate(mat_list[:: (1 if rmul else -1)], 1):
            if rmul:
                vec = torch.bmm(vec, mat)
            else:
                vec = torch.bmm(mat, vec)

        # Since we only have a single vector, wrap it as a EdgeVec
        return EdgeVec(vec.squeeze(dummy_ind), is_left_vec=rmul)

    def __rmul__(self, edge_vec):
        return self.__mul__(edge_vec, rmul=True)

    def reduce(self):
        """
        Multiplies together all matrices and returns resultant SingleMat

        This method uses iterated batch multiplication to evaluate the full
        matrix product in depth O( log(num_mats) )
        """
        mats = self.tensor
        shape = list(mats.shape)
        size, D = shape[1:3]

        # Iteratively multiply pairs of matrices until there is only one
        while size > 1:
            odd_size = size % 2 == 1
            half_size = size // 2
            nice_size = 2 * half_size

            even_mats = mats[:, 0:nice_size:2]
            odd_mats = mats[:, 1:nice_size:2]
            # For odd sizes, set aside one batch of matrices for the next round
            leftover = mats[:, nice_size:]

            # Multiply together all pairs of matrices (except leftovers)
            mats = torch.einsum("bslu,bsur->bslr", [even_mats, odd_mats])
            mats = torch.cat([mats, leftover], 1)

            size = half_size + int(odd_size)

        # Since we only have a single matrix, wrap it as a SingleMat
        return SingleMat(mats.squeeze(1))


class OutputCore(Contractable):
    """
    A single MPS core with a single output index
    """

    def __init__(self, tensor):
        # Check the input shape
        if len(tensor.shape) not in [3, 4]:
            raise ValueError(
                "OutputCore tensors must have shape [batch_size, "
                "output_dim, D_l, D_r], or else [output_dim, D_l,"
                " D_r] if batch size has already been set"
            )

        super().__init__(tensor, bond_str="bolr")


class SingleMat(Contractable):
    """
    A batch of matrices associated with a single location in our MPS
    """

    def __init__(self, mat):
        # Check the input shape
        if len(mat.shape) not in [2, 3]:
            raise ValueError(
                "SingleMat tensors must have shape [batch_size, "
                "D_l, D_r], or else [D_l, D_r] if batch size "
                "has already been set"
            )

        super().__init__(mat, bond_str="blr")


class OutputMat(Contractable):
    """
    An output core associated with an edge of our MPS
    """

    def __init__(self, mat, is_left_mat):
        # Check the input shape
        if len(mat.shape) not in [2, 3]:
            raise ValueError(
                "OutputMat tensors must have shape [batch_size, "
                "D, output_dim], or else [D, output_dim] if "
                "batch size has already been set"
            )

        # OutputMats on left edge will have a right-facing bond, and vice versa
        bond_str = "b" + ("r" if is_left_mat else "l") + "o"
        super().__init__(mat, bond_str=bond_str)

    def __mul__(self, edge_vec, rmul=False):
        """
        Multiply with an edge vector along the shared linear index
        """
        if not isinstance(edge_vec, EdgeVec):
            raise NotImplemented  # noqa: F901
        else:
            return super().__mul__(edge_vec, rmul)

    def __rmul__(self, edge_vec):
        return self.__mul__(edge_vec, rmul=True)


class EdgeVec(Contractable):
    """
    A batch of vectors associated with an edge of our MPS

    EdgeVec instances are always associated with an edge of an MPS, which
    requires the is_left_vec flag to be set to True (vector on left edge) or
    False (vector on right edge)
    """

    def __init__(self, vec, is_left_vec):
        # Check the input shape
        if len(vec.shape) not in [1, 2]:
            raise ValueError(
                "EdgeVec tensors must have shape "
                "[batch_size, D], or else [D] if batch size "
                "has already been set"
            )

        # EdgeVecs on left edge will have a right-facing bond, and vice versa
        bond_str = "b" + ("r" if is_left_vec else "l")
        super().__init__(vec, bond_str=bond_str)

    def __mul__(self, right_vec):
        """
        Take the inner product of our vector with another vector
        """
        # The input must be an instance of EdgeVec
        if not isinstance(right_vec, EdgeVec):
            return NotImplemented

        left_vec = self.tensor.unsqueeze(1)
        right_vec = right_vec.tensor.unsqueeze(2)
        batch_size = left_vec.size(0)

        # Do the batch inner product
        scalar = torch.bmm(left_vec, right_vec).view([batch_size])

        # Since we only have a single scalar, wrap it as a Scalar
        return Scalar(scalar)


class Scalar(Contractable):
    """
    A batch of scalars
    """

    def __init__(self, scalar):
        # Add dummy dimension if we have a torch scalar
        shape = list(scalar.shape)
        if shape is []:
            scalar = scalar.view([1])
            shape = [1]

        # Check the input shape
        if len(shape) != 1:
            raise ValueError(
                "input scalar must be a torch tensor with shape "
                "[batch_size], or [] or [1] if batch size has "
                "been set"
            )

        super().__init__(scalar, bond_str="b")

    def __mul__(self, contractable):
        """
        Multiply a contractable by our scalar and return the result
        """
        scalar = self.tensor
        tensor = contractable.tensor
        bond_str = contractable.bond_str

        ein_string = f"{bond_str},b->{bond_str}"
        out_tensor = torch.einsum(ein_string, [tensor, scalar])

        # Wrap the result in the same class right_contractable belongs to
        contract_class = type(contractable)
        if contract_class is not Contractable:
            return contract_class(out_tensor)
        else:
            return Contractable(out_tensor, bond_str)

    def __rmul__(self, contractable):
        # Scalar multiplication is commutative
        return self.__mul__(contractable)

In [None]:
#from Contractable import (
#SingleMat,
#    MatRegion,
#    OutputCore,
#    ContractableList,
#    EdgeVec,
#    OutputMat,
#)


class TI_MPS(nn.Module):
    """
    Sequence MPS which converts input of arbitrary length to a single output vector
    """

    def __init__(
        self,
        output_dim,
        bond_dim,
        feature_dim=2,
        parallel_eval=False,
        fixed_ends=False,
        init_std=1e-9,
        use_bias=True,
        fixed_bias=True,
    ):
        super().__init__()

        # Initialize the core tensor defining our model near the identity
        # This tensor holds all of the trainable parameters of our model
        tensor = init_tensor(
            bond_str="lri",
            shape=[bond_dim, bond_dim, feature_dim],
            init_method=("random_zero", init_std),
        )
        self.register_parameter(name="core_tensor", param=nn.Parameter(tensor))

        # Define our initial vector and terminal matrix, which are both
        # functional modules, i.e. unchanged during training
        assert isinstance(fixed_ends, bool)
        self.init_vector = InitialVector(bond_dim, fixed_vec=fixed_ends)
        self.terminal_mat = TerminalOutput(bond_dim, output_dim, fixed_mat=fixed_ends)

        # Set the bias matrix
        if use_bias:
            # bias_mat is identity when fixed_bias=True, near-identity otherwise
            if fixed_bias:
                bias_mat = torch.eye(bond_dim)
                self.register_buffer(name="bias_mat", tensor=bias_mat)
            else:
                bias_mat = init_tensor(
                    bond_str="lr",
                    shape=[bond_dim, bond_dim],
                    init_method=("random_eye", init_std),
                )
                self.register_parameter(name="bias_mat", param=nn.Parameter(bias_mat))
        else:
            self.bias_mat = None

        # Set the rest of our TI_MPS attributes
        self.feature_dim = feature_dim
        self.output_dim = output_dim
        self.bond_dim = bond_dim
        self.parallel_eval = parallel_eval
        self.use_bias = use_bias
        self.fixed_bias = fixed_bias
        self.feature_map = None

    def forward(self, input_data):
        """
        Converts batch input tensor into a batch output tensor

        Args:
            input_data: A tensor of shape [batch_size, length, feature_dim].
        """

        # Reformat our input to a batch format, padding with zeros as needed
        batch_input = self.format_input(input_data)
        batch_size = batch_input.size(0)
        seq_len = batch_input.size(1)

        # Build up a contractable_list as EdgeVec + MatRegion + OutputMat
        expanded_core = self.core_tensor.expand(
            [seq_len, self.bond_dim, self.bond_dim, self.feature_dim]
        )
        input_region = InputRegion(
            expanded_core,
            use_bias=self.use_bias,
            fixed_bias=self.fixed_bias,
            bias_mat=self.bias_mat,
            ephemeral=True,
        )
        contractable_list = [input_region(batch_input)]

        # Prepend an EdgeVec and append an OutputMat
        contractable_list = [self.init_vector()] + contractable_list
        contractable_list.append(self.terminal_mat())

        # Wrap contractable_list as a ContractableList instance
        contractable_list = ContractableList(contractable_list)

        # Contract everything in contractable_list
        output = contractable_list.reduce(parallel_eval=self.parallel_eval)
        batch_output = output.tensor

        # Check shape before returning output values
        assert output.bond_str == "bo"
        assert batch_output.size(0) == batch_size
        assert batch_output.size(1) == self.output_dim

        return batch_output

    def format_input(self, input_data):
        """
        Converts input list of sequences into a single batch sequence tensor.

        If input is already a batch tensor, it is returned unchanged. Otherwise,
        convert input list into a batch sequence with shape [batch_size, length,
        feature_dim].

        If self.use_bias = self.fixed_bias = True, then sequences of different
        lengths can be used, in which case shorter sequences are padded with
        zeros at the end, making the batch tensor length equal to the length
        of the longest input sequence.

        Args:
            input_data: A tensor of shape [batch_size, length] or
            [batch_size, length, feature_dim], or a list of length batch_size,
            whose i'th item is a tensor of shape [length_i, feature_dim] or
            [length_i]. If self.use_bias or self.fixed_bias are False, then
            length_i must be the same for all i.
        """
        feature_dim = self.feature_dim

        # If we get a batch tensor, just embed it and/or return it unchanged
        if isinstance(input_data, torch.Tensor):
            if len(input_data.shape) == 2:
                input_data = self.embed_input(input_data)

            # Check to make sure shape is alright
            shape = input_data.shape
            assert len(shape) == 3
            assert shape[2] == feature_dim

            return input_data

        # Collate the input list into a single batch tensor
        elif isinstance(input_data, list):
            # Check formatting, require that input sequences are either all
            # unembedded or all pre-embedded
            num_modes = len(input_data[0].shape)
            assert num_modes in [1, 2]
            assert all(
                [
                    isinstance(s, torch.Tensor) and len(s.shape) == num_modes
                    for s in input_data
                ]
            )
            assert num_modes == 1 or all([s.size(1) == feature_dim for s in input_data])

            # Check that all the sequences are the same length or can be padded
            max_len = max([s.size(0) for s in input_data])
            can_pad = self.use_bias and self.fixed_bias
            if not can_pad and any([s.size(0) != max_len for s in input_data]):
                raise ValueError(
                    "To process input_data as list of sequences "
                    "with different lengths, must have self.use_bias="
                    "self.fixed_bias=True (currently self.use_bias="
                    f"{self.use_bias}, self.fixed_bias={self.fixed_bias})"
                )

            # Pad the sequences with zeros (if needed), return as batch tensor
            if can_pad:
                batch_size = len(input_data)
                full_size = [batch_size, max_len, feature_dim]
                batch_input = torch.zeros(full_size[: num_modes + 1])

                # Copy each sequence into batch_input
                for i, seq in enumerate(input_data):
                    batch_input[i, : seq.size(0)] = seq
            else:
                batch_input = torch.stack(input_data)

            # Embed everything (if needed) and return the batch tensor
            if len(batch_input.shape) == 2:
                batch_input = self.embed_input(batch_input)

            return batch_input

        else:
            raise ValueError(
                "input_data must either be Tensor with shape"
                "[batch_size, length] or [batch_size, length, feature_dim], "
                "or list of Tensors with shapes [length_i, feature_dim] or "
                "[length_i]"
            )

    def embed_input(self, input_data):
        """
        Embed pixels of input_data into separate local feature spaces

        Args:
            input_data (Tensor):    Input with shape [batch_size, length].

        Returns:
            embedded_data (Tensor): Input embedded into a tensor with shape
                                    [batch_size, input_dim, feature_dim]
        """
        assert len(input_data.shape) == 2

        # Get relevant dimensions
        batch_dim, length = input_data.shape
        feature_dim = self.feature_dim
        embedded_shape = [batch_dim, length, feature_dim]

        # Apply a custom embedding map if it has been defined by the user
        if self.feature_map is not None:
            f_map = self.feature_map
            embedded_data = torch.stack(
                [torch.stack([f_map(x) for x in batch]) for batch in input_data]
            )

            # Make sure our embedded input has the desired size
            assert list(embedded_data.shape) == embedded_shape

        # Otherwise, use a simple linear embedding map with feature_dim = 2
        else:
            if self.feature_dim != 2:
                raise RuntimeError(
                    f"self.feature_dim = {feature_dim}, but "
                    "default feature_map requires self.feature_dim = 2"
                )
            embedded_data = torch.empty(embedded_shape)

            embedded_data[:, :, 0] = input_data
            embedded_data[:, :, 1] = 1 - input_data

        return embedded_data

    def register_feature_map(self, feature_map):
        """
        Register a custom feature map to be used for embedding input data

        Args:
            feature_map (function): Takes a single scalar input datum and
                                    returns an embedded representation of the
                                    image. The output size of the function must
                                    match self.feature_dim. If feature_map=None,
                                    then the feature map will be reset to a
                                    simple default linear embedding
        """
        if feature_map is not None:
            # Test to make sure feature_map outputs vector of proper size
            test_out = feature_map(torch.tensor(0))
            assert isinstance(test_out, torch.Tensor)

            out_shape, needed_shape = list(test_out.shape), [self.feature_dim]
            if out_shape != needed_shape:
                raise ValueError(
                    "Given feature_map returns values with shape "
                    f"{list(out_shape)}, but should return "
                    f"values of size {list(needed_shape)}"
                )

        self.feature_map = feature_map


class MPS(nn.Module):
    """
    Tunable MPS model giving mapping from fixed-size data to output vector

    Model works by first converting each 'pixel' (local data) to feature
    vector via a simple embedding, then contracting embeddings with inputs
    to each MPS cores. The resulting transition matrices are contracted
    together along bond dimensions (i.e. hidden state spaces), with output
    produced via an uncontracted edge of an additional output core.

    MPS model permits many customizable behaviors, including custom
    'routing' of MPS through the input, choice of boundary conditions
    (meaning the model can act as a tensor train or a tensor ring),
    GPU-friendly parallel evaluation, and an experimental mode to support
    adaptive choice of bond dimensions based on singular value spectrum.

    Args:
        input_dim:       Number of 'pixels' in the input to the MPS
        output_dim:      Size of the vectors output by MPS via output core
        bond_dim:        Dimension of the 'bonds' connecting adjacent MPS
                         cores, which act as hidden state spaces of the
                         model. In adaptive mode, bond_dim instead
                         specifies the maximum allowed bond dimension
        feature_dim:     Size of the local feature spaces each pixel is
                         embedded into (default: 2)
        periodic_bc:     Whether MPS has periodic boundary conditions (i.e.
                         is a tensor ring) or open boundary conditions
                         (i.e. is a tensor train) (default: False)
        parallel_eval:   Whether contraction of tensors is performed in a
                         serial or parallel fashion. The former is less
                         expensive for open boundary conditions, but
                         parallelizes more poorly (default: False)
        label_site:      Location in the MPS chain where output is placed
                         (default: input_dim // 2)
        path:            List specifying a path through the input data
                         which MPS is 'routed' along. For example, choosing
                         path=[0, 1, ..., input_dim-1] gives a standard
                         in-order traversal (behavior when path=None), while
                         path=[0, 2, ..., input_dim-1] specifies an MPS
                         accepting input only from even-valued input pixels
                         (default: None)
        init_std:        Size of the Gaussian noise used in default
                         near-identity initialization (default: 1e-9)
        initializer:     Pytorch initializer for custom initialization of
                         MPS cores, with None specifying default
                         near-identity initialization (default: None)
        use_bias:        Whether to use trainable bias matrices in MPS
                         cores, which are initialized near the zero matrix
                         (default: False)
        adaptive_mode:   Whether MPS is trained with experimental adaptive
                         bond dimensions selection (default: False)
        cutoff:          Singular value cutoff controlling bond dimension
                         adaptive selection (default: 1e-9)
        merge_threshold: Number of inputs before adaptive MPS shifts its
                         merge state once, with two shifts leading to the
                         update of all bond dimensions (default: 2000)
    """

    # TODO: Support arbitrary initializers
    # TODO: Clean up the current treatment of initialization
    # TODO: Resolve weirdness with fixed bias and initialization choice
    # TODO: Add function to convert to canonical form
    # TODO: Fix issue of no training when use_bias=False

    def __init__(
        self,
        input_dim,
        output_dim,
        bond_dim,
        feature_dim=2,
        periodic_bc=False,
        parallel_eval=False,
        label_site=None,
        path=None,
        init_std=1e-9,
        initializer=None,
        use_bias=True,
        adaptive_mode=False,
        cutoff=1e-10,
        merge_threshold=2000,
    ):
        super().__init__()

        if label_site is None:
            label_site = input_dim // 2
        assert label_site >= 0 and label_site <= input_dim

        # Using bias matrices in adaptive_mode is too complicated, so I'm
        # disabling it here
        if adaptive_mode:
            use_bias = False

        # Our MPS is made of two InputRegions separated by an OutputSite.
        module_list = []
        init_args = {
            "bond_str": "slri",
            "shape": [label_site, bond_dim, bond_dim, feature_dim],
            "init_method": (
                "min_random_eye" if adaptive_mode else "random_zero",
                init_std,
                output_dim,
            ),
        }

        # The first input region
        if label_site > 0:
            tensor = init_tensor(**init_args)

            module_list.append(InputRegion(tensor, use_bias=use_bias, fixed_bias=False))

        # The output site
        tensor = init_tensor(
            shape=[output_dim, bond_dim, bond_dim],
            bond_str="olr",
            init_method=(
                "min_random_eye" if adaptive_mode else "random_eye",
                init_std,
                output_dim,
            ),
        )
        module_list.append(OutputSite(tensor))

        # The other input region
        if label_site < input_dim:
            init_args["shape"] = [
                input_dim - label_site,
                bond_dim,
                bond_dim,
                feature_dim,
            ]
            tensor = init_tensor(**init_args)
            module_list.append(InputRegion(tensor, use_bias=use_bias, fixed_bias=False))

        # Initialize linear_region according to our adaptive_mode specification
        if adaptive_mode:
            self.linear_region = MergedLinearRegion(
                module_list=module_list,
                periodic_bc=periodic_bc,
                parallel_eval=parallel_eval,
                cutoff=cutoff,
                merge_threshold=merge_threshold,
            )

            # Initialize the list of bond dimensions, which starts out constant
            self.bond_list = bond_dim * torch.ones(input_dim + 2, dtype=torch.long)
            if not periodic_bc:
                self.bond_list[0], self.bond_list[-1] = 1, 1

            # Initialize the list of singular values, which start out at -1
            self.sv_list = -1.0 * torch.ones([input_dim + 2, bond_dim])

        else:
            self.linear_region = LinearRegion(
                module_list=module_list,
                periodic_bc=periodic_bc,
                parallel_eval=parallel_eval,
            )
        assert len(self.linear_region) == input_dim

        if path:
            assert isinstance(path, (list, torch.Tensor))
            assert len(path) == input_dim

        # Set the rest of our MPS attributes
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.bond_dim = bond_dim
        self.feature_dim = feature_dim
        self.periodic_bc = periodic_bc
        self.adaptive_mode = adaptive_mode
        self.label_site = label_site
        self.path = path
        self.use_bias = use_bias
        self.cutoff = cutoff
        self.merge_threshold = merge_threshold
        self.feature_map = None

    def forward(self, input_data):
        """
        Embed our data and pass it to an MPS with a single output site

        Args:
            input_data (Tensor): Input with shape [batch_size, input_dim] or
                                 [batch_size, input_dim, feature_dim]. In the
                                 former case, the data points are turned into
                                 2D vectors using a default linear feature map.

                                 When using a user-specified path, the size of
                                 the second tensor mode need not exactly equal
                                 input_dim, since the path variable is used to
                                 slice a certain subregion of input_data. This
                                 can be used to define multiple MPS 'strings',
                                 which act on different parts of the input.
        """
        # For custom paths, rearrange our input into the desired order
        if self.path:
            path_inputs = []
            for site_num in self.path:
                path_inputs.append(input_data[:, site_num])
            input_data = torch.stack(path_inputs, dim=1)

        # Embed our input data before feeding it into our linear region
        input_data = self.embed_input(input_data)
        output = self.linear_region(input_data)

        # If we got a tuple as output, then use the last two entries to
        # update our bond dimensions and singular values
        if isinstance(output, tuple):
            output, new_bonds, new_svs = output

            assert len(new_bonds) == len(self.bond_list)
            assert len(new_bonds) == len(new_svs)
            for i, bond_dim in enumerate(new_bonds):
                if bond_dim != -1:
                    assert new_svs[i] is not -1
                    self.bond_list[i] = bond_dim
                    self.sv_list[i] = new_svs[i]

        return output

    def embed_input(self, input_data):
        """
        Embed pixels of input_data into separate local feature spaces

        Args:
            input_data (Tensor):    Input with shape [batch_size, input_dim], or
                                    [batch_size, input_dim, feature_dim]. In the
                                    latter case, the data is assumed to already
                                    be embedded, and is returned unchanged.

        Returns:
            embedded_data (Tensor): Input embedded into a tensor with shape
                                    [batch_size, input_dim, feature_dim]
        """
        assert len(input_data.shape) in [2, 3]
        assert input_data.size(1) == self.input_dim

        # If input already has a feature dimension, return it as is
        if len(input_data.shape) == 3:
            if input_data.size(2) != self.feature_dim:
                raise ValueError(
                    f"input_data has wrong shape to be unembedded "
                    "or pre-embedded data (input_data.shape = "
                    f"{list(input_data.shape)}, feature_dim = {self.feature_dim})"
                )
            return input_data

        # Apply a custom embedding map if it has been defined by the user
        if self.feature_map is not None:
            f_map = self.feature_map
            embedded_data = torch.stack(
                [torch.stack([f_map(x) for x in batch]) for batch in input_data]
            )

            # Make sure our embedded input has the desired size
            assert embedded_data.shape == torch.Size(
                [input_data.size(0), self.input_dim, self.feature_dim]
            )

        # Otherwise, use a simple linear embedding map with feature_dim = 2
        else:
            if self.feature_dim != 2:
                raise RuntimeError(
                    f"self.feature_dim = {self.feature_dim}, "
                    "but default feature_map requires self.feature_dim = 2"
                )

            embedded_data = torch.stack([input_data, 1 - input_data], dim=2)

        return embedded_data

    def register_feature_map(self, feature_map):
        """
        Register a custom feature map to be used for embedding input data

        Args:
            feature_map (function): Takes a single scalar input datum and
                                    returns an embedded representation of the
                                    image. The output size of the function must
                                    match self.feature_dim. If feature_map=None,
                                    then the feature map will be reset to a
                                    simple default linear embedding
        """
        if feature_map is not None:
            # Test to make sure feature_map outputs vector of proper size
            out_shape = feature_map(torch.tensor(0)).shape
            needed_shape = torch.Size([self.feature_dim])
            if out_shape != needed_shape:
                raise ValueError(
                    "Given feature_map returns values of size "
                    f"{list(out_shape)}, but should return "
                    f"values of size {list(needed_shape)}"
                )

        self.feature_map = feature_map

    def core_len(self):
        """
        Returns the number of cores, which is at least the required input size
        """
        return self.linear_region.core_len()

    def __len__(self):
        """
        Returns the number of input sites, which equals the input size
        """
        return self.input_dim


class LinearRegion(nn.Module):
    """
    List of modules which feeds input to each module and returns reduced output
    """

    def __init__(
        self, module_list, periodic_bc=False, parallel_eval=False, module_states=None
    ):
        # Check that module_list is a list whose entries are Pytorch modules
        if not isinstance(module_list, list) or module_list is []:
            raise ValueError("Input to LinearRegion must be nonempty list")
        for i, item in enumerate(module_list):
            if not isinstance(item, nn.Module):
                raise ValueError(
                    "Input items to LinearRegion must be PyTorch "
                    f"Module instances, but item {i} is not"
                )
        super().__init__()

        # Wrap as a ModuleList for proper parameter registration
        self.module_list = nn.ModuleList(module_list)
        self.periodic_bc = periodic_bc
        self.parallel_eval = parallel_eval

    def forward(self, input_data):
        """
        Contract input with list of MPS cores and return result as contractable

        Args:
            input_data (Tensor): Input with shape [batch_size, input_dim,
                                                   feature_dim]
        """
        # Check that input_data has the correct shape
        assert len(input_data.shape) == 3
        assert input_data.size(1) == len(self)
        periodic_bc = self.periodic_bc
        parallel_eval = self.parallel_eval
        lin_bonds = ["l", "r"]

        # Whether to move intermediate vectors to a GPU (fixes Issue #8)
        to_cuda = input_data.is_cuda
        device = f"cuda:{input_data.get_device()}" if to_cuda else "cpu"

        # For each module, pull out the number of pixels needed and call that
        # module's forward() method, putting the result in contractable_list
        ind = 0
        contractable_list = []
        for module in self.module_list:
            mod_len = len(module)
            if mod_len == 1:
                mod_input = input_data[:, ind]
            else:
                mod_input = input_data[:, ind : (ind + mod_len)]
            ind += mod_len

            contractable_list.append(module(mod_input))

        # For periodic boundary conditions, reduce contractable_list and
        # trace over the left and right indices to get our output
        if periodic_bc:
            contractable_list = ContractableList(contractable_list)
            contractable = contractable_list.reduce(parallel_eval=True)

            # Unpack the output (atomic) contractable
            tensor, bond_str = contractable.tensor, contractable.bond_str
            assert all(c in bond_str for c in lin_bonds)

            # Build einsum string for the trace of tensor
            in_str, out_str = "", ""
            for c in bond_str:
                if c in lin_bonds:
                    in_str += "l"
                else:
                    in_str += c
                    out_str += c
            ein_str = in_str + "->" + out_str

            # Return the trace over left and right indices
            return torch.einsum(ein_str, [tensor])

        # For open boundary conditions, add dummy edge vectors to
        # contractable_list and reduce everything to get our output
        else:
            # Get the dimension of left and right bond indices
            end_items = [contractable_list[i] for i in [0, -1]]
            bond_strs = [item.bond_str for item in end_items]
            bond_inds = [bs.index(c) for (bs, c) in zip(bond_strs, lin_bonds)]
            bond_dims = [
                item.tensor.size(ind) for (item, ind) in zip(end_items, bond_inds)
            ]

            # Build dummy end vectors and insert them at the ends of our list
            end_vecs = [torch.zeros(dim).to(device) for dim in bond_dims]

            for vec in end_vecs:
                vec[0] = 1
            contractable_list.insert(0, EdgeVec(end_vecs[0], is_left_vec=True))
            contractable_list.append(EdgeVec(end_vecs[1], is_left_vec=False))

            # Multiply together everything in contractable_list
            contractable_list = ContractableList(contractable_list)
            output = contractable_list.reduce(parallel_eval=parallel_eval)

            return output.tensor

    def core_len(self):
        """
        Returns the number of cores, which is at least the required input size
        """
        return sum([module.core_len() for module in self.module_list])

    def __len__(self):
        """
        Returns the number of input sites, which is the required input size
        """
        return sum([len(module) for module in self.module_list])


class MergedLinearRegion(LinearRegion):
    """
    Dynamic variant of LinearRegion that periodically rearranges its submodules
    """

    def __init__(
        self,
        module_list,
        periodic_bc=False,
        parallel_eval=False,
        cutoff=1e-10,
        merge_threshold=2000,
    ):
        # Initialize a LinearRegion with our given module_list
        super().__init__(module_list, periodic_bc, parallel_eval)

        # Initialize attributes self.module_list_0 and self.module_list_1
        # using the unmerged self.module_list, then redefine the latter in
        # terms of one of the former lists
        self.offset = 0
        self._merge(offset=self.offset)
        self._merge(offset=(self.offset + 1) % 2)
        self.module_list = getattr(self, f"module_list_{self.offset}")

        # Initialize variables used during switching
        self.input_counter = 0
        self.merge_threshold = merge_threshold
        self.cutoff = cutoff

    def forward(self, input_data):
        """
        Contract input with list of MPS cores and return result as contractable

        MergedLinearRegion keeps an input counter of the number of inputs, and
        when this exceeds its merge threshold, triggers an unmerging and
        remerging of its parameter tensors.

        Args:
            input_data (Tensor): Input with shape [batch_size, input_dim,
                                                   feature_dim]
        """
        # If we've hit our threshold, flip the merge state of our tensors
        if self.input_counter >= self.merge_threshold:
            bond_list, sv_list = self._unmerge(cutoff=self.cutoff)
            self.offset = (self.offset + 1) % 2
            self._merge(offset=self.offset)
            self.input_counter -= self.merge_threshold

            # Point self.module_list to the appropriate merged module
            self.module_list = getattr(self, f"module_list_{self.offset}")
        else:
            bond_list, sv_list = None, None

        # Increment our counter and call the LinearRegion's forward method
        self.input_counter += input_data.size(0)
        output = super().forward(input_data)

        # If we flipped our merge state, then return the bond_list and output
        if bond_list:
            return output, bond_list, sv_list
        else:
            return output

    @torch.no_grad()
    def _merge(self, offset):
        """
        Convert unmerged modules in self.module_list to merged counterparts

        Calling _merge (or _unmerge) directly can cause undefined behavior,
        but see MergedLinearRegion.forward for intended use

        This proceeds by first merging all unmerged cores internally, then
        merging lone cores when possible during a second sweep
        """
        assert offset in [0, 1]

        unmerged_list = self.module_list

        # Merge each core internally and add the results to midway_list
        site_num = offset
        merged_list = []
        for core in unmerged_list:
            assert not isinstance(core, MergedInput)
            assert not isinstance(core, MergedOutput)

            # Apply internal merging routine if our core supports it
            if hasattr(core, "_merge"):
                merged_list.extend(core._merge(offset=site_num % 2))
            else:
                merged_list.append(core)

            site_num += core.core_len()

        # Merge pairs of cores when possible (currently only with
        # InputSites), making sure to respect the offset for merging.
        while True:
            mod_num, site_num = 0, 0
            combined_list = []

            while mod_num < len(merged_list) - 1:
                left_core, right_core = merged_list[mod_num : mod_num + 2]
                new_core = self.combine(left_core, right_core, merging=True)

                # If cores aren't combinable, move our sliding window by 1
                if new_core is None or offset != site_num % 2:
                    combined_list.append(left_core)
                    mod_num += 1
                    site_num += left_core.core_len()

                # If we get something new, move to the next distinct pair
                else:
                    assert (
                        new_core.core_len()
                        == left_core.core_len() + right_core.core_len()
                    )
                    combined_list.append(new_core)
                    mod_num += 2
                    site_num += new_core.core_len()

                # Add the last core if there's nothing to merge it with
                if mod_num == len(merged_list) - 1:
                    combined_list.append(merged_list[mod_num])
                    mod_num += 1

            # We're finished when unmerged_list remains unchanged
            if len(combined_list) == len(merged_list):
                break
            else:
                merged_list = combined_list

        # Finally, update the appropriate merged module list
        list_name = f"module_list_{offset}"
        # If the merged module list hasn't been set yet, initialize it
        if not hasattr(self, list_name):
            setattr(self, list_name, nn.ModuleList(merged_list))

        # Otherwise, do an in-place update so that all tensors remain
        # properly registered with whatever optimizer we use
        else:
            module_list = getattr(self, list_name)
            assert len(module_list) == len(merged_list)
            for i in range(len(module_list)):
                assert module_list[i].tensor.shape == merged_list[i].tensor.shape
                module_list[i].tensor[:] = merged_list[i].tensor

    @torch.no_grad()
    def _unmerge(self, cutoff=1e-10):
        """
        Convert merged modules to unmerged counterparts

        Calling _unmerge (or _merge) directly can cause undefined behavior,
        but see MergedLinearRegion.forward for intended use

        This proceeds by first unmerging all merged cores internally, then
        combining lone cores where possible
        """
        list_name = f"module_list_{self.offset}"
        merged_list = getattr(self, list_name)

        # Unmerge each core internally and add results to unmerged_list
        unmerged_list, bond_list, sv_list = [], [-1], [-1]
        for core in merged_list:

            # Apply internal unmerging routine if our core supports it
            if hasattr(core, "_unmerge"):
                new_cores, new_bonds, new_svs = core._unmerge(cutoff)
                unmerged_list.extend(new_cores)
                bond_list.extend(new_bonds[1:])
                sv_list.extend(new_svs[1:])
            else:
                assert not isinstance(core, InputRegion)
                unmerged_list.append(core)
                bond_list.append(-1)
                sv_list.append(-1)

        # Combine all combinable pairs of cores. This occurs in several
        # passes, and for now acts nontrivially only on InputSite instances
        while True:
            mod_num = 0
            combined_list = []

            while mod_num < len(unmerged_list) - 1:
                left_core, right_core = unmerged_list[mod_num : mod_num + 2]
                new_core = self.combine(left_core, right_core, merging=False)

                # If cores aren't combinable, move our sliding window by 1
                if new_core is None:
                    combined_list.append(left_core)
                    mod_num += 1

                # If we get something new, move to the next distinct pair
                else:
                    combined_list.append(new_core)
                    mod_num += 2

                # Add the last core if there's nothing to combine it with
                if mod_num == len(unmerged_list) - 1:
                    combined_list.append(unmerged_list[mod_num])
                    mod_num += 1

            # We're finished when unmerged_list remains unchanged
            if len(combined_list) == len(unmerged_list):
                break
            else:
                unmerged_list = combined_list

        # Find the average (log) norm of all of our cores
        log_norms = []
        for core in unmerged_list:
            log_norms.append([torch.log(norm) for norm in core.get_norm()])
        log_scale = sum([sum(ns) for ns in log_norms])
        log_scale /= sum([len(ns) for ns in log_norms])

        # Now rescale all cores so that their norms are roughly equal
        scales = [[torch.exp(log_scale - n) for n in ns] for ns in log_norms]
        for core, these_scales in zip(unmerged_list, scales):
            core.rescale_norm(these_scales)

        # Add our unmerged module list as a new attribute and return
        # the updated bond dimensions
        self.module_list = nn.ModuleList(unmerged_list)
        return bond_list, sv_list

    def combine(self, left_core, right_core, merging):
        """
        Combine a pair of cores into a new core using context-dependent rules

        Depending on the types of left_core and right_core, along with whether
        we're currently merging (merging=True) or unmerging (merging=False),
        either return a new core, or None if no rule exists for this context
        """

        # Combine an OutputSite with a stray InputSite, return a MergedOutput
        if merging and (
            (isinstance(left_core, OutputSite) and isinstance(right_core, InputSite))
            or (isinstance(left_core, InputSite) and isinstance(right_core, OutputSite))
        ):

            left_site = isinstance(left_core, InputSite)
            if left_site:
                new_tensor = torch.einsum(
                    "lui,our->olri", [left_core.tensor, right_core.tensor]
                )
            else:
                new_tensor = torch.einsum(
                    "olu,uri->olri", [left_core.tensor, right_core.tensor]
                )
            return MergedOutput(new_tensor, left_output=(not left_site))

        # Combine an InputRegion with a stray InputSite, return an InputRegion
        elif not merging and (
            (isinstance(left_core, InputRegion) and isinstance(right_core, InputSite))
            or (
                isinstance(left_core, InputSite) and isinstance(right_core, InputRegion)
            )
        ):

            left_site = isinstance(left_core, InputSite)
            if left_site:
                left_tensor = left_core.tensor.unsqueeze(0)
                right_tensor = right_core.tensor
            else:
                left_tensor = left_core.tensor
                right_tensor = right_core.tensor.unsqueeze(0)

            assert left_tensor.shape[1:] == right_tensor.shape[1:]
            new_tensor = torch.cat([left_tensor, right_tensor])

            return InputRegion(new_tensor)

        # If this situation doesn't belong to the above cases, return None
        else:
            return None

    def core_len(self):
        """
        Returns the number of cores, which is at least the required input size
        """
        return sum([module.core_len() for module in self.module_list])

    def __len__(self):
        """
        Returns the number of input sites, which is the required input size
        """
        return sum([len(module) for module in self.module_list])


class InputRegion(nn.Module):
    """
    Contiguous region of MPS input cores, associated with bond_str = 'slri'
    """

    def __init__(
        self, tensor, use_bias=True, fixed_bias=True, bias_mat=None, ephemeral=False
    ):
        super().__init__()

        # Make sure tensor has correct size and the component mats are square
        assert len(tensor.shape) == 4
        assert tensor.size(1) == tensor.size(2)
        bond_dim = tensor.size(1)

        # If we are using bias matrices, set those up here
        if use_bias:
            assert bias_mat is None or isinstance(bias_mat, torch.Tensor)
            bias_mat = (
                torch.eye(bond_dim).unsqueeze(0) if bias_mat is None else bias_mat
            )

            bias_modes = len(list(bias_mat.shape))
            assert bias_modes in [2, 3]
            if bias_modes == 2:
                bias_mat = bias_mat.unsqueeze(0)

        # Register our tensors as a Pytorch Parameter or Tensor
        if ephemeral:
            self.register_buffer(name="tensor", tensor=tensor.contiguous())
            self.register_buffer(name="bias_mat", tensor=bias_mat)
        else:
            self.register_parameter(
                name="tensor", param=nn.Parameter(tensor.contiguous())
            )
            if fixed_bias:
                self.register_buffer(name="bias_mat", tensor=bias_mat)
            else:
                self.register_parameter(name="bias_mat", param=nn.Parameter(bias_mat))

        self.use_bias = use_bias
        self.fixed_bias = fixed_bias

    def forward(self, input_data):
        """
        Contract input with MPS cores and return result as a MatRegion

        Args:
            input_data (Tensor): Input with shape [batch_size, input_dim,
                                                   feature_dim]
        """
        # Check that input_data has the correct shape
        tensor = self.tensor
        assert len(input_data.shape) == 3
        assert input_data.size(1) == len(self)
        assert input_data.size(2) == tensor.size(3)

        # Contract the input with our core tensor
        mats = torch.einsum("slri,bsi->bslr", [tensor, input_data])

        # If we're using bias matrices, add those here
        if self.use_bias:
            bias_mat = self.bias_mat.unsqueeze(0)
            mats = mats + bias_mat.expand_as(mats)

        return MatRegion(mats)

    def _merge(self, offset):
        """
        Merge all pairs of neighboring cores and return a new list of cores

        offset is either 0 or 1, which gives the first core at which we start
        our merging. Depending on the length of our InputRegion, the output of
        merge may have 1, 2, or 3 entries, with the majority of sites ending in
        a MergedInput instance
        """
        assert offset in [0, 1]
        num_sites = self.core_len()
        parity = num_sites % 2

        # Cases with empty tensors might arise in recursion below
        if num_sites == 0:
            return [None]

        # Simplify the problem into one where offset=0 and num_sites is even
        if (offset, parity) == (1, 1):
            out_list = [self[0], self[1:]._merge(offset=0)[0]]
        elif (offset, parity) == (1, 0):
            out_list = [self[0], self[1:-1]._merge(offset=0)[0], self[-1]]
        elif (offset, parity) == (0, 1):
            out_list = [self[:-1]._merge(offset=0)[0], self[-1]]

        # The main case of interest, with no offset and an even number of sites
        else:
            tensor = self.tensor
            even_cores, odd_cores = tensor[0::2], tensor[1::2]
            assert len(even_cores) == len(odd_cores)

            # Multiply all pairs of cores, keeping inputs separate
            merged_cores = torch.einsum("slui,surj->slrij", [even_cores, odd_cores])
            out_list = [MergedInput(merged_cores)]

        # Remove empty MergedInputs, which appear in very small InputRegions
        return [x for x in out_list if x is not None]

    def __getitem__(self, key):
        """
        Returns an InputRegion instance sliced along the site index
        """
        assert isinstance(key, int) or isinstance(key, slice)

        if isinstance(key, slice):
            return InputRegion(self.tensor[key])
        else:
            return InputSite(self.tensor[key])

    def get_norm(self):
        """
        Returns list of the norms of each core in InputRegion
        """
        return [torch.norm(core) for core in self.tensor]

    @torch.no_grad()
    def rescale_norm(self, scale_list):
        """
        Rescales the norm of each core by an amount specified in scale_list

        For the i'th tensor defining a core in InputRegion, we rescale as
        tensor_i <- scale_i * tensor_i, where scale_i = scale_list[i]
        """
        assert len(scale_list) == len(self.tensor)

        for core, scale in zip(self.tensor, scale_list):
            core *= scale

    def core_len(self):
        return len(self)

    def __len__(self):
        return self.tensor.size(0)


class MergedInput(nn.Module):
    """
    Contiguous region of merged MPS cores, each taking in a pair of input data

    Since MergedInput arises after contracting together existing input cores,
    a merged input tensor is required for initialization
    """

    def __init__(self, tensor):
        # Check that our input tensor has the correct shape
        # bond_str = "slrij"
        shape = tensor.shape
        assert len(shape) == 5
        assert shape[1] == shape[2]
        assert shape[3] == shape[4]

        super().__init__()

        # Register our tensor as a Pytorch Parameter
        self.register_parameter(name="tensor", param=nn.Parameter(tensor.contiguous()))

    def forward(self, input_data):
        """
        Contract input with merged MPS cores and return result as a MatRegion

        Args:
            input_data (Tensor): Input with shape [batch_size, input_dim,
                                 feature_dim], where input_dim must be even
                                 (each merged core takes 2 inputs)
        """
        # Check that input_data has the correct shape
        tensor = self.tensor
        assert len(input_data.shape) == 3
        assert input_data.size(1) == len(self)
        assert input_data.size(2) == tensor.size(3)
        assert input_data.size(1) % 2 == 0

        # Divide input_data into inputs living on even and on odd sites
        inputs = [input_data[:, 0::2], input_data[:, 1::2]]

        # Contract the odd (right-most) and even inputs with merged cores
        tensor = torch.einsum("slrij,bsj->bslri", [tensor, inputs[1]])
        mats = torch.einsum("bslri,bsi->bslr", [tensor, inputs[0]])

        return MatRegion(mats)

    def _unmerge(self, cutoff=1e-10):
        """
        Separate the cores in our MergedInput and return an InputRegion

        The length of the resultant InputRegion will be identical to our
        original MergedInput (same number of inputs), but its core_len will
        be doubled (twice as many individual cores)
        """
        # bond_str = "slrij"
        tensor = self.tensor
        svd_string = "lrij->lui,urj"
        max_D = tensor.size(1)

        # Split every one of the cores into two and add them both to core_list
        core_list, bond_list, sv_list = [], [-1], [-1]
        for merged_core in tensor:
            sv_vec = torch.empty(max_D)
            left_core, right_core, bond_dim = svd_flex(
                merged_core, svd_string, max_D, cutoff, sv_vec=sv_vec
            )

            core_list += [left_core, right_core]
            bond_list += [bond_dim, -1]
            sv_list += [sv_vec, -1]

        # Collate the split cores into one tensor and return as an InputRegion
        tensor = torch.stack(core_list)
        return [InputRegion(tensor)], bond_list, sv_list

    def get_norm(self):
        """
        Returns list of the norm of each core in MergedInput
        """
        return [torch.norm(core) for core in self.tensor]

    @torch.no_grad()
    def rescale_norm(self, scale_list):
        """
        Rescales the norm of each core by an amount specified in scale_list

        For the i'th tensor defining a core in MergedInput, we rescale as
        tensor_i <- scale_i * tensor_i, where scale_i = scale_list[i]
        """
        assert len(scale_list) == len(self.tensor)

        for core, scale in zip(self.tensor, scale_list):
            core *= scale

    def core_len(self):
        return len(self)

    def __len__(self):
        """
        Returns the number of input sites, which is twice the number of cores
        """
        return 2 * self.tensor.size(0)


class InputSite(nn.Module):
    """
    A single MPS core which takes in a single input datum, bond_str = 'lri'
    """

    def __init__(self, tensor):
        super().__init__()
        # Register our tensor as a Pytorch Parameter
        self.register_parameter(name="tensor", param=nn.Parameter(tensor.contiguous()))

    def forward(self, input_data):
        """
        Contract input with MPS core and return result as a SingleMat

        Args:
            input_data (Tensor): Input with shape [batch_size, feature_dim]
        """
        # Check that input_data has the correct shape
        tensor = self.tensor
        assert len(input_data.shape) == 2
        assert input_data.size(1) == tensor.size(2)

        # Contract the input with our core tensor
        mat = torch.einsum("lri,bi->blr", [tensor, input_data])

        return SingleMat(mat)

    def get_norm(self):
        """
        Returns the norm of our core tensor, wrapped as a singleton list
        """
        return [torch.norm(self.tensor)]

    @torch.no_grad()
    def rescale_norm(self, scale):
        """
        Rescales the norm of our core by a factor of input `scale`
        """
        if isinstance(scale, list):
            assert len(scale) == 1
            scale = scale[0]

        self.tensor *= scale

    def core_len(self):
        return 1

    def __len__(self):
        return 1


class OutputSite(nn.Module):
    """
    A single MPS core with no input and a single output index, bond_str = 'olr'
    """

    def __init__(self, tensor):
        super().__init__()
        # Register our tensor as a Pytorch Parameter
        self.register_parameter(name="tensor", param=nn.Parameter(tensor.contiguous()))

    def forward(self, input_data):
        """
        Return the OutputSite wrapped as an OutputCore contractable
        """
        return OutputCore(self.tensor)

    def get_norm(self):
        """
        Returns the norm of our core tensor, wrapped as a singleton list
        """
        return [torch.norm(self.tensor)]

    @torch.no_grad()
    def rescale_norm(self, scale):
        """
        Rescales the norm of our core by a factor of input `scale`
        """
        if isinstance(scale, list):
            assert len(scale) == 1
            scale = scale[0]

        self.tensor *= scale

    def core_len(self):
        return 1

    def __len__(self):
        return 0


class MergedOutput(nn.Module):
    """
    Merged MPS core taking in one input datum and returning an output vector

    Since MergedOutput arises after contracting together an existing input and
    output core, an already-merged tensor is required for initialization

    Args:
        tensor (Tensor):    Value that our merged core is initialized to
        left_output (bool): Specifies if the output core is on the left side of
                            the input core (True), or on the right (False)
    """

    def __init__(self, tensor, left_output):
        # Check that our input tensor has the correct shape
        # bond_str = "olri"
        assert len(tensor.shape) == 4
        super().__init__()

        # Register our tensor as a Pytorch Parameter
        self.register_parameter(name="tensor", param=nn.Parameter(tensor.contiguous()))
        self.left_output = left_output

    def forward(self, input_data):
        """
        Contract input with input index of core and return an OutputCore

        Args:
            input_data (Tensor): Input with shape [batch_size, feature_dim]
        """
        # Check that input_data has the correct shape
        tensor = self.tensor
        assert len(input_data.shape) == 2
        assert input_data.size(1) == tensor.size(3)

        # Contract the input with our core tensor
        tensor = torch.einsum("olri,bi->bolr", [tensor, input_data])

        return OutputCore(tensor)

    def _unmerge(self, cutoff=1e-10):
        """
        Split our MergedOutput into an OutputSite and an InputSite

        The non-zero entries of our tensors are dynamically sized according to
        the SVD cutoff, but will generally be padded with zeros to give the
        new index a regular size.
        """
        # bond_str = "olri"
        tensor = self.tensor
        left_output = self.left_output
        if left_output:
            svd_string = "olri->olu,uri"
            max_D = tensor.size(2)
            sv_vec = torch.empty(max_D)

            output_core, input_core, bond_dim = svd_flex(
                tensor, svd_string, max_D, cutoff, sv_vec=sv_vec
            )
            return (
                [OutputSite(output_core), InputSite(input_core)],
                [-1, bond_dim, -1],
                [-1, sv_vec, -1],
            )

        else:
            svd_string = "olri->our,lui"
            max_D = tensor.size(1)
            sv_vec = torch.empty(max_D)

            output_core, input_core, bond_dim = svd_flex(
                tensor, svd_string, max_D, cutoff, sv_vec=sv_vec
            )
            return (
                [InputSite(input_core), OutputSite(output_core)],
                [-1, bond_dim, -1],
                [-1, sv_vec, -1],
            )

    def get_norm(self):
        """
        Returns the norm of our core tensor, wrapped as a singleton list
        """
        return [torch.norm(self.tensor)]

    @torch.no_grad()
    def rescale_norm(self, scale):
        """
        Rescales the norm of our core by a factor of input `scale`
        """
        if isinstance(scale, list):
            assert len(scale) == 1
            scale = scale[0]

        self.tensor *= scale

    def core_len(self):
        return 2

    def __len__(self):
        return 1


class InitialVector(nn.Module):
    """
    Vector of ones and zeros to act as initial vector within the MPS

    By default the initial vector is chosen to be all ones, but if fill_dim is
    specified then only the first fill_dim entries are set to one, with the
    rest zero.

    If fixed_vec is False, then the initial vector will be registered as a
    trainable model parameter.
    """

    def __init__(self, bond_dim, fill_dim=None, fixed_vec=True, is_left_vec=True):
        super().__init__()

        vec = torch.ones(bond_dim)
        if fill_dim is not None:
            assert fill_dim >= 0 and fill_dim <= bond_dim
            vec[fill_dim:] = 0

        if fixed_vec:
            vec.requires_grad = False
            self.register_buffer(name="vec", tensor=vec)
        else:
            vec.requires_grad = True
            self.register_parameter(name="vec", param=nn.Parameter(vec))

        assert isinstance(is_left_vec, bool)
        self.is_left_vec = is_left_vec

    def forward(self):
        """
        Return our initial vector wrapped as an EdgeVec contractable
        """
        return EdgeVec(self.vec, self.is_left_vec)

    def core_len(self):
        return 1

    def __len__(self):
        return 0


class TerminalOutput(nn.Module):
    """
    Output matrix at end of chain to transmute virtual state into output vector

    By default, a fixed rectangular identity matrix with shape
    [bond_dim, output_dim] will be used as a state transducer. If fixed_mat is
    False, then the matrix will be registered as a trainable model parameter.
    """

    def __init__(self, bond_dim, output_dim, fixed_mat=False, is_left_mat=False):
        super().__init__()

        # I don't have a nice initialization scheme for a non-injective fixed
        # state transducer, so just throw an error if that's needed
        if fixed_mat and output_dim > bond_dim:
            raise ValueError(
                "With fixed_mat=True, TerminalOutput currently "
                "only supports initialization for bond_dim >= "
                "output_dim, but here bond_dim="
                f"{bond_dim} and output_dim={output_dim}"
            )

        # Initialize the matrix and register it appropriately
        mat = torch.eye(bond_dim, output_dim)
        if fixed_mat:
            mat.requires_grad = False
            self.register_buffer(name="mat", tensor=mat)
        else:
            # Add some noise to help with training
            mat = mat + torch.randn_like(mat) / bond_dim

            mat.requires_grad = True
            self.register_parameter(name="mat", param=nn.Parameter(mat))

        assert isinstance(is_left_mat, bool)
        self.is_left_mat = is_left_mat

    def forward(self):
        """
        Return our terminal matrix wrapped as an OutputMat contractable
        """
        return OutputMat(self.mat, self.is_left_mat)

    def core_len(self):
        return 1

    def __len__(self):
        return 0

In [None]:
mps = MPS(
    input_dim=512 ** 2,
    output_dim=4,
    bond_dim=bond_dim,
    adaptive_mode=adaptive_mode,
    periodic_bc=periodic_bc,
)

In [None]:
loss_fun = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(mps.parameters(), lr=learn_rate, weight_decay=l2_reg)

In [None]:
print(f"Maximum MPS bond dimension = {bond_dim}")
print(f" * {'Adaptive' if adaptive_mode else 'Fixed'} bond dimensions")
print(f" * {'Periodic' if periodic_bc else 'Open'} boundary conditions")
print(f"Using Adam w/ learning rate = {learn_rate:.1e}")
if l2_reg > 0:
    print(f" * L2 regularization = {l2_reg:.2e}")
print()

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
%%bash
conda install -c conda-forge gdcm -y

In [None]:
num_train = 500
num_test = 200

In [None]:
# Put COVID data into dataloaders
samplers = {
    "train": torch.utils.data.SubsetRandomSampler(range(num_train)),
    "test": torch.utils.data.SubsetRandomSampler(range(num_test)),
}
loaders = {
    name: torch.utils.data.DataLoader(
        dataset, batch_size=batch_size, sampler=samplers[name], drop_last=True
    )
    for (name, dataset) in [("train", transformed_train_dataset), ("test", transformed_test_dataset)]
}
num_batches = {
    name: total_num // batch_size
    for (name, total_num) in [("train", num_train), ("test", num_test)]
}

In [None]:
transformed_train_dataset

In [None]:
# Let's start training!
for epoch_num in range(1, num_epochs + 1):
    running_loss = 0.0
    running_acc = 0.0

    for inputs, labels in transformed_train_dataset:
        inputs, labels = inputs.view([batch_size, 512 ** 2]), labels.data
        #inputs = inputs.data, labels.data

        # Call our MPS to get logit scores and predictions
        scores = mps(inputs)
        _, preds = torch.max(scores, 1)

        # Compute the loss and accuracy, add them to the running totals
        loss = loss_fun(scores, labels)
        with torch.no_grad():
            accuracy = torch.sum(preds == labels).item() / batch_size
            running_loss += loss
            running_acc += accuracy

        # Backpropagate and update parameters
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print(f"### Epoch {epoch_num} ###")
    print(f"Average loss:           {running_loss / num_batches['train']:.4f}")
    print(f"Average train accuracy: {running_acc / num_batches['train']:.4f}")

    # Evaluate accuracy of MPS classifier on the validation set
    with torch.no_grad():
        running_acc = 0.0

        for inputs, labels in loaders["test"]:
            inputs, labels = inputs.view([batch_size, 1512 ** 2]), labels.data

            # Call our MPS to get logit scores and predictions
            scores = mps(inputs)
            _, preds = torch.max(scores, 1)
            running_acc += torch.sum(preds == labels).item() / batch_size

    print(f"Validation accuracy:          {running_acc / num_batches['valid']:.4f}")
    print(f"Runtime so far:         {int(time.time()-start_time)} sec\n")