In [1]:
import pandas as pd
import pickle
import numpy as np
from Constants import Const
import torch
import torch.nn as nn
import cv2
import glob
import json

In [2]:
# import sparseconvnet as scn
import open3d as o3d
import open3d.ml.torch as ml3d

Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.


2023-05-12 08:36:26.310429: I tensorflow/core/util/port.cc:110] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2023-05-12 08:36:26.362721: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [3]:
def load_pointclouds(pc_dir='../data/',max_num = 10000,bad_ids=Const.bad_ids):
    files = glob.glob(pc_dir+'pclouds_*.json')
    all_files = []
    for file in files:
        with open(file,'r') as f:
            test = json.load(f)
        if int(test['patient_id']) not in bad_ids:
            all_files.append(test)
        if len(all_files) > max_num:
            break
    return all_files
# pointclouds = load_pointclouds()
# len(pointclouds)

In [6]:
def get_mdasi_vector(pid,mdasi,symptoms  = None,threshold=5,confounders = None):
    if symptoms is None:
        symptoms = ['drymouth','pain','swallow','choke','voice','mucus','mucositis']
    if confounders is None:
        confounders = ['hpv']
    entry = mdasi['data'].get(str(pid))
    results = [entry[s+'_6M'] for s in symptoms]
    if threshold is not None:
        results = [int(r >= threshold) for r in results]
    else:
        results = [r/10 for r in results]
    for confounder in confounders:
        results.append(entry[confounder])
    return np.array(results).astype(bool)

def extract_dist(patient_id,dists,organs=None,gtvs=['gtv','gtvn']):
    if organs is None:
        organs = dists['colOrder']
    colOrder = dists['colOrder']
    oindexes = [colOrder.index(o) for o in organs]
    patient = dists['distances'][str(patient_id)]
    d = [[p[i] for i in oindexes]for ii,p in enumerate(patient) if dists['rowOrder'][ii] in gtvs]
    d = np.array(d).min(axis=0)
    return d

def downsample_pointcloud(pcloud,k=1,max_points = 5000,pad=True):
    pc = o3d.geometry.PointCloud()
    pc.points = o3d.utility.Vector3dVector(pcloud['coordinates'])
    #save dose as rgb color for when they're discretized
    pc.colors = o3d.utility.Vector3dVector(np.stack([pcloud['dose_values'] for c in range(3)],axis=-1))
    pc = pc.voxel_down_sample(k)#to make it uniform for the convolution
    if len(pc.points) > max_points:
        pc = pc.farthest_point_down_sample(max_points)
    points = np.asarray(pc.points)
    colors = np.asarray(pc.colors)[:,0]
    if points.shape[0] < max_points and pad:
        diff = max_points -points.shape[0] 
        padd = np.zeros((diff,3))
        points = np.concatenate([points,padd],axis=0)
        colors = np.concatenate([colors,padd[:,0]],axis=0)
    return {'coordinates': points,'dose_values': colors}

def get_patient_stuff(pentry,
                dists_data,
                mdasi_data,
                organs = Const.organ_list,
                gtvs=['gtv','gtvn'],
                downsample_k=None,
                max_points=2000,
                pad=True,
                symptoms=None,
                torchify=True,
                requires_grad=True,
               ):
    #processes patient, returns a list of [pointcloud coordinates,pointcloud dose values, distances, mdasi_vector]
    #pointclouds and distacnes are in the predefined order, distacnes is min gtv-organ distances currently
    #pointcloud stuff is a list of tensors (one for each organ)
    cpc = pentry['contour_pointclouds']
    pid = pentry['patient_id']
    distances = extract_dist(pid,dists_data,organs=organs)
    mdasi_stuff = get_mdasi_vector(pid,mdasi_data,symptoms=symptoms)
    padsize = 10 if not pad else max_points
    placeholder = {'coordinates': np.zeros((padsize,3)),'dose_values': np.zeros((padsize,1))}
    coords = []
    values = []
    max_val = 0
    for o in gtvs + organs:
        vals = cpc.get(o)
        if vals is not None and (downsample_k is not None or vals['coordinates'].shape[0] > max_points or pad):
            mpoints = max_points if vals in organs else 10000000 #don't downsample tumors
            vals = downsample_pointcloud(vals,downsample_k,mpoints,pad)
        else:
            vals = placeholder
        coords.append(vals['coordinates'])
        values.append(vals['dose_values'].reshape(-1,1))
        if o in gtvs:
            max_val = max(max_val,np.max(vals['dose_values']))
    dose_scale = 100
    if max_val > 13000:
        dose_scale = 1000
    if max_val > 130000:
        dose_scale = 10000
    values = [v/dose_scale for v in values]
    if torchify:
        coords = [torch.FloatTensor(c) for c in coords]
        values = [torch.FloatTensor(v) for v in values]
        if pad:
            coords = torch.stack(coords)
            values = torch.stack(values)
