In [None]:
import torch
from PIL import Image
import torchvision.transforms as transforms
import torchvision.transforms.functional as TF
from models import KeyPointVAE
from utils.util_func import reparameterize
import matplotlib.pyplot as plt
# from matplotlib.widgets import Slider, Button
# import matplotlib as mpl
from matplotlib import pyplot as plt
import numpy as np
from PIL import Image
import numpy as np
import cv2

from utils.util_func import create_tracked_kp_video, plot_matched_keypoints , plot_keypoints_on_image, plot_bb_on_image_from_masks
from utils.util_func import create_masks_fast

## Define parameters / Load Model

In [None]:

# Adjust path to checkpoint
path_to_model_ckpt = './checkpoints/dlp_calvin_d_d_gauss_pointnetpp_best.pth'
# index of image used to test keypoint assosiation
image_idx = 6
path_to_images = ['./checkpoints/sample_images/calvin/calvin_scene_A.png',
                          './checkpoints/sample_images/calvin/calvin_scene_B.png',
                          './checkpoints/sample_images/calvin/calvin_scene_C.png',
                          './checkpoints/sample_images/calvin/calvin_scene_D.png',
                          './checkpoints/sample_images/calvin/calvin_scene_D_2.png',
                          './checkpoints/sample_images/calvin/calvin_scene_D_3.png',
                          './checkpoints/sample_images/calvin/calvin_scene_D_4.png']

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
use_logsoftmax = False
pad_mode = 'replicate'
sigma = 0.1
dropout = 0.0
n_kp = 1
kp_range = (-1, 1)
kp_activation = "tanh"
mask_threshold = 0.2
learn_order = False
image_size = 128
ch = 3
enc_channels = [32, 64, 128, 256]
prior_channels = (16, 32, 64)
imwidth = 160
crop = 16
n_kp_enc = 10 
n_kp_prior = 20 
use_object_enc = True
use_object_dec = True
learned_feature_dim = 10
patch_size = 16
anchor_s = 0.25
dec_bone = "gauss_pointnetpp"
exclusive_patches = False

model = KeyPointVAE(cdim=ch, enc_channels=enc_channels, prior_channels=prior_channels,
                    image_size=image_size, n_kp=n_kp, learned_feature_dim=learned_feature_dim,
                    use_logsoftmax=use_logsoftmax, pad_mode=pad_mode, sigma=sigma,
                    dropout=dropout, dec_bone=dec_bone, patch_size=patch_size, n_kp_enc=n_kp_enc,
                    n_kp_prior=n_kp_prior, kp_range=kp_range, kp_activation=kp_activation,
                    mask_threshold=mask_threshold, use_object_enc=use_object_enc,
                    exclusive_patches=exclusive_patches, use_object_dec=use_object_dec, anchor_s=anchor_s,
                    learn_order=learn_order).to(device)

model.load_state_dict(torch.load(path_to_model_ckpt, map_location=device), strict=False)
model.eval()
print("loaded model from checkpoint")

image_idx = min(image_idx, len(path_to_images) - 1)
path_to_image = path_to_images[image_idx]
im = Image.open(path_to_image)

im = im.convert('RGB')
im = im.resize((image_size, image_size), Image.BICUBIC)
trans = transforms.ToTensor()
data = trans(im)
data = data.unsqueeze(0).to(device)
x = data
logvar_threshold = 13.0

print('Input image')
plt.imshow(im)
plt.axis('off') 
plt.show()

