In [None]:

import torch
import numpy as np
import trimesh

from scipy.spatial.transform import Rotation

from dust3r.inference import inference, load_model
from dust3r.image_pairs import make_pairs
from dust3r.utils.image import load_images, rgb
from dust3r.utils.device import to_numpy
from dust3r.viz import add_scene_cam, CAM_COLORS, OPENGL, pts3d_to_trimesh, cat_meshes
from dust3r.cloud_opt import global_aligner, GlobalAlignerMode

import matplotlib.pyplot as pl
pl.ion()

torch.backends.cuda.matmul.allow_tf32 = True  # for gpu >= Ampere and pytorch >= 1.12
batch_size = 1
import pandas as pd
import open3d as o3d

In [None]:
ade20K = pd.read_csv(
    "/scratch2/yuxili/interiorDesign/color_coding_semantic_segmentation_classes - Sheet1.csv"
)

In [None]:
label_dict = {
    0: "wall",
    1: "building",
    2: "sky",
    3: "floor",
    4: "tree",
    5: "ceiling",
    6: "road",
    7: "bed ",
    8: "windowpane",
    9: "grass",
    10: "cabinet",
    11: "sidewalk",
    12: "person",
    13: "earth",
    14: "door",
    15: "table",
    16: "mountain",
    17: "plant",
    18: "curtain",
    19: "chair",
    20: "car",
    21: "water",
    22: "painting",
    23: "sofa",
    24: "shelf",
    25: "house",
    26: "sea",
    27: "mirror",
    28: "rug",
    29: "field",
    30: "armchair",
    31: "seat",
    32: "fence",
    33: "desk",
    34: "rock",
    35: "wardrobe",
    36: "lamp",
    37: "bathtub",
    38: "railing",
    39: "cushion",
    40: "base",
    41: "box",
    42: "column",
    43: "signboard",
    44: "chest of drawers",
    45: "counter",
    46: "sand",
    47: "sink",
    48: "skyscraper",
    49: "fireplace",
    50: "refrigerator",
    51: "grandstand",
    52: "path",
    53: "stairs",
    54: "runway",
    55: "case",
    56: "pool table",
    57: "pillow",
    58: "screen door",
    59: "stairway",
    60: "river",
    61: "bridge",
    62: "bookcase",
    63: "blind",
    64: "coffee table",
    65: "toilet",
    66: "flower",
    67: "book",
    68: "hill",
    69: "bench",
    70: "countertop",
    71: "stove",
    72: "palm",
    73: "kitchen island",
    74: "computer",
    75: "swivel chair",
    76: "boat",
    77: "bar",
    78: "arcade machine",
    79: "hovel",
    80: "bus",
    81: "towel",
    82: "light",
    83: "truck",
    84: "tower",
    85: "chandelier",
    86: "awning",
    87: "streetlight",
    88: "booth",
    89: "television receiver",
    90: "airplane",
    91: "dirt track",
    92: "apparel",
    93: "pole",
    94: "land",
    95: "bannister",
    96: "escalator",
    97: "ottoman",
    98: "bottle",
    99: "buffet",
    100: "poster",
    101: "stage",
    102: "van",
    103: "ship",
    104: "fountain",
    105: "conveyer belt",
    106: "canopy",
    107: "washer",
    108: "plaything",
    109: "swimming pool",
    110: "stool",
    111: "barrel",
    112: "basket",
    113: "waterfall",
    114: "tent",
    115: "bag",
    116: "minibike",
    117: "cradle",
    118: "oven",
    119: "ball",
    120: "food",
    121: "step",
    122: "tank",
    123: "trade name",
    124: "microwave",
    125: "pot",
    126: "animal",
    127: "bicycle",
    128: "lake",
    129: "dishwasher",
    130: "screen",
    131: "blanket",
    132: "sculpture",
    133: "hood",
    134: "sconce",
    135: "vase",
    136: "traffic light",
    137: "tray",
    138: "ashcan",
    139: "fan",
    140: "pier",
    141: "crt screen",
    142: "plate",
    143: "monitor",
    144: "bulletin board",
    145: "shower",
    146: "radiator",
    147: "glass",
    148: "clock",
    149: "flag",
}