#             coords.requires_grad_(requires_grad)
#             values.requires_grad_(requires_grad)
        distances = torch.FloatTensor(distances)
#         distances.requires_grad_(requires_grad)
        mdasi_stuff = torch.LongTensor(mdasi_stuff) #change if I ever do continuous mdasi features
    output = [coords,values,distances,mdasi_stuff]
    return output

In [None]:
def get_device():
    return 'cuda' if torch.cuda.is_available() else 'cpu'

class DicomDataset(torch.utils.data.Dataset):
    
    def __init__(self,
                 pc_files=None,#list of precompute patient dose pointclouds
                 dist_json=None, #file with tumor-organ distances for each patient
                 mdasi_json=None, #symptoms
                 bad_ids = Const.bad_ids,#skip these
                 dist_file = None,#if dist json is NOne, where to load from
                 mdasi_file = None,#if mdasi_json is none, where to load
                 symptoms =  ['drymouth','pain','choke','mucus'],
                 confounders=[],
                 organs = None,
                 gtvs=None,
                 max_pc_points=1000,
                 voxel_size=1,
                 pad = False,
                 shuffle_on_init=True,
                 outcome_weight_scale=.5,
                 requires_grad=True,
                ):
        self.symptoms = symptoms
        self.n_symptoms = len(symptoms)
        self.max_pc_points = max_pc_points
        self.voxel_size = voxel_size
        
        if organs is None:
            organs = Const.organ_list[:]
        if gtvs is None:
            gtvs = ['gtv','gtvn']
        self.organs = organs
        self.gtvs = gtvs
        self.confounders = confounders
        self.pad = pad
        
        if pc_files is None:
            pc_files = load_pointclouds()
    
        if dist_json is None:
            if dist_file is None:
                dist_file = Const.small_dist_json
            with open(dist_file,'r') as f:
                dist_json = json.load(f)
        
        if mdasi_json is None:
            if mdasi_file is None:
                mdasi_file = '../data/dicom_mdasi.json'
            with open(mdasi_file,'r') as f:
                mdasi_json = json.load(f)
        
        self.coords = []
        self.values = []
        self.distances = []
        self.mdasi = []
        for file in pc_files:
            [c,v,d,m] = get_patient_stuff(file,
                                                               dist_json,
                                                               mdasi_json,
                                                               pad=self.pad,
                                                               organs=self.organs,
                                                               gtvs=self.gtvs,
                                                               downsample_k = voxel_size,
                                                               max_points = self.max_pc_points,
                                                               symptoms=self.symptoms,
                                                              requires_grad=requires_grad,
                                                              )
            self.coords.append(c)
            self.values.append(v)
            self.distances.append(d)
            self.mdasi.append(m)
        self.normalizer = self.get_normalizer(self.values,self.gtvs+self.organs) 
        if shuffle_on_init:
            self.shuffle()
        else:
            self.distances = torch.stack(self.distances)
            self.mdasi = torch.stack(self.mdasi)

        frequencies = self.mdasi.sum(axis=0)/self.mdasi.shape[0]
        outcome_weights = 1/frequencies
        outcome_weights = outcome_weights/outcome_weights.max()
        self.outcome_weights = outcome_weights**outcome_weight_scale
        self.require_grad = requires_grad
            
    def shuffle(self):
        temp = list(zip(self.coords,self.values,self.distances,self.mdasi))
        np.random.shuffle(temp)
        c,v,d,m = zip(*temp)
        self.coords = list(c)
        self.values = list(v)
        self.distances = list(d)
        self.mdasi = list(m)
        self.distances = torch.stack(self.distances)
        self.mdasi = torch.stack(self.mdasi)
        
    def get_normalizer(self,val_lists,organs):
        allvals = []
        for patient in val_lists:
            for l in patient:
                if len(l) > 2:
                    allvals.append(l)
        allvals = torch.concat(allvals)
        normalize = lambda x: (x - allvals.mean())/allvals.std()
        return normalize
    
    def __len__(self):
            return self.mdasi.shape[0]
        
    def __getitem__(self,idx):
        c = self.coords[idx]
        v = [self.normalizer(vv) for vv in self.values[idx]]
        d = self.distances[idx]
        m = self.mdasi[idx]
        return [c,v,d], m
    
