In [None]:
import torch
import torch.nn.functional as F
from torchvision import transforms
from unet import UNet
from load_data import BRATS_test
from utils import RandomCrop_test,ToTensor_test,DataLoader,test
import os
import matplotlib.pyplot as plt
import numpy as np
import cv2

# Loading Weights

In [None]:
PATH = './weights/model.pth'
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = UNet(in_channels=4,n_classes=4, padding=True, up_mode='upsample').to(device)
state_dict = torch.load(PATH)
in_ = model.load_state_dict(state_dict['model_state_dict'])

# Loading Dataset

In [None]:
transformed_dataset = BRATS_test(root_dir='../BRATS/Task01_BrainTumour',
                            transform=transforms.Compose([RandomCrop_test((228,144)),
                                               ToTensor_test()]))
dataloader = DataLoader(transformed_dataset, batch_size=4,
                        shuffle=True, num_workers=4)

# Prediction

In [None]:
pred,X = test(model,dataloader,device)
pred1 = pred.detach().cpu().numpy()
X1 = np.transpose(X.detach().cpu().numpy(),[0,3,2,1])

In [None]:
k = 2
i = 2
pred2 = pred1[i,2,:,:].T
pred4 = pred1[i,3,:,:].T

# Input Image 1

In [None]:
plt.imshow(np.c_[X1[i,:,:,0],X1[i,:,:,1],X1[i,:,:,2],X1[i,:,:,3]],cmap='gray')
pred3 = cv2.morphologyEx(255*(pred2>0.5).astype('uint8'),cv2.MORPH_OPEN,kernel = np.ones([3,3],np.uint8))
pred5 = cv2.morphologyEx(255*(pred4>0.5).astype('uint8'),cv2.MORPH_OPEN,kernel = np.ones([3,3],np.uint8))

# Result without Post Processing (left - non enhancing, right- enhancing)

In [None]:
plt.imshow(np.c_[pred2>0.5,pred4>0.5],cmap='gray')

# Result with Processing

In [None]:
plt.imshow(np.c_[pred3/255,pred5/255],cmap='gray')

# Input Image 2

In [None]:
i = 1
pred2 = pred1[i,2,:,:].T
pred4 = pred1[i,3,:,:].T

In [None]:
plt.imshow(np.c_[X1[i,:,:,0],X1[i,:,:,1],X1[i,:,:,2],X1[i,:,:,3]],cmap='gray')
pred3 = cv2.morphologyEx(255*(pred2>0.5).astype('uint8'),cv2.MORPH_OPEN,kernel = np.ones([3,3],np.uint8))
pred5 = cv2.morphologyEx(255*(pred4>0.5).astype('uint8'),cv2.MORPH_OPEN,kernel = np.ones([3,3],np.uint8))

# Result without Post Processing (left - non enhancing, right- enhancing)

In [None]:
plt.imshow(np.c_[pred2>0.5,pred4>0.5],cmap='gray')

# Result with Processing

In [None]:
plt.imshow(np.c_[pred3/255,pred5/255],cmap='gray')