In [None]:
X.shape

In [None]:
import torch
import hydra
from lib.utils.config import load_config
from lib.model.flame import FLAME
from lib.data.datamodule import DPHMDataModule
from lib.trainer.logger import FlameLogger
from lib.data.loader import load_intrinsics
from lib.rasterizer import Rasterizer
from lib.renderer.camera import Camera
from lib.utils.config import set_configs

cfg = load_config(
    "optimize",
    overrides=[
        "optimizer=gauss_newton",
        "joint_trainer.init_idxs=[0]",
        "joint_trainer.max_iters=1",
        "joint_trainer.max_optims=1",
        "joint_trainer.scheduler.milestones=[0]",
        "joint_trainer.scheduler.params=[[global_pose,transl]]",
        "joint_trainer.coarse2fine.milestones=[0]",
        "joint_trainer.coarse2fine.scales=[8]",
        "sequential_trainer=null",
    ],
)
cfg = set_configs(cfg)

device = "cuda" if torch.cuda.is_available() else "cpu"

K = load_intrinsics(data_dir=cfg.data.data_dir, return_tensor="pt")
camera = Camera(
    K=K,
    width=cfg.data.width,
    height=cfg.data.height,
    near=cfg.data.near,
    far=cfg.data.far,
)
rasterizer = Rasterizer(width=camera.width, height=camera.height)
datamodule: DPHMDataModule = hydra.utils.instantiate(cfg.data, devie=device)
# logger: FlameLogger = hydra.utils.instantiate(cfg.logger)
model: FLAME = hydra.utils.instantiate(cfg.model).to(device)
# coarse2fine = hydra.utils.instantiate(cfg.joint_trainer.coarse2fine)
# scheduler = hydra.utils.instantiate(cfg.joint_trainer.scheduler)
# optimizer = hydra.utils.instantiate(cfg.optimizer)

# datamodule.setup()
model.init_renderer(camera=camera, rasterizer=rasterizer)
# coarse2fine.init_scheduler(camera=camera, rasterizer=rasterizer)
# model.init_logger(logger=logger)
# optimizer.init_logger(logger=logger)

In [None]:
from lib.utils.mesh import vertex_normals

out = model()
vertices = out["vertices"]
faces = model.faces.data
fragments = model.renderer.rasterize(vertices, faces)


def infer():
    model()


def normals(_vertices, _faces):
    vertex_normals(_vertices, _faces)


def interpolate(
    vertices_idx: torch.Tensor,
    bary_coords: torch.Tensor,
    attributes: torch.Tensor,
):
    model.renderer.interpolate(vertices_idx, bary_coords, attributes)


def timed(fn):
    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)
    start.record()
    result = fn()
    end.record()
    torch.cuda.synchronize()
    return result, start.elapsed_time(end)


infer_opt = torch.compile(infer, mode="default")
normals_opt = torch.compile(normals, mode="default")
interpolate_opt = torch.compile(interpolate, mode="default")

In [None]:
from lib.utils.mesh import vertex_normals

# vertex_normals(torch.rand((1, 5023, 3)).cuda(), torch.randn((9976, 3)).cuda())
vertex_normals(vertices, faces)

In [None]:
from lib.trainer.timer import TimeTracker

N_ITER = 100
tracker = TimeTracker()
for _ in range(N_ITER):
    tracker.start("model_inference")
    interpolate(
        vertices_idx=fragments.vertices_idx,
        bary_coords=fragments.bary_coords,
        attributes=vertices,
    )
    tracker.stop()
for _ in range(N_ITER):
    tracker.start("model_opt_inference")
    interpolate_opt(
        vertices_idx=fragments.vertices_idx,
        bary_coords=fragments.bary_coords,
        attributes=vertices,
    )
    tracker.stop()
print(tracker.print_summary())

In [None]:
print(
    timed(
        lambda: interpolate(
            vertices_idx=fragments.vertices_idx,
            bary_coords=fragments.bary_coords,
            attributes=vertices,
        )
    )[1]
)
print(
    timed(
        lambda: interpolate_opt(
            vertices_idx=fragments.vertices_idx,
            bary_coords=fragments.bary_coords,
            attributes=vertices,
        )
    )[1]
)

In [None]:
print(timed(lambda: normals(vertices, faces))[1])
print(timed(lambda: normals_opt(vertices, faces))[1])

In [None]:
# fetch single batch
iter_step = 0
c2fs.schedule_dataset(datamodule=datamodule, iter_step=iter_step)
fts.param_groups(model, iter_step=iter_step)
dataloader = datamodule.train_dataloader()
batch = next(iter(dataloader))

In [None]:
scheduler.configure_optimizer(
    optimizer=optimizer,
    model=model,
    batch=batch,
    iter_step=iter_step,
)

In [None]:
import inspect

signature = inspect.signature(model.forward)
param_names = [param.name for param in signature.parameters.values()]
param_names

In [None]:
import torch

# n = 100
A = torch.rand((700, 700))
B = torch.rand(700)
X = torch.linalg.solve(A, B)
x = torch.zeros((700), requires_grad=True)
# def foo(x):
#     return (A @ x - B)
# J = torch.autograd.functional.jacobian(foo, x)
# J

In [None]:
import torch

inputs = torch.rand(2, 2)

In [None]:
import torch
from torch.autograd.functional import jacobian


def exp_reducer(x):
    return x.exp().sum(dim=1)


# jacobian(exp_reducer, inputs, strategy="forward-mode", vectorize=True)
jacobian(
    exp_reducer, inputs, strategy="reverse-mode", vectorize=True, create_graph=True
)

In [None]:
inputs = torch.rand(2, 2)
inputs.requires_grad = True

In [None]:
from torch.func import jacrev, vmap, jacfwd


def f(x):
    return x.sin().sum(dim=-1)


v = vmap(jacrev(torch.exp))(inputs)
v

In [None]:
inputs

In [None]:
jacobian(exp_reducer, inputs, strategy="forward-mode", vectorize=True)

In [None]:
torch.__version__

In [None]:
from tqdm import tqdm
max_steps = 100000
x = torch.zeros((n), requires_grad=True)
# optimizer = torch.optim.Adam([x], lr=1.0) 
optimizer = torch.optim.LBFGS([x) 

for step in tqdm(range(max_steps)):
    optimizer.zero_grad()
    F = A @ x - B
    loss = torch.pow(F, 2).sum()
    # print(f"{step}) {loss}")
    loss.backward()
    optimizer.step()
print(f"{step}) {loss}")

In [None]:
A.inverse() @ B

In [None]:
(A.T @ A).inverse() @ (A.T @ B)

In [None]:
torch.linalg.solve(A, B)

In [None]:
import torch

a = torch.tensor([2.0, 3.0], requires_grad=True)
b = torch.tensor([6.0, 4.0], requires_grad=True)
Q = 3 * a**3 - b**2
external_grad = torch.tensor([1.0, 1.0])
Q.sum().backward()

In [None]:
a = torch.randn(5, requires_grad=True)
b = 2 * a
c = b**2  # replace this with c = b + 2 and the autograd error will go away
print(b._version)
b = b + 1
print(b._version)
b += 1  # inplace operation!
print(b._version)
# c.sum().backward()