def dicom_tt_split(all_files=None,test_ratio=.3,shuffle_split=False,shuffle_on_epoch = True,batch_size=10,**kwargs):
    if all_files is None:
        all_files = load_pointclouds()
    if shuffle_split:
        np.random.shuffle(all_files)
    test_index = int(len(all_files)*test_ratio)
    test_pcs = all_files[0:test_index]
    train_pcs = all_files[test_index:]
    train_dl = DicomDataset(pc_files=train_pcs,**kwargs)
    test_dl = DicomDataset(pc_files=test_pcs,**kwargs)
    device = get_device()
    def collate(batch):
        coords = []
        vals = []
        dists = []
        mdasi = []
        for [[c,v,d],m] in batch:
            coords.append([cc.to(device) for cc in c])
            vals.append([vv.to(device) for vv in v])
            dists.append(d) #first value is # of channels, currentl assume only dose
            mdasi.append(m)
        ybatch = torch.stack(mdasi).to(device)
        dists = torch.stack(dists).to(device)
        xbatch = [coords,vals,dists]
        return xbatch,ybatch
    train_dl = torch.utils.data.DataLoader(train_dl,
                                           batch_size=batch_size,
                                           shuffle=shuffle_on_epoch,
                                           drop_last=False,
                                           num_workers=3,
                                           pin_memory=True,
                                           collate_fn=collate)
    test_dl = torch.utils.data.DataLoader(test_dl,
                                          batch_size=batch_size,
                                          shuffle=False,
                                          drop_last=False,
                                          num_workers=3,
                                           pin_memory=True,
                                          collate_fn=collate)
    return train_dl, test_dl

[train_data,test_data] = dicom_tt_split()
next(iter(train_data))

In [8]:
class SparseConvBlock(nn.Module):
    #runs a batch of convolutions on a single pointcloud
    def __init__(self,
                 input_features,
                 dims=3,
                 filter_sizes=[32,32],
                 kernel_sizes=[20,10],
                 voxel_size=2,
                 continuous=True,
                ):
        nn.Module.__init__(self)
        self.dims = dims
        self.features=input_features
        self.blockBatchNorm = nn.BatchNorm1d(filter_sizes[-1])
        inputs = [input_features] + filter_sizes
        makeConv = ml3d.layers.ContinuousConv
        if not continuous:
            makeConv = ml3d.layers.SparseConv
        self.convolutions = nn.ModuleList([makeConv(inputs[i],filter_sizes[i],[kernel_sizes[i] for k in range(dims)]) for i in range(len(filter_sizes))])
        self.voxel_pool = ml3d.layers.VoxelPooling(position_fn='center',feature_fn='max')
        #for tracking organ positions
        self.voxel_pool_nn = ml3d.layers.VoxelPooling(position_fn='center',feature_fn='nearest_neighbor')
        self.voxel_size=voxel_size
        
    def forward(self,positions,features):
        x = self.convolutions[0](features,positions,positions,1)
        for conv in self.convolutions[1:]:
            x = conv(x,positions,positions,1)
