In [1]:
import numpy as np
import torch
import matplotlib.pyplot as plt
from glob import glob
from os.path import join
from scipy.stats import mode
import csv
%matplotlib notebook

from sklearn.metrics import accuracy_score,roc_auc_score
from moment_kernels import *

In [2]:
# path to raw npz files
data_dir = '/nafs/dtward/allen/npz_files/'

In [3]:
class Dataset(torch.utils.data.Dataset):
    def __init__(self,data_dir=data_dir,lr=0,n_per_slice=100,size=64,
                 level='/nafs/dtward/allen/rois/categories.csv'
                ):
        self.data_dir = data_dir
        self.lr = lr
        files = glob(join(data_dir,f'*_lr_{lr}.npz'))
        files.sort()
        files = files
        self.files = files
        self.n_per_slice = n_per_slice
        self.size = size
        self.current_slice = -1
        
        # get a map from the level
        label_mapper = {}
        label_names = {}
        count = 0
        with open(level) as f:
            reader = csv.reader(f)
            for line in reader:                
                if count == 0:
                    headers = line
                    count += 1
                    continue
                labels = line[2]
                labels = labels.replace('\n',' ').replace('[','').replace(']','')
                labels = labels.split(' ')
                #print(labels)
                for l in labels:
                    if l:
                        label_mapper[int(l)] = int(line[0])
                label_names[int(line[0])] = line[1]
                count += 1
        self.label_mapper = label_mapper
        self.label_names = label_names
        
    def __len__(self):
        return self.n_per_slice * len(self.files)
    def __getitem__(self,i):
        slice_number = i//self.n_per_slice
        if slice_number != self.current_slice:
            # load a slice            
            data = np.load(self.files[slice_number])
            self.I = data['I']
            self.L = data['L']
            self.current_slice = slice_number
        # get a region
        rowmin = 0
        rowmax = self.I.shape[1]
        colmin = 0
        colmax = self.I.shape[2]
        
        # cutout
        r = np.random.randint(rowmin,rowmax-self.size)
        c = np.random.randint(colmin,colmax-self.size)
        I_ = self.I[:,r:r+self.size,c:c+self.size]
        L_ = self.L[:,r:r+self.size,c:c+self.size]
        self.L_ = L_
        self.I_ = I_
        L__ = L_.ravel()
        L__ = [self.label_mapper[l] for l in L__]
        L__ = np.reshape(L__,L_.shape)
        self.L__ = L__
        self.L_ = L_
        return I_,L__
        
        
    

In [4]:
dataset = Dataset(level='/nafs/dtward/allen/rois/divisions.csv')
dataset_test = Dataset(level='/nafs/dtward/allen/rois/divisions.csv',lr=1)

In [5]:
len(dataset.label_names)

26

In [6]:
dataset.label_names

{0: 'unassigned',
 1: 'HY',
 2: 'Isocortex',
 3: 'lfbs',
 4: 'P',
 5: 'MB',
 6: 'TH',
 7: 'HPF',
 8: 'STR',
 9: 'mfbs',
 10: 'cm',
 11: 'cbf',
 12: 'VL',
 13: 'MY',
 14: 'CB',
 15: 'eps',
 16: 'V3',
 17: 'CTXsp',
 18: 'AQ',
 19: 'V4',
 20: 'OLF',
 21: 'c',
 22: 'PAL',
 23: 'brain-unassigned',
 24: 'fiber tracts-unassigned',
 25: 'scwm'}

In [7]:
colors = np.random.rand(len(dataset.label_names),3)

In [8]:
x,l = dataset[2]

In [9]:
fig,ax = plt.subplots()
ax.imshow(dataset.L_[0])

fig,ax = plt.subplots()
ax.imshow(colors[dataset.L__[0]])

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<matplotlib.image.AxesImage at 0x7f058d70daf0>

In [10]:
keys = list(dataset.label_names.keys())
keys.sort()
weights = 1.0 - np.array(['unassigned' in dataset.label_names[k] and dataset.label_names[k] != 'unassigned' for k in keys])
weights

array([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1., 1., 1., 1., 0., 0., 1.])

In [11]:
data_loader = torch.utils.data.DataLoader(dataset,batch_size=8,shuffle=False)
data_loader_test = torch.utils.data.DataLoader(dataset_test,batch_size=8,shuffle=False)

In [12]:
for x,l in data_loader:
    break

