# Learning label model conditioned on disagreements

In [None]:
import numpy as np
from matplotlib import pyplot as plt
from sklearn.metrics import f1_score
from CRF import run_crf, compute_dice
from label_model import run_seg_label_model, generate_data, flying_squid, binarize_conditionals, compute_baselines


## Notebook overview

This notebook will estimate segmentation masks give noisy outputs of labeling functions. We currently assume five labeling functions are available to generate five noisy segmentation masks per training image.

We first include an example with synthetic data to make sure everything is working, then show an example with real data. 


## Test PGM estimation with synthetic data

Here, we include an example with synthetic data. This should validate the code is working. An example with real data is in the next section.

### Generate synthetic data according to proposed PGM

In [None]:
# Set a few vars controlling dataset size and distribution
n_train=500000 # Number of samples in synthetic train set
n_test=100000 # Number of samples in synthetic val set
use_comp = True # Which data generating function to use. False for simple, classification-like PGM; True for complex, segmentation-like PGM

# Some additional parameters, do not change
dep_nodes = [0,1,2] # Nodes among which we check for disagreements
ind_nodes = [3,4] # Nodes we keep conditionally independent for triplet method
all_nodes = dep_nodes + ind_nodes
n_conds = 4 # Number of disagreement conditions

# Generate some random parameters we'll use to define the ground truth data distribution
if not use_comp: # Synthetic data for standard WS PGM
    theta_y = .05*np.random.randn(1,1)
    theta_lam_y_ind = np.random.uniform(.1,1,(len(all_nodes),1))
    theta_lam_y_cond = None
    theta_lam_lam = None
else: # Synthetic data matching complex PGM
    std_conds = 0.2
    theta_y = .05*np.random.randn() + np.random.uniform(-std_conds,std_conds,(n_conds,1))
    theta_lam_y_ind = np.random.uniform(0.2,0.6,len(ind_nodes)) + np.random.uniform(-std_conds,std_conds,(n_conds,len(ind_nodes)))
    theta_lam_y_cond = np.random.uniform(0.2,0.6) + np.random.uniform(-std_conds,std_conds,(n_conds))
    theta_lam_lam = None    
thetas = [theta_y,theta_lam_y_ind,theta_lam_y_cond,theta_lam_lam]

# Generate a dataset according to canonical parameters (thetas)
sample_matrix, sample_matrix_test, lst, pmf = generate_data(n_train, n_test, theta_y, theta_lam_y_ind, theta_lam_y_cond, theta_lam_lam, m=len(all_nodes), v=len(all_nodes)+1, comp=use_comp)

# View data
fig, axs = plt.subplots(1,3,figsize=(18,5))
axs[0].plot(np.mean(np.asarray(lst)[:,:-1]==1,1),pmf,'r.')
axs[0].set_title('PMF for each joint')
axs[0].set_ylabel('probability')
axs[0].set_xlabel('percent of lambdas voting 1')
axs[1].plot(pmf,'b.')
axs[1].set_title('PMF for each joint, second visualization')
axs[1].set_ylabel('probability')
axs[1].set_xlabel('assignment')
axs[2].hist(sample_matrix[:,-1])
axs[2].set_title('Class balance')
fig.show()

### Aggregate pixels with label model

In [None]:
L_train = sample_matrix[:,:-1] # N x 5
L_dev = sample_matrix_test[:,:-1] # M x 5
Y_dev = sample_matrix_test[:,-1] # M

est_thetas, est_pmf, P_train, P_dev = run_seg_label_model(L_train, L_dev, Y_dev)

### Compare to previous weak supervision model

In [None]:
empirical_p_Y = np.mean(Y_dev==1)
fs_model, fs_ind_cond = flying_squid(sample_matrix[:,:-1], sample_matrix_test[:,:-1], m=5, v=1, cb=np.asarray([1-empirical_p_Y,empirical_p_Y]))
print('Done with flying squid model.')

### Compare accuracies from binarized conditional probs on dev set

In [None]:
proposed_Y = binarize_conditionals(P_dev)
fs_ind_Y = binarize_conditionals(fs_ind_cond)
mv_Y, _, _ = compute_baselines(sample_matrix_test[:,:-1])
lf1 = binarize_conditionals(sample_matrix_test[:,0])
lf2 = binarize_conditionals(sample_matrix_test[:,1])
lf3 = binarize_conditionals(sample_matrix_test[:,2])
lf4 = binarize_conditionals(sample_matrix_test[:,3])
lf5 = binarize_conditionals(sample_matrix_test[:,4])

