In [8]:
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Buile and train model."""

import os
import time

import mindspore as md
import numpy as np
from engine import RendererWithCriterion, test_net, train_net
from tqdm import tqdm

from data.load_llff import load_llff_data
from models import VolumeRenderer
from utils.config import get_config
from utils.engine_utils import context_setup, create_nerf
from utils.ray import generate_rays
from utils.results_handler import save_image, save_video
from utils.sampler import sample_grid_2d


def train_pipeline(config, out_dir):
    """Train nerf model: data preparation, model and optimizer preparation, and model training."""
    md.set_seed(1)

    print(">>> Loading dataset")

    if config.dataset_type == "blender":
        images, poses, render_poses, hwf, i_split = load_blender_data(config.data_dir, config.half_res,
                                                                      config.test_skip)
        print("Loaded blender", images.shape, render_poses.shape, hwf, config.data_dir)
        i_train, i_val, i_test = i_split
        near = 2.0
        far = 6.0

        if config.white_bkgd:
            images = images[..., :3] * images[..., -1:] + (1.0 - images[..., -1:])
        else:
            images = images[..., :3]

    elif config.dataset_type == "llff":
        images, poses, bds, render_poses, i_test = load_llff_data(
            config.data_dir,
            config.factor,
            recenter=True,
            bd_factor=0.75,
            spherify=config.spherify,
        )
        hwf = poses[0, :3, -1]
        poses = poses[:, :3, :4]
        print("Loaded llff", images.shape, render_poses.shape, hwf, config.data_dir)
        if not isinstance(i_test, list):
            i_test = [i_test]

        if config.llff_hold > 0:
            print("Auto LLFF holdout,", config.llff_hold)
            i_test = np.arange(images.shape[0])[::config.llff_hold]

        i_val = i_test
        i_train = np.array([i for i in np.arange(int(images.shape[0])) if (i not in i_test and i not in i_val)])

        print("DEFINING BOUNDS")
        config.no_ndc = True
        if config.no_ndc:
            near = float(np.min(bds)) * 0.9
            far = float(np.max(bds)) * 1.0
        else:
            near = 0.0
            far = 1.0
        print("NEAR FAR", near, far)

    else:
        print("Unknown dataset type", config.dataset_type, "exiting")
        return

    if config.render_test:
        render_poses = poses[i_test.tolist()]

    print(f"TRAIN views: {i_train}\nTEST views: {i_test}\nVAL views: {i_val}")

    # Cast intrinsics to right types
    cap_h, cap_w, focal = hwf
    cap_h, cap_w = int(cap_h), int(cap_w)

    hwf = [cap_h, cap_w, focal]
    # Setup logging and directory for results
    print(">>> Saving checkpoints and results in", out_dir)
    # Create output directory if not existing

    os.makedirs(out_dir, exist_ok=True)
    # Record current configuration
    with open(os.path.join(out_dir, "configs.txt"), "w+", encoding="utf-8") as config_f:
        attrs = vars(config)
        for k in attrs:
            config_f.write(f"{k} = {attrs[k]}\n")

    # Create network models, optimizer and renderer
    print(">>> Creating models")

    # Create nerf model
    (
        start_iter,
        optimizer,
        model_coarse,
        model_fine,
        embed_fn,
        embed_dirs_fn,
    ) = create_nerf(config, out_dir)
    # Training steps
    global_steps = start_iter
    # Create volume renderer
    renderer = VolumeRenderer(
        config.chunk,
        config.cap_n_samples,
        config.cap_n_importance,
        config.net_chunk,
        config.white_bkgd,
        model_coarse,
        model_fine,
        embed_fn,
        embed_dirs_fn,
        near,
        far,
    )

    renderer_with_criterion = RendererWithCriterion(renderer)
    optimizer = md.nn.Adam(
        params=renderer.trainable_params(),
        learning_rate=config.l_rate,
        beta1=0.9,
        beta2=0.999,
    )

    train_renderer = md.nn.TrainOneStepCell(renderer_with_criterion, optimizer)
    train_renderer.set_train()

    # Start training
    print(">>> Start training")

    cap_n_rand = config.cap_n_rand

    # Move training data to GPU
    images = md.Tensor(images)
    poses = md.Tensor(poses)

    # Maximum training iterations
    cap_n_iters = config.cap_n_iters
    if start_iter >= cap_n_iters:
        return

    train_model(config, out_dir, images, poses, i_train, i_test, cap_h, cap_w, focal, start_iter, optimizer,
                global_steps, renderer, train_renderer, cap_n_rand, cap_n_iters)


def train_model(config, out_dir, images, poses, i_train, i_test, cap_h, cap_w, focal, start_iter, optimizer,
                global_steps, renderer, train_renderer, cap_n_rand, cap_n_iters):
    """Training model iteratively"""
    with tqdm(range(1, cap_n_iters + 1)) as p_bar:
        p_bar.n = start_iter

        for _ in p_bar:
            # Show progress
            p_bar.set_description(f"Iter {global_steps + 1:d}")
            p_bar.update()

            # Start time of the current iteration
            time_0 = time.time()

            img_i = int(np.random.choice(i_train))

            target = images[img_i]
            pose = poses[img_i, :3, :4]

            if cap_n_rand is not None:
                rays_o, rays_d = generate_rays(cap_h, cap_w, focal,
                                               md.Tensor(pose))  # (cap_h, cap_w, 3), (cap_h, cap_w, 3)
                sampled_rows, sampled_cols = sample_grid_2d(cap_h, cap_w, cap_n_rand)
                rays_o = rays_o[sampled_rows, sampled_cols]  # (cap_n_rand, 3)
                rays_d = rays_d[sampled_rows, sampled_cols]  # (cap_n_rand, 3)

                batch_rays = md.ops.Stack(axis=0)([rays_o, rays_d])
                target_s = target[sampled_rows, sampled_cols]  # (cap_n_rand, 3)

            loss, psnr = train_net(config, global_steps, train_renderer, optimizer, batch_rays, target_s)

            p_bar.set_postfix(time=time.time() - time_0, loss=loss, psnr=psnr)

            # Logging
            # Save training states
            if (global_steps + 1) % config.i_ckpt == 0:
                path = os.path.join(out_dir, f"{global_steps + 1:06d}.tar")

                md.save_checkpoint(
                    save_obj=renderer,
                    ckpt_file_name=path,
                    append_dict={"global_steps": global_steps},
                    async_save=True,
                )
                p_bar.write(f"Saved checkpoints at {path}")

            # Save testing results
            if (global_steps + 1) % config.i_testset == 0:
                test_save_dir = os.path.join(out_dir, f"test_{global_steps + 1:06d}")
                os.makedirs(test_save_dir, exist_ok=True)

                p_bar.write(f"Testing (iter={global_steps + 1}):")

                test_time, test_loss, test_psnr = test_net(
                    cap_h,
                    cap_w,
                    focal,
                    renderer,
                    md.Tensor(poses[i_test.tolist()]),
                    images[i_test.tolist()],
                    on_progress=lambda j, img: save_image(j, img, test_save_dir),  # pylint: disable=cell-var-from-loop
                    on_complete=lambda imgs: save_video(global_steps + 1, imgs, test_save_dir),  # pylint: disable=cell-var-from-loop
                )

                p_bar.write(
                    f"Testing results: [ Mean Time: {test_time:.4f}s, Loss: {test_loss:.4f}, PSNR: {test_psnr:.4f} ]")

            global_steps += 1


def main_():
    """main function, set up config."""
    config = get_config()

    # Cuda device
    context_setup(config.gpu, config.device, getattr(md.context, config.mode))

    # Output directory
    base_dir = config.base_dir
    if not os.path.exists(base_dir):
        os.makedirs(base_dir)

    # Experiment name
    exp_name = config.dataset_type + "_" + config.name
    # get the experiment number
    exp_num = max([int(fn.split("_")[-1]) for fn in os.listdir(base_dir) if fn.find(exp_name) >= 0] + [0])
    if config.no_reload:
        exp_num += 1

    # Output directory
    out_dir = os.path.join(base_dir, exp_name + "_" + str(exp_num))

    # Start training pipeline
    train_pipeline(config, out_dir)

main_()
# if __name__ == "__main__":
#     main()


usage: __main__.py [-h] [--config CONFIG] --name NAME [--base_dir BASE_DIR]
                   [--data_dir DATA_DIR] [--cap_n_iters CAP_N_ITERS]
                   [--net_depth NET_DEPTH] [--net_width NET_WIDTH]
                   [--net_depth_fine NET_DEPTH_FINE]
                   [--net_width_fine NET_WIDTH_FINE] [--cap_n_rand CAP_N_RAND]
                   [--l_rate L_RATE] [--l_rate_decay L_RATE_DECAY]
                   [--chunk CHUNK] [--net_chunk NET_CHUNK] [--no_batching]
                   [--no_reload] [--gpu GPU] [--device {GPU,CPU,Ascend}]
                   [--cap_n_samples CAP_N_SAMPLES]
                   [--cap_n_importance CAP_N_IMPORTANCE] [--perturb PERTURB]
                   [--use_view_dirs] [--i_embed I_EMBED]
                   [--multi_res MULTI_RES] [--multi_res_views MULTI_RES_VIEWS]
                   [--raw_noise_std RAW_NOISE_STD] [--render_only]
                   [--render_test] [--render_factor RENDER_FACTOR]
                   [--dataset_type DATASET_

SystemExit: 2