In [None]:
def _convert_scene_output_to_glb(outdir, out_name, imgs, pts3d, mask, focals, cams2world, cam_size=0.05,
                                 cam_color=None, as_pointcloud=False, transparent_cams=False):
    assert len(pts3d) == len(mask) <= len(imgs) <= len(cams2world) == len(focals)
    pts3d = to_numpy(pts3d)
    imgs = to_numpy(imgs)
    focals = to_numpy(focals)
    cams2world = to_numpy(cams2world)
    
    scene = trimesh.Scene()

    pts = np.concatenate([p[m] for p, m in zip(pts3d, mask)])
    col = np.concatenate([p[m] for p, m in zip(imgs, mask)])
    pct = trimesh.PointCloud(pts.reshape(-1, 3), colors=col.reshape(-1, 3))
    scene.add_geometry(pct)
    pct.export(outdir + out_name)

def get_3D_model_from_scene(outdir, out_name, scene, min_conf_thr=3, as_pointcloud=False, mask_sky=False,
                            clean_depth=False, transparent_cams=False, cam_size=0.05):
    """
    extract 3D_model (glb file) from a reconstructed scene
    """
    if scene is None:
        return None
    # post processes
    if clean_depth:
        scene = scene.clean_pointcloud()
    if mask_sky:
        scene = scene.mask_sky()

    # get optimized values from scene
    rgbimg = scene.imgs
    focals = scene.get_focals().cpu()
    cams2world = scene.get_im_poses().cpu()
    # 3D pointcloud from depthmap, poses and intrinsics
    pts3d = to_numpy(scene.get_pts3d())
    scene.min_conf_thr = float(scene.conf_trf(torch.tensor(min_conf_thr)))
    msk = to_numpy(scene.get_masks())
    return _convert_scene_output_to_glb(outdir, out_name, rgbimg, pts3d, msk, focals, cams2world, as_pointcloud=as_pointcloud,
                                        transparent_cams=transparent_cams, cam_size=cam_size)


In [None]:
model_path = "/scratch2/yuxili/interiorDesign/dust3r/checkpoints/DUSt3R_ViTLarge_BaseDecoder_512_dpt.pth"
device = "cuda"
batch_size = 1
schedule = "cosine"
lr = 0.01
niter = 300
image_path = "/scratch2/yuxili/interiorDesign/output/livingroom/2024-04-10-19-56-08-experiment/room.png"

model = load_model(model_path, device)
# load_images can take a list of images or a directory
images = load_images(
    [
        image_path,image_path
        # "/scratch2/yuxili/interiorDesign/output/livingroom/2024-04-10-19-56-08-experiment/inpaint/0.png",
        # "/scratch2/yuxili/interiorDesign/output/livingroom/2024-04-10-19-56-08-experiment/inpaint/1.png",
        # "/scratch2/yuxili/interiorDesign/output/livingroom/2024-04-10-19-56-08-experiment/inpaint/3.png",
        # "/scratch2/yuxili/interiorDesign/output/livingroom/2024-04-10-19-56-08-experiment/inpaint/4.png",
        # "/scratch2/yuxili/interiorDesign/output/livingroom/2024-04-10-19-56-08-experiment/inpaint/6.png",
        # "/scratch2/yuxili/interiorDesign/output/livingroom/2024-04-10-19-56-08-experiment/inpaint/7.png",
        # "/scratch2/yuxili/interiorDesign/output/livingroom/2024-04-10-19-56-08-experiment/inpaint/8.png",
    ],
    size=512,
)
pairs = make_pairs(images, scene_graph="complete", prefilter=None, symmetrize=True)
output = inference(pairs, model, device, batch_size=batch_size)

scene = global_aligner(
    output, device=device, mode=GlobalAlignerMode.PointCloudOptimizer
)
loss = scene.compute_global_alignment(init="mst", niter=niter, schedule=schedule, lr=lr)

In [None]:
intrinsics1 = np.zeros((4, 4))
intrinsics1[:3, :3] = scene.get_intrinsics()[0].cpu().detach().numpy()
intrinsics1[3, 3] = 1

intrinsics2 = np.zeros((4, 4))
intrinsics2[:3, :3] = scene.get_intrinsics()[1].cpu().detach().numpy()
intrinsics2[3, 3] = 1

