# This notebook shows training of new birth networks.

set work directory to the "Single_Particle_Tracking" folder

download and upzip [data and intermediate results](https://drive.google.com/open?id=1AO6du609gYup2mcyKIWEqU5dH5p8Fa4K) to work directory

In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

from __future__ import print_function
import matplotlib.pyplot as plt
from skimage import io
import random 
import numpy as np

from skimage import measure
from skimage.measure import label 
import scipy as sp
import scipy.ndimage.morphology

In [None]:
# load blury images (obs.tif) and ground truth (labels.tif)

folder = './data_and_pre_calculated_results/'

from skimage import io
obs = io.imread(folder+'obs.tif').astype('float32')
labels = io.imread(folder+'label.tif').astype('float32')
img_size = obs.shape[-1]

print(obs.shape)
print(labels.shape)

In [None]:
# plot raw data  

import random 
n = 5
i_base = random.randint(0,200)  
plt.figure(figsize=(24, 10))
for i in range(n):
 
    ax = plt.subplot(2, n, i+1)
    plt.imshow(obs[i+i_base,:,:].reshape(img_size, img_size),interpolation='none', cmap='jet')
 
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)
    ax.set_title('obs. i='+str(i+i_base))
    plt.colorbar()
    
    ax = plt.subplot(2, n, i+n+1)  
    plt.imshow(labels[i+i_base,:,:].reshape(img_size, img_size),interpolation='none', cmap='jet')  
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)
    ax.set_title('true. i='+str(i+i_base))
    plt.colorbar()
    true_idx = np.where(labels[i+i_base,:,:]!=0)
    true_idx_label = labels[i+i_base,:,:][true_idx]
    for tru in range(len(true_idx[0])):
        plt.text(true_idx[1][tru],true_idx[0][tru], str((true_idx_label[tru]).astype('int')),color=[.7,.7,.7])#color='magenta' )
        plt.plot(true_idx[1][tru],true_idx[0][tru], 'yx')
        
    plt.xlim(0,img_size-1)
    plt.ylim(img_size-1,0)

In [None]:
# normalize data

T_train = 9000 # number of frames chosen to train
T_all = 10000 # total number frames used (train+test)

x_train = obs[0:T_train,:,:]
x_train_before_normalization = x_train.copy() 
print(x_train.shape)
y_train = labels[0:T_train,:,:]
print(y_train.shape)

x_test = obs[T_train:T_all,:,:]
x_test_before_normalization = x_test.copy() 
print(x_test.shape)
y_test = labels[T_train:T_all,:,:]
print(y_test.shape)


x_train = x_train.astype('float32')    
x_test = x_test.astype('float32')    
 
#normalization 
for idx in range(x_train.shape[0]):
    x_train[idx,:,:] = x_train[idx,:,:]/x_train[idx,:,:].max()
 
    
for idx in range(x_test.shape[0]):
    x_test[idx,:,:] = x_test[idx,:,:]/x_test[idx,:,:].max()

In [None]:
# plot data after normalization
import random 
n = 5
i_base = random.randint(0,200)  
plt.figure(figsize=(24, 10))
for i in range(n):
 
    ax = plt.subplot(2, n, i+1)
    plt.imshow(x_train[i+i_base,:,:].reshape(img_size, img_size),interpolation='none', cmap='jet')
    #plt.gray()
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)
    ax.set_title('obs. i='+str(i+i_base))
    plt.colorbar()
    
    ax = plt.subplot(2, n, i+n+1)  
    plt.imshow(y_train[i+i_base,:,:].reshape(img_size, img_size),interpolation='none', cmap='jet') 
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)
    ax.set_title('true. i='+str(i+i_base))
    plt.colorbar()
    true_idx = np.where(labels[i+i_base,:,:]!=0)
    true_idx_label = labels[i+i_base,:,:][true_idx]
    for tru in range(len(true_idx[0])):
        plt.text(true_idx[1][tru],true_idx[0][tru], str((true_idx_label[tru]).astype('int')),color=[.7,.7,.7])#color='magenta' )
        plt.plot(true_idx[1][tru],true_idx[0][tru], 'yx')
        
    plt.xlim(0,img_size-1)
    plt.ylim(img_size-1,0)

# prepare NN input and output for New Birth Networks

In [None]:
# takes a while. you can load the pre-run NN input and output 
''' 
NN_inputs = np.load(folder+'NN_inputs_newbirth.npy')
print(NN_inputs.shape)
NN_outputs = np.load(folder+'NN_outputs_newbirth.npy')
print(NN_outputs.shape)
'''

