**This Google Colab notebook implements a U-Net designed by the Allen Institute** (https://github.com/AllenCellModeling/pytorch_fnet)

Typically, training a model requires access to GPUs. Google provides free GPUs via their platform Google Colaboratory so I'll be using that for training. One can then save the model parameters and run model prediction locally - this step shouldn't require a GPU.**

Google Colab is a Jupyter Notebook style platform. To enable GPU, click on "Runtime --> Change runtime type"



**<font color='green'>Parameters</font>**



*IMG_WIDTH*, *IMG_HEIGHT*
  
  - image dimensions; you can choose dimensions =/ your input and the code will reshape to new dimensions</font>

*batch_size*
 

  - optimization is done in batches to avoid violating memory constraints of hardware. Choose a *batch_size* that will give representative subsets of the data - optimization is ultimately an average over gradients computed on each batch, so non-representative batches will mislead the gradient descent and lead to non-optimal results (ie, decrease our likelihood of reaching true and reasonable minima of the loss function)

*learning_rate*

  - determines the stepsize of the gradient descent. Typical range is 0.0001 - 1, from slow - fast learning. *learning_rate* is usually heuristically determined - there is no universal formula

*numclasses*

  - This model is designed for 2 classes: foreground and background. So *numclasses*=2

*mult_chan*

  - Vestige of AllenInstitute implementation. This model only works with 1 channel of image data so *mult_chan* = 1

*epochs*

  - Training duration. In general, keep learning until you either can't or won't! No universal formula for this either

*betas*

  - Coefficients to do fancy stuff during Adam optimization. We can generally keep these default. 

*depth*

  - How many layers deep do you want U-Net? Typically 4-6. *My current implementation fails after something like depth=8, which I have to fix. Another vestige from AllenInstitute implementation* 

*contrast/brightness/saturation*

  - Values for data augmentation. E.g., *contrast = 0.1* will randomly apply some contrast between 0 and 10% to our training images. This helps us prevent overfitting (kind of like the transational invariance but more abstract)

*optimizer_fx*

  - Choose your optimization scheme. 
    - *Adam*
    - *SGD*

*loss_fx*

  - Choose your loss function. 
    - *BCEWithLogitsLoss*
    - *BCELoss*
    - *CrossEntropyLoss*
    - *NLLLoss*
    - *MLELoss* 

In [0]:
#  First, mount Google Drive. This let's us store data on Google Drive 
#  and access it (relatively quickly) from inside this notebook
from google.colab import drive
drive.mount('/content/drive')

In [0]:
#@title libraries
from tqdm import tqdm
import torch.nn.functional as F
from torch.autograd import Variable
import torch
from torchsummary import summary
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
import os
import numpy as np
from skimage.io import imread
from skimage.transform import resize
from IPython.display import clear_output
from albumentations import (HorizontalFlip, ShiftScaleRotate, Normalize,
                            Resize, Compose, GaussNoise)
from albumentations.pytorch import ToTensor

In [0]:
# # PARAMETERS

# normalization = True
IMG_WIDTH = 128
IMG_HEIGHT = 128

batch_size = 32
learning_rate = 0.001

numclasses = 2
mult_chan = 1
epochs = 750
depth = 6
betas = (0.5, 0.999)  # taken from paper
contrast = 0.1
brightness = 0.1
saturation = 0.1

optimizer_fx = 'Adam'     
loss_fx = 'BCEWithLogitsLoss'

# path to folder that houses subfolders (folder for signal and folder for ground truth)
TRAIN_PATH = '/content/drive/My Drive/DeepLearningProject/Training Data/Mito/'
# signal subfolder
TRAIN_IMAGES_PATH = TRAIN_PATH + 'MitoSignal/'
# ground-truth subfolder
TRAIN_MASKS_PATH = TRAIN_PATH + 'MitoMask/'

# path to save model parameters after optimization
MODEL_SAVE_PATH = '/content/drive/My Drive/DeepLearningProject/'
# filename 
MODEL_NAME = ('modelmito_' + str(IMG_WIDTH) + '_'+ str(batch_size) + '_' +
 str(learning_rate) + '_' + str(epochs) + '_' + str(depth) + '_' + str(contrast)[1:3] + 
 str(brightness)[1:4] + str(saturation)[1:4] + '_' + loss_fx + '_' + optimizer_fx +'.pt')
# print filename
print("MODEL_NAME: {}".format(MODEL_NAME))


MODEL_NAME: modelmito_128_32_0.001_750_6_.1.1.1_BCEWithLogitsLoss_Adam.pt


In [0]:
#@title read images

# GRAB PATHS TO TRAINING SET IMAGES
train_ids = next(os.walk(TRAIN_IMAGES_PATH))[2]
train_images = np.zeros((len(train_ids), IMG_WIDTH, IMG_HEIGHT,1), dtype=np.float32)
train_masks = np.zeros((len(train_ids), IMG_WIDTH, IMG_HEIGHT,1), dtype=np.bool)

# READ IN SIGNAL FOR TRAINING
n = 0
num_val_chosen = 0 
for filename in sorted(os.listdir(TRAIN_IMAGES_PATH)):
  img = imread(os.path.join(TRAIN_IMAGES_PATH, filename))
  img = resize(img, (IMG_WIDTH, IMG_HEIGHT), mode='constant', preserve_range=True)
  img -= np.mean(img)
  img /= np.std(img)
  img[img < 0] = 0
  train_images[n,:,:,0] = img
  n+=1
  if np.mod(n,10) == 0:
    clear_output()
    print("Loaded {}/{} of signal images".format(n,len(os.listdir(TRAIN_IMAGES_PATH))))

# READ IN MASKS FOR TRAINING
n = 0
for filename in sorted(os.listdir(TRAIN_MASKS_PATH)):
  img = imread(os.path.join(TRAIN_MASKS_PATH, filename))
  img = resize(img, (IMG_WIDTH, IMG_HEIGHT), mode='constant', preserve_range=True)
  train_masks[n,:,:,0] = img
  n+=1
  if np.mod(n,10) == 0:
    clear_output()
    print("Loaded {}/{} of mask images".format(n,len(os.listdir(TRAIN_IMAGES_PATH))))



In [0]:
#@title define dataset class and assign images


# DEFINE DATASET CLASS
class FormsDataset(Dataset):
    def __init__(self, images, masks, num_classes: int, transforms=None):
        self.images = images
        self.masks = masks
        self.num_classes = num_classes
        self.transforms = transforms
    
    def __getitem__(self, idx):
        image = self.images[idx]
        mask = self.masks[idx]
        image = image.astype(np.float32)
        mask = mask.astype(np.bool)

        if self.transforms is not None:
          image = self.transforms(image)

        return image, mask
    
    def __len__(self):
        return len(self.images)

#@title data augmentation

my_transforms = transforms.Compose([transforms.ToPILImage(),
                                    transforms.Grayscale(num_output_channels=1),
                                    transforms.ColorJitter(brightness=brightness,contrast=contrast,saturation=saturation),
                                    transforms.ToTensor()])


In [0]:
#@title define model layers
# define model first; based on paper
class SubNet2Conv(torch.nn.Module):
    def __init__(self, n_in, n_out):
        super().__init__()
        self.conv1 = torch.nn.Conv2d(n_in,  n_out, kernel_size=3, padding=1)
        self.bn1 = torch.nn.BatchNorm2d(n_out)
        self.relu1 = torch.nn.ReLU()
        self.conv2 = torch.nn.Conv2d(n_out, n_out, kernel_size=3, padding=1)
        self.bn2 = torch.nn.BatchNorm2d(n_out)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu1(x)
        x = self.conv2(x)
        x = self.bn2(x)
        return x

class MyModel(torch.nn.Module):
  def __init__(self, n_in_channels, mult_chan=2,depth=1):
    super().__init__()

    n_out_channels = n_in_channels * mult_chan
    self.sub_2conv_more = SubNet2Conv(n_in_channels, n_out_channels)

    if depth > 0:
        self.sub_2conv_less = SubNet2Conv(2*n_out_channels, n_out_channels)
        self.conv_down = torch.nn.Conv2d(n_out_channels, n_out_channels, 2, stride=2)
        self.bn0 = torch.nn.BatchNorm2d(n_out_channels)
        self.relu0 = torch.nn.ReLU()
            
        self.convt = torch.nn.ConvTranspose2d(2*n_out_channels, n_out_channels, kernel_size=2, stride=2)
        self.bn1 = torch.nn.BatchNorm2d(n_out_channels)
        self.relu1 = torch.nn.ReLU()
        self.sub_u = MyModel(n_out_channels,mult_chan=2,depth=depth-1)

        # self.relu2 = torch.nn.ReLU()
    self.depth = depth

  def forward(self, x):
    if self.depth == 0:
        return self.sub_2conv_more(x)
    else:  # depth > 0
        x_2conv_more = self.sub_2conv_more(x)
        x_conv_down = self.conv_down(x_2conv_more)
        x_bn0 = self.bn0(x_conv_down)
        x_relu0 = self.relu0(x_bn0)
        x_sub_u = self.sub_u(x_relu0)
        x_convt = self.convt(x_sub_u)
        x_bn1 = self.bn1(x_convt)
        x_relu1 = self.relu1(x_bn1)
        x_cat = torch.cat((x_2conv_more, x_relu1), 1)  
        x_2conv_less = self.sub_2conv_less(x_cat)
    return x_2conv_less


In [0]:
#@title model summary and init

# # CONNECT TO GPU
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

# # INITIALIZE AND ADD MODEL TO GPU
model = MyModel(n_in_channels=1,mult_chan = mult_chan,depth=depth).to(device)

# # SUMMARY
summary(model,input_size=(1,IMG_WIDTH,IMG_HEIGHT))

In [0]:
#@title training example image

ind = np.random.randint(0,len(train_images))
fig, ax = plt.subplots(1,2,figsize=(8,4))
ax[0].imshow(train_images[ind,:,:,0])
ax[0].set_title('Signal')
ax[1].imshow(train_masks[ind,:,:,0])
ax[1].set_title('Ground truth')

In [0]:
#@title train model

train_dataset = FormsDataset(train_images, train_masks, numclasses, transforms=my_transforms)
train_data_loader = DataLoader(train_dataset, batch_size=batch_size,shuffle=True)


# Use gpu for training if available else use cpu

# Here is the loss and optimizer definition

# criterion = torch.nn.BCELoss()
# m = torch.nn.LogSoftmax(dim=1)

if loss_fx.lower() in 'BCEWithLogitsLoss'.lower():
  criterion = torch.nn.BCEWithLogitsLoss()
elif loss_fx.lower() in 'BCELoss'.lower():
  criterion = torch.nn.BCELoss()
elif loss_fx.lower() in 'NLLLoss'.lower():
  criterion = torch.nn.NLLLoss()
elif loss_fx.lower() in 'MSELoss'.lower():
  criterion = torch.nn.MSELoss()
elif loss_fx.lower() in 'CrossEntropyLoss'.lower():
  criterion = torch.nn.CrossEntropyLoss()
else
  criterion = torch.nn.BCEWithLogitsLoss()

if optimizer_fx.lower() in 'Adam'.lower():
  optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate,
                             betas = betas)
