# UV to XY interpolation

## Aim of the notebook: to demonstrate how to perform UV to XY interpolation for image distortion using shaders.

**Unity** In shaders we typically have:

- UV coordinates: (u, v) ranging from 0 to 1, representing normalized texture coordinates.
- Apply distortion to create a distorted grid
- Sample the original image using the distorted coordinates

**In this notebook** We create a distorted output by looking up where each output pixel should come from in the input image

In [1]:
#Importing of libraries for data analysis

import matplotlib
import pandas as pd
import numpy as np
import scipy as sp
from scipy import interpolate
from scipy.ndimage import map_coordinates
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
from PIL import Image

## 1. Core interpolation function

### Inputs

- img: original image (numpy array)
- distorted_grid_df: DataFrame with columns ['uv_x', 'uv_y', 'distorted_x', 'distorted_y']
- output_resolution

### Outputs

- Distorted image

In [None]:
def uv_to_xy_interpolation(img, distorted_grid_df, output_resolution = None):
    """
    Perform UV to XY interpolation for image distortion.

    Parameters:
        img (numpy array): Input image as a numpy array.
        distorted_grid_df (DataFrame): DataFrame containing distorted grid coordinates with columns ['uv_x', 'uv_y', 'distorted_x', 'distorted_y'].
        output_resolution: (height, width) for output image
    """

    # Defining the output resolution

    output_height, output_width = output_resolution

    # Defining the boundaries of the distorted coordinates space -> how far is the distortion going to extend for?

    min_x, max_x = distorted_grid_df['distorted_x'].min(), distorted_grid_df['distorted_x'].max()
    min_y, max_y = distorted_grid_df['distorted_y'].min(), distorted_grid_df['distorted_y'].max()

    # Create the output grid in the distorted space

    x_range = np.linspace(min_x, max_x, output_width)
    y_range = np.linspace(min_y, max_y, output_height)
    xs, ys = np.meshgrid(x_range, y_range) # Function for creation of meshgrid
    ys = np.flipud(ys) # Flip the y-coordinates to match image coordinate system -> (0,0) at top-left

    # Interpolate to find corresponding UV coordinates for each point in the distorted grid -> Reverse mapping: for each distorted position, we find the original UV

    uv_x_interpolated = sp.interpolate.griddata(distorted_grid_df[['distorted_x', 'distorted_y']].values, distorted_grid_df['uv_x'].values, (xs, ys), method='cubic', fill_value=0)
    uv_y_interpolated = sp.interpolate.griddata(distorted_grid_df[['distorted_x', 'distorted_y']].values, distorted_grid_df['uv_y'].values, (xs, ys), method='cubic', fill_value=0)

    # Convert the UV coordinates to pixel coordinates in the original image

    pixel_x = uv_x_interpolated * (img.shape[1] - 1)
    pixel_y = uv_y_interpolated * (img.shape[0] - 1)

    # Sample the original image at these coordinates, handling grayscale and color images

    # Grayscale image

    if len(img.shape) == 2: #Only two dimensions: height and width
        distorted_img = map_coordinates(img, [pixel_y.flatten(), pixel_x.flatten()], mode='constant', cval = 0.0).reshape((output_height, output_width))
        
    # Color image

    else:
        channels = []
        for c in range(img.shape[2]): #For each color channel (RGB) - 3 dimensions
            channel_interpolated = map_coordinates(img[:, :, c], [pixel_y.flatten(), pixel_x.flatten()], mode='constant', cval = 0.0).reshape((output_height, output_width))
            channels.append(channel_interpolated
        distorted_img = np.stack(channels, axis=-1)

    return distorted_img

## 2. Aid function: creation of regular grid from points -> Used for the core interpolation function

# Create regular grid df from defined points

### Inputs

- grid_resolution: Resolution of the grid, originally defined for the dimensions of an Amsler Grid (8x8)

### Outputs

- grid_df: DataFrame for the regular grid that can posteriously be plotted

In [None]:
def grid_nodistortion(grid_resolution=(8,8)):
    rows, cols = grid_resolution
    uv_x = np.linspace(0, 1, cols)
    uv_y = np.linspace(0, 1, rows)
    uv_grid_x, uv_grid_y = np.meshgrid(uv_x, uv_y)

    # Flatten the grid for DataFrame creation
    uv_x_flat = uv_grid_x.flatten()
    uv_y_flat = uv_grid_y.flatten()

    # Create DataFrame
    grid_df = pd.DataFrame({'uv_x': uv_x_flat, 'uv_y': uv_y_flat, 'x': uv_x_flat, 'y': uv_y_flat})

    return grid_df

## 3. Aid function: creation of distorted grid from points -> Used for the core interpolation function

### Creates distorted grid df from control points -> User defines how specific UV positions map to distorted XY positions

### Inputs

- grid_resolution: (rows, cols)
- distortion_function: function that takes (uv_x, uv_y) and returns (distorted_x, distorted_y)

### Outputs

- df with columns ['uv_x', 'uv_y', 'distorted_x', 'distorted_y']

In [None]:
def create_distorted_grid(grid_resolution=(10, 10), distortion_function = None):
    rows, cols = grid_resolution

    # Create UV grid

    uv_x = linspace(0, 1, cols)
    uv_y = linspace(0, 1, rows)
    uv_grid_x, uv_grid_y = np.meshgrid(uv_x, uv_y)

    # Flatten for df

    uv_x_flat = uv_grid_x.flatten()
    uv_y_flat = uv_grid_y.flatten()

    # Apply distortion

    distorted_points = np.array([distortion_function(u, v) for u, v in zip(uv_x_flat, uv_y_flat)])

    # Create dataframe for storage of values

    grid_df = pd.DataFrame({'uv_x': uv_x_flat, 'uv_y': uv_y_flat, 'distorted_x': distorted_points[:, 0], 'distorted_y': distorted_points[:, 1]})

    return grid_df

## 4. Exemplar distortion function: barrel distortion

### Applies a barrel distortion to specific UV coordinates


In [None]:
def barrel_distortion(u, v):
    cx, cy = 0.5, 0.5  # Center of the distortion
    dx, dy = u - cx, u - cy
    r = np.sqrt(dx**2 + dy**2)
    k = 0.3 # Distortion coefficient
    r_distorted = r * (1 + k * r**2)
    if r > 0:
        scale = r_distorted/r
        return cx + dx * scale, cy + dy * scale
    else:
        return cx, cy

## 5. Grids plotting

### Performs a overlapped plotting of the regular grid and the distorted grid for comparison purposes

In [None]:
def plot_grids(original_grid_df, distorted_grid_df):
    plt.figure(figsize = (8,8))
    plt.scatter(original_grid_df['uv_x'], regular_grid_df['uv_y'], color='blue', label='Original Grid')
    plt.scatter(distorted_grid_df['distorted_x'], distorted_grid_df['distorted_y'], color='red', label='Distorted Grid')
    plt.title('Regular grid vs. Distorted grid')
    plt.xlabel('X')
    plt.ylabel('Y')
    plt.legend()
    plt.axis('equal')
    plt.grid(alpha = 0.3)
    plt.show()

## Images plotting

### Performs a comparative plotting of the original image and the distorted image side by side

In [None]:
def plot_images(original_img, distorted_img):
    fig, axes = plt.subplots(1, 2, figsize=(12, 6))
    axes[0].imshow(original_image)
    axes[0].set_title('Original Image')
    axes[0].axis('off')
    axes[1].imshow(distorted_image)
    axes[1].set_title('Distorted Image')
    axes[1].axis('off')
    plt.tight_layout()
    plt.show()

## Main workflow

### For the execution of the procedure of grid generation, deformation and plotting

In [None]:
if __name__ == "__main__":
    