In [None]:
import numpy as np
import os
from PIL import Image

from scipy.ndimage import shift
from skimage import io

import matplotlib.pyplot as plt
%matplotlib inline
from astropy.io import fits

import ipywidgets as widgets
from ipywidgets import interactive

from timepix_geometry_correction.config import default_config_timepix1 as config


In [None]:
# image1 = os.path.join("data",  "siemens_star.tif")
image1 = os.path.join("data",  "raw_siemens_star.fits")


assert os.path.exists(image1), f"File does not exist: {image1}"

# Load images
# data_images1 = np.array(Image.open(image1))

with fits.open(image1) as hdlu:
    data_images1 = np.array(hdlu[0].data)

In [None]:
fig, ax = plt.subplots(figsize=(10, 10))
ax.imshow(data_images1, cmap="gray")

In [None]:
print(config)
chip_size = (256, 256)
print(f"{np.shape(data_images1)=}")


# Applying shift correction based on the configuration

In [None]:
# max_xoffset = int(np.ceil(np.max([shift_config['chip1']['xoffset'],
#                       shift_config['chip4']['xoffset']])))
# max_yoffset = int(np.ceil(np.max([shift_config['chip3']['yoffset'],
#                       shift_config['chip4']['yoffset']])))

def apply_shift_correction(image, shift_config):

    image[np.isnan(image)] = 0
    image[np.isinf(image)] = 0

    # create an empty array for new image
    # new_image = np.zeros_like(image)
    new_image = np.zeros((image.shape[0], image.shape[1]))

    # chip 2 (fixed one)
    new_image[0:256, 0:256] = image[0:256, 0:256]

    # chip 1
    region = image[0:256, 256:]
    chips1_shift = (shift_config['chip1']['yoffset'], shift_config['chip1']['xoffset'])
    shifted_data = shift(region, shift=chips1_shift, order=3)
    new_image[0:256, 256:] = shifted_data

    # chip 3
    region = image[256:, 0:256]
    chips3_shift = (shift_config['chip3']['yoffset'], shift_config['chip3']['xoffset'])
    shifted_data = shift(region, shift=chips3_shift, order=3)
    new_image[256:, 0:256] = shifted_data

    # chip 4
    region = image[256:, 256:]
    chips4_shift = (shift_config['chip4']['yoffset'], shift_config['chip4']['xoffset'])
    shifted_data = shift(region, shift=chips4_shift, order=3)
    new_image[256:, 256:] = shifted_data

    return new_image

In [None]:
corrected = apply_shift_correction(data_images1, config)
raw_image = data_images1.copy()

fig, axs = plt.subplots(ncols=1, figsize=(10, 10))
axs.imshow(corrected, cmap='viridis')
axs.set_title("Corrected Image")

In [None]:

zoom_size = 15
region_zoomed = [248, 442, 245, 260]

def plot_zoomed(y=100, x=100, zoom_size=15):

    fig, axs = plt.subplots(ncols=2, figsize=(15, 10))
    axs[0].imshow(corrected[y: y+zoom_size, x: x+zoom_size])
    # show grid
    axs[0].grid(True)

    axs[1].imshow(corrected)
    # display the zoomed region
    rectangle = plt.Rectangle((x, y), zoom_size, zoom_size, edgecolor='red', facecolor='none')
    axs[1].add_patch(rectangle)
    plt.show()

display_zoom = interactive(plot_zoomed, y=widgets.IntSlider(min=0, max=512-zoom_size, step=1, value=region_zoomed[0]),
                                       x=widgets.IntSlider(min=0, max=512-zoom_size, step=1, value=region_zoomed[2]),
                                       zoom_size=widgets.IntSlider(min=5, max=50, step=1, value=zoom_size))
display(display_zoom)


In [None]:
# let's compare the corrected image with the original one
fig, axs = plt.subplots(ncols=2, figsize=(15, 10))
axs[0].imshow(corrected, cmap='gray')
axs[0].set_title("Corrected Image")
axs[1].imshow(raw_image, cmap='gray')
axs[1].set_title("Original Image")

