In [88]:
import numpy as np
import torch
import random

In [37]:
def get_new_anchor(nonzero_pos, anchors):

    num_nonzero = len(nonzero_pos[0])
    
    # throw error if there are no possible anchor positions left
    if len(anchors) == num_nonzero:
        raise ValueError('All possible anchor positions have been already selected.')
    
    # select any segment position at random
    rand_nonzero_idx = np.random.randint(num_nonzero)
    i = nonzero_pos[0][rand_nonzero_idx]
    j = nonzero_pos[1][rand_nonzero_idx]
    anchor = [i,j]
    
    # ensure anchor has not already been selected
    while anchor in anchors:
        rand_nonzero_idx = np.random.randint(num_nonzero)
        i = nonzero_pos[0][rand_nonzero_idx]
        j = nonzero_pos[1][rand_nonzero_idx]
        anchor = [i,j]
    
    return anchor, rand_nonzero_idx

In [38]:
def position_is_in_image(pos, img_rows, img_cols):
    c1 = pos[0] >= 0 and pos[0] < img_rows
    c2 = pos[1] >= 0 and pos[1] < img_cols

    position_is_in_image = c1 and c2
    return position_is_in_image

In [39]:
def position_is_background(pos, nonzero_pos):
    
    position_is_background = True
    
    # get all segment row positions that match pos[0]
    matches_row_idx = np.where(nonzero_pos[0] == pos[0])[0] # extract array from results

    # check all row position matches to see if matching column position exists
    for each in matches_row_idx:
        if nonzero_pos[1][each] == pos[1]:
            position_is_background = False
        
    return position_is_background

In [40]:
def get_radius_positions(anchor, radius):
    # get (i,j) position of upper left corner at given radius from anchor even if invalid
    ul_i = anchor[0] - radius
    ul_j = anchor[1] - radius

    # set the length of the square formed at the current radius away from anchor 
    square_length = 3 + 2 * (radius - 1)

    # calc all valid positions that are radius pixels away from the anchor
    # ignore any positions that are segment pixels or invalid positions
    radius_positions = [] 
    for i_offset in range(square_length):
            if i_offset == 0 or i_offset == square_length-1:  # first row or last row

                # every position in first or last row is radius pixels from anchor
                for j_offset in range(square_length):
                    pos = [ul_i+i_offset,ul_j+j_offset]
                    radius_positions.append(pos)

            else: # any row except first or last

                # only first and last columns are radius pixels from anchor
                for j_offset in [0,square_length-1]:
                    pos = [ul_i+i_offset,ul_j+j_offset]
                    radius_positions.append(pos)
                    
    return radius_positions

In [41]:
def get_new_negative(nonzero_pos, anchor, img_rows, img_cols):
    # check for background pixels using increasing radius approach
    # closest background pixel to anchor to be the hard negative
    
    negative = None
    radius=1 # start search 1 pixel away from anchor
    
    while negative is None:

        radius_positions = get_radius_positions(anchor, radius)
#         print(radius_positions)
                        
        # check all radius positions for valid candidates
        candidates = []
        for pos in radius_positions:
            is_valid = position_is_in_image(pos, img_rows, img_cols)
            if is_valid:
                is_background = position_is_background(pos, nonzero_pos)
                if is_background:
                    candidates.append(pos)
                        
        # select candidate at random if any exist
        if len(candidates) > 0:
            negative = random.choice(candidates)
        else:
            # no candidates exist at that radius
            # increment radius and search again
            radius += 1
            
            # sanity check
            if radius > img_rows:
                raise ValueError('Target does not have any valid negative samples.')
                        
    return negative        

In [42]:
def get_new_positive(nonzero_pos, anchor_idx):
    num_nonzero = len(nonzero_pos[0])
    
    if num_nonzero < 2:
        raise ValueError('Target does not have enough segment pixels to choose a positive sample.')
    
    # choose any segment pixel except anchor as positive
    rand_positive_idx=np.random.randint(num_nonzero)
    while rand_positive_idx == anchor_idx:
        rand_positive_idx=np.random.randint(num_nonzero)

    # convert index to an (i,j) pixel position
    i = nonzero_pos[0][rand_positive_idx]
    j = nonzero_pos[1][rand_positive_idx]
    positive = [i,j]
    return positive

In [82]:
def get_triplet_indices(target, n):
    
    # for storing n triplets
    anchors = []
    positives = []
    negatives = []
    
    # extract dimensions of target image
    img_rows = target.shape[0]
    img_cols = target.shape[1]
    
    # get indices of all nonzero positions (segment pixels)
    nonzero_pos = np.nonzero(target.numpy())
    
    for i in range(n):
        # choose as anchor a segment pixel which has not yet been chosen to be an anchor
        curr_anchor, curr_anchor_idx = get_new_anchor(nonzero_pos, anchors)
        
        # calculate closest background pixel to be hard negative for this anchor
        curr_negative = get_new_negative(nonzero_pos, curr_anchor, img_rows, img_cols)
        
        # choose as positive a segment pixel which is not the current anchor
        curr_positive = get_new_positive(nonzero_pos, curr_anchor_idx)
        
        # save triplet
        anchors.append(curr_anchor)
        negatives.append(curr_negative)
        positives.append(curr_positive)
        
    return np.array(anchors), np.array(negatives), np.array(positives)
    

