# 2D UNETR Sanity Test

Code from ???

In [None]:
import numpy as np
from matplotlib import pyplot as plt
import pandas as pd
import torch
from torch import nn
import torchvision.transforms.functional as F
from torch.utils.data import DataLoader

from segmentation_dataset_2D import SegmentationDataset2D, load_sanity_dataset

### Load Sanity Dataset Object

In [None]:
sanity_data = load_sanity_dataset()

X = sanity_data[0][0]
Y = sanity_data[0][1]
print('MRI shape', X.shape)
print('Mask shape',Y.shape)

plt.figure(figsize=(10, 5))

plt.subplot(121)
plt.imshow(F.to_pil_image(X))
plt.title('MRI')
plt.axis('off')

plt.subplot(122)
plt.imshow(F.to_pil_image(Y), cmap='gray')
plt.title('Segmentation Mask')
plt.axis('off')

plt.show()

### UNETR Model

### Sanity Test

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

In [None]:
net = UNETR().to(device)

y_prob = net(X.unsqueeze(0)).to('cpu').squeeze()

print(X.shape)
print(y_prob.shape)
print(y_prob.min(), y_prob.max())

plt.figure(figsize=(15, 5))
plt.subplot(141)
plt.imshow(F.to_pil_image(X))
plt.title('MRI')
plt.axis('off')

plt.subplot(142)
plt.imshow(F.to_pil_image(Y), cmap='gray')
plt.title('Segmentation Mask')
plt.axis('off')

plt.subplot(143)
plt.imshow(F.to_pil_image(y_prob), cmap='gray')
plt.title('Output Prob')
plt.axis('off')

plt.subplot(144)
plt.imshow(F.to_pil_image(torch.round(y_prob)), cmap='gray')
plt.title('Output')
plt.axis('off')


plt.show()


In [None]:
def get_optimizer(net):
    return torch.optim.SGD(net.parameters(), lr=0.05, momentum=0.9)

def train(data_loader, net, optimizer, weight=(0.9, 0.1)):
    for data in data_loader:

        inputs, targets = data[0].to(device), data[1].to(device)
        outputs = net(inputs)

        loss = nn.BCELoss()(outputs, targets)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        del inputs
        del targets

        return loss.item()

In [None]:
net.train()
optimizer = get_optimizer(net)
loss_graph = []

EPOCH = 50

sanity_loader = DataLoader(sanity_data, batch_size=1, num_workers=0, shuffle=False)
for e in range(EPOCH):
    loss = train(sanity_loader, net, optimizer)
    loss_graph.append(loss)
    print("Epoch: {} Loss: {}".format(e, loss))

plt.figure(figsize=(6, 3))
plt.plot(np.arange(0, EPOCH), loss_graph)
plt.xlabel('Iterations')
plt.ylabel('Loss value')
plt.title('Training loss for sanity check')
plt.show()

In [None]:
y_prob = net(X.unsqueeze(0)).to('cpu').squeeze()

print(X.shape)
print(y_prob.shape)
print(y_prob.min(), y_prob.max())

plt.figure(figsize=(15, 5))
plt.subplot(141)
plt.imshow(F.to_pil_image(X))
plt.title('MRI')
plt.axis('off')

plt.subplot(142)
plt.imshow(F.to_pil_image(Y), cmap='gray')
plt.title('Segmentation Mask')
plt.axis('off')

plt.subplot(143)
plt.imshow(F.to_pil_image(y_prob), cmap='gray')
plt.title('Output Prob')
plt.axis('off')

plt.subplot(144)
plt.imshow(F.to_pil_image(torch.round(y_prob)), cmap='gray')
plt.title('Output')
plt.axis('off')

plt.show()