A notebook to process results of the compression task on CelebA where the input set contains three images, weighted in various ways.

In [None]:
%matplotlib inline
%load_ext autoreload
%autoreload 2

#%config InlineBackend.figure_format = 'svg'
#%config InlineBackend.figure_format = 'pdf'

In [None]:
import cadgan
import cadgan.kernel as kernel
import cadgan.glo as glo
import cadgan.main as main
import cadgan.net.net as net
import cadgan.gen as gen
import cadgan.plot as plot
import cadgan.embed as embed
import cadgan.util as util

import matplotlib
import matplotlib.pyplot as plt
import os
import numpy as np
import scipy.stats as stats
import torch
import torch.nn.functional as F
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms

In [None]:
# font options
font = {
    #'family' : 'normal',
    #'weight' : 'bold',
    'size'   : 18
}

plt.rc('font', **font)
plt.rc('lines', linewidth=2)
matplotlib.rcParams['pdf.fonttype'] = 42
matplotlib.rcParams['ps.fonttype'] = 42

In [None]:
use_cuda = True and torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
tensor_type = torch.cuda.FloatTensor if use_cuda else torch.FloatTensor
torch.set_default_tensor_type(tensor_type)

## Plot results

In [None]:
case_name = 'interpolation_v5'
# case_name = 'interpolation_v3'
# case_name = 'interpolation_full'
root_results_folder = glo.result_folder('3imgs_compression', 'interpolation_full', case_name)
print(root_results_folder)
if not os.path.exists(root_results_folder):
    print('Path does not exist: {}'.format(root_results_folder))

Get the input images

In [None]:
import glob
import skimage
    
def get_np_input_images(root_results_folder):
    """
    Return a numpy stack of input images. 
    """
    subfolders = glob.glob(os.path.join(root_results_folder, 'face_interpolation_test_*'))
    subfolder2 = glob.glob(subfolders[0] + '/*')
    input_folder = os.path.join(subfolder2[0], 'input_images')
    input_fpaths = glob.glob(input_folder + '/*')
    # need to sort. glob.glob does not necessarily return a sorted list
    arg = np.argsort([os.path.basename(fp) for fp in input_fpaths])
    input_fpaths = [input_fpaths[i] for i in arg]
    list_imgs = []
    for img_fpath in input_fpaths:
        img = skimage.io.imread(img_fpath)
#         print(img.shape)
        list_imgs.append(img)
    return np.stack(list_imgs, axis=0)

In [None]:
import skimage
in_imgs = get_np_input_images(root_results_folder)
for i in range(in_imgs.shape[0]):
    plt.figure()
    plt.imshow(in_imgs[i])

In [None]:
import re
def get_candidate_weights(root_results_folder):
    """
    Return a numpy array of size n x 3 (of type string) where 
    n is the number of candidate input weight vectors. 
    """
    case_folders = glob.glob(os.path.join(root_results_folder, 'face_interpolation_test*'))
    list_weights = []
#     pat = r'(\d+(\.\d+)?)_(\d+(\.\d+)?)_(\d(\.\d+)?)\.sh'
    pat = r'face_interpolation_test_(\d+(\.\d+)?)_(\d+(\.\d+)?)_(\d(\.\d+)?)'
    for shf in case_folders:
        sh_fname = os.path.basename(shf)
#         print(sh_fname)
        m = re.match(pat, sh_fname)
#         print(m.group(3))
        w = np.array([m.group(i) for i in [1,3,5]])
        list_weights.append(w)
    return np.stack(list_weights, axis=0)

In [None]:
np_str_weights = get_candidate_weights(root_results_folder)
np_str_weights

In [None]:
def get_output_img(root_results_folder, weights, verbose=False):
    """
    Return the output image (assume there is only one) given 
    one vector of input weights.
    
    * weights: a numpy array with one dimension. Each value is a string.
    """
    subfolder = 'face_interpolation_test_{}_{}_{}'.format(*weights)
    case_path = os.path.join(root_results_folder, subfolder )
    subfolders = glob.glob(case_path)
    subfolder2 = glob.glob(subfolders[0] + '/*')
    
    output_folder = os.path.join(subfolder2[0], 'output_images')
    iter_fpaths = glob.glob(output_folder + '/*')
#     print([os.path.basename(fname) for fname in iter_fpaths])
    arg = np.argsort([int(os.path.basename(fname)) for fname in iter_fpaths])
    # highest iteration
    last_iter_folder = iter_fpaths[arg[-1]]
    out_fpaths = glob.glob(last_iter_folder + '/*')
    # expect only one output image
    assert len(out_fpaths) == 1
    
    # load the image
    if verbose:
        print('Loading output image: {}'.format(out_fpaths[0]))
    return skimage.io.imread(out_fpaths[0])
    