intrinsics = (intrinsics1 + intrinsics2) / 2

In [None]:
# retrieve useful values from scene:
imgs = scene.imgs
focals = scene.get_focals()
poses = scene.get_im_poses()
pts3d = scene.get_pts3d()
confidence_masks = scene.get_masks()
cams2world = scene.get_im_poses()

In [None]:
get_3D_model_from_scene(
    "../../point_cloud/",
    "test_render_2_inpaint.ply",
    scene,
    min_conf_thr=3,
    as_pointcloud=True,
    mask_sky=False,
    clean_depth=True,
    transparent_cams=False,
    cam_size=0.05,
)

In [None]:
point_cloud = o3d.io.read_point_cloud("/scratch2/yuxili/interiorDesign/output/livingroom/2024-04-10-19-56-08-experiment/room.ply")
cl, ind = point_cloud.remove_radius_outlier(nb_points=50, radius=0.1)


In [None]:
o3d.io.write_point_cloud("output.ply", cl)

In [None]:
class PointCloud:
    def __init__(self, point_cloud_path):
        pcd = o3d.io.read_point_cloud(point_cloud_path)
        self.points = np.asarray(pcd.points)
        self.num_points = self.points.shape[0]
        self.colors = np.asarray(pcd.colors)

    def get_homogeneous_coordinates(self):
        return np.append(self.points, np.ones((self.num_points, 1)), axis=-1)

In [None]:
point_cloud = PointCloud("../../point_cloud/test.ply")

In [None]:
# poses = cams2world[0].cpu().detach().numpy()
poses = np.eye(4)
X = point_cloud.get_homogeneous_coordinates()
n_points = point_cloud.num_points
depth = scene.get_depthmaps()[0].cpu().detach().numpy()
intrinsic = intrinsics

projected_points = np.zeros((n_points, 2), dtype=int)
visible_points_view = np.zeros((n_points), dtype=bool)
print(f"[INFO] Computing the visible points in each view.")

# *******************************************************************************************************************
# STEP 1: get the projected points
# Get the coordinates of the projected points in the i-th view (i.e. the view with index idx)
projected_points_not_norm = (intrinsic @ poses @ X.T).T
# Get the mask of the points which have a non-null third coordinate to avoid division by zero
mask = projected_points_not_norm[:, 2] != 0

# don't do the division for point with the third coord equal to zero
# Get non homogeneous coordinates of valid points (2D in the image)

projected_points[mask] = np.column_stack(
    [
        [
            projected_points_not_norm[:, 0][mask]
            / projected_points_not_norm[:, 2][mask],
            projected_points_not_norm[:, 1][mask]
            / projected_points_not_norm[:, 2][mask],
        ]
    ]
).T

In [None]:
image = Image.open(
    "/scratch2/yuxili/interiorDesign/output/bedroom/2024-04-02-13-25-53-experiment/bedroom.png"
)
# resize image
cache_dir = "/scratch2/yuxili/interiorDesign/huggingface/"

processor = OneFormerProcessor.from_pretrained(
    "shi-labs/oneformer_ade20k_swin_large", cache_dir=cache_dir
)
model = OneFormerForUniversalSegmentation.from_pretrained(
    "shi-labs/oneformer_ade20k_swin_large", cache_dir=cache_dir
)

# Semantic Segmentation
semantic_inputs = processor(images=image, task_inputs=["semantic"], return_tensors="pt")
semantic_outputs = model(**semantic_inputs)
# pass through image_processor for postprocessing
predicted_semantic_map = processor.post_process_semantic_segmentation(
    semantic_outputs, target_sizes=[image.size[::-1]]
)[0]

predicted_semantic_map = predicted_semantic_map.float()  # Convert tensor to float
predicted_semantic_map = predicted_semantic_map.unsqueeze(0).unsqueeze(0)
downsampled_map = F.interpolate(
    predicted_semantic_map, size=(512, 512), mode="bilinear", align_corners=False
)
predicted_semantic_map = downsampled_map.squeeze(0).squeeze(0).cpu().detach().numpy()

In [None]:
pc_color = point_cloud.colors

