In [1]:
import torch
import torch.nn as nn

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# for reproducibility
torch.manual_seed(0)
torch.backends.cudnn.benchmark = False  # set False whenever input size varies
torch.backends.cudnn.deterministic = True
# torch.cuda.empty_cache()

In [10]:
# read data


In [12]:
class Encoder(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.encode = nn.Sequential(
            # N x 1 x 145 x 174 x 145 --> N x 16 x 73 x 87 x 73
            nn.Conv3d(
                in_channels=1,
                out_channels=1,
                kernel_size=3,
                stride=2,
                padding=1
            ),
            nn.BatchNorm3d(16),
            nn.ReLU(),
        )
    
    def forward(self, x):
        return self.encode(x)

class Decoder(nn.Module):
    def __init__(self):
        super().__init__()
    
        self.decode = nn.Sequential(
            # N x 1 x 73 x 87 x 73 --> N x 1 x 145 x 174 x 145
            nn.ConvTranspose3d(
                in_channels=1,
                out_channels=1,
                kernel_size=3,
                stride=2,
                padding=1,
                output_padding=(0, 1, 0)
            ),
            nn.BatchNorm3d(1),
            nn.ReLU(),
            
            # layer to remove checkerboard artifacts
            # N x 1 x 145 x 174 x 145 -->  N x 1 x 145 x 174 x 145
            nn.Conv3d(
                in_channels=1,
                out_channels=1,
                kernel_size=3,
                stride=1,
                padding=1,
            )
            nn.BatchNorm3d(1),
        )
    
    def forward(self, feature):
        return self.decode(feature)