### Workflow

1. Create a class that inherits from nn.Module and implement the following methods:
   + `__init__` (put the convolutional blocks here)
   + `forward` (this will take a properly formatted T1-w image as input and, once trained, output an approximate T2-w image of the same shape as the input image)
1. Verify that the forward pass works by running an image from the training set through the network. Check that the input shape is the same as the output shape.

#### Notes

* Loss function will just compare T1 and T2 images
* You many not need a loss function here. 
* Goal is simply to send an image through a forward pass of the network. Training happens in the next module.

#### Questions

* How would you reformulate this problem as a GAN?  Why would a GAN necessarily work better with less training data?
>   Note that there are other techniques to accomplish this image transformation task (e.g., generative adversarial networks), but since we have a large set of paired data, we can use supervised techniques like we are doing here.
* Downsample and upsample or just keep things simple and the same dimension? 


In [25]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [16]:
### Model Architecture 
class MRConvNet(nn.Module):
    def __init__(self, nChans=[16]):
        super(MRConvNet, self).__init__()
        self.conv1 = nn.Conv3d(1, nChans[0], 3)
        
    def forward(self, x):
        x = F.relu(self.conv1(x))
        return x



In [17]:
### instantiate model and try out the forward() pass. 
mcn = MRConvNet()

In [27]:
# Load a raw image
import nibabel as nib

def load_mr_image(subj, ttype):
    suff = ttype.upper()
    if ttype == 't2': 
        suff = f'{suff}_reg'
    return nib.load(f'./data/small/{ttype}/{subj}-{suff}_fcm.nii.gz')

SUBJ = 'IXI102-HH-1416'

image = load_mr_image(SUBJ, 't1')
image_data = image.get_fdata()

print(type(image_data))
print(image_data.shape)


<class 'numpy.ndarray'>
(120, 120, 90)


In [28]:
# Conver ND array to tensor and reshape it
        
# numpy image: H x W x C
# torch image: C X H X W
image_data = image_data.transpose((2, 0, 1))
print(image_data.shape)


(90, 120, 120)


In [40]:
tensor = torch.from_numpy(image_data)
tensor = tensor.unsqueeze(0) # add channels
tensor = tensor.unsqueeze(0) # add bactch (?)
tensor.shape

torch.Size([1, 1, 90, 120, 120])

In [38]:
# Consider trying to load part of an image from the dataset here. 

In [43]:
#out = mcn.forward(tensor.float()) #this probably works as well. 
out = mcn(tensor.float())

In [44]:
out.shape

torch.Size([1, 16, 88, 118, 118])

In [45]:
### Left off here
# - get back to 1 channel (not 16)
# - zero pad the image (argument to Conv3d) so that you don't accidentally trim it with the kernel.