for idx, point in enumerate(projected_points):
    pc_color[idx] = np.array(
        [
            int(element)
            for element in ade20K.iloc[
                int(predicted_semantic_map.T[point[0], point[1]])
            ]["Color_Code (R,G,B)"]
            .strip("()")
            .split(",")
        ]
    )

point_cloud.colors = pc_color

In [None]:
import numpy as np
import open3d as o3d

# Assuming 'points' and 'colors' are your arrays containing the point coordinates and colors
points = np.asarray(point_cloud.points)
colors = np.asarray(point_cloud.colors) / point_cloud.colors.max()

# Create a PointCloud object
pcd = o3d.geometry.PointCloud()

# Assign the points and colors to the PointCloud object
pcd.points = o3d.utility.Vector3dVector(points)
pcd.colors = o3d.utility.Vector3dVector(colors)

# Save the point cloud to a file
o3d.io.write_point_cloud("output.ply", pcd)

In [None]:
clolor_map = np.zeros((517, 517, 3), dtype=np.uint8)
for idx, point in enumerate(projected_points):
    clolor_map[point[0], point[1], :] = point_cloud.colors[idx] * 255

In [None]:
from matplotlib import pyplot as plt

plt.imshow(clolor_map)

In [None]:
# print(projected_points[:, 0].min())
# print(projected_points[:, 0].max())
# print(projected_points[:, 1].min())
# print(projected_points[:, 1].max())

In [None]:
plt.imshow(predicted_semantic_map.T)

In [None]:
# def _convert_scene_output_to_glb(
#     outdir,
#     out_name,
#     imgs,
#     pts3d,
#     mask,
#     focals,
#     cams2world,
#     cam_size=0.05,
#     cam_color=None,
#     as_pointcloud=False,
#     transparent_cams=False,
# ):
#     assert len(pts3d) == len(mask) <= len(imgs) <= len(cams2world) == len(focals)
#     pts3d = to_numpy(pts3d)
#     imgs = to_numpy(imgs)
#     focals = to_numpy(focals)
#     cams2world = to_numpy(cams2world)

#     scene = trimesh.Scene()

#     pts = np.concatenate([p[m] for p, m in zip(pts3d, mask)])
#     col = np.concatenate([p[m] for p, m in zip(imgs, mask)])
#     pct = trimesh.PointCloud(pts.reshape(-1, 3), colors=col.reshape(-1, 3))

#     pct.export(outdir + out_name)
#     scene.add_geometry(pct)

#     # add each camera
#     for i, pose_c2w in enumerate(cams2world):
#         if isinstance(cam_color, list):
#             camera_edge_color = cam_color[i]
#         else:
#             camera_edge_color = cam_color or CAM_COLORS[i % len(CAM_COLORS)]
#         add_scene_cam(
#             scene,
#             pose_c2w,
#             camera_edge_color,
#             None if transparent_cams else imgs[i],
#             focals[i],
#             imsize=imgs[i].shape[1::-1],
#             screen_width=cam_size,
#         )


# def get_3D_model_from_scene(
#     outdir,
#     out_name,
#     scene,
#     min_conf_thr=3,
#     as_pointcloud=False,
#     mask_sky=False,
#     clean_depth=False,
#     transparent_cams=False,
#     cam_size=0.05,
# ):
#     """
#     extract 3D_model (glb file) from a reconstructed scene
#     """
#     if scene is None:
#         return None
#     # post processes
#     if clean_depth:
#         scene = scene.clean_pointcloud()
#     if mask_sky:
#         scene = scene.mask_sky()

#     # get optimized values from scene
#     rgbimg = scene.imgs
#     focals = scene.get_focals().cpu()
#     cams2world = scene.get_im_poses().cpu()
#     # 3D pointcloud from depthmap, poses and intrinsics
#     pts3d = to_numpy(scene.get_pts3d())
#     scene.min_conf_thr = float(scene.conf_trf(torch.tensor(min_conf_thr)))
#     msk = to_numpy(scene.get_masks())
#     return _convert_scene_output_to_glb(
#         outdir,
#         out_name,
#         rgbimg,
#         pts3d,
#         msk,
#         focals,
#         cams2world,
#         as_pointcloud=as_pointcloud,
#         transparent_cams=transparent_cams,
#         cam_size=cam_size,
#     )