Skip to content

Commit

Permalink
Refactored sensor module.
Browse files Browse the repository at this point in the history
  • Loading branch information
simonpf committed May 17, 2024
1 parent b0f6e8f commit 7fcbb89
Show file tree
Hide file tree
Showing 10 changed files with 283 additions and 2,249 deletions.
52 changes: 32 additions & 20 deletions gprof_nn/data/training_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,11 +462,11 @@ def load_tbs_1d_xtrack_sim(
tbs = training_data.simulated_brightness_temperatures.data

tbs_full = np.nan * np.zeros((tbs.shape[0], 15), dtype=np.float32)
tbs_full[:, sensor.gmi_channels] = tbs
tbs_full[:, sensor.gprof_channel_indices] = tbs

biases = training_data.brightness_temperature_biases.data
biases_full = np.nan * np.zeros((tbs.shape[0], 15), dtype=np.float32)
biases_full[:, sensor.gmi_channels] = biases
biases_full[:, sensor.gprof_channel_indices] = biases

biases = (
biases_full /
Expand All @@ -476,7 +476,7 @@ def load_tbs_1d_xtrack_sim(
)
)

return torch.tensor(tbs_full - biases)
return torch.tensor(tbs_full + biases)


def load_tbs_1d_conical_sim(
Expand Down Expand Up @@ -506,7 +506,7 @@ def load_tbs_1d_conical_sim(
tbs = training_data.simulated_brightness_temperatures.data
biases = training_data.brightness_temperature_biases.data

tbs = tbs - biases
tbs = tbs + biases
return torch.tensor(tbs)


Expand All @@ -529,10 +529,10 @@ def load_tbs_1d_xtrack_other(
"""
tbs = training_data["brightness_temperatures"].data
tbs_full = np.nan * np.zeros((tbs.shape[0], 15), dtype=np.float32)
tbs_full[:, sensor.gmi_channels] = tbs
tbs_full[:, sensor.gprof_channel_indices] = tbs
angles = training_data["earth_incidence_angle"].data
angles_full = np.nan * np.zeros_like(tbs_full)
angles_full[:, sensor.gprof_channels] = angles[..., None]
angles_full[:, sensor.gprof_channel_indices] = angles[..., None]

tbs = torch.tensor(tbs_full.astype("float32"))
angles = torch.tensor(angles_full.astype("float32"))
Expand All @@ -558,10 +558,10 @@ def load_tbs_1d_conical_other(
"""
tbs = training_data["brightness_temperatures"].data
tbs_full = np.nan * np.ones(tbs.shape[:-1] + (15,), dtype="float32")
tbs_full[:, sensor.gprof_channels] = tbs
tbs_full[:, sensor.gprof_channel_indices] = tbs
angles = training_data["earth_incidence_angle"].data
angles_full = np.nan * np.ones(tbs.shape[:-1] + (15,), dtype="float32")
angles_full[:, sensor.gprof_channels] = angles
angles_full[:, sensor.gprof_channel_indices] = angles
tbs = torch.tensor(tbs_full.astype("float32"))
angles = torch.tensor(angles_full.astype("float32"))
return tbs, angles
Expand Down Expand Up @@ -893,7 +893,10 @@ def load_training_data_3d_gmi(
"ancillary_data": anc
}

scene = scene.transpose("levels", "scans", "pixels", ...)
dims = ("scans", "pixels")
if "levels" in scene.dims:
dims = ("levels",) + dims
scene = scene.transpose(*dims, ...)
y = {}
for target in targets:
# MRMS collocations don't contain all targets.
Expand Down Expand Up @@ -983,23 +986,25 @@ def load_training_data_3d_xtrack_sim(
angs = sensor.viewing_geometry.get_earth_incidence_angles()
angs = angs[j_start:j_end]
angs = np.repeat(angs.reshape(1, -1), height, axis=0)
weights = calculate_interpolation_weights(np.abs(angs), sensor.angles)

angles = scene.angles.data
weights = calculate_interpolation_weights(np.abs(angs), angles)
weights = np.repeat(weights.reshape(1, -1, sensor.n_angles), height, axis=0)
weights = calculate_interpolation_weights(np.abs(angs), scene.angles.data)
weights = calculate_interpolation_weights(np.abs(angs), angles)

# Calculate brightness temperatures
# Calculate brightness temperatures
tbs_sim = scene.simulated_brightness_temperatures.data
tbs_sim = interpolate(tbs_sim, weights)
tb_biases = scene.brightness_temperature_biases.data
tbs = tbs_sim - tb_biases
tbs = tbs_sim + tb_biases

full_shape = tbs_sim.shape[:2] + (15,)
tbs_full = np.nan * np.ones(full_shape, dtype="float32")
tbs_full[:, :, sensor.gmi_channels] = tbs
tbs_full[:, :, sensor.gprof_channel_indices] = tbs
tbs_full = torch.permute(torch.tensor(tbs_full), (2, 0, 1))

angs_full = np.nan * np.ones(full_shape, dtype="float32")
angs_full[:, :, sensor.gmi_channels] = angs[..., None]
angs_full[:, :, sensor.gprof_channel_indices] = angs[..., None]
angs_full = torch.permute(torch.tensor(angs_full), (2, 0, 1))

anc = torch.tensor(np.stack(
Expand All @@ -1023,7 +1028,13 @@ def load_training_data_3d_xtrack_sim(
y[target] = empty
continue

data = scene[target].permute(("levels", "scans", "pixels", "angles"))
dims = ("scans", "pixels")
if "levels" in scene[target].dims:
dims = ("levels",) + dims
if "angles" in scene[target].dims:
dims = dims + ("angles",)
data = scene[target].transpose(*dims)

data = data.data.astype("float32")

if "angles" in scene[target].dims:
Expand Down Expand Up @@ -1106,12 +1117,13 @@ def load_training_data_3d_conical_sim(
# Calculate brightness temperatures
tbs_sim = scene.simulated_brightness_temperatures.data
tb_biases = scene.brightness_temperature_biases.data
tbs = torch.tensor(tbs_sim - tb_biases, dtype=torch.float32)
tbs = torch.tensor(tbs_sim + tb_biases, dtype=torch.float32)
tbs = torch.permute(tbs, (2, 0, 1))

angs_full = np.broadcast_to(EIA_GMI.astype("float32")[0][..., None, None], tbs.shape).copy()
gprof_indices = sensor.gprof_channel_indices
for ind in range(15):
if ind not in sensor.gmi_channels:
if ind not in gprof_indices:
angs_full[ind] = np.nan
angs_full = torch.tensor(angs_full)

Expand Down Expand Up @@ -1207,7 +1219,7 @@ def load_training_data_3d_other(
full_shape = tbs.shape[:2] + (15,)
if tbs.shape != full_shape:
tbs_full = np.nan * np.ones(full_shape, dtype="float32")
tbs_full[:, :, sensor.gmi_channels] = tbs
tbs_full[:, :, sensor.gprof_channel_indices] = tbs
else:
tbs_full = tbs
tbs_full = torch.permute(torch.tensor(tbs_full), (2, 0, 1))
Expand All @@ -1217,7 +1229,7 @@ def load_training_data_3d_other(
angs = angs[..., None]
if tbs.shape != full_shape:
angs_full = np.nan * np.ones(full_shape, dtype="float32")
angs_full[:, :, sensor.gmi_channels] = angs
angs_full[:, :, sensor.gprof_channel_indices] = angs
else:
angs_full = angs
angs_full = torch.permute(torch.tensor(angs_full), (2, 0, 1))
Expand Down
Loading

0 comments on commit 7fcbb89

Please sign in to comment.