# This notebook shows training of conditional probability networks.

set work directory to the "Single_Particle_Tracking" folder

download and upzip [data and intermediate results](https://drive.google.com/open?id=1AO6du609gYup2mcyKIWEqU5dH5p8Fa4K) to work directory

In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

from __future__ import print_function
import matplotlib.pyplot as plt
from skimage import io
import random 
import numpy as np

from skimage import measure
from skimage.measure import label 
import scipy as sp
import scipy.ndimage.morphology

In [None]:
# load blury images (obs.tif) and ground truth (labels.tif)

folder = './data_and_pre_calculated_results/'

from skimage import io
obs = io.imread(folder+'obs.tif').astype('float32')
labels = io.imread(folder+'label.tif').astype('float32')
img_size = obs.shape[-1]

print(obs.shape)
print(labels.shape)

In [None]:
# plot raw data  
import random 
n = 5
i_base = random.randint(0,200) 
plt.figure(figsize=(24, 10))
for i in range(n):
 
    ax = plt.subplot(2, n, i+1)
    plt.imshow(obs[i+i_base,:,:].reshape(img_size, img_size),interpolation='none', cmap='jet')
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)
    ax.set_title('obs. i='+str(i+i_base))
    plt.colorbar()
    
    ax = plt.subplot(2, n, i+n+1)
    plt.imshow(labels[i+i_base,:,:].reshape(img_size, img_size),interpolation='none', cmap='jet')
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)
    ax.set_title('true. i='+str(i+i_base))
    plt.colorbar()
    true_idx = np.where(labels[i+i_base,:,:]!=0)
    true_idx_label = labels[i+i_base,:,:][true_idx]
    for tru in range(len(true_idx[0])):
        plt.text(true_idx[1][tru],true_idx[0][tru], str((true_idx_label[tru]).astype('int')),color=[.7,.7,.7])#color='magenta' )
        plt.plot(true_idx[1][tru],true_idx[0][tru], 'yx')
        
    plt.xlim(0,img_size-1)
    plt.ylim(img_size-1,0)

In [None]:
# Normalize data

T_train = 9000  
T_all = 10000 

x_train = obs[0:T_train,:,:]
x_train_before_normalization = x_train.copy() 
print(x_train.shape)
y_train = labels[0:T_train,:,:]
print(y_train.shape)

x_test = obs[T_train:T_all,:,:]
x_test_before_normalization = x_test.copy() 
print(x_test.shape)
y_test = labels[T_train:T_all,:,:]
print(y_test.shape)


x_train = x_train.astype('float32')    
x_test = x_test.astype('float32')    
 
#normalization 
for idx in range(x_train.shape[0]):
    x_train[idx,:,:] = x_train[idx,:,:]/x_train[idx,:,:].max()
 
    
for idx in range(x_test.shape[0]):
    x_test[idx,:,:] = x_test[idx,:,:]/x_test[idx,:,:].max()

In [None]:
# plot data after normalization
import random 
n = 5
i_base = random.randint(0,200)  
plt.figure(figsize=(24, 10))
for i in range(n):
 
    ax = plt.subplot(2, n, i+1)
    plt.imshow(x_train[i+i_base,:,:].reshape(img_size, img_size),interpolation='none', cmap='jet')
 
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)
    ax.set_title('obs. i='+str(i+i_base))
    plt.colorbar()
    
    ax = plt.subplot(2, n, i+n+1)
    plt.imshow(y_train[i+i_base,:,:].reshape(img_size, img_size),interpolation='none', cmap='jet')
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)
    ax.set_title('true. i='+str(i+i_base))
    plt.colorbar()
    true_idx = np.where(labels[i+i_base,:,:]!=0)
    true_idx_label = labels[i+i_base,:,:][true_idx]
    for tru in range(len(true_idx[0])):
        plt.text(true_idx[1][tru],true_idx[0][tru], str((true_idx_label[tru]).astype('int')),color=[.7,.7,.7])#color='magenta' )
        plt.plot(true_idx[1][tru],true_idx[0][tru], 'yx')
        
    plt.xlim(0,img_size-1)
    plt.ylim(img_size-1,0)