In [13]:
class AvgPoolTranspose2d(torch.nn.Module):
    def __init__(self,factor):
        if factor != (2,2):
            raise Exception('Only (2,2) supported')
        super().__init__()
    def forward(self,x):
        # to be transpose, I think it should divide by 4
        return torch.repeat_interleave(torch.repeat_interleave(x,2,dim=-1),2,dim=-2)/4
    

In [14]:
# build a quick net
class Net(torch.nn.Module):
    def __init__(self,):
        super().__init__()
        
        # downsampling
        k = 3
        p = 1
        c0 = 16
        cin = 501
        cout = len(dataset.label_names)
        self.c1 = torch.nn.Conv2d(cin,c0,k,padding=p)
        self.p1 = torch.nn.AvgPool2d((2,2))
        self.s1 = torch.nn.ReLU()
        
        self.c2 = torch.nn.Conv2d(c0,c0*2,k,padding=p)
        self.p2 = torch.nn.AvgPool2d((2,2))
        self.s2 = torch.nn.ReLU()
        
        self.c3 = torch.nn.Conv2d(c0*2,c0*4,k,padding=p)
        self.p3 = torch.nn.AvgPool2d((2,2))
        self.s3 = torch.nn.ReLU()
        
        self.c4 = torch.nn.Conv2d(c0*4,c0*8,k,padding=p)
        self.p4 = torch.nn.AvgPool2d((2,2))
        self.s4 = torch.nn.ReLU()
        
        self.c5 = torch.nn.Conv2d(c0*8,c0*16,k,padding=p)
        self.p5 = torch.nn.AvgPool2d((2,2))
        self.s5 = torch.nn.ReLU()
        
        # upsampling
        self.p5_ = AvgPoolTranspose2d((2,2))
        self.c5_ = torch.nn.Conv2d(c0*(16+8),c0*8,k,padding=p)
        self.s5_ = torch.nn.ReLU()
        
        self.p4_ = AvgPoolTranspose2d((2,2))
        self.c4_ = torch.nn.Conv2d(c0*(8+4),c0*4,k,padding=p)
        self.s4_ = torch.nn.ReLU()
        
        self.p3_ = AvgPoolTranspose2d((2,2))
        self.c3_ = torch.nn.Conv2d(c0*(4+2),c0*2,k,padding=p)
        self.s3_ = torch.nn.ReLU()
        
        self.p2_ = AvgPoolTranspose2d((2,2))
        self.c2_ = torch.nn.Conv2d(c0*(2+1),c0*1,k,padding=p)
        self.s2_ = torch.nn.ReLU()
        
        
        self.p1_ = AvgPoolTranspose2d((2,2))
        self.c1_ = torch.nn.Conv2d(c0+cin,cout,k,padding=p)
        
        
        
        
        
    def forward(self,x):
        x0 = [x] # 64x64
        x = self.s1(self.p1(self.c1(x)))
        x0.append(x) # 32x32
        x = self.s2(self.p2(self.c2(x)))
        x0.append(x) # 16x16
        x = self.s3(self.p3(self.c3(x)))
        x0.append(x) # 8x8
        x = self.s4(self.p4(self.c4(x)))
        x0.append(x) # 4x4
        x = self.s5(self.p5(self.c5(x)))
        # 2x2
        
        
        
        # now upsampling
        x = self.p5_(x) # upsample        
        # concat
        x = torch.concatenate((x,x0.pop()),-3)        
        # conv
        x = self.c5_(x)
        x = self.s5_(x)
        
        x = self.p4_(x)     
        x = torch.concatenate((x,x0.pop()),-3)                
        x = self.c4_(x)
        x = self.s4_(x)
        
        x = self.p3_(x)     
        x = torch.concatenate((x,x0.pop()),-3)                
        x = self.c3_(x)
        x = self.s3_(x)
        
        x = self.p2_(x)     
        x = torch.concatenate((x,x0.pop()),-3)                
        x = self.c2_(x)
        x = self.s2_(x)
        
        
        x = self.p1_(x)     
        x = torch.concatenate((x,x0.pop()),-3)                
        x = self.c1_(x)
        
        
        
        
        return x

