In [1]:
import time
from collections import OrderedDict
from pathlib import Path

import torch
from lib.structures.field_list import collect

from lib import utils, logger, config, modeling, solver, data

%load_ext autoreload
%autoreload 2

## Setup

In [2]:
config.merge_from_file('configs/front3d_train_3d.yaml')

model = modeling.PanopticReconstruction()
device = torch.device(config.MODEL.DEVICE)
model.to(device, non_blocking=True)

model.log_model_info()
model.fix_weights()

# Setup optimizer, scheduler, checkpointer
optimizer = torch.optim.Adam(model.parameters(), config.SOLVER.BASE_LR,
                                          betas=(config.SOLVER.BETA_1, config.SOLVER.BETA_2),
                                          weight_decay=config.SOLVER.WEIGHT_DECAY)
scheduler = solver.WarmupMultiStepLR(optimizer, config.SOLVER.STEPS, config.SOLVER.GAMMA,
                                                  warmup_factor=1,
                                                  warmup_iters=0,
                                                  warmup_method="linear")


-----------------------------------------
unet_output_channels 16
unet_fetures 16


In [3]:
model_dict = model.state_dict()
pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print("Number of Trainable Parameters: {}".format(pytorch_total_params))

output_path = Path('output')
checkpointer = utils.DetectronCheckpointer(model, optimizer, scheduler, output_path)

# Load the checkpoint
checkpoint_data = checkpointer.load()

checkpoint_arguments = {}
checkpoint_arguments["iteration"] = 0

if config.SOLVER.LOAD_SCHEDULER:
    checkpoint_arguments.update(checkpoint_data)

# TODO: move to checkpointer?
if config.MODEL.PRETRAIN2D:
    pretrain_2d = torch.load(config.MODEL.PRETRAIN2D)
    model.load_state_dict(pretrain_2d["model"])
    
# Dataloader
dataloader = data.setup_dataloader(config.DATASETS.TRAIN)

Number of Trainable Parameters: 10367060
Number of Trainable Parameters: 10367060


## Training

In [None]:
# Switch training mode
# self.model.switch_training()
print(len(dataloader))
model.switch_training()
iteration = 0
iteration_end = time.time()


for idx, (image_ids, targets) in enumerate(dataloader):
    assert targets is not None, "error during data loading"
    data_time = time.time() - iteration_end
    # Get input images
    images = collect(targets, "color")

    # Pass through model
    # try:
    losses, results = model(images, targets)
    # except Exception as e:
    #     print(e, "skipping", image_ids[0])
    #     del targets, images
    #     continue
    
    # Accumulate total loss
    total_loss: torch.Tensor = 0.0
    log_meters = OrderedDict()

    for loss_group in losses.values():
        for loss_name, loss in loss_group.items():
            if torch.is_tensor(loss) and not torch.isnan(loss) and not torch.isinf(loss):
                total_loss += loss
                log_meters[loss_name] = loss.item()

    # Loss backpropagation, optimizer & scheduler step
    optimizer.zero_grad()

    if torch.is_tensor(total_loss):
        total_loss.backward()
        optimizer.step()
        scheduler.step()
        log_meters["total"] = total_loss.item()
    else:
        log_meters["total"] = total_loss

    # Minkowski Engine recommendation
    torch.cuda.empty_cache()
    
    # Save checkpoint
    if iteration % config.SOLVER.CHECKPOINT_PERIOD == 0:
        checkpointer.save(f"model_{iteration:07d}", **checkpoint_arguments)
    
    last_training_stage = model.set_current_training_stage(iteration)
    
    # Save additional checkpoint after hierarchy level
    if last_training_stage is not None:
        checkpointer.save(f"model_{last_training_stage}_{iteration:07d}", **checkpoint_arguments)
        logger.info(f"Finish {last_training_stage} hierarchy level")
    
    iteration += 1
    iteration_end = time.time()

    print("\riteration: {}, total_loss: {}".format(iteration, total_loss), end="")
    if iteration%10 == 0:
        print("\riteration: {}, total_loss: {}".format(iteration, total_loss))
        
    # if idx>4:
        # break


300


  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
  padding_offsets = difference // 2
  coords = coords // tensor_stride


iteration: 10, total_loss: 8.237290382385254
iteration: 11, total_loss: 8.946187973022461

  predicted_coordinates[:, 1:] = predicted_coordinates[:, 1:] // prediction.tensor_stride[0]
  predicted_coordinates[:, 1:] = predicted_coordinates[:, 1:] // prediction.tensor_stride[0]
  predicted_coordinates[:, 1:] = predicted_coordinates[:, 1:] // prediction.tensor_stride[0]


