In [1]:
import torch
import numpy as np
import os

In [8]:
cd "/content/drive/MyDrive/Decomposer"

/content/drive/MyDrive/Decomposer


In [7]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [3]:
from torch import nn
import torch.nn.functional as F
import layers

In [4]:
def scaled_dot_product_attention(q,k,v,mask=None):
    """
    q : query matrix (..,len_q,depth)
    k : key matrix (..,len_k,depth)
    v : value matrix (..,len_v,depth_v) len_v == len_k
    mask : mask to be used before softmax (..,len_q,len_k)
    """
    
    qk = torch.matmul(q,torch.transpose(k,-2,-1)) # (..,len_q,len_k)
    
    # scaling the qk
    dk = torch.tensor(k.shape[-1]).float()
    scaled_qk = qk / torch.sqrt(dk)
    
    if mask is not None:
        scaled_qk += (mask * -1e9)
        
    # softmax is normalized on the last axis (seq_len_k) so that the scores
    # add up to 1.
    
    attention_weights = torch.nn.functional.softmax(scaled_qk,dim = -1) #(..,len_q,len_k)
    
    output = torch.matmul(attention_weights,v) #(len_q,depth_v)
    
    return output,attention_weights
    
    

    
    
    

In [5]:
class MultiHeadAttention(nn.Module):
    def __init__(self,d_model,num_heads,input_dims):
        super(MultiHeadAttention, self).__init__()
        self.num_heads = num_heads
        self.d_model = d_model #output feature dimension
        self.input_dims = input_dims
        
        #d_model is the depth of each head hence it should be divisible by number of heads
        assert d_model % self.num_heads == 0
        self.depth = d_model // self.num_heads
        
        self.wq = torch.nn.Linear(input_dims,d_model)
        print("weights shape:", self.wq.weight.data.shape)
        
        self.wk = torch.nn.Linear(input_dims,d_model)
        self.wv = torch.nn.Linear(input_dims,d_model)
        
        self.projector = torch.nn.Linear(d_model,d_model)
        
    
    def split_heads(self,x,batch_size):
        """Split the last dimension into (num_heads, depth).
        Transpose the result such that the shape is (batch_size, num_heads, seq_len(N), depth)
        """
        x = torch.reshape(x,(batch_size,-1,self.num_heads,self.depth))
        return torch.transpose(x,2,1)
    
    def forward(self,v,k,q,mask):
        batch_size = q.size()[0]
        
        q = self.wq(q) #(batch_size,seq_len(N),d_model)
        k = self.wk(k) #(batch_size,seq_len(N),d_model)
        v = self.wv(v) #(batch_size,seq_len(N),d_model)
        
        q = self.split_heads(q, batch_size)  # (batch_size, num_heads, len_q, depth)
        k = self.split_heads(k, batch_size)  # (batch_size, num_heads, len_k, depth)
        v = self.split_heads(v, batch_size)  # (batch_size, num_heads, len_v, depth)
        
        # scaled_attention.shape == (batch_size, num_heads, len_q, depth)
        # attention_weights.shape == (batch_size, num_heads, len_q, len_k)
        scaled_attention, attention_weights = scaled_dot_product_attention(q, k, v, mask)
        
        scaled_attention = torch.transpose(scaled_attention,2,1)  # (batch_size, len_q, num_heads, depth)
        
        concat_attention = torch.reshape(scaled_attention,(batch_size, -1, self.d_model))   # (batch_size,len_q, d_model)

        
        output = self.projector(concat_attention)  # (batch_size,len_q, d_model)

        return output, attention_weights



        

        
        

In [6]:
# temp_mha = MultiHeadAttention(d_model=512, num_heads=8,input_dims=7)
# y = torch.rand((1, 1024, 7))
# out,att = temp_mha(y,y,y,mask=None)
# out.shape,att.shape

In [7]:
class point_wise_feed_forward_network(nn.Module):
    def __init__(self,d_model,dff,input_dims):
        super(point_wise_feed_forward_network, self).__init__()
        self.d_model = d_model # output feature dimension
        self.dff = dff # intermediate feature dimension
        self.input_dims = input_dims 
        
        self.dense1 = torch.nn.Linear(self.input_dims,self.dff) # (batch_size, seq_len(N), dff)
        self.dense2 = torch.nn.Linear(self.dff,self.d_model)  # (batch_size, seq_len(N), d_model)
        
        
    def forward(self,x):
        x = F.relu(self.dense1(x))
        output = self.dense2(x)
        return output

