In [None]:
# Cell 0: Installs
# !pip update -qq
# !pip install -y libvips libvips-dev
!pip install -q pyvips
!pip install -q cupy-cuda12x
!pip install -q "zarr>=2.16,<3.0"
!pip install -q numcodecs
!pip install -q dask[complete]
!pip install -q tqdm
!pip install -q numba
!pip install -q matplotlib
!pip install -q psutil
!pip install -q imageio

ERROR: unknown command "update"


Usage:   
  pip install [options] <requirement specifier> [package-index-options] ...
  pip install [options] -r <requirements file> [package-index-options] ...
  pip install [options] [-e] <vcs project url> ...
  pip install [options] [-e] <local project path> ...
  pip install [options] <archive url/path> ...

no such option: -y

[notice] A new release of pip is available: 24.2 -> 25.1.1
[notice] To update, run: python.exe -m pip install --upgrade pip


^C


In [None]:
# Cell 1: Imports
import os
import gc
import json
import numpy as np
import cupy as cp
import pyvips
import zarr
from numcodecs import Blosc
from tqdm import tqdm
import psutil
from numba import cuda, jit
from numba.core import config
import threading
import queue
import time
from pathlib import Path
from datetime import datetime
import warnings
import imageio

warnings.filterwarnings('ignore')
config.FASTMATH = False

In [None]:
# Cell 1.5: Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')
print("Google Drive mounted. Checkpoint directory configured in Cell 12.")

In [None]:
# Cell 2: Kernel Management and Colormap - Complex Function
# Complex CUDA kernel with double precision and Blue-Black-Red colormap
implicit_function_kernel_double = cp.RawKernel(r'''
// Device function for Blue-Black-Red colormap
__device__ void get_target_style_colormap_rgb(double normalized_value,
                                             unsigned char* r_channel,
                                             unsigned char* g_channel,
                                             unsigned char* b_channel) {
    double r_float, g_float, b_float;

    // Define target colors (0-255 range, as doubles for interpolation)
    const double blue_r = 20.0,  blue_g = 60.0,  blue_b = 170.0;
    const double black_r = 0.0,  black_g = 0.0,  black_b = 0.0;
    const double red_r = 170.0,  red_g = 25.0,   red_b = 50.0;

    if (normalized_value < 0.5) {
        // Interpolate from Blue (at 0.0) to Black (at 0.5)
        double t = normalized_value * 2.0;

        r_float = blue_r * (1.0 - t) + black_r * t;
        g_float = blue_g * (1.0 - t) + black_g * t;
        b_float = blue_b * (1.0 - t) + black_b * t;
    } else {
        // Interpolate from Black (at 0.5) to Red (at 1.0)
        double t = (normalized_value - 0.5) * 2.0;

        r_float = black_r * (1.0 - t) + red_r * t;
        g_float = black_g * (1.0 - t) + red_g * t;
        b_float = black_b * (1.0 - t) + red_b * t;
    }

    // Clamp and convert to unsigned char with rounding
    *r_channel = (unsigned char)(fmax(0.0, fmin(255.0, r_float + 0.5)));
    *g_channel = (unsigned char)(fmax(0.0, fmin(255.0, g_float + 0.5)));
    *b_channel = (unsigned char)(fmax(0.0, fmin(255.0, b_float + 0.5)));
}

extern "C" __global__
void implicit_function_rgb_double(
    unsigned char* out_r,
    unsigned char* out_g,
    unsigned char* out_b,
    const double center_x,
    const double center_y,
    const double scale,
    const int width,
    const int height,
    const int tile_overlap
) {
    int idx = blockDim.x * blockIdx.x + threadIdx.x;
    int idy = blockDim.y * blockIdx.y + threadIdx.y;

    if (idx >= width || idy >= height) return;

    double pixel_x = (double)idx;
    double pixel_y = (double)idy;

    double x_temp = pixel_x;
    double y_temp = pixel_y;

    double x = center_x + x_temp * scale;
    double y = center_y + y_temp * scale;

    double xy_product = x * y;
    double xy_scaled = xy_product * 1000.0;

    double term1 = sin(xy_scaled);

    double r_squared = x * x + y * y;
    double r = sqrt(r_squared + 1e-10);

    double cos_r_squared = cos(r_squared);
    double sin_100r = sin(100.0 * r);
    double term2;

    int branch_taken = 0;  // Track which branch was taken

    if (cos_r_squared > 0) {
        term2 = pow(cos_r_squared, sin_100r);
        branch_taken = 1;
    } else if (cos_r_squared < 0 && fabs(sin_100r - round(sin_100r)) < 1e-6) {
        int exp_int = (int)round(sin_100r);
        term2 = pow(fabs(cos_r_squared), fabs(sin_100r));
        if (exp_int % 2 != 0) term2 = -term2;
        branch_taken = 2;
    } else {
        term2 = 0.0;
        branch_taken = 3;
    }

    double value = term1 + term2 - 0.5;
    // BROKEN  - NEEDS FIX

    double blend_factor = 1.0;
    if (tile_overlap > 0) {
        int dist_to_edge = min(min(idx, width - 1 - idx),
                               min(idy, height - 1 - idy));
        if (dist_to_edge < tile_overlap) {
            blend_factor = (double)dist_to_edge / (double)tile_overlap;
        }
    }

    value *= blend_factor;
    value = fmax(-3.0, fmin(3.0, value));
    double normalized = (value + 3.0) / 6.0;
    normalized = fmax(0.0, fmin(1.0, normalized));

    int linear_idx = idy * width + idx;
    get_target_style_colormap_rgb(normalized, &out_r[linear_idx], &out_g[linear_idx], &out_b[linear_idx]);
}
''', 'implicit_function_rgb_double', options=('--fmad=false',))

# Define target colors as global constants for Numba
# (RGB values in 0-255 range)
_TARGET_BLUE_RGB = (20, 60, 170)
_TARGET_BLACK_RGB = (0, 0, 0)
_TARGET_RED_RGB = (170, 25, 50)

# Numba JIT-compiled colormap function for CPU kernels - Blue-Black-Red style
@jit(nopython=True, cache=True)
def target_style_colormap_numba(normalized_value):
    """
    Applies a custom Blue-Black-Red colormap.
    - normalized_value = 0.0 maps to Blue
    - normalized_value = 0.5 maps to Black
    - normalized_value = 1.0 maps to Red
    """
    r_float, g_float, b_float = 0.0, 0.0, 0.0

    if normalized_value < 0.5:
        # Interpolate from Blue (at 0.0) to Black (at 0.5)
        t = normalized_value * 2.0

        r_float = float(_TARGET_BLUE_RGB[0]) * (1.0 - t) + float(_TARGET_BLACK_RGB[0]) * t
        g_float = float(_TARGET_BLUE_RGB[1]) * (1.0 - t) + float(_TARGET_BLACK_RGB[1]) * t
        b_float = float(_TARGET_BLUE_RGB[2]) * (1.0 - t) + float(_TARGET_BLACK_RGB[2]) * t
    else:
        # Interpolate from Black (at 0.5) to Red (at 1.0)
        t = (normalized_value - 0.5) * 2.0

        r_float = float(_TARGET_BLACK_RGB[0]) * (1.0 - t) + float(_TARGET_RED_RGB[0]) * t
        g_float = float(_TARGET_BLACK_RGB[1]) * (1.0 - t) + float(_TARGET_RED_RGB[1]) * t
        b_float = float(_TARGET_BLACK_RGB[2]) * (1.0 - t) + float(_TARGET_RED_RGB[2]) * t

    # Clamp and convert to uint8 with rounding
    r_out = np.uint8(max(0.0, min(255.0, r_float + 0.5)))
    g_out = np.uint8(max(0.0, min(255.0, g_float + 0.5)))
    b_out = np.uint8(max(0.0, min(255.0, b_float + 0.5)))

    return r_out, g_out, b_out