with torch.no_grad():
        deterministic = True

        model_output = model(x, x_prior=x)
        mu_p = model_output['kp_p']
        gmap = model_output['gmap']
        mu = model_output['mu']
        logvar = model_output['logvar']
        rec_x = model_output['rec']
        mu_features = model_output['mu_features']
        logvar_features = model_output['logvar_features']
        # object stuff
        dec_objects_original = model_output['dec_objects_original']
        cropped_objects_original = model_output['cropped_objects_original']
        '''
        enc_out = model.encode_all(data, return_heatmap=True, deterministic=deterministic)
        mu, logvar, kp_heatmap, mu_features, logvar_features, obj_on, order_weights = enc_out
        '''
        if deterministic:
            z = mu
            z_features = mu_features
        else:
            z = reparameterize(mu, logvar)
            z_features = reparameterize(mu_features, logvar_features)
        # top-k
        logvar_sum = logvar.sum(-1)
        logvar_topk = torch.topk(logvar_sum, k=5, dim=-1, largest=False)
        indices = logvar_topk[1]  # [batch_size, topk]
        batch_indices = torch.arange(mu.shape[0]).view(-1, 1).to(mu.device)
        topk_kp = mu[batch_indices, indices]
        # bounding boxes
        masks = create_masks_fast(mu[:, :-1].detach(), anchor_s=model.anchor_s, feature_dim=x.shape[-1])
        masks = torch.where(masks < mask_threshold, 0.0, 1.0)
        bb_scores = -1 * logvar_sum
        hard_threshold = bb_scores.mean()

        # print(masks[0].shape)
        # print(masks[0])
        # print(x[0].shape)
        # print(bb_scores[0].shape)
        # print(bb_scores[0])
        # img_with_masks_nms = plot_bb_on_image_from_masks_nms(masks[0], x[0], scores=bb_scores[0])
        img_with_masks = plot_bb_on_image_from_masks(masks[0], x[0])
        #img_with_masks_nms, nms_ind = plot_bb_on_image_batch_from_masks_nms(masks, x,scores=bb_scores)
N = mu.shape[1]
xmin = 0
xmax = image_size

x = np.linspace(xmin, xmax, N)

mu = mu.clamp(kp_range[0], kp_range[1])
original_mu = mu.clone()
mu = (mu - kp_range[0]) / (kp_range[1] - kp_range[0])
xvals = mu[0, :-1, 1].data.cpu().numpy() * (image_size - 1)
yvals = mu[0, :-1, 0].data.cpu().numpy() * (image_size - 1)
if learned_feature_dim > 0:
    feature_1_vals = mu_features[0, :, -1].data.cpu().numpy()
    # set up a plot for topk

topk_kp = topk_kp.clamp(kp_range[0], kp_range[1])
topk_kp = (topk_kp - kp_range[0]) / (kp_range[1] - kp_range[0])
xvals_topk = topk_kp[0, :, 1].data.cpu().numpy() * (image_size - 1)
yvals_topk = topk_kp[0, :, 0].data.cpu().numpy() * (image_size - 1)


#print(topk_kp)
#print(data.shape)
#img_with_kp_topk = plot_keypoints_on_image_batch(topk_kp, data)
print(rec_x.shape)
print('Reconstruction')
plt.imshow(rec_x.cpu().squeeze().permute(1, 2, 0).numpy())
plt.axis('off') 
plt.show()

print('top 5 keypoints')
img_with_kp_topk = plot_keypoints_on_image(topk_kp[0], data[0])
plt.imshow(img_with_kp_topk)
plt.axis('off') 
plt.show()


print('Keypoints')
img_with_kp = plot_keypoints_on_image(mu[:, :-1][0].clamp(min=kp_range[0], max=kp_range[1]), data[0])
plt.imshow(img_with_kp)
plt.axis('off') 
plt.show()

'''
img_with_kp_p = plot_keypoints_on_image(mu_p[0], data[0])
plt.imshow(img_with_kp_p)
plt.axis('off') 
plt.show()
'''
print('BB')
plt.imshow(img_with_masks)
plt.axis('off')
plt.show()

print('Masks')
dec_objects = model_output['dec_objects']
plt.imshow(dec_objects[0].permute(1, 2, 0).cpu().numpy())
plt.axis('off')
plt.show()
# dec_objects[:max_imgs, -3:]

## Define Functions

In [None]:

