In [None]:
'''
This cell loads the model from the config file and initializes the viewer
'''
%matplotlib widget
import torch
import matplotlib.pyplot as plt
from nerfstudio.utils.eval_utils import eval_setup
from pathlib import Path
import numpy as np
from nerfstudio.viewer.viewer import Viewer
from nerfstudio.configs.base_config import ViewerConfig

# config = Path("outputs/garfield_plushie/dig/2024-03-12_153919/config.yml")#denoised with patch stride 7 and enable_pe
# config = Path("outputs/garfield_plushie/dig/2024-03-12_155712/config.yml")#denoised with stride 7 and no enable_pe
# config = Path("outputs/garfield_plushie/dig/2024-03-12_160428/config.yml")#non-denoised patch size 7

# config = Path("outputs/nerfgun/dig/2024-03-12_162138/config.yml")#denoised nerfgun, patch stride 7
# config = Path("outputs/nerfgun/dig/2024-03-12_163729/config.yml") #denoised nerfgun, patch stride 14

# config = Path("outputs/nerfgun2/dig/2024-03-12_165027/config.yml")#nerfgun2 stride 14, denoised
# config = Path("outputs/nerfgun2/dig/2024-03-12_171730/config.yml")#stride 7
config = Path("outputs/nerfgun2/dig/2024-03-13_151932/config.yml")#stride 7, with garfield

# config = Path("outputs/boops_mug/dig/2024-03-13_141111/config.yml")#denoise stride 7
_,pipeline,_,_ = eval_setup(config)
pipeline.eval()
dino_loader = pipeline.datamanager.dino_dataloader
v = Viewer(ViewerConfig(default_composite_depth=False),config.parent,pipeline.datamanager.get_datapath(),pipeline)

In [None]:
'''
This cell loads the video, picks a random train cam, and shows the pca of the video frame and the random train view.
you can then click on the pca feats in the rendered view to pick keypoints to match in the video
'''
from PIL import Image
from torchvision.transforms import ToTensor
import cv2
import moviepy.editor as mpy

video_path = Path("garfield_move.mp4")
assert video_path.exists()
motion_clip = cv2.VideoCapture(str(video_path.absolute()))
def get_vid_frame(cap,timestamp):
    fps = cap.get(cv2.CAP_PROP_FPS)
    
    # Calculate the frame number based on the timestamp and fps
    frame_number = min(int(timestamp * fps),int(cap.get(cv2.CAP_PROP_FRAME_COUNT)-1))
    
    # Set the video position to the calculated frame number
    cap.set(cv2.CAP_PROP_POS_FRAMES, frame_number)
    
    # Read the frame
    success, frame = cap.read()
    # convert BGR to RGB
    frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
    return frame
frame = get_vid_frame(motion_clip,7.2)
pil_image = ToTensor()(Image.fromarray(frame))

# image_path = Path("/home/justin/nerfstudio/data/colorful_mugs_colmap/images/frame_00010.jpg")
# image_path = Path("/home/justin/nerfstudio/data/louvre_statue/images/frame_00170.png")
# pil_image = ToTensor()(Image.open(image_path))
img_pca_feats = dino_loader.get_pca_feats(pil_image.unsqueeze(0)).cuda().squeeze()

cam,data = pipeline.datamanager.next_train(0)
outputs = pipeline.model.get_outputs_for_camera(cam)
which_to_rgb_pca = torch.cat([outputs['dino'].view(-1,img_pca_feats.shape[-1]),img_pca_feats.view(-1,img_pca_feats.shape[-1])],dim=0)
_,_,rgb_pca = torch.pca_lowrank(which_to_rgb_pca.view(-1,which_to_rgb_pca.shape[-1]), q=3, niter=30)
from nerfstudio.utils.colormaps import apply_pca_colormap

