In [None]:
'''
This cell loads the model from the config file and initializes the viewer
'''
%matplotlib widget
import torch
import matplotlib.pyplot as plt
from nerfstudio.utils.eval_utils import eval_setup
from pathlib import Path
import numpy as np
from nerfstudio.viewer.viewer import Viewer
from nerfstudio.configs.base_config import ViewerConfig
import cv2
from torchvision.transforms import ToTensor
from PIL import Image
from typing import List
config = Path("outputs/garfield_plushie/dig/2024-03-20_114231/config.yml")#with garfield, patch size 14, with denoising, 48->64 dim

# config = Path("outputs/nerfgun2/dig/2024-03-20_113021/config.yml")#with garfield, patch size 14, with denoising, 48->64 dim

# config = Path("outputs/boops_mug/dig/2024-03-20_110937/config.yml")#with garfield, patch size 14, with denoising, 48->64 dim
_,pipeline,_,_ = eval_setup(config)
pipeline.eval()
dino_loader = pipeline.datamanager.dino_dataloader
v = Viewer(ViewerConfig(default_composite_depth=False),config.parent,pipeline.datamanager.get_datapath(),pipeline)

In [None]:
"""
This cell defines a simple pose optimizer for learning a rigid transform offset given a gaussian model, star pose, and starting view
"""
from lerf.dig import DiGModel
from lerf.data.utils.dino_dataloader import DinoDataloader
from nerfstudio.cameras.cameras import Cameras
from nerfstudio.cameras.camera_optimizers import CameraOptimizerConfig, CameraOptimizer
from copy import deepcopy
from torchvision.transforms.functional import resize
import torchvision
from nerfstudio.cameras.lie_groups import exp_map_SE3
def get_vid_frame(cap,timestamp):
    fps = cap.get(cv2.CAP_PROP_FPS)
    
    # Calculate the frame number based on the timestamp and fps
    frame_number = min(int(timestamp * fps),int(cap.get(cv2.CAP_PROP_FRAME_COUNT)-1))
    
    # Set the video position to the calculated frame number
    cap.set(cv2.CAP_PROP_POS_FRAMES, frame_number)
    
    # Read the frame
    success, frame = cap.read()
    # convert BGR to RGB
    frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
    return frame
        
class Cam2ObjOptimizer:
    def __init__(self, dig_model:DiGModel, dino_loader: DinoDataloader, init_c2o: Cameras):
        self.dig_model = dig_model
        self.cam_optimizer = CameraOptimizer(CameraOptimizerConfig(mode = "SO3xR3", trans_l2_penalty = 0.0, rot_l2_penalty = 0.0),1,'cuda')
        self.dino_loader = dino_loader
        self.init_c2o = deepcopy(init_c2o).to('cuda')
        self.dig_model.eval()
        self.optimizer = torch.optim.Adam(list(self.cam_optimizer.parameters()),lr=0.02)
        self.blur = torchvision.transforms.GaussianBlur(kernel_size=[13,13]).cuda()

    def set_frame(self, rgb_frame:torch.Tensor, ):
        """
        Sets the rgb_frame to optimize the pose for
        rgb_frame: HxWxC tensor image
        init_c2o: initial camera to object transform (given whatever coordinates the self.dig_model is in)
        """
        self.rgb_frame = rgb_frame
        self.frame_pca_feats = self.dino_loader.get_pca_feats(rgb_frame.permute(2,0,1).unsqueeze(0)).cuda().squeeze()

    def step(self, niters = 1):
        for i in range(niters):
            self.optimizer.zero_grad()
            c2o = deepcopy(self.init_c2o)
            c2o.camera_to_worlds.requires_grad = True
            c2o.metadata = {'cam_idx':0}
            self.cam_optimizer.apply_to_camera(c2o)
            dig_outputs = self.dig_model.get_outputs(c2o)
            dino_feats = self.blur(dig_outputs["dino"].permute(2,0,1)).permute(1,2,0)
            # THIS IS BAD WE NEED TO FIX THIS
            frame_feats = resize(self.frame_pca_feats.permute(2,0,1), (dino_feats.shape[0],dino_feats.shape[1])).permute(1,2,0).contiguous()
            pix_loss = (frame_feats - dino_feats)
            loss = pix_loss.norm()
            loss.backward()
            self.optimizer.step()
        return dig_outputs