In [None]:
# Test get_output_img()
img = get_output_img(root_results_folder, np_str_weights[0])
plt.imshow(img)

In [None]:
# load all output images. Each output image corresponds to one input weight vector
list_out_imgs = []
for i, weights in enumerate(np_str_weights):
    loaded = get_output_img(root_results_folder, weights)
    list_out_imgs.append(loaded)
    
# stack
out_imgs = np.stack(list_out_imgs, axis=0)

In [None]:
# show input images
print('Input images')
for i in range(in_imgs.shape[0]):
    plt.figure()
    plt.imshow(in_imgs[i])

In [None]:
for i in range(out_imgs.shape[0]):
    oimg = out_imgs[i]
    plt.figure(figsize=(3,3))
    plt.imshow(oimg)
    Wi = np_str_weights[i]
    plt.title('in weights: {}, {}, {}'.format(Wi[0], Wi[1], Wi[2]))

In [None]:
def proj2d(x):
    """
    x is a vector of 3 coordinates in the 2-dimensional simplex
    Project onto the equilateral triangle on 2d where the lower 
    left corner is at the origin. All three sides have unit length.
    
    """
    from functools import reduce
    xcol = x[:, np.newaxis]
    v1 = np.array([[-1, 1, 0.0]]).T
    v2 = np.array([[-1, 0.0, 1]]).T
#     A = np.hstack((v1, v2))/np.sqrt(2)
    A = np.hstack((v1, v2))
    ATA = A.T.dot(A)
    e1 = np.array([[1, 0, 0.0]]).T
#     print(A.T)
    P = np.linalg.inv(ATA).dot(A.T)
#     print(P)
    pro = P.dot( xcol-e1)
    
#     print(pro)
    # reconstruct with a new 2d axes
    e1_2d = np.array([[1, 0]]).T
    r2 = np.array([[np.cos(np.pi/3), np.sin(np.pi/3)]]).T
    R = np.hstack((e1_2d, r2))
#     print(R.dot(P))
#     return pro
    return R.dot(pro)
#     return np.array([[-0.5*(x[0]-1)+x[1]/2.0], [-0.28868*(x[0]-1+x[1])+0.57735*x[2]]])

In [None]:
w3d = np.array([4, 2, 2])/8.0
proj2d(w3d)

In [None]:
candidate_weights = np_str_weights.astype(np.float64)

In [None]:
def plot_triangle(in_imgs, candidate_weights,  out_imgs, img_zoom=0.25,
                 axis_margin=0.15, figsize=(8,8), verbose=False, 
                  input_imgs_distance=0.12, plot_triangle=True):
    """
    img_zoom: size of the images. Higher = larger.
    input_imgs_distance: distance to add to the input images at the 3 corners so that 
        they are further away in the directions opposite to the center
    """
    
    from matplotlib.offsetbox import (OffsetImage, AnnotationBbox)

    fig = plt.figure(figsize=figsize)
    ax = fig.subplots()
    
    # Fix the display limits to see everything
    ax.set_xlim(0-axis_margin, 1.0+axis_margin)
    ax.set_ylim(0-axis_margin, np.sqrt(1**2 - 0.5**2)+axis_margin)
    ax.set_aspect('equal')
#     ax.axis('off')

    # draw triangle
    if plot_triangle:
        plt.plot([0, 0.5], [0, np.sqrt(3/4.0)], 'k')
        plt.plot([0.5, 1], [np.sqrt(3/4.0), 0], 'k')
        plt.plot([0, 1], [0, 0], 'k')
    

    # Go through each weight vector and plot the output image
    for i in range(candidate_weights.shape[0]):
        Wi = candidate_weights[i]
        Wi_2d = proj2d(Wi)

        # output image
        oimg = out_imgs[i]
        imagebox = OffsetImage(oimg, zoom=img_zoom)
        imagebox.image.axes = ax

        xy = (Wi_2d[0, 0], Wi_2d[1, 0])
        if verbose:
            print('({:.2f}, {:.2f}, {:.2f}) |-> ({:.2f}, {:.2f})'.format(
                Wi[0], Wi[1], Wi[2],
                xy[0], xy[1]))
        ab = AnnotationBbox(imagebox, xy,
                            xybox=(1, -1),
        #                     xycoords='data',
                            boxcoords="offset points",
                            pad=0.0,
                            arrowprops=dict(
                                arrowstyle="->",
                                connectionstyle="angle,angleA=0,angleB=90,rad=3")
#                                 connectionstyle="angle,rad=3")
                            )
        ax.add_artist(ab)
        # annotate with text showing weights
