<a href="https://colab.research.google.com/github/vidaurridante3/pointnet/blob/main/Pointnet_outline.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F


In [None]:
# Point Net Backbone
class PointNetBackbone(nn.Module):
    def __init__(self):

        super(PointNetBackbone, self).__init__()


        # transformer encoders
        self.tencoder1 = nn.Transformer_Encoder(
            num_layers=4,
            d_model=3,
            nhead=1,
            dropout=.2,
            drop_path=0,
            activation='relu',
            normalize_before=True
        )

        self.pos_embed_feature = build_position_encoding(
            position_embedding_type='learned',
            num_pos_feats=64,
        )

        self.tencoder2 = nn.Transformer_Encoder(
            num_layers=4,
            d_model=64,
            nhead=1,
            dropout=.2,
            drop_path=0,
            activation='relu',
            normalize_before=True
        )



        # shared MLP 1
        self.conv1 = nn.Conv1d(in_channels=3, out_channels=64, kernel_size=1)
        self.conv2 = nn.Conv1d(in_channels=64, out_channels=64, kernel_size=1)
        self.conv3 = nn.Conv1d(in_channels=64, out_channels=64, kernel_size=1)

        # shared MLP 2
        self.conv4 = nn.Conv1d(in_channels=64, out_channels=64, kernel_size=1)
        self.conv5 = nn.Conv1d(in_channels=64, out_channels=128, kernel_size=1)
        self.conv6 = nn.Conv1d(in_channels=128, out_channels=1024, kernel_size=1)


        # max pool to get the global features
        self.final_conv = nn.Conv1d(in_channels=361,
            out_channels=1,
            kernel_size=1,
            stride=1,
            padding=0
        )


    def forward(self, x: NestedTensor):

        # get nested vector size
        data = x.tensors
        mask = x.mask

        # pass through first encoder layer
        output = self.tencoder1(data, mask)

        # pass through first shared MLP
        #into batch_size, num_features, num_points
        output = output.permute(0, 2, 1)
        output = F.relu(self.conv1(output))
        output = F.relu(self.conv2(output))
        output = F.relu(self.conv3(output))
        #batch_size, num_points, num_features
        output = output.permute(0, 2, 1)

        # pass through second encoder layer
        pos_embed = self.pos_embed_feature(output)
        x = self.tencoder2(output, mask, pos_embed)

        # pass through second MLP
        #into batch_size, num_features, num_points
        output = output.permute(0, 2, 1)
        output = F.relu(self.conv4(output))
        output = F.relu(self.conv5(output))
        output = F.relu(self.conv6(output))
        #batch_size, num_points, num_features
        output = output.permute(0, 2, 1)


        #pass through max pool
        output = self.final_conv(output)

        return output