fig,axs = plt.subplots(1,4,figsize=(20,5))
axs[0].imshow(pil_image.permute(1,2,0).cpu().numpy())
axs[1].imshow(data["image"].cpu().numpy())
axs[2].imshow(apply_pca_colormap(img_pca_feats,rgb_pca).cpu().numpy())
axs[3].imshow(apply_pca_colormap(outputs["dino"],rgb_pca).cpu().numpy())
click_coords = []
def onclick(event):
    # Check if the click was on one of the axes
    for i, ax in enumerate(axs):
        if event.inaxes == ax:
            if i == 3:
                click_coords.append((int(event.ydata),int(event.xdata)))
            break
fig.canvas.mpl_connect('button_press_event', onclick)
plt.show()

In [None]:
'''
This cell defines some useful functions for finding matches in frames and visualizing them
'''
    
from nerfstudio.utils.colormaps import apply_float_colormap
def get_matches(feats1,feats2,feats1_coords):
    matches,dist_maps = [],[]
    for i,pc in enumerate(feats1_coords):
        single_feat = feats1[pc[0],pc[1],:]
        dists = (single_feat - feats2).norm(dim=-1,keepdim=True)
        dist_maps.append(dists)
        argmin = torch.argmin(dists).cpu().numpy()
        argmin_coords = (argmin//dists.shape[1],argmin%dists.shape[1])
        matches.append(argmin_coords)
    return matches,dist_maps

def plot_matches(img1, img2, coords1, coords2, distance_maps = []):
    """
    Takes in the two images and returns a PIL of the matches
    coords are specified in terms of percent of the (h,w) of each img1, img2
    """
    img1 = img1.cpu().numpy()
    img2 = img2.cpu().numpy()
    h1,w1 = img1.shape[:2]
    h2,w2 = img2.shape[:2]
    #convert coords to pixel coords
    coords1 = [(int(c[1]*w1),int(c[0]*h1)) for c in coords1]
    coords2 = [(int(c[1]*w2),int(c[0]*h2)) for c in coords2]
    #concatenate the images using PIL
    img1 = Image.fromarray((img1*255).astype(np.uint8))
    img2 = Image.fromarray((img2*255).astype(np.uint8))
    W = w1 + w2 + w1*len(distance_maps)
    new_img = Image.new('RGB',(W,max(h1,h2)))
    new_img.paste(img1,(0,0))
    new_img.paste(img2,(w1,0))
    for i,d in enumerate(distance_maps):
        max_val = max([d.max().item() for d in distance_maps])
        d = apply_float_colormap(d/max_val).cpu().numpy()
        d = (d*255).astype(np.uint8)
        d = cv2.resize(d,(w1,h1))
        d = Image.fromarray(d)
        new_img.paste(d,(w1 + w2 + w1*i,0))
    new_img = np.array(new_img)
    #draw lines between the matches
    for i,(c1,c2) in enumerate(zip(coords1,coords2)):
        c2 = (c2[0]+w1,c2[1])
        cv2.line(new_img,c1,c2,(255,0,0),2)
        #draw a small x at each end of the line
        cv2.drawMarker(new_img,c1,(0,255,0),markerType=cv2.MARKER_CROSS,markerSize=10,thickness=1)
        cv2.drawMarker(new_img,c2,(0,255,0),markerType=cv2.MARKER_CROSS,markerSize=10,thickness=1)
        #draw a small number next to the startpoint of the line
        cv2.putText(new_img,str(i),c1,cv2.FONT_HERSHEY_SIMPLEX,1,(0,0,255),3)
    return new_img

def plot_distribution(img1, img2, coords1, distance_maps, cutoff = 3.0):
    img1 = img1.cpu().numpy()
    img2 = img2.cpu().numpy()
    h1,w1 = img1.shape[:2]
    h2,w2 = img2.shape[:2]
    #convert coords to pixel coords
    coords1 = [(int(c[1]*w1),int(c[0]*h1)) for c in coords1]
    #concatenate the images using PIL
    img1 = Image.fromarray((img1*255).astype(np.uint8))
    img2 = Image.fromarray((img2*255).astype(np.uint8))
    W = w1 + w2
    new_img = Image.new('RGB',(W,max(h1,h2)))
    new_img.paste(img1,(0,0))
    new_img.paste(img2,(w1,0))
    #generate unique colors for each distance map
    colors = np.random.randint(0,255,(len(distance_maps),3))
    for i,d in enumerate(distance_maps):
        d = d.cpu().numpy()
        alpha_img = np.zeros((h2,w2),dtype=np.uint8)
        distribution_img = np.zeros((h2,w2,3),dtype=np.uint8)
        upscaled_d = cv2.resize(d,(w2,h2))
        alpha_img[upscaled_d<=cutoff] = upscaled_d[upscaled_d<=cutoff]*255/cutoff
        distribution_img[upscaled_d<=cutoff] = colors[i]
        #paste the distribution image onto new_img weighted by alpha
        new_img.paste(Image.fromarray(distribution_img),(w1,0),Image.fromarray(alpha_img,mode='L'))
    new_img = np.array(new_img)
    #draw a colored dot at the start point of the same color as colors[i]
    for i,c in enumerate(coords1):
        color = (int(colors[i][0]),int(colors[i][1]),int(colors[i][2]))
        cv2.drawMarker(new_img,c,color,markerType=cv2.MARKER_CROSS,markerSize=20,thickness=10)

    return new_img

def visualize_matches(pix_coords, feats1, feats2, rgb1, rgb2, include_distance_maps = False):
    matches,distance_maps = get_matches(feats1, feats2, pix_coords)
    coords_norm1 = [((p[0]+0.5)/feats1.shape[0],(p[1]+0.5)/feats1.shape[1]) for p in pix_coords]
    coords_norm2 = [((p[0]+0.5)/feats2.shape[0],(p[1]+0.5)/feats2.shape[1]) for p in matches]
    return plot_matches(rgb1, rgb2, coords_norm1, coords_norm2, distance_maps if include_distance_maps else [])
    # return plot_distribution(rgb1, rgb2, coords_norm1, distance_maps)

In [None]:
'''
This cell takes the click points from the above cell and visualizes matches/heatmaps in the provided video
'''
gif_frames = []
pix_coords = click_coords + []
t_start = 3
t_end = 6
fps = 5
for t in np.linspace(t_start,t_end,(t_end-t_start)*fps):
    frame = get_vid_frame(motion_clip,t)
    pil_image = ToTensor()(Image.fromarray(frame))
    img_pca_feats = dino_loader.get_pca_feats(pil_image.unsqueeze(0)).cuda().squeeze()
    img = visualize_matches(pix_coords, outputs["dino"], img_pca_feats, outputs['rgb'], pil_image.permute(1,2,0))
    gif_frames.append(img)
# save with mediapy
out_clip = mpy.ImageSequenceClip(gif_frames, fps=fps)
out_clip.write_videofile("nerfgun_matches_7.mp4", fps=fps)

#visualize the histogram of the difference maps across all frames
# fig,ax = plt.subplots(1,len(pix_coords),figsize=(20,5))
# all_dists = [[] for _ in pix_coords]
# for t in np.linspace(t_start,t_end,(t_end-t_start)*fps):
#     frame = get_vid_frame(motion_clip,t)
#     pil_image = ToTensor()(Image.fromarray(frame))
#     img_pca_feats = dino_loader.get_pca_feats(pil_image.unsqueeze(0)).cuda().squeeze()
#     matches,distance_maps = get_matches(outputs["dino"], img_pca_feats, pix_coords)
#     for i,d in enumerate(distance_maps):
#         all_dists[i].append(d.cpu().numpy().flatten())
# for i,d in enumerate(all_dists):
#     data = np.concatenate(d)
#     ax[i].hist(data,bins=200)
