In [None]:
import torch 
import numpy as np
from torch import optim,nn 
import torch.nn.functional as F
from torch.utils.data import DataLoader,TensorDataset
from torch.utils.data.sampler import SubsetRandomSampler 
import matplotlib.pyplot as plt
from keras.utils import to_categorical

In [None]:
class Unet3D(nn.Module):

    def __init__(self):
        super(Unet3D,self).__init__()

        self.conv1 = nn.Conv3d(in_channels=4,out_channels=32,kernel_size= (3,3,3),stride=(1,1,1),padding=(1,1,1))
        self.conv2 = nn.Conv3d(in_channels=32, out_channels=64,kernel_size =(3,3,3),stride=(1,1,1),padding=(1,1,1))
        self.maxpool1 = nn.MaxPool3d(kernel_size = (2,2,2))
        self.conv3 = nn.Conv3d(in_channels=64,out_channels=64,kernel_size=(3,3,3),stride=(1,1,1),padding=(1,1,1))
        self.conv4 = nn.Conv3d(in_channels=64,out_channels=128,kernel_size=(3,3,3),stride=(1,1,1),padding=(1,1,1))
        self.upsample1 = nn.Upsample(scale_factor = (2,2,2))
        self.conv5 = nn.Conv3d(in_channels=192,out_channels=64,kernel_size=(3,3,3),stride=(1,1,1),padding=(1,1,1))
        self.conv6 = nn.Conv3d(in_channels=64,out_channels=64,kernel_size=(3,3,3),padding=(1,1,1),stride=(1,1,1))
        self.fconv = nn.Conv3d(in_channels=64,out_channels=3,kernel_size=(1,1,1),stride=(1,1,1))


    def forward(self,x):

        z1 = self.conv1(x)
        a1 = F.relu(z1)

        z2 = self.conv2(a1)
        a2 = F.relu(z2)

        z3 = self.maxpool1(a2)
        z3 = self.conv3(z3)
        a3 = F.relu(z3)

        z4 = self.conv4(a3)
        a4 = F.relu(z4)

        z5 = self.upsample1(a4)
        a5 = torch.cat([z5,a2],dim=1)

        z6 = self.conv5(a5)
        a6 = F.relu(z6)

        z7 = self.conv6(a6)
        a7 = F.relu(z7)

        z8 = self.fconv(a7)
        a8 = F.logsigmoid(z8)

        return a8

In [None]:
model = Unet3D()
model.cuda()

Unet3D(
  (conv1): Conv3d(4, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
  (conv2): Conv3d(32, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
  (maxpool1): MaxPool3d(kernel_size=(2, 2, 2), stride=(2, 2, 2), padding=0, dilation=1, ceil_mode=False)
  (conv3): Conv3d(64, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
  (conv4): Conv3d(64, 128, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
  (upsample1): Upsample(scale_factor=(2.0, 2.0, 2.0), mode=nearest)
  (conv5): Conv3d(192, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
  (conv6): Conv3d(64, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
  (fconv): Conv3d(64, 3, kernel_size=(1, 1, 1), stride=(1, 1, 1))
)

In [None]:
model.load_state_dict(torch.load('/content/AutoTumourModel.pt'))

<All keys matched successfully>

In [None]:
def standardize(image):
    
    standardized_image = np.zeros(image.shape)

    for c in range(image.shape[0]):
        for z in range(image.shape[3]):
            image_slice = image[c,:,:,z]
            centered = image_slice - np.mean(image_slice)
            if np.std(centered) != 0:
                centered_scaled = centered / np.std(centered)
                standardized_image[c, :, :, z] = centered_scaled

    return standardized_image

In [None]:
def entire_scan(image,label,model):
    model.eval()

    model_label = np.zeros([3, 320, 320, 160])

    for x in range(0, image.shape[0], 160):
        for y in range(0, image.shape[1], 160):
            for z in range(0, image.shape[2], 16):
                patch = np.zeros([4, 160, 160, 16])
                p = np.moveaxis(image[x: x + 160, y: y + 160, z:z + 16], 3, 0)
                patch[:, 0:p.shape[1], 0:p.shape[2], 0:p.shape[3]] = p
                patch = standardize(patch)
                pred = model(torch.unsqueeze(torch.from_numpy(patch).float().cuda(),0))
                pred = torch.exp(pred)
                pred[pred>0.5] = 1.0
                pred[pred<=0.5] = 0.0
                pred = pred.cpu().detach().numpy()
                model_label[:, x:x + p.shape[1],y:y + p.shape[2],z: z + p.shape[3]] += pred[0][:, :p.shape[1], :p.shape[2],:p.shape[3]]

    model_label = np.moveaxis(model_label[:, 0:240, 0:240, 0:155], 0, 3)
    model_label_reformatted = np.zeros((240, 240, 155, 4))
    model_label_reformatted = to_categorical(label, num_classes=4).astype(np.uint8)
    model_label_reformatted[:, :, :, 1:4] = model_label

    return model_label_reformatted


In [None]:
import nibabel as nib

In [None]:
img = np.array(nib.load('/content/BRATS_001.nii.gz').get_fdata())

In [None]:
label = np.array(nib.load('/content/label.nii.gz').get_fdata())

In [None]:
y = entire_scan(img,label,model)

In [None]:
torch.save(y,'y_pred.pt')