from gsplat._torch_impl import quat_to_rotmat
from scipy.spatial.transform import Rotation as Rot
from typing import Literal
class RigidGroupOptimizer():
    rot_type: Literal['quat','SE3'] = 'quat'
    loss_type: Literal['dino','rgb','both'] = 'both'
    def __init__(self, dig_model: DiGModel, dino_loader: DinoDataloader, init_c2o: Cameras, group_masks: List[torch.Tensor]):
        """
        This one takes in a list of gaussian ID masks to optimize local poses for
        Each rigid group can be optimized independently, with no skeletal constraints
        """
        self.dig_model = dig_model
        #detach all the params to avoid retain_graph issue
        self.dig_model.gauss_params['means'] = self.dig_model.gauss_params['means'].detach()
        self.dig_model.gauss_params['quats'] = self.dig_model.gauss_params['quats'].detach()
        self.dino_loader = dino_loader
        self.group_masks = group_masks
        self.init_c2o = deepcopy(init_c2o).to('cuda')
        #store a 6-vec of trans, rotation for each group
        if self.rot_type == 'SE3':
            self.pose_deltas = torch.nn.Parameter(torch.zeros(len(group_masks),6,dtype=torch.float32,device='cuda'))
        elif self.rot_type == 'quat':
            self.pose_deltas = torch.zeros(len(group_masks),7,dtype=torch.float32,device='cuda')
            self.pose_deltas[:,3:] = torch.tensor([1,0,0,0],dtype=torch.float32,device='cuda')
            self.pose_deltas = torch.nn.Parameter(self.pose_deltas)
        lr = .03 if self.loss_type == 'dino' else .01
        self.optimizer = torch.optim.Adam([self.pose_deltas],lr=lr)
        self.init_means = dig_model.gauss_params['means'].detach().clone()
        self.init_quats = dig_model.gauss_params['quats'].detach().clone()
        self.blur = torchvision.transforms.GaussianBlur(kernel_size=[5,5]).cuda()

    def step(self, niter = 1):
        self.dig_model.eval()
        for i in range(niter):
            self.optimizer.zero_grad()
            self.apply_to_model()
            dig_outputs = self.dig_model.get_outputs(self.init_c2o)
            if 'dino' not in dig_outputs:
                self.reset_transforms()
                raise RuntimeError("Lost tracking")
            if self.loss_type == 'dino':
                dino_feats = self.blur(dig_outputs["dino"].permute(2,0,1)).permute(1,2,0).contiguous()
                # THIS IS BAD WE NEED TO FIX THIS
                pix_loss = (self.frame_pca_feats - dino_feats)
                loss = pix_loss.abs().mean()
            elif self.loss_type == 'rgb':
                rgb_out = dig_outputs["rgb"]
                pix_loss = (self.rgb_frame - rgb_out)
                loss = pix_loss.abs().mean()
            elif self.loss_type == 'both':
                rgb_out = dig_outputs["rgb"]
                rgb_loss = (self.rgb_frame - rgb_out)
                dino_feats = self.blur(dig_outputs["dino"].permute(2,0,1)).permute(1,2,0).contiguous()
                # THIS IS BAD WE NEED TO FIX THIS
                dino_loss = (self.frame_pca_feats - dino_feats)
                loss = (rgb_loss.abs().mean() + dino_loss.abs().mean())/2
            loss.backward()
            #weight the grads for rotation higher
            self.optimizer.step()
            self.reset_transforms()
        return dig_outputs
    
    def apply_to_model(self):
        """
        Takes the current pose_deltas and applies them to each of the group masks
        """
        transforms = self.get_3x4s()
        self.reset_transforms()
        for i,mask in enumerate(self.group_masks):
            H = transforms[i]
            self.apply_H_to_group(mask, H)

    def reset_transforms(self):
        with torch.no_grad():
            self.dig_model.gauss_params['means'] = self.init_means.clone()
            self.dig_model.gauss_params['quats'] = self.init_quats.clone()

    def get_3x4s(self):
        """
        Returns a list of 3x4 transforms for each group
        """
        if self.rot_type == 'quat':
            Hs = torch.zeros(len(self.group_masks),3,4,dtype=torch.float32,device='cuda')
            for i in range(len(self.group_masks)):
                Hs[i,:3,3] = self.pose_deltas[i,:3]
                Hs[i,:3,:3] = quat_to_rotmat(self.pose_deltas[i,3:])
            return Hs
        elif self.rot_type == 'SE3':
            return exp_map_SE3(self.pose_deltas)
    
    def apply_H_to_group(self, group_mask:torch.Tensor, H:torch.Tensor):
        """
        Applies the 4x4 transform H to the gaussians in the group mask
        """
        # apply H to the quats
        with torch.no_grad():
            all_Rs = quat_to_rotmat(self.dig_model.gauss_params['quats'])
            all_Rs = torch.matmul(H[:3,:3],all_Rs)
            self.dig_model.gauss_params['quats'][group_mask] = torch.tensor(Rot.from_matrix(all_Rs[group_mask].cpu().detach().numpy()).as_quat()).float().cuda()[:, [3, 0, 1, 2]]
        # First apply H to the means
        muled_means = H[:3,3] + torch.matmul(H[:3,:3],self.dig_model.gauss_params['means'].T).T
        self.dig_model.gauss_params['means'] = torch.where(group_mask[...,None],muled_means,self.dig_model.gauss_params['means'])[:,:3]
        
    def set_frame(self, rgb_frame:torch.Tensor):
        """
        Sets the rgb_frame to optimize the pose for
        rgb_frame: HxWxC tensor image
        init_c2o: initial camera to object transform (given whatever coordinates the self.dig_model is in)
        """
        with torch.no_grad():
            self.rgb_frame = resize(rgb_frame.permute(2,0,1), (self.init_c2o.height,self.init_c2o.width)).permute(1,2,0).contiguous()
            self.frame_pca_feats = self.dino_loader.get_pca_feats(rgb_frame.permute(2,0,1).unsqueeze(0),keep_cuda=True).squeeze()
            self.frame_pca_feats = resize(self.frame_pca_feats.permute(2,0,1), (self.init_c2o.height,self.init_c2o.width)).permute(1,2,0).contiguous()
