In [1]:
import nibabel as nib
import imageio
import torch
import matplotlib.pyplot as plt
from torch.utils.data.dataset import Dataset
from torchvision.transforms import Compose
from pathlib import Path
import numpy as np
from fastcore.utils import *

from data import dataset, dataloaders, transforms 
%matplotlib inline
%load_ext autoreload
%autoreload 2

In [51]:
from torch import nn

class Net(nn.Module):
    def __init__(self, in_channels=1):
        super().__init__()
        self.encoder = nn.Sequential(
            self.conv_block(in_channels=in_channels, out_channels=16, kernel_size=3, stride=1, output_size=(60,60,45)),
            self.conv_block(in_channels=16, out_channels=32, kernel_size=3, stride=1, output_size=(30,30,23)),
        )
        self.decoder = nn.Sequential(
            self.deconv_block(32, 16, kernel_size=3, stride=2,padding=(1, 1, 1),  output_padding=1),
            self.deconv_block(16, 1, kernel_size=3, stride=2, padding=(1, 1, 2), output_padding=1)
        )
        # fill this in, add conv. layers here
    
    def conv_block(self, in_channels, out_channels, kernel_size, stride, output_size):
        return nn.Sequential(
            nn.Conv3d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride),
            nn.ReLU(),
            nn.AdaptiveMaxPool3d(output_size),
        )
    def deconv_block(self, in_channels, out_channels, kernel_size, stride, padding, output_padding):
        return nn.Sequential(
            nn.ConvTranspose3d(in_channels, out_channels, kernel_size, stride, padding=padding, output_padding=output_padding),
            nn.ReLU(),
        )
    def forward(self, x):
        x = self.encoder(x)
#         print(x.shape)
        x = self.decoder(x)
#         print(x.shape)
        x = torch.sigmoid(x)
        return x

In [52]:
model = Net()

In [53]:
model

Net(
  (encoder): Sequential(
    (0): Sequential(
      (0): Conv3d(1, 16, kernel_size=(3, 3, 3), stride=(1, 1, 1))
      (1): ReLU()
      (2): AdaptiveMaxPool3d(output_size=(60, 60, 45))
    )
    (1): Sequential(
      (0): Conv3d(16, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1))
      (1): ReLU()
      (2): AdaptiveMaxPool3d(output_size=(30, 30, 23))
    )
  )
  (decoder): Sequential(
    (0): Sequential(
      (0): ConvTranspose3d(32, 16, kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1), output_padding=(1, 1, 1))
      (1): ReLU()
    )
    (1): Sequential(
      (0): ConvTranspose3d(16, 1, kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 2), output_padding=(1, 1, 1))
      (1): ReLU()
    )
  )
)

In [54]:
data_dir = Path("/media/wwymak/Storage/ixi_brain_images/small")
transform=Compose([
   transforms.ToTensor()
])
nifti_dataloader = dataloaders.get_dataloader(source_directory=data_dir/'t1',target_directory=data_dir/'t2', transforms=transform)

test_batch = next(iter(nifti_dataloader))

In [55]:
output = model(test_batch['t1'].float())

torch.Size([8, 32, 30, 30, 23])
torch.Size([8, 1, 120, 120, 90])


```d_out = (d_in -1) * stride - 2 * padding + dilation *( kernel_size -1) + output_padding + 1
d_out = 29 * 2 - 2 * 1 + 1 * 2 + 1 ```