#         ax.annotate(
#             '{}'.format(Wi), 
#             xy=Wi_2d.reshape(-1)+np.array([0,0.07]), 
#             fontsize=10)

    # plot the three input images at the three corners of the triangle
    # Do this at the last step so that the input images are on top.
    center = np.array([0.5, np.sqrt(3/4.0)/2])
    corners = np.array([
        [0, 0], [1, 0], [0.5, np.sqrt(3/4.0)]
    ])
    for i in range(3):
        # direction to push away from the center.
        dir_away =( corners[i] - center)
        dir_away /= np.linalg.norm(dir_away)
        dir_away *= input_imgs_distance
        
        imbox = OffsetImage(in_imgs[i], zoom=img_zoom)
        imbox.image.axes = ax
        ab = AnnotationBbox(imbox, corners[i]+dir_away,
                    xybox=(1, -1),
#                     xycoords='data',
                    boxcoords="offset points",
                    pad=0.3,
                    arrowprops=dict(
                        arrowstyle="->",
                        connectionstyle="angle,angleA=0,angleB=90,rad=3")
                    )
        ax.add_artist(ab)
        
        # annotate text x1, x2, x3
        ax.annotate('x{}'.format(i+1), xy=corners[i], fontsize=12)

    return ax

In [None]:
ax = plot_triangle(
    in_imgs, candidate_weights, out_imgs, img_zoom=0.28, 
    figsize=(13,13), axis_margin=0.2, input_imgs_distance=0.12,
    plot_triangle=False
)
# ax = plot_triangle(
#     in_imgs, candidate_weights, out_imgs, img_zoom=0.24, 
#     figsize=(11,11), axis_margin=0.2, input_imgs_distance=0.12,
#     plot_triangle=False
# )
ax.axis('off');

save_fname = 'm3_triangle_{}'.format(case_name)
plt.savefig(save_fname+'.pdf', bbox_inches='tight', dpi=300)
plt.savefig(save_fname+'.png', bbox_inches='tight', dpi=300)

## Candidate list of weight vectors

In [None]:
def plot_weights_triangle(points3d):
    projCW = np.hstack([proj2d(w) for w in points3d]).T
    # draw triangle
    plt.figure(figsize=(6,6))
    plt.plot([0, 0.5], [0, np.sqrt(3/4.0)], 'k')
    plt.plot([0.5, 1], [np.sqrt(3/4.0), 0], 'k')
    plt.plot([0, 1], [0, 0], 'k')

    for i in range(projCW.shape[0]):
        wi = projCW[i]
        plt.plot(wi[0], wi[1], 'bo')

    plt.axis('square');

In [None]:
def recursive_partition_candidate_weights(depth, corners=None):
    """
    depth: depth of the recursion (subpartitioning)
    corners: 3x3 numpy array where each row specifies one corner (in 3d)
    """
    if corners is None:
        corners = np.eye(3)
        
    if depth <= 1:
        return corners
    
    C = corners
    # find 3 mid points on the three edges
    m0 = (C[0]+C[1])/2.0
    m1 = (C[0]+C[2])/2.0
    m2 = (C[1]+C[2])/2.0
    mid = np.mean(C, 0)
    M = np.vstack([m0, m1, m2])
    
    if depth == 2:
        return np.vstack([C, M])
    
    # subpartitioning via recursion
    C0 = np.vstack([C[0], m0, m1])
    W0 = recursive_partition_candidate_weights(depth-1, C0)
    
    C1 = np.vstack([m0, C[1], m2])
    W1 = recursive_partition_candidate_weights(depth-1, C1)
    
    C2 = np.vstack([m1, m2, C[2]])
    W2 = recursive_partition_candidate_weights(depth-1, C2)
    
    # middle region
    Wm = recursive_partition_candidate_weights(depth-1, M)
    allpoints = np.vstack([C, M, W0, W1, W2, Wm])
    return allpoints

def remove_duplicate_rows(A, tol=1e-6):
    import scipy
    D = scipy.spatial.distance_matrix(A, A)
    n = A.shape[0]
    toremove = set()
    for i in range(n):
        ind = np.where(D[i, (i+1):] <= tol)[0]
        toremove.update(ind+i+1)
    tokeep = set(range(n)) - toremove
    return A[np.array(list(tokeep))]
    
def gen_candidate_weights(depth):
    C = recursive_partition_candidate_weights(depth)
    # remove duplicates
    return remove_duplicate_rows(C)

In [None]:
weights = gen_candidate_weights(depth=4)
plot_weights_triangle(weights)

In [None]:
for W in weights:
    print('{}/{}/{},'.format(W[0], W[1], W[2]))