elif optimizer_fx.lower() in 'SGD'.lower():
  optimizer = torch.optim.SGD(model.parameters(),lr-learning_rate)
else


# m = torch.nn.Sigmoid()
# optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
# scheduler = torch.optim.lr_scheduler.StepLR(optimizer=optimizer,step_size=1,gamma=0.1)

# The training loop
total_steps = len(train_data_loader)
iter = 0
print(f"{epochs} epochs, {total_steps} total_steps per epoch")

for epoch in range(epochs):
    for i, (images, masks) in enumerate(train_data_loader, 1):

        images = images.to(device)
        masks = masks.type(torch.LongTensor)
        masks = masks.to(device)

        # Forward pass
        outputs = model(images)
        masks = masks.permute(0,3,1,2)
        loss = criterion(outputs,masks.type_as(outputs))

        # Backward and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()


    iter +=1
    if np.mod(iter,10) == 0:
      print (f"Epoch [{epoch + 1}/{epochs}], Step [{i}/{total_steps}], Loss: {loss.item():4f}")


In [0]:
#@title save model after training
# SAVE MODEL
torch.save(model.state_dict(),MODEL_SAVE_PATH + MODEL_NAME)

In [0]:
#@title check predictions on training data


#prep training data
# DEFINE DATASET CLASS
class ValDataset(Dataset):
    def __init__(self, images, num_classes: int, transforms=None):
        self.images = images
        self.num_classes = num_classes
        self.transforms = transforms
    
    def __getitem__(self, idx):
        image = self.images[idx]
        image = image.astype(np.float32)
        image = image
    
        return image
    
    def __len__(self):
        return len(self.images)

# read validation data
val_dataset = ValDataset(np.moveaxis(train_images,-1,1),numclasses,None)
val_data_loader = DataLoader(val_dataset, batch_size=1, shuffle=False)

val_images = []
for i, images in enumerate(val_data_loader, 1):
  images = images.to(device).type(torch.FloatTensor)
  val_images.append(images)

  
# VISUALIZE A CELL FROM THE TRAINING SET AND SEE MODEL'S PREDICTION
cell = np.random.randint(0,len(val_images))

pred = model(val_images[cell].to(device))

fig, ax = plt.subplots(2,3,figsize=(12,6))
ax[0,0].imshow(train_images[cell,:,:,0])
ax[0,0].set_title("Signal")
ax[0,1].imshow(pred[0,0,:,:].cpu().detach())
ax[0,1].set_title("Prediction")
ax[0,2].imshow(train_masks[cell,:,:,0])
ax[0,2].set_title("Ground-truth")

ax[1,0].hist(train_images[cell,:,:,0]);
ax[1,1].hist(pred[0,0,:,:].cpu().detach());
ax[1,2].hist(train_masks[cell,:,:,0]);