In [15]:
# build a quick net
class EqNet(torch.nn.Module):
    def __init__(self,):
        super().__init__()
        
        # downsampling
        k = 3
        p = 1
        c0 = 16
        cin = 501
        cout = len(dataset.label_names)

        self.c1 = ScalarToScalar(in_channels = cin, out_channels=c0, kernel_size=k, padding=p)
        self.p1 = Downsample()
        self.s1 = ScalarSigmoid()

        self.c2 = ScalarToScalar(in_channels = c0, out_channels=c0*2, kernel_size=k, padding=p)
        self.p2 = Downsample()
        self.s2 = ScalarSigmoid()

        self.c3 = ScalarToScalar(in_channels = c0*2, out_channels=c0*4, kernel_size=k, padding=p)
        self.p3 = Downsample()
        self.s3 = ScalarSigmoid()

        self.c4 = ScalarToScalar(in_channels = c0*4, out_channels=c0*8, kernel_size=k, padding=p)
        self.p4 = Downsample()
        self.s4 = ScalarSigmoid()

        self.c5 = ScalarToScalar(in_channels = c0*8, out_channels=c0*16, kernel_size=k, padding=p)
        self.p5 = Downsample()
        self.s5 = ScalarSigmoid()
        
        # upsampling
        self.p5_ = Upsample()
        self.c5_ = ScalarToScalar(in_channels = c0*16+c0*8, out_channels=c0*8, kernel_size=k, padding=p)
        self.s5_ = ScalarSigmoid()
        
        self.p4_ = Upsample()
        self.c4_ = ScalarToScalar(in_channels = c0*8+c0*4, out_channels=c0*4, kernel_size=k, padding=p)
        self.s4_ = ScalarSigmoid()
        
        self.p3_ = Upsample()
        self.c3_ = ScalarToScalar(in_channels = c0*4+c0*2, out_channels=c0*2, kernel_size=k, padding=p)
        self.s3_ = ScalarSigmoid()

        self.p2_ = Upsample()
        self.c2_ = ScalarToScalar(in_channels = c0*2+c0, out_channels=c0, kernel_size=k, padding=p)
        self.s2_ = ScalarSigmoid()
        
        self.p1_ = Upsample()
        self.c1_ = ScalarToScalar(in_channels = c0+cin, out_channels=cout, kernel_size=k, padding=p)
           
    def forward(self,x):
        x0 = [x] # 64x64
        x = self.s1(self.p1(self.c1(x)))
        x0.append(x) # 32x32
        x = self.s2(self.p2(self.c2(x)))
        x0.append(x) # 16x16
        x = self.s3(self.p3(self.c3(x)))
        x0.append(x) # 8x8
        x = self.s4(self.p4(self.c4(x)))
        x0.append(x) # 4x4
        x = self.s5(self.p5(self.c5(x)))
        # 2x2
        
        # now upsampling
        x = self.p5_(x) # upsample        
        # # concat
        x = torch.concatenate((x,x0.pop()),-3)        
        # # conv
        x = self.c5_(x)
        x = self.s5_(x)
        
        x = self.p4_(x)     
        x = torch.concatenate((x,x0.pop()),-3)                
        x = self.c4_(x)
        x = self.s4_(x)
        
        x = self.p3_(x)     
        x = torch.concatenate((x,x0.pop()),-3)                
        x = self.c3_(x)
        x = self.s3_(x)
        
        x = self.p2_(x)     
        x = torch.concatenate((x,x0.pop()),-3)                
        x = self.c2_(x)
        x = self.s2_(x)
        
        x = self.p1_(x)     
        x = torch.concatenate((x,x0.pop()),-3)                
        x = self.c1_(x)
        
        return x

In [16]:
# non equivariant Unet-like architecture
net = Net()

In [17]:
# equivariant Unet-like network 
eqnet = EqNet()

In [18]:
lhat = eqnet(x)
print(lhat.shape)

torch.Size([8, 26, 64, 64])


In [19]:
lhat_ = net(x)
print(lhat_.shape)

torch.Size([8, 26, 64, 64])


In [20]:
lhat.shape

torch.Size([8, 26, 64, 64])

In [21]:
optimizer = torch.optim.Adam(eqnet.parameters())

In [22]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [23]:
loss = torch.nn.CrossEntropyLoss(weight=torch.tensor(weights,dtype=torch.float32,device=device))


