In [2]:
import numpy as np
import torch.nn as nn
from s2cnn.nn.soft.so3_conv import SO3Convolution
from s2cnn.nn.soft.s2_conv import S2Convolution
from s2cnn.nn.soft.so3_integrate import so3_integrate
from s2cnn.ops.so3_localft import near_identity_grid as so3_near_identity_grid
from s2cnn.ops.s2_localft import near_identity_grid as s2_near_identity_grid
import torch.nn.functional as F
import torch
import torch.utils.data as data_utils
import gzip, pickle
import numpy as np
from torch.autograd import Variable
from torch.distributions import Normal

In [3]:
def load_data(path, batch_size):

    with gzip.open(path, 'rb') as f:
        dataset = pickle.load(f)

    train_data = torch.from_numpy(
        dataset["train"]["images"][:,None,:,:].astype(np.float32))
    train_labels = torch.from_numpy(
        dataset["train"]["labels"].astype(np.int64))

    mean = train_data.mean()
    stdv = train_data.std()

    train_dataset = data_utils.TensorDataset(train_data, train_labels)
    train_loader = data_utils.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

    test_data = torch.from_numpy(
        dataset["test"]["images"][:,None,:,:].astype(np.float32))
    test_labels = torch.from_numpy(
        dataset["test"]["labels"].astype(np.int64))

    test_dataset = data_utils.TensorDataset(test_data, test_labels)
    test_loader = data_utils.DataLoader(test_dataset, batch_size=batch_size, shuffle=True)

    return train_loader, test_loader, train_dataset, test_dataset

In [4]:
class MLP(nn.Module):
    def __init__(self ,H = [1,10,1], activation = nn.ReLU):
        super(MLP, self).__init__()
        self.H = H
        self.activation = activation()
        modules = []
        for input_dim, output_dim in zip (H, H[1:-1]):
            modules.append(nn.Linear(input_dim, output_dim))
            modules.append(self.activation)
        modules.append(nn.Linear(H[-2],H[-1]))
        self.module = nn.Sequential(*modules)    

    def forward(self, x):
        y = self.module(x)
        return y

In [5]:

class S2ConvNet(nn.Module):

    def __init__(self, f_list = [1,10,10], b_list = [30,5,6], mlp_dim = [10], activation = nn.ReLU):
        super(S2ConvNet, self).__init__()
        #TODO make boolean for integrate
        grid_s2 = s2_near_identity_grid()
        grid_so3 = so3_near_identity_grid()
        
        self.f_list = f_list
        self.b_list = b_list
        self.mlp_dim = mlp_dim.copy()
        self.mlp_dim.insert(0, f_list[-1]*(b_list[-1]*2)**3)
        print(self.mlp_dim)
        self.activation = activation

        modules = []
        conv1 = S2Convolution(
            nfeature_in= f_list[0],
            nfeature_out=f_list[1],
            b_in=b_list[0],
            b_out=b_list[1],
            grid=grid_s2)
        modules.append(conv1)
        modules.append(self.activation())
    
        for f_in, f_out, b_in, b_out in zip(f_list[1:-1], f_list[2:], b_list[1:-1], b_list[2:]):
            #print(f_in, f_out, b_in, b_out)
            conv = SO3Convolution(
                                    nfeature_in=f_in,
                                    nfeature_out=f_out,
                                    b_in=b_in,
                                    b_out=b_out,
                                    grid=grid_so3)
            
            modules.append(conv)
            modules.append(self.activation())
            
        self.conv_module = nn.Sequential(*modules) 
        
        self.mlp_module = MLP(H = self.mlp_dim, activation = self.activation)

    def forward(self, x):
       
        x = self.conv_module(x)
        
        x = self.mlp_module(x.view(-1,self.mlp_dim[0]))
       

        #x = so3_integrate(x)

        
        return x
#a = S2ConvNet()