#         x = self.blockBatchNorm(x.transpose(1,2)).transpose(1,2)
#         positions, x = self.voxel_pool(positions,x,self.voxel_size)
        return positions,x

class SparseCNN(nn.Module):
    #full cnn for a single pointcloud
    def __init__(self,
                 input_features,
                 dims=3,
                 filter_sizes=[[32,32],[32,32]],
                 kernel_sizes=[[20,10],[20,10]],
                 continuous=True,
                 pool_size = 2,
                ):
        nn.Module.__init__(self)
        self.dims = dims
        self.features=input_features
        curr_size = input_features
        self.blocks = []
        self.inputNorm = nn.BatchNorm1d(input_features)
        curr_pool_size = pool_size
        for [filterblock, kernelblock] in zip(filter_sizes,kernel_sizes):
            block = SparseConvBlock(curr_size,dims,filter_sizes=filterblock,kernel_sizes=kernelblock,continuous=continuous,voxel_size=curr_pool_size)
            self.blocks.append( block)
            curr_size = filterblock[-1]
            curr_pool_size += 1
        self.blocks = nn.ModuleList(self.blocks)
        
    def voxel_pool_nn(self,positions,features):
        #for when I do a convolution and need to keep track of organ positions
#         for block in self.blocks:
#             positions,features = block.voxel_pool_nn(positions,features,block.voxel_size)
        return positions, features
        
    def forward(self,positions,features):
        x = features#self.inputNorm(features)
        for block in self.blocks:
            positions,x = block(positions,x)
        return positions,x
    
class OrganNet(nn.Module):
    
    def __init__(self,
                 n_classes=8,
                 n_organs=50,
                 dist_dims=[1000],
                 fc_dims=[500,500],
                 fc_dropout = .5,
                 sparse_cnn=None,
                 organ_dropout=.5,
                 **cnn_args
                ):
        super(OrganNet, self).__init__()
        if sparse_cnn is None:
            sparse_cnn = SparseCNN(1,*cnn_args)
        self.cnn = sparse_cnn
        self.dist_layers = nn.ModuleList([nn.LazyLinear(dd) for dd in dist_dims])
        self.fc_layers = nn.ModuleList([nn.LazyLinear(dd) for dd in fc_dims])
        self.fc_dropout = nn.Dropout(p=fc_dropout)
#         self.full_dropout = torch.nn.Dropout(1)
        self.organ_dropout = organ_dropout
        self.activation = nn.ReLU()
        self.final_linear = nn.Linear(fc_dims[-1],n_classes)
        self.softmax = nn.Softmax()
        self.n_organs= 50
        self.dummy_param = nn.Parameter(torch.empty(0))
        
    def forward(self,x,train=False):
        [positions,features,distances] = x
