In [None]:
def plot_mesh(img, verts, faces, R, t, K):
    ax = plt.subplot(111)
    plt.imshow(img)
    verts_2d = np.matmul(K, np.matmul(R, verts.T) + t).T
    verts_2d = verts_2d[:,:2] / verts_2d[:,2,None]

    patches = []
    for face in faces:
        points = [verts_2d[i_vertex-1] for i_vertex in face]
        poly = Polygon(points, True)
        patches.append(poly)
    p = PatchCollection(patches, cmap=matplotlib.cm.jet, alpha=0.4)
    ax.add_collection(p)
    plt.show()

In [None]:
def show_keypoints(image, key_pts):
    """Show image with keypoints"""
    plt.imshow(image)
    plt.scatter(key_pts[:, 0], key_pts[:, 1], s=20, marker='.', c='m')

def show_all_keypoints(image, predicted_key_pts, gt_pts=None, fileName = None ,plot=False):
    """Show image with predicted keypoints"""
    # image is grayscale
    plt.imshow(image, cmap='gray')
    
    if plot:
        plt.plot(predicted_key_pts[:, 0], predicted_key_pts[:, 1], c='m', label='Predicted')
        # plot ground truth points as green pts
        if gt_pts is not None:
            plt.plot(gt_pts[:, 0], gt_pts[:, 1], c='g', label='True')
    else:
        
        plt.scatter(predicted_key_pts[:, 0], predicted_key_pts[:, 1], s=20, marker='.', c='m')
        # plot ground truth points as green pts
        if gt_pts is not None:
            plt.scatter(gt_pts[:, 0], gt_pts[:, 1], s=20, marker='.', c='g')
    
    if fileName is not None:
        plt.savefig("OutputKP/{}".format(fileName))
        
# visualize the output
# by default this shows a batch of 10 images
def visualize_output(data, test_outputs, gt_pts=None ,batch_size=10, plot=False, savefig = False):

    for i in range(0, batch_size):
        plt.figure(figsize=(20,10))
        ax = plt.subplot(1, batch_size, i+1)
        test_data = data[i]
        # un-transform the image data
        image = test_data['image']   # get the image from it's Variable wrapper

        # un-transform the predicted key_pts data
        predicted_key_pts = test_outputs[i].data
        predicted_key_pts = predicted_key_pts.numpy()
        
        # plot ground truth points for comparison, if they exist
        ground_truth_pts = None
        if gt_pts is not None:
            ground_truth_pts = test_data['keypoints']
        
        file = 'Rescaled-Output-{}.png'.format(i) if savefig else None
        # call show_all_keypoints
        show_all_keypoints(image, predicted_key_pts, ground_truth_pts, fileName = file ,plot=plot)
        
        #print('RMS error for image %d is : %03f' %(i, np.sqrt(mean_squared_error(ground_truth_pts, predicted_key_pts))))
            
        #plt.axis('off')

    plt.show()
    
def visualize_test_output(test_images, test_outputs, gt_pts=None ,batch_size=10, plot=False, savefig = False):

    for i in range(0, batch_size):
        plt.figure(figsize=(20,10))
        #ax = plt.subplot(i+1, 1, i+1)

        # un-transform the image data
        image = test_images[i].data   # get the image from it's Variable wrapper
        image = image.numpy()   # convert to numpy array from a Tensor
        image = np.transpose(image, (1, 2, 0))   # transpose to go from torch to numpy image

        # un-transform the predicted key_pts data
        predicted_key_pts = test_outputs[i].data
        predicted_key_pts = predicted_key_pts.numpy()
        
        # plot ground truth points for comparison, if they exist
        ground_truth_pts = None
        if gt_pts is not None:
            ground_truth_pts = gt_pts[i]
        
        file = 'Output-{}.png'.format(i) if savefig else None
        # call show_all_keypoints
        show_all_keypoints(np.squeeze(image), np.squeeze(predicted_key_pts), np.squeeze(ground_truth_pts),fileName = file, plot=plot)
        
        #print('RMS error for image %d is : %03f' %(i, np.sqrt(mean_squared_error(ground_truth_pts, predicted_key_pts))))
            
        #plt.axis('off')

    plt.show()

def drawKeyPoints(num_to_display, dataSet):
    
    for i in range(num_to_display):

        # define the size of images
        fig = plt.figure(figsize=(20,10))

        plt.ion()
        # randomly select a sample
        rand_i = np.random.randint(0, len(dataSet))
        sample = dataSet[rand_i]

        # print the shape of the image and keypoints
        print(i, sample['image'].shape, sample['keypoints'].shape)

        ax = plt.subplot(1, num_to_display * 2, i + 1)
        ax.set_title('ORB Sample #{}'.format(i))
        
        plt.imshow(sample['orb_image'])
        
        ax = plt.subplot(1, num_to_display * 2, i + 2)
        ax.set_title('Sample #{}'.format(i))
        
        # Using the same display function, defined earlier
        show_keypoints(sample['image'], sample['keypoints'])