In [8]:
# sample_ffn = point_wise_feed_forward_network(512, 2048,512)
# sample_ffn(torch.rand((64, 50, 512))).shape


In [9]:
class Encoder(nn.Module):
    def __init__(self,d_model,dff,num_heads,input_dims,if_ffn=False,if_layer_norm=False,
                if_dropout = False,dropout_p = 0.5):
        super(Encoder, self).__init__()
        self.d_model = d_model
        self.dff = dff
        self.num_heads = num_heads
        self.input_dims = input_dims
        self.if_ffn = if_ffn
        self.if_layer_norm = if_layer_norm
        self.if_dropout = if_dropout
        self.dropout_p = dropout_p

        
        self.mha = MultiHeadAttention(d_model=self.d_model, num_heads=self.num_heads,input_dims = self.input_dims)
        if self.if_ffn:
            self.ffn = point_wise_feed_forward_network(d_model = self.d_model,dff=self.dff,input_dims = self.input_dims)
        if self.if_layer_norm:
            self.layernorm1 = torch.nn.LayerNorm(eps = 1e-6)
            self.layernorm2 = torch.nn.LayerNorm(eps = 1e-6)
        if self.if_dropout:
            self.dropout1 = torch.nn.Dropout(self.dropout_p)
            self.dropout2 = torch.nn.Dropout(self.dropout_p)
        
    
    def forward(self,x,mask):
        out, _ = self.mha(x, x, x, mask)  # (batch_size, input_seq_len, d_model)
        if self.if_dropout:
            out = self.dropout1(out)
        if self.if_layer_norm:
            out = self.layernorm1(x + out)
            
        
        if self.if_ffn:
            out = self.ffn(out)
            if self.if_dropout:
                out = self.dropout2(out)
            if self.if_layer_norm:
                out = self.layernorm2(x + out)
        return out
           

In [10]:
sample_encoder_layer = Encoder(d_model=512, num_heads=8, dff=2048,input_dims=512)

sample_encoder_layer_output = sample_encoder_layer(
    torch.rand((64, 43, 512)), None)
sample_encoder_layer_output.size()

weights shape: torch.Size([512, 512])


torch.Size([64, 43, 512])

