In [1]:
import torch
from PIL import Image
import numpy as np
from depth_to_normal import Depth2Normal
import csv
import os
import matplotlib.pyplot as plt
import re

device = "cuda:0"

depth2normal = Depth2Normal().to(device)

In [2]:
TRAIN_OR_TEST = "train"

DIR_PATH = "/scratchdata/nyu_data/"
DATA_FILE = "data/nyu2_{}.csv".format(TRAIN_OR_TEST)
TARGET_FILE = "data/custom_{}.csv".format(TRAIN_OR_TEST)

INTRINSIC = np.array([
    [518.8579, 0, 325.5824, 0],
    [0,	518.8579, 253.7362, 0],
    [0, 0, 1, 0],
    [0, 0, 0, 1]
], dtype=np.float32)

DEPTH_SCALE = 1000.0
MAX_DEPTH = 10.0

In [3]:
with open(os.path.join(DIR_PATH,DATA_FILE), "r") as f:
    reader = csv.reader(f)
    data = list(reader)

for i in range(len(data)):
    intrinsic = torch.tensor(INTRINSIC, dtype=torch.float32).to(device)
    intrinsic = intrinsic.unsqueeze(0)

    depth_path = os.path.join(DIR_PATH,data[i][1])
    depth = Image.open(depth_path)
    depth = torch.Tensor(np.array(depth)).to(device)
    depth = depth.unsqueeze(0).unsqueeze(0)

    mask = (depth > depth.min()) & (depth < depth.max())

    normal, mask = depth2normal(depth, intrinsic, mask, 1)

    # Save as png

    normal = normal.squeeze(0).cpu().numpy()
    mask = mask.squeeze(0).squeeze(0).cpu().numpy()

    normal = normal.transpose(1,2,0)
    mask = mask.transpose(0,1)

    normal = Image.fromarray((normal * 255).astype(np.uint8))
    mask = Image.fromarray((mask * 255).astype(np.uint8))

    normal_path = re.sub("depth", "normal", depth_path)
    mask_path = re.sub("depth", "mask", depth_path)

    normal.save(normal_path)
    mask.save(mask_path)

    data[i].append(re.sub("depth", "normal", data[i][1]))
    data[i].append(re.sub("depth", "mask", data[i][1]))
    data[i].append(INTRINSIC[0,0])
    data[i].append(INTRINSIC[1,1])
    data[i].append(INTRINSIC[0,2])
    data[i].append(INTRINSIC[1,2])
    data[i].append(DEPTH_SCALE)
    data[i].append(MAX_DEPTH)

In [4]:
# Save data to csv

with open(os.path.join(DIR_PATH,TARGET_FILE), "w") as f:
    writer = csv.writer(f)
    writer.writerows(data)