In [1]:
import numpy as np
import torch, os
import torch.nn as nn
import torch.nn.functional as F
import torch.onnx

In [2]:
class Block(nn.Module):
    def __init__(self, in_channels: int, out_channels: int, kernel_size: int, stride: int=1, padding: int=1):
        super(Block, self).__init__()

        self.conv1 = nn.Conv3d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding)
        self._relu = nn.ReLU(inplace=True)
    
    def forward(self, x):
        out = self.conv1(x)
        out = self._relu(out)
        return out

In [39]:
res = 32
model_name = "pc_sample_full_res_32.pt"
output_path = "./saved_models"
model_path = os.path.join(output_path, model_name)

In [47]:
visualization = {}
def hook_fn(m, i, o):
  # print(m)
  print("------------Input Grad------------")

  for x in range(2):
    if x == 0:
      visualization['E2'] = x
    else:
      visualization['E2'] = o

  for grad in i:
    try:
      print(grad.shape)
    except AttributeError: 
      print ("None found for Gradient")

  print("------------Output Grad------------")
  for grad in o:  
    try:
      print(grad.shape)
      # print(grad)
    except AttributeError: 
      print ("None found for Gradient")
  print("\n")

In [48]:
class Encoder(nn.Module):
    def __init__(self, res):
        super(Encoder, self).__init__()
        
        self.E_2 = Block(in_channels=64, out_channels=128, kernel_size=3, stride=2, padding=1)
        if res == 32:
            self.E_3 = Block(in_channels=128, out_channels=256, kernel_size=4, stride=4, padding=1)
        else:
            # self.E_3 = Block(in_channels=128, out_channels=256, kernel_size=3, stride=2, padding=1)
            self.E_3 = Block(in_channels=256, out_channels=256, kernel_size=3, stride=2, padding=1)
    
    def forward(self, input, res, model_path):
        E_2_output = self.E_2(input)
        print("E2 output - ")
        print(E_2_output.shape)
        print("\n-----------")

        if res == 32:
            E_3_output = self.E_3(E_2_output)

            print("E3 output - ")
            print(E_3_output.shape)
            print("\n-----------")

        if res == 8:
            # x = torch.load(model_path)
            # # print(x)
            # for key in x:
            #     if key == "E_2.conv1.weight":
            #         E_2_res32 = x["E_2.conv1.weight"]
            #         print(E_2_res32.shape)
            #         break
            # ------------------------------

            prev_E2_output = visualization['E2']

            print("E_2_output - ", E_2_output.shape, ", prev_E2_output", prev_E2_output.shape, "\n")
                    
            E2_cat = torch.cat((prev_E2_output, E_2_output), dim=1)
            print("E2_cat output - ")
            print(E2_cat.shape)
            print("\n-----------")

            E_3_output = self.E_3(E2_cat)
            print("E3_output  - ", E_3_output.shape, "\n")

        return "Done"

In [50]:

model = Encoder(res)
if res == 32:
    model.E_2.register_forward_hook(hook_fn)
# print(model)

In [51]:
input = torch.rand(1, 64, 16, 16, 16)
a = model(input, res, model_path)
# print(a)
# print(a.shape)

------------Input Grad------------
torch.Size([1, 64, 16, 16, 16])
------------Output Grad------------
torch.Size([128, 8, 8, 8])


E2 output - 
torch.Size([1, 128, 8, 8, 8])

-----------
E3 output - 
torch.Size([1, 256, 2, 2, 2])

-----------


In [None]:
visualization['E2']

In [64]:
torch.save(model, model_path)
# torch.onxx.export(model, input, os.path.join(output_path, "res_32.onxx"), export_params=True)

In [67]:
# model.load_state_dict(torch.load(model_path))

x = torch.load(model_path)
# print(x)
# for key in x:
#     print(key)
#     print(x[key].shape)
for name, param in x.named_parameters():
    print(name, ':', param.shape)



E_2.conv1.weight : torch.Size([128, 64, 3, 3, 3])
E_2.conv1.bias : torch.Size([128])
E_3.conv1.weight : torch.Size([256, 128, 4, 4, 4])
E_3.conv1.bias : torch.Size([256])


### For res=32

E2 output - 
`torch.Size([1, 128, 8, 8, 8])`

E3 output - 
`torch.Size([1, 256, 2, 2, 2])`

----

### For res=8

E2 output - 
torch.Size([1, 128, 8, 8, 8])

E_2_output -  torch.Size([1, 128, 8, 8, 8]) , prev_E2_output torch.Size([1, 128, 8, 8, 8]) 

E2_cat output - 
torch.Size([1, 256, 8, 8, 8])

E3_output  -  torch.Size([1, 256, 4, 4, 4]) 