iteration: 20, total_loss: 118.40057373046875
iteration: 30, total_loss: 109.09566497802734
iteration: 36, total_loss: 97.508529663085944

  predicted_coordinates[:, 1:] = predicted_coordinates[:, 1:] // prediction.tensor_stride[0]
  predicted_coordinates[:, 1:] = predicted_coordinates[:, 1:] // prediction.tensor_stride[0]
  predicted_coordinates[:, 1:] = predicted_coordinates[:, 1:] // prediction.tensor_stride[0]
  predicted_coordinates[:, 1:] = predicted_coordinates[:, 1:] // prediction.tensor_stride[0]


iteration: 40, total_loss: 464.70416259765625
iteration: 50, total_loss: 280.65881347656255
iteration: 60, total_loss: 224.05183410644532
iteration: 70, total_loss: 212.21794128417978
iteration: 80, total_loss: 169.01382446289062
iteration: 90, total_loss: 163.88758850097656
iteration: 100, total_loss: 166.9730987548828
iteration: 110, total_loss: 161.50872802734375
iteration: 111, total_loss: 147.83642578125

## Get color prediction for rendering

In [None]:
print(results.keys())
print(results['frustum'].keys())
geometry_sparse_prediction = results['frustum']['geometry']
rgb_sparse_prediction = results['frustum']['rgb']
print("geometry_sparse shape: ", geometry_sparse_prediction.shape)
print("rgb_sparse shape: ", rgb_sparse_prediction.shape)

In [None]:
from lib.structures import DepthMap
import numpy as np
from typing import Tuple
from lib.structures.frustum import compute_camera2frustum_transform


# def adjust_intrinsic(intrinsic: np.array, intrinsic_image_dim: Tuple, image_dim: Tuple) -> np.array:
#     if intrinsic_image_dim == image_dim:
#         return intrinsic

#     intrinsic_return = np.copy(intrinsic)

#     height_after = image_dim[1]
#     height_before = intrinsic_image_dim[1]

#     width_after = image_dim[0]
#     width_before = intrinsic_image_dim[0]

#     intrinsic_return[0, 0] *= float(width_after) / float(width_before)
#     intrinsic_return[1, 1] *= float(height_after) / float(height_before)

#     # account for cropping/padding here
#     intrinsic_return[0, 2] *= float(width_after - 1) / float(width_before - 1)
#     intrinsic_return[1, 2] *= float(height_after - 1) / float(height_before - 1)

#     return intrinsic_return

dense_dimensions = torch.Size([1, 1] + config.MODEL.FRUSTUM3D.GRID_DIMENSIONS)
min_coordinates = torch.IntTensor([0, 0, 0]).to(device)
truncation = config.MODEL.FRUSTUM3D.TRUNCATION

# Get Dense Predictions
geometry, _, _ = geometry_sparse_prediction.dense(dense_dimensions, min_coordinates, default_value=truncation)
rgb, _, _ = rgb_sparse_prediction.dense(dense_dimensions, min_coordinates)
geometry = geometry.squeeze()
rgb = rgb.squeeze()
print("input shape: ", images.shape)
print("rgb: {}".format(rgb.shape))
print("rgb values: [{},{}]".format(torch.max(rgb), torch.min(rgb)))
print("geometry: {}".format(geometry.shape))
print("geometry values: [{},{}]".format(torch.max(geometry), torch.min(geometry)))


# # Generate Mesh and Render
# # Prepare intrinsic matrix.
# color_image_size = (320, 240)
# depth_image_size = (160, 120)
# front3d_intrinsic = np.array(config.MODEL.PROJECTION.INTRINSIC)
# front3d_intrinsic = adjust_intrinsic(front3d_intrinsic, color_image_size, depth_image_size)
# front3d_intrinsic = torch.from_numpy(front3d_intrinsic).to(device).float()

# print('\n camera_instrinsics: ', front3d_intrinsic)
# camera2frustum = compute_camera2frustum_transform(front3d_intrinsic.cpu(), torch.tensor(images.size()) / 2.0,
#                                                       config.MODEL.PROJECTION.DEPTH_MIN,
#                                                       config.MODEL.PROJECTION.DEPTH_MAX,
#                                                       config.MODEL.PROJECTION.VOXEL_SIZE)