def show_img_with_kp(img, model):
    model.eval()

    im = img
    im = im.convert('RGB')
    im = im.resize((image_size, image_size), Image.BICUBIC)
    trans = transforms.ToTensor()
    data = trans(im)
    data = data.unsqueeze(0).to(device)

    with torch.no_grad():
            model_output = model(data, x_prior=data)

            mu_p = model_output['kp_p']
            gmap = model_output['gmap']
            mu = model_output['mu']
            logvar = model_output['logvar']
            rec_x = model_output['rec']
            mu_features = model_output['mu_features']
            logvar_features = model_output['logvar_features']

            #print('kp_p: ', mu_p)
            # print('mu: ', mu)
            # print('mu_features: ', mu_features)
            #print('logvar: ', logvar)
            #print('logvar_features: ', logvar_features)
            #print('rec: ', rec_x)

            # top-k
            logvar_sum = logvar.sum(-1)
            logvar_topk = torch.topk(logvar_sum, k=10, dim=-1, largest=False)
            indices = logvar_topk[1]  # [batch_size, topk]
            batch_indices = torch.arange(mu.shape[0]).view(-1, 1).to(mu.device)
            topk_kp = mu[batch_indices, indices]

    # print('mu: ', mu[0])

    topk_kp = mu.clamp(kp_range[0], kp_range[1])
    topk_kp = (topk_kp - kp_range[0]) / (kp_range[1] - kp_range[0])

    # print('topk_kp: ', topk_kp[0])

    img_with_kp_topk = plot_keypoints_on_image(topk_kp[0], data[0])

    dec_objects = model_output['dec_objects']
    dec_objects = dec_objects[0].permute(1, 2, 0).cpu().numpy()

    combined_img = np.hstack((img_with_kp_topk / 255.0, dec_objects))

    plt.imshow(combined_img)
    plt.axis('off') 
    plt.show()

def associate_kp(img1, img2, model):
    model.eval()

    img1 = img1.convert('RGB')
    img1 = img1.resize((image_size, image_size), Image.BICUBIC)
    trans = transforms.ToTensor()
    data1 = trans(img1)
    data1 = data1.unsqueeze(0).to(device)

    img2 = img2.convert('RGB')
    img2 = img2.resize((image_size, image_size), Image.BICUBIC)
    trans = transforms.ToTensor()
    data2 = trans(img2)
    data2 = data2.unsqueeze(0).to(device)

    with torch.no_grad():
        model_output_1 = model(data1, x_prior=data1)
        model_output_2 = model(data2, x_prior=data2)

        mu_coords_1 = model_output_1['mu']
        mu_coords_2 = model_output_2['mu']
        mu_features_1 = model_output_1['mu_features']
        mu_features_2 = model_output_2['mu_features']
        logvar_1 = model_output_1['logvar']
        logvar_2 = model_output_2['logvar']

       # mu_1 = list(zip(mu_coords_1[0], mu_features_1[0]))
       # mu_2 = list(zip(mu_coords_2[0], mu_features_2[0]))
        mu_1 = list(zip(mu_coords_1[0].cpu().numpy(), mu_features_1[0].cpu().numpy()))
        mu_2 = list(zip(mu_coords_2[0].cpu().numpy(), mu_features_2[0].cpu().numpy()))

        matched_KPs = []
        for kp_1, feature_1 in mu_1:
            best_match = None
            best_sim = float('inf')
            for kp_2, feature_2 in mu_2:
                # Euclidean distance 
                geometric_sim = np.linalg.norm(kp_1 - kp_2)
                # Cosine Similarity
                cos_sim = np.dot(feature_1, feature_2) / ((np.linalg.norm(feature_1)) * (np.linalg.norm(feature_2)))
                semantic_sim = abs(1 - cos_sim)

                total_sim = geometric_sim + semantic_sim

                if total_sim < best_sim:
                    best_sim = total_sim
                    best_match = kp_2

                # print('simulatiry: ', total_sim, ' : ', geometric_sim, ' : ', semantic_sim)
            sim_threshold = 0.2
            if best_sim < sim_threshold:
                matched_KPs.append((kp_1, best_match, best_sim))

        # filter keypoints one-to-one
        matched_KPs.sort(key=lambda x: x[2])
        matched_KPs_filtered = []
        matched_kp_1 = []
        matched_kp_2 = []
        for kp_1, kp_2, sim in matched_KPs:
            if (kp_1[0], kp_1[1]) not in matched_kp_1 and (kp_2[0], kp_2[1]) not in matched_kp_2: 
                matched_KPs_filtered.append((kp_1, kp_2, sim))
                matched_kp_1.append((kp_1[0], kp_1[1]))
                matched_kp_2.append((kp_2[0], kp_2[1]))

        # print(matched_KPs_filtered)


    #topk_kp = topk_kp.clamp(kp_range[0], kp_range[1])
    #topk_kp = (topk_kp - kp_range[0]) / (kp_range[1] - kp_range[0])
    
    matched_kp_tuple = []
    for kp_1, kp_2, sim in matched_KPs_filtered:
        kp_1 = ((kp_1 - kp_range[0]) / (kp_range[1] - kp_range[0])).clip(0, 1)
        kp_2 = ((kp_2 - kp_range[0]) / (kp_range[1] - kp_range[0])).clip(0, 1)
        kp_tuple = (kp_1, kp_2)
        matched_kp_tuple.append(kp_tuple)

    #matched_kp_tuple = torch.Tensor(matched_kp_tuple)

    image_matched = plot_matched_keypoints(matched_kp_tuple, data2[0])
    plt.imshow(image_matched)
    plt.axis('off') 
    plt.show()  

    kp_1_list = []
    kp_2_list = []

    for kp_1, kp_2, sim in matched_KPs_filtered:
        kp_1_list.append(kp_1)
        kp_2_list.append(kp_2)

        topk_kp_1 = torch.tensor([kp_1]).clamp(kp_range[0], kp_range[1])
        topk_kp_1 = (topk_kp_1 - kp_range[0]) / (kp_range[1] - kp_range[0])
        img_with_kp_topk_1 = plot_keypoints_on_image(topk_kp_1, data1[0])

        topk_kp_2 = torch.tensor([kp_2]).clamp(kp_range[0], kp_range[1])
        topk_kp_2 = (topk_kp_2 - kp_range[0]) / (kp_range[1] - kp_range[0])   
        img_with_kp_topk_2 = plot_keypoints_on_image(topk_kp_2, data2[0])

        combined_img = np.hstack((img_with_kp_topk_1,  img_with_kp_topk_2))

        plt.imshow(combined_img)
        plt.axis('off') 
        plt.show()  
        
