In [37]:
# einstein example 

import torch
from einops import rearrange, einsum
import einx

In [4]:
torch.manual_seed(42)
images = torch.randn(64, 128, 128, 3)
dim_by = torch.linspace(start = 0.0, end = 1.0, steps = 10)

print(images.shape)
print(dim_by.shape)

torch.Size([64, 128, 128, 3])
torch.Size([10])


In [5]:
dim_value = rearrange(dim_by, 'dim_value -> 1 dim_value 1 1 1')
print(dim_value.shape)

torch.Size([1, 10, 1, 1, 1])


In [6]:
images_rearr = rearrange(images, "b height width channel -> b 1 height width channel")
print(images_rearr.shape)

torch.Size([64, 1, 128, 128, 3])


In [7]:
dimmed_images = images_rearr * dim_value # what is * here 
print(dimmed_images.shape)

torch.Size([64, 10, 128, 128, 3])


In [10]:
dimmed_images = einsum(
    images, dim_by,
    'batch height width channel, dim_value -> batch dim_value height width channel'
)
print(dimmed_images.shape)

torch.Size([64, 10, 128, 128, 3])


In [24]:
channels_last = torch.randn(64, 32, 32, 3) # Batch, Height, Width, Channels 
B = torch.randn(32*32, 32*32)
channels_last_flat = channels_last.view(
    -1, channels_last.size(1) * channels_last.size(2), channels_last.size(3), 
) # view reshapes 
print(channels_last_flat.shape)

torch.Size([64, 1024, 3])
torch.Size([64, 3, 1024])


In [25]:
channels_first_flat = channels_last_flat.transpose(1, 2)
print(channels_first_flat.shape)

height = width = 32
channels_first = rearrange( channels_last, "batch height width channel -> batch channel (height width)" )
print(channels_first.shape)

torch.Size([64, 3, 1024])
torch.Size([64, 3, 1024])


In [30]:
channels_first_flat_transformed = channels_first_flat @ B.T
print(channels_first_flat_transformed.shape)

channels_first_transformed = einsum(channels_first, B, "batch channel pixel_in, pixel_out pixel_in -> batch channel pixel_out" )
print(channels_first_transformed.shape)

torch.Size([64, 3, 1024])
torch.Size([64, 3, 1024])


In [31]:
channels_last_flat_transformed = channels_first_flat_transformed.transpose(1, 2)
print(channels_last_flat_transformed.shape)

torch.Size([64, 1024, 3])


In [32]:
channels_last_transformed = channels_last_flat_transformed.view(*channels_last.shape) # * unpacks the shape
print(channels_last_transformed.shape)
channels_last_transformed = rearrange( channels_first_transformed, "batch channel (height width) -> batch height width channel", height=height, width=width )
print(channels_last_transformed.shape)

torch.Size([64, 32, 32, 3])
torch.Size([64, 32, 32, 3])


In [39]:
channels_last_transformed = einx.dot( "batch row_in col_in channel, (row_out col_out) (row_in col_in)" "-> batch row_out col_out channel", channels_last, B, col_in=width, col_out=width )
print(channels_last_transformed.shape)

torch.Size([64, 32, 32, 3])
