In [None]:

X.shape

In [None]:
import torch
import hydra
from lib.utils.config import load_config
from lib.model.flame import FLAME
from lib.data.datamodule import DPHMDataModule
from lib.data.scheduler import CoarseToFineScheduler, FinetuneScheduler
from lib.model.flame import FLAME
from lib.model.logger import FlameLogger
from lib.model.loss import calculate_point2plane

cfg = load_config(
    "optimize",
    overrides=[
        "optimizer=gauss_newton",
        "optimzier.lin_solver.dim=6",
        "joint_trainer.init_idxs=[0]",
        "joint_trainer.max_iters=1",
        "joint_trainer.max_optims=1",
        "joint_trainer.scheduler.milestones=[0]",
        "joint_trainer.scheduler.params=[[global_pose,transl]]",
        "joint_trainer.coarse2fine.milestones=[0]",
        "joint_trainer.coarse2fine.scales=[8]",
        "sequential_trainer=null",
    ],
)

device = "cuda" if torch.cuda.is_available() else "cpu"

K = load_intrinsics(data_dir=cfg.data.data_dir, return_tensor="pt")
camera = Camera(
    K=K,
    width=cfg.data.width,
    height=cfg.data.height,
    near=cfg.data.near,
    far=cfg.data.far,
)
rasterizer = Rasterizer(width=camera.width, height=camera.height)
datamodule: DPHMDataModule = hydra.utils.instantiate(cfg.data, devie=device)
logger: FlameLogger = hydra.utils.instantiate(cfg.logger)
model: FLAME = hydra.utils.instantiate(cfg.model).to(device)
coarse2fine: FinetuneScheduler = hydra.utils.instantiate(cfg.scheduler.finetune)
scheduler: CoarseToFineScheduler = hydra.utils.instantiate(cfg.scheduler.coarse2fine)

datamodule.setup()
model.init_renderer(camera=camera, rasterizer=rasterizer)
coarse2fine.init_scheduler(camera=camera, rasterizer=rasterizer)
model.init_logger(logger=logger)
optimizer.init_logger(logger=logger)

In [None]:
# fetch single batch
iter_step = 0
c2fs.schedule_dataset(datamodule=datamodule, iter_step=iter_step)
fts.param_groups(model, iter_step=iter_step)
dataloader = datamodule.train_dataloader()
batch = next(iter(dataloader))

In [None]:
scheduler.configure_optimizer(
    optimizer=optimizer,
    model=model,
    batch=batch,
    iter_step=iter_step,
)

In [None]:
import inspect
signature = inspect.signature(model.forward)
param_names = [param.name for param in signature.parameters.values()]
param_names

In [None]:
import torch

# n = 100
A = torch.rand((700, 700))
B = torch.rand(700)
X = torch.linalg.solve(A, B)
x = torch.zeros((700), requires_grad=True)
# def foo(x):
#     return (A @ x - B)
# J = torch.autograd.functional.jacobian(foo, x)
# J

In [None]:
import torch
inputs = torch.rand(2, 2)

In [None]:
import torch
from torch.autograd.functional import jacobian

def exp_reducer(x):
    return x.exp().sum(dim=1)
# jacobian(exp_reducer, inputs, strategy="forward-mode", vectorize=True)
jacobian(exp_reducer, inputs, strategy="reverse-mode", vectorize=True, create_graph=True)

In [None]:
inputs = torch.rand(2, 2)
inputs.requires_grad = True

In [None]:
from torch.func import jacrev, vmap, jacfwd
def f(x):
    return x.sin().sum(dim=-1)
v = vmap(jacrev(torch.exp))(inputs)
v

In [None]:
inputs

In [None]:
jacobian(exp_reducer, inputs, strategy="forward-mode", vectorize=True)

In [None]:
torch.__version__

In [None]:
from tqdm import tqdm
max_steps = 100000
x = torch.zeros((n), requires_grad=True)
# optimizer = torch.optim.Adam([x], lr=1.0) 
optimizer = torch.optim.LBFGS([x) 

for step in tqdm(range(max_steps)):
    optimizer.zero_grad()
    F = A @ x - B
    loss = torch.pow(F, 2).sum()
    # print(f"{step}) {loss}")
    loss.backward()
    optimizer.step()
print(f"{step}) {loss}")

In [None]:
A.inverse() @ B

In [None]:
(A.T @ A).inverse() @ (A.T @ B)

In [None]:
torch.linalg.solve(A, B)

In [None]:
import torch

a = torch.tensor([2., 3.], requires_grad=True)
b = torch.tensor([6., 4.], requires_grad=True)
Q = 3*a**3 - b**2
external_grad = torch.tensor([1., 1.])
Q.sum().backward()

In [None]:
a = torch.randn(5, requires_grad=True)
b = 2 * a
c = b ** 2  # replace this with c = b + 2 and the autograd error will go away
print(b._version)
b = b + 1
print(b._version)
b += 1  # inplace operation!
print(b._version)
# c.sum().backward()