In [1]:
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import numpy as np
from tqdm.notebook import tqdm
from gromo.tools import *
import tensorly as tl
tl.set_backend('pytorch')

In [2]:
def assert_close(a, b, msg="", atol=1e-8, rtol=1e-5):
    assert a.shape == b.shape, f"{msg} (shape mismatch : {a.shape} != {b.shape})"
    assert torch.allclose(a, b, atol=atol, rtol=rtol), f"{msg} (||.||_inf = {torch.max(torch.abs(a - b)):2e}, ||.||_0 = {torch.sum(torch.abs(a - b)) / a.numel():2e} %, ||.||_2 = {torch.norm(a - b) / torch.norm(b):2e})"

# Lemma

Conv o AvgPool = AvgPool o Conv

In [3]:
# Hyperparameters
cin = 4
cout = 5
ĥd = 3
wd = 2
h = 22
w = 31
px = 5
py = 6
n = 10

# Create a random image
x = torch.randn(n, cin, h, w)

# Create the layers
conv = nn.Conv2d(cin, cout, wd)
pool = nn.AvgPool2d((px, py), stride=(1, 1))

# Forward pass
y1 = conv(pool(x))
y2 = pool(conv(x))

# Check the shapes
assert y1.shape == y2.shape, f"Shapes are different: {y1.shape=} != {y2.shape=}"

assert_close(y1, y2, "There is no commutativity", atol=1e-6)

# Conv to conv

## Define the data

In [4]:
# Hyperparameters
c0 = 5
c1 = 6
c2 = 7

hd1 = 3
wd1 = 5
hd2 = 5
wd2 = 3

h0 = 9
w0 = 10

# h0 = 21
# w0 = 30

n = 2

# Create a random image

x = torch.randn(n, c0, h0, w0)
pre_conv = nn.Conv2d(c0, c0, (3, 3), padding=1, bias=False)
x = pre_conv(x) # add unisotropicity

# Create the layers
conv1 = nn.Conv2d(c0, c1, (hd1, wd1), bias=False)
assert conv1.weight.shape == (c1, c0, hd1, wd1), f"Wrong shape: {conv1.weight.shape=}"
assert torch.allclose(torch.nn.functional.conv2d(input=x, 
                                                 weight=conv1.weight, 
                                                 bias=conv1.bias,
                                                 stride=conv1.stride, 
                                                 padding=conv1.padding, 
                                                 dilation=conv1.dilation), 
                                                 conv1(x)), "Error in the convolution"

conv2 = nn.Conv2d(c1, c2, (hd2, wd2), bias=False)

# Forward pass
print(f"Start with {x.shape=}")
print(f"Apply {conv1=} (output shape: {conv1(x).shape})")
print(f"Apply {conv2=} (output shape: {conv2(conv1(x)).shape})")

h1, w1 = conv1(x).shape[-2:]
h2, w2 = conv2(conv1(x)).shape[-2:]


y_th = conv2(conv1(x))

Start with x.shape=torch.Size([2, 5, 9, 10])
Apply conv1=Conv2d(5, 6, kernel_size=(3, 5), stride=(1, 1), bias=False) (output shape: torch.Size([2, 6, 7, 6]))
Apply conv2=Conv2d(6, 7, kernel_size=(5, 3), stride=(1, 1), bias=False) (output shape: torch.Size([2, 7, 3, 4]))


In [5]:
x_unfolded = torch.nn.functional.unfold(x, (hd1, wd1), padding=conv1.padding, stride=conv1.stride, dilation=conv1.dilation)
assert x_unfolded.shape == (n, c0 * hd1 * wd1, h1 * w1)

$\hat{B} \in (n, C[-1] dd, HW)$

In [6]:
x_unfolded_flat = x_unfolded.permute(0, 2, 1).flatten(end_dim=1)
assert x_unfolded_flat.shape == (n * h1 * w1, c0 * hd1 * wd1), f"{x_unfolded_flat.shape=}"

In [7]:
theta1_flat = conv1.weight.flatten(start_dim=1)
assert theta1_flat.shape == (c1, c0 * hd1 * wd1)

In [8]:
y_th_half = conv1(x)
assert y_th_half.shape == (n, c1, h1, w1)

y_th_half_flat = y_th_half.flatten(start_dim=2)
assert y_th_half_flat.shape == (n, c1, h1 * w1)

In [9]:
y_f_half1 = torch.einsum("iam, ca -> icm", x_unfolded, theta1_flat)
assert y_f_half1.shape == (n, c1, h1 * w1)