In [11]:
class Decomposer(nn.Module):
    def __init__(self,d_model_list=[128,256,512],dff=1024,num_heads=4,input_dims=3,if_ffn=False,
                 if_layer_norm=False,if_dropout = False,dropout_p = 0.5,mlp=[16,32,64],n_neighbors=32,
                no_of_classes = 10,if_first_hd=True,first_hd_dim=64,if_group_norm=True,if_avg_pool=True):
        super(Decomposer, self).__init__()
        self.d_model_list = d_model_list
        self.dff = dff
        self.num_heads = num_heads
        self.input_dims = input_dims
        self.if_ffn = if_ffn
        self.if_layer_norm = if_layer_norm
        self.if_dropout = if_dropout
        self.dropout_p = dropout_p
        self.mlp = mlp
        self.n_neighbors = n_neighbors
        self.no_of_classes  = no_of_classes
        self.if_first_hd = if_first_hd
        self.first_hd_dim = first_hd_dim
        self.if_group_norm = if_group_norm
        self.if_avg_pool = if_avg_pool
        
        
        
        self.mlp_convs = nn.ModuleList()
        self.mlp_bns = nn.ModuleList()
        if self.if_first_hd:
            input_dims = self.first_hd_dim + 3
        else:
            input_dims = self.input_dims 
        last_channel = input_dims 
        for out_channel in self.mlp:
            self.mlp_convs.append(nn.Conv2d(last_channel, out_channel, 1))
            self.mlp_bns.append(nn.BatchNorm2d(out_channel))
            last_channel = out_channel
        
        self.encoders = nn.ModuleList()
        input_dims = self.mlp[-1]
        for i in range(len(self.d_model_list)):
            self.encoders.append(layers.Encoder(d_model = self.d_model_list[i],
                                         dff = self.dff,
                                         num_heads = self.num_heads,
                                         input_dims = input_dims,
                                         if_ffn = self.if_ffn,
                                        if_layer_norm = self.if_layer_norm,
                                        if_dropout = self.if_dropout,
                                        dropout_p = self.dropout_p
                                        ))
            input_dims = self.d_model_list[i]
        if self.if_first_hd:
            self.first_hd_conv = torch.nn.Conv1d(self.input_dims,self.first_hd_dim,1)
            self.first_hd_bn = nn.BatchNorm1d(self.first_hd_dim)
        self.classifier_layer = torch.nn.Conv1d(self.d_model_list[-1],self.no_of_classes,1)
        
            
            
    def forward(self,pc,mask):
        pc = pc.permute(0, 2, 1) #(B,C+d,N) --> (B,N,C+d)
        B,N,D = pc.size()
        
        xyz = pc[:,:,:3]
        print(xyz.shape)
        if D>3:
            feat = pc[:,:,3:]
        
        pc = pc.permute(0, 2, 1) #(B,C+d,N) --> (B,N,C+d)
        if self.if_first_hd:
            pc = self.first_hd_conv(pc)
            pc = self.first_hd_bn(pc)
            feat = pc.permute(0, 2, 1)
        
        
        print(pc.shape)
        print(xyz.shape)
        knn_idx = layers.knn(xyz,self.n_neighbors)
        grouped_xyz = layers.index_points(xyz,knn_idx)
        grouped_feat = layers.index_points(feat,knn_idx)
        
        if self.if_group_norm:
            grouped_xyz_norm = layers.group_norm(xyz,grouped_xyz)
        else:
            grouped_xyz_norm = grouped_xyz
        print(grouped_xyz_norm.shape)
        if D>3:
            feat_xyz = torch.cat((grouped_xyz_norm,grouped_feat),dim=3)
        else:
            feat_xyz = grouped_xyz
        
        feat_xyz = feat_xyz.permute(0,3,2,1)
        print(feat_xyz.shape)
        for i,conv in enumerate(self.mlp_convs):
            bn = self.mlp_bns[i]
            feat_xyz = F.relu(bn(conv(feat_xyz)))
        print("after pointnet",feat_xyz.shape)
        
        if self.if_avg_pool:
            feat_xyz = torch.mean(feat_xyz,dim=2)
            
        else:
            feat_xyz = torch.max(feat_xyz)
        print(feat_xyz.shape)   
        feat_xyz = feat_xyz.permute(0,2,1)
        print(feat_xyz.shape)
        for i,encoder in enumerate(self.encoders):
            feat_xyz = encoder(feat_xyz,mask)
        feat_xyz = feat_xyz.permute(0,2,1)
        output = F.log_softmax(self.classifier_layer(feat_xyz),dim=1)
        
        output = output.permute(0,2,1)
        
        return output
        
            
        
        
        
        
        
        
        

        
        
        
        
        
    

In [12]:
net = Decomposer(input_dims=7)
for i in range(5):
  pc = torch.randn((16,7,1024+i))
# from dataloader import 
  out = net(pc,None)
  out.shape

torch.Size([16, 1024, 3])
torch.Size([16, 64, 1024])
torch.Size([16, 1024, 3])
torch.Size([16, 1024, 32, 3])
torch.Size([16, 67, 32, 1024])
after pointnet torch.Size([16, 64, 32, 1024])
torch.Size([16, 64, 1024])
torch.Size([16, 1024, 64])
torch.Size([16, 1025, 3])
torch.Size([16, 64, 1025])
torch.Size([16, 1025, 3])
torch.Size([16, 1025, 32, 3])
torch.Size([16, 67, 32, 1025])
after pointnet torch.Size([16, 64, 32, 1025])
torch.Size([16, 64, 1025])
torch.Size([16, 1025, 64])
torch.Size([16, 1026, 3])
torch.Size([16, 64, 1026])
torch.Size([16, 1026, 3])
torch.Size([16, 1026, 32, 3])
torch.Size([16, 67, 32, 1026])
after pointnet torch.Size([16, 64, 32, 1026])
torch.Size([16, 64, 1026])
torch.Size([16, 1026, 64])
torch.Size([16, 1027, 3])
torch.Size([16, 64, 1027])
torch.Size([16, 1027, 3])
torch.Size([16, 1027, 32, 3])
torch.Size([16, 67, 32, 1027])
after pointnet torch.Size([16, 64, 32, 1027])
torch.Size([16, 64, 1027])
torch.Size([16, 1027, 64])
torch.Size([16, 1028, 3])
torch.Size([16