In [10]:
class S2DeconvNet(nn.Module):

    def __init__(self, f_list = [10,10,1], b_list = [5,10,30], mlp_dim = [10], activation = nn.ReLU):
        super(S2DeconvNet, self).__init__()
        #TODO make boolean for integrate
        grid_s2 = s2_near_identity_grid()
        grid_so3 = so3_near_identity_grid()
        
        self.f_list = f_list
        self.b_list = b_list
        self.mlp_dim = mlp_dim.copy()
        self.mlp_dim.append( f_list[0]*(b_list[0]*2)**3)
        print(self.mlp_dim)
        self.activation = activation
        
        self.mlp_module = MLP(H = self.mlp_dim, activation = self.activation)

        modules = []
        
  
        for f_in, f_out, b_in, b_out in zip(f_list[:-1], f_list[1:], b_list[:-1], b_list[1:]):
        
            modules.append(self.activation())
            
            if b_in < b_out:
                modules.append(torch.nn.Upsample(size=b_out*2,  mode='nearest'))
                
            conv = SO3Convolution(
                                    nfeature_in=f_in,
                                    nfeature_out=f_out,
                                    b_in=b_out,
                                    b_out=b_out,
                                    grid=grid_so3)
            
            modules.append(conv)
            
            
        self.conv_module = nn.Sequential(*modules) 
        

    def forward(self, x):
        
        x = self.mlp_module(x)
        shape = x.size()[:-1]
        x = x.view(*shape, self.f_list[0], self.b_list[0]*2, self.b_list[0]*2, self.b_list[0]*2)
        
        x = self.conv_module(x)
        
        
       

        #x = so3_integrate(x)

        
        return x
    
#d = S2DeconvNet()

In [20]:
DEVICE_ID = 0
MNIST_PATH = "./mnist_example/s2_mnist.gz"
NUM_EPOCHS = 20
BATCH_SIZE = 32
LEARNING_RATE = 5e-3

train_loader, test_loader, train_dataset, _ = load_data(
        MNIST_PATH, BATCH_SIZE)

torch.cuda.set_device(DEVICE_ID)

classifier = S2ConvNet()

print("#params", sum([x.numel() for x in classifier.parameters()]))

if torch.cuda.is_available():
    classifier.cuda(DEVICE_ID)

criterion = nn.CrossEntropyLoss()
if torch.cuda.is_available():
    criterion = criterion.cuda(DEVICE_ID)

optimizer = torch.optim.Adam(
    classifier.parameters(),
    lr=LEARNING_RATE)

[17280, 10]
#params 180270


In [14]:
for epoch in range(NUM_EPOCHS):
    for i, (images, labels) in enumerate(train_loader):
        images = Variable(images)
        labels = Variable(labels)

        if torch.cuda.is_available():
            images = images.cuda(DEVICE_ID)
            labels = labels.cuda(DEVICE_ID)

        optimizer.zero_grad()
        outputs = classifier(images)
        loss = criterion(outputs, labels)
        loss.backward()

        optimizer.step()


        print('\rEpoch [{0}/{1}], Iter [{2}/{3}] Loss: {4:.4f}'.format(
           epoch+1, NUM_EPOCHS, i+1, len(train_dataset)//BATCH_SIZE,
           loss.data[0]), end="")
    print("")
    correct = 0
    total = 0
    for i, (images, labels) in enumerate(test_loader):
        images = Variable(images, volatile=True)
        if torch.cuda.is_available():
            images = images.cuda(DEVICE_ID)
            labels = labels.cuda(DEVICE_ID)

        outputs = classifier(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum()

    print('Test Accuracy: {0}'.format(100 * correct / total))

AssertionError: 

In [11]:
d = S2DeconvNet()
d = d.cuda()

[10, 10000]


In [12]:
x = Variable(Normal(torch.Tensor(np.zeros((10))), torch.Tensor(np.ones((10)))).sample_n(6)).cuda()
x.size()

torch.Size([6, 10])

In [13]:
d(x).size()

torch.Size([6, 1, 60, 60, 60])

In [1]:
d


NameError: name 'd' is not defined