In [1]:
from gs_renderer_4d import Renderer, MiniCam
from dataset_4d import SparseDataset
import os
import tqdm
import numpy as np
import torch

from cam_utils import orbit_camera, OrbitCamera
from guidance.sd_utils import StableDiffusion


class trainer:
    def __init__(self,opt) -> None:
        
        #initialize options
        self.opt=opt
        self.device=self.opt.device
        
        #initialize renderer and gaussians
        self.renderer = Renderer(sh_degree=self.opt.sh_degree)
        self.renderer.initialize(num_pts=self.opt.num_pts)   
        self.renderer.gaussians.training_setup(self.opt)
        
        self.optimizer = self.renderer.gaussians.optimizer
        
        self.cam = OrbitCamera(opt.W, opt.H, r=opt.radius, fovy=opt.fovy)
        
        #initialize sd. replace with your own diffusion model if necessary.
        self.enable_sd = True
        self.guidance_sd = StableDiffusion(self.device)
        self.guidance_sd.get_text_embeds([self.opt.prompt],negative_prompts= [''])
        
    def save(self,save_path):
        #save 
        auto_path = save_path
        os.makedirs(auto_path,exist_ok=True)
        ply_path = os.path.join(auto_path,'model.ply')
        self.renderer.gaussians.save_ply(ply_path)
        self.renderer.gaussians.save_deformation(auto_path)
        
    def load(self, load_path):
        #load
        auto_path = load_path
        ply_path = os.path.join(auto_path,'model.ply')
        self.renderer.gaussians.load_model(auto_path)
        self.renderer.gaussians.load_ply(ply_path)
           
           
    def render(self,frame_id, elevation, azimuth, radius):
        #render with parameters
        pose = orbit_camera(elevation,azimuth,radius)
        cam = MiniCam(
                        pose,
                        self.opt.ref_size,
                        self.opt.ref_size,
                        self.cam.fovy,
                        self.cam.fovx,
                        self.cam.near,
                        self.cam.far,
                        )   
        cam.time=float(frame_id/30) #30 is the total frame
        #use stage='coarse' for static rendering, use stage='fine' for dynamic rendering
        out = self.renderer.render(cam,stage='fine')
        image = out["image"].unsqueeze(0)# [1, 3, H, W] in [0, 1]
        
        return image
    
    def train(self):
        self.step=0
        
        for i in tqdm.tqdm(range(10000)):
            self.step+=1
            self.renderer.gaussians.update_learning_rate(self.step)
            loss = 0
            
            min_ver = -30
            max_ver = 30
            vers, hors, radiis, poses = [], [], [], []
            images=[]
            viewspace_point_tensor_list, radii_list, visibility_filter_list = [], [], []

            render_resolution=512
            
            for _ in range(self.opt.batch_size):
                #sample time, vertical& horizontal  angle
                ver = np.random.randint(min_ver, max_ver)
                hor = np.random.randint(-180, 180)
                radius=0
                self.t = np.random.randint(0,30)
                self.time = self.t/30
                
                vers.append(torch.tensor(self.opt.elevation + ver,device=self.device).unsqueeze(dim=0))
                hors.append(torch.tensor(hor,device=self.device).unsqueeze(dim=0))
                radiis.append(torch.tensor(self.opt.radius + radius,device=self.device).unsqueeze(dim=0))
                
                pose = orbit_camera(self.opt.elevation + ver, hor, self.opt.radius + radius)
                
                poses.append(pose)


                cur_cam = MiniCam(
                        pose,
                        render_resolution,
                        render_resolution,
                        self.cam.fovy,
                        self.cam.fovx,
                        self.cam.near,
                        self.cam.far,
                    )
                cur_cam.time=self.time
                
                bg_color = torch.tensor([1, 1, 1] if np.random.rand() > self.opt.invert_bg_prob else [0, 0, 0], dtype=torch.float32, device="cuda")
                #use stage='coarse' for static rendering, use stage='fine' for dynamic rendering
                out = self.renderer.render(cur_cam, bg_color=bg_color,stage='fine')
                
                #basic values for densification
                viewspace_point_tensor, visibility_filter, radii = out["viewspace_points"], out["visibility_filter"], out["radii"]  
                radii_list.append(radii.unsqueeze(0))
                visibility_filter_list.append(visibility_filter.unsqueeze(0))
                viewspace_point_tensor_list.append(viewspace_point_tensor)
                
                image = out["image"].unsqueeze(0)# [1, 3, H, W] in [0, 1]
                images.append(image)
                
            images_batch = torch.cat(images, dim=0)
            poses = torch.from_numpy(np.stack(poses, axis=0)).to(self.device)
            vers_batch = torch.cat(vers, dim=0).cpu().numpy()
            hors_batch = torch.cat(hors, dim=0).cpu().numpy()
            radii_batch = torch.cat(radiis, dim=0).cpu().numpy()

            if self.enable_sd:
                sd_loss = self.guidance_sd.train_step(images_batch,step_ratio=None,poses=poses)
                # guidance loss. replace with your own diffusion model if necessary.
                loss = loss + sd_loss
            else:
                zero123_loss = self.guidance_zero123.train_step(images_batch, vers_batch, hors_batch, radii_batch,step_ratio=None)
                # guidance loss.
                loss = loss + zero123_loss
                
            # optimize step
            loss.backward()
            self.optimizer.step()
            self.optimizer.zero_grad()

            #densifications. Adaptive densification is used here.
            viewspace_point_tensor_grad = torch.zeros_like(viewspace_point_tensor)
            for idx in range(0, len(viewspace_point_tensor_list)):
                    viewspace_point_tensor_grad = viewspace_point_tensor_grad + viewspace_point_tensor_list[idx].grad

            if self.step >= self.opt.density_start_iter and self.step <= self.opt.density_end_iter:
                self.renderer.gaussians.max_radii2D[visibility_filter] = torch.max(self.renderer.gaussians.max_radii2D[visibility_filter], radii[visibility_filter])
                self.renderer.gaussians.add_densification_stats(viewspace_point_tensor_grad, visibility_filter)
                if self.step % self.opt.densification_interval == 0 :

                    self.renderer.gaussians.densify_and_prune(self.opt.densify_grad_threshold, min_opacity=0.01, extent=1, max_screen_size=2)


In [2]:
from omegaconf import OmegaConf

opt=OmegaConf.load('./configs/image_4d_m.yaml')

train=trainer(opt)
train.train()

feature_dim: 128
Number of points at initialisation :  10000


Loading pipeline components...:   0%|          | 0/6 [00:00<?, ?it/s]

  1%|          | 72/10000 [00:19<44:48,  3.69it/s]  


In [1]:
import numpy as np
np.deg2rad(0)

0.0