In [10]:
assert_close(y_f_half1, y_th_half_flat, "Error in the first convolution", atol=1e-6)

In [11]:
y_f_half_tl = tl.tenalg.mode_dot(x_unfolded, theta1_flat, mode=1)
assert y_f_half_tl.shape == (n, c1, h1 * w1)
assert_close(y_f_half_tl, y_th_half_flat, "Error in the mode_dot", atol=1e-6)

In [12]:
y_f_half2 = x_unfolded_flat @ theta1_flat.T
assert y_f_half2.shape == (n * h1 * w1, c1)
y_half2 = y_f_half2.reshape(n, h1 * w1, c1).permute(0, 2, 1)
assert_close(y_half2, y_th_half_flat, "Error in the matrix multiplication", atol=1e-6)

$Conv_{\Theta_1}(B) \sim \hat{B}^{iam} \Theta_1^{ca} = \hat{B} \times_1 \Theta_1 \sim \hat{B}_F \times \Theta_1^T$



## Optimal update

In [13]:
s = x_unfolded_flat.T @ x_unfolded_flat
assert s.shape == (c0 * hd1 * wd1, c0 * hd1 * wd1)


y_th_half_flat_2 = y_th_half_flat.permute(0, 2, 1).flatten(end_dim=1)
assert y_th_half_flat_2.shape == (n * h1 * w1, c1)

m = y_th_half_flat_2.T @ x_unfolded_flat
assert m.shape == (c1, c0 * hd1 * wd1)

In [14]:
s1 = torch.einsum("iam, ibm -> ab", x_unfolded, x_unfolded)
assert s1.shape == (c0 * hd1 * wd1, c0 * hd1 * wd1)

n1 = torch.einsum("iam, icm -> ca", x_unfolded, y_th_half_flat)
assert n1.shape == (c1, c0 * hd1 * wd1)

In [15]:
assert_close(n1, m, "Error in the einsum", atol=1e-6)
assert_close(s1, s, "Error in the einsum", atol=1e-6)

In [16]:
w_star = torch.linalg.solve(s, m.T)
assert w_star.shape == (c0 * hd1 * wd1, c1)
w_star_r = w_star.reshape(c0, hd1, wd1, c1).permute(3, 0, 1, 2)

conv_star = nn.Conv2d(c0, c1, (hd1, wd1), bias=False)
conv_star.weight.data = w_star_r
torch.norm(conv_star(x) - conv1(x))

tensor(1.4402e-05, grad_fn=<LinalgVectorNormBackward0>)

## Double convolution as a linear transformation

We show that $(Conv_1 \circ Conv_2)(X) = X \times \Theta$ with the correct reshape.

In [17]:
# theta = K1 K2 = A B
theta = torch.einsum("cdkl, ecxy -> dklxye", conv1.weight, conv2.weight)
assert theta.shape == (c0, hd1, wd1, hd2, wd2, c2)

theta = theta.flatten(3, 4).flatten(0, 2)
assert theta.shape == (c0 * hd1 * wd1, hd2 * wd2, c2), f"{theta.shape=}"

$\Theta^{dklxye} = W^{cdkl} W[+1]_{ecxy} \in (C[-1] dd, d[+1]d[+1], C[+1])$ 

In [18]:
t1 = compute_mask_tensor_t((h0, w0), conv1)
assert t1.shape == (h1 * w1, hd1 * wd1, h0 * w0)
del t1

t2 = compute_mask_tensor_t((h1, w1), conv2)
assert t2.shape == (h2 * w2, hd2 * wd2, h1 * w1)

tt2 = torch.einsum("mbl, mdk -> bldk", t2, t2)
assert tt2.shape == (hd2 * wd2, h1 * w1, hd2 * wd2, h1 * w1)

old_tt2 = torch.einsum('xkl, xkm->lm', t2, t2)
assert old_tt2.shape == (h1 * w1, h1 * w1)

$T \in (H[+1]W[+1], d[+1]d[+1], HW)$

$(TT)^{bldk} = T^{mbl} T_{mdk} \in (d[+1]d[+1], HW, d[+1]d[+1], HW)$

In [19]:
# T X
t2x = torch.einsum("ial, mbl -> imab", x_unfolded, t2)
assert t2x.shape == (n, h2 * w2, c0 * hd1 * wd1, hd2 * wd2)

