In [None]:
import numpy as np
import torch
import torch.utils.cpp_extension
from torch import Tensor, jit, nn
from tqdm.autonotebook import tqdm, trange

In [None]:
torch.utils.cpp_extension.load(
    name="spectral_norm",
    sources=["spectral_norm.cpp"],
    is_python_module=False,
    verbose=True,
)
print(torch.ops.custom.spectral_norm)
spectral_norm = torch.ops.custom.spectral_norm

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
@jit.script
def experiment(device: str = "cpu", atol: float = 1e-7, rtol: float = 1e-4) -> None:
    device = torch.device(device)
    # shape = torch.Size(torch.randint(1, 512, (2,)))
    shape = torch.randint(1, 512, (2,))
    A = torch.randn(shape[0], shape[1], device=device)
    u0 = torch.randn(shape[0], device=device)
    v0 = torch.randn(shape[1], device=device)
    s = spectral_norm(A, maxiter=None, atol=1e-8, rtol=1e-5)
    s_ref = torch.linalg.matrix_norm(A, ord=2)
    assert (
        abs(s_ref - s) < atol + rtol * s_ref
    ), f"s={s}  {abs(s_ref - s)/s_ref} m+n={sum(shape)}"
    assert s > 0, s

In [None]:
for k in trange(10_000):
    experiment()

In [None]:
experiment()

In [None]:
shape = torch.Size(torch.randint(1, 512, (2,)))

In [None]:
shape
A = torch.randn(shape, device=device)

In [None]:
jit.script(experiment)()

In [None]:
import matplotlib.pyplot as plt
from tqdm.autonotebook import tqdm, trange

In [None]:
import matplotlib.pyplot as plt
from tqdm.autonotebook import tqdm, trange


def alg_convergene(lists, tol=1e-5) -> list[int]:
    return [np.argmax(np.array(l) < tol) for l in lists]

In [None]:
maxiter = 1000
m, n = 128, 256
A = torch.torch.randn(m, n)
U, S, V = torch.svd(A)
u_true = U[:, 0]
v_true = V[:, 0]
s_true = S[0]

R = torch.arange(maxiter)

u0 = torch.randn(m)
v0 = torch.randn(n)

u = u0.clone()
v = v0.clone()

f_s = []  # s residual
f_u = []  # u residual
f_v = []  # v residual
f_r = []  # right residual
f_l = []  # left residual
f_x = []  # diff u
f_y = []  # diff v
sign_s = []

for k in tqdm(R):
    u_old = u
    u = A.mv(v)
    u /= u.norm()
    s = A.mv(v).dot(u)
    sign_s.append(s.sign())

    v_old = v
    v = A.T.mv(u)
    v /= v.norm()
    s = A.mv(v).dot(u)
    f_r.append((A.mv(v) - s * u).norm())
    f_x.append((v - v_old).norm())

    f_l.append((A.t().mv(u) - s * v).norm())
    f_y.append((u - u_old).norm())

    f_v.append(min((v - v_true).norm(), (v + v_true).norm()))
    f_u.append(min((u - u_true).norm(), (u + u_true).norm()))
    f_s.append(abs(s - s_true))
    # f_r.append((A.mv(v) - s * u).norm())

u = u0.clone()
v = v0.clone()
g_s = []  # s residual
g_u = []  # u residual
g_v = []  # v residual
g_r = []  # right residual
g_l = []  # left residual

for k in tqdm(R):
    u_old = u
    v_old = v
    u = A.mv(v_old)
    u /= u.norm()

    v = A.T.mv(u_old)
    v /= v.norm()
    s = A.mv(v).dot(u)

    g_v.append(min((v - v_true).norm(), (v + v_true).norm()))
    g_u.append(min((u - u_true).norm(), (u + u_true).norm()))
    g_s.append(abs(abs(s) - s_true))
    g_r.append((A.mv(v) - s * u).norm())
    g_l.append((A.t().mv(u) - s * v).norm())

assert (np.array(sign_s) == 1).all()
# f_s,f_u,f_v,f_r,f_l,f_x,f_y


print(alg_convergene([f_s, f_u, f_v, f_r, f_l, f_x, f_y]))
print(alg_convergene([g_s, g_u, g_v, g_r, g_l]))

In [None]:
fig, ax = plt.subplots(figsize=(16, 10))
plt.loglog(
    # fmt: off
    R, f_u, "-r", R, f_v, "-b", R, f_s, "-g", R, f_l, "-k", R, f_r, "-y", R, f_x, R, f_y,
    R, g_u, ":r", R, g_v, ":b", R, g_s, ":g", R, g_l, ":k", R, g_r, ":y",
    # fmt: on
);

