In [46]:
import matplotlib
import torch
import h5py
import torch.nn.functional as F

def interpolate_first_dimension(tensor, new_length):
    """
    Interpolates the first dimension of a 3D tensor to a new length.
    
    Args:
        tensor (torch.Tensor): The input tensor of shape (D, H, W).
        new_length (int): The desired length for the first dimension.
        
    Returns:
        torch.Tensor: The interpolated tensor of shape (new_length, H, W).
    """
    # Reshape the tensor to (1, 1, D, H, W) for interpolation
    tensor_reshaped = tensor.unsqueeze(0).unsqueeze(0)
    
    # Interpolate using F.interpolate
    interpolated_tensor_reshaped = F.interpolate(
        tensor_reshaped, 
        size=(new_length, tensor_reshaped.size(-2), tensor_reshaped.size(-1)), 
        mode='trilinear', 
        align_corners=True,
    )
    
    # Reshape back to original dimensions (new_length, H, W)
    interpolated_tensor = interpolated_tensor_reshaped.squeeze(0).squeeze(0)
    
    return interpolated_tensor

class NormalizedMSELoss(torch.nn.modules.loss._Loss):
    def __init__(self, reduce="mean"):
        super(NormalizedMSELoss, self).__init__()

        self.reduce = reduce

    def forward(self, x, y, mask):
        one_mask = mask != 0

        error_energy = torch.norm((x - y)*one_mask, p=2, dim=(-1, -2))
        field_energy = torch.norm(y*one_mask, p=2, dim=(-1, -2)) + 1e-6
        return (error_energy/field_energy).mean()

criterien = NormalizedMSELoss()
with h5py.File("../data/fdtd/raw_diff_res/mmi_3x3_L_swp_res-0000-p0.h5", "r") as f:
    p0_ground_truth = torch.from_numpy(
        f["Ez"][ : ][()]
    ).float()
with h5py.File("../data/fdtd/raw_diff_res/mmi_3x3_L_swp_res-0000-p1.h5", "r") as f:
    p1_ground_truth = torch.from_numpy(
        f["Ez"][ : ][()]
    ).float()
with h5py.File("../data/fdtd/raw_diff_res/mmi_3x3_L_swp_res-0000-p2.h5", "r") as f:
    p2_ground_truth = torch.from_numpy(
        f["Ez"][ : ][()]
    ).float()
target_size = int(round(p0_ground_truth.shape[-2]*0.75*(256/534.75))), int(round(p0_ground_truth.shape[-1]*0.75*(256/534.75)))
p0_ground_truth = torch.nn.functional.interpolate(p0_ground_truth.unsqueeze(0), target_size, mode="bilinear").squeeze(0)
p1_ground_truth = torch.nn.functional.interpolate(p1_ground_truth.unsqueeze(0), target_size, mode="bilinear").squeeze(0)
p2_ground_truth = torch.nn.functional.interpolate(p2_ground_truth.unsqueeze(0), target_size, mode="bilinear").squeeze(0)
# p0_ground_truth = interpolate_first_dimension(p0_ground_truth, 800)
# p1_ground_truth = interpolate_first_dimension(p1_ground_truth, 800)
# p2_ground_truth = interpolate_first_dimension(p2_ground_truth, 800)
# p0_ground_truth = p0_ground_truth[:720]
# p1_ground_truth = p1_ground_truth[:720]
# p2_ground_truth = p2_ground_truth[:720]

In [47]:
device_name = 'mmi_3x3_L_swp_res-000'
p0_error = []
p1_error = []
p2_error = []
for id in range(10):
    for port in range(3):

        file_name = f"../data/fdtd/raw_diff_res/{device_name}{id}-p{port}.h5"
        with h5py.File(file_name, "r") as f:
            data = torch.from_numpy(
                f["Ez"][ : ][()]
            ).float()
            print("this is the shape of data:", data.shape)
        # data = interpolate_first_dimension(data, 800)
        # data = data[:720]
        # interpolate the last two dimensions to the target size
        # data = torch.nn.functional.interpolate(data.unsqueeze(0), target_size, mode="bilinear").squeeze(0)
        # if port == 0:
        #     p0_error.append(criterien(data.unsqueeze(0), p0_ground_truth.unsqueeze(0), torch.ones_like(p0_ground_truth[-1].unsqueeze(0).unsqueeze(0))))
        # elif port == 1:
        #     p1_error.append(criterien(data.unsqueeze(0), p1_ground_truth.unsqueeze(0), torch.ones_like(p1_ground_truth[-1].unsqueeze(0).unsqueeze(0))))
        # elif port == 2:
        #     p2_error.append(criterien(data.unsqueeze(0), p2_ground_truth.unsqueeze(0), torch.ones_like(p2_ground_truth[-1].unsqueeze(0).unsqueeze(0))))

print("p0 error:", p0_error)
print("p1 error:", p1_error)
print("p2 error:", p2_error)

this is the shape of data: torch.Size([833, 620, 160])
this is the shape of data: torch.Size([833, 620, 160])
this is the shape of data: torch.Size([833, 620, 160])
this is the shape of data: torch.Size([818, 558, 144])


RuntimeError: The size of tensor a (818) must match the size of tensor b (833) at non-singleton dimension 1