In [None]:
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
import os

from PIL import Image
from torchvision.transforms import Resize, Compose, ToTensor, Normalize
import numpy as np
import skimage
import matplotlib.pyplot as plt

import time
import pickle

from datetime import datetime
from pathlib import Path

In [None]:
class Flatten(nn.Module):
    """ Reshapes a 4d matrix to a 2d matrix. """
    def forward(self, input):
        return input.view(input.size(0), -1)
    
class CNN(nn.Module):

    def __init__(self, *layers):
        super(CNN, self).__init__()
        
        self.model = []
        
        for i in range(len(layers)): 
            
            if i+1 != len(layers): 
                self.append_conv_layer(*layers[i])
            else:    
                self.model.append(Flatten())    
                self.model.append(nn.Dropout(0.25))
                self.model.append(nn.Linear(*layers[i]))
                
        self.model = nn.Sequential(*self.model)
        
    def append_conv_layer(self, in_channels, out_channels, kernel_size=3, activation=nn.ReLU()):
        
        self.model.append(nn.Conv3d(in_channels, out_channels, kernel_size, padding=1))
        
        self.model.append(activation)        
        
        self.model.append(nn.MaxPool3d(2, stride=2, ceil_mode=True))
        
        self.model.append(nn.BatchNorm3d(out_channels))
        
        
    def forward(self, x):
        out = self.model(x)
        return out

1 2


In [None]:
print("Imported CNN model.")