In [20]:
t2x_bis = torch.zeros((n, h2 * w2, c0 * hd1 * wd1, hd2 * wd2))

assert t2.shape == (h2 * w2, hd2 * wd2, h1 * w1)
assert x_unfolded.shape == (n, c0 * hd1 * wd1, h1 * w1)

for i in range(n):
    for j in range(h2 * w2):
        t2x_bis[i, j] = x_unfolded[i].detach() @ t2[j].T

assert torch.allclose(t2x, t2x_bis, atol=1e-6)

$X^{imab} = \hat{B}^{ial} T_{mbl} \in (n, H[+1]W[+1], C[-1]dd, d[+1]d[+1])$

$X[i, j] = \hat{B}[i] \times T[j]^T$

In [21]:
y_f_th = y_th.permute(0, 2, 3, 1).flatten(1, 2)
assert y_f_th.shape == (n, h2 * w2, c2), f"{y_f_th.shape=}"

$Y \in (n, H[+1]W[+1], C[+1])$

### With einsum

In [22]:
x_theta = torch.einsum("imab, abc -> imc", t2x, theta)
assert x_theta.shape == (n, h2 * w2, c2)

assert_close(y_f_th, x_theta, "The formula is not correct", atol=1e-6)

In [23]:
x_theta_2 = torch.einsum("ial, mbl, abc -> imc", x_unfolded, t2, theta)
assert x_theta_2.shape == (n, h2 * w2, c2)

assert_close(y_f_th, x_theta_2, "The formula is not correct", atol=1e-6)

In [24]:
x_theta_3 = torch.zeros((n, c2, h2 * w2))

assert t2x.shape == (n, h2 * w2, c0 * hd1 * wd1, hd2 * wd2)
assert conv2.weight.shape == (c2, c1, hd2, wd2)
assert conv1.weight.shape == (c1, c0, hd1, wd1)
alpha_flat = conv1.weight.flatten(start_dim=1)
assert alpha_flat.shape == (c1, c0 * hd1 * wd1)




for i in range(n):
    for m in range(c2):
        for j in range(h2 * w2):
            for k in range(c1):
                omega = conv2.weight[m, k].flatten()
                assert omega.shape == (hd2 * wd2,)
                x_theta_3[i, m, j] += omega.T @ t2x[i, j].T @ alpha_flat[k]


assert_close(x_theta_3.permute(0, 2, 1), y_f_th, "The formula is not correct", atol=1e-6)

  x_theta_3[i, m, j] += omega.T @ t2x[i, j].T @ alpha_flat[k]


$Conv_2(Conv_1(X)) \sim X^{imab} \Theta^{abc} = \hat{B}^{ial} T_{mbl} \Theta^{abc}$

### With matrix multiplication

In [25]:
theta_f = theta.flatten(0, 1)

assert theta_f.shape == (c0 * hd1 * wd1 * hd2 * wd2, c2), f"{theta_f.shape=}"
x_f = t2x.flatten(0, 1)
x_f = x_f.flatten(start_dim=1)
assert x_f.shape == (n * h2 * w2, c0 * hd1 * wd1 * hd2 * wd2), f"{x_f.shape=}"

y_f = x_f @ theta_f
assert y_f.shape == (n * h2 * w2, c2)

assert_close(y_f, y_f_th.flatten(0, 1), "The formula is not correct", atol=1e-6)

$X_F \in (n H[+1]W[+1], C[-1]ddd[+1]d[+1])$

$\Theta_F \in (C[-1]ddd[+1]d[+1], C[+1])$

$Y_F \in (n H[+1]W[+1], C[+1])$ 

$Conv_2(Conv_1(X)) \sim X_F \Theta_F$

## Compute S

### From the definition with matrix multiplication

$S := X_F^T X_F \in (C[-1]ddd[+1]d[+1], C[-1]ddd[+1]d[+1])$

In [26]:
s_f = x_f.t() @ x_f
assert s_f.shape == (c0 * hd1 * wd1 * hd2 * wd2, c0 * hd1 * wd1 * hd2 * wd2)

In [29]:
%%timeit
s = torch.einsum("imab, imcd -> abcd", t2x, t2x)
assert s.shape == (c0 * hd1 * wd1, hd2 * wd2, c0 * hd1 * wd1, hd2 * wd2)
s = s.flatten(2, 3).flatten(0, 1)
assert s.shape == (c0 * hd1 * wd1 * hd2 * wd2, c0 * hd1 * wd1 * hd2 * wd2)
assert_close(s, s_f, "The formula is not correct", atol=1e-6)

