In [None]:
import os
import torch
import numpy as np
import matplotlib.pyplot as plt
from normal_m import UNet
from PIL import Image
import torch.nn.functional as F
from scipy.sparse import lil_matrix
from scipy.sparse.linalg import spsolve

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
np.random.seed(0)
def predict(image_paths):
    images = []

    for img_path in image_paths:
        img = Image.open(img_path).convert('L')
        img = img.resize((1024, 1024))
        img = np.array(img).astype(np.float32)  
        #裁取中间512*512的部分
        # img = img[256:768, 256:768]
        images.append(torch.from_numpy(img))
    input_tensor = torch.stack(images, dim=0)  # [4, H, W]
    input_tensor = input_tensor.unsqueeze(0).to(device).float()  # [1, 4, H, W]

    with torch.no_grad():
        normal_map = model(input_tensor)  # [1, 3, H, W]
        normal_map = normal_map.squeeze(0).cpu().numpy()  # [3, H, W]

    return normal_map

def poisson_reconstruct(normal_map):
    nx = normal_map[0]
    ny = normal_map[1]
    nz = normal_map[2]

    # 避免除以零或接近零的值
    epsilon = 1e-8
    nz = np.where(nz == 0, epsilon, nz)

    fx = nx / nz
    fy = ny / nz

    fx[np.isnan(fx)] = 0
    fx[np.isinf(fx)] = 0
    fy[np.isnan(fy)] = 0
    fy[np.isinf(fy)] = 0

    fxx = np.gradient(fx, axis=1)
    fyy = np.gradient(fy, axis=0)
    f = fxx + fyy

    depth = poisson_solver_function(f)

    return depth

def poisson_solver_function(f):
    h, w = f.shape
    n = h * w
    A = lil_matrix((n, n))
    b = f.flatten()

    # 构建稀疏矩阵 A
    for y in range(h):
        for x in range(w):
            idx = x + y * w
            A[idx, idx] = 4
            if x > 0:
                A[idx, idx - 1] = -1
            if x < w - 1:
                A[idx, idx + 1] = -1
            if y > 0:
                A[idx, idx - w] = -1
            if y < h - 1:
                A[idx, idx + w] = -1

    A = A.tocsr()
    # 求解线性方程组
    u = spsolve(A, b)
    u = u.reshape((h, w))
    return u

def visualize_result(predicted_depth, true_depth, epoch):
    fig, axs = plt.subplots(1, 2, figsize=(12, 6))
    axs[0].imshow(predicted_depth, cmap='viridis')
    axs[0].set_title(f'Predicted Depth {epoch}')
    axs[0].axis('off')
    axs[1].imshow(true_depth, cmap='viridis')
    axs[1].set_title('True Depth')
    axs[1].axis('off')
    plt.tight_layout()
    plt.show()

if __name__ == "__main__":
    model_nums = [55000]
    for group_num in range(1,5):
        if group_num == 6:
            continue
        # if group_num != 11:
        #     continue
        for model_num in model_nums:
            model_path = f"model/unet_epoch_{model_num}.pth" 
            model = UNet()
            model.load_state_dict(torch.load(model_path, map_location=device))
            model = model.to(device)
            model.eval()
            num = group_num
            # image_paths = [
            #     f'test/{num}-0.png',
            #     f'test/{num}-90.png',
            #     f'test/{num}-180.png',
            #     f'test/{num}-270.png',
            # ]
            image_paths = [
                f"dataset/image/{num}-0.png",
                f"dataset/image/{num}-90.png",
                f"dataset/image/{num}-180.png",
                f"dataset/image/{num}-270.png",
            ]
            true_depth = np.load(f"dataset_transformed/depth/{num}.npy")
            normal_map = predict(image_paths)  # [3, H, W]
            predicted_depth = poisson_reconstruct(normal_map)
            np.save(f"{num}-{model_num}", predicted_depth)
            visualize_result(predicted_depth, true_depth, model_num)
