In [2]:
import torch
import numpy as np

In [3]:
from spherenet import SphereConv2D, SphereMaxPool2D
from torch import nn
import torch.nn.functional as F

class SphereNet(nn.Module):
    def __init__(self):
        super(SphereNet, self).__init__()
        # conv_block_1
        self.conv1_1 = SphereConv2D(1 , 64, stride=1)
        self.conv1_2 = SphereConv2D(64, 64, stride=1)
        # conv_block_2
        self.conv2_1 = SphereConv2D(64 , 128, stride=1)
        self.conv2_2 = SphereConv2D(128, 128, stride=1)
        # conv_block_3
        self.conv3_1 = SphereConv2D(128, 256, stride=1)
        self.conv3_2 = SphereConv2D(256, 256, stride=1)
        self.conv3_3 = SphereConv2D(256, 256, stride=1)
        # conv_block_4
        self.conv4_1 = SphereConv2D(256, 512, stride=1)
        self.conv4_2 = SphereConv2D(512, 512, stride=1)
        self.conv4_3 = SphereConv2D(512, 512, stride=1)
        # conv_block_5
        self.conv5_1 = SphereConv2D(512, 512, stride=1)
        self.conv5_2 = SphereConv2D(512, 512, stride=1)
        self.conv5_3 = SphereConv2D(512, 512, stride=1)
        self.pool = SphereMaxPool2D(stride=2)
        #self.fc = nn.Linear(14400, 10)

    def forward(self, x):
        x = self.conv1_2(self.conv1_1(x))
        x = F.relu(self.pool(x))
        x = self.conv2_2(self.conv2_1(x))
        x = F.relu(self.pool(x))
        x = self.conv3_3(self.conv3_2(self.conv3_1(x)))
        x = F.relu(self.pool(x))
        x = self.conv4_3(self.conv4_2(self.conv4_1(x)))
        x = F.relu(self.pool(x))
        x = self.conv5_3(self.conv5_2(self.conv5_1(x)))
        #x = F.relu(self.pool(x))
        #x = x.view(-1, 14400)  # flatten, [B, C, H, W) -> (B, C*H*W)
        #x = self.fc(x)
        return x

In [4]:
from PIL import Image
import torchvision.transforms.functional as tfun

def rgb2gray(rgb):
    return np.dot(rgb[...,:3], [0.2989, 0.5870, 0.1140])

path = 'camera_04a287849657478ea774727e5bff5202_office_3_frame_equirectangular_domain_rgb.png'
image = Image.open(path).resize((224,224))
data = rgb2gray(np.asarray(image, dtype=np.float32))
data = tfun.to_tensor(data)
data = data.unsqueeze_(0)
#data = data[:,:3,:,:]
print(data.shape)

torch.Size([1, 1, 224, 224])


In [5]:
device = torch.device('cpu')
spheremodel = SphereNet().to(device)
out = spheremodel(data.to(device))

  new_theta = theta + arctan(x*sin(v) / (rho*cos(phi)*cos(v) - y*sin(phi)*sin(v)))


In [6]:
out.shape

torch.Size([1, 512, 14, 14])

In [7]:
state = torch.load('sphere_model.pkl')
for k in sorted(state.keys()):
    v = state[k]
    print(k, v.shape)

conv1.bias torch.Size([32])
conv1.weight torch.Size([32, 1, 3, 3])
conv2.bias torch.Size([64])
conv2.weight torch.Size([64, 32, 3, 3])
fc.bias torch.Size([10])
fc.weight torch.Size([10, 14400])