7.96 ms ± 212 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [30]:
%%timeit
# WARNING: This is very memory consuming
s2 = torch.einsum("ial, mbl, ick, mdk -> abcd", x_unfolded, t2, x_unfolded, t2)
assert s2.shape == (c0 * hd1 * wd1, hd2 * wd2, c0 * hd1 * wd1, hd2 * wd2)
s2 = s2.flatten(2, 3).flatten(0, 1)
assert s2.shape == (c0 * hd1 * wd1 * hd2 * wd2, c0 * hd1 * wd1 * hd2 * wd2)
assert_close(s2, s_f, "The formula is not correct", atol=1e-5)

64.4 ms ± 4.3 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


$S^{abcd} = X^{imab} X^{imcd} \in (C[-1]dd, d[+1]d[+1], C[-1]dd, d[+1]d[+1])$

$S^{abcd} = \hat{B}^{ial} T_{mbl} \hat{B}^{ick} T_{mdk} \in (C[-1]dd, d[+1]d[+1], C[-1]dd, d[+1]d[+1])$

In [29]:
assert x_unfolded.shape == (n, c0 * hd1 * wd1, h1 * w1)
assert old_tt2.shape == (h1 * w1, h1 * w1)
old_s = torch.einsum('ial, lm, ibm -> ab', x_unfolded, old_tt2, x_unfolded)
assert old_s.shape == (c0 * hd1 * wd1, c0 * hd1 * wd1)

In [30]:
assert x_unfolded.shape == (n, c0 * hd1 * wd1, h1 * w1)
assert t2.shape == (h2 * w2, hd2 * wd2, h1 * w1)
b_t = torch.einsum("ial, jbl -> ijab", x_unfolded, t2)
assert b_t.shape == (n, h2 * w2, c0 * hd1 * wd1, hd2 * wd2)
olds_s2 = torch.einsum('ijab, ijcb -> ac', b_t, b_t)
assert olds_s2.shape == (c0 * hd1 * wd1, c0 * hd1 * wd1)
assert_close(olds_s2, old_s, "The formula is not correct", atol=1e-5, rtol=1e-2)

In [31]:
assert_close(olds_s2, old_s, "The formula is not correct", atol=1e-5, rtol=1e-2)

In [32]:
old_s3 = torch.einsum('ial, jbl, icl, jbl -> ac', x_unfolded, t2, x_unfolded, t2)
assert old_s3.shape == (c0 * hd1 * wd1, c0 * hd1 * wd1)

assert_close(old_s3, old_s, "The formula is not correct", atol=1e-5)

$B^t_{i, j} := T_j (B^c_i)^T$ or $(B^t)^{ij} = T^{jbl} (B^c)^{ial}$

$S := \sum_{i = 1}^n \sum_{j = 1}^{H[+1]W[+1]} (B^t_{i, j})^T (B^t_{i, j}) \in (C[-1] dd, C[-1] dd) = (B^t)^{ijab} (B^t)^{ijcb}$

$S^{ac} = (B^c)^{ial} T_{jbl} (B^c)^{icl} T_{jbl} \in (C[-1] dd, C[-1] dd)$


## Compute N
(named M in the code)

$N := X_F^T Y_F \in (C[-1]ddd[+1]d[+1], C[+1])$

In [33]:
m_f = x_f.t() @ y_f
assert m_f.shape == (c0 * hd1 * wd1 * hd2 * wd2, c2) 

In [41]:
assert y_f_th.shape == (n, h2 * w2, c2)
assert t2x.shape == (n, h2 * w2, c0 * hd1 * wd1, hd2 * wd2)

In [34]:
m = torch.einsum("ixab, ixc -> abc", t2x, y_f_th)
assert m.shape == (c0 * hd1 * wd1, hd2 * wd2, c2)
m = m.flatten(0, 1)
assert m.shape == (c0 * hd1 * wd1 * hd2 * wd2, c2)
assert_close(m, m_f, "The formula is not correct", atol=1e-4)

In [35]:
m2 = torch.einsum("ial, xbl, ixc -> abc", x_unfolded, t2, y_f_th)
assert m2.shape == (c0 * hd1 * wd1, hd2 * wd2, c2)
m2 = m2.flatten(0, 1)
assert m2.shape == (c0 * hd1 * wd1 * hd2 * wd2, c2)
assert_close(m2, m_f, "The formula is not correct", atol=1e-4)