In [24]:
# get some metrics
def compute_dice(ltruesave,lpredictsave,labels=None):
    '''
    Returns dice score for each structure
    
    TODO
    probably should return weights for averaging across structures, later
    '''
    if labels is None:
        labels = np.unique(ltruesave)
    # for each example
    # look at volume, intersection, and unions
    dice = np.zeros(len(labels))
    dicesum = np.zeros(len(labels))
    count = 0
    for ltrue,lpredict in zip(ltruesave,lpredictsave):
        ntrue_ = []
        npredict_ = []
        nobth_ = []
        lcount = 0
        for l in labels:
            ntrue = np.sum(ltrue==l,(-1,-2))
            npredict = np.sum(lpredict==l,(-1,-2))
            nboth = np.sum((lpredict==l)*(ltrue==l),(-1,-2))
            dice[lcount] += 2*nboth / (ntrue + npredict + 1e-6)
            dicesum[lcount] += nboth
            lcount += 1
            
        count += 1
    dice /= count
    # to average
    weights = dicesum / np.sum(dicesum)
    dicemean = np.sum(dice*weights)
        
        
    
    return dice,dicemean

In [25]:
nepochs = 200
Esave = []
accuracytestsave = []
accuracytrainsave = []
dicetestsave = []
dicetrainsave = []
fig = plt.figure(figsize=(8,5))
ax = fig.subplot_mosaic([['all','accuracy','image','true'],['this','dice','error','predicted']])

net = net.to(device)
count = 0
for e in range(nepochs):
    Esave_ = []
    ltruesave = []
    lpredictsave = []
    probsave = []
    for x,l in data_loader:
        x = x.to(device)
        ltruesave.append( l )
        l = l.to(device)
        optimizer.zero_grad()
        
        lhat = eqnet(x) #net(x)
        
        E = loss(lhat,l[:,0])
        
        E.backward()
        optimizer.step()
        
        Esave_.append(E.item())
        probsave.append( torch.softmax(lhat,-3).clone().detach().cpu() )
        lpredictsave.append(torch.argmax(lhat,-3).clone().detach().cpu())
        
        if not count%10:
            
            ax['this'].cla()
            ax['this'].plot(Esave_)
            ax['this'].set_title('Loss this epoch')
            
            ax['image'].cla()
            ax['image'].imshow(x[0,0].clone().detach().cpu())
            ax['image'].set_title('image')
            
            ax['true'].cla()
            ax['true'].imshow(colors[l[0,0].clone().detach().cpu()],interpolation='none',vmin=0,vmax=len(dataset.label_mapper)-1)
            ax['true'].set_title('true')
            
            ax['error'].cla()
            ax['error'].imshow(l[0,0].clone().detach().cpu().numpy() != torch.argmax(lhat[0],0).clone().detach().cpu().numpy(),interpolation='none',vmin=0,vmax=1)
            ax['error'].set_title('error')
            
            ax['predicted'].cla()
            ax['predicted'].imshow(colors[torch.argmax(lhat[0],0).clone().detach().cpu().numpy()],interpolation='none',vmin=0,vmax=len(dataset.label_mapper)-1)
            ax['predicted'].set_title('predicted')
            
            fig.canvas.draw()
        count += 1
    Esave.append(np.mean(Esave_))
    accuracytrainsave.append( accuracy_score(torch.concatenate(ltruesave).ravel(),torch.concatenate(lpredictsave).ravel()) )
    dicetrain,dicetrainmean = compute_dice(torch.concatenate(ltruesave,0).numpy(),torch.concatenate(lpredictsave,0).numpy(),labels=np.arange(len(dataset.label_names))) 
    dicetrainsave.append(dicetrainmean )
    
    
    ax['all'].cla()
    ax['all'].plot(Esave)
    ax['all'].set_title('Loss all epochs')
    
    ltruesave = []
    lpredictsave = []
    probsave = []
    with torch.no_grad():
        eqnet.eval()
        for x,l in data_loader_test:
            ltruesave.append( l )
            x = x.to(device)
            lhat = net(x)
            lhat = lhat.clone().detach().cpu()
            probsave.append( torch.softmax(lhat,-3).clone().detach().cpu() )
            lpredictsave.append(torch.argmax(lhat,-3).clone().detach().cpu())
    
            
        # josef is computing
        # haussdorf dice iou and loss
        # NOTE I DIDN'T DO ANYTHING WITH THE TEST SET OR REPORT ANY METRICS OF PERFORMANCE
        eqnet.train()
    
    accuracytestsave.append( accuracy_score(torch.concatenate(ltruesave).ravel(),torch.concatenate(lpredictsave).ravel()) )
    dicetest,dicetestmean = compute_dice(torch.concatenate(ltruesave,0).numpy(),torch.concatenate(lpredictsave,0).numpy(),labels=np.arange(len(dataset.label_names))) 
    dicetestsave.append(dicetestmean )

    # if accuracy improves, save the model
    if e == 0 or accuracytestsave[-1] > np.max(accuracytestsave[:-1]):
        torch.save(eqnet.state_dict(),'eqnet.pth')
    
    
    ax['accuracy'].cla()
    ax['accuracy'].plot(accuracytrainsave,label='train')
    ax['accuracy'].plot(accuracytestsave,label='test')
    ax['accuracy'].set_title('accuracy')
    ax['accuracy'].legend()
    
    ax['dice'].cla()
    ax['dice'].plot(dicetrainsave,label='train')
    ax['dice'].plot(dicetestsave,label='test')
    #ax['dice'].set_xticklabels(list(dataset.label_names.values()),rotation=45,ha='right')
    ax['dice'].set_title('Dice')
    ax['dice'].legend()
    
    #fig.tight_layout()
    fig.canvas.draw()