In [None]:
# let's make sure the counts are not affected by the shift correction
print("Total counts in original image:", np.sum(corrected))
print("Total counts in corrected image:", np.sum(raw_image))

## Let's make sure the counts are not affected by the shift correction

In [None]:
box_width = 200
box_height = 200

# chip1 - red
box_uncorrected_chip1 = {'x0': 300, 'y0': 10, 'x1': 300+box_width, 'y1': 10+box_height}
box_corrected_chip1 = {'x0': box_uncorrected_chip1['x0'] + config['chip1']['xoffset'], 
                   'y0': box_uncorrected_chip1['y0'] + config['chip1']['yoffset'], 
                   'x1': box_uncorrected_chip1['x1'] + config['chip1']['xoffset'], 
                   'y1': box_uncorrected_chip1['y1'] + config['chip1']['yoffset']}

# chip2 - blue
box_uncorrected_chip2 = {'x0': 50, 'y0': 10, 'x1': 50+box_width, 'y1': 10+box_height}
box_corrected_chip2 = {'x0': box_uncorrected_chip2['x0'] , 
                        'y0': box_uncorrected_chip2['y0'], 
                        'x1': box_uncorrected_chip2['x1'], 
                        'y1': box_uncorrected_chip2['y1']}

# chip3 - green
box_uncorrected_chip3 = {'x0': 50, 'y0': 300, 'x1': 50+box_width, 'y1': 300+box_height}
box_corrected_chip3 = {'x0': box_uncorrected_chip3['x0']+ config['chip3']['xoffset'],
                       'y0': box_uncorrected_chip3['y0']+ config['chip3']['yoffset'],
                       'x1': box_uncorrected_chip3['x1']+ config['chip3']['xoffset'],
                       'y1': box_uncorrected_chip3['y1']+ config['chip3']['yoffset']}

# chip4 - yellow
box_uncorrected_chip4 = {'x0': 300, 'y0': 300, 'x1': 300+box_width, 'y1': 300+box_height}
box_corrected_chip4 = {'x0': box_uncorrected_chip4['x0']+ config['chip4']['xoffset'],
                       'y0': box_uncorrected_chip4['y0']+ config['chip4']['yoffset'],
                       'x1': box_uncorrected_chip4['x1']+ config['chip4']['xoffset'],
                       'y1': box_uncorrected_chip4['y1']+ config['chip4']['yoffset']}

fig, axs = plt.subplots(ncols=2, figsize=(15, 10))
axs[0].imshow(raw_image, cmap='viridis')
axs[0].set_title("Corrected Image")
axs[1].imshow(corrected, cmap='viridis')
axs[1].set_title("Original Image")

# chip1
rectangle_corrected_chip1 = plt.Rectangle((box_corrected_chip1['x0'], box_corrected_chip1['y0']), box_height, box_width, edgecolor='red', facecolor='none')
axs[0].add_patch(rectangle_corrected_chip1)

rectangle_original_chip1 = plt.Rectangle((box_uncorrected_chip1['x0'], box_uncorrected_chip1['y0']), box_height, box_width, edgecolor='red', facecolor='none')
axs[1].add_patch(rectangle_original_chip1)

# chip2
rectangle_corrected_chip2 = plt.Rectangle((box_corrected_chip2['x0'], box_corrected_chip2['y0']), box_height, box_width, edgecolor='blue', facecolor='none')
axs[0].add_patch(rectangle_corrected_chip2)

rectangle_original_chip2 = plt.Rectangle((box_uncorrected_chip2['x0'], box_uncorrected_chip2['y0']), box_height, box_width, edgecolor='blue', facecolor='none')
axs[1].add_patch(rectangle_original_chip2)

# chip3
rectangle_corrected_chip3 = plt.Rectangle((box_corrected_chip3['x0'], box_corrected_chip3['y0']), box_height, box_width, edgecolor='green', facecolor='none')
axs[0].add_patch(rectangle_corrected_chip3) 