$N^{abc} = X^{ixab} Y^{ixc} \in (C[-1]dd, d[+1]d[+1], C[+1])$

$N^{abc} = \hat{B}^{ial} T_{xbl} \Theta^{ixc} \in (C[-1]dd, d[+1]d[+1], C[+1])$

In [36]:
assert t2.shape == (h2 * w2, hd2 * wd2, h1 * w1)
assert x_unfolded.shape == (n, c0 * hd1 * wd1, h1 * w1)
assert y_f_th.shape == (n, h2 * w2, c2)

old_n = torch.einsum('xkl, ial, ixc -> ack', t2, x_unfolded, y_f_th)

assert old_n.shape == (c0 * hd1 * wd1, c2, hd2 * wd2)
old_n = old_n.flatten(1, 2)
assert old_n.shape == (c0 * hd1 * wd1, c2 * hd2 * wd2)
# x : h2 * w2
# k: hd2 * wd2
# l: h1 * w1
# i: n
# a: c0 * hd1 * wd1

In [37]:
assert b_t.shape == (n, h2 * w2, c0 * hd1 * wd1, hd2 * wd2)
old_n2 = torch.einsum('ijm, ijak -> amk', y_f_th, b_t)
assert old_n2.shape == (c0 * hd1 * wd1, c2, hd2 * wd2)
assert_close(old_n2.flatten(1, 2), old_n, "The formula is not correct", atol=1e-6)

## Add optimal neurons

### New S, M

In [38]:
sm = s @ m
assert sm.shape == (c0 * hd1 * wd1 * hd2 * wd2, c2)
sm.shape, sm.numel()

(torch.Size([1125, 7]), 7875)

In [39]:
alpha, omega, sigma = compute_optimal_added_parameters(s, m, numerical_threshold=1e-15, statistical_threshold=1e-15)
k = sigma.shape[0]
assert k == min(c0 * hd1 * wd1 * hd2 * wd2, c2)
assert alpha.shape == (c0 * hd1 * wd1 * hd2 * wd2, k)
assert omega.shape == (k, c2)
sigma.shape, sigma

