In [21]:
import torch
import torch.nn as nn
import torch.nn.functional as F


class UNet(nn.Module):
    def __init__(
        self,
        init_features: int = 64,
        depth: int = 4,
        size: int = 256,
        mode: str = "point",  # "point", "normal", "point_normal"
        max_weight: float = 100.0,
        device: str = "cuda",
    ):
        super().__init__()

        self.max_weight = max_weight        

        self.mode = mode
        assert mode in ["point", "normal", "point_normal"]
        out_channels = 1
        in_channels = 6
        if mode == "point_normal":
            in_channels = 12        

        self.depth = depth
        self.size = size
        features = init_features

        # Contracting Path (Encoder)
        self.encoders = nn.ModuleList()
        self.pools = nn.ModuleList()
        for i in range(depth):
            self.encoders.append(UNet._block(in_channels, features))
            self.pools.append(nn.MaxPool2d(kernel_size=2, stride=2))
            in_channels = features
            features *= 2

        # Bottleneck
        self.bottleneck = UNet._block(features // 2, features)

        # Expansive Path (Decoder)
        self.upconvs = nn.ModuleList()
        self.decoders = nn.ModuleList()
        for i in range(depth):
            features //= 2
            self.upconvs.append(
                nn.ConvTranspose2d(
                    features * 2,
                    features,
                    kernel_size=2,
                    stride=2,
                )
            )
            self.decoders.append(UNet._block(features * 2, features))

        # Final Convolution
        self.conv = nn.Conv2d(
            in_channels=features,
            out_channels=out_channels,
            kernel_size=1,
        )

        self.to(device)

    def forward(
        self,
        s_point: torch.Tensor,
        s_normal: torch.Tensor,
        t_point: torch.Tensor,
        t_normal: torch.Tensor,
    ):
        # prepare input
        if self.mode == "point_normal":  # (B, H, W, 12)
            x = torch.cat([s_point, s_normal, t_point, t_normal], dim=-1)
        elif self.mode == "point":
            x = torch.cat([s_point, t_point], dim=-1)  # (B, H, W, 6)
        elif self.mode == "normal":
            x = torch.cat([s_normal, t_normal], dim=-1)  # (B, H, W, 6)
        else:
            raise AttributeError(f"No {self.mode} that works.")
        x = x.permute(0, 3, 1, 2)  # (B, C, H, W)
        B, C, H, W = x.shape 
        x = self._pad(x, height=H, width=W)

        # B, H, W, C
        encoders_output = []
        for i in range(self.depth):
            x = self.encoders[i](x)
            encoders_output.append(x)
            x = self.pools[i](x)

        bottleneck = self.bottleneck(x)

        for i in range(self.depth):
            x = self.upconvs[i](bottleneck if i == 0 else x)
            enc_output = encoders_output[-(i + 1)]
            x = torch.cat((x, enc_output), dim=1)
            x = self.decoders[i](x)
        x = torch.exp(self.conv(x))
        x = self._unpad(x, height=H, width=W)
        x = x.permute(0, 2, 3, 1)  # (B, W, H, 1)

        return x 

    @staticmethod
    def _block(in_channels: int, features: int):
        return nn.Sequential(
            nn.Conv2d(
                in_channels=in_channels,
                out_channels=features,
                kernel_size=3,
                padding=1,
            ),
            nn.ReLU(inplace=True),
            nn.Conv2d(
                in_channels=features,
                out_channels=features,
                kernel_size=3,
                padding=1,
            ),
            nn.ReLU(inplace=True),
        )

    def _pad(self, x: torch.Tensor, height: int, width: int):
        # Desired output dimensions
        target_height = self.size
        target_width = self.size

        # Calculate padding for height and width
        pad_height = target_height - height
        pad_width = target_width - width

        # Pad equally on both sides
        padding = [
            pad_width // 2,
            pad_width - pad_width // 2,
            pad_height // 2,
            pad_height - pad_height // 2,
        ]  # (left, right, top, bottom)

        # Apply padding
        return F.pad(x, padding)

    def _unpad(self, x: torch.Tensor, height: int, width: int):
        # Desired output dimensions
        target_height = self.size
        target_width = self.size

        # Calculate padding for height and width
        pad_height = target_height - height
        pad_width = target_width - width

        # Slice back to the original shape (135, 240)
        start_height = pad_height // 2
        end_height = start_height + height 

        start_width = pad_width // 2
        end_width = start_width +  width

        return x[:, :, start_height:end_height, start_width:end_width]


model = UNet(init_features=32, depth=4, mode="point_normal", device="cpu")
# print(unet)

# Example input tensor (batch_size, channels, height, width)
s_point = torch.zeros((1, 135, 240, 3))
t_point = torch.zeros((1, 135, 240, 3))
s_normal = torch.zeros((1, 135, 240, 3))
t_normal = torch.zeros((1, 135, 240, 3))
out = model(
    s_point =s_point,
    t_point=t_point,
    s_normal=s_normal,
    t_normal=t_normal,
)
out["bottleneck"].shape

torch.Size([1, 512, 16, 16])

In [73]:
# Example of usage
model = UNet(in_channels=1, out_channels=1, init_features=32)
# print(model)

# Example input tensor (batch_size, channels, height, width)
x = torch.zeros((1, 1, 135, 240))
out = model(x)
out["weight"].shape

torch.Size([1, 1, 256, 256])

131072

In [45]:
import torch
import torch.nn.functional as F

# Input tensor of shape (batch_size, channels, height, width)
x = torch.randn((1, 1, 135, 240), requires_grad=True)


# Desired output dimensions
target_height = 256
target_width = 256

# Calculate padding for height and width
pad_height = target_height - x.shape[2]
pad_width = target_width - x.shape[3]

# Pad equally on both sides
padding = [pad_width // 2, pad_width - pad_width // 2, pad_height // 2, pad_height - pad_height // 2]  # (left, right, top, bottom)

# Apply padding
x_padded = F.pad(x, padding)

print("Padded shape:", x_padded.shape)  # Should be (1, 1, 256, 256)


Padded shape: torch.Size([1, 1, 256, 256])


In [46]:
import torch
import torch.nn.functional as F

# Input tensor of shape (batch_size, channels, height, width)
x = torch.randn((1, 1, 135, 240), requires_grad=True)

# Desired output dimensions
target_height = 256
target_width = 256

# Calculate padding for height and width
pad_height = target_height - x.shape[2]
pad_width = target_width - x.shape[3]

# Pad equally on both sides
padding = [pad_width // 2, pad_width - pad_width // 2, pad_height // 2, pad_height - pad_height // 2]  # (left, right, top, bottom)

# Apply padding
x_padded = F.pad(x, padding)

print("Padded shape:", x_padded.shape)  # Should be (1, 1, 256, 256)

# Slice back to the original shape (135, 240)
start_height = pad_height // 2
end_height = start_height + x.shape[2]

start_width = pad_width // 2
end_width = start_width + x.shape[3]

x_original_shape = x_padded[:, :, start_height:end_height, start_width:end_width]

print("Shape after slicing:", x_original_shape.shape)  # Should be (1, 1, 135, 240)


Padded shape: torch.Size([1, 1, 256, 256])
Shape after slicing: torch.Size([1, 1, 135, 240])


In [47]:
x_original_shape

tensor([[[[ 1.0182, -0.2634,  1.2838,  ..., -0.5184,  0.2893,  0.1558],
          [-0.3295, -0.1763, -0.4001,  ..., -0.4550, -0.6738, -0.6903],
          [-0.5878, -0.4613, -0.0079,  ...,  1.0670, -1.4411,  0.4294],
          ...,
          [-0.2986,  1.8018, -0.8441,  ..., -1.4393,  1.9679,  0.3780],
          [-0.7267, -1.0729, -0.5690,  ..., -0.0377,  1.1575, -0.7883],
          [-0.6154,  0.8982, -1.1257,  ...,  0.6682,  1.5402,  0.4965]]]],
       grad_fn=<SliceBackward0>)

In [48]:
x

tensor([[[[ 1.0182, -0.2634,  1.2838,  ..., -0.5184,  0.2893,  0.1558],
          [-0.3295, -0.1763, -0.4001,  ..., -0.4550, -0.6738, -0.6903],
          [-0.5878, -0.4613, -0.0079,  ...,  1.0670, -1.4411,  0.4294],
          ...,
          [-0.2986,  1.8018, -0.8441,  ..., -1.4393,  1.9679,  0.3780],
          [-0.7267, -1.0729, -0.5690,  ..., -0.0377,  1.1575, -0.7883],
          [-0.6154,  0.8982, -1.1257,  ...,  0.6682,  1.5402,  0.4965]]]],
       requires_grad=True)