# camera2frustum[:3, 3] += (torch.tensor([256, 256, 256]) - torch.tensor([231, 174, 187])) / 2
# frustum2camera = torch.inverse(camera2frustum)
# print("frustum2camera: ", frustum2camera)

## Use marching cubes to generate mesh

In [None]:
import marching_cubes as mc

distance_field = geometry.clone()
colors = rgb.clone().permute(1,2,3,0)

if isinstance(distance_field, torch.Tensor):
    distance_field = distance_field.detach().cpu().numpy()
if isinstance(colors, torch.Tensor):
    colors = colors.detach().cpu().numpy()
    
vertices_i, triangles_i = mc.marching_cubes_color(distance_field, colors, 1.0, truncation)
colors_i = vertices_i[..., 3:]
vertices_i = vertices_i[..., :3]

vertices = torch.from_numpy(vertices_i.astype(np.float32))
triangles = torch.from_numpy(triangles_i.astype(np.int64))
colors_rgb = torch.from_numpy(colors_i.astype(np.float32)).unsqueeze(0)

print("vertices shape: {}".format(vertices.shape))
print("colors shape: {}".format(colors_rgb.shape))
print("triangles shape: {}".format(triangles.shape))

In [None]:
# import numpy as np

# import torch
# from torchmcubes import marching_cubes, grid_interp

# # Grid data
# N = 128
# x, y, z = np.mgrid[:N, :N, :N]
# x = (x / N).astype('float32')
# y = (y / N).astype('float32')
# z = (z / N).astype('float32')

# # Implicit function (metaball)
# f0 = (x - 0.35) ** 2 + (y - 0.35) ** 2 + (z - 0.35) ** 2
# f1 = (x - 0.65) ** 2 + (y - 0.65) ** 2 + (z - 0.65) ** 2
# u = 1.0 / f0 + 1.0 / f1
# rgb = np.stack((x, y, z), axis=-1)
# rgb = np.transpose(rgb, axes=(3, 2, 1, 0)).copy()
# print(rgb.shape)
# print(u.shape)
# # Test
# u = torch.from_numpy(u).cuda()
# rgb = torch.from_numpy(rgb).cuda()
# verts, faces = marching_cubes(u, 15.0)
# colros = grid_interp(rgb, verts)

# verts = verts.cpu().numpy()
# faces = faces.cpu().numpy()
# colors = colors.cpu().numpy()

In [None]:
# from pytorch3d.ops import cubify
# meshes = cubify(geometry.unsqueeze(0), thresh=1.0)
# print(type(meshes[0]))
from torchmcubes import marching_cubes, grid_interp
verts, faces = marching_cubes(geometry, 1.0)
colors = grid_interp(rgb, verts)

In [None]:
# colors = torch.randn_like(verts)
colors = colors-torch.min(colors)
colors = colors/torch.max(colors)
print("verts shape: ", verts.shape)
print("faces shape: ", faces.shape)
print("colors shape: ", colors.shape)
print("colors range: [{}, {}] ".format(torch.max(colors), torch.min(colors)))


## Render Views

In [None]:
import os
import sys
import torch
import pytorch3d
# Util function for loading meshes
from pytorch3d.io import load_objs_as_meshes, load_obj, load_ply

# Data structures and functions for rendering
from pytorch3d.structures import Meshes
from pytorch3d.renderer import Textures

from pytorch3d.vis.plotly_vis import AxisArgs, plot_batch_individually, plot_scene
from pytorch3d.vis.texture_vis import texturesuv_image_matplotlib
from pytorch3d.renderer import (
    look_at_view_transform,
    look_at_rotation,
    FoVPerspectiveCameras, 
    PointLights, 
    DirectionalLights, 
    Materials, 
    RasterizationSettings, 
    MeshRenderer, 
    MeshRasterizer,  
    HardPhongShader,
    SoftPhongShader,
    TexturesUV,
    TexturesVertex,
    OpenGLPerspectiveCameras, 
    
)

if torch.cuda.is_available():
    device = torch.device("cuda:0")
    torch.cuda.set_device(device)
else:
    device = torch.device("cpu")

# Generate texture
# tex = Textures(verts_rgb=colors_rgb)
# mesh = Meshes(verts=[vertices], faces=[triangles], textures=tex).to(device)

tex = Textures(verts_rgb=colors.unsqueeze(0))
mesh = Meshes(verts=[verts], faces=[faces], textures=tex).to(device)