# Numba CPU kernel for complex function using Blue-Black-Red colormap
@jit(nopython=True, parallel=True, fastmath=False, cache=True)
def compute_implicit_function_cpu_double(center_x, center_y, scale, width, height, tile_overlap):
    """CPU-optimized implicit function with double precision matching GPU"""
    result = np.zeros((height, width, 3), dtype=np.uint8)

    for idy in range(height):
        for idx in range(width):
            pixel_x = float(idx)
            pixel_y = float(idy)

            x_temp = pixel_x
            y_temp = pixel_y

            x = center_x + x_temp * scale
            y = center_y + y_temp * scale

            xy_product = x * y
            xy_scaled = xy_product * 1000.0

            term1 = np.sin(xy_scaled)

            r_squared = x * x + y * y
            r = np.sqrt(r_squared + 1e-10)

            cos_r_squared = np.cos(r_squared)
            sin_100r = np.sin(100.0 * r)
            term2 = 0.0

            if cos_r_squared > 0:
                term2 = cos_r_squared ** sin_100r
            elif cos_r_squared < 0 and abs(sin_100r - round(sin_100r)) < 1e-6:
                exp_int = int(round(sin_100r))
                term2 = abs(cos_r_squared) ** abs(sin_100r)
                if exp_int % 2 != 0:
                    term2 = -term2

            value = term1 + term2 - 0.5

            blend_factor = 1.0
            if tile_overlap > 0:
                dist_to_edge = min(min(idx, width - 1 - idx),
                                   min(idy, height - 1 - idy))
                if dist_to_edge < tile_overlap:
                    blend_factor = float(dist_to_edge) / float(tile_overlap)
            value *= blend_factor

            value = max(-3.0, min(3.0, value))
            normalized = (value + 3.0) / 6.0
            normalized = max(0.0, min(1.0, normalized))

            # Use the new Blue-Black-Red colormap function
            r_out, g_out, b_out = target_style_colormap_numba(normalized)
            result[idy, idx, 0] = r_out
            result[idy, idx, 1] = g_out
            result[idy, idx, 2] = b_out
    return result

In [None]:
# Cell 2b: Kernel Management - Simple Function
# Simple CUDA kernel for sin(x*y) with Blue-Black-Red colormap
implicit_function_simple_double = cp.RawKernel(r'''
// Device function for Blue-Black-Red colormap
__device__ void get_target_style_colormap_rgb(double normalized_value,
                                             unsigned char* r_channel,
                                             unsigned char* g_channel,
                                             unsigned char* b_channel) {
    double r_float, g_float, b_float;

    // Define target colors (0-255 range, as doubles for interpolation)
    const double blue_r = 20.0,  blue_g = 60.0,  blue_b = 170.0;
    const double black_r = 0.0,  black_g = 0.0,  black_b = 0.0;
    const double red_r = 170.0,  red_g = 25.0,   red_b = 50.0;

    if (normalized_value < 0.5) {
        // Interpolate from Blue (at 0.0) to Black (at 0.5)
        double t = normalized_value * 2.0;

        r_float = blue_r * (1.0 - t) + black_r * t;
        g_float = blue_g * (1.0 - t) + black_g * t;
        b_float = blue_b * (1.0 - t) + black_b * t;
    } else {
        // Interpolate from Black (at 0.5) to Red (at 1.0)
        double t = (normalized_value - 0.5) * 2.0;

        r_float = black_r * (1.0 - t) + red_r * t;
        g_float = black_g * (1.0 - t) + red_g * t;
        b_float = black_b * (1.0 - t) + red_b * t;
    }

    // Clamp and convert to unsigned char with rounding
    *r_channel = (unsigned char)(fmax(0.0, fmin(255.0, r_float + 0.5)));
    *g_channel = (unsigned char)(fmax(0.0, fmin(255.0, g_float + 0.5)));
    *b_channel = (unsigned char)(fmax(0.0, fmin(255.0, b_float + 0.5)));
}

extern "C" __global__
void implicit_function_simple_double(
    unsigned char* out_r,
    unsigned char* out_g,
    unsigned char* out_b,
    const double center_x,
    const double center_y,
    const double scale,
    const int width,
    const int height,
    const int tile_overlap
) {
    int idx = blockDim.x * blockIdx.x + threadIdx.x;
    int idy = blockDim.y * blockIdx.y + threadIdx.y;

    if (idx >= width || idy >= height) return;

    double pixel_x = (double)idx;
    double pixel_y = (double)idy;

    // Direct calculation - tile origin already contains proper offset
    double x = center_x + pixel_x * scale;
    double y = center_y + pixel_y * scale;

    double value = sin(x * y);

    double blend_factor = 1.0;
    if (tile_overlap > 0) {
        int dist_to_edge = min(min(idx, width - 1 - idx),
                               min(idy, height - 1 - idy));
        if (dist_to_edge < tile_overlap) {
            blend_factor = (double)dist_to_edge / (double)tile_overlap;
        }
    }
    value *= blend_factor;

    double normalized = (value + 1.0) / 2.0;
    normalized = fmax(0.0, fmin(1.0, normalized));

    int linear_idx = idy * width + idx;
    get_target_style_colormap_rgb(normalized, &out_r[linear_idx], &out_g[linear_idx], &out_b[linear_idx]);
}
''', 'implicit_function_simple_double', options=('--fmad=false',))

# Numba CPU kernel for simple function using Blue-Black-Red colormap
@jit(nopython=True, parallel=True, fastmath=False, cache=True)
def compute_implicit_function_cpu_simple_double(center_x, center_y, scale, width, height, tile_overlap):
    """CPU-optimized simple implicit function with double precision matching GPU"""
    result = np.zeros((height, width, 3), dtype=np.uint8)

    for idy in range(height):
        for idx in range(width):
            pixel_x = float(idx)
            pixel_y = float(idy)

            # Direct calculation - tile origin already contains proper offset
            x = center_x + pixel_x * scale
            y = center_y + pixel_y * scale

            value = np.sin(x * y)

            blend_factor = 1.0
            if tile_overlap > 0:
                dist_to_edge = min(min(idx, width - 1 - idx),
                                   min(idy, height - 1 - idy))
                if dist_to_edge < tile_overlap:
                    blend_factor = float(dist_to_edge) / float(tile_overlap)
            value *= blend_factor

            normalized = (value + 1.0) / 2.0
            normalized = max(0.0, min(1.0, normalized))

            # Use the new Blue-Black-Red colormap function
            r_out, g_out, b_out = target_style_colormap_numba(normalized)
            result[idy, idx, 0] = r_out
            result[idy, idx, 1] = g_out
            result[idy, idx, 2] = b_out

    return result