rectangle_original_chip3 = plt.Rectangle((box_uncorrected_chip3['x0'], box_uncorrected_chip3['y0']), box_height, box_width, edgecolor='green', facecolor='none')
axs[1].add_patch(rectangle_original_chip3)

# chip4
rectangle_corrected_chip4 = plt.Rectangle((box_corrected_chip4['x0'], box_corrected_chip4['y0']), box_height, box_width, edgecolor='yellow', facecolor='none')
axs[0].add_patch(rectangle_corrected_chip4)

rectangle_original_chip4 = plt.Rectangle((box_uncorrected_chip4['x0'], box_uncorrected_chip4['y0']), box_height, box_width, edgecolor='yellow', facecolor='none')
axs[1].add_patch(rectangle_original_chip4)

# display the differences between the two images in the selected region
fig2, axs = plt.subplots(ncols=2, nrows=2, figsize=(15, 15))

axs[0, 0].set_title("Difference from chip1 region")
im00 = axs[0, 0].imshow(raw_image[box_uncorrected_chip1['y0']: box_uncorrected_chip1['y1'], box_uncorrected_chip1['x0']: box_uncorrected_chip1['x1']] - 
                 corrected[box_corrected_chip1['y0']: box_corrected_chip1['y1'], box_corrected_chip1['x0']: box_corrected_chip1['x1']], cmap='viridis')
plt.colorbar(im00, ax=axs[0, 0], shrink=0.5)

axs[0, 1].set_title("Difference from chip2 region")
im01 = axs[0, 1].imshow(raw_image[box_uncorrected_chip2['y0']: box_uncorrected_chip2['y1'], box_uncorrected_chip2['x0']: box_uncorrected_chip2['x1']] - 
                 corrected[box_corrected_chip2['y0']: box_corrected_chip2['y1'], box_corrected_chip2['x0']: box_corrected_chip2['x1']], cmap='viridis')
plt.colorbar(im01, ax=axs[0, 1], shrink=0.5)

axs[1, 0].set_title("Difference from chip3 region")
im10 = axs[1, 0].imshow(raw_image[box_uncorrected_chip3['y0']: box_uncorrected_chip3['y1'], box_uncorrected_chip3['x0']: box_uncorrected_chip3['x1']] - 
                 corrected[box_corrected_chip3['y0']: box_corrected_chip3['y1'], box_corrected_chip3['x0']: box_corrected_chip3['x1']], cmap='viridis')
plt.colorbar(im10, ax=axs[1, 0], shrink=0.5)

axs[1, 1].set_title("Difference from chip4 region")
im11 = axs[1, 1].imshow(raw_image[box_uncorrected_chip4['y0']: box_uncorrected_chip4['y1'], box_uncorrected_chip4['x0']: box_uncorrected_chip4['x1']] - 
                 corrected[box_corrected_chip4['y0']: box_corrected_chip4['y1'], box_corrected_chip4['x0']: box_corrected_chip4['x1']], cmap='viridis')
plt.colorbar(im11, ax=axs[1, 1], shrink=0.5)