# prepare NN input and output

In [None]:
# compute input and output for conditional probability neural networks. 

Ntrain = len(range(2,x_train.shape[0])) - 2 # remove t=0 and 1

nfeatures = 8 #

NN_inputs = np.zeros([Ntrain, nfeatures, img_size, img_size, 1])  
NN_outputs = np.zeros([Ntrain, img_size, img_size, 1])  


for t in range(2,x_train.shape[0]-2):

    if t%1000==0:
        print(t)
        
    # observations: t-2, t-1, t, t+1, t+2 
    images = x_train[t-2:t+3,:,:]
    
    # particles mask in t-1 (history)
    t_minus_1_mask = np.zeros([img_size, img_size])
    t_minus_1_mask[np.where(y_train[t-1,:,:]!=0)] = 1    
    t_minus_1_values = np.unique(y_train[t-1,:,:])
    t_minus_1_particles = t_minus_1_values[t_minus_1_values!=0] # particles at t-1
    nparticles = len(t_minus_1_particles) # number of particles at t-1
    
    # for sampling at t, from the pool of particles at t-1, random draw a subset of particles as "sampled partilces", 
    # the rest are particles to sample. 
    sampled_binary_mask = np.random.binomial(1, .5, nparticles)                                 
    sampled_particles = t_minus_1_particles[np.where(sampled_binary_mask==1)]
    to_sample_particles = t_minus_1_particles[np.where(sampled_binary_mask==0)]
    sampled_number = len(sampled_particles)
    
    # sampled at t
    t_mask_sampled = np.zeros([img_size, img_size]) 
    for i_sampled in range(len(sampled_particles)):
        var = t_mask_sampled[np.where(y_train[t,:,:] ==  sampled_particles[i_sampled])] 
        if var!=0:
            raise('sampled two')
        t_mask_sampled[np.where(y_train[t,:,:] ==  sampled_particles[i_sampled])] = 1
        
    # sample one new particle id for sampling at t
    if len(to_sample_particles) != 0:
        id_to_sampling = np.random.randint(len(to_sample_particles))
        particle_id_to_sampling =  to_sample_particles[id_to_sampling]
         
        
        t_minus_1_mask_sampling = np.zeros([img_size, img_size])
        t_minus_1_mask_sampling[np.where(y_train[t-1,:,:] == particle_id_to_sampling)] = 1
             
        # target (at t)
        t_mask_sampling = np.zeros([img_size, img_size])
        t_mask_sampling[np.where(y_train[t,:,:] == particle_id_to_sampling)] = 1 
                        
    else:
         
        t_minus_1_mask_sampling = np.zeros([img_size, img_size])
        t_mask_sampling = np.zeros([img_size, img_size])
            

        
    # expand dimension
    t_minus_1_mask = np.expand_dims(t_minus_1_mask,axis=0)
    t_mask_sampled = np.expand_dims(t_mask_sampled,axis=0)
    t_minus_1_mask_sampling = np.expand_dims(t_minus_1_mask_sampling,axis=0)
    t_mask_sampling  = np.expand_dims(t_mask_sampling,axis=0)
    
    inputs = np.concatenate((images, t_minus_1_mask, t_mask_sampled, t_minus_1_mask_sampling), axis=0)
    
    NN_inputs[t-2,:,:,:,0] = inputs   #remove t=0 and 1 in NN_inputs or NN_output
    NN_outputs[t-2,:,:,0] = t_mask_sampling 
   

In [None]:
print(NN_outputs.shape)
print(NN_inputs.shape)

In [None]:
NN_outputs1 = np.expand_dims(NN_outputs, axis=1)  # extend a dimension for NN input
NN_outputs1.shape

