In [None]:
%matplotlib widget
import torch
import matplotlib.pyplot as plt
from nerfstudio.utils.eval_utils import eval_setup
from pathlib import Path
import numpy as np
from nerfstudio.viewer.viewer import Viewer
from nerfstudio.configs.base_config import ViewerConfig

# config = Path("outputs/garfield_plushie/dig/2024-03-12_120409/config.yml")#appearance embed stride 14

# config = Path("outputs/garfield_plushie/dig/2024-03-12_124947/config.yml")#no appearance embed and patch stride 7
config = Path("outputs/garfield_plushie/dig/2024-03-12_125949/config.yml")#no appearance embed and patch stride 14


# config = Path("outputs/boops_mug/dig/2024-03-12_131401/config.yml")
_,pipeline,_,_ = eval_setup(config)
dino_loader = pipeline.datamanager.dino_dataloader
if hasattr(pipeline,"garfield_pipeline"):
    v = Viewer(ViewerConfig(default_composite_depth=False),config.parent,pipeline.datamanager.get_datapath(),pipeline)

In [None]:
from PIL import Image
from torchvision.transforms import ToTensor
import moviepy.editor as mpy
video_path = Path("garfield_move.mp4")
# video_path = Path("boops_lift.MOV")
motion_clip = mpy.VideoFileClip(str(video_path))
if motion_clip.rotation == 90:
    motion_clip = motion_clip.resize(motion_clip.size[::-1])
    motion_clip.rotation = 0
#print length of video in seconds
frame = motion_clip.get_frame(7.2)
pil_image = ToTensor()(Image.fromarray(frame))

# image_path = Path("/home/justin/nerfstudio/data/colorful_mugs_colmap/images/frame_00010.jpg")
# image_path = Path("/home/justin/nerfstudio/data/louvre_statue/images/frame_00170.png")
# pil_image = ToTensor()(Image.open(image_path))
img_pca_feats = dino_loader.get_pca_feats(pil_image.unsqueeze(0)).cuda().squeeze()
cam,data = pipeline.datamanager.next_train(0)
outputs = pipeline.model.get_outputs_for_camera(cam)
which_to_rgb_pca = torch.cat([outputs['dino'].view(-1,img_pca_feats.shape[-1]),img_pca_feats.view(-1,img_pca_feats.shape[-1])],dim=0)
_,_,rgb_pca = torch.pca_lowrank(which_to_rgb_pca.view(-1,which_to_rgb_pca.shape[-1]), q=3, niter=30)
from nerfstudio.utils.colormaps import apply_pca_colormap

fig,axs = plt.subplots(1,4,figsize=(20,10))
axs[0].imshow(pil_image.permute(1,2,0).cpu().numpy())
axs[1].imshow(data["image"].cpu().numpy())
axs[2].imshow(apply_pca_colormap(img_pca_feats,rgb_pca).cpu().numpy())
axs[3].imshow(apply_pca_colormap(outputs["dino"],rgb_pca).cpu().numpy())
click_coords = []
def onclick(event):
    # Check if the click was on one of the axes
    for i, ax in enumerate(axs):
        if event.inaxes == ax:
            if i == 3:
                click_coords.append((int(event.ydata),int(event.xdata)))
            break
fig.canvas.mpl_connect('button_press_event', onclick)
plt.show()