<IPython.core.display.Javascript object>

KeyboardInterrupt: 

In [None]:
torch.save(net.state_dict(),'UnetCNN_v00.pth')

In [None]:
# think about soft hausdorff
# hausdorff is based on a cdf of distances
# a maximum
# but for a 95th quantile we could use a cdf
# now we an use a weighted cdf
# still doesn't quite seem appropriate
# how about this
# compute a histogram

In [None]:
# hmm
# I think I was implicitly thinking of a 1 v rest hausdorff
# maybe I should do a 1v1 hausdorff, what would that look like?
# for each pair of labels I'd compute a set based on which is bigger
# than do a set distance
# BUT
# then take a weighted average
#

In [None]:
# whatabout if I somehow sweep a threshold from 0 to 1
# then the mask will change, likely start with the whole roi full
# then slowy move to the whole roi empty

## **Taking a look at predictions from equivariant and non equivariant segmentation models**

In [270]:
net = EqNet()
net = net.to(device)
net.load_state_dict(torch.load('best_eqUnetCNN_v00.pth', map_location = device), strict=False)

<All keys matched successfully>

In [252]:
x, l = next(iter(data_loader_test))

In [271]:
idx = 5

sample_img = x[idx].clone()
sample_label = l[idx, 0].clone()

print(sample_img.shape)
print(sample_label.shape)

pred_ = net(sample_img)

# softmax and argmax
pred = torch.softmax(pred_, -3)
pred = torch.argmax(pred, -3)

# plot
plt.imshow(pred)
plt.savefig('pred.png')

torch.Size([501, 64, 64])
torch.Size([64, 64])


<IPython.core.display.Javascript object>

In [272]:
plt.imshow(sample_label)
plt.savefig('label.png')

<IPython.core.display.Javascript object>

In [264]:
# turn to list of unique labels
unique_labels = np.unique(sample_label)
print(unique_labels)

# associate it with the label names
label_names = [dataset.label_names[l] for l in unique_labels]
print(label_names)

unique_pred = np.unique(pred)
print(unique_pred)

# associate it with the label names
pred_names = [dataset.label_names[l] for l in unique_pred]
print(pred_names)

[ 0 14]
['unassigned', 'CB']
[ 0  2 10 11 14 20]
['unassigned', 'Isocortex', 'cm', 'cbf', 'CB', 'OLF']


In [265]:
cnn = Net()
cnn.load_state_dict(torch.load('UnetCNN_v00.pth', map_location = device), strict=False)

<All keys matched successfully>

In [266]:
# do the same on this model
pred_ = cnn(sample_img)

# softmax and argmax
pred = torch.softmax(pred_, -3)
pred = torch.argmax(pred, -3)

# plot
plt.imshow(pred)
plt.savefig('pred_cnn.png')

<IPython.core.display.Javascript object>

In [267]:
plt.imshow(sample_label)
plt.savefig('label_cnn.png')

<IPython.core.display.Javascript object>

In [268]:
unique_pred = np.unique(pred)
print(unique_pred)

# associate it with the label names
pred_names = [dataset.label_names[l] for l in unique_pred]
print(pred_names)

[ 0 11 14]
['unassigned', 'cbf', 'CB']


In [269]:
#original image
plt.imshow(sample_img[0], cmap='gray')
plt.savefig('original.png')

<IPython.core.display.Javascript object>