Max difference: 1.91e-06,% of non-zero elements: 2.41%
  warn(


AssertionError: 

In [82]:
def conv_from_alpha_omega(alpha: torch.Tensor, 
                          omega: torch.Tensor, 
                          conv1: torch.nn.Conv2d, 
                          conv2: torch.nn.Conv2d,
                          k: int) -> tuple[torch.nn.Conv2d, torch.nn.Conv2d]:
    """
    Create a convolutional layer from the alpha and omega parameters

    Parameters
    ----------
    alpha: torch.Tensor
        tensor alpha in (c0 * hd1 * wd1 * hd2 * wd2, k)
    omega: torch.Tensor
        tensor omega in (k, c2)
    conv1: torch.nn.Conv2d
        convolutional layer 1
    conv2: torch.nn.Conv2d
        convolutional layer 2

    Returns
    -------
    torch.nn.Conv2d
        convolutional layer
    """
    c1, c0, hd1, wd1 = conv1.weight.shape
    c2, _, hd2, wd2 = conv2.weight.shape
    assert c1 == _, f"{c1=} != {_=}"
    max_k = min(c0 * hd1 * wd1 * hd2 * wd2, c2)
    assert alpha.shape == (c0 * hd1 * wd1 * hd2 * wd2, max_k), f"{alpha.shape=}"
    assert omega.shape == (max_k, c2)
    k = min(max_k, k)

    theta = alpha @ omega
    assert theta.shape == (c0 * hd1 * wd1 * hd2 * wd2, c2)
    theta = theta.reshape((c0 * hd1 * wd1, hd2 * wd2, c2)).permute(0, 2, 1).flatten(1, 2)
    assert theta.shape == (c0 * hd1 * wd1, c2 * hd2 * wd2)

    u, s, v = torch.linalg.svd(theta, full_matrices=False)
    assert (torch.all(s >= 0))
    s = torch.sqrt(s[:k])
    u = u[:, :k] * s
    v = v[:k, :] * s.unsqueeze(1)
    
    new_conv1 = torch.nn.Conv2d(c0, k, (hd1, wd1), bias=False)
    new_conv1.weight.data = u.reshape((k, c0, hd1, wd1))
    new_conv2 = torch.nn.Conv2d(k, c2, (hd2, wd2), bias=False)
    new_conv2.weight.data = v.reshape((c2, k, hd2, wd2))

    return new_conv1, new_conv2

In [83]:
new_conv1, new_conv2 = conv_from_alpha_omega(alpha, omega, conv1, conv2, k)

new_y = new_conv2(new_conv1(x))
torch.norm(new_y - y_th)

tensor(1295.1821, grad_fn=<LinalgVectorNormBackward0>)

### Old S, M

In [78]:
sn_old = old_s @ old_n
assert sn_old.shape == (c0 * hd1 * wd1, c2 * hd2 * wd2)
sn_old.shape, sn_old.numel()

(torch.Size([75, 105]), 7875)

In [79]:
old_alpha, old_omega, sigma = compute_optimal_added_parameters(old_s, old_n, numerical_threshold=1e-15, statistical_threshold=1e-15)
k = sigma.shape[0]
assert k == min(c2 * hd2 * wd2, c0 * hd1 * wd1), f"{k=}"
assert old_alpha.shape == (c0 * hd1 * wd1, k), f"{old_alpha.shape=}"
assert old_omega.shape == (k, c2 * hd2 * wd2)
sigma.shape, sigma

(torch.Size([75]),
 tensor([20.3776, 17.9205, 17.6068, 17.3225, 17.0303, 16.3034, 15.2867, 15.2516,
         13.6711, 12.6789, 12.1093, 11.8605, 11.6597, 11.2274, 10.5018, 10.2002,
          9.9270,  9.7805,  9.5504,  8.8780,  8.5152,  8.1829,  8.0237,  7.9264,
          7.7054,  7.2579,  7.1677,  6.8433,  6.6986,  6.4866,  6.4120,  6.0317,
          5.9407,  5.8310,  5.5817,  5.3840,  5.2577,  5.1800,  5.0048,  4.7238,
          4.6380,  4.3988,  4.3635,  4.1891,  4.0053,  3.8942,  3.7571,  3.6106,
          3.5652,  3.5043,  3.3150,  3.1713,  3.0611,  2.9485,  2.6393,  2.5713,
          2.5231,  2.4216,  2.3493,  2.2053,  2.0556,  1.8755,  1.8494,  1.7431,
          1.6724,  1.6123,  1.5173,  1.3643,  1.2467,  1.2000,  1.0787,  0.9900,
          0.8586,  0.7888,  0.6594], grad_fn=<IndexBackward0>))

In [80]:
def conv_from_alpha_omega_old(alpha: torch.Tensor, 
                              omega: torch.Tensor, 
                              conv1: torch.nn.Conv2d, 
                              conv2: torch.nn.Conv2d,
                              k: int) -> tuple[torch.nn.Conv2d, torch.nn.Conv2d]:
    """
    Create a convolutional layer from the alpha and omega parameters

    Parameters
    ----------
    alpha: torch.Tensor
        tensor alpha in (c0 * hd1 * wd1 * hd2 * wd2, k)
    omega: torch.Tensor
        tensor omega in (k, c2)
    conv1: torch.nn.Conv2d
        convolutional layer 1
    conv2: torch.nn.Conv2d
        convolutional layer 2

    Returns
    -------
    torch.nn.Conv2d
        convolutional layer
    """
    c1, c0, hd1, wd1 = conv1.weight.shape
    c2, _, hd2, wd2 = conv2.weight.shape
    assert c1 == _, f"{c1=} != {_=}"
    max_k = min(c0 * hd1 * wd1, c2 * hd2 * wd2)
    assert alpha.shape == (c0 * hd1 * wd1, max_k), f"{alpha.shape=}"
    assert omega.shape == (max_k, c2 * hd2 * wd2)
    k = min(k, max_k)

    alpha = alpha[:, :k]
    omega = omega[:k, :]

    new_conv1 = torch.nn.Conv2d(c0, k, (hd1, wd1), bias=False)
    new_conv1.weight.data = alpha.reshape((k, c0, hd1, wd1))

    new_conv2 = torch.nn.Conv2d(k, c2, (hd2, wd2), bias=False)
    new_conv2.weight.data = omega.reshape((c2, k, hd2, wd2))

    return new_conv1, new_conv2


In [81]:
new_conv1_old, new_conv2_old = conv_from_alpha_omega_old(old_alpha, old_omega, conv1, conv2, k)

new_y_old = new_conv2_old(new_conv1_old(x))
torch.norm(new_y_old - y_th)

tensor(99.2078, grad_fn=<LinalgVectorNormBackward0>)