In [None]:
%config InteractiveShell.ast_node_interactivity='last_expr_or_assign'  # always print last expr.

import torch
from torch import Tensor, jit, nn
from torch.linalg import matrix_norm
from torch.nn import functional as F
from torch.optim import SGD

from linodenet.lib import singular_triplet
from linodenet.models.encoders.invertible_layers import (
    LinearContraction,
    iResNetBlock,
    iSequential,
)

torch.autograd.set_detect_anomaly(True)

## Test simple LinearContraction

In [None]:
N, m, n = 32, 256, 256

x = torch.randn(m)
X = torch.randn(N, m)
model = jit.script(LinearContraction(m, n))

mem_params = sum([
    param.nelement() * param.element_size() for param in model.parameters()
])
mem_bufs = sum([buf.nelement() * buf.element_size() for buf in model.buffers()])
print(f"{(mem_params + mem_bufs) // (1024**2)} MiB")

jit.save(model, "LinearContraction.pt")
# model = jit.load("model.pt")

optim = SGD(model.parameters(), lr=0.5)
# print(model.weight)
# print(model.cached_weight)
# model.reset_cache()
# print(model.cached_weight)
# print(model.cached_weight)
# model.reset_cache()
print(model.sigma)
print(matrix_norm(model.weight, ord=2))
print(matrix_norm(model.cached_weight, ord=2))

In [None]:
for k in range(3):
    model.zero_grad(set_to_none=True)
    # y = -F.linear(x, model.cached_weight).norm()
    y = -model(x).norm()
    y.backward()
    optim.step()
    print(f"{k=} {y.item()} ============ ")
    model.reset_cache()
    print(model.sigma)
    print(matrix_norm(model.weight, ord=2))
    print(matrix_norm(model.cached_weight, ord=2))

# model.reset_cache()
# print(model.sigma)
# print(matrix_norm(model.weight, ord=2))
# print(matrix_norm(model.cached_weight, ord=2))

# Test Sequential

In [None]:
model = nn.Sequential(
    LinearContraction(m, n),
    LinearContraction(n, m),
)
model = jit.script(model)

mem_params = sum([
    param.nelement() * param.element_size() for param in model.parameters()
])
mem_bufs = sum([buf.nelement() * buf.element_size() for buf in model.buffers()])
print(f"{(mem_params + mem_bufs) // (1024**2)} MiB")

jit.save(model, "sequential.pt")

optim = SGD(model.parameters(), lr=0.5)


def reset_caches(module):
    for m in model.modules():
        if hasattr(m, "reset_cache"):
            m.reset_cache()


def show_params(module):
    for m in model.modules():
        if m.original_name == "LinearContraction":
            print(m.sigma)
            print(matrix_norm(m.weight, ord=2))
            print(matrix_norm(m.cached_weight, ord=2))


show_params(model)

In [None]:
for k in range(3):
    model.zero_grad(set_to_none=True)
    y = -model(x).norm()
    y.backward()
    optim.step()
    print(f"{k=} ============ {y.item()}")
    reset_caches(model)
    show_params(model)
    # model.reset_cache()
    # print(model.sigma)
    # print(matrix_norm(model.weight, ord=2))
    # print(matrix_norm(model.cached_weight, ord=2))

## Test iResNetBlock

In [None]:
def surgery(model):
    print("Applying Surgery!!!")
    with torch.no_grad():
        inner = list(model.block.modules())[1:]
        outer = list(model.inverse.block.modules())[1:]

        for layer, other in zip(inner, outer):
            other.weight = layer.weight
            other.bias = layer.bias
            other.cached_weight = layer.cached_weight
            other.sigma = layer.sigma
            other.u = layer.u
            other.v = layer.v

In [None]:
inner_model = nn.Sequential(
    LinearContraction(m, n),
    LinearContraction(n, m),
)

model = iResNetBlock(inner_model)
model = jit.script(model)

jit.save(model, "iREsNetBlock.pt")
model = jit.load("iREsNetBlock.pt")

mem_params = sum([
    param.nelement() * param.element_size() for param in model.parameters()
])
mem_bufs = sum([buf.nelement() * buf.element_size() for buf in model.buffers()])
print(f"{(mem_params + mem_bufs) // (1024**2)} MiB")
optim = SGD(model.parameters(), lr=0.5)


def reset_caches(module):
    for m in model.modules():
        if hasattr(m, "reset_cache"):
            m.reset_cache()


def show_params(module):
    for m in model.modules():
        if (
            getattr(m, "original_name", False)
            or getattr(m.__class__, "__name__", False)
        ) == "LinearContraction":
            print(m.sigma)
            print(matrix_norm(m.weight, ord=2))
            print(matrix_norm(m.cached_weight, ord=2))


show_params(model)

In [None]:
for k in range(3):
    model.zero_grad(set_to_none=True)
    y = -model(x).norm()
    y.backward()
    optim.step()
    print(f"{k=} ============ {y.item()}")
    reset_caches(model)
    print("~~~~ Encoder Params ~~~~~")
    show_params(model.block)
    print("~~~~ Decoder Params ~~~~~")
    show_params(model.inverse.block)

## CHECK IF encoder.weight "IS" decoder.weight

In [None]:
encoder_layers = list(model.block.modules())[1:]
decoder_layers = list(model.block.modules())[1:]

In [None]:
for layer, other in zip(encoder_layers, decoder_layers):
    assert layer.weight is other.weight
    assert layer.sigma is other.sigma
    assert layer.cached_weight is other.cached_weight
    assert layer.u is other.u
    assert layer.v is other.v