In [None]:
import torch
import matplotlib.pyplot as plt

from scipy.stats import linregress

In [None]:
results_dir = "anacal_results_density240/"

In [None]:
catalogs = torch.load(results_dir + "catalogs.pt")
true_shear1 = torch.tensor([cat[0]['g1'] for cat in catalogs])
true_shear2 = torch.tensor([cat[0]['g2'] for cat in catalogs])

In [None]:
num_images = true_shear1.shape[0]
print(f"Number of images: {num_images}")

In [None]:
e1_sum = torch.load(results_dir + "e1_sum.pt")[:num_images]
e1g1_sum = torch.load(results_dir + "e1g1_sum.pt")[:num_images]
e2_sum = torch.load(results_dir + "e2_sum.pt")[:num_images]
e2g2_sum = torch.load(results_dir + "e2g2_sum.pt")[:num_images]
num_detections = torch.load(results_dir + "num_detections.pt")[:num_images]

In [None]:
e1_avg = e1_sum / num_detections
e2_avg = e2_sum / num_detections
R1 = e1g1_sum.sum() / num_detections.sum()
R2 = e2g2_sum.sum() / num_detections.sum()
print(f"Estimated R1 = {R1}")
print(f"Estimated R2 = {R2}")
est_shear1 = e1_avg / R1
est_shear2 = e2_avg / R2

In [None]:
fig, ax = plt.subplots(1, 2, figsize = (9, 4))

_ = ax[0].scatter(true_shear1, est_shear1, color = 'forestgreen', alpha = 0.25)
_ = ax[0].axline((0,0), slope = 1, linestyle = 'dashed', color = 'black')
_ = ax[0].set_xlabel("$\gamma_1$")
_ = ax[0].set_ylabel("$\widehat{\gamma}_1$")
_ = ax[0].set_xlim(-0.075, 0.075)
_ = ax[0].set_ylim(-0.075, 0.075)

_ = ax[1].scatter(true_shear2, est_shear2, color = 'forestgreen', alpha = 0.25)
_ = ax[1].axline((0,0), slope = 1, linestyle = 'dashed', color = 'black')
_ = ax[1].set_xlabel("$\gamma_2$")
_ = ax[1].set_ylabel("$\widehat{\gamma}_2$")
_ = ax[1].set_xlim(-0.075, 0.075)
_ = ax[1].set_ylim(-0.075, 0.075)

_ = fig.tight_layout()

In [None]:
lr1 = linregress(true_shear1.flatten().cpu().numpy(), est_shear1.flatten().cpu().numpy())
lr2 = linregress(true_shear2.flatten().cpu().numpy(), est_shear2.flatten().cpu().numpy())

print(f"Shear 1:\nc ± 3SE = {lr1.intercept:.6f} ± {3 * lr1.intercept_stderr:.6f}, m ± 3SE = {lr1.slope - 1:.6f} ± {3*lr1.stderr}\n")
print(f"Shear 2:\nc ± 3SE = {lr2.intercept:.6f} ± {3 * lr2.intercept_stderr:.6f}, m ± 3SE = {lr2.slope - 1:.6f} ± {3*lr2.stderr}")