In [186]:
target = torch.tensor([[0., 0., 1., 1., 1.],
       [0., 0., 0., 1., 1.],
       [0., 0., 0., 0., 1.],
       [1., 0., 0., 0., 1.],
       [1., 1., 0., 0., 1.]])
print(target.shape)
target

torch.Size([5, 5])


tensor([[0., 0., 1., 1., 1.],
        [0., 0., 0., 1., 1.],
        [0., 0., 0., 0., 1.],
        [1., 0., 0., 0., 1.],
        [1., 1., 0., 0., 1.]])

In [84]:
# target = np.zeros((5,5))
# print(target.shape)
# target

In [85]:
# target = np.ones((5,5))
# print(target.shape)
# target

In [86]:
# target = np.zeros((5,5))
# target[0][0]=1
# print(target.shape)
# target

In [195]:
n=2
target = torch.tensor(target)
triplet_indices = get_triplet_indices(target, n)
triplet_indices

  target = torch.tensor(target)


(array([[4, 4],
        [3, 0]]),
 array([[3, 3],
        [2, 1]]),
 array([[1, 4],
        [0, 3]]))

In [235]:
output = torch.rand((2,5,5))
output

tensor([[[0.6375, 0.5182, 0.4396, 0.0555, 0.0057],
         [0.5666, 0.8391, 0.4668, 0.0356, 0.8924],
         [0.3725, 0.6685, 0.8238, 0.8782, 0.9070],
         [0.2593, 0.5689, 0.4267, 0.0236, 0.8909],
         [0.9445, 0.1930, 0.5347, 0.1236, 0.3778]],

        [[0.4523, 0.5304, 0.7748, 0.9978, 0.2531],
         [0.9391, 0.7552, 0.1014, 0.0545, 0.8908],
         [0.0469, 0.3538, 0.1531, 0.5613, 0.4489],
         [0.0234, 0.8388, 0.2468, 0.6098, 0.6700],
         [0.7683, 0.1252, 0.7088, 0.4645, 0.6962]]])

In [252]:
def get_output_triplets(output, triplet_indices):
#     print(output.shape)
    
    # unpack
    anchors, negatives, positives = triplet_indices
    
    output_anchors = []
    output_negatives = []
    output_positives = []
    
    for j,k in anchors:
        output_anchor = output[:,j,k]
        output_anchors.append(output_anchor)
        
    for j,k in negatives:
        output_negative = output[:,j,k]
        output_negatives.append(output_negative)
        
    for j,k in positives:
        output_positive = output[:,j,k]
        output_positives.append(output_positive)

    # convert lists to tensors
    tensor_anchors = torch.stack(output_anchors).T
    tensor_positives = torch.stack(output_negatives).T
    tensor_negatives = torch.stack(output_positives).T

    return tensor_anchors, tensor_positives, tensor_negatives

In [253]:
tensor_triplets = get_output_triplets(output, triplet_indices)
tensor_triplets

(tensor([[0.3778, 0.2593],
         [0.6962, 0.0234]]),
 tensor([[0.0236, 0.6685],
         [0.6098, 0.3538]]),
 tensor([[0.8924, 0.0555],
         [0.8908, 0.9978]]))

In [289]:
def triplet_loss(tensor_triplets, margin=0.2):
    
    # unpack
    tensor_anchors,tensor_negatives, tensor_positives = tensor_triplets
    
    diff_anchor_positives = tensor_anchors - tensor_positives
    norm_anchor_positives = torch.linalg.norm(diff_anchor_positives)
#     print(norm_anchor_positives)
    
    diff_anchor_negatives = tensor_anchors - tensor_negatives
    norm_anchor_negatives = torch.linalg.norm(diff_anchor_negatives)
#     print(norm_anchor_negatives)
    
    option1 = norm_anchor_positives - norm_anchor_negatives + margin
    option2 = 0.0
#     print(option1)
    
    loss = max(option1, option2)
    return float(loss)

In [290]:
triplet_loss(tensor_triplets)

0.697445273399353

In [291]:
targets = torch.stack([target,target])
print(targets.shape)
targets

torch.Size([2, 5, 5])


tensor([[[0., 0., 1., 1., 1.],
         [0., 0., 0., 1., 1.],
         [0., 0., 0., 0., 1.],
         [1., 0., 0., 0., 1.],
         [1., 1., 0., 0., 1.]],

        [[0., 0., 1., 1., 1.],
         [0., 0., 0., 1., 1.],
         [0., 0., 0., 0., 1.],
         [1., 0., 0., 0., 1.],
         [1., 1., 0., 0., 1.]]])

In [309]:
def loss(output, targets, n=1, margin=0.5):
    batch_loss = 0.0
    
    # for each output in the mini-batch, calculate triplet_loss
    for i in range(len(output)):
        target_i = targets[i]
        output_i = output[i]
#         print(output_i.shape)
        if len(output_i.shape) < 3:
            # output tensor must have at least 3 dimensions for function to work
            output_i = torch.unsqueeze(output_i, 0)
#         print(output_i.shape)
        
        
        
        # pixel indices
        triplet_indices_i = get_triplet_indices(target_i, n)
#         print(triplet_indices_i)
        
        # output pixels
        tensor_triplets_i = get_output_triplets(output_i, triplet_indices_i)
#         print(tensor_triplets_i)
        
        # calculate loss
        loss_i = triplet_loss(tensor_triplets_i, margin)
        batch_loss += loss_i
        
    return batch_loss

In [310]:
loss(output, targets)

0.7193681001663208