#         values = torch.zeros((distances.shape[0],self.n_organs),requires_grad=True)
        values = []
        for row,(plist, flist) in enumerate(zip(positions,features)): 
            vrow = []
            prow = []
            organ_indices=[]
            #basically for each patient there is a list of pointclouds and a list of pointcloud features
            #one for each organ. assumes in correct order with missing organs as an empty item
            #does a convolution block on all points, then max-pools points for each individual organ, and concatenates the results
            #final vector is (n_organs*features in last cnn block)
            for col,(p,f) in enumerate(zip(plist,flist)):
                if f.shape[0] < 2:
                    f = torch.zeros((3,1)).to(self.dummy_param.device)
                    p = torch.zeros((3,3)).to(self.dummy_param.device)
                #stacks feature s and checks what position they're in (corresponds to organ)
                #randomly set organs to zero during training for dropout
                if self.train and self.organ_dropout > 0 and torch.rand(1).item() < self.organ_dropout:
                    f = torch.zeros(f.shape).to(self.dummy_param.device)
                vrow.append(f)
                prow.append(p)
                col_idx = torch.zeros(f.shape).to(self.dummy_param.device)
                col_idx[:] = col
                organ_indices.append(col_idx)
            prow = torch.concatenate(prow)
            vrow = torch.concatenate(vrow)
            organ_indices = torch.concatenate(organ_indices)
            new_positions,new_features = self.cnn(prow,vrow)
            #cnn should have a helper function to do pooling using a nearest neighbors approach to track organ membership
            _, organ_indices = self.cnn.voxel_pool_nn(prow,organ_indices)
            organ_indices = organ_indices.view(-1)
            pooled_values = []
            #max pool organs seperately
            for oi in range(self.n_organs):
                idx = torch.argwhere(organ_indices == oi)
                #we lose some organs in the voxel pooling stage
                if idx.shape[0] > 0:
                    ovals = torch.max(new_features[idx],axis=0)[0]
                    pooled_values.append(ovals.view(-1))
                else:
                    pooled_values.append(torch.zeros(new_features.shape[-1]).to(self.dummy_param.device))
            pooled_values = torch.stack(pooled_values)
            values.append(pooled_values.view(-1))
        #batch size x (n_organs*n_featuers in last cnn block)
        values = torch.stack(values)
        #fully connected layers of just intra-organ distances
        dx = self.dist_layers[0](distances)
        for layer in self.dist_layers[1:]:
            dx = layer(dx)
            dx = self.activation(dx)
        #combined convolution and distances, then fully connected layers
        x = torch.cat((dx,values),axis=1)
        for layer in self.fc_layers:
            x = layer(x)
            x = self.activation(x)
        if train:
            x = self.fc_dropout(x)
        x = self.final_linear(x)
        x = self.activation(x)
        x = self.softmax(x)
        return x

test_data.dataset.mdasi.shape[1]

5

In [9]:
test_data.dataset.symptoms

['drymouth', 'pain', 'choke', 'mucus']

In [10]:
from sklearn.metrics import roc_auc_score , f1_score

def multi_bce_loss(ypred,target,weights=[1,2,.4]):
    nclasses = ypred.shape[1]
    bce = nn.BCELoss()
    total_loss = 0
    for i in range(nclasses):
        closs = bce(ypred[:,i],target[:,i].type(torch.FloatTensor))
        total_loss += weights[i]*closs.item()
        print(closs.item())
    return total_loss

def multiclass_metrics(ypred,target):
    nclasses = ypred.shape[1]
    ypred = ypred.detach().numpy().astype(float)
    target = target.detach().numpy().astype(float)
    aucs = []
    f1s = []
    for i in range(nclasses):
        if target[:,i].std() < .0001:
            auc_score = -1
            f1_scores = -1
        else:
            auc_score = roc_auc_score(target[:,i],ypred[:,i])
            f1_scores = f1_score(target[:,i],ypred[:].argmax(axis=1) == i)
        aucs.append(auc_score)
        f1s.append(f1_scores)
    return {'auc': aucs, 'f1': f1s}
multiclass_metrics(torch.rand(4,3),torch.LongTensor([[0,1,1],[1,0,0],[0,0,0],[1,1,1]]))

{'auc': [0.5, 0.25, 0.0], 'f1': [0.0, 0.0, 0.5]}

In [11]:
torch.concatenate([torch.rand(4,3),torch.rand(5,3)]).shape

torch.Size([9, 3])

In [12]:
def run_model(model=None,train_data=None,test_data=None,pc_files=None,batch_size=20,epochs = 100,lr=.000001,patience=5):
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    if train_data is None or test_data is None:
        train_data, test_data = dicom_tt_split(all_files=pc_files,batch_size=batch_size)
    n_classes = train_data.dataset.mdasi.shape[1]
    loss_weights = train_data.dataset.outcome_weights
    
    if model is None:
        model = OrganNet(n_classes= n_classes)
    model = model.to(device)