def get_index(path, ann):
        annotations = np.load('{}/lang_annotations/auto_lang_ann.npy'.format(path), allow_pickle=True).item()
        annotations = list(zip(annotations["info"]["indx"], annotations["language"]["ann"]))
        for annotation in annotations:
                if annotation[1] != ann:
                        continue
                return annotation[0][0]
        return 0

'''
    newer trackKeyPoints_2 function below
'''
def track_keypoints(image_array, model):
    model.eval()
    tracked_keypoints = [] 
    
    fourcc = cv2.VideoWriter_fourcc(*'XVID')
    fps = 20 
    frame_size = (128, 128)
    out = cv2.VideoWriter('tracked_keypoints_video.avi', fourcc, fps, frame_size)
    
    print(len(image_array))
    for i in range(len(image_array) - 1):
        
        img1 = Image.fromarray(np.uint8(image_array[i]))
        img1 = img1.convert('RGB')
        img1 = img1.resize((image_size, image_size), Image.BICUBIC)
        trans = transforms.ToTensor()
        data1 = trans(img1)
        data1 = data1.unsqueeze(0).to(device)

        img2 = Image.fromarray(np.uint8(image_array[i + 1]))
        img2 = img2.convert('RGB')
        img2 = img2.resize((image_size, image_size), Image.BICUBIC)
        data2 = trans(img2)
        data2 = data2.unsqueeze(0).to(device)

        with torch.no_grad():
            model_output_1 = model(data1, x_prior=data1)
            model_output_2 = model(data2, x_prior=data2)

            mu_coords_1 = model_output_1['mu']
            mu_coords_2 = model_output_2['mu']
            mu_features_1 = model_output_1['mu_features']
            mu_features_2 = model_output_2['mu_features']

            mu_1 = list(zip(mu_coords_1[0].cpu().numpy(), mu_features_1[0].cpu().numpy()))
            mu_2 = list(zip(mu_coords_2[0].cpu().numpy(), mu_features_2[0].cpu().numpy()))

            matched_KPs = []
            for kp_1, feature_1 in mu_1:
                best_match = None
                best_sim = float('inf')
                for kp_2, feature_2 in mu_2:
                    # Calculate similarity
                    geometric_sim = np.linalg.norm(kp_1 - kp_2)
                    cos_sim = np.dot(feature_1, feature_2) / (np.linalg.norm(feature_1) * np.linalg.norm(feature_2))
                    semantic_sim = abs(1 - cos_sim)
                    total_sim = geometric_sim + semantic_sim
                    if total_sim < best_sim:
                        best_sim = total_sim
                        best_match = kp_2
                sim_threshold = 0.2
                if best_sim < sim_threshold:
                    matched_KPs.append((kp_1, best_match, best_sim))

            # Filter keypoints to ensure one-to-one associations
            matched_KPs.sort(key=lambda x: x[2])
            matched_KPs_filtered = []
            matched_kp_1 = []
            matched_kp_2 = []
            for kp_1, kp_2, sim in matched_KPs:
                if (kp_1[0], kp_1[1]) not in matched_kp_1 and (kp_2[0], kp_2[1]) not in matched_kp_2: 
                    matched_KPs_filtered.append((kp_1, kp_2, sim))
                    matched_kp_1.append((kp_1[0], kp_1[1]))
                    matched_kp_2.append((kp_2[0], kp_2[1]))
            
            tracked_keypoints.append(matched_KPs_filtered)

        matched_kp_tuple = []


        for kp_1, kp_2, sim in matched_KPs_filtered:
            kp_1 = ((kp_1 - kp_range[0]) / (kp_range[1] - kp_range[0])).clip(0, 1)
            kp_2 = ((kp_2 - kp_range[0]) / (kp_range[1] - kp_range[0])).clip(0, 1)
            kp_tuple = (kp_1, kp_2)
            matched_kp_tuple.append(kp_tuple)
        image_matched = plot_matched_keypoints(matched_kp_tuple, data2[0])
        plt.imshow(image_matched)
        plt.axis('off') 
        plt.show()  
        out.write(image_matched[:, :, ::-1])

    out.release()
    return tracked_keypoints


