In [None]:
def compute_statistics(window):
    stats = []
    for band in range(window.shape[0]):  # Iterate over bands
        band_data = window[band, :, :]  # Get the data for the current band
        
        # Compute the basic statistics
        mean = band_data.mean()
        
        # Apply the transformation to the statistics of the first band
        if band == 0:
            factor = 51.3 * 3.14 * (0.99 ** 2) / (2 ** 16 - 1) / 0.589 / 151.51
            mean *= factor

        elif band==1:
            factor = 30.2 * 3.14 * (0.99 ** 2) / (2 ** 16 - 1) / 0.589 / 182.01
            mean *= factor

        elif band==2:
            factor = 36.7 * 3.14 * (0.99 ** 2) / (2 ** 16 - 1) / 0.589 / 157.54
            mean *= factor

        elif band==3:
            factor = 46.2 * 3.14 * (0.99 ** 2) / (2 ** 16 - 1) / 0.589 / 110.86
            mean *= factor
    
        stats.extend([mean])
    return np.array(stats)  # Return as numpy array

def crop_mask_to_image(mask, image_transform, image_shape):
    """Crop the mask to the extent of the input image based on the image's transform and shape."""
    image_bounds = rasterio.transform.array_bounds(image_shape[0], image_shape[1], image_transform)
    mask_window = rasterio.windows.from_bounds(*image_bounds, transform=mask.transform)
    
    # Adjust window to match the image shape exactly
    mask_window = mask_window.round_offsets().round_lengths()
    
    # Read the mask with the specified window
    cropped_mask = mask.read(1, window=mask_window)
    
    # Resize the mask to match image dimensions if there are minor mismatches
    if cropped_mask.shape != image_shape:
        from skimage.transform import resize
        cropped_mask = resize(cropped_mask, image_shape, preserve_range=True, anti_aliasing=True, order=0)
        cropped_mask = (cropped_mask > 0.5).astype(np.uint8)  # Ensure binary mask
    
    return cropped_mask


def classify_image(image_path, mask_path, model, window_size=5):
    """Classify the image using the trained model and apply the water mask."""
    with rasterio.open(image_path) as src:
        image = src.read()
        height, width = image.shape[1:]
        image_transform = src.transform
        
        with rasterio.open(mask_path) as mask_src:
            mask = crop_mask_to_image(mask_src, image_transform, (height, width))
        
        # Add padding to the image
        pad_width = window_size // 2
        padded_image = np.pad(image, ((0, 0), (pad_width, pad_width), (pad_width, pad_width)), mode='reflect')
        
        # Create windows view
        windows = view_as_windows(padded_image, (image.shape[0], window_size, window_size))
        
        # Flatten the windows
        flattened_windows = windows.reshape(-1, *windows.shape[-3:])
        
        # Compute statistics for each window in parallel
        features = Parallel(n_jobs=-1)(delayed(compute_statistics)(window) for window in flattened_windows)
        
        # Convert the list of features to a NumPy array
        features = np.array(features)
        
        # Predict classes
        predictions = model.predict(features)
        
        # Reshape predictions back to the image shape
        classified_image = predictions.reshape(height, width)
        
        # Apply the water mask
        classified_image[mask == 0] = 6  # Replace with appropriate label
        
        return classified_image, src