#     optimizer = torch.optim.Adam(model.parameters(),lr=lr)
    optimizer = torch.optim.SGD(model.parameters(),lr=lr)
    bce = nn.BCELoss()

    def loss_fn(ypred,target):
        total_loss = 0
#         ypred = ypred.to(model.dummy_param.device)
#         target = target.to(model.dummy_param.device)
        target=target.type(torch.FloatTensor).to(device)
        for i in range(n_classes):
            temp = bce(ypred[:,i],target[:,i])
            total_loss += loss_weights[i]*temp
        return total_loss
        
    curr_loss_train = 0
    curr_loss_val = 0
    best_val_loss = 0
    steps_since_improvement  = 0
    steps_since_improvement
    
    def train_loop():
        running_loss = 0
        model.train(True)
        steps = 0
        for xbatch, ybatch in train_data:
            ypred = model(xbatch)
            loss = loss_fn(ypred,ybatch)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
            print('train loss',loss,end='\r')
        return running_loss/len(train_data.dataset)
    
    def val_loop():
        running_loss = 0
        model.train(False)
        steps = 0
        pred = []
        target = []
        for xbatch, ybatch in test_data:
            ypred = model(xbatch)
            pred.append(ypred)
            target.append(ybatch)
            loss = loss_fn(ypred,ybatch)
            running_loss += loss.item()
            print('val loss',loss,end='\r')
        target = torch.concatenate(target)
        pred = torch.concatenate(pred)
        return target, pred, running_loss/len(test_data.dataset)
    
    for i in range(epochs):
        curr_loss_train = train_loop()
        print('train loss',curr_loss_train)
        ytrue, ypred, val_loss = val_loop()
        print('val_loss',val_loss)
        metrics = multiclass_metrics(ypred.detach().cpu(),ytrue.detach().cpu())
        for k, v in metrics.items():
            print('val '+k, [i + str(np.round(ii,2)) for i,ii in zip(train_data.dataset.symptoms,v)])
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            steps_since_improvement =0
        else:
            steps_since_improvement += 1
        if steps_since_improvement > patience:
            break
    print('final losses',curr_loss_train,curr_loss_val,best_val_loss)
    return model

run_model(train_data=train_data,test_data=test_data)

  x = self.softmax(x)


train loss tensor(1.2813, device='cuda:0', grad_fn=<AddBackward0>)

KeyboardInterrupt: 

In [None]:
pnet = _ml3d.utils.get_module('model','Pointnet2MSG','torch')
pnet([test_l,test_f])

In [None]:
torch.cuda.is_available()

In [None]:
for i in train_data:
    print(len(i))
    break

In [None]:
# def downsample_pointcloud(pcloud,k=1,max_points = 5000):
#     pc = o3d.geometry.PointCloud()
#     pc.points = o3d.utility.Vector3dVector(pcloud['coordinates'])
#     #save dose as rgb color for when they're discretized
#     pc.colors = o3d.utility.Vector3dVector(np.stack([pcloud['dose_values'] for c in range(3)],axis=-1))
#     pc = pc.voxel_down_sample(k)#to make it uniform for the convolution
#     if len(pc.points) > max_points:
#         pc = pc.farthest_point_down_sample(max_points)
#     points = np.asarray(pc.points)
#     colors = np.asarray(pc.colors)[:,0]
#     if points.shape[0] < max_points:
#         diff = max_points -points.shape[0] 
#         padd = np.zeros((diff,3))
#         points = np.concatenate([points,padd],axis=0)
#         colors = np.concatenate([colors,padd[:,0]],axis=0)
#     return {'coordinates': points,'dose_values': colors}

