In [21]:
import torch
import torch.nn as nn
import torch.nn.functional as F


class UNet(nn.Module):
    def __init__(
        self,
        init_features: int = 64,
        depth: int = 4,
        size: int = 256,
        mode: str = "point",  # "point", "normal", "point_normal"
        max_weight: float = 100.0,
        device: str = "cuda",
    ):
        super().__init__()

        self.max_weight = max_weight        

        self.mode = mode
        assert mode in ["point", "normal", "point_normal"]
        out_channels = 1
        in_channels = 6
        if mode == "point_normal":
            in_channels = 12        

        self.depth = depth
        self.size = size
        features = init_features

        # Contracting Path (Encoder)
        self.encoders = nn.ModuleList()
        self.pools = nn.ModuleList()
        for i in range(depth):
            self.encoders.append(UNet._block(in_channels, features))
            self.pools.append(nn.MaxPool2d(kernel_size=2, stride=2))
            in_channels = features
            features *= 2

        # Bottleneck
        self.bottleneck = UNet._block(features // 2, features)

        # Expansive Path (Decoder)
        self.upconvs = nn.ModuleList()
        self.decoders = nn.ModuleList()
        for i in range(depth):
            features //= 2
            self.upconvs.append(
                nn.ConvTranspose2d(
                    features * 2,
                    features,
                    kernel_size=2,
                    stride=2,
                )
            )
            self.decoders.append(UNet._block(features * 2, features))

        # Final Convolution
        self.conv = nn.Conv2d(
            in_channels=features,
            out_channels=out_channels,
            kernel_size=1,
        )

        self.to(device)

    def forward(
        self,
        s_point: torch.Tensor,
        s_normal: torch.Tensor,
        t_point: torch.Tensor,
        t_normal: torch.Tensor,
    ):
        # prepare input
        if self.mode == "point_normal":  # (B, H, W, 12)
            x = torch.cat([s_point, s_normal, t_point, t_normal], dim=-1)
        elif self.mode == "point":
            x = torch.cat([s_point, t_point], dim=-1)  # (B, H, W, 6)
        elif self.mode == "normal":
            x = torch.cat([s_normal, t_normal], dim=-1)  # (B, H, W, 6)
        else:
            raise AttributeError(f"No {self.mode} that works.")
        x = x.permute(0, 3, 1, 2)  # (B, C, H, W)
        B, C, H, W = x.shape 
        x = self._pad(x, height=H, width=W)

        # B, H, W, C
        encoders_output = []
        for i in range(self.depth):
            x = self.encoders[i](x)
            encoders_output.append(x)
            x = self.pools[i](x)

        bottleneck = self.bottleneck(x)

        for i in range(self.depth):
            x = self.upconvs[i](bottleneck if i == 0 else x)
            enc_output = encoders_output[-(i + 1)]
            x = torch.cat((x, enc_output), dim=1)
            x = self.decoders[i](x)
        x = torch.exp(self.conv(x))
        x = self._unpad(x, height=H, width=W)
        x = x.permute(0, 2, 3, 1)  # (B, W, H, 1)

        return x 

    @staticmethod
    def _block(in_channels: int, features: int):
        return nn.Sequential(
            nn.Conv2d(
                in_channels=in_channels,
                out_channels=features,
                kernel_size=3,
                padding=1,
            ),
            nn.ReLU(inplace=True),
            nn.Conv2d(
                in_channels=features,
                out_channels=features,
                kernel_size=3,
                padding=1,
            ),
            nn.ReLU(inplace=True),
        )

    def _pad(self, x: torch.Tensor, height: int, width: int):
        # Desired output dimensions
        target_height = self.size
        target_width = self.size

        # Calculate padding for height and width
        pad_height = target_height - height
        pad_width = target_width - width

        # Pad equally on both sides
        padding = [
            pad_width // 2,
            pad_width - pad_width // 2,
            pad_height // 2,
            pad_height - pad_height // 2,
        ]  # (left, right, top, bottom)

        # Apply padding
        return F.pad(x, padding)

    def _unpad(self, x: torch.Tensor, height: int, width: int):
        # Desired output dimensions
        target_height = self.size
        target_width = self.size

        # Calculate padding for height and width
        pad_height = target_height - height
        pad_width = target_width - width

        # Slice back to the original shape (135, 240)
        start_height = pad_height // 2
        end_height = start_height + height 

        start_width = pad_width // 2
        end_width = start_width +  width

        return x[:, :, start_height:end_height, start_width:end_width]


model = UNet(init_features=32, depth=4, mode="point_normal", device="cpu")
# print(unet)

# Example input tensor (batch_size, channels, height, width)
s_point = torch.zeros((1, 135, 240, 3))
t_point = torch.zeros((1, 135, 240, 3))
s_normal = torch.zeros((1, 135, 240, 3))
t_normal = torch.zeros((1, 135, 240, 3))
out = model(
    s_point =s_point,
    t_point=t_point,
    s_normal=s_normal,
    t_normal=t_normal,
)
out["bottleneck"].shape

torch.Size([1, 512, 16, 16])

In [73]:
# Example of usage
model = UNet(in_channels=1, out_channels=1, init_features=32)
# print(model)

# Example input tensor (batch_size, channels, height, width)
x = torch.zeros((1, 1, 135, 240))
out = model(x)
out["weight"].shape

torch.Size([1, 1, 256, 256])

131072

In [45]:
import torch
import torch.nn.functional as F

# Input tensor of shape (batch_size, channels, height, width)
x = torch.randn((1, 1, 135, 240), requires_grad=True)


# Desired output dimensions
target_height = 256
target_width = 256

# Calculate padding for height and width
pad_height = target_height - x.shape[2]
pad_width = target_width - x.shape[3]

# Pad equally on both sides
padding = [pad_width // 2, pad_width - pad_width // 2, pad_height // 2, pad_height - pad_height // 2]  # (left, right, top, bottom)

# Apply padding
x_padded = F.pad(x, padding)

print("Padded shape:", x_padded.shape)  # Should be (1, 1, 256, 256)


Padded shape: torch.Size([1, 1, 256, 256])


In [46]:
import torch
import torch.nn.functional as F

# Input tensor of shape (batch_size, channels, height, width)
x = torch.randn((1, 1, 135, 240), requires_grad=True)

# Desired output dimensions
target_height = 256
target_width = 256

# Calculate padding for height and width
pad_height = target_height - x.shape[2]
pad_width = target_width - x.shape[3]

# Pad equally on both sides
padding = [pad_width // 2, pad_width - pad_width // 2, pad_height // 2, pad_height - pad_height // 2]  # (left, right, top, bottom)

# Apply padding
x_padded = F.pad(x, padding)

print("Padded shape:", x_padded.shape)  # Should be (1, 1, 256, 256)

# Slice back to the original shape (135, 240)
start_height = pad_height // 2
end_height = start_height + x.shape[2]

start_width = pad_width // 2
end_width = start_width + x.shape[3]

x_original_shape = x_padded[:, :, start_height:end_height, start_width:end_width]

print("Shape after slicing:", x_original_shape.shape)  # Should be (1, 1, 135, 240)


Padded shape: torch.Size([1, 1, 256, 256])
Shape after slicing: torch.Size([1, 1, 135, 240])


In [47]:
x_original_shape

tensor([[[[ 1.0182, -0.2634,  1.2838,  ..., -0.5184,  0.2893,  0.1558],
          [-0.3295, -0.1763, -0.4001,  ..., -0.4550, -0.6738, -0.6903],
          [-0.5878, -0.4613, -0.0079,  ...,  1.0670, -1.4411,  0.4294],
          ...,
          [-0.2986,  1.8018, -0.8441,  ..., -1.4393,  1.9679,  0.3780],
          [-0.7267, -1.0729, -0.5690,  ..., -0.0377,  1.1575, -0.7883],
          [-0.6154,  0.8982, -1.1257,  ...,  0.6682,  1.5402,  0.4965]]]],
       grad_fn=<SliceBackward0>)

In [48]:
x

tensor([[[[ 1.0182, -0.2634,  1.2838,  ..., -0.5184,  0.2893,  0.1558],
          [-0.3295, -0.1763, -0.4001,  ..., -0.4550, -0.6738, -0.6903],
          [-0.5878, -0.4613, -0.0079,  ...,  1.0670, -1.4411,  0.4294],
          ...,
          [-0.2986,  1.8018, -0.8441,  ..., -1.4393,  1.9679,  0.3780],
          [-0.7267, -1.0729, -0.5690,  ..., -0.0377,  1.1575, -0.7883],
          [-0.6154,  0.8982, -1.1257,  ...,  0.6682,  1.5402,  0.4965]]]],
       requires_grad=True)

In [56]:
import torch

xs = []
for i in range(120):
    x = torch.load(f"/home/borth/GuidedResearch/data/dphm_kinect/christoph_mouthmove/params/{i:05}.pt")
    xs.append(x["transl"])
xs = torch.cat(xs)
(xs[1:] - xs[:-1]).mean()

tensor(3.2209e-06)

In [50]:
xs.mean()

tensor(-0.1526)

In [35]:
import torch
from lib.optimizer.solver import PytorchCholeskySolver, LinearSystemSolver, PytorchSolver, PytorchLSTSQSolver

x = torch.load("/home/borth/GuidedResearch/temp/_tracked_innocenzo_fulgintl_rotatemouth/linsys/0000000.pt")
A = x["A"].requires_grad_(True)
b = x["b"].requires_grad_(True)
solver = PytorchLSTSQSolver()
x_gt, _ = solver(A, b)
x_gt.mean().backward()
print(A.grad.mean(), A.grad.median(), A.grad.max(), A.grad.min())
print(b.grad.mean(), b.grad.median(), b.grad.max(), b.grad.min())


tensor(-2.5607e-07) tensor(3.2350e-07) tensor(0.0004) tensor(-0.0004)
tensor(0.0004) tensor(-1.6172e-06) tensor(0.0100) tensor(-0.0008)


In [37]:
torch.linalg.solve(A, b)

tensor([ 3.2022e-02, -1.6498e-02,  1.6422e-02,  8.3287e-03,  1.8885e-01,
        -1.1039e-01, -2.6702e-01,  5.2789e-02, -9.4564e-02, -1.7879e-01,
        -2.8307e-02, -2.4503e-01, -2.4458e-01,  6.2571e-02, -2.1742e-01,
         5.7680e-02, -3.0225e-01, -1.0088e-01, -1.5052e-03, -8.2054e-03,
        -3.3691e-01, -2.4137e-02, -1.6685e-01,  5.2742e-01, -1.1675e-01,
        -4.3037e-01,  4.1128e-01, -1.2757e-01, -4.6779e-01, -2.1549e-01,
        -2.5377e-01,  2.5678e-01,  1.4873e-01,  2.8678e-01,  5.6243e-02,
         1.8464e-02, -2.5776e-03, -2.2784e-01, -5.6792e-03, -5.7189e-01,
         2.1492e-01, -1.1935e-01, -2.1601e-01, -2.4235e-01,  3.4376e-01,
         2.6661e-01, -1.9733e-01,  4.8864e-03,  1.6744e-01,  1.0434e-01,
        -3.0153e-04,  3.2145e-03,  3.1760e-03,  1.7339e-04,  3.7922e-05,
         7.3675e-05, -2.1500e-04, -1.6379e-03, -5.2934e-03],
       grad_fn=<LinalgSolveExBackward0>)

In [39]:
torch.linalg.cond(A)

tensor(59782876., grad_fn=<SqueezeBackward1>)

In [40]:
A

tensor([[ 8.1731e-03,  1.1354e-03, -3.3658e-05,  ...,  1.0128e-01,
         -5.0421e-03, -5.1304e-03],
        [ 1.1354e-03,  2.1057e-03, -6.0995e-05,  ...,  1.2770e-02,
          1.9029e-03,  1.7827e-04],
        [-3.3658e-05, -6.0995e-05,  2.6531e-04,  ..., -6.9914e-04,
         -6.3254e-03, -5.7062e-03],
        ...,
        [ 1.0128e-01,  1.2770e-02, -6.9914e-04,  ...,  5.1271e+00,
         -9.5530e-02, -3.2283e-01],
        [-5.0421e-03,  1.9029e-03, -6.3254e-03,  ..., -9.5530e-02,
          2.2648e+00,  1.0380e-01],
        [-5.1304e-03,  1.7827e-04, -5.7062e-03,  ..., -3.2283e-01,
          1.0380e-01,  1.9050e+00]], requires_grad=True)

In [41]:
A.diag()

tensor([8.1731e-03, 2.1057e-03, 2.6531e-04, 9.0030e-04, 4.0633e-04, 4.1603e-04,
        1.4641e-04, 3.8512e-04, 1.4818e-04, 1.8286e-04, 2.0913e-04, 1.7092e-04,
        6.2570e-05, 2.0327e-04, 8.8231e-05, 1.3383e-04, 1.7895e-04, 5.1430e-05,
        1.7467e-04, 7.6409e-05, 5.9546e-05, 5.3258e-05, 3.6550e-05, 6.1496e-05,
        4.3800e-05, 3.9916e-05, 3.2244e-05, 3.4983e-05, 2.7346e-05, 3.5999e-05,
        3.9655e-05, 2.7206e-05, 4.1970e-05, 3.4752e-05, 3.0459e-05, 4.0527e-05,
        4.2741e-05, 2.3114e-05, 2.0822e-05, 3.6277e-05, 2.9260e-05, 2.3219e-05,
        2.9221e-05, 2.6685e-05, 2.1644e-05, 2.3351e-05, 1.9196e-05, 2.4129e-05,
        1.9122e-05, 1.9676e-05, 4.0554e+01, 7.4488e+00, 2.3953e+01, 1.1929e+03,
        8.0555e+02, 3.1684e+03, 5.1271e+00, 2.2648e+00, 1.9050e+00],
       grad_fn=<DiagonalBackward0_copy>)

In [44]:
torch.linalg.cond(A + torch.eye(A.))

tensor(3369.5950, grad_fn=<SqueezeBackward1>)

In [45]:
torch.linalg.cond(A)

tensor(59782876., grad_fn=<SqueezeBackward1>)

In [63]:
torch.linalg.cond(A + 1.0 * torch.diag_embed(A.diag()))

tensor(51664336., grad_fn=<SqueezeBackward1>)