In [None]:
import torch
from pytorch_semifield_conv import (
    BroadcastSemifield,
    GenericConv2D,
    QuadraticKernelSpectral2D,
)
import matplotlib.pyplot as plt
import numpy as np
from tqdm.auto import trange

In [None]:
# Move to project root
from pathlib import Path
import os

if not Path("./src/models").is_dir():
    for parent_path in Path.cwd().parents:
        if (parent_path / "src/models").is_dir():
            os.chdir(parent_path)
            break
    else:
        raise FileNotFoundError("Can't find project root")

assert Path("./src/models").is_dir()

In [None]:
resolution = 100
kernel_size = 11

In [None]:
dilation = BroadcastSemifield.tropical_max().dynamic()

In [None]:
# Pick an interesting-looking kernel: it learns for other seeds as well
torch.manual_seed(7)
target_kernel = QuadraticKernelSpectral2D(
    1, 1, kernel_size, {"var": "skewed", "theta": "uniform"}
).requires_grad_(False)
print(*target_kernel.covs.named_parameters(), sep="\n")

In [None]:
torch.manual_seed(0)
learned_kernel = QuadraticKernelSpectral2D(1, 1, kernel_size)
print(*learned_kernel.covs.named_parameters(), sep="\n")

In [None]:
torch.manual_seed(0)
inp = torch.rand((100, 1, resolution, resolution))
target_out = GenericConv2D(kernel=target_kernel, conv=dilation)(inp)

In [None]:
initial_out = GenericConv2D(kernel=learned_kernel, conv=dilation)(inp)

In [None]:
plt.set_cmap("Spectral_r");

In [None]:
loss = torch.nn.MSELoss(reduction="sum")

In [None]:
def plot_data(
    data: torch.Tensor,
    batch: int = 0,
    channel: int = 0,
    ax: plt.Axes = None,
    title: str = "",
    vmin: float = 0,
    vmax: float = 1,
    save_to: str = None,
):
    assert len(data.shape) == 4
    data = data[batch, channel].numpy(force=True)
    if ax is None:
        _, ax = plt.subplots(
            layout="compressed", dpi=500 if save_to else None, figsize=(3, 3)
        )
    ax.imshow(data, vmin=vmin, vmax=vmax)
    ax.axis("off")
    ax.set_title(title, font="Latin Modern Roman", fontsize=16)
    if save_to:
        plt.savefig(save_to)


plot_data(inp, title="Input data")
plot_data(target_kernel(), title="Target kernel", vmin=-1, vmax=0)
plot_data(target_out, title="Target output")
plot_data(learned_kernel(), title="Initial kernel", vmin=-1, vmax=0)
plot_data(initial_out, title="Initial output")
print(
    "Initial loss:",
    loss(initial_out, target_out).item(),
)

In [None]:
torch.manual_seed(0)
learned_kernel = QuadraticKernelSpectral2D(1, 1, kernel_size)
optim = torch.optim.Adam(learned_kernel.parameters(), lr=0.05)

print("TARGET", *target_kernel.covs.named_parameters(), sep="\n")
print("INITIAL", *learned_kernel.covs.named_parameters(), sep="\n")

plot_data(
    target_kernel(),
    title=f"Target kernel",
    vmin=-1,
    vmax=0,
    save_to="./report/figures/poc_target.png",
)
plot_data(
    learned_kernel(),
    title=f"Initial (skewed, theta=0) kernel",
    vmin=-1,
    vmax=0,
    save_to="./report/figures/poc_init.png",
)

for i in trange(50, unit="steps", desc="Fitting on random data"):
    if i and not i % 10:
        plot_data(
            learned_kernel(),
            title=f"Training step {i}",
            vmin=-1,
            vmax=0,
            save_to=f"./report/figures/poc_step_{i}.png",
        )
    output = GenericConv2D(kernel=learned_kernel, conv=dilation)(inp)
    optim.zero_grad()
    cur_loss = loss(output, target_out)
    cur_loss.backward()
    optim.step()

plot_data(
    learned_kernel(),
    title=f"Learned kernel, after 50 steps",
    vmin=-1,
    vmax=0,
    save_to=f"./report/figures/poc_result.png",
)
print("RESULT", *learned_kernel.covs.named_parameters(), sep="\n")

In [28]:
torch.backends.cudnn.benchmark = True
print(torch.backends.cudnn.benchmark_limit)

None
