In [15]:
import sys
import os

import torch
import torch.nn as nn
import torch.nn.functional as F

from e3nn.non_linearities.rescaled_act import sigmoid, swish, tanh
from e3nn.non_linearities.norm_activation import NormActivation
from e3nn.batchnorm import BatchNorm
from e3nn.image.convolution import Convolution


def ConvBlock(Rs_in, Rs_out, size, stride, fpix):
    return nn.Sequential(
    Convolution(Rs_in, Rs_out, size = size, stride = stride, padding = size//2, bias=None, fuzzy_pixels = fpix),
    NormActivation(Rs_out, swish, normalization = 'norm'),
    )

def ConvTBlock(Rs_in, Rs_out, size, fpix):
    return nn.Sequential(
    Convolution(Rs_in, Rs_out, size=size, stride=2, padding=size//2, bias=None, output_padding=1, transpose=True, fuzzy_pixels = fpix),
    #NormActivation(Rs_out, swish, normalization = 'componenet'),
    )

In [30]:
class VNet(nn.Module):
    def __init__(self, size, n_reps, inp_channels):
        super().__init__()
        #super(VNet, self).init()
        Rs_in = [(inp_channels,0)]
        Rs_out =  [(inp_channels,0)]
        m = inp_channels

        n_vec, n_ten = n_reps
        RsSca = [0]; RsVec = [1] * n_vec; RsTen = [2] * n_ten
        Rs = RsSca * m
        Rs1 = RsVec + RsTen
        fp = True  #option to add noise to conv kernels
        
        #down
        self.dw1 = ConvBlock(Rs_in,          Rs +  Rs1, size=size, stride = 2, fpix=fp)
        self.dw2 = ConvBlock(   Rs +  Rs1, 2*Rs +  Rs1, size=size, stride = 2, fpix=fp)
        self.dw3 = ConvBlock(2* Rs +  Rs1, 2*Rs +  Rs1, size=size, stride = 2, fpix=fp)
        self.dw4 = ConvBlock(2* Rs +  Rs1, 2*Rs +  Rs1, size=size, stride = 2, fpix=fp)
        
        #up
        self.up1 = ConvTBlock( 2*Rs + Rs1, 2*Rs + Rs1, size=size, fpix=fp)
        self.cd1 = ConvBlock(  2*Rs + Rs1, 2*Rs + Rs1,  size=size, stride = 1, fpix=fp)
        self.up2 = ConvTBlock( 2*Rs + Rs1, 2*Rs + Rs1,  size=size, fpix=fp)
        self.cd2 = ConvBlock(  2*Rs + Rs1, 2*Rs + Rs1,  size=size, stride = 1, fpix=fp)
        self.up3 = ConvTBlock( 2*Rs + Rs1, 2*Rs + Rs1,  size=size, fpix=fp)
        self.cd3 = ConvBlock(  2*Rs + Rs1,   Rs + Rs1,  size=size, stride = 1, fpix=fp)
        self.up4 = ConvTBlock(   Rs + Rs1,   Rs + Rs1,  size=size, fpix=fp)
        self.cd4 = ConvBlock(    Rs + Rs1,   Rs_out,    size=size, stride = 1, fpix=fp)
   
    def forward(self, x):
        # Down sampling
        dw1 = self.dw1(x) 
        dw2 = self.dw2(dw1) 
        dw3 = self.dw3(dw2) 
        dw4 = self.dw4(dw3) 
        
        up1 = self.up1(dw4) 
        cd1 = self.cd1(up1)
        up2 = self.up2(cd1) 
        cd2 = self.cd2(up2) 
        up3 = self.up3(cd2) 
        cd3 = self.cd3(up3) 
        up4 = self.up4(cd3) 
        cd4 = self.cd4(up4)
        
        return cd4, dw4

In [32]:
inp_size = 16
inp_channels = 11
n_reps = [3,1] # n_vectors (L=1), n_tensors (L=2)
k_size = 3

x = torch.Tensor(2, inp_size, inp_size, inp_size, inp_channels).cuda()
#x.to(device)
print("Input size: {}".format(x.size()))

model = VNet(inp_size, n_reps, inp_channels).cuda()

out, latent = model(x)
print("latent size: {}".format(latent.size()))
print("output size: {}".format(out.size()))

Input size: torch.Size([2, 16, 16, 16, 11])


TypeError: conv3d() received an invalid combination of arguments - got (Tensor, Tensor, transpose=bool, output_padding=int, bias=NoneType, padding=int, stride=int), but expected one of:
 * (Tensor input, Tensor weight, Tensor bias, tuple of ints stride, tuple of ints padding, tuple of ints dilation, int groups)
      didn't match because some of the keywords were incorrect: transpose, output_padding
 * (Tensor input, Tensor weight, Tensor bias, tuple of ints stride, str padding, tuple of ints dilation, int groups)
      didn't match because some of the keywords were incorrect: transpose, output_padding


In [3]:
import torch
import torch.nn as nn
import numpy as np
import pyvista as pv
from scipy import ndimage
from e3nn import rs
from e3nn.image.convolution import Convolution

class Simple(nn.Module):
    def __init__(self, fuzzy_pixels):
        #super(Simple, self).__init__()
        super().__init__()

        size = 3
        self.f = torch.nn.Sequential( Convolution(Rs_in=[0], Rs_out=[0, 1], size=size,  steps=(1., 1., 1.), fuzzy_pixels=fuzzy_pixels),
        )
    def forward(self, x):
        out = self.f(x)
        return out

def rotate(inp, rotation_angle):    
    inp = ndimage.interpolation.rotate(inp,
                                       angle = rotation_angle,
                                       axes=(2,3),
                                       reshape=False,
                                       order=1,
                                       mode= 'nearest',#'constant',
                                       cval=0.0)
    return inp


def VoxPositions(dim, res):
    pos = np.empty(shape=(dim*dim*dim,3))
    l = 0
    for i in range(dim):
        for j in range(dim):
            for k in range(dim):
                pos[l, 0] = i*res + res/2
                pos[l, 1] = j*res + res/2
                pos[l, 2] = k*res + res/2
                l = l + 1 
    return pos

if __name__=='__main__':

    dim = 16;  chans = 1
    torch.manual_seed(1000)

    # get input
    inp = torch.zeros(1, chans, dim, dim, dim)
    inp[:, :, dim//2, dim//2, dim//2] = 1.0
    inp[:, :, dim//2, 1, dim//2] = 1.0
    model = Simple(fuzzy_pixels=True)
    
    # plotting setup
    pl = pv.Plotter(shape=(1, 3))
    pl.open_movie('Simple.mp4')
    fs = 12 #text font size
    cents = VoxPositions(dim=dim-2, res = 1.0)

    for i in range(180):
        inpR = rotate(inp.numpy(), rotation_angle= i*2.0)
        inpR = torch.from_numpy(inpR).float()#.to('cuda')
        inpR = torch.einsum('tixyz->txyzi', inpR) #permute
                
        model.eval()
        outR = model(inpR)
        outR = torch.einsum('txyzi->tixyz', outR) #unpermute
        inpR = torch.einsum('txyzi->tixyz', inpR) #unpermute

        OutSca = outR[0, 0, :, :, :].detach().numpy()
        OutVec = outR[0, 1:4, :, :, :]
        OutVec = OutVec.flatten(1).detach().numpy()

        vec = np.array([OutVec[2], OutVec[0], OutVec[1]]).T
             
        text = "angle = " + str(2*i)
        pl.subplot(0, 0);  
        pl.add_text("Input", position = 'lower_left', font_size = fs)
        pl.add_text(text, position = 'upper_left', font_size = fs)
        pl.add_volume(inpR[0][0].detach().numpy(), cmap = "viridis_r",
                      opacity = "linear", show_scalar_bar=False)

        pl.subplot(0, 1);  
        pl.add_text("Out Vector", position = 'lower_left', font_size = fs)
        pl.add_arrows(cents, vec, mag=20, show_scalar_bar=False)
        
        pl.subplot(0, 2);  
        pl.add_text("Output", position = 'lower_left', font_size = fs)
        OutSca[OutSca < 0.01] = 0.0
        pl.add_volume(OutSca, cmap = "viridis_r", show_scalar_bar=False)
        #pl.add_axes()

        if i == 0 :
            pl.show(auto_close=False)
          
        pl.write_frame()
        pl.clear()

    pl.close()

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
  out = mul_m_lm(Rs, sha, shz)
 does not have profile information (Triggered internally at  /opt/conda/conda-bld/pytorch_1659484809535/work/torch/csrc/jit/codegen/cuda/graph_fuser.cpp:104.)
  out = mul_m_lm(Rs, sha, shz)


ViewInteractiveWidget(height=768, layout=Layout(height='auto', width='100%'), width=1024)