In [None]:
A = torch.randn(5, 7)
u0 = torch.randn(5)
v0 = torch.randn(7)

In [None]:
spectral_norm(A, maxiter=10_000, atol=10**-2, rtol=10**-2)

In [None]:
torch.linalg.matrix_norm(A, ord=2)

In [None]:
with open("op.cpp", "r", encoding="utf8") as file:
    op_source = file.read()

torch.utils.cpp_extension.load_inline(
    name="op_with_autograd",
    cpp_sources=op_source,
    extra_ldflags=["-lopencv_core", "-lopencv_imgproc"],
    is_python_module=False,
    verbose=True,
)

print(torch.ops.custom.op_with_autograd)

In [None]:
op_with_autograd = torch.ops.custom.op_with_autograd

a = torch.randn(3, 3).cuda()
b = torch.tensor(3).cuda()
c = torch.randn(3, 3).cuda()

op_with_autograd(a, b, c)

In [None]:
import torch
import torch.utils.cpp_extension

# Load the custom module
custom_module = torch.utils.cpp_extension.load(
    name="AlexNet",
    sources=["custom_module.cpp"],
    is_python_module=False,
    verbose=True,
    build_directory="build",
    extra_cflags=["-O3"],
)

In [None]:
dir(torch.ops.custom.CustomModule)

In [None]:
torch.ops.

In [None]:
# Instantiate the module
model = custom_module.CustomModule()

# test custom op serialization

In [None]:
class Foo(nn.Module):
    def forward(self, x: Tensor, y: int, z: Tensor) -> Tensor:
        return op_with_autograd(x, y, z)


module = Foo()
module(a, b, c)

In [None]:
scripted = jit.script(module)
scripted.save("scripted_module.pt")
scripted(a, b, c)

In [None]:
loaded = jit.load("scripted_module.pt")
loaded(a, b, c)

## test backward

In [None]:
import torch

torch.ops.load_library("libcustom_ops.so")

In [None]:
torch.ops.custom.opa

In [None]:
torch.ops.loaded_libraries

In [None]:
print(dir(torch.ops.my_ops))
print(dir(torch.ops.custom))

In [None]:
import torch.utils.cpp_extension

torch.utils.cpp_extension.load(
    name="warp_perspective",
    sources=["op.cpp"],
    extra_ldflags=["-lopencv_core", "-lopencv_imgproc"],
    is_python_module=False,
    verbose=True,
)

print(torch.ops.my_ops.warp_perspective)

In [None]:
torch.ops.my_ops.op_with_autograd

In [None]:
def f(x: Tensor) -> Tensor:
    return torch.ops.my_ops.warp_perspective(x)

In [None]:
import torch
import torch.utils.cpp_extension

op_source = """
#include <torch/script.h>

torch::Tensor warp_perspective(torch::Tensor image, torch::Tensor warp) {
  return image.clone();
}

TORCH_LIBRARY(my_ops, m) {
  m.def("warp_perspective", &warp_perspective);
}
"""

torch.utils.cpp_extension.load_inline(
    name="warp_perspective",
    cpp_sources=op_source,
    extra_ldflags=["-lopencv_core", "-lopencv_imgproc"],
    is_python_module=False,
    verbose=True,
)

print(torch.ops.my_ops.warp_perspective)

In [None]:
print(torch.ops.my_ops.warp_perspective)

In [None]:
dir(torch.ops.my_ops)

In [None]:
import torch

torch.ops.load_library("build/libwarp_perspective.so")
print(torch.ops.my_ops.warp_perspective)

In [None]:
import torch.utils.cpp_extension

torch.utils.cpp_extension.load(
    name="warp_perspective",
    sources=["op.cpp"],
    extra_ldflags=["-lopencv_core", "-lopencv_imgproc"],
    is_python_module=False,
    verbose=True,
)

print(torch.ops.my_ops.warp_perspective)

In [None]:
from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CppExtension

setup(
    name="warp_perspective",
    ext_modules=[
        CppExtension(
            "warp_perspective",
            ["example_app/warp_perspective/op.cpp"],
            libraries=["opencv_core", "opencv_imgproc"],
        )
    ],
    cmdclass={"build_ext": BuildExtension.with_options(no_python_abi_suffix=True)},
)

In [None]:
def compute(x, y, z):
    x = torch.ops.my_ops.warp_perspective(x, torch.eye(3))
    return x.matmul(y) + torch.relu(z)

In [None]:
inputs = [torch.randn(4, 8), torch.randn(8, 5), torch.randn(4, 5)]
trace = torch.jit.trace(compute, inputs)
print(trace.graph)