Ntrain = x_train.shape[0]-2 

NN_inputs = np.zeros([Ntrain, 6, img_size, img_size, 1])  
NN_outputs = np.zeros([Ntrain, img_size, img_size, 1])  

for t in range(2,x_train.shape[0]-2):
    
    if t%1000==0:
        print(t)
    images = x_train[t-2:t+3,:,:]
    
    
    current_unique_values = np.unique(y_train[t,:,:])[1:]
    new_borns = []
    sampled_id = []
    new_borns_img = np.zeros([img_size, img_size])
    sampled_mask = np.zeros([img_size, img_size])
    for i in range(len(current_unique_values)):
        if len(np.where(y_train[:t,:,:] == current_unique_values[i])[0])==0:
            new_borns = np.append(new_borns, current_unique_values[i])
        
            var = new_borns_img[np.where(y_train[t,:,:] == current_unique_values[i])]
            if var != 0: 
                raise('var is not zero: more than one particles in the pixel, t=%d, i=%d'%(t,i))
            new_borns_img[np.where(y_train[t,:,:] == current_unique_values[i])] = 1
        
        else:
            var_sampled = sampled_mask[np.where(y_train[t,:,:] == current_unique_values[i])]
            sampled_id = np.append(sampled_id, current_unique_values[i])
            if var_sampled != 0: 
                raise('var_sampled is not zero: more than one particles in the pixel, t=%d, i=%d'%(t,i))
            sampled_mask[np.where(y_train[t,:,:] == current_unique_values[i])] = 1
            
            
    sampled_mask_concat = np.expand_dims(sampled_mask, axis=0)
    new_borns_img_concat = np.expand_dims(new_borns_img, axis=0)
    
    inputs = np.concatenate((images, sampled_mask_concat), axis=0)
    NN_inputs[t,:,:,:,0] = inputs   
    NN_outputs[t,:,:,0] = new_borns_img 

In [None]:
print(NN_outputs.shape)
print(NN_inputs.shape)

In [None]:
NN_outputs1 = np.expand_dims(NN_outputs, axis=1)  # extend a dimension for NN input
NN_outputs1.shape

In [None]:
np.save(folder+'NN_inputs_newbirth',np.transpose(NN_inputs, (0, 4, 1, 2, 3)))
np.save(folder+'NN_outputs_newbirth',np.transpose(NN_outputs1, (0, 4, 1, 2, 3)))

# plot NN input and output

In [None]:
# folder = './data_and_pre_calculated_results/'
# NN_inputs = np.load(folder+'NN_inputs_newbirth.npy')
# NN_outputs = np.load(folder+'NN_outputs_newbirth.npy')
# NN_inputs = np.transpose(NN_inputs, (0, 2, 3, 4, 1))
# NN_outputs = np.transpose(NN_outputs, (0, 2, 3, 4, 1))[:, 0, :, :, :]

In [None]:
# print(NN_outputs.shape)
# print(NN_inputs.shape)

In [None]:
# visualize neural networks input and output (for debug)
import random 
img_size = 60
n = 5
i_base = random.randint(0,100)  
plt.figure(figsize=(15, 6))
for i in range(n):
 
    
    ax = plt.subplot(2, n, i+1)
    plt.imshow(NN_inputs[i_base,i,:,:,0].reshape(img_size, img_size),interpolation='none', cmap='jet')
    #plt.gray()
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)
    ax.set_title('obs. i='+str(i+i_base))
    plt.colorbar()
    
    
    
    ax = plt.subplot(2, n, n+1)
    plt.imshow(NN_inputs[i_base,5,:,:,0].reshape(img_size, img_size),interpolation='none', cmap='jet')
 
    ax = plt.subplot(2, n, 2+n)
    plt.imshow(NN_outputs[i_base,:,:,0].reshape(img_size, img_size),interpolation='none', cmap='jet') 

In [1]:
import torch
import torch.nn as nn
from torchinfo import summary
import numpy as np
from tqdm import tqdm
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score
from dataset import ParticleDataset
from model import SamplingNet

# Define Dataloader

In [2]:
torch.manual_seed(0)
full_dataset = ParticleDataset('./data_and_pre_calculated_results/', 'NN_inputs_newbirth.npy', 'NN_outputs_newbirth.npy')
train_size = int(0.95 * len(full_dataset))
test_size = len(full_dataset) - train_size
trainset, testset = torch.utils.data.random_split(full_dataset, [train_size, test_size])