In [None]:
# save for pytorch training
np.save(folder+'NN_inputs_sampling',np.transpose(NN_inputs, (0, 4, 1, 2, 3)))
np.save(folder+'NN_outputs1_sampling',np.transpose(NN_outputs1, (0, 4, 1, 2, 3)))

In [None]:
np.amin(NN_outputs1)

# plot NN input and output

In [None]:
def label_im(im, marker_info, marker_label, if_text=0):
    
    # plot markers on images
    
    im_uni_values = im[im!=0]

    sz = im.shape[0]
    
    myidx = np.where(im!=0)
    for u in range(len(im_uni_values)):
 
        values =  im[myidx[0][u],myidx[1][u]]
         
        if if_text:
            plt.text(myidx[1][u],myidx[0][u],(values), color='w')
        if u==0:
            plt.plot(myidx[1][u],myidx[0][u], marker_info, label=marker_label)
        else:
            plt.plot(myidx[1][u],myidx[0][u], marker_info)
        
    plt.legend()    
    plt.xlim(0,sz-1)
    plt.ylim(sz-1,0)

In [None]:
# visualize NN input and output (for debug) 
import random 
img_size = 60
n = 5
i_base = random.randint(0,100) #72
plt.figure(figsize=(24, 8))

       
true_t_minus_1 =  y_train[i_base+1,:,:]  
true_t =  y_train[i_base+2,:,:] 

        
for i in range(n):
 
    
    ax = plt.subplot(1, n, i+1)
    plt.imshow(NN_inputs[i_base,i,:,:,0].reshape(img_size, img_size),interpolation='none', cmap='jet')
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)
    ax.set_title('sample=%d. frame=%d'%(i_base, i))
    plt.colorbar()
    
    
plt.figure()
plt.imshow(NN_inputs[i_base,5,:,:,0].reshape(img_size, img_size),interpolation='none', cmap='jet')
plt.title('t-1 mask, t=%d'%(2))
label_im(true_t_minus_1, 'yx', 't-1')
label_im(true_t, 'g+','t')
 
plt.figure()
plt.imshow(NN_inputs[i_base,6,:,:,0].reshape(img_size, img_size),interpolation='none', cmap='jet') 
plt.title('t sampled mask, t=%d'%(2))
label_im(true_t_minus_1, 'yx', 't-1')
label_im(true_t, 'g+','t')
    
plt.figure()
plt.imshow(NN_inputs[i_base,7,:,:,0].reshape(img_size, img_size),interpolation='none', cmap='jet')
plt.title('t-1 sampling mask, t=%d'%(2))
label_im(true_t_minus_1, 'yx', 't-1')
label_im(true_t, 'g+','t')
    
plt.figure()
plt.imshow(NN_outputs1[i_base,0,:,:,0].reshape(img_size, img_size),interpolation='none', cmap='jet')
plt.title('output: t sampling mask, t=%d'%(2))
label_im(true_t_minus_1, 'yx', 't-1')
label_im(true_t, 'g+','t')

In [2]:
import torch
import torch.nn as nn
from torchinfo import summary
import numpy as np
from tqdm import tqdm
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score
from dataset import ParticleDataset
from model import SamplingNet

# Define Dataloader

In [3]:
torch.manual_seed(0)
full_dataset = ParticleDataset('./data_and_pre_calculated_results/', 'NN_inputs_sampling.npy', 'NN_outputs1_sampling.npy')
train_size = int(0.95 * len(full_dataset))
test_size = len(full_dataset) - train_size
trainset, testset = torch.utils.data.random_split(full_dataset, [train_size, test_size])

trainloader = torch.utils.data.DataLoader(trainset, batch_size=10, shuffle=True, num_workers=8)
testloader = torch.utils.data.DataLoader(testset, batch_size=10, shuffle=False, num_workers=8)

In [4]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Define Model