for preds, pred_name in zip([proposed_Y,fs_ind_Y,mv_Y,lf1,lf2,lf3,lf4,lf5],
                            ['Proposed','Flying Squid','Majority vote','LF1','LF2','LF3','LF4','LF5']):
    print('\t',pred_name,'f1:',f1_score(sample_matrix_test[:,-1],preds[:]))
    

## Use PGM estimation with real data

To run with real data, you need to define:
- X_train: list of N training images; each image is size X x Y x D; each pixel should be [0,1]
- L_train: list of N training noisy labels; each noisy label is size 5 x X x Y x D; each pixel should be {0,1}
- X_dev: list of M dev set images; each image is size X x Y x D; each pixel should be [0,1]
- L_dev: list of M dev set noisy labels; each noisy label is size 5 x X x Y x D; each pixel should be {0,1}
- Y_dev: list of M dev set GT labels; each GT label is size X x Y x D; each pixel should be {0,1}

We will reformat the lists into following matrices:
- L_train_mat: matrix of all training noisy labels; matrix is size (N * X * Y * D) x 5
- L_dev_mat: matrix of all dev set noisy labels; matrix is size (M * X * Y * D) x 5
- Y_dev_mat: matrix of all dev set ground truth labels; matrix is size (M * X * Y * D)

Then use the label model + CRF code to generate the following results:
- Y_apx_train: list of N approximated labels for train set; each label is size X x Y x D
- Y_apx_dev: list of M approximated labels for dev set; each label is size X x Y x D

Note: you can reformat code to use without dev set labels if they're not available


In [None]:
# Start by reshaping the lists into matrices
L_train_mat = np.hstack([l.reshape(5,-1) for l in L_train]).T
L_dev_mat = np.hstack([l.reshape(5,-1) for l in L_dev]).T
Y_dev_mat = np.squeeze(np.hstack([y.reshape(1,-1) for y in Y_dev]).T)

L_train_mat[L_train_mat<1] = -1
L_dev_mat[L_dev_mat<1] = -1
Y_dev_mat[Y_dev_mat<1] = -1


In [None]:
# Aggregate at the pixel level
est_thetas, est_pmf, P_train_mat, P_dev_mat = run_seg_label_model(L_train_mat, L_dev_mat, Y_dev_mat)


In [None]:
# Reformat matrices into lists
P_train = []
total_pix = 0
for img in X_train:
    pix_in_img = img.shape[0]*img.shape[1]*img.shape[2]
    img_weak_pix = P_train_mat[total_pix:total_pix+pix_in_img]
    img_weak_pix = img_weak_pix.reshape(img.shape)
    P_train += [img_weak_pix]
    total_pix += pix_in_img
        
P_dev = []
total_pix = 0
for img in X_dev:
    pix_in_img = img.shape[0]*img.shape[1]*img.shape[2]
    img_weak_pix = P_dev_mat[total_pix:total_pix+pix_in_img]
    img_weak_pix = img_weak_pix.reshape(img.shape)
    P_dev += [img_weak_pix]
    total_pix += pix_in_img
    

In [None]:
# Run CRF over probabilistic labels
Y_apx_train, Y_apx_dev = run_crf(X_train, P_train, X_dev, P_dev, Y_dev, seed=1)

In [None]:
# Check dice for all images on dev set
all_dice = {'LF 0':[],'LF 1':[],'LF 2':[],'LF 3':[],'LF 4':[], 'MV':[], 'Pred':[]}
for img, weak, pred, strong in zip(X_dev,L_dev,Y_apx_dev,Y_dev):
    for lf in range(weak.shape[0]):
        dice = compute_dice(strong,weak[lf])
        all_dice['LF '+str(lf)] += [dice]
    mv = np.sum(weak,0)>(weak.shape[0]/2)
    dice = compute_dice(strong,mv)
    all_dice['MV'] += [dice]
    dice = compute_dice(strong,pred)
    all_dice['Pred'] += [dice]

print([(k,np.mean(v),len(v)) for (k, v) in all_dice.items()])