# We scale normalize and center the target mesh to fit in a sphere of radius 1 
# centered at (0,0,0). (scale, center) will be used to bring the predicted mesh 
# to its original center and scale.  Note that normalizing the target mesh, 
# speeds up the optimization but is not necessary!
verts = mesh.verts_packed()
N = verts.shape[0]
center = verts.mean(0)
scale = max((verts - center).abs().max(0)[0])
mesh.offset_verts_(-center)
mesh.scale_verts_((-1.0 / float(scale)));

In [None]:
# Multiple view rendering
from plot_image_grid import image_grid
# the number of different viewpoints from which we want to render the mesh.
num_views = 20

# Get a batch of viewing angles. 
elev = torch.linspace(130, 200, num_views)
azim = torch.linspace(0, 360, num_views)

# Place a point light in front of the object. As mentioned above, the front of 
# the cow is facing the -z direction. 
lights = PointLights(device=device, location=[[0.0, 0.0, -3.0]])

# Initialize an OpenGL perspective camera that represents a batch of different 
# viewing angles. All the cameras helper methods support mixed type inputs and 
# broadcasting. So we can view the camera from the a distance of dist=2.7, and 
# then specify elevation and azimuth angles for each viewpoint as tensors. 
R, T = look_at_view_transform(dist=1.5, elev=elev, azim=azim)
R0 = look_at_rotation(T, at=((0, 0, 3.0), ), up=((0, -1, 0), ))

cameras = OpenGLPerspectiveCameras(device=device, R=R, T=T)


# We arbitrarily choose one particular view that will be used to visualize 
# results
camera = OpenGLPerspectiveCameras(device=device, R=R[None, 1, ...], 
                                  T=T[None, 1, ...]) 

# Define the settings for rasterization and shading. Here we set the output 
# image to be of size 128X128. As we are rendering images for visualization 
# purposes only we will set faces_per_pixel=1 and blur_radius=0.0. Refer to 
# rasterize_meshes.py for explanations of these parameters.  We also leave 
# bin_size and max_faces_per_bin to their default values of None, which sets 
# their values using heuristics and ensures that the faster coarse-to-fine 
# rasterization method is used.  Refer to docs/notes/renderer.md for an 
# explanation of the difference between naive and coarse-to-fine rasterization. 
raster_settings = RasterizationSettings(
    image_size=128, 
    blur_radius=0.0, 
    faces_per_pixel=1, 
)

# Create a Phong renderer by composing a rasterizer and a shader. The textured 
# Phong shader will interpolate the texture uv coordinates for each vertex, 
# sample from a texture image and apply the Phong lighting model
renderer = MeshRenderer(
    rasterizer=MeshRasterizer(
        cameras=camera, 
        raster_settings=raster_settings
    ),
    shader=HardPhongShader(
        device=device, 
        cameras=camera,
        lights=lights
    )
)

# Create a batch of meshes by repeating the cow mesh and associated textures. 
# Meshes has a useful `extend` method which allows us do this very easily. 
# This also extends the textures. 
meshes = mesh.extend(num_views)

# Render the cow mesh from each viewing angle
target_images = renderer(meshes, cameras=cameras, lights=lights)

# Our multi-view cow dataset will be represented by these 2 lists of tensors,
# each of length num_views.
target_rgb = [target_images[i, ..., :3] for i in range(num_views)]
target_cameras = [OpenGLPerspectiveCameras(device=device, R=R[None, i, ...], 
                                           T=T[None, i, ...]) for i in range(num_views)]



In [None]:
import matplotlib.pyplot as plt
image_grid(target_images.cpu().numpy(), rows=4, cols=5, rgb=True, show_axes=True)
plt.show()

# Deprecated stuff

In [None]:
# from lib.data import samplers, datasets, collate
# from lib.utils.imports import import_file
# from torch.utils import data

# def build_dataset(dataset_name) -> data.Dataset:
#     paths_catalog = import_file("lib.config.paths_catalog", config.PATHS_CATALOG, True)
#     dataset_catalog = paths_catalog.DatasetCatalog
#     print("dataset_catalog: ", dataset_catalog.get(dataset_name))
#     info = dataset_catalog.get(dataset_name)
#     factory = getattr(datasets, info.pop("factory"))
#     info["fields"] = config.DATASETS.FIELDS

#     # make dataset from factory
#     dataset = factory(**info)

#     return dataset

# dataset = build_dataset(config.DATASETS.TRAIN)

In [None]:
# print(len(dataset))
# print(dataset[0][1].get_field("color").shape)
# print(dataset[0][1].fields())