trainloader = torch.utils.data.DataLoader(trainset, batch_size=10, shuffle=True, num_workers=8)
testloader = torch.utils.data.DataLoader(testset, batch_size=10, shuffle=False, num_workers=8)

In [3]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Define Model

In [4]:
model = SamplingNet(last_dim=6).to(device)
summary(model, input_size=(10, 1, 6, 60, 60))

Layer (type:depth-idx)                        Output Shape              Param #
SamplingNet                                   --                        --
├─Sequential: 1-1                             [10, 1, 1, 60, 60]        --
│    └─ConvBLSTM: 2-1                         [10, 64, 6, 60, 60]       --
│    │    └─ConvLSTM: 3-1                     [10, 32, 6, 60, 60]       38,144
│    │    └─ConvLSTM: 3-2                     [10, 32, 6, 60, 60]       38,144
│    └─BatchNorm3d: 2-2                       [10, 64, 6, 60, 60]       128
│    └─Dropout3d: 2-3                         [10, 64, 6, 60, 60]       --
│    └─ConvBLSTM: 2-4                         [10, 80, 6, 60, 60]       --
│    │    └─ConvLSTM: 3-3                     [10, 40, 6, 60, 60]       149,920
│    │    └─ConvLSTM: 3-4                     [10, 40, 6, 60, 60]       149,920
│    └─BatchNorm3d: 2-5                       [10, 80, 6, 60, 60]       160
│    └─Dropout3d: 2-6                         [10, 80, 6, 60, 60]       --


# Define optimizer and loss function

In [5]:
# optimizer
optimizer = torch.optim.RMSprop(model.parameters(), lr=0.0001, weight_decay=1e-6)
# loss
criterion = nn.BCELoss()

# Training

In [6]:
num_epoches = 10
for epoch in range(num_epoches):  # loop over the dataset multiple times
    
    # training
    model.train()
    train_bar = tqdm(enumerate(trainloader), total=len(trainloader), leave=False)
    
    for i, (data, targets) in train_bar:
        data = data.to(device)
        targets = targets.to(device)

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = model(data)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()
        
        pred = outputs.gt(0.5).float().cpu().numpy().flatten()
        targets = targets.cpu().numpy().flatten()
        
            
        train_bar.set_description(f'Epoch [{epoch+1}/{num_epoches}]')
        train_bar.set_postfix(loss=loss.item(), acc=accuracy_score(targets, pred), f1=f1_score(targets, pred, zero_division=0), prec=precision_score(targets, pred, zero_division=0), recall=recall_score(targets, pred, zero_division=0))
            
    # validating
    model.eval()
    test_bar = tqdm(enumerate(testloader), total=len(testloader), leave=False)
    
    with torch.no_grad():
        for i, (data, targets) in test_bar:
            data = data.to(device)
            targets = targets.to(device)
            outputs = model(data)
            pred = outputs.gt(0.5).cpu().numpy().flatten()
            targets = targets.cpu().numpy().flatten()
            
            test_bar.set_description(f'Epoch [{epoch+1}/{num_epoches}]')
            test_bar.set_postfix(loss=loss.item(), acc=accuracy_score(targets, pred), f1=f1_score(targets, pred, zero_division=0), prec=precision_score(targets, pred, zero_division=0), recall=recall_score(targets, pred, zero_division=0))

                                                                                                                             

In [7]:
# save trained model
torch.save(model.state_dict(), './ckps/new_birth.pth')

In [11]:
# load trained model
model = SamplingNet(last_dim=6).cuda()
model.load_state_dict(torch.load('./ckps/new_birth.pth'))
model.eval()
test_bar = tqdm(enumerate(testloader), total=len(testloader), leave=False)
    
with torch.no_grad():
    for i, (data, targets) in test_bar:
        data = data.to(device)
        targets = targets.to(device)
        outputs = model(data)
        pred = outputs.gt(0.5).cpu().numpy().flatten()
        targets = targets.cpu().numpy().flatten()

        test_bar.set_postfix(loss=loss.item(), acc=accuracy_score(targets, pred), f1=f1_score(targets, pred, zero_division=0), prec=precision_score(targets, pred, zero_division=0), recall=recall_score(targets, pred, zero_division=0))

                                                                                                        