In [5]:
model = SamplingNet(last_dim=8).to(device)
summary(model, input_size=(10, 1, 8, 60, 60))

Layer (type:depth-idx)                        Output Shape              Param #
SamplingNet                                   --                        --
├─Sequential: 1-1                             [10, 1, 1, 60, 60]        --
│    └─ConvBLSTM: 2-1                         [10, 64, 8, 60, 60]       --
│    │    └─ConvLSTM: 3-1                     [10, 32, 8, 60, 60]       38,144
│    │    └─ConvLSTM: 3-2                     [10, 32, 8, 60, 60]       38,144
│    └─BatchNorm3d: 2-2                       [10, 64, 8, 60, 60]       128
│    └─Dropout3d: 2-3                         [10, 64, 8, 60, 60]       --
│    └─ConvBLSTM: 2-4                         [10, 80, 8, 60, 60]       --
│    │    └─ConvLSTM: 3-3                     [10, 40, 8, 60, 60]       149,920
│    │    └─ConvLSTM: 3-4                     [10, 40, 8, 60, 60]       149,920
│    └─BatchNorm3d: 2-5                       [10, 80, 8, 60, 60]       160
│    └─Dropout3d: 2-6                         [10, 80, 8, 60, 60]       --


# Define optimizer and loss function

In [6]:
# optimizer
optimizer = torch.optim.RMSprop(model.parameters(), lr=0.0001, weight_decay=1e-6)
# loss
criterion = nn.BCELoss()

# Training Model

In [7]:
num_epoches = 10
for epoch in range(num_epoches):  # loop over the dataset multiple times
    
    # training
    model.train()
    train_bar = tqdm(enumerate(trainloader), total=len(trainloader), leave=False)
    
    for i, (data, targets) in train_bar:
        data = data.to(device)
        targets = targets.to(device)

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = model(data)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()
        
        pred = outputs.gt(0.5).float().cpu().numpy().flatten()
        targets = targets.cpu().numpy().flatten()
        
            
        train_bar.set_description(f'Epoch [{epoch+1}/{num_epoches}]')
        train_bar.set_postfix(loss=loss.item(), acc=accuracy_score(targets, pred), f1=f1_score(targets, pred), prec=precision_score(targets, pred, zero_division=0), recall=recall_score(targets, pred))
            
    # validating
    model.eval()
    test_bar = tqdm(enumerate(testloader), total=len(testloader), leave=False)
    
    with torch.no_grad():
        for i, (data, targets) in test_bar:
            data = data.to(device)
            targets = targets.to(device)
            outputs = model(data)
            pred = outputs.gt(0.5).cpu().numpy().flatten()
            targets = targets.cpu().numpy().flatten()
            
            test_bar.set_description(f'Epoch [{epoch+1}/{num_epoches}]')
            test_bar.set_postfix(loss=loss.item(), acc=accuracy_score(targets, pred), f1=f1_score(targets, pred), prec=precision_score(targets, pred, zero_division=0), recall=recall_score(targets, pred))

                                                                                                                            

In [8]:
# save trained model
torch.save(model.state_dict(), './ckps/transition.pth')

In [None]:
# load trained model
model = SamplingNet(last_dim=8).cuda()
model.load_state_dict(torch.load('./ckps/transition.pth'))
model.eval()
test_bar = tqdm(enumerate(testloader), total=len(testloader), leave=False)
    
with torch.no_grad():
    for i, (data, targets) in test_bar:
        data = data.to(device)
        targets = targets.to(device)
        outputs = model(data)
        pred = outputs.gt(0.5).cpu().numpy().flatten()
        targets = targets.cpu().numpy().flatten()

        test_bar.set_postfix(loss=loss.item(), acc=accuracy_score(targets, pred), f1=f1_score(targets, pred, zero_division=0), prec=precision_score(targets, pred, zero_division=0), recall=recall_score(targets, pred, zero_division=0))

                                                                                                        