def track_keypoints_2(image_array, model):
    model.eval()
    tracked_keypoints = [] 
    keypoint_colors = {}
    
    for i in range(len(image_array) - 1):
        
        img1 = Image.fromarray(np.uint8(image_array[i]))
        img1 = img1.convert('RGB')
        img1 = img1.resize((image_size, image_size), Image.BICUBIC)
        trans = transforms.ToTensor()
        data1 = trans(img1)
        data1 = data1.unsqueeze(0).to(device)

        img2 = Image.fromarray(np.uint8(image_array[i + 1]))
        img2 = img2.convert('RGB')
        img2 = img2.resize((image_size, image_size), Image.BICUBIC)
        data2 = trans(img2)
        data2 = data2.unsqueeze(0).to(device)

        with torch.no_grad():
            model_output_1 = model(data1, x_prior=data1)
            model_output_2 = model(data2, x_prior=data2)

            mu_coords_1 = model_output_1['mu']
            mu_coords_2 = model_output_2['mu']
            mu_features_1 = model_output_1['mu_features']
            mu_features_2 = model_output_2['mu_features']
            mu_1 = list(zip(mu_coords_1[0].cpu().numpy(), mu_features_1[0].cpu().numpy()))
            mu_2 = list(zip(mu_coords_2[0].cpu().numpy(), mu_features_2[0].cpu().numpy()))


            matched_KPs = []
            for kp_1, feature_1 in mu_1:
                best_match = None
                best_sim = float('inf')
                for kp_2, feature_2 in mu_2:
                    # Calculate similarity
                    geometric_sim = np.linalg.norm(kp_1 - kp_2)
                    cos_sim = np.dot(feature_1, feature_2) / (np.linalg.norm(feature_1) * np.linalg.norm(feature_2))
                    semantic_sim = abs(1 - cos_sim)
                    total_sim = geometric_sim + semantic_sim
                    if total_sim < best_sim:
                        best_sim = total_sim
                        best_match = kp_2
                sim_threshold = 0.2
                if best_sim < sim_threshold:
                    matched_KPs.append((kp_1, best_match, best_sim))

            # Filter keypoints to ensure one-to-one associations
            matched_KPs.sort(key=lambda x: x[2])
            matched_KPs_filtered = []
            matched_kp_1 = []
            matched_kp_2 = []
            for kp_1, kp_2, sim in matched_KPs:
                if (kp_1[0], kp_1[1]) not in matched_kp_1 and (kp_2[0], kp_2[1]) not in matched_kp_2: 
                    matched_KPs_filtered.append((kp_1, kp_2, sim))
                    matched_kp_1.append((kp_1[0], kp_1[1]))
                    matched_kp_2.append((kp_2[0], kp_2[1]))
            

        matched_kp_tuple = []
        for kp_1, kp_2, sim in matched_KPs_filtered:
            kp_1 = ((kp_1 - kp_range[0]) / (kp_range[1] - kp_range[0])).clip(0, 1)
            kp_2 = ((kp_2 - kp_range[0]) / (kp_range[1] - kp_range[0])).clip(0, 1)

            prev_kp = tuple(kp_1)
            next_kp = tuple(kp_2)
            if prev_kp in keypoint_colors:
                color = keypoint_colors[prev_kp]
                keypoint_colors[next_kp] = color
            else:
                color = tuple(np.random.randint(0, 256, size=3))
                keypoint_colors[next_kp] = color

            kp_match_color = {'kp_1' : kp_1, 
                              'kp_2' : kp_2,
                              'color' : color}
            matched_kp_tuple.append(kp_match_color)


        tracked_keypoints.append({'kps' : matched_kp_tuple, 
                                  'img' : data2[0]})

    return tracked_keypoints