In [None]:
import cv2
from nerfstudio.utils.colormaps import apply_float_colormap
def get_matches(feats1,feats2,feats1_coords):
    matches,dist_maps = [],[]
    for i,pc in enumerate(feats1_coords):
        single_feat = feats1[pc[0],pc[1],:]
        dists = (single_feat - feats2).norm(dim=-1,keepdim=True)
        dist_maps.append(dists)
        argmin = torch.argmin(dists).cpu().numpy()
        argmin_coords = (argmin//dists.shape[1],argmin%dists.shape[1])
        matches.append(argmin_coords)
    return matches,dist_maps

def plot_matches(img1, img2, coords1, coords2, distance_maps = []):
    """
    Takes in the two images and returns a PIL of the matches
    coords are specified in terms of percent of the (h,w) of each img1, img2
    """
    img1 = img1.cpu().numpy()
    img2 = img2.cpu().numpy()
    h1,w1 = img1.shape[:2]
    h2,w2 = img2.shape[:2]
    #convert coords to pixel coords
    coords1 = [(int(c[1]*w1),int(c[0]*h1)) for c in coords1]
    coords2 = [(int(c[1]*w2),int(c[0]*h2)) for c in coords2]
    #concatenate the images using PIL
    img1 = Image.fromarray((img1*255).astype(np.uint8))
    img2 = Image.fromarray((img2*255).astype(np.uint8))
    W = w1 + w2 + w1*len(distance_maps)
    new_img = Image.new('RGB',(W,max(h1,h2)))
    new_img.paste(img1,(0,0))
    new_img.paste(img2,(w1,0))
    for i,d in enumerate(distance_maps):
        max_val = max([d.max().item() for d in distance_maps])
        d = apply_float_colormap(d/max_val).cpu().numpy()
        d = (d*255).astype(np.uint8)
        d = cv2.resize(d,(w1,h1))
        d = Image.fromarray(d)
        new_img.paste(d,(w1 + w2 + w1*i,0))
    new_img = np.array(new_img)
    #draw lines between the matches
    for i,(c1,c2) in enumerate(zip(coords1,coords2)):
        c2 = (c2[0]+w1,c2[1])
        cv2.line(new_img,c1,c2,(255,0,0),2)
        #draw a small x at each end of the line
        cv2.drawMarker(new_img,c1,(0,255,0),markerType=cv2.MARKER_CROSS,markerSize=10,thickness=1)
        cv2.drawMarker(new_img,c2,(0,255,0),markerType=cv2.MARKER_CROSS,markerSize=10,thickness=1)
        #draw a small number next to the startpoint of the line
        cv2.putText(new_img,str(i),c1,cv2.FONT_HERSHEY_SIMPLEX,1,(0,0,255),3)
    return new_img

def visualize_matches(pix_coords, feats1, feats2,rgb1,rgb2,include_distance_maps = False):
    matches,distance_maps = get_matches(feats1, feats2, pix_coords)
    coords_norm1 = [((p[0]+0.5)/feats1.shape[0],(p[1]+0.5)/feats1.shape[1]) for p in pix_coords]
    coords_norm2 = [((p[0]+0.5)/feats2.shape[0],(p[1]+0.5)/feats2.shape[1]) for p in matches]
    return plot_matches(rgb1, rgb2, coords_norm1, coords_norm2, distance_maps if include_distance_maps else [])


gif_frames = []
pix_coords = click_coords + []
t_start = 1
t_end = 8
fps = 10
for t in np.linspace(t_start,t_end,(t_end-t_start)*fps):
    frame = motion_clip.get_frame(t)
    pil_image = ToTensor()(Image.fromarray(frame))
    img_pca_feats = dino_loader.get_pca_feats(pil_image.unsqueeze(0)).cuda().squeeze()
    img = visualize_matches(pix_coords, outputs["dino"], img_pca_feats, outputs['rgb'], pil_image.permute(1,2,0))
    gif_frames.append(img)
#save with mediapy
out_clip = mpy.ImageSequenceClip(gif_frames, fps=fps)
out_clip.write_videofile("garfield_matches.mp4", fps=fps)

In [None]:
from matplotlib.patches import ConnectionPatch
pix_coords = [(10,30),(27,30)]
pix_matches = []
_,diff_axs = plt.subplots(2,len(pix_coords),figsize=(20,10))
for i,pc in enumerate(pix_coords):
    img_pca_feat = img_pca_feats[pc[0],pc[1],:]
    distance_img = (outputs['dino'] - img_pca_feat).pow(2).sum(dim=-1).sqrt()[...,None]
    img_distance = (img_pca_feats - img_pca_feat).pow(2).sum(dim=-1).sqrt()
    diff_axs[0,i].imshow(distance_img.cpu().numpy().squeeze(),vmin = 10,vmax=30)
    diff_axs[1,i].imshow(img_distance.cpu().numpy(),vmin = 10,vmax=30)
    argmin = torch.argmin(distance_img).cpu().numpy()
    argmin_coords = (argmin//distance_img.shape[1],argmin%distance_img.shape[1])
    pix_matches.append(argmin_coords)
fig,axs = plt.subplots(1,2,figsize=(20,10))
#visualize the input click on the original image
axs[0].imshow(apply_pca_colormap(img_pca_feats,rgb_pca).cpu().numpy())
axs[1].imshow(apply_pca_colormap(outputs["dino"],rgb_pca).cpu().numpy())
for i in range(len(pix_matches)):
    color = np.random.rand(3)
    transFigure = fig.transFigure.inverted()
    con = ConnectionPatch(xyA=(pix_coords[i][1],pix_coords[i][0]), xyB=(pix_matches[i][1],pix_matches[i][0]),
                           coordsA="data", coordsB="data",
                      axesA=axs[0], axesB=axs[1], color=color)
    fig.add_artist(con)
    axs[0].scatter(pix_coords[i][1],pix_coords[i][0],c=color,s=300)
    axs[1].scatter(pix_matches[i][1],pix_matches[i][0],c=color,s=300)
plt.show()