In [None]:
def fill_chip_gaps(image, shift_config, chip_size=(256, 256)):
    """
    Fill gaps between shifted chips using linear interpolation.

    After shift correction, zero-filled gaps appear at chip boundaries.
    This function fills them using:
      - Linear interpolation for row and column gaps
      - Bilinear interpolation for the corner intersection

    Parameters
    ----------
    image : 2D numpy array
        The shift-corrected image with gaps.
    shift_config : dict
        Chip configuration with xoffset / yoffset for each chip.
    chip_size : tuple
        (height, width) of each chip.

    Returns
    -------
    filled : 2D numpy array
        Image with inter-chip gaps filled by interpolation.
    """
    filled = image.copy().astype(float)
    h, w = chip_size

    # Offsets that define the gap sizes at each boundary
    x_gap_top   = shift_config['chip1']['xoffset']   # vertical gap width  (top half)
    y_gap_left  = shift_config['chip3']['yoffset']   # horizontal gap height (left half)
    x_gap_bot   = shift_config['chip4']['xoffset']   # vertical gap width  (bottom half)
    y_gap_right = shift_config['chip4']['yoffset']   # horizontal gap height (right half)

    max_x_gap = max(x_gap_top, x_gap_bot)
    max_y_gap = max(y_gap_left, y_gap_right)

    # ------------------------------------------------------------------
    # 1. Vertical gap – top half  (between chip 2 and chip 1)
    #    cols w … w+x_gap_top-1,  rows 0 … h-1
    # ------------------------------------------------------------------
    if x_gap_top > 0:
        left  = filled[0:h, w - 1]
        right = filled[0:h, w + x_gap_top]
        for i in range(x_gap_top):
            t = (i + 1) / (x_gap_top + 1)
            filled[0:h, w + i] = (1 - t) * left + t * right

    # ------------------------------------------------------------------
    # 2. Horizontal gap – left half  (between chip 2 and chip 3)
    #    rows h … h+y_gap_left-1,  cols 0 … w-1
    # ------------------------------------------------------------------
    if y_gap_left > 0:
        above = filled[h - 1, 0:w]
        below = filled[h + y_gap_left, 0:w]
        for j in range(y_gap_left):
            t = (j + 1) / (y_gap_left + 1)
            filled[h + j, 0:w] = (1 - t) * above + t * below

    # ------------------------------------------------------------------
    # 3. Horizontal gap – right half  (between chip 1 and chip 4)
    #    rows h … h+y_gap_right-1,  cols w+x_gap_bot … 2w-1
    #    (the corner cols are handled separately in step 5)
    # ------------------------------------------------------------------
    if y_gap_right > 0:
        above = filled[h - 1, w + max_x_gap:2 * w]
        below = filled[h + y_gap_right, w + max_x_gap:2 * w]
        for j in range(y_gap_right):
            t = (j + 1) / (y_gap_right + 1)
            filled[h + j, w + max_x_gap:2 * w] = (1 - t) * above + t * below

    # ------------------------------------------------------------------
    # 4. Vertical gap – bottom half  (between chip 3 and chip 4)
    #    cols w … w+x_gap_bot-1,  rows h+max_y_gap … 2h-1
    #    (the corner rows are handled separately in step 5)
    # ------------------------------------------------------------------
    if x_gap_bot > 0:
        left  = filled[h + max_y_gap:2 * h, w - 1]
        right = filled[h + max_y_gap:2 * h, w + x_gap_bot]
        for i in range(x_gap_bot):
            t = (i + 1) / (x_gap_bot + 1)
            filled[h + max_y_gap:2 * h, w + i] = (1 - t) * left + t * right

    # ------------------------------------------------------------------
    # 5. Corner intersection  (bilinear interpolation)
    #    rows h … h+max_y_gap-1,  cols w … w+max_x_gap-1
    # ------------------------------------------------------------------
    if max_x_gap > 0 and max_y_gap > 0:
        tl = filled[h - 1,       w - 1]           # chip 2 corner
        tr = filled[h - 1,       w + max_x_gap]   # chip 1 side
        bl = filled[h + max_y_gap, w - 1]         # chip 3 side
        br = filled[h + max_y_gap, w + max_x_gap] # chip 4 corner

        for j in range(max_y_gap):
            ty = (j + 1) / (max_y_gap + 1)
            for i in range(max_x_gap):
                tx = (i + 1) / (max_x_gap + 1)
                filled[h + j, w + i] = (
                    tl * (1 - tx) * (1 - ty) +
                    tr * tx       * (1 - ty) +
                    bl * (1 - tx) * ty +
                    br * tx       * ty
                )

    return filled


# Apply gap filling
gap_filled = fill_chip_gaps(corrected, config, chip_size)

In [None]:
gap_filled

In [None]:
fig, axs = plt.subplots(ncols=2, figsize=(15, 10))
axs[0].imshow(gap_filled, cmap='gray')
axs[0].set_title("Gap-Filled Image")
axs[1].imshow(corrected, cmap='gray')
axs[1].set_title("Shift-Corrected Image (with gaps)")