MATCH_RESOLUTION = 250
train_cam_pose,data = pipeline.datamanager.next_train(0)
view_cam_pose = pipeline.viewer_control.get_camera(200,None,0)
train_cam_pose.camera_to_worlds = view_cam_pose.camera_to_worlds
train_cam_pose.rescale_output_resolution(MATCH_RESOLUTION/max(train_cam_pose.width,train_cam_pose.height))
outputs = pipeline.model.get_outputs_for_camera(train_cam_pose)
plt.imshow(outputs["rgb"].cpu().numpy())
plt.show()
halfmask = torch.ones(pipeline.model.means.shape[0],dtype=torch.bool,device='cuda')
optimizer = RigidGroupOptimizer(pipeline.model,dino_loader,train_cam_pose,[halfmask])
rgb_renders = []

In [None]:
import moviepy.editor as mpy
import tqdm
video_path = Path("garfield_move.mp4")
assert video_path.exists()
motion_clip = cv2.VideoCapture(str(video_path.absolute()))
start=4
end=8
fps = 30
frame = get_vid_frame(motion_clip,start)
target_frame_rgb = ToTensor()(Image.fromarray(frame)).permute(1,2,0).cuda()
optimizer.set_frame(target_frame_rgb)
try:
    for i in tqdm.tqdm(range(10)):
        target_vis_frame = resize(target_frame_rgb.permute(2,0,1),(outputs["rgb"].shape[0],outputs["rgb"].shape[1])).permute(1,2,0)
        #composite the outputs['rgb'] on top of target_vis frame
        target_vis_frame = target_vis_frame*0.5 + outputs["rgb"]*0.5
        vis_frame = torch.concatenate([outputs["rgb"],target_vis_frame],dim=1)
        rgb_renders.append(vis_frame.detach().cpu().numpy()*255)
        outputs = optimizer.step(15)
    for t in tqdm.tqdm(np.linspace(start,end,int((end-start)*fps))):
        frame = get_vid_frame(motion_clip,t)
        target_frame_rgb = ToTensor()(Image.fromarray(frame)).permute(1,2,0).cuda()
        optimizer.set_frame(target_frame_rgb)
        outputs = optimizer.step(10)
        target_vis_frame = resize(target_frame_rgb.permute(2,0,1),(outputs["rgb"].shape[0],outputs["rgb"].shape[1])).permute(1,2,0)
        #composite the outputs['rgb'] on top of target_vis frame
        target_vis_frame = target_vis_frame*0.5 + outputs["rgb"]*0.5
        vis_frame = torch.concatenate([outputs["rgb"],target_vis_frame],dim=1)
        rgb_renders.append(vis_frame.detach().cpu().numpy()*255)
except RuntimeError:
    print("Lost tracking")
#save as an mp4
out_clip = mpy.ImageSequenceClip(rgb_renders, fps=fps)
out_clip.write_videofile("both_garfield.mp4", fps=fps)