In [None]:
# Cell 3: GPU-CPU Hybrid Renderer
class HybridGPUCPURenderer:
    # Coordinate pattern kernel for validation
    _COORDINATE_PATTERN_KERNEL = r'''
    extern "C" __global__
    void coordinate_pattern(
        unsigned char* out_r, unsigned char* out_g, unsigned char* out_b,
        const int width, const int height
    ) {
        int idx = blockDim.x * blockIdx.x + threadIdx.x;
        int idy = blockDim.y * blockIdx.y + threadIdx.y;
        if (idx >= width || idy >= height) return;
        int linear_idx = idy * width + idx;
        out_r[linear_idx] = (unsigned char)(((double)idx * 255.0 / (double)(width > 1 ? width - 1 : 1)) + 0.5);
        out_g[linear_idx] = (unsigned char)(((double)idy * 255.0 / (double)(height > 1 ? height - 1 : 1)) + 0.5);
        out_b[linear_idx] = 128;
    }'''

    def __init__(self, width, height, tile_size, center_x, center_y, scale, tile_overlap,
                 cuda_threads_x=32, cuda_threads_y=32, kernel_type='complex'):
        self.width = width
        self.height = height
        self.tile_size = tile_size
        self.tile_overlap = tile_overlap

        self.center_x = np.float64(center_x)
        self.center_y = np.float64(center_y)
        self.scale = np.float64(scale)

        self.cuda_threads_x = cuda_threads_x
        self.cuda_threads_y = cuda_threads_y

        self.kernel_type = kernel_type.lower()

        # Select kernels based on type
        if self.kernel_type == 'complex':
            self.gpu_kernel = implicit_function_kernel_double
            self.cpu_kernel = compute_implicit_function_cpu_double
            print("Using COMPLEX implicit function kernel.")
        elif self.kernel_type == 'simple':
            self.gpu_kernel = implicit_function_simple_double
            self.cpu_kernel = compute_implicit_function_cpu_simple_double
            print("Using SIMPLE implicit function kernel (sin(xy)).")
        else:
            raise ValueError(f"Unknown kernel_type: {kernel_type}. Choose 'complex' or 'simple'.")

        self._coordinate_pattern_kernel = None

        try:
            self.gpu_available = cp.cuda.runtime.getDeviceCount() > 0
        except Exception as e:
            print(f"CUDA availability check failed: {e}")
            self.gpu_available = False

        if not self.gpu_available:
            print("CUDA not available, using CPU only.")
            self.gpu_kernel = None
        else:
            try:
                # Compile coordinate pattern kernel once
                self._coordinate_pattern_kernel = cp.RawKernel(
                    self._COORDINATE_PATTERN_KERNEL, 'coordinate_pattern')

                self.pool = cp.cuda.MemoryPool(cp.cuda.malloc_managed)
                cp.cuda.set_allocator(self.pool.malloc)
                device = cp.cuda.Device()
                self.device_id = device.id
                memory_info = device.mem_info
                print(f"GPU initialized: ID {self.device_id}, "
                      f"Total Memory {memory_info[1]/1024**3:.1f}GB, "
                      f"Free Memory {memory_info[0]/1024**3:.1f}GB")
            except Exception as e:
                print(f"Error initializing GPU: {e}")
                self.gpu_available = False
                self.gpu_kernel = None

    def compute_tile_gpu(self, x_start, y_start, tile_width, tile_height, validate=False):
        """Compute tile using GPU with double precision and overlap handling"""
        if not self.gpu_available:
            return self.compute_tile_cpu(x_start, y_start, tile_width, tile_height, validate)

        actual_width = tile_width + 2 * self.tile_overlap
        actual_height = tile_height + 2 * self.tile_overlap

        actual_x_start = max(0, x_start - self.tile_overlap)
        actual_y_start = max(0, y_start - self.tile_overlap)

        actual_tile_width = min(actual_width, self.width - actual_x_start)
        actual_tile_height = min(actual_height, self.height - actual_y_start)

        size = actual_tile_width * actual_tile_height
        if size == 0:
            return np.zeros((0, 0, 3), dtype=np.uint8)

        out_r = cp.zeros(size, dtype=cp.uint8)
        out_g = cp.zeros(size, dtype=cp.uint8)
        out_b = cp.zeros(size, dtype=cp.uint8)

        threads_per_block = (self.cuda_threads_x, self.cuda_threads_y)
        blocks_x = (actual_tile_width + threads_per_block[0] - 1) // threads_per_block[0]
        blocks_y = (actual_tile_height + threads_per_block[1] - 1) // threads_per_block[1]
        blocks_per_grid = (blocks_x, blocks_y)

        if validate:
            if self._coordinate_pattern_kernel:
                self._coordinate_pattern_kernel(
                    blocks_per_grid, threads_per_block,
                    (out_r, out_g, out_b, actual_tile_width, actual_tile_height))
        else:
            if self.gpu_kernel is None:
                return np.zeros((actual_tile_height, actual_tile_width, 3), dtype=np.uint8)

            tile_origin_x = self.center_x + (actual_x_start - self.width * 0.5) * self.scale
            tile_origin_y = self.center_y + (actual_y_start - self.height * 0.5) * self.scale

            self.gpu_kernel(
                blocks_per_grid, threads_per_block,
                (out_r, out_g, out_b, tile_origin_x, tile_origin_y,
                 self.scale, actual_tile_width, actual_tile_height, self.tile_overlap))

        cp.cuda.Stream.null.synchronize()
        r = cp.asnumpy(out_r.reshape((actual_tile_height, actual_tile_width)))
        g = cp.asnumpy(out_g.reshape((actual_tile_height, actual_tile_width)))
        b = cp.asnumpy(out_b.reshape((actual_tile_height, actual_tile_width)))
        del out_r, out_g, out_b
        cp.get_default_memory_pool().free_all_blocks()

        # Crop overlap regions
        if self.tile_overlap > 0 and not validate:
            crop_top = self.tile_overlap if y_start > 0 else 0
            crop_left = self.tile_overlap if x_start > 0 else 0

            tile_end_y = actual_y_start + actual_tile_height
            tile_end_x = actual_x_start + actual_tile_width

            crop_bottom = self.tile_overlap if tile_end_y < self.height else 0
            crop_right = self.tile_overlap if tile_end_x < self.width else 0

            final_h = actual_tile_height - crop_top - crop_bottom
            final_w = actual_tile_width - crop_left - crop_right

            if final_h > 0 and final_w > 0:
                r = r[crop_top:actual_tile_height - crop_bottom,
                      crop_left:actual_tile_width - crop_right]
                g = g[crop_top:actual_tile_height - crop_bottom,
                      crop_left:actual_tile_width - crop_right]
                b = b[crop_top:actual_tile_height - crop_bottom,
                      crop_left:actual_tile_width - crop_right]
            else:
                return np.zeros((0, 0, 3), dtype=np.uint8)

        return np.stack([r, g, b], axis=2)

    def compute_tile_cpu(self, x_start, y_start, tile_width, tile_height, validate=False):
        """Compute tile using CPU with double precision"""
        actual_width = tile_width + 2 * self.tile_overlap
        actual_height = tile_height + 2 * self.tile_overlap

        actual_x_start = max(0, x_start - self.tile_overlap)
        actual_y_start = max(0, y_start - self.tile_overlap)

        actual_tile_width = min(actual_width, self.width - actual_x_start)
        actual_tile_height = min(actual_height, self.height - actual_y_start)

        if actual_tile_width <= 0 or actual_tile_height <= 0:
            return np.zeros((0, 0, 3), dtype=np.uint8)

        if validate:
            result = np.zeros((actual_tile_height, actual_tile_width, 3), dtype=np.uint8)
            for y_idx in range(actual_tile_height):
                for x_idx in range(actual_tile_width):
                    max_w = actual_tile_width - 1 if actual_tile_width > 1 else 1
                    max_h = actual_tile_height - 1 if actual_tile_height > 1 else 1
                    result[y_idx, x_idx, 0] = int((float(x_idx) * 255.0 / float(max_w)) + 0.5)
                    result[y_idx, x_idx, 1] = int((float(y_idx) * 255.0 / float(max_h)) + 0.5)
                    result[y_idx, x_idx, 2] = 128
        else:
            if self.cpu_kernel is None:
                return np.zeros((actual_tile_height, actual_tile_width, 3), dtype=np.uint8)

            tile_origin_x = self.center_x + (actual_x_start - self.width * 0.5) * self.scale
            tile_origin_y = self.center_y + (actual_y_start - self.height * 0.5) * self.scale

            result = self.cpu_kernel(
                tile_origin_x, tile_origin_y,
                self.scale, actual_tile_width, actual_tile_height, self.tile_overlap)

        # Crop overlap regions
        if self.tile_overlap > 0 and not validate:
            crop_top = self.tile_overlap if y_start > 0 else 0
            crop_left = self.tile_overlap if x_start > 0 else 0

            tile_end_y = actual_y_start + actual_tile_height
            tile_end_x = actual_x_start + actual_tile_width

            crop_bottom = self.tile_overlap if tile_end_y < self.height else 0
            crop_right = self.tile_overlap if tile_end_x < self.width else 0

            final_h = actual_tile_height - crop_top - crop_bottom
            final_w = actual_tile_width - crop_left - crop_right

            if final_h > 0 and final_w > 0:
                result = result[crop_top:actual_tile_height - crop_bottom,
                                crop_left:actual_tile_width - crop_right]
            else:
                return np.zeros((0, 0, 3), dtype=np.uint8)

        return result

    def validate_gpu_cpu_consistency(self, num_samples, epsilon):
        """Validate GPU-CPU consistency for selected kernel"""
        if not self.gpu_available:
            print("GPU not available, skipping consistency validation.")
            return True

        if self.gpu_kernel is None or self.cpu_kernel is None:
            print(f"Cannot validate kernel type '{self.kernel_type}' - kernels not available.")
            return False

        print(f"Validating GPU-CPU consistency for '{self.kernel_type}' kernel...")
        np.random.seed(42)

        sample_size_w = min(self.tile_size, self.width)
        sample_size_h = min(self.tile_size, self.height)

        test_positions = []
        if self.width >= sample_size_w and self.height >= sample_size_h:
            test_positions.append((0, 0))  # Top-left
            test_positions.append((max(0, (self.width - sample_size_w) // 2),
                                   max(0, (self.height - sample_size_h) // 2)))  # Center
            if self.width > sample_size_w and self.height > sample_size_h:
                test_positions.append((self.width - sample_size_w,
                                       self.height - sample_size_h))  # Bottom-right

        # Add random positions
        for _ in range(max(0, num_samples - len(test_positions))):
            if self.width > sample_size_w and self.height > sample_size_h:
                x = np.random.randint(0, self.width - sample_size_w + 1)
                y = np.random.randint(0, self.height - sample_size_h + 1)
                test_positions.append((x, y))

        test_positions = list(set(test_positions))
        if not test_positions:
            print("No valid test positions available.")
            return True

        all_consistent = True
        for i, (x, y) in enumerate(test_positions):
            tile_w = min(sample_size_w, self.width - x)
            tile_h = min(sample_size_h, self.height - y)

            print(f"Testing tile {i+1}/{len(test_positions)} at ({x}, {y})...")
            gpu_result = self.compute_tile_gpu(x, y, tile_w, tile_h, validate=False)
            cpu_result = self.compute_tile_cpu(x, y, tile_w, tile_h, validate=False)

            if gpu_result.shape != cpu_result.shape:
                print(f"  WARNING: Shape mismatch! GPU: {gpu_result.shape}, CPU: {cpu_result.shape}")
                all_consistent = False
                continue

            if gpu_result.size == 0:
                continue

            max_diff = np.max(np.abs(gpu_result.astype(np.float64) - cpu_result.astype(np.float64))) / 255.0
            mean_diff = np.mean(np.abs(gpu_result.astype(np.float64) - cpu_result.astype(np.float64))) / 255.0
            print(f"  Max diff: {max_diff:.6f}, Mean diff: {mean_diff:.6f}")

            if max_diff > epsilon:
                print(f"  WARNING: Inconsistency detected! Max diff {max_diff} > {epsilon}")
                all_consistent = False

        print("✓ GPU-CPU consistency validated!" if all_consistent else "✗ GPU-CPU inconsistency detected!")
        return all_consistent

In [None]:
# Cell 4: Zarr Storage Management
class ZarrTiledStorage:
    def __init__(self, path, width, height, chunk_size, center_x, center_y, scale):
        self.path = path
        self.width = width
        self.height = height
        self.chunk_size = chunk_size

        self.center_x = np.float64(center_x)
        self.center_y = np.float64(center_y)
        self.scale = np.float64(scale)

        self.store = zarr.storage.DirectoryStore(path)
        self.root = zarr.group(store=self.store, overwrite=True)
        compressor = Blosc(cname='zstd', clevel=3, shuffle=Blosc.SHUFFLE)

        zarr_chunk_h = max(1, chunk_size)
        zarr_chunk_w = max(1, chunk_size)

        self.array_r = self.root.create_dataset('red', shape=(height, width),
                                                chunks=(zarr_chunk_h, zarr_chunk_w),
                                                dtype='uint8', compressor=compressor)
        self.array_g = self.root.create_dataset('green', shape=(height, width),
                                                chunks=(zarr_chunk_h, zarr_chunk_w),
                                                dtype='uint8', compressor=compressor)
        self.array_b = self.root.create_dataset('blue', shape=(height, width),
                                                chunks=(zarr_chunk_h, zarr_chunk_w),
                                                dtype='uint8', compressor=compressor)

        # Store metadata
        self.root.attrs['width'] = width
        self.root.attrs['height'] = height
        self.root.attrs['chunk_size'] = chunk_size
        self.root.attrs['center_x'] = float(center_x)
        self.root.attrs['center_y'] = float(center_y)
        self.root.attrs['scale'] = float(scale)
        self.root.attrs['coordinate_system'] = 'center_origin'
        self.root.attrs['precision'] = 'float64'
        self.root.attrs['created'] = datetime.now().isoformat()

    def write_tile(self, x, y, tile_rgb):
        h, w, _ = tile_rgb.shape
        if h == 0 or w == 0:
            return

        y_end = min(y + h, self.height)
        x_end = min(x + w, self.width)
        actual_h = y_end - y
        actual_w = x_end - x

        if actual_h <= 0 or actual_w <= 0:
            return

        self.array_r[y:y_end, x:x_end] = tile_rgb[:actual_h, :actual_w, 0]
        self.array_g[y:y_end, x:x_end] = tile_rgb[:actual_h, :actual_w, 1]
        self.array_b[y:y_end, x:x_end] = tile_rgb[:actual_h, :actual_w, 2]

In [None]:
# Cell 5: PyVips Environment Setup
def setup_pyvips_environment(concurrency, disc_threshold, cache_mem_mb, cache_files):
    """Configure PyVips environment settings."""
    os.environ['VIPS_CONCURRENCY'] = str(concurrency)
    os.environ['VIPS_DISC_THRESHOLD'] = str(disc_threshold)
    pyvips.cache_set_max(0)  # Disable operation cache
    pyvips.cache_set_max_mem(cache_mem_mb * 1024 * 1024)
    pyvips.cache_set_max_files(cache_files)
    print(f"PyVips configured: Concurrency={concurrency}, DiscThreshold={disc_threshold}, "
          f"MaxMemMB={cache_mem_mb}, MaxFiles={cache_files}")

In [None]:
# Cell 6: Gigapixel Generator
class GigapixelGenerator:
    def __init__(self, width, height, tile_size, checkpoint_dir,
                 center_x, center_y, scale, tile_overlap,
                 cuda_threads_x, cuda_threads_y,
                 mempool_chunk, mempool_size, mempool_limit_gb, mempool_threshold_gb,
                 validate_mode=False, kernel_type='complex'):
        self.width = width
        self.height = height
        self.tile_size = tile_size
        self.checkpoint_dir = Path(checkpoint_dir)
        self.checkpoint_dir.mkdir(parents=True, exist_ok=True)
        self.validate_mode = validate_mode

        self.center_x = np.float64(center_x)
        self.center_y = np.float64(center_y)
        self.scale = np.float64(scale)
        self.tile_overlap = tile_overlap
        self.kernel_type = kernel_type

        self.renderer = HybridGPUCPURenderer(
            width, height, tile_size,
            self.center_x, self.center_y, self.scale, self.tile_overlap,
            cuda_threads_x, cuda_threads_y,
            kernel_type=kernel_type)

        self.memory_pool = MemoryPoolManager(
            chunk_size=mempool_chunk,
            pool_size=mempool_size,
            memory_limit_gb=mempool_limit_gb,
            threshold_gb=mempool_threshold_gb)

        self.checkpoint_file = self.checkpoint_dir / 'progress.json'
        self.completed_tiles = self.load_checkpoint()
        self.validation_results = {}

    def load_checkpoint(self):
        if self.checkpoint_file.exists():
            try:
                with open(self.checkpoint_file, 'r') as f:
                    data = json.load(f)
                    # Check compatibility
                    if ('center_x' in data and
                        (abs(data['center_x'] - self.center_x) > 1e-9 or
                         abs(data['center_y'] - self.center_y) > 1e-9 or
                         abs(data['scale'] - self.scale) > 1e-9 or
                         data.get('kernel_type', 'complex') != self.kernel_type or
                         data.get('width') != self.width or
                         data.get('height') != self.height or
                         data.get('tile_size') != self.tile_size)):
                        print("WARNING: Checkpoint parameters mismatch. Starting fresh.")
                        return set()
                    return set(tuple(t) for t in data.get('completed_tiles', []))
            except json.JSONDecodeError:
                print(f"Error reading checkpoint. Starting fresh.")
                return set()
        return set()

    def save_checkpoint(self):
        with open(self.checkpoint_file, 'w') as f:
            json.dump({
                'completed_tiles': list(self.completed_tiles),
                'timestamp': datetime.now().isoformat(),
                'center_x': float(self.center_x),
                'center_y': float(self.center_y),
                'scale': float(self.scale),
                'width': self.width,
                'height': self.height,
                'tile_size': self.tile_size,
                'kernel_type': self.kernel_type
            }, f, indent=4)

    def validate_coordinate_patterns(self, pattern_size, pattern_scale, sample_tiles, epsilon):
        print("Generating coordinate validation patterns...")
        test_w = min(pattern_size, self.width)
        test_h = min(pattern_size, self.height)

        if test_w <= 0 or test_h <= 0:
            print("Image too small for pattern validation.")
            self.validation_results['coordinate_patterns'] = 'Skipped'
            return True

        # Generate gradient patterns
        if self.renderer.gpu_available:
            gpu_gradient = self.renderer.compute_tile_gpu(0, 0, test_w, test_h, validate=True)
            if gpu_gradient.size > 0:
                try:
                    imageio.imwrite(f'gpu_gradient_{self.kernel_type}.png', gpu_gradient)
                    print(f"Saved gpu_gradient_{self.kernel_type}.png")
                    self.validation_results['gpu_gradient'] = 'Generated'
                except Exception as e:
                    print(f"Could not save GPU gradient: {e}")
                    self.validation_results['gpu_gradient'] = 'Error'

        cpu_gradient = self.renderer.compute_tile_cpu(0, 0, test_w, test_h, validate=True)
        if cpu_gradient.size > 0:
            try:
                imageio.imwrite(f'cpu_gradient_{self.kernel_type}.png', cpu_gradient)
                print(f"Saved cpu_gradient_{self.kernel_type}.png")
                self.validation_results['cpu_gradient'] = 'Generated'
            except Exception as e:
                print(f"Could not save CPU gradient: {e}")
                self.validation_results['cpu_gradient'] = 'Error'

        # GPU-CPU consistency check
        print(f"\nValidating GPU-CPU consistency for '{self.renderer.kernel_type}' kernel...")
        consistency = self.renderer.validate_gpu_cpu_consistency(sample_tiles, epsilon)
        self.validation_results[f'consistency_{self.renderer.kernel_type}'] = consistency
        return consistency

    def verify_tile_boundaries(self, sample_array, max_diff):
        print("Verifying tile boundaries...")
        if sample_array is None or sample_array.size == 0:
            print("Sample array is empty.")
            self.validation_results['boundaries'] = 'Skipped'
            return True

        if sample_array.ndim == 2:
            sample_array = sample_array[:, :, np.newaxis]

        img_h, img_w, _ = sample_array.shape
        if img_h < self.tile_size and img_w < self.tile_size:
            print("Sample smaller than tile size.")
            self.validation_results['boundaries'] = 'Skipped'
            return True

        discontinuities = []

        # Check horizontal boundaries
        for y in range(self.tile_size, img_h, self.tile_size):
            if y < img_h:
                diff = np.abs(sample_array[y - 1, :, :] - sample_array[y, :, :])
                max_diff_val = np.max(diff) if diff.size > 0 else 0
                if max_diff_val > max_diff:
                    discontinuities.append(('horizontal', y, max_diff_val))

        # Check vertical boundaries
        for x in range(self.tile_size, img_w, self.tile_size):
            if x < img_w:
                diff = np.abs(sample_array[:, x - 1, :] - sample_array[:, x, :])
                max_diff_val = np.max(diff) if diff.size > 0 else 0
                if max_diff_val > max_diff:
                    discontinuities.append(('vertical', x, max_diff_val))

        if discontinuities:
            print(f"Found {len(discontinuities)} boundary discontinuities:")
            for direction, pos, diff_val in discontinuities[:5]:
                print(f"  {direction} at {pos}: max_diff={diff_val}")
            self.validation_results['boundaries'] = False
            return False
        else:
            print("✓ No significant boundary discontinuities found!")
            self.validation_results['boundaries'] = True
            return True

    def generate_with_zarr_backend(self, zarr_path, batch_size, checkpoint_interval,
                                   run_validation, val_pattern_size, val_pattern_scale,
                                   val_sample_tiles, val_epsilon,
                                   val_boundary_size, val_boundary_diff):
        print(f"Generating {self.width}x{self.height} image. Tile size: {self.tile_size}x{self.tile_size}")
        print(f"Mathematical center: ({self.center_x}, {self.center_y}), Scale: {self.scale}")
        print(f"Tile overlap: {self.tile_overlap} pixels")
        print(f"Using kernel type: '{self.kernel_type}'")

        if run_validation and not self.validate_mode:
            print("\n=== Running Validation Checks ===")
            self.validate_coordinate_patterns(val_pattern_size, val_pattern_scale,
                                              val_sample_tiles, val_epsilon)

        # Initialize Zarr storage
        print(f"Initializing Zarr storage at: {zarr_path}")
        try:
            storage = ZarrTiledStorage(zarr_path, self.width, self.height, self.tile_size,
                                       self.center_x, self.center_y, self.scale)
            print("Zarr storage initialized.")
        except Exception as e:
            print(f"Error initializing Zarr storage: {e}")
            return None

        # Generate tile list
        all_tiles = []
        for y in range(0, self.height, self.tile_size):
            for x in range(0, self.width, self.tile_size):
                tile_w = min(self.tile_size, self.width - x)
                tile_h = min(self.tile_size, self.height - y)
                if tile_w > 0 and tile_h > 0:
                    all_tiles.append((x, y, tile_w, tile_h))

        remaining_tiles = [t for t in all_tiles if tuple(t) not in self.completed_tiles]
        if not remaining_tiles:
            print("All tiles already generated.")
            return storage

        print(f"Total tiles: {len(all_tiles)}, To compute: {len(remaining_tiles)}")
        pbar = tqdm(total=len(remaining_tiles), desc="Generating tiles")

        try:
            for i in range(0, len(remaining_tiles), batch_size):
                batch = remaining_tiles[i:i + batch_size]

                # Check memory
                if not self.memory_pool.check_memory():
                    print("Low memory detected, cleaning up...")
                    self.memory_pool.cleanup()
                    gc.collect()

                for x, y, w, h in batch:
                    # Alternate GPU/CPU based on position
                    use_gpu = (self.renderer.gpu_available and
                               ((x // self.tile_size) + (y // self.tile_size)) % 2 == 0)

                    if self.validate_mode:
                        tile_data = (self.renderer.compute_tile_gpu(x, y, w, h, validate=True) if use_gpu
                                     else self.renderer.compute_tile_cpu(x, y, w, h, validate=True))
                    else:
                        tile_data = (self.renderer.compute_tile_gpu(x, y, w, h, validate=False) if use_gpu
                                     else self.renderer.compute_tile_cpu(x, y, w, h, validate=False))

                    if tile_data is not None and tile_data.size > 0:
                        storage.write_tile(x, y, tile_data)
                    else:
                        print(f"Warning: Empty tile at ({x},{y})")

                    self.completed_tiles.add((x, y, w, h))
                    pbar.update(1)

                # Save checkpoint every N tiles
                if (pbar.n % checkpoint_interval == 0 and pbar.n > 0) or (i + batch_size >= len(remaining_tiles)):
                    self.save_checkpoint()

        except Exception as e:
            print(f"Error during generation: {e}")
            import traceback
            traceback.print_exc()
            self.save_checkpoint()
            raise
        finally:
            pbar.close()
            self.save_checkpoint()

        print("Generation complete!")
        return storage

In [None]:
# Cell 7: Memory Pool Manager
class MemoryPoolManager:
    """Manages GPU memory allocation and monitoring."""

    def __init__(self, chunk_size, pool_size, memory_limit_gb, threshold_gb=2):
        self.chunk_size = chunk_size
        self.pool_size = pool_size
        self.memory_limit = memory_limit_gb * 1024**3
        self.threshold = threshold_gb * 1024**3
        self.gpu_available = False

        try:
            self.gpu_available = cp.cuda.is_available()
            if self.gpu_available:
                self._log_gpu_info()
        except Exception:
            pass

    def _log_gpu_info(self):
        """Log GPU memory information."""
        try:
            device = cp.cuda.Device()
            total, free = device.mem_info
            print(f"GPU Memory: {total/1024**3:.1f}GB total, {free/1024**3:.1f}GB free")
        except Exception as e:
            print(f"GPU info error: {e}")
            self.gpu_available = False

    def check_memory(self, required_bytes=0):
        """Check if sufficient memory is available."""
        if not self.gpu_available:
            return True

        try:
            cp.cuda.Stream.null.synchronize()
            _, free = cp.cuda.Device().mem_info
            return free > required_bytes + self.threshold
        except Exception as e:
            print(f"Memory check error: {e}")
            return False

    def cleanup(self):
        """Free GPU memory blocks."""
        if self.gpu_available:
            try:
                cp.cuda.Stream.null.synchronize()
                cp.get_default_memory_pool().free_all_blocks()
                gc.collect()
            except Exception as e:
                print(f"Cleanup error: {e}")


In [None]:
# Cell 8: Zarr to TIFF Converter
class ZarrToTIFFConverter:
    def __init__(self, zarr_path, output_tiff_path):
        self.zarr_path = zarr_path
        self.output_tiff_path = output_tiff_path
        self.store = zarr.storage.DirectoryStore(zarr_path)
        self.root = zarr.open_group(store=self.store, mode='r')

        self.width = self.root.attrs['width']
        self.height = self.root.attrs['height']
        self.zarr_chunk_size = self.root.attrs.get('chunk_size', 4096)

        self.center_x = self.root.attrs.get('center_x', 0.0)
        self.center_y = self.root.attrs.get('center_y', 0.0)
        self.scale = self.root.attrs.get('scale', 0.01)

    def convert_with_pyvips(self, read_tile_size, tiff_tile_width, tiff_tile_height, compression):
        print(f"Converting {self.width}x{self.height} Zarr to BigTIFF...")
        print(f"Read tile size: {read_tile_size}, TIFF tile size: {tiff_tile_width}x{tiff_tile_height}")

        image = pyvips.Image.black(self.width, self.height, bands=3)

        tiles_y = (self.height + read_tile_size - 1) // read_tile_size
        tiles_x = (self.width + read_tile_size - 1) // read_tile_size
        total_tiles = tiles_x * tiles_y
        pbar = tqdm(total=total_tiles, desc="Converting to TIFF")

        for y_idx in range(tiles_y):
            for x_idx in range(tiles_x):
                y = y_idx * read_tile_size
                x = x_idx * read_tile_size

                tile_h = min(read_tile_size, self.height - y)
                tile_w = min(read_tile_size, self.width - x)

                if tile_h <= 0 or tile_w <= 0:
                    continue

                tile_r = self.root['red'][y:y + tile_h, x:x + tile_w]
                tile_g = self.root['green'][y:y + tile_h, x:x + tile_w]
                tile_b = self.root['blue'][y:y + tile_h, x:x + tile_w]
                tile_rgb = np.stack([tile_r, tile_g, tile_b], axis=2)

                h, w, bands = tile_rgb.shape
                linear = tile_rgb.reshape(w * h * bands)
                tile_vips = pyvips.Image.new_from_memory(linear.data, w, h, bands, format='uchar')
                image = image.insert(tile_vips, x, y)
                pbar.update(1)
        pbar.close()

        metadata_str = f'center_x={self.center_x},center_y={self.center_y},scale={self.scale}'
        image_copy = image.copy()
        image_copy.set_type(pyvips.GValue.gstr_type, 'image-description', metadata_str)

        print("Saving BigTIFF...")
        image_copy.tiffsave(
            self.output_tiff_path, tile=True, compression=compression,
            predictor='horizontal', bigtiff=True,
            tile_width=tiff_tile_width, tile_height=tiff_tile_height,
            pyramid=True, properties=True)
        print(f"Conversion complete! Saved to {self.output_tiff_path}")

In [None]:
# Cell 9: PNG Conversion
def convert_tiff_to_png(tiff_path, png_dir, png_sizes, allow_upscale):
    Path(png_dir).mkdir(parents=True, exist_ok=True)
    print(f"Loading TIFF: {tiff_path}")

    try:
        image = pyvips.Image.new_from_file(tiff_path, access='sequential')
    except pyvips.Error as e:
        print(f"Error loading TIFF: {e}")
        return

    print(f"Source image size: {image.width}x{image.height}")

    try:
        description = image.get('image-description')
        print(f"Image metadata: {description}")
    except pyvips.Error:
        pass

    pbar = tqdm(total=len(png_sizes), desc="Creating PNG versions")
    for name, target_w, target_h in png_sizes:
        try:
            h_scale = target_w / image.width if image.width > 0 else 1
            v_scale = target_h / image.height if image.height > 0 else 1
            scale_factor = min(h_scale, v_scale)

            if scale_factor > 1 and not allow_upscale:
                scale_factor = 1

            final_w = int(image.width * scale_factor)
            final_h = int(image.height * scale_factor)

            if final_w == 0 or final_h == 0:
                print(f"\nSkipping {name}: Invalid dimensions")
                pbar.update(1)
                continue

            print(f"\nCreating '{name}': {final_w}x{final_h}", end="")
            if scale_factor > 1:
                print(" (upscaled)", end="")
            print()

            kernel = 'lanczos3' if scale_factor < 1 else 'cubic'
            resized = image.resize(scale_factor, kernel=kernel)

            output_path = Path(png_dir) / f"gigapixel_{name}.png"
            resized.pngsave(str(output_path), compression=9, interlace=True)
            size_mb = output_path.stat().st_size / 1024**2
            print(f"Saved: {output_path} ({size_mb:.1f} MB)")
        except Exception as e:
            print(f"Error creating '{name}': {e}")
        pbar.update(1)
    pbar.close()
    print("\nPNG conversion complete!")

In [None]:
# Cell 10: Main Pipeline Function
def run_gigapixel_pipeline(
    # Image dimensions
    width, height, tile_size, tile_overlap,
    # Mathematical coordinates
    center_x, center_y, scale,
    # Kernel selection
    main_kernel,
    # Output paths
    checkpoint_dir, zarr_path, tiff_path, png_dir,
    # TIFF settings
    tiff_tile_w, tiff_tile_h, tiff_compression,
    # PyVips settings
    vips_concurrency, vips_disc_threshold, vips_cache_mem, vips_cache_files,
    # Validation parameters
    val_enabled, val_width, val_height, val_tile_size, val_overlap,
    val_center_x, val_center_y, val_scale, val_kernel,
    val_pattern_size, val_pattern_scale, val_epsilon, val_sample_tiles,
    val_boundary_size, val_boundary_diff,
    # CUDA settings
    cuda_threads_x, cuda_threads_y,
    # Generation settings
    batch_size, checkpoint_interval,
    # Memory settings
    mempool_chunk, mempool_size, mempool_limit_gb, mempool_threshold_gb,
    # PNG settings
    png_upscale, png_sizes):

    # Setup PyVips
    setup_pyvips_environment(vips_concurrency, vips_disc_threshold, vips_cache_mem, vips_cache_files)

    # Ensure correct types
    center_x = np.float64(center_x)
    center_y = np.float64(center_y)

    # Auto-calculate scale if None
    if scale is None:
        scale = np.float64(20.0 / min(width, height)) if min(width, height) > 0 else np.float64(0.01)
    else:
        scale = np.float64(scale)

    print("=== Gigapixel Pipeline Starting ===")
    print(f"Target size: {width}x{height}, Tile size: {tile_size}, Overlap: {tile_overlap}")
    print(f"Center: ({center_x:.5f}, {center_y:.5f}), Scale: {scale:.5e}")
    print(f"Main kernel: '{main_kernel}'")

    validation_results = {}

    # Step 0: Optional validation run
    if val_enabled:
        print(f"\n--- Validation Run ---")
        print(f"Size: {val_width}x{val_height}, Tile: {val_tile_size}")
        print(f"Center: ({val_center_x:.5f}, {val_center_y:.5f}), Scale: {val_scale:.5e}")
        print(f"Kernel: '{val_kernel}'")

        val_checkpoint = Path(checkpoint_dir) / "validation"
        val_zarr = Path(checkpoint_dir) / "validation.zarr"

        validator = GigapixelGenerator(
            width=val_width, height=val_height, tile_size=val_tile_size,
            checkpoint_dir=str(val_checkpoint),
            center_x=val_center_x, center_y=val_center_y, scale=val_scale,
            tile_overlap=val_overlap,
            cuda_threads_x=cuda_threads_x, cuda_threads_y=cuda_threads_y,
            mempool_chunk=mempool_chunk, mempool_size=mempool_size,
            mempool_limit_gb=mempool_limit_gb, mempool_threshold_gb=mempool_threshold_gb,
            validate_mode=False,
            kernel_type=val_kernel)

        try:
            val_storage = validator.generate_with_zarr_backend(
                zarr_path=str(val_zarr),
                batch_size=max(1, batch_size // 2),
                checkpoint_interval=max(10, checkpoint_interval // 2),
                run_validation=True,
                val_pattern_size=val_pattern_size,
                val_pattern_scale=val_pattern_scale,
                val_sample_tiles=val_sample_tiles,
                val_epsilon=val_epsilon,
                val_boundary_size=val_boundary_size,
                val_boundary_diff=val_boundary_diff)

            validation_results.update(validator.validation_results)
            print("Validation run complete.")

            # Check boundaries on validation run
            if (val_width >= val_boundary_size and val_height >= val_boundary_size and val_storage):
                print(f"\n--- Checking Tile Boundaries (Sample: {val_boundary_size}x{val_boundary_size}) ---")
                sample_h = min(val_boundary_size, val_height)
                sample_w = min(val_boundary_size, val_width)
                sample_r = val_storage.array_r[0:sample_h, 0:sample_w]
                validator.verify_tile_boundaries(sample_r, val_boundary_diff)
                validation_results.update(validator.validation_results)

        except Exception as e:
            print(f"Validation error: {e}")
            import traceback
            traceback.print_exc()
            print("Continuing with main generation...")

    # Step 1: Main generation
    print(f"\n--- Main Generation ---")
    main_generator = GigapixelGenerator(
        width=width, height=height, tile_size=tile_size,
        checkpoint_dir=checkpoint_dir,
        center_x=center_x, center_y=center_y, scale=scale,
        tile_overlap=tile_overlap,
        cuda_threads_x=cuda_threads_x, cuda_threads_y=cuda_threads_y,
        mempool_chunk=mempool_chunk, mempool_size=mempool_size,
        mempool_limit_gb=mempool_limit_gb, mempool_threshold_gb=mempool_threshold_gb,
        validate_mode=False,
        kernel_type=main_kernel)

    try:
        storage = main_generator.generate_with_zarr_backend(
            zarr_path=zarr_path,
            batch_size=batch_size,
            checkpoint_interval=checkpoint_interval,
            run_validation=False,
            val_pattern_size=val_pattern_size,
            val_pattern_scale=val_pattern_scale,
            val_sample_tiles=val_sample_tiles,
            val_epsilon=val_epsilon,
            val_boundary_size=val_boundary_size,
            val_boundary_diff=val_boundary_diff)

        print("\n✓ Main generation complete!")
        validation_results.update(main_generator.validation_results)

        # Check boundaries on main generation
        if width >= val_boundary_size and height >= val_boundary_size and storage:
            print(f"\n--- Checking Main Tile Boundaries ---")
            sample_h = min(val_boundary_size, height)
            sample_w = min(val_boundary_size, width)
            sample_r = storage.array_r[0:sample_h, 0:sample_w]
            main_generator.verify_tile_boundaries(sample_r, val_boundary_diff)
            validation_results.update(main_generator.validation_results)

    except Exception as e:
        print(f"\n✗ Main generation failed: {e}")
        import traceback
        traceback.print_exc()
        validation_results['main_generation'] = f'Failed: {e}'
        return

    # Step 2: Convert to TIFF
    print("\n--- Converting to TIFF ---")
    if not Path(zarr_path).exists():
        print("Zarr file not found. Skipping TIFF conversion.")
        validation_results['tiff_conversion'] = 'Skipped'
    else:
        try:
            converter = ZarrToTIFFConverter(zarr_path, tiff_path)
            converter.convert_with_pyvips(
                read_tile_size=tile_size,
                tiff_tile_width=tiff_tile_w,
                tiff_tile_height=tiff_tile_h,
                compression=tiff_compression)
            print("\n✓ TIFF conversion complete!")
            validation_results['tiff_conversion'] = 'Success'
        except Exception as e:
            print(f"\n✗ TIFF conversion failed: {e}")
            validation_results['tiff_conversion'] = f'Failed: {e}'

    # Step 3: Create PNG versions
    print("\n--- Creating PNG Versions ---")
    if not Path(tiff_path).exists():
        print("TIFF file not found. Skipping PNG conversion.")
        validation_results['png_conversion'] = 'Skipped'
    else:
        try:
            convert_tiff_to_png(tiff_path, png_dir, png_sizes, png_upscale)
            print("\n✓ PNG versions created!")
            validation_results['png_conversion'] = 'Success'
        except Exception as e:
            print(f"\n✗ PNG conversion failed: {e}")
            validation_results['png_conversion'] = f'Failed: {e}'

    # Summary
    print("\n=== Pipeline Complete! ===")
    print(f"Zarr: {zarr_path}")
    print(f"TIFF: {tiff_path}")
    print(f"PNGs: {png_dir}")

    if validation_results:
        print("\n--- Validation Summary ---")
        for key, value in sorted(validation_results.items()):
            status = value if isinstance(value, str) else ('✓ Passed' if value else '✗ Failed')
            print(f"  {key}: {status}")

In [None]:
# Cell 11: Pipeline Configuration
# Run identification
RUN_ID = f"run_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
print(f"Run ID: {RUN_ID}")

# Kernel selection
MAIN_KERNEL = 'complex'  # 'complex' or 'simple'
VAL_KERNEL = 'simple'    # Validation run kernel

# Image dimensions
IMAGE_WIDTH = 169420
IMAGE_HEIGHT = IMAGE_WIDTH
TILE_SIZE = 1000
TILE_OVERLAP = 1 # In pixels?

# Mathematical parameters
CENTER_X = 0.0
CENTER_Y = 0.0
SCALE = 0.001

# Output paths
BASE_DIR = Path(f"/content/gigapixel_output/{RUN_ID}")
BASE_DIR.mkdir(parents=True, exist_ok=True)

CHECKPOINT_DIR = str(BASE_DIR / "checkpoints")
ZARR_PATH = str(BASE_DIR / "image_output.zarr")
TIFF_PATH = str(BASE_DIR / "image_output.tiff")
PNG_DIR = str(BASE_DIR / "png_previews")

# TIFF parameters
TIFF_TILE_WIDTH = 512
TIFF_TILE_HEIGHT = 512
TIFF_COMPRESSION = 'deflate'

# PyVips settings
VIPS_CONCURRENCY = psutil.cpu_count(logical=True) or 2
VIPS_DISC_THRESHOLD = '10gb'
VIPS_CACHE_MEM_MB = 256
VIPS_CACHE_FILES = 20

# Validation parameters
VAL_ENABLED = True
VAL_WIDTH = 1024
VAL_HEIGHT = 1024
VAL_TILE_SIZE = 256
VAL_OVERLAP = TILE_OVERLAP
VAL_CENTER_X = 0.0
VAL_CENTER_Y = 0.0
VAL_SCALE = 0.01

# Validation checks
VAL_PATTERN_SIZE = 256
VAL_PATTERN_SCALE = 1.0
VAL_GPU_CPU_EPSILON = 1e-4
VAL_SAMPLE_TILES = 3
VAL_BOUNDARY_SIZE = 512
VAL_BOUNDARY_DIFF = 15

# CUDA settings
CUDA_THREADS_X = 32
CUDA_THREADS_Y = 32

# Generation settings
BATCH_SIZE = 4
CHECKPOINT_INTERVAL = 600  # Save checkpoint every N tiles

# Memory management
MEMPOOL_CHUNK = TILE_SIZE
MEMPOOL_SIZE = 10
MEMPOOL_LIMIT_GB = int(psutil.virtual_memory().total / (1024**3) * 0.75)
MEMPOOL_THRESHOLD_GB = 1.5

# PNG export
PNG_UPSCALE = True
PNG_SIZES = [
    ('thumbnail', 1024, 1024),
    ('preview', 2048, 2048),
    ('display', 5096, 5096),
    ('hires', 10192, 10192)
]

print(f"Configuration complete for {RUN_ID}")

In [None]:
# Cell 12: Execute Pipeline
if __name__ == "__main__":
    print(f"\n{'='*60}")
    print(f"Starting Gigapixel Pipeline - {RUN_ID}")
    print(f"{'='*60}\n")

    run_gigapixel_pipeline(
        # Core parameters
        width=IMAGE_WIDTH,
        height=IMAGE_HEIGHT,
        tile_size=TILE_SIZE,
        tile_overlap=TILE_OVERLAP,
        center_x=CENTER_X,
        center_y=CENTER_Y,
        scale=SCALE,
        main_kernel=MAIN_KERNEL,

        # Paths
        checkpoint_dir=CHECKPOINT_DIR,
        zarr_path=ZARR_PATH,
        tiff_path=TIFF_PATH,
        png_dir=PNG_DIR,

        # TIFF settings
        tiff_tile_w=TIFF_TILE_WIDTH,
        tiff_tile_h=TIFF_TILE_HEIGHT,
        tiff_compression=TIFF_COMPRESSION,

        # PyVips environment
        vips_concurrency=VIPS_CONCURRENCY,
        vips_disc_threshold=VIPS_DISC_THRESHOLD,
        vips_cache_mem=VIPS_CACHE_MEM_MB,
        vips_cache_files=VIPS_CACHE_FILES,

        # Validation
        val_enabled=VAL_ENABLED,
        val_width=VAL_WIDTH,
        val_height=VAL_HEIGHT,
        val_tile_size=VAL_TILE_SIZE,
        val_overlap=VAL_OVERLAP,
        val_center_x=VAL_CENTER_X,
        val_center_y=VAL_CENTER_Y,
        val_scale=VAL_SCALE,
        val_kernel=MAIN_KERNEL,
        val_pattern_size=VAL_PATTERN_SIZE,
        val_pattern_scale=VAL_PATTERN_SCALE,
        val_epsilon=VAL_GPU_CPU_EPSILON,
        val_sample_tiles=VAL_SAMPLE_TILES,
        val_boundary_size=VAL_BOUNDARY_SIZE,
        val_boundary_diff=VAL_BOUNDARY_DIFF,

        # CUDA
        cuda_threads_x=CUDA_THREADS_X,
        cuda_threads_y=CUDA_THREADS_Y,

        # Generation
        batch_size=BATCH_SIZE,
        checkpoint_interval=CHECKPOINT_INTERVAL,

        # Memory
        mempool_chunk=MEMPOOL_CHUNK,
        mempool_size=MEMPOOL_SIZE,
        mempool_limit_gb=MEMPOOL_LIMIT_GB,
        mempool_threshold_gb=MEMPOOL_THRESHOLD_GB,

        # PNG export
        png_upscale=PNG_UPSCALE,
        png_sizes=PNG_SIZES
    )

    print(f"\n{'='*60}")
    print(f"Pipeline Complete - {RUN_ID}")
    print(f"{'='*60}\n")