In [1]:
import torch

import contextlib
import itertools as it
import functools as ft
from io import StringIO
from IPython.display import display, HTML
from ipywidgets import widgets
import torchviz


#def 

In [2]:
@contextlib.contextmanager
def html_tag(t, *, dest, style=None):
    style_str = "" if style is None else f'style="{style}"'
    dest.write(f"<{t} {style_str}>")
    yield
    dest.write(f"</{t}>")

def table_html(table, styles="", default_td_style="box-shadow: 4px 2px 5px grey;"):
    res = StringIO()
    tag = ft.partial(html_tag, dest=res)
    if isinstance(styles, str): styles= it.repeat(styles)
    with tag("table", style="table-layout: fixed; border-spacing: 3px; border-collapse: separate"), tag("tbody"):
        for row, row_style in zip(table, styles):
            if isinstance(row_style, str): row_style = it.repeat(row_style)
            with tag("tr"):
                for value, style in zip(row, row_style,):
                    with tag("td", style=f"{default_td_style} {style}"):
                        res.write(value)

    return res.getvalue()

n_colors = 10
colors = [f"hsl({hue}, 70%, 80%)" for hue in range(0, 360, 360//n_colors)]
colors_styles = [f"background-color: {color}" for color in colors]

def get_style_from_color_indices(color_list):
    match color_list:
        case []:
            return "background-color: white"
        case [c]:
            c_= colors[c  % len(colors)]
            return f"background-color: {c_}"
        case [*cs]:
            n_colors = len(color_list)
            percentages = [f"{k/n_colors:.0%}" for k in range(n_colors +1)]
            s = StringIO()
            s.write("background: linear-gradient( to bottom")
            for c, p0, p1 in zip(cs, percentages[:-1], percentages[1:]):
                c_ = colors[c  % len(colors)]
                s.write(f", {c_} {p0} {p1}")
            s.write(")")
            return s.getvalue()


def display_table(table, styles=tuple()):
    display(HTML(table_html(table, styles)))

def display_tensor(tensor):
    if tensor.ndim == 1:
        tensor = tensor.unsqueeze(0)
    if torch.is_floating_point(tensor):
        display_table([[f"{i.item():.01f}" for i in r] for r in tensor], colors_styles)
    else: display_table([[str(i.item()) for i in r] for r in tensor], colors_styles)

def display_storage(tensor):
    storage = tensor.storage()
    storage_colors = [[] for _ in storage]
    for index in it.product(*[range(i) for i in tensor.shape]):
        if not len(index): continue # 0-dim tensor
        column, *_ = index
        storage_index = sum([ i * s for i, s in zip(index, tensor.stride())])
        storage_colors[storage_index].append(column)
    #print(storage_colors)
    styles = [[get_style_from_color_indices(c) for c in storage_colors]]
    #print(styles)
    if torch.is_floating_point(tensor):
        return display_table([[f"{i:.01f}" for i in storage]], styles)
    else: 
        return display_table([[str(i) for i in storage]], styles)

    


class DisplayColumns:

    def __init__(self):
        self.outputs = []

    @contextlib.contextmanager
    def column(self):
        output = widgets.Output()
        self.outputs.append(output)
        with output:
            yield

    def display(self):
        display(widgets.HBox(self.outputs))

#display_storage(w)

# Tensor Structure

In [3]:
def print_characterestics(tensor):
    print("stride: ", tensor.stride())
    print("shape: ", tuple(tensor.shape))
    print("contiguous? ", tensor.is_contiguous())

def display_tensor_and_storage(tensor):
    c = DisplayColumns()
    if tensor.ndim == 1:
        tensor = tensor[None, :]
    print()
    with c.column():
        print("tensor:")
        display_tensor(tensor)
    with c.column():
        print("storage:")
        display_storage(tensor)
    c.display()

In [5]:
t = torch.arange(12, dtype=torch.uint8)
print_characterestics(t)
display_tensor_and_storage(t)

stride:  (1,)
shape:  (12,)
contiguous?  True



HBox(children=(Output(), Output()))

In [6]:
u = t.reshape((3, 4))
print_characterestics(u)
display_tensor_and_storage(u)

stride:  (4, 1)
shape:  (3, 4)
contiguous?  True



HBox(children=(Output(), Output()))

In [7]:
u_transpose = u.T
print_characterestics(u_transpose)
display_tensor_and_storage(u_transpose)

stride:  (1, 4)
shape:  (4, 3)
contiguous?  False



HBox(children=(Output(), Output()))

## Slicing

In [8]:
u_sliced = u[:2, :]
print_characterestics(u_sliced)
display_tensor_and_storage(u_sliced)

stride:  (4, 1)
shape:  (2, 4)
contiguous?  True



HBox(children=(Output(), Output()))

In [9]:
u_sliced2 = u[:, :2]
print_characterestics(u_sliced2)
display_tensor_and_storage(u_sliced2)

stride:  (4, 1)
shape:  (3, 2)
contiguous?  False



HBox(children=(Output(), Output()))

In [10]:
t_skipping = t[::3]
print_characterestics(t_skipping)
display_tensor_and_storage(t_skipping)

stride:  (3,)
shape:  (4,)
contiguous?  False



HBox(children=(Output(), Output()))

## Broadcasting

In [11]:
# indexing with "None" means "create 1 empty dimension"
v = t[:, None]
print_characterestics(v)
display_tensor_and_storage(v)

stride:  (1, 1)
shape:  (12, 1)
contiguous?  True



HBox(children=(Output(), Output()))

In [12]:
w = v.expand(12, 5)
print_characterestics(w)
display_tensor_and_storage(w)

stride:  (1, 0)
shape:  (12, 5)
contiguous?  False



HBox(children=(Output(), Output()))

In [13]:
w_transpose = w.T
print_characterestics(w_transpose)
display_tensor_and_storage(w_transpose)

stride:  (0, 1)
shape:  (5, 12)
contiguous?  False



HBox(children=(Output(), Output()))

In [14]:
u_sliced_flat = u_sliced.flatten()
print_characterestics(u_sliced_flat)
display_tensor_and_storage(u_sliced_flat)

stride:  (1,)
shape:  (8,)
contiguous?  True



HBox(children=(Output(), Output()))

In [15]:
w_transpose_flat = w_transpose.flatten()
print_characterestics(w_transpose_flat)
display_tensor_and_storage(w_transpose_flat)

stride:  (1,)
shape:  (60,)
contiguous?  True



HBox(children=(Output(), Output()))

## Masking and indexing

Fancy indexing!



### Indexing with an array

Remember `t_skipping`


In [16]:
print_characterestics(t_skipping)
display_tensor_and_storage(t_skipping)

stride:  (3,)
shape:  (4,)
contiguous?  False



HBox(children=(Output(), Output()))


Say I want elements 0, 2, and 3  of `t_skipping`

I can just write

In [17]:
indexed = t_skipping[[0, 2, 3]]
display_tensor_and_storage(indexed)




HBox(children=(Output(), Output()))

This operations makes a copy onto a *new storage* (it couldn't work by just changing the strides / shape) 

In a more realistic case, say I want to do *message passing*, I have a node vector

In [18]:
nodes = torch.randn(5, 7)
edges = torch.randint(5, size=(4, 2))
print("nodes: ")
display_tensor_and_storage(nodes)
print()
print("edges: ")
display_tensor_and_storage(edges)

nodes: 



HBox(children=(Output(), Output()))


edges: 



HBox(children=(Output(), Output()))

In [19]:
tensor_edge_j = nodes[edges[:, 1]]
display_tensor_and_storage(tensor_edge_j)




HBox(children=(Output(), Output()))

Now we need to send it through an MLP, and aggregate into `edge[0]`… which we will see how to do later

### Masking (indexing with a boolean array)

say Now I want my nodes labels to be bounded above by $0.5$. So I want to set everything bigger than $0.5$ to $0.5$

I can easily make a boolean array indicating where nodes label values are above $0.5$

In [20]:
threshold = 0.5
threshold_mask = nodes > threshold
print("nodes: ")
display_tensor_and_storage(nodes)
print()
print("threshold mask: ")
display_tensor_and_storage(threshold_mask)

nodes: 



HBox(children=(Output(), Output()))


threshold mask: 



HBox(children=(Output(), Output()))

If I wan to get all the values that are above $0.5$ I can use this array to index !

In [21]:
above_threshold_values = nodes[threshold_mask]

display_tensor_and_storage(above_threshold_values)




HBox(children=(Output(), Output()))

# Operations

## Broadcasting


In [22]:
three = torch.tensor(3)
print_characterestics(three)
display_storage(three)

stride:  ()
shape:  ()
contiguous?  True


0
3


What happens if I call `torch.mul(torch.tensor(3), u)`?

Broadcasting! First `three` gets padded with dimensions of leghth `1` to have the same number of dimensions as `u`

In [23]:
three_padded = three[None, None]
print_characterestics(three_padded)
display_tensor_and_storage(three_padded)

stride:  (1, 1)
shape:  (1, 1)
contiguous?  True



HBox(children=(Output(), Output()))

Then `three` gets "broadcasted" (replicated) across the dimensions equal to 1, so that it matches the dimensions of `u`

In [24]:
three_broadcasted = three_padded.expand(u.shape)
print_characterestics(three_broadcasted)
display_tensor_and_storage(three_broadcasted)

stride:  (0, 0)
shape:  (3, 4)
contiguous?  False



HBox(children=(Output(), Output()))

Now that `three` and `u`

**Note:** This would work even if `three` was not a 0-dim tensor. In fact this is what happens when you add a vector and a matrix for example. In the general case, the rules are as follow.

*example: do an operation on a tensor of shape `[5, 12, 1, 5]` and one of shape `[3, 5]`*  

1. If one of the two tensors has fewer dimensions, pad this one at the start with dimensions of size `1`. *we now have `[5, 12, 1, 5]` and `[1, 1, 3, 5]`*

2. For all dimensions

   - if they are the same, do nothing
   - if they are different, but one is one, broadcast that one to the other
   - otherwise, fail
   
   *We now have `[5, 12, 3, 5]` and `[5, 12, 3, 5]`*

## Dispatch

Somewhere in the torch source code is a file called `native_functions.yaml` containing

```yaml
...

- func: mul.Tensor(Tensor self, Tensor other) -> Tensor
  device_check: NoCheck   # TensorIterator
  structured_delegate: mul.out
  variants: function, method
  dispatch:
    SparseCPU, SparseCUDA: mul_sparse
    SparseCsrCPU, SparseCsrCUDA: mul_sparse_csr
    MkldnnCPU: mkldnn_mul
    ZeroTensor: mul_zerotensor
    NestedTensorCPU, NestedTensorCUDA: NestedTensor_mul_Tensor
  tags: [core, pointwise]

...

- func: mul.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
  device_check: NoCheck   # TensorIterator
  structured: True
  structured_inherits: TensorIteratorBase
  dispatch:
    CPU, CUDA: mul_out
    MPS: mul_out_mps
    SparseCPU: mul_out_sparse_cpu
    SparseCUDA: mul_out_sparse_cuda
    SparseCsrCPU, SparseCsrCUDA: mul_out_sparse_csr
    MkldnnCPU: mkldnn_mul_out
  tags: pointwise
  # For C++ only, until we have conversion from C++ numbers to Tensor

```

which basically tells `torch` that when calling `torch.mul` on `CPU` or `CUDA`, it should call the `mul_out` function

## GPU and asynchronicity

### What is a CUDA Tensor?

- A tensor whose **storage** lives in gpu memory
- The stride / shape are still on CPU memory! no need to access 

In [30]:
t_pinned = t.pin_memory()


In [28]:
t_gpu = t.to("cuda:0")
t_gpu.storage().device

device(type='cuda', index=0)