First install the repo and requirements.

In [None]:
%pip --quiet install git+https://github.com/wilson-labs/cola.git

# Quick Start

We now showcase the basic functionality present in CoLA. We'll start by showing how to define different types of Linear ops, then we'll show how to perform basic arithmetic with Linear ops and, finally, we'll conclude applying some linear algebra operations (like solves or log determinants) to the Linear ops.

We'll work with torch.tensors in this example, but the same code can be run using JAX arrays (jnp.ndarrays). 

In [1]:
import cola as co
import torch
torch.manual_seed(21)

<torch._C.Generator at 0x7fd8dbf60df0>

## Creating a Linear Operator

You can find several predefined Linear ops under cola.ops. We'll ilustrate three basic cases: Dense, Diagonal and Tridiagonal.

A Dense Linear Operator is nothing more than a wrapper on a dense matrix, where the wrapper defines a matmat function $v \mapsto Av$ and holds several attributes such as dtype and shape.

Let's start by defining a dense matrix and a vector to act upon. Below we show the entries of the matrix $A$, of the vector $v$ and the result of $Av$.

In [2]:
N = 3
A_dense = torch.randn(N, N)
vec = torch.randn(N)
print(A_dense)
print(vec)
print(A_dense @ vec)

tensor([[ 0.1081, -0.4376, -0.7697],
        [-0.1929, -0.3626, -2.8451],
        [ 1.4435,  0.4976,  0.6542]])
tensor([ 0.0754, -1.0767,  0.1269])
tensor([ 0.3816,  0.0147, -0.3438])


To create a Dense operator simply run:

In [3]:
A = co.ops.Dense(A_dense)
print(type(A))

<class 'cola.ops.Dense'>


The previous operator now has a dtype and a shape attribute. More importantly, it can now act on the vector $v$ and get the same result as above.

In [4]:
print(f"Dtype: {A.dtype} | Shape: {A.shape}")
print(A @ vec)

Dtype: torch.float32 | Shape: torch.Size([3, 3])
tensor([ 0.3816,  0.0147, -0.3438])


To define a Diagonal Linear Operator we only have to pass a diagonal like below and we can reconstruct the dense matrix by using the to_dense().

In [5]:
diagonal = torch.tensor([1., 2., 3])
D = co.ops.Diagonal(diagonal)
print(D.to_dense())

tensor([[1., 0., 0.],
        [0., 2., 0.],
        [0., 0., 3.]])


We follow a similar procedure for a Tridiagonal Linear Operator, where we now provide the diagonal but also the lower and upper bands of the matrix.

In [6]:
upper_band = torch.tensor([[1., 1.]]).T
lower_band = torch.tensor([[-1., -1.]]).T
diagonal = torch.tensor([[3., 3., 3.]]).T
T = co.ops.Tridiagonal(lower_band, diagonal, upper_band)
print(T.to_dense())

tensor([[ 3.,  1.,  0.],
        [-1.,  3.,  1.],
        [ 0., -1.,  3.]])


It is worth noting that for both the Diagonal and the Tridiagonal ops, the cost of doing an MVM is no longer $O(N^2)$ but rather $O(N)$.

Overall, the different types of predefined Linear ops in CoLA have different requirements, but usually they contained a representation that is either much sparser than the dense one or has faster MVMs (or both).

## Doing binary operations with Linear ops

CoLA provides a similar interface to combine Linear ops as you would combine matrices. For example, to sum two Linear ops we simply do.

In [7]:
DT = D + T
print(DT.to_dense())

tensor([[ 4.,  1.,  0.],
        [-1.,  5.,  1.],
        [ 0., -1.,  6.]])


However, we can combine the Linear Operator much further. For example, we can create a new linear operator $B= A(D-T) + \mu I$ regularized by $\mu$ by running the following code.

In [8]:
from cola.linear_algebra import I_like
mu = 1e-6
B = A @ (D - T)
B += mu * I_like(B)
print(B)

Densediag(tensor([1., 2., 3.]))+-1*Tridiagonal+1e-06*


Under the hood the operator $B$ is lazily defined and would know how to apply the Linear ops to any vector $v$.

## Computing solves and log determinants

To solve the linear system $Bx=v$ we use the inverse function. This inverse function lazily defines $B^{-1}$ and hence applying it to $v$ yields the solution $x=B^{-1}v$. The inverse of $B$ is never computed, using $B^{-1}$ is simply how in CoLA we call linear solves.

In [9]:
B_inv = co.inverse(B)
soln = B_inv @ vec
abs_res = torch.linalg.norm(B @ soln - vec)
print(f"{abs_res:1.3e}")

6.366e-08


Let's construct a PSD Linear Operator in order to apply the log determinant operation. Let's set $C=(D + T)(D + T)^{*}$ which would be a symmetric diagonally-dominant operator.

In [10]:
co.logdet(D)  # Finish adding the implementation for a PSD op

tensor(1.7918)