## Test keypoint assosiation between 2 images

In [None]:
model.load_state_dict(torch.load(path_to_model_ckpt, map_location=device), strict=False)
print("loaded model from checkpoint")

path = '/media/tim/E/datasets/task_D_D/validation'
target_ann = "pick up the pink block from the drawer"
start_index = get_index(path, target_ann)
        
data_1 = np.load(f"{path}/episode_{start_index + 40:07d}.npz", allow_pickle=True)
img_arr_1 = data_1['rgb_static']
img_1 = Image.fromarray(np.uint8(img_arr_1))
data_2 = np.load(f"{path}/episode_{start_index + 41:07d}.npz", allow_pickle=True)
img_arr_2 = data_2['rgb_static']
        
img_2 = Image.fromarray(np.uint8(img_arr_2))
show_img_with_kp(img_1, model)        
associate_kp(img_1, img_2, model)

## Create Video Clip

In [None]:
path = '/media/tim/E/datasets/task_D_D/validation'
target_ann = "open the drawer"
video_out = 'tracked_keypoints_video_1.avi'
video_length = 200

start_index = get_index(path, target_ann)
print(start_index)
image_array = []       
for i in range(0, video_length):
    data = np.load(f"{path}/episode_{start_index + i:07d}.npz", allow_pickle=True)
    img = data['rgb_static']   
    image_array.append(img)

# trackedPoints = trackKeyPoints(image_array, model)
# print(trackedPoints[0])


tracked_points = track_keypoints_2(image_array, model)
# print(tracked_points)

video_img_array = create_tracked_kp_video(tracked_points)
fourcc = cv2.VideoWriter_fourcc(*'XVID')
fps = 20 
frame_size = (128, 128)
out = cv2.VideoWriter(video_out, fourcc, fps, frame_size)
for img in video_img_array:
  #  plt.imshow(img)
  #  plt.axis('off') 
  #  plt.show()  
    out.write(img[:, :, ::-1])
out.release()    
print(f'done: created video in project {video_out}')
    