# def compile_pointclouds(patient,
#                         batch_index=None, 
#                         organs = Const.organ_list + ['gtv','gtvn'],
#                         downsample_k=1,
#                         max_points=5000,
#                         add_organ_key = False,
#                        ):
#     cpc = patient['contour_pointclouds']
#     all_points = []
#     all_values = []
#     patient_max_dose = 0
#     for organ in organs:
#         entry = cpc.get(organ)
#         if entry is None:
#             continue
#         coords = np.array(entry['coordinates'])
#         values = np.array(entry['dose_values'])
#         patient_max_dose = max(values.max(),patient_max_dose)
#         all_points.append(coords)
#         all_values.append(values)
#     all_points = np.concatenate(all_points,axis=0)
#     all_values = np.concatenate(all_values,axis=0).ravel()
#     entry = {'coordinates': all_points,'dose_values':all_values}
#     if downsample_k > 0:
#         entry = downsample_pointcloud(entry,downsample_k,max_points)
#     if batch_index is not None:
#         ap = entry['coordinates']
#         bi = np.full((ap.shape[0],1),batch_index)
#         entry['coordinates'] =np.hstack((ap,bi))
#     entry['patient_id'] = patient['patient_id']
#     return entry

# class SparseCNN(nn.Module):
    
#     def __init__(self,dims=3,features=1,
#                  output_dim = 1,
#                  manifold_m = [32,32],
#                  manifold_size=[3,3],
#                  full_m = [32],
#                  full_size = 5,
#                  reps=1,
#                 ):
#         nn.Module.__init__(self)
#         self.dims = dims
#         self.features = features
    
#         self.input_layer = scn.InputLayer(dims,features) 
#         self.batchNorm = scn.BatchNormalization(features) # I don't actually know what input planes means
#         self.manifold = scn.SubmanifoldConvolution(dims,features,manifold_m[0],manifold_size[0],True)
#         self.manifold2 = scn.SubmanifoldConvolution(dims,manifold_m[0],manifold_m[1],manifold_size[1],True)

#         self.batchNorm2 = scn.BatchNormReLU(manifold_m[-1])
#         self.fullConv = scn.FullConvolution(dims,manifold_m[-1],full_m[0],full_size,1,True)
        
#         #with submanifold it will be last # of features, fullConvolution seems to be features*(full_size**3)
#         self.final_dim = full_m[-1]*(full_size**3)
#         self.sparseToDense = scn.SparseToDense(dims,full_m[-1])

#         self.linear = nn.LazyLinear(output_dim)
        
#     def forward(self,locs,features,batch_size):
#         x = self.input_layer((locs,features,batch_size))
#         x = self.batchNorm(x)
#         x = self.manifold(x)
#         x = self.manifold2(x)
    
#         x = self.batchNorm2(x)
#         x = self.fullConv(x)
#         x = self.sparseToDense(x)
#         x = x.view(x.shape[0],-1)
#         x = self.linear(x)
#         return x
    
# def SparseAutoEncoder(nn.Module):
#     def __init__(self,dims=3,features=1,
#                  output_dim = 1,
#                  manifold_m = [32,32,32,32],
#                  manifold_size=[3,3,3,3],
#                  full_m = [32],
#                  full_size = 5,
#                  reps=1,
#                 ):
#         nn.Module.__init__(self)
#         self.dims = dims
#         self.features = features
    
#         self.input_layer = scn.InputLayer(dims,features) 
#         self.batchNorm = scn.BatchNormalization(features) # I don't actually know what input planes means
#         self.manifold = scn.SubmanifoldConvolution(dims,features,manifold_m[0],manifold_size[0],True)
#         self.manifold2 = scn.SubmanifoldConvolution(dims,manifold_m[0],manifold_m[1],manifold_size[1],True)

#         self.batchNorm2 = scn.BatchNormReLU(manifold_m[-1])
        
#     def forward(self,locs,features,batch_size):
#         x = self.input_layer((locs,features,batch_size))
#         x = self.batchNorm(x)
#         x = self.manifold(x)
#         x = self.manifold2(x)
    
#         x = self.batchNorm2(x)
#         x = self.fullConv(x)
#         x = self.sparseToDense(x)
# #         x = x.view(x.shape[0],-1)
# #         x = self.linear(x)
#         return x
    
# model = SparseCNN()
# model = model.to(device)
# model(test_l,test_f,4).shape