# Environment Setup

In [None]:
# High-quality figure settings
import matplotlib as mpl
#mpl.rcParams['figure.dpi'] = 300
#mpl.rcParams['savefig.dpi'] = 300
#mpl.rcParams['font.size'] = 12
#mpl.rcParams['figure.figsize'] = [8.0, 6.0]
#mpl.rcParams['axes.labelsize'] = 'medium'
#mpl.rcParams['axes.titlesize'] = 'medium'
#mpl.rcParams['legend.fontsize'] = 'medium'
#mpl.rcParams['axes.linewidth'] = 0.8
#mpl.rcParams['lines.linewidth'] = 1.5

## Aux

In [None]:
!pip install timm

In [None]:
def visualize_results(*images_and_titles):
    num_images = len(images_and_titles) // 2
    fig, axes = plt.subplots(1, num_images, figsize=(5 * num_images, 5))

    if num_images == 1:
        axes = [axes]

    for i, (image, title) in enumerate(zip(images_and_titles[::2], images_and_titles[1::2])):
        if image.dim() == 3:
            # For RGB images
            axes[i].imshow(image.permute(1, 2, 0).cpu().detach().numpy())
        elif image.dim() == 2:
            # For grayscale images or masks
            axes[i].imshow(image.cpu().detach().numpy(), cmap='gray')
        else:
            raise ValueError(f"Unsupported image dimension: {image.dim()}")

        axes[i].set_title(title)
        axes[i].axis('off')

    plt.tight_layout()
    plt.show()

## Synthetic Dataset Functions

In [None]:
module_name='mpslib'
try:
    __import__(module_name)
    print('%s allready installed. skipping installation.' % module_name)
    exe_folder = ''

except ImportError:
    import sys
    is_colab = 'google.colab' in sys.modules
    print (is_colab)
    if is_colab:
        print('%s cannot be loaded. trying to install it.' % module_name)
        !pip install scikit-mps pyvista panel

        # Recompile from src on Colab
        import pathlib
        import mpslib as mps

        O=mps.mpslib()
        O.compile_mpslib()
        # Next line is needed in GoogleColabe
        !bash mpslib_download_and_install.sh

    else:
        print('Please install MPSlib and scikit-mps from http://github.com/ergosimulation/mpslib/')

In [None]:
import torch
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
from scipy.spatial.distance import pdist, squareform
from scipy.spatial.distance import cdist
from numpy.linalg import inv
from scipy.interpolate import interp1d
import numpy as np
from typing import Dict, Any, Optional

# Utility Functions

def create_grid(size):
    y, x = torch.meshgrid(torch.arange(size[0], dtype=torch.float32),
                          torch.arange(size[1], dtype=torch.float32),
                          indexing='ij')
    return torch.stack((y.flatten(), x.flatten()), dim=1).unsqueeze(0)

def create_param_grid(size, params):
    num_params, param_dims = params.shape
    param_grid = torch.zeros(param_dims - 2, size[0], size[1])
    param_mask = torch.zeros(1, size[0], size[1])

    for param in params:
        y, x = param[:2].long()
        if 0 <= y < size[0] and 0 <= x < size[1]:
            param_grid[:, y, x] = param[2:]
            param_mask[0, y, x] = 1

    return param_grid, param_mask

def create_id_spatial_data(size, params, grid, exponent=2.0):
    coords = params[:, :2]
    values = params[:, 2]
    distances = torch.cdist(grid, coords.unsqueeze(0)).squeeze(0)  # [size[0]*size[1], num_params]
    weights = 1 / (distances ** exponent + 1e-6)
    weighted_values = (weights * values).sum(dim=1)
    weights_sum = weights.sum(dim=1)
    interpolated_values = weighted_values / weights_sum

    return interpolated_values.reshape(1, size[0], size[1])

def create_stationary_spatial_data(size, params, grid):
    mean = params[:, 2].mean().item()
    variance = 0.1

    # Ensure grid is a NumPy array with the correct shape
    if isinstance(grid, torch.Tensor):
        grid = grid.squeeze(0).cpu().numpy()

    # Ensure params is also a NumPy array
    if isinstance(params, torch.Tensor):
        params = params.cpu().numpy()

    # Extract coordinates and known values
    coords = params[:, :2]
    known_values = params[:, 2]

    def exponential_covariance(h, range=3):
        return np.exp(-h / range)

    # Calculate pairwise distances within the entire grid for covariance
    distances = pdist(grid)
    cov_matrix = exponential_covariance(squareform(distances))

    # Find closest grid points to known coordinates
    distances_to_known = cdist(grid, coords)
    closest_indices = np.argmin(distances_to_known, axis=0)

    # Extract necessary sub-matrices from the covariance matrix
    C_oo = cov_matrix[np.ix_(closest_indices, closest_indices)]
    C_ou = cov_matrix[closest_indices, :]
    C_uu = cov_matrix

    # Conditional mean and covariance for Gaussian field
    conditional_mean = mean + C_ou.T @ inv(C_oo) @ (known_values - mean)
    conditional_cov = C_uu - C_ou.T @ inv(C_oo) @ C_ou

    z_conditional = np.random.multivariate_normal(conditional_mean, variance * conditional_cov)
    z_conditional = 1 / (1 + np.exp(-z_conditional))  # Logistic transformation
    z_conditional = np.clip(z_conditional, 0, 1)  # Clipping to ensure [0, 1]

    # Convert the result to a PyTorch tensor
    return torch.from_numpy(z_conditional.reshape((1, size[0], size[1]))).float()

def create_base_stationary_spatial_data(size, grid):
    mean = 0.5
    variance = 0.1
    height, width = size

    # Define the exponential covariance function
    def exponential_covariance(h, range_param=3):
        return np.exp(-h / range_param)

    # Create distance grid
    y = np.arange(-height//2, height//2)
    x = np.arange(-width//2, width//2)
    xx, yy = np.meshgrid(x, y)
    distances = np.sqrt(xx**2 + yy**2)
    cov = exponential_covariance(distances)

    # Use FFT to generate the random field efficiently
    fft_cov = np.fft.fft2(np.fft.fftshift(cov))
    # Ensure that the covariance is real and non-negative
    fft_cov = np.real(fft_cov)
    fft_cov[fft_cov < 0] = 0

    # Generate white noise
    white_noise = np.random.normal(size=(height, width))

    # Multiply in Fourier domain and invert
    fft_noise = np.fft.fft2(white_noise)
    field = np.real(np.fft.ifft2(fft_noise * np.sqrt(fft_cov)))

    # Normalize the field to have the desired variance and mean
    field = field - np.mean(field)
    field = field / np.std(field)
    field = field * np.sqrt(variance) + mean

    # Apply logistic transformation and clip values between 0 and 1
    field = 1 / (1 + np.exp(-field))
    field = np.clip(field, 0, 1)

    # Convert the result to a PyTorch tensor
    return torch.from_numpy(field.reshape((1, height, width))).float()

def create_vid_spatial_data(size, params, grid, exponent=2.0):
    height, width = size

    # Extract x and value from params
    x_coords = params[:, 1].numpy()  # Use x-coordinate (column 1)
    values = params[:, 2].numpy()    # Use the first value (column 2)

    # Sort the points by x-coordinate
    sort_idx = np.argsort(x_coords)
    x_coords = x_coords[sort_idx]
    values = values[sort_idx]

    # Create interpolation function
    f = interp1d(x_coords, values, kind='linear', bounds_error=False, fill_value='extrapolate')

    # Generate x values for interpolation
    x_range = np.arange(width)

    # Interpolate values
    interpolated_values = f(x_range)

    # Create 2D output by repeating the interpolated values vertically
    output = np.tile(interpolated_values, (height, 1))

    return torch.tensor(output, dtype=torch.float32).unsqueeze(0)

def create_cvid_spatial_data(size, params, grid, exponent=2.0):
    height, width = size

    # Extract x and value from params
    x_coords = params[:, 1].numpy()  # Use x-coordinate (column 1)
    values = params[:, 2].numpy()    # Use the first value (column 2)

    # Sort the points by x-coordinate
    sort_idx = np.argsort(x_coords)
    x_coords = x_coords[sort_idx]
    values = values[sort_idx]

    # Create interpolation function
    f = interp1d(x_coords, values, kind='linear', bounds_error=False, fill_value='extrapolate')

    # Generate x values for interpolation
    x_range = np.arange(width)

    # Interpolate values
    interpolated_values = f(x_range)

    # Create 2D output by repeating the interpolated values vertically
    output = np.tile(interpolated_values, (height, 1))

    # Round values to nearest 0.1
    output = np.round(output * 10) / 10

    return torch.tensor(output, dtype=torch.float32).unsqueeze(0)

def create_nn_spatial_data(size, params, grid):
    distances = torch.cdist(grid, params[:, :2].unsqueeze(0)).squeeze(0)  # [size[0]*size[1], num_params]
    nearest_idx = distances.argmin(dim=1)  # Get indices of the nearest parameters
    values = params[nearest_idx, 2]  # Use the value of the nearest parameter
    return values.reshape(1, size[0], size[1])

def create_kriging_spatial_data(size, params, grid, range_param=100.0, nugget=1e-5):
    # Extract coordinates and values
    coords = params[:, :2].cpu().numpy()
    values = params[:, 2].cpu().numpy()

    # Calculate distance matrix
    dist_matrix = cdist(coords, coords, metric='euclidean')

    # Calculate covariance matrix (using an exponential model)
    variogram_model = lambda h: np.exp(-h / range_param)
    cov_matrix = variogram_model(dist_matrix)

    # Add nugget effect
    cov_matrix += np.eye(len(cov_matrix)) * nugget

    # Calculate weights
    weights = np.linalg.solve(cov_matrix, values)

    # Calculate distances from grid points to known points
    grid_coords = grid[0].cpu().numpy()
    dist_to_grid = cdist(grid_coords, coords, metric='euclidean')

    # Calculate covariance from grid points to known points
    cov_to_grid = variogram_model(dist_to_grid)

    # Interpolate values at grid points
    interpolated_values = cov_to_grid @ weights

    return torch.tensor(interpolated_values.reshape(size[0], size[1]), dtype=torch.float32).unsqueeze(0)


def create_layered_spatial_data(size, params, grid):
    height, width = size
    output = torch.full((1, height, width), 0.0)  # Start with all zeros

    # Group points by their integer value (layer)
    layers = {}
    for y, x, value in params:
        layer = int(value*10)
        if layer not in layers:
            layers[layer] = []
        layers[layer].append((x.item(), y.item(), value.item()))

    # Sort layers by value (bottom to top)
    sorted_layers = sorted(layers.items(), key=lambda x: x[0])

    # Keep track of the highest point filled for each x-coordinate
    highest_filled = torch.zeros(width, dtype=torch.long)

    # Create splines for each layer
    for layer_value, points in sorted_layers:
        if len(points) < 2:
            continue  # Need at least 2 points for interpolation
        points.sort(key=lambda p: p[0])  # Sort points by x-coordinate
        x_coords, y_coords, values = zip(*points)

        # Create a linear interpolation
        spline = interp1d(x_coords, y_coords, kind='linear', bounds_error=False, fill_value='extrapolate')

        # Evaluate each pixel for this layer
        for x in range(width):
            y_intersect = int(spline(x))
            y_start = max(y_intersect, highest_filled[x])
            if y_start < height:
                output[0, y_start:, x] = layer_value / 10
                highest_filled[x] = y_start

    return output

def create_mps_spatial_data(size, params, grid):
    height, width = size

    # Initialize MPSlib with GENESIM method
    O = mps.mpslib(method='mps_genesim', simulation_grid_size=np.array([ height, width, 1]), verbose_level=0, debug_level=-1)

    # Set up MPSlib parameters
    O.parameter_filename = 'mps.txt'
    O.par['n_real'] = 1
    O.par['n_cond'] = 25
    O.par['template_size'] = np.array([[10, 5], [10, 5], [1, 1]])

    value =  (params[:, 2] < 0.5)
    print(value)
    # Prepare hard data
    hard_data = np.column_stack((params[:, 0], params[:, 1], np.zeros(params.shape[0]), value))
    O.d_hard = hard_data

    # Generate or load a training image
    TI, _ = mps.trainingimages.strebelle(di=2, coarse3d=1)
    O.ti = TI

    # Run the simulation
    O.run()

    # Get the simulated result
    sim = O.sim[0].squeeze()

    return torch.tensor(sim, dtype=torch.float32).unsqueeze(0)

def plot_sample(x: torch.Tensor, param_grid: torch.Tensor, param_mask: torch.Tensor, mask: Optional[torch.Tensor] = None):
    num_channels = x.shape[0]
    size = x.shape[1]
    fig, axs = plt.subplots(num_channels, 3 if mask is None else 4, figsize=(15 if mask is None else 20, 5*num_channels))
    plt.subplots_adjust(wspace=0.3, hspace=0.3)

    if num_channels == 1:
        axs = [axs]

    x = x.cpu().numpy() if isinstance(x, torch.Tensor) else x
    param_grid = param_grid.cpu().numpy() if isinstance(param_grid, torch.Tensor) else param_grid
    param_mask = param_mask.cpu().numpy() if isinstance(param_mask, torch.Tensor) else param_mask

    if mask is not None:
        mask = mask.cpu().numpy() if isinstance(mask, torch.Tensor) else mask

    for i in range(num_channels):
        im = axs[i][0].imshow(x[i], cmap='viridis', interpolation='nearest')
        axs[i][0].set_title('Spatial Data')
        axs[i][0].axis('off')
        #plt.colorbar(im, ax=axs[i][0], fraction=0.046, pad=0.04)

        if mask is not None:
            mask_indices = np.argwhere(mask[0] > 0)
            mask_values = x[i][mask[0] > 0]
            axs[i][1].imshow(x[i], cmap='viridis', interpolation='nearest', alpha=0.5)
            axs[i][1].scatter(mask_indices[:, 1], mask_indices[:, 0], c=mask_values, cmap='viridis', edgecolors='black', s=50)
            axs[i][1].set_title('Sampling Mask')
            axs[i][1].axis('off')
            #plt.colorbar(im, ax=axs[i][1], fraction=0.046, pad=0.04)

        im = axs[i][-2].imshow(param_grid[i], cmap='viridis', interpolation='nearest')
        axs[i][-2].set_title('Parameter Grid')
        axs[i][-2].axis('off')
        #plt.colorbar(im, ax=axs[i][-2], fraction=0.046, pad=0.04)

        param_indices = np.argwhere(param_mask[0] > 0)
        param_values = param_grid[i][param_mask[0] > 0]
        axs[i][-1].imshow(param_grid[i], cmap='viridis', interpolation='nearest', alpha=0.5)
        axs[i][-1].scatter(param_indices[:, 1], param_indices[:, 0], c=param_values, cmap='viridis', edgecolors='black', s=50)
        axs[i][-1].set_title('Parameter Mask')
        axs[i][-1].axis('off')
        #plt.colorbar(im, ax=axs[i][-1], fraction=0.046, pad=0.04)

    plt.tight_layout()
    plt.show()

# Datasets
class SpatialGenerator():
    def __init__(self, size, param_generator, methods=['kriging']):
        self.size = size
        self.param_generator = param_generator
        self.methods = methods if isinstance(methods, list) else [methods]
        self.grid = create_grid(size)

    def generate_item(self):
        # Generate parameters using the provided lambda function
        params = self.param_generator()

        x = torch.zeros(params.shape[1] - 2, self.size[0], self.size[1])  # Assuming params has columns for x, y, and other values
        for i in range(params.shape[1] - 2):
            method_index = min(i, len(self.methods) - 1)
            method = self.methods[method_index].lower().strip()
            channel_params = params[:, [0, 1, i + 2]]
            if method == 'id':
                x[i] = create_id_spatial_data(self.size, channel_params, self.grid)
            elif method == 'vid':
                x[i] = create_vid_spatial_data(self.size, channel_params, self.grid)
            elif method == 'cvid':
                x[i] = create_cvid_spatial_data(self.size, channel_params, self.grid)
            elif method == 'nn':
                x[i] = create_nn_spatial_data(self.size, channel_params, self.grid)
            elif method == 'kriging':
                x[i] = create_kriging_spatial_data(self.size, channel_params, self.grid)
            elif method == 'stationary':
                x[i] = create_stationary_spatial_data(self.size, channel_params, self.grid)
            elif method == 'bstationary':
                x[i] = create_base_stationary_spatial_data(self.size, self.grid)
            elif method == 'layers':
                x[i] = create_layered_spatial_data(self.size, channel_params, self.grid)
            elif method == 'mps':
                x[i] = create_mps_spatial_data(self.size, channel_params, self.grid)
            else:
                raise ValueError(f"Unsupported interpolation method: {method}")

        param_grid, param_mask = create_param_grid(self.size, params)

        return x, param_grid, param_mask

class CategoricalSpatialGenerator():
    def __init__(self, size, param_generator, num_categories, methods):
        self.size = size
        self.param_generator = param_generator
        self.num_categories = num_categories
        self.methods = methods
        self.grid = create_grid(size)

    def generate_item(self):
        # Generate parameters using the provided lambda function
        params = self.param_generator()
        params = self.normalize_data(params)

        x = torch.zeros(len(self.methods), self.size[0], self.size[1])
        param_grid = torch.zeros(len(self.methods), self.size[0], self.size[1])
        param_mask = torch.zeros(1, self.size[0], self.size[1])

        # Generate categorical data using the first 3 columns of params
        category_data = self.create_categorical_data(params)
        x[0] = category_data.squeeze()  # Remove the channel dimension
        param_grid[0] = category_data.squeeze()  # Remove the channel dimension

        # Initialize category mask
        for param in params:
            y, x_coord = param[:2].long()
            if 0 <= y < self.size[0] and 0 <= x_coord < self.size[1]:
                param_mask[0, y, x_coord] = 1

        category_params = [[] for _ in range(self.num_categories)]
        for param in params:
            category = self.continuous_to_categorical(param[2]).item()
            category_params[category].append(param)

        for category in range(self.num_categories):
            if len(category_params[category]) > 0:
                cat_params = torch.stack(category_params[category])
                for i, method in enumerate(self.methods[1:], start=1):
                    channel_params = cat_params[:, [0, 1, i + 2]]
                    interpolated_values = self.interpolate(method, channel_params)
                    x[i] = torch.where(category_data.squeeze() == category, interpolated_values, x[i])
                    param_grid[i] = torch.where(category_data.squeeze() == category, interpolated_values, param_grid[i])

        return x, param_grid, param_mask


    def normalize_data(self, params):
        values = params[:, 2:]
        min_val = values.min(0, keepdim=True)[0]
        max_val = values.max(0, keepdim=True)[0]
        normalized_values = (values - min_val) / (max_val - min_val + 1e-6)  # Adding a small constant to avoid division by zero
        params[:, 2:] = normalized_values
        return params

    def create_categorical_data(self, params):
        first_method = self.methods[0]
        continuous_data = self.interpolate_method(first_method, params[:, :3])
        return self.continuous_to_categorical(continuous_data)

    def continuous_to_categorical(self, x):
        return torch.clamp((x * self.num_categories).long(), 0, self.num_categories - 1)

    def interpolate(self, method, channel_params):
        if method == 'id':
            return create_id_spatial_data(self.size, channel_params, self.grid)[0]
        elif method == 'vid':
            return create_vid_spatial_data(self.size, channel_params, self.grid)[0]
        elif method == 'cvid':
            return create_cvid_spatial_data(self.size, channel_params, self.grid)[0]
        elif method == 'nn':
            return create_nn_spatial_data(self.size, channel_params, self.grid)
        elif method == 'kriging':
            return create_kriging_spatial_data(self.size, channel_params, self.grid)
        elif method == 'mps':
            return create_mps_spatial_data(self.size, channel_params, self.grid)
        elif method == 'stationary':
            return create_stationary_spatial_data(self.size, channel_params, self.grid)
        elif method == 'layered':
            return create_layered_spatial_data(self.size, channel_params, self.grid)
        else:
            raise ValueError(f"Unsupported interpolation method: {method}")

    def interpolate_method(self, method, channel_params):
        if method == 'id':
            return create_id_spatial_data(self.size, channel_params, self.grid)
        elif method == 'vid':
            return create_vid_spatial_data(self.size, channel_params, self.grid)
        elif method == 'cvid':
            return create_cvid_spatial_data(self.size, channel_params, self.grid)
        elif method == 'nn':
            return create_nn_spatial_data(self.size, channel_params, self.grid)
        elif method == 'kriging':
            return create_kriging_spatial_data(self.size, channel_params, self.grid)
        elif method == 'layered':
            return create_layered_spatial_data(self.size, channel_params, self.grid)
        elif method == 'mps':
            return create_mps_spatial_data(self.size, channel_params, self.grid)
        elif method == 'stationary':
            return create_stationary_spatial_data(self.size, channel_params, self.grid)
        else:
            raise ValueError(f"Unsupported interpolation method: {method}")


param_generator = lambda: torch.rand(10, 5) * torch.tensor([32, 64, 1, 1, 1])
param_generator = lambda: torch.cat([
    torch.rand(10, 2) * torch.tensor([32, 64]),
    torch.rand(10, 1).repeat(1, 3)
], dim=1)
size = (32, 64)

#x, param_grid, param_mask = SpatialGenerator(size, param_generator, methods=['stationary', 'kriging', 'mps']).generate_item()
#plot_sample(x, param_grid, param_mask)

#num_categories = 20
#x, param_grid, param_mask = CategoricalSpatialGenerator(size, param_generator, num_categories, methods=['mps', 'stationary']).generate_item()
#plot_sample(x, param_grid, param_mask)


In [None]:
import numpy as np
import torch
import matplotlib.pyplot as plt
import random

class BaseModel:
    def __init__(self):
        self.layers = []

    def add_deposit(self, points, value):
        self.layers.append((points, value))
        return self

    def _distribute_values(self, points, value):
        num_points = len(points)
        if isinstance(value, (int, float)):
            return [value] * num_points
        elif isinstance(value, (tuple, list)):
            return np.interp(np.linspace(0, 1, num_points), np.linspace(0, 1, len(value)), value)
        raise ValueError("Value must be a number or an array of numbers")

    def generate_primarys(self):
        primarys = []
        for layer, (points, value) in enumerate(self.layers):
            distributed_values = self._distribute_values(points, value)
            for (x, y), v in zip(points, distributed_values):
                primarys.append((y, x, layer, v))
        return torch.tensor(primarys, dtype=torch.float32)

    def plot_model(self):
        fig, ax = plt.subplots(figsize=(10, 6))
        for i, (points, value) in enumerate(self.layers):
            x_coords, y_coords = zip(*points)
            distributed_values = self._distribute_values(points, value)
            scatter = ax.scatter(x_coords, y_coords, c=distributed_values, cmap='viridis', s=50, edgecolors='black')
            ax.plot(x_coords, y_coords, '-', alpha=0.5)
            plt.colorbar(scatter, ax=ax, label=f'Layer {i}')
        ax.set_ylim(ax.get_ylim()[::-1])  # Invert y-axis
        ax.set_xlabel('X'), ax.set_ylabel('Y')
        ax.set_title('Geological Model Layers')
        ax.grid(True)
        plt.tight_layout()
        plt.show()

def generate_points(num_points, width, elevations, noise_factor=0.1):
    x = np.linspace(0, width, num_points)

    # Interpolate elevations
    if len(elevations) > 2:
        control_points = np.linspace(0, width, len(elevations))
        y = np.interp(x, control_points, elevations)
    else:
        y = np.linspace(elevations[0], elevations[-1], num_points)

    # Add noise
    elevation_range = max(elevations) - min(elevations)
    noise = np.random.normal(0, noise_factor * elevation_range, num_points)
    y += noise

    return list(zip(x, y))

def sine_vals(num_values, start=0, amp=40, noise=0.1):
    x = np.linspace(0, 2 * np.pi, num_values)
    amplitude = amp * np.random.rand()
    phase = np.random.rand() * 2 * np.pi
    y = amplitude * np.sin(x + phase) + np.random.normal(0, 0.1, num_values) + start
    return y

def generate_constrained_values(num_values, max_change=0.1):
    base_value = random.random()
    values = [base_value]
    for _ in range(num_values - 1):
        change = random.uniform(-max_change, max_change)
        new_value = max(0, min(1, values[-1] + change))
        values.append(new_value)
    return values

def two_layer_generator():
    model = BaseModel()
    width = 64
    height = 32
    layers = random.randint(2, 2)
    points = random.randint(5, 15)
    amp = height / layers

    # First layer always at 0
    first_layer_points = generate_points(points, width, [0, 0], noise_factor=0)
    random_values = generate_constrained_values(random.randint(2, 5))
    model.add_deposit(first_layer_points, random_values)

    for i in range(1, layers):
        random_values = generate_constrained_values(random.randint(2, 5))
        model.add_deposit(
            generate_points(points, width, sine_vals(points, amp * i - amp/2, amp=amp), noise_factor=0.05),
            random_values
        )

    return model.generate_primarys()

# Example usage:
size = (32, 64)

# Create the generator
#generator = CategoricalSpatialGenerator(size, lambda: two_layer_generator(), num_categories=10, methods=['layered', 'vid'])

# Generate the item
#x, primary_grid, primary_mask = generator.generate_item()

# Plot the sample
#plot_sample(x, primary_grid, primary_mask)


In [None]:
import os, random, torch, numpy as np
from datetime import datetime
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import pickle
from tqdm.notebook import tqdm

# Grid sampling pattern
def grid_sampling(size):
    num_points = int(0.1 * size[0] * size[1])  # 10% of total points
    step = max(1, int(min(size) / (num_points ** 0.5)))
    mask = torch.zeros(1, *size)
    mask[0, ::step, ::step] = 1
    return mask

# Clustered sampling pattern
def clustered_sampling(size):
    num_points = int(0.1 * size[0] * size[1])  # 10% of total points
    mask = torch.zeros(1, *size)
    num_clusters = 5
    points_per_cluster = num_points // num_clusters
    for _ in range(num_clusters):
        center = torch.tensor([torch.randint(0, size[0], (1,)).item(),
                               torch.randint(0, size[1], (1,)).item()])
        cluster_points = torch.randn(points_per_cluster, 2) * torch.tensor([size[0], size[1]]) / 10 + center
        cluster_points[:, 0] = cluster_points[:, 0].clamp(0, size[0] - 1)
        cluster_points[:, 1] = cluster_points[:, 1].clamp(0, size[1] - 1)
        cluster_points = cluster_points.long()
        mask[0, cluster_points[:, 0], cluster_points[:, 1]] = 1
    return mask


class SpatialDataset(Dataset):
    def __init__(self, num_generations, generator, sampling_fn, secondary_grid_fn, data_folder=None, dynamic_secondary_mask=False, x_channels=None, primary_channels=None, secondary_channels=None):
        self.num_generations = num_generations
        self.generator = generator
        self.sampling_fn = sampling_fn
        self.secondary_grid_fn = secondary_grid_fn
        self.data_folder = data_folder
        self.dynamic_secondary_mask = dynamic_secondary_mask
        self.x_channels = x_channels
        self.primary_channels = primary_channels
        self.secondary_channels = secondary_channels

        if data_folder is not None:
            os.makedirs(self.data_folder, exist_ok=True)
            if num_generations > 0:
                self._generate_and_save_entries()
            self.data = self._load_all_entries()
        else:
            self.data = self._generate_items()

    def __len__(self):
        return len(self.data)

    def _generate_and_save_entries(self):
        entries = self._generate_items()
        with open(os.path.join(self.data_folder, f'entries_{datetime.now().strftime("%Y%m%d_%H%M%S")}.pkl'), 'wb') as f:
            pickle.dump(entries, f)

    def _load_all_entries(self):
        entries = []
        for file in os.listdir(self.data_folder):
            if file.endswith('.pkl'):
                with open(os.path.join(self.data_folder, file), 'rb') as f:
                    entries.extend(pickle.load(f))
        return entries

    def _generate_items(self):
        return [self._generate_item() for i in tqdm(range(self.num_generations), desc="Generating items", mininterval=1.0)]

    def _generate_item(self):
        x, primary_grid, primary_mask = self.generator.generate_item()
        secondary_grid = self.secondary_grid_fn(x)
        secondary_mask = self.sampling_fn(secondary_grid.shape[1:])

        return x, primary_grid, primary_mask, secondary_grid, secondary_mask

    def __getitem__(self, idx):
        x, primary_grid, primary_mask, secondary_grid, saved_secondary_mask = self.data[idx]

        if self.dynamic_secondary_mask:
            secondary_mask = self.sampling_fn(secondary_grid.shape[1:])
        else:
            secondary_mask = saved_secondary_mask

        x = self._apply_channel_selection(x, self.x_channels)
        primary_grid = self._apply_channel_selection(primary_grid, self.primary_channels)
        secondary_grid = self._apply_channel_selection(secondary_grid, self.secondary_channels)

        return x, primary_grid, primary_mask, secondary_grid, secondary_mask

    def _apply_channel_selection(self, tensor, channel_selection):
        if channel_selection is not None:
            if isinstance(channel_selection, int):
                return tensor[channel_selection:channel_selection+1]
            elif isinstance(channel_selection, (list, tuple)):
                return tensor[list(channel_selection)]
            else:
                raise ValueError("channel_selection must be an integer, list, or tuple")
        return tensor

# Example usage (rest of the code remains the same)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def secondary_grid_fn(x):
    return 1-x

def equal_grid_fn(x):
    return x

def random_sampling(size, min_samples=3):
    total_points = size[0] * size[1]
    target_samples = max(min_samples, int(0.01 * total_points))  # 10% of total points or min_samples, whichever is larger

    mask = (torch.rand(1, *size) < (target_samples / total_points)).float()

    # Ensure we have at least min_samples
    while torch.sum(mask) < min_samples:
        additional_samples = torch.rand(1, *size) < (min_samples - torch.sum(mask)) / total_points
        mask = torch.logical_or(mask, additional_samples).float()

    return mask

def drilling_sampling(size, min_drillholes=5, max_drillholes=15, min_samples=3, max_samples=20):
    mask = torch.zeros(1, *size)
    height, width = size
    num_drillholes = random.randint(min_drillholes, max_drillholes)

    for _ in range(num_drillholes):
        x = random.randint(0, width - 1)
        num_samples = random.randint(min_samples, max_samples)
        y_positions = random.sample(range(height), min(num_samples, height))
        for y in y_positions:
            mask[0, y, x] = 1

    return mask
generator = CategoricalSpatialGenerator(size, lambda: two_layer_generator(), num_categories=10, methods=['layered', 'vid'])
dataset = SpatialDataset(1, generator, drilling_sampling, secondary_grid_fn, data_folder="data", dynamic_secondary_mask=False, x_channels=1, secondary_channels=1, primary_channels=1)
dataloader = DataLoader(dataset, batch_size=16, shuffle=True, num_workers=0, pin_memory=True)
print(len(dataset))

for batch in dataloader:
    x, primary_grid, primary_mask, secondary_grid, secondary_mask = batch
    print(f"Shapes: x: {x.shape}, primary_grid: {primary_grid.shape}, primary_mask: {primary_mask.shape}, secondary_grid: {secondary_grid.shape}, secondary_mask: {secondary_mask.shape}")

    plot_sample(x[0], primary_grid[0], primary_mask[0])
    plot_sample(secondary_grid[0], secondary_grid[0] * secondary_mask[0], secondary_mask[0])
    break

## Sparse Autoencoder Model

In [None]:
"""
This implementation is based on the SparK (Sparse masKed modeling) approach
from the paper "Designing BERT for Convolutional Networks: Sparse and Hierarchical Masked Modeling"
by Tian et al. (https://arxiv.org/pdf/2301.03580.pdf).

Modifications:
- Adapted the _get_active_ex_or_ii function to be more consistent with the original implementation
- Integrated the sparse masking strategy with the ConvNeXt architecture
"""

import torch
import torch.nn as nn
from timm.models.layers import DropPath

_cur_active: torch.Tensor = None

def _get_active_ex_or_ii(H, W, returning_active_ex=True):
    cur_H, cur_W = _cur_active.shape[-2:]

    if H != cur_H or W != cur_W:
        scale_factor = cur_H // H
        active_ex = F.max_pool2d(_cur_active, kernel_size=scale_factor, stride=scale_factor)
        active_ex = (active_ex > 0.5).float()
    else:
        active_ex = _cur_active

    return active_ex if returning_active_ex else active_ex.squeeze(1).nonzero(as_tuple=True)


def sp_conv_forward(self, x: torch.Tensor):
    x = super(type(self), self).forward(x)
    x *= _get_active_ex_or_ii(H=x.shape[2], W=x.shape[3], returning_active_ex=True)    # (BCHW) *= (B1HW), mask the output of conv
    return x


def sp_bn_forward(self, x: torch.Tensor):
    ii = _get_active_ex_or_ii(H=x.shape[2], W=x.shape[3], returning_active_ex=False)

    bhwc = x.permute(0, 2, 3, 1)
    nc = bhwc[ii]                               # select the features on non-masked positions to form a flatten feature `nc`
    nc = super(type(self), self).forward(nc)    # use BN1d to normalize this flatten feature `nc`

    bchw = torch.zeros_like(bhwc)
    bchw[ii] = nc
    bchw = bchw.permute(0, 3, 1, 2)
    return bchw


class SparseConv2d(nn.Conv2d):
    forward = sp_conv_forward   # hack: override the forward function; see `sp_conv_forward` above for more details


class SparseMaxPooling(nn.MaxPool2d):
    forward = sp_conv_forward   # hack: override the forward function; see `sp_conv_forward` above for more details


class SparseAvgPooling(nn.AvgPool2d):
    forward = sp_conv_forward   # hack: override the forward function; see `sp_conv_forward` above for more details


class SparseBatchNorm2d(nn.BatchNorm1d):
    forward = sp_bn_forward     # hack: override the forward function; see `sp_bn_forward` above for more details


class SparseSyncBatchNorm2d(nn.SyncBatchNorm):
    forward = sp_bn_forward     # hack: override the forward function; see `sp_bn_forward` above for more details


class SparseConvNeXtLayerNorm(nn.LayerNorm):
    def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last", sparse=True):
        if data_format not in ["channels_last", "channels_first"]:
            raise NotImplementedError
        super().__init__(normalized_shape, eps, elementwise_affine=True)
        self.data_format = data_format
        self.sparse = sparse

    def forward(self, x):
        if x.ndim == 4: # BHWC or BCHW
            if self.data_format == "channels_last": # BHWC
                if self.sparse:
                    ii = _get_active_ex_or_ii(H=x.shape[1], W=x.shape[2], returning_active_ex=False)
                    nc = x[ii]
                    nc = super(SparseConvNeXtLayerNorm, self).forward(nc)

                    x = torch.zeros_like(x)
                    x[ii] = nc
                    return x
                else:
                    return super(SparseConvNeXtLayerNorm, self).forward(x)
            else:       # channels_first, BCHW
                if self.sparse:
                    ii = _get_active_ex_or_ii(H=x.shape[2], W=x.shape[3], returning_active_ex=False)
                    bhwc = x.permute(0, 2, 3, 1)
                    nc = bhwc[ii]
                    nc = super(SparseConvNeXtLayerNorm, self).forward(nc)

                    x = torch.zeros_like(bhwc)
                    x[ii] = nc
                    return x.permute(0, 3, 1, 2)
                else:
                    u = x.mean(1, keepdim=True)
                    s = (x - u).pow(2).mean(1, keepdim=True)
                    x = (x - u) / torch.sqrt(s + self.eps)
                    x = self.weight[:, None, None] * x + self.bias[:, None, None]
                    return x
        else:           # BLC or BC
            if self.sparse:
                raise NotImplementedError
            else:
                return super(SparseConvNeXtLayerNorm, self).forward(x)

    def __repr__(self):
        return super(SparseConvNeXtLayerNorm, self).__repr__()[:-1] + f', ch={self.data_format.split("_")[-1]}, sp={self.sparse})'


class SparseConvNeXtBlock(nn.Module):
    def __init__(self, dim, drop_path=0., layer_scale_init_value=1e-6, sparse=True, ks=7):
        super().__init__()
        self.dwconv = nn.Conv2d(dim, dim, kernel_size=ks, padding=ks//2, groups=dim)  # depthwise conv
        self.norm = SparseConvNeXtLayerNorm(dim, eps=1e-6, sparse=sparse)
        self.pwconv1 = nn.Linear(dim, 4 * dim)  # pointwise/1x1 convs, implemented with linear layers
        self.act = nn.GELU()
        self.pwconv2 = nn.Linear(4 * dim, dim)
        self.gamma = nn.Parameter(layer_scale_init_value * torch.ones((dim)),
                                  requires_grad=True) if layer_scale_init_value > 0 else None
        self.drop_path: nn.Module = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        self.sparse = sparse

    def forward(self, x):
        input = x
        x = self.dwconv(x)
        x = x.permute(0, 2, 3, 1)  # (N, C, H, W) -> (N, H, W, C)
        x = self.norm(x)
        x = self.pwconv1(x)
        x = self.act(x)            # GELU(0) == (0), so there is no need to mask x (no need to `x *= _get_active_ex_or_ii`)
        x = self.pwconv2(x)
        if self.gamma is not None:
            x = self.gamma * x
        x = x.permute(0, 3, 1, 2)  # (N, H, W, C) -> (N, C, H, W)

        if self.sparse:
            x *= _get_active_ex_or_ii(H=x.shape[2], W=x.shape[3], returning_active_ex=True)

        x = input + self.drop_path(x)
        return x

    def __repr__(self):
        return super(SparseConvNeXtBlock, self).__repr__()[:-1] + f', sp={self.sparse})'


class SparseEncoder(nn.Module):
    def __init__(self, cnn, input_size, sbn=False, verbose=False):
        super(SparseEncoder, self).__init__()
        self.sp_cnn = SparseEncoder.dense_model_to_sparse(m=cnn, verbose=verbose, sbn=sbn)
        self.input_size, self.downsample_raito, self.enc_feat_map_chs = input_size, cnn.get_downsample_ratio(), cnn.get_feature_map_channels()

    @staticmethod
    def dense_model_to_sparse(m: nn.Module, verbose=False, sbn=False):
        oup = m
        if isinstance(m, nn.Conv2d):
            m: nn.Conv2d
            bias = m.bias is not None
            oup = SparseConv2d(
                m.in_channels, m.out_channels,
                kernel_size=m.kernel_size, stride=m.stride, padding=m.padding,
                dilation=m.dilation, groups=m.groups, bias=bias, padding_mode=m.padding_mode,
            )
            oup.weight.data.copy_(m.weight.data)
            if bias:
                oup.bias.data.copy_(m.bias.data)
        elif isinstance(m, nn.MaxPool2d):
            m: nn.MaxPool2d
            oup = SparseMaxPooling(m.kernel_size, stride=m.stride, padding=m.padding, dilation=m.dilation, return_indices=m.return_indices, ceil_mode=m.ceil_mode)
        elif isinstance(m, nn.AvgPool2d):
            m: nn.AvgPool2d
            oup = SparseAvgPooling(m.kernel_size, m.stride, m.padding, ceil_mode=m.ceil_mode, count_include_pad=m.count_include_pad, divisor_override=m.divisor_override)
        elif isinstance(m, (nn.BatchNorm2d, nn.SyncBatchNorm)):
            m: nn.BatchNorm2d
            oup = (SparseSyncBatchNorm2d if sbn else SparseBatchNorm2d)(m.weight.shape[0], eps=m.eps, momentum=m.momentum, affine=m.affine, track_running_stats=m.track_running_stats)
            oup.weight.data.copy_(m.weight.data)
            oup.bias.data.copy_(m.bias.data)
            oup.running_mean.data.copy_(m.running_mean.data)
            oup.running_var.data.copy_(m.running_var.data)
            oup.num_batches_tracked.data.copy_(m.num_batches_tracked.data)
            if hasattr(m, "qconfig"):
                oup.qconfig = m.qconfig
        elif isinstance(m, nn.LayerNorm) and not isinstance(m, SparseConvNeXtLayerNorm):
            m: nn.LayerNorm
            oup = SparseConvNeXtLayerNorm(m.weight.shape[0], eps=m.eps)
            oup.weight.data.copy_(m.weight.data)
            oup.bias.data.copy_(m.bias.data)
        elif isinstance(m, (nn.Conv1d,)):
            raise NotImplementedError

        for name, child in m.named_children():
            oup.add_module(name, SparseEncoder.dense_model_to_sparse(child, verbose=verbose, sbn=sbn))
        del m
        return oup

    def forward(self, x):
        return self.sp_cnn(x, hierarchical=True)

In [None]:
"""
This implementation is based on the ConvNeXt architecture
from the paper "A ConvNet for the 2020s"
by Liu et al. (https://arxiv.org/pdf/2201.03545.pdf).

Modifications:
- Adapted to use a flexible number of stages, controlled by the lengths of 'depths' and 'dims' lists
- Integrated with the SparK sparse masking strategy from Tian et al. (https://arxiv.org/pdf/2301.03580.pdf)
- Added a custom Decoder class for upsampling and reconstruction

The ConvNeXt architecture is combined with the sparse masking approach
from SparK (Tian et al., https://arxiv.org/pdf/2301.03580.pdf)
to create a sparse, hierarchical masked modeling framework for convolutional networks.
"""

from typing import List
import torch
import torch.nn as nn
from timm.models.layers import trunc_normal_
from timm.models import register_model
import torch.nn.functional as F
import math


class ConvNeXt(nn.Module):
    r""" ConvNeXt
        A PyTorch impl of : `A ConvNet for the 2020s`  -
          https://arxiv.org/pdf/2201.03545.pdf
    Args:
        in_chans (int): Number of input image channels. Default: 3
        num_classes (int): Number of classes for classification head. Default: 1000
        depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3]
        dims (int): Feature dimension at each stage. Default: [96, 192, 384, 768]
        drop_path_rate (float): Stochastic depth rate. Default: 0.
        layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
        head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1.
    """

    def __init__(self, in_chans=1, num_classes=1000,
                 depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], drop_path_rate=0.,
                 layer_scale_init_value=1e-6, head_init_scale=1., global_pool='avg',
                 sparse=True,
                 ):
        super().__init__()
        self.dims: List[int] = dims
        self.downsample_layers = nn.ModuleList()  # stem and 3 intermediate downsampling conv layers
        stem = nn.Sequential(
            nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4),
            SparseConvNeXtLayerNorm(dims[0], eps=1e-6, data_format="channels_first", sparse=sparse)
        )
        self.downsample_layers.append(stem)
        for i in range(3):
            downsample_layer = nn.Sequential(
                SparseConvNeXtLayerNorm(dims[i], eps=1e-6, data_format="channels_first", sparse=sparse),
                nn.Conv2d(dims[i], dims[i + 1], kernel_size=2, stride=2),
            )
            self.downsample_layers.append(downsample_layer)

        self.stages = nn.ModuleList()  # 4 feature resolution stages, each consisting of multiple residual blocks
        self.drop_path_rate = drop_path_rate
        self.layer_scale_init_value = layer_scale_init_value
        dp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
        cur = 0
        for i in range(4):
            stage = nn.Sequential(
                *[SparseConvNeXtBlock(dim=dims[i], drop_path=dp_rates[cur + j],
                                      layer_scale_init_value=layer_scale_init_value, sparse=sparse) for j in range(depths[i])]
            )
            self.stages.append(stage)
            cur += depths[i]
        self.depths = depths

        self.apply(self._init_weights)
        if num_classes > 0:
            self.norm = SparseConvNeXtLayerNorm(dims[-1], eps=1e-6, sparse=False)  # final norm layer for LE/FT; should not be sparse
            self.fc = nn.Linear(dims[-1], num_classes)
        else:
            self.norm = nn.Identity()
            self.fc = nn.Identity()

    def _init_weights(self, m):
        if isinstance(m, (nn.Conv2d, nn.Linear)):
            trunc_normal_(m.weight, std=.02)
            nn.init.constant_(m.bias, 0)

    def get_downsample_ratio(self) -> int:
        return 32

    def get_feature_map_channels(self) -> List[int]:
        return self.dims

    def forward(self, x, hierarchical=False):
        if hierarchical:
            ls = []
            for i in range(4):
                x = self.downsample_layers[i](x)
                x = self.stages[i](x)
                ls.append(x)
            return ls
        else:
            return self.fc(self.norm(x.mean([-2, -1]))) # (B, C, H, W) =mean=> (B, C) =norm&fc=> (B, NumCls)

    def get_classifier(self):
        return self.fc

    def extra_repr(self):
        return f'drop_path_rate={self.drop_path_rate}, layer_scale_init_value={self.layer_scale_init_value:g}'

def is_pow2n(x):
    return x > 0 and (x & (x - 1) == 0)

class UNetBlock(nn.Module):
    def __init__(self, cin, cout, bn2d):
        """
        a UNet block with 2x up sampling
        """
        super().__init__()
        self.up_sample = nn.ConvTranspose2d(cin, cin, kernel_size=4, stride=2, padding=1, bias=True)
        self.conv = nn.Sequential(
            nn.Conv2d(cin, cin, kernel_size=3, stride=1, padding=1, bias=False), bn2d(cin), nn.ReLU6(inplace=True),
            nn.Conv2d(cin, cout, kernel_size=3, stride=1, padding=1, bias=False), bn2d(cout),
        )

    def forward(self, x):
        x = self.up_sample(x)
        return self.conv(x)

class Decoder(nn.Module):
    def __init__(self, up_sample_ratio, out_chans=1, width=768, sbn=True):   # todo: the decoder's width follows a simple halfing rule; you can change it to any other rule
        super().__init__()
        self.width = width
        assert is_pow2n(up_sample_ratio)
        n = round(math.log2(up_sample_ratio))
        channels = [self.width // 2 ** i for i in range(n + 1)] # todo: the decoder's width follows a simple halfing rule; you can change it to any other rule
        bn2d = nn.SyncBatchNorm if sbn else nn.BatchNorm2d
        self.dec = nn.ModuleList([UNetBlock(cin, cout, bn2d) for (cin, cout) in zip(channels[:-1], channels[1:])])
        self.proj = nn.Conv2d(channels[-1], out_chans, kernel_size=1, stride=1, bias=True)

        self.initialize()

    def forward(self, to_dec: List[torch.Tensor]):
        x = 0
        for i, d in enumerate(self.dec):
            if i < len(to_dec) and to_dec[i] is not None:
                x = x + to_dec[i]
            x = self.dec[i](x)
        return self.proj(x)

    def extra_repr(self) -> str:
        return f'width={self.width}'

    def initialize(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                trunc_normal_(m.weight, std=.02)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Conv2d):
                trunc_normal_(m.weight, std=.02)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0.)
            elif isinstance(m, (nn.LayerNorm, nn.BatchNorm1d, nn.BatchNorm2d, nn.SyncBatchNorm)):
                nn.init.constant_(m.bias, 0)
                nn.init.constant_(m.weight, 1.0)



In [None]:
import torch.optim as optim

def train_model(sparse_encoder, imputation_decoder, subsurface_decoder, dataloader, num_epochs=10, learning_rate=1e-4, device="cuda", save_model=False, model_save_path="model.pth", use_secondary=True, use_input_mask=True, visualize=True, input_reconstruction_weight=10.0, use_sparse=True):
    sparse_encoder.to(device)
    subsurface_decoder.to(device)
    sparse_encoder.train()
    subsurface_decoder.train()

    if use_secondary:
        imputation_decoder.to(device)
        imputation_decoder.train()
        optimizer = optim.AdamW(list(sparse_encoder.parameters()) + list(imputation_decoder.parameters()) + list(subsurface_decoder.parameters()), lr=learning_rate)
    else:
        optimizer = optim.AdamW(list(sparse_encoder.parameters()) + list(subsurface_decoder.parameters()), lr=learning_rate)

    criterion = nn.MSELoss()

    training_losses = []

    for epoch in range(num_epochs):
        total_primary_loss = 0
        total_secondary_loss = 0
        total_imput_reconstruction_loss = 0
        num_batches = 0

        for batch in tqdm(dataloader, desc=f"Epoch {epoch+1}/{num_epochs}"):
            primary_grid, realization_input_grid, realization_input_mask, secondary_grid, random_sampling_mask = batch
            batch_primary_grid = primary_grid.to(device)
            batch_realization_input_grid = realization_input_grid.to(device)
            batch_realization_input_mask = realization_input_mask.to(device)
            batch_secondary_grid = secondary_grid.to(device).float()
            batch_random_sampling_mask = random_sampling_mask.to(device)

            optimizer.zero_grad()

            global _cur_active

            if use_sparse:
                if use_input_mask:
                    _cur_active = batch_realization_input_mask
                else:
                    _cur_active = batch_random_sampling_mask
            else:
                _cur_active = torch.ones_like(batch_realization_input_mask).to(device)

            if use_secondary:
                model_input = batch_secondary_grid * _cur_active
            else:
                model_input = batch_primary_grid * _cur_active

            features = sparse_encoder(model_input)

            primary_output = subsurface_decoder(features[::-1])
            primary_loss = criterion(primary_output, batch_primary_grid)
            total_primary_loss += primary_loss.item()

            if use_secondary:
                secondary_output = imputation_decoder(features[::-1])
                secondary_loss = criterion(secondary_output, batch_secondary_grid)
                imput_reconstruction_loss = criterion(secondary_output * _cur_active, batch_secondary_grid * _cur_active)
                total_secondary_loss += secondary_loss.item()
            else:
                input_reconstruction_loss = criterion(primary_output * _cur_active, batch_primary_grid * _cur_active)
                secondary_loss = 0

            total_imput_reconstruction_loss += input_reconstruction_loss.item()
            loss = secondary_loss + primary_loss

            loss.backward()
            optimizer.step()

            num_batches += 1

        avg_primary_loss = total_primary_loss / num_batches
        avg_imput_reconstruction_loss = total_imput_reconstruction_loss / num_batches

        if use_secondary:
            avg_secondary_loss = total_secondary_loss / num_batches
            training_losses.append({'epoch': epoch+1, 'secondary_loss': avg_secondary_loss, 'primary_loss': avg_primary_loss})
            print(f"Epoch {epoch+1}/{num_epochs}, Average Secondary Loss: {avg_secondary_loss:.10f}, Average Primary Loss: {avg_primary_loss:.10f}, Average Imput Reconstruction Loss: {avg_imput_reconstruction_loss:.10f}")
        else:
            training_losses.append({'epoch': epoch+1, 'primary_loss': avg_primary_loss})
            print(f"Epoch {epoch+1}/{num_epochs}, Average Primary Loss: {avg_primary_loss:.10f}, Average Imput Reconstruction Loss: {avg_imput_reconstruction_loss:.10f}")

        if visualize:
            with torch.no_grad():
                visualize_results(batch_realization_input_grid[0], "realization_input_grid", batch_realization_input_mask[0], "realization_input_mask", batch_random_sampling_mask[0], "random_sampling_mask")
                visualize_results(batch_primary_grid[0], "primary_grid", primary_output[0], "primary_output")
                if use_secondary:
                    visualize_results(batch_secondary_grid[0], "secondary_grid", secondary_output[0], "secondary_output")

    if save_model:
        model_dict = {
            'sparse_encoder_state_dict': sparse_encoder.state_dict(),
            'subsurface_decoder_state_dict': subsurface_decoder.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'use_secondary': use_secondary,
            'use_sparse': use_sparse,
            'training_losses': training_losses
        }
        if use_secondary:
            model_dict['imputation_decoder_state_dict'] = imputation_decoder.state_dict()
        torch.save(model_dict, model_save_path)
        print(f"Model saved to {model_save_path}")

    return sparse_encoder, imputation_decoder if use_secondary else None, subsurface_decoder, training_losses

def load_model(model_path, sparse_encoder, imputation_decoder, subsurface_decoder, device="cuda"):
    checkpoint = torch.load(model_path, map_location=device)

    sparse_encoder.load_state_dict(checkpoint['sparse_encoder_state_dict'])
    subsurface_decoder.load_state_dict(checkpoint['subsurface_decoder_state_dict'])
    training_losses = checkpoint.get('training_losses', None)

    use_secondary = checkpoint.get('use_secondary', True)
    use_sparse = checkpoint.get('use_sparse', True)

    if use_secondary:
        if imputation_decoder is None:
            raise ValueError("Model was saved with imputation, but no imputation_decoder was provided.")
        imputation_decoder.load_state_dict(checkpoint['imputation_decoder_state_dict'])
    elif imputation_decoder is not None:
        print("Warning: imputation_decoder provided but model was saved without imputation. Ignoring imputation_decoder.")

    return sparse_encoder, imputation_decoder if use_secondary else None, subsurface_decoder, use_secondary, use_sparse, training_losses

# Test

In [None]:
import matplotlib.colors as mcolors

def evaluate_model(sparse_encoder, subsurface_decoder, dataloader, device="cuda", use_sparse=False, use_secondary=False, use_input_mask=True):
    sparse_encoder.to(device)
    subsurface_decoder.to(device)
    sparse_encoder.eval()
    subsurface_decoder.eval()

    criterion = nn.MSELoss()
    total_loss = 0
    total_variance = 0
    num_batches = 0
    all_outputs = []

    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Evaluating"):
            primary_grid, realization_input_grid, realization_input_mask, secondary_grid, random_sampling_mask = batch
            batch_primary_grid = primary_grid.to(device)
            batch_realization_input_grid = realization_input_grid.to(device)
            batch_realization_input_mask = realization_input_mask.to(device)
            batch_secondary_grid = secondary_grid.to(device).float()
            batch_random_sampling_mask = random_sampling_mask.to(device)

            global _cur_active

            if use_sparse:
                if use_input_mask:
                    _cur_active = batch_realization_input_mask
                else:
                    _cur_active = batch_random_sampling_mask
            else:
                _cur_active = torch.ones_like(batch_realization_input_mask).to(device)

            if use_secondary:
                model_input = batch_secondary_grid * _cur_active
            else:
                model_input = batch_primary_grid * _cur_active

            features = sparse_encoder(model_input)

            primary_output = subsurface_decoder(features[::-1])

            # Store outputs for variance calculation
            all_outputs.append(primary_output.cpu())

            # Compute loss
            loss = criterion(primary_output, batch_primary_grid)
            total_loss += loss.item()

            # Compute variance of predictions
            batch_variance = torch.var(primary_output).item()
            total_variance += batch_variance

            num_batches += 1

    avg_loss = total_loss / num_batches
    avg_variance = total_variance / num_batches

    # Calculate overall variance across all predictions
    all_outputs = torch.cat(all_outputs, dim=0)
    total_prediction_variance = torch.var(all_outputs).item()

    return avg_loss, avg_variance, total_prediction_variance

# Define the process_batch function
def process_batch(sparse_encoder, subsurface_decoder, batch, device="cuda", use_secondary=False, use_sparse=False, use_input_mask=True):
    sparse_encoder.eval()
    subsurface_decoder.eval()
    with torch.no_grad():
        primary_grid, realization_input_grid, realization_input_mask, secondary_grid, random_sampling_mask = batch
        batch_primary_grid = primary_grid.to(device)
        batch_realization_input_grid = realization_input_grid.to(device)
        batch_realization_input_mask = realization_input_mask.to(device)
        batch_secondary_grid = secondary_grid.to(device).float()
        batch_random_sampling_mask = random_sampling_mask.to(device)

        global _cur_active

        if use_sparse:
            if use_input_mask:
                _cur_active = batch_realization_input_mask
                current_mask = batch_realization_input_mask
            else:
                _cur_active = batch_random_sampling_mask
                current_mask = batch_random_sampling_mask
        else:
            _cur_active = torch.ones_like(batch_realization_input_mask).to(device)
            current_mask = torch.ones_like(batch_realization_input_mask).to(device)

        if use_secondary:
            model_input = batch_secondary_grid * _cur_active
        else:
            model_input = batch_primary_grid * _cur_active

        features = sparse_encoder(model_input)

        primary_output = subsurface_decoder(features[::-1])

        return model_input.cpu(), primary_output.cpu(), current_mask.cpu(), batch_primary_grid.cpu()

def plot_collected_images(collected_images, images_per_row=4):
    """
    Visualize collected images with data validation and proper tensor handling.
    If original_image is present in the data, displays it alongside the prediction.
    """
    num_images = len(collected_images)
    # Double the columns if we have original images to show them side by side
    has_originals = any('original_image' in data for data in collected_images)
    cols_per_item = 2 if has_originals else 1
    images_per_row = min(images_per_row, num_images)
    num_rows = (num_images + images_per_row - 1) // images_per_row

    fig, axes = plt.subplots(num_rows, images_per_row * cols_per_item,
                            figsize=(5 * images_per_row * cols_per_item, 5 * num_rows))
    if num_rows == 1 and images_per_row == 1:
        axes = np.array([axes])
    axes = axes.flatten()

    for idx, data in enumerate(collected_images):
        # Calculate base index for this item's axes
        base_idx = idx * cols_per_item

        try:
            input_grid = data['input_grid'][0].squeeze().cpu().numpy()
            output_grid = data['output_grid'][0].squeeze().cpu().numpy()
            input_mask = data['input_mask'][0].squeeze().cpu().numpy()
            title = data['title']

            # Plot original image if available
            if has_originals:
                ax_orig = axes[base_idx]
                if 'original_image' in data:
                    original = data['original_image'][0].squeeze().cpu().numpy()
                    im_orig = ax_orig.imshow(original, cmap='viridis', interpolation='nearest')
                    ax_orig.set_title(f"{title}\n(Original)")
                    plt.colorbar(im_orig, ax=ax_orig)
                else:
                    ax_orig.text(0.5, 0.5, 'No original image',
                               ha='center', va='center', transform=ax_orig.transAxes)
                ax_orig.axis('off')

                # Use next axis for prediction
                ax = axes[base_idx + 1]
            else:
                ax = axes[base_idx]

            # Plot the predicted values as background
            im = ax.imshow(output_grid, cmap='viridis', interpolation='nearest')

            # Get mask indices and corresponding values
            mask_indices = np.argwhere(input_mask > 0)
            if len(mask_indices) > 0:  # Only plot scatter if we have points
                mask_values = input_grid[tuple(mask_indices.T)]  # Use proper indexing

                # Verify we have the same number of points and values
                assert len(mask_indices) == len(mask_values), \
                    f"Mismatch between number of points ({len(mask_indices)}) and values ({len(mask_values)})"

                scatter = ax.scatter(mask_indices[:, 1],
                                   mask_indices[:, 0],
                                   c=mask_values,
                                   cmap='viridis',
                                   edgecolors='black',
                                   s=50)

            ax.set_title(f"{title}\n(Prediction)" if has_originals else title)
            ax.axis('off')
            plt.colorbar(im, ax=ax)

        except Exception as e:
            # Create an error message plot
            current_ax = axes[base_idx]
            current_ax.text(0.5, 0.5, f'Error processing image {idx}\n{str(e)}',
                          ha='center', va='center', transform=current_ax.transAxes, color='red')
            current_ax.axis('off')

            # If we're showing originals, clear the second axis too
            if has_originals and base_idx + 1 < len(axes):
                axes[base_idx + 1].axis('off')

    # Remove any empty subplots
    total_plots = num_rows * images_per_row * cols_per_item
    for idx in range(len(collected_images) * cols_per_item, total_plots):
        fig.delaxes(axes[idx])

    plt.tight_layout()
    plt.show()

# Non sparse model

In [None]:
import os
import time
import random
import torch
import numpy as np
from datetime import datetime
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import pickle
from tqdm import tqdm
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import torch.optim as optim
from sklearn.metrics import mean_squared_error, r2_score
from skimage.metrics import structural_similarity as ssim

# Define data ranges and max epochs
data_ranges = [100,200,400,800,1600,3200,6400,12800]
epochs = [1,2,4,8,16,32,64,128]

# Initialize results list
results = []

# Device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Define test datasets and dataloaders for different scenarios
# All test sets will use the same size (1000 samples) but different point ranges
interpolation_generator = lambda: torch.rand(torch.randint(3, 50, (1,)).item(), 3) * torch.tensor([64, 64, 1])
extrapolation_close_generator = lambda: torch.rand(torch.randint(51, 70, (1,)).item(), 3) * torch.tensor([64, 64, 1])
extrapolation_far_generator = lambda: torch.rand(torch.randint(70, 90, (1,)).item(), 3) * torch.tensor([64, 64, 1])

# Create test datasets
test_generators = {
    'interpolation': interpolation_generator,
    'extrapolation_close': extrapolation_close_generator,
    'extrapolation_far': extrapolation_far_generator
}

test_datasets = {}
for name, gen in test_generators.items():
    generator_obj = SpatialGenerator((64, 64), gen, methods=['kriging'])
    test_dataset = SpatialDataset(1000, generator_obj, drilling_sampling, lambda x: x,
                                  dynamic_secondary_mask=True, x_channels=0,
                                  secondary_channels=0, primary_channels=0)
    test_datasets[name] = test_dataset

test_dataloaders = {name: DataLoader(dataset, batch_size=128, shuffle=False,
                                     num_workers=0, pin_memory=True) for name, dataset in test_datasets.items()}

# Prepare to collect images for plotting
collected_images = []
use_sparse=False
use_input_mask=True

for data_range in data_ranges:
    print(f"\nStarting training with data size: {data_range}")

    # Create training dataset using the same range as interpolation test (3-50)
    generator = SpatialGenerator((64, 64), interpolation_generator, methods=['kriging'])
    train_dataset = SpatialDataset(data_range, generator, drilling_sampling, lambda x: x,
                                   dynamic_secondary_mask=True, x_channels=0,
                                   secondary_channels=0, primary_channels=0)
    train_dataloader = DataLoader(train_dataset, batch_size=128, shuffle=True,
                                  num_workers=0, pin_memory=True)

    # Initialize the models
    convnext = ConvNeXt(depths=[3, 3, 9, 3], dims=[96, 192, 384, 768]).to(device)
    sparse_encoder = SparseEncoder(convnext, input_size=(64, 64)).to(device)
    subsurface_decoder = Decoder(up_sample_ratio=32, out_chans=1, width=768, sbn=False).to(device)

    # Progressive training loop
    current_epoch = 0

    for target_epoch in epochs:
        epochs_to_train = target_epoch - current_epoch
        if epochs_to_train <= 0:
            continue

        print(f"Training from epoch {current_epoch + 1} to {target_epoch}")

        # Start timing the training
        training_start_time = time.time()

        # Train the model
        sparse_encoder, _, subsurface_decoder, training_losses = train_model(
            sparse_encoder, None, subsurface_decoder, train_dataloader,
            use_sparse=use_sparse,
            model_save_path=f"model_data{data_range}_epochs{target_epoch}.pth",
            num_epochs=epochs_to_train,
            use_secondary=False,
            use_input_mask=use_input_mask,
            device=device,
            save_model=False,
            visualize=False
        )

        # End timing the training
        training_end_time = time.time()
        training_time = training_end_time - training_start_time  # Calculate training time

        # Collect final training loss
        final_training_loss = training_losses[-1]['primary_loss']

        # Evaluate the model on all three test sets
        test_metrics = {}
        evaluation_times = {}

        # Evaluate the model on all three test sets
        for name, dataloader in test_dataloaders.items():
            # Start timing the evaluation
            evaluation_start_time = time.time()

            loss, variance, total_variance = evaluate_model(
                sparse_encoder, subsurface_decoder,
                dataloader, device=device,
                use_sparse=use_sparse, use_input_mask=use_input_mask
            )

            # End timing the evaluation
            evaluation_end_time = time.time()
            evaluation_time = evaluation_end_time - evaluation_start_time  # Calculate evaluation time

            test_metrics[name] = {
                'loss': loss,
                'batch_variance': variance,
                'total_variance': total_variance
            }
            evaluation_times[name] = evaluation_time  # Store evaluation time

        # Append the results
        results.append({
            'data_size': data_range,
            'epochs': target_epoch,
            'training_loss': final_training_loss,
            'training_time': training_time,  # Add training time
            'evaluation_time_interpolation': evaluation_times['interpolation'],  # Add evaluation times
            'evaluation_time_extrapolation_close': evaluation_times['extrapolation_close'],
            'evaluation_time_extrapolation_far': evaluation_times['extrapolation_far'],
            'test_loss_interpolation': test_metrics['interpolation']['loss'],
            'test_variance_interpolation': test_metrics['interpolation']['total_variance'],
            'test_loss_extrapolation_close': test_metrics['extrapolation_close']['loss'],
            'test_variance_extrapolation_close': test_metrics['extrapolation_close']['total_variance'],
            'test_loss_extrapolation_far': test_metrics['extrapolation_far']['loss'],
            'test_variance_extrapolation_far': test_metrics['extrapolation_far']['total_variance'],
            'training_points_range': '3-50'
        })

        current_epoch = target_epoch

        # Process and collect images for training data
        train_iter = iter(train_dataloader)
        train_batch = next(train_iter)
        train_input, train_output, train_mask = process_batch(
            sparse_encoder, subsurface_decoder, train_batch, device=device, use_sparse=use_sparse, use_input_mask=use_input_mask
        )
        collected_images.append({
            'title': f'Train Size: {data_range}, Epochs: {target_epoch}',
            'input_grid': train_input,
            'output_grid': train_output,
            'input_mask': train_mask
        })

        # Process and collect images for test datasets
        for name, dataset in test_datasets.items():
            test_dataloader = DataLoader(dataset, batch_size=128, shuffle=False,
                                         num_workers=0, pin_memory=True)
            test_iter = iter(test_dataloader)
            test_batch = next(test_iter)
            test_input, test_output, test_mask = process_batch(
                sparse_encoder, subsurface_decoder, test_batch, device=device, use_sparse=use_sparse, use_input_mask=use_input_mask
            )
            collected_images.append({
                'title': f'Test: {name}, Epochs: {target_epoch}',
                'input_grid': test_input,
                'output_grid': test_output,
                'input_mask': test_mask
            })

# Convert results to DataFrame
df_results = pd.DataFrame(results)
print("\nResults Summary:")
print(df_results)

# Create visualization plots
plt.style.use('seaborn')
fig, axes = plt.subplots(3, 2, figsize=(20, 24))

# 1. Test Losses vs Data Size for Different Epochs
ax1 = axes[0, 0]
test_types = {
    'test_loss_interpolation': 'Interpolation (3-50)',
    'test_loss_extrapolation_close': 'Close Extrapolation (51-70)',
    'test_loss_extrapolation_far': 'Far Extrapolation (70-90)'
}

for test_type, label in test_types.items():
    for epoch in epochs:
        epoch_data = df_results[df_results['epochs'] == epoch]
        ax1.plot(epoch_data['data_size'], epoch_data[test_type],
                 label=f'{label} (Epochs {epoch})', marker='o', alpha=0.7)
ax1.set_xlabel('Training Data Size')
ax1.set_ylabel('Test Loss')
ax1.set_title('Test Losses vs Training Data Size\n(Training Range: 3-50 points)')
ax1.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
ax1.grid(True)

# 2. Test Losses vs Epochs for Different Data Sizes
ax2 = axes[0, 1]
for test_type, label in test_types.items():
    for data_size in data_ranges:
        size_data = df_results[df_results['data_size'] == data_size]
        ax2.plot(size_data['epochs'], size_data[test_type],
                 label=f'{label} (Size {data_size})', marker='o', alpha=0.7)
ax2.set_xlabel('Epochs')
ax2.set_ylabel('Test Loss')
ax2.set_title('Test Losses vs Epochs\n(Training Range: 3-50 points)')
ax2.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
ax2.grid(True)

# 3. Heatmaps for each test type including Training Loss
fig_heatmaps, axes_heatmaps = plt.subplots(1, 4, figsize=(25, 6))
test_types_with_training = {
    'training_loss': 'Training Loss',
    **test_types
}
for idx, (test_type, label) in enumerate(test_types_with_training.items()):
    pivot_table = df_results.pivot(index='data_size', columns='epochs', values=test_type)
    sns.heatmap(pivot_table, annot=True, fmt=".6f", cmap='viridis', ax=axes_heatmaps[idx])
    axes_heatmaps[idx].set_title(f'{label}\nHeatmap')
fig_heatmaps.tight_layout()
# Removed plt.show() here

# 4. Comparison of test types
ax4 = axes[1, 1]
box_data = [df_results[test_type] for test_type in test_types.keys()]
ax4.boxplot(box_data, labels=[label.replace('\n', ' ') for label in test_types.values()])
ax4.set_ylabel('Test Loss')
ax4.set_title('Distribution of Test Losses by Type\n(Training Range: 3-50 points)')
ax4.grid(True)

# 5. Loss Ratio Analysis
ax3 = axes[1, 0]
df_results['close_extrapolation_ratio'] = df_results['test_loss_extrapolation_close'] / df_results['test_loss_interpolation']
df_results['far_extrapolation_ratio'] = df_results['test_loss_extrapolation_far'] / df_results['test_loss_interpolation']

for ratio, label in [('close_extrapolation_ratio', 'Close Extrapolation / Interpolation'),
                     ('far_extrapolation_ratio', 'Far Extrapolation / Interpolation')]:
    for epoch in epochs:
        epoch_data = df_results[df_results['epochs'] == epoch]
        ax3.plot(epoch_data['data_size'], epoch_data[ratio],
                 label=f'{label} (Epochs {epoch})', marker='o', alpha=0.7)
ax3.axhline(y=1, color='r', linestyle='--', label='Baseline (Equal Performance)')
ax3.set_xlabel('Training Data Size')
ax3.set_ylabel('Loss Ratio')
ax3.set_title('Extrapolation Performance Relative to Interpolation\n(Training Range: 3-50 points)')
ax3.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
ax3.grid(True)

# Heatmaps for variances of each test type
fig_variance_heatmaps, axes_variance_heatmaps = plt.subplots(1, 3, figsize=(20, 6))
variance_test_types = {
    'test_variance_interpolation': 'Interpolation Variance',
    'test_variance_extrapolation_close': 'Close Extrapolation Variance',
    'test_variance_extrapolation_far': 'Far Extrapolation Variance'
}
for idx, (variance_type, label) in enumerate(variance_test_types.items()):
    pivot_table = df_results.pivot(index='data_size', columns='epochs', values=variance_type)
    sns.heatmap(pivot_table, annot=True, fmt=".6f", cmap='viridis', ax=axes_variance_heatmaps[idx])
    axes_variance_heatmaps[idx].set_title(f'{label}\nHeatmap')
fig_variance_heatmaps.tight_layout()

# 7. Training Loss Heatmap (This remains unchanged)
ax7 = axes[2, 1]
training_loss_pivot = df_results.pivot(index='data_size', columns='epochs', values='training_loss')
sns.heatmap(training_loss_pivot, annot=True, fmt=".6f", cmap='viridis', ax=ax7)
ax7.set_title('Training Loss Heatmap')

# 8. Training Time Heatmap
ax8 = axes[2, 0]
training_time_pivot = df_results.pivot(index='data_size', columns='epochs', values='training_time')
sns.heatmap(training_time_pivot, annot=True, fmt=".2f", cmap='Blues', ax=ax8)
ax8.set_title('Training Time Heatmap')
ax8.set_xlabel('Epochs')
ax8.set_ylabel('Training Data Size')

# 9. Evaluation Time Heatmaps for each test set
fig_eval_time_heatmaps, axes_eval_time_heatmaps = plt.subplots(1, 3, figsize=(20, 6))
evaluation_time_test_types = {
    'evaluation_time_interpolation': 'Interpolation Evaluation Time',
    'evaluation_time_extrapolation_close': 'Close Extrapolation Evaluation Time',
    'evaluation_time_extrapolation_far': 'Far Extrapolation Evaluation Time'
}
for idx, (eval_time_type, label) in enumerate(evaluation_time_test_types.items()):
    pivot_table = df_results.pivot(index='data_size', columns='epochs', values=eval_time_type)
    sns.heatmap(pivot_table, annot=True, fmt=".2f", cmap='Greens', ax=axes_eval_time_heatmaps[idx])
    axes_eval_time_heatmaps[idx].set_title(f'{label}\nHeatmap')
    axes_eval_time_heatmaps[idx].set_xlabel('Epochs')
    axes_eval_time_heatmaps[idx].set_ylabel('Training Data Size')
fig_eval_time_heatmaps.tight_layout()

plt.tight_layout()
plt.show()

# Call the function to plot collected images
plot_collected_images(collected_images, images_per_row=4)



In [None]:
import os
import time
import random
import torch
import numpy as np
from datetime import datetime
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import pickle
from tqdm import tqdm
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import torch.optim as optim
from sklearn.metrics import mean_squared_error, r2_score
from skimage.metrics import structural_similarity as ssim

# Define data ranges and max epochs
data_ranges = [100,200,400,800,1600,3200,6400,12800]
epochs = [1,2,4,8,16,32,64,128]

# Initialize results list
results = []

# Device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Define test datasets and dataloaders for different scenarios
# All test sets will use the same size (1000 samples) but different point ranges
interpolation_generator = lambda: torch.rand(torch.randint(3, 50, (1,)).item(), 3) * torch.tensor([64, 64, 1])
extrapolation_close_generator = lambda: torch.rand(torch.randint(51, 70, (1,)).item(), 3) * torch.tensor([64, 64, 1])
extrapolation_far_generator = lambda: torch.rand(torch.randint(70, 90, (1,)).item(), 3) * torch.tensor([64, 64, 1])

# Create test datasets
test_generators = {
    'interpolation': interpolation_generator,
    'extrapolation_close': extrapolation_close_generator,
    'extrapolation_far': extrapolation_far_generator
}

test_datasets = {}
for name, gen in test_generators.items():
    generator_obj = SpatialGenerator((64, 64), gen, methods=['kriging'])
    test_dataset = SpatialDataset(1000, generator_obj, drilling_sampling, lambda x: x,
                                  dynamic_secondary_mask=True, x_channels=0,
                                  secondary_channels=0, primary_channels=0)
    test_datasets[name] = test_dataset

test_dataloaders = {name: DataLoader(dataset, batch_size=128, shuffle=False,
                                     num_workers=0, pin_memory=True) for name, dataset in test_datasets.items()}

# Prepare to collect images for plotting
collected_images = []
use_sparse=True
use_input_mask=True

for data_range in data_ranges:
    print(f"\nStarting training with data size: {data_range}")

    # Create training dataset using the same range as interpolation test (3-50)
    generator = SpatialGenerator((64, 64), interpolation_generator, methods=['kriging'])
    train_dataset = SpatialDataset(data_range, generator, drilling_sampling, lambda x: x,
                                   dynamic_secondary_mask=True, x_channels=0,
                                   secondary_channels=0, primary_channels=0)
    train_dataloader = DataLoader(train_dataset, batch_size=128, shuffle=True,
                                  num_workers=0, pin_memory=True)

    # Initialize the models
    convnext = ConvNeXt(depths=[3, 3, 9, 3], dims=[96, 192, 384, 768]).to(device)
    sparse_encoder = SparseEncoder(convnext, input_size=(64, 64)).to(device)
    subsurface_decoder = Decoder(up_sample_ratio=32, out_chans=1, width=768, sbn=False).to(device)

    # Progressive training loop
    current_epoch = 0

    for target_epoch in epochs:
        epochs_to_train = target_epoch - current_epoch
        if epochs_to_train <= 0:
            continue

        print(f"Training from epoch {current_epoch + 1} to {target_epoch}")

        # Start timing the training
        training_start_time = time.time()

        # Train the model
        sparse_encoder, _, subsurface_decoder, training_losses = train_model(
            sparse_encoder, None, subsurface_decoder, train_dataloader,
            use_sparse=use_sparse,
            model_save_path=f"model_data{data_range}_epochs{target_epoch}.pth",
            num_epochs=epochs_to_train,
            use_secondary=False,
            use_input_mask=use_input_mask,
            device=device,
            save_model=False,
            visualize=False
        )

        # End timing the training
        training_end_time = time.time()
        training_time = training_end_time - training_start_time  # Calculate training time

        # Collect final training loss
        final_training_loss = training_losses[-1]['primary_loss']

        # Evaluate the model on all three test sets
        test_metrics = {}
        evaluation_times = {}

        # Evaluate the model on all three test sets
        for name, dataloader in test_dataloaders.items():
            # Start timing the evaluation
            evaluation_start_time = time.time()

            loss, variance, total_variance = evaluate_model(
                sparse_encoder, subsurface_decoder,
                dataloader, device=device,
                use_sparse=use_sparse, use_input_mask=use_input_mask
            )

            # End timing the evaluation
            evaluation_end_time = time.time()
            evaluation_time = evaluation_end_time - evaluation_start_time  # Calculate evaluation time

            test_metrics[name] = {
                'loss': loss,
                'batch_variance': variance,
                'total_variance': total_variance
            }
            evaluation_times[name] = evaluation_time  # Store evaluation time

        # Append the results
        results.append({
            'data_size': data_range,
            'epochs': target_epoch,
            'training_loss': final_training_loss,
            'training_time': training_time,  # Add training time
            'evaluation_time_interpolation': evaluation_times['interpolation'],  # Add evaluation times
            'evaluation_time_extrapolation_close': evaluation_times['extrapolation_close'],
            'evaluation_time_extrapolation_far': evaluation_times['extrapolation_far'],
            'test_loss_interpolation': test_metrics['interpolation']['loss'],
            'test_variance_interpolation': test_metrics['interpolation']['total_variance'],
            'test_loss_extrapolation_close': test_metrics['extrapolation_close']['loss'],
            'test_variance_extrapolation_close': test_metrics['extrapolation_close']['total_variance'],
            'test_loss_extrapolation_far': test_metrics['extrapolation_far']['loss'],
            'test_variance_extrapolation_far': test_metrics['extrapolation_far']['total_variance'],
            'training_points_range': '3-50'
        })

        current_epoch = target_epoch

        # Process and collect images for training data
        train_iter = iter(train_dataloader)
        train_batch = next(train_iter)
        train_input, train_output, train_mask = process_batch(
            sparse_encoder, subsurface_decoder, train_batch, device=device, use_sparse=use_sparse, use_input_mask=use_input_mask
        )
        collected_images.append({
            'title': f'Train Size: {data_range}, Epochs: {target_epoch}',
            'input_grid': train_input,
            'output_grid': train_output,
            'input_mask': train_mask
        })

        # Process and collect images for test datasets
        for name, dataset in test_datasets.items():
            test_dataloader = DataLoader(dataset, batch_size=128, shuffle=False,
                                         num_workers=0, pin_memory=True)
            test_iter = iter(test_dataloader)
            test_batch = next(test_iter)
            test_input, test_output, test_mask = process_batch(
                sparse_encoder, subsurface_decoder, test_batch, device=device, use_sparse=use_sparse, use_input_mask=use_input_mask
            )
            collected_images.append({
                'title': f'Test: {name}, Epochs: {target_epoch}',
                'input_grid': test_input,
                'output_grid': test_output,
                'input_mask': test_mask
            })

# Convert results to DataFrame
df_results = pd.DataFrame(results)
print("\nResults Summary:")
print(df_results)

# Create visualization plots
plt.style.use('seaborn')
fig, axes = plt.subplots(3, 2, figsize=(20, 24))

# 1. Test Losses vs Data Size for Different Epochs
ax1 = axes[0, 0]
test_types = {
    'test_loss_interpolation': 'Interpolation (3-50)',
    'test_loss_extrapolation_close': 'Close Extrapolation (51-70)',
    'test_loss_extrapolation_far': 'Far Extrapolation (70-90)'
}

for test_type, label in test_types.items():
    for epoch in epochs:
        epoch_data = df_results[df_results['epochs'] == epoch]
        ax1.plot(epoch_data['data_size'], epoch_data[test_type],
                 label=f'{label} (Epochs {epoch})', marker='o', alpha=0.7)
ax1.set_xlabel('Training Data Size')
ax1.set_ylabel('Test Loss')
ax1.set_title('Test Losses vs Training Data Size\n(Training Range: 3-50 points)')
ax1.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
ax1.grid(True)

# 2. Test Losses vs Epochs for Different Data Sizes
ax2 = axes[0, 1]
for test_type, label in test_types.items():
    for data_size in data_ranges:
        size_data = df_results[df_results['data_size'] == data_size]
        ax2.plot(size_data['epochs'], size_data[test_type],
                 label=f'{label} (Size {data_size})', marker='o', alpha=0.7)
ax2.set_xlabel('Epochs')
ax2.set_ylabel('Test Loss')
ax2.set_title('Test Losses vs Epochs\n(Training Range: 3-50 points)')
ax2.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
ax2.grid(True)

# 3. Heatmaps for each test type including Training Loss
fig_heatmaps, axes_heatmaps = plt.subplots(1, 4, figsize=(25, 6))
test_types_with_training = {
    'training_loss': 'Training Loss',
    **test_types
}
for idx, (test_type, label) in enumerate(test_types_with_training.items()):
    pivot_table = df_results.pivot(index='data_size', columns='epochs', values=test_type)
    sns.heatmap(pivot_table, annot=True, fmt=".6f", cmap='viridis', ax=axes_heatmaps[idx])
    axes_heatmaps[idx].set_title(f'{label}\nHeatmap')
fig_heatmaps.tight_layout()
# Removed plt.show() here

# 4. Comparison of test types
ax4 = axes[1, 1]
box_data = [df_results[test_type] for test_type in test_types.keys()]
ax4.boxplot(box_data, labels=[label.replace('\n', ' ') for label in test_types.values()])
ax4.set_ylabel('Test Loss')
ax4.set_title('Distribution of Test Losses by Type\n(Training Range: 3-50 points)')
ax4.grid(True)

# 5. Loss Ratio Analysis
ax3 = axes[1, 0]
df_results['close_extrapolation_ratio'] = df_results['test_loss_extrapolation_close'] / df_results['test_loss_interpolation']
df_results['far_extrapolation_ratio'] = df_results['test_loss_extrapolation_far'] / df_results['test_loss_interpolation']

for ratio, label in [('close_extrapolation_ratio', 'Close Extrapolation / Interpolation'),
                     ('far_extrapolation_ratio', 'Far Extrapolation / Interpolation')]:
    for epoch in epochs:
        epoch_data = df_results[df_results['epochs'] == epoch]
        ax3.plot(epoch_data['data_size'], epoch_data[ratio],
                 label=f'{label} (Epochs {epoch})', marker='o', alpha=0.7)
ax3.axhline(y=1, color='r', linestyle='--', label='Baseline (Equal Performance)')
ax3.set_xlabel('Training Data Size')
ax3.set_ylabel('Loss Ratio')
ax3.set_title('Extrapolation Performance Relative to Interpolation\n(Training Range: 3-50 points)')
ax3.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
ax3.grid(True)

# Heatmaps for variances of each test type
fig_variance_heatmaps, axes_variance_heatmaps = plt.subplots(1, 3, figsize=(20, 6))
variance_test_types = {
    'test_variance_interpolation': 'Interpolation Variance',
    'test_variance_extrapolation_close': 'Close Extrapolation Variance',
    'test_variance_extrapolation_far': 'Far Extrapolation Variance'
}
for idx, (variance_type, label) in enumerate(variance_test_types.items()):
    pivot_table = df_results.pivot(index='data_size', columns='epochs', values=variance_type)
    sns.heatmap(pivot_table, annot=True, fmt=".6f", cmap='viridis', ax=axes_variance_heatmaps[idx])
    axes_variance_heatmaps[idx].set_title(f'{label}\nHeatmap')
fig_variance_heatmaps.tight_layout()

# 7. Training Loss Heatmap (This remains unchanged)
ax7 = axes[2, 1]
training_loss_pivot = df_results.pivot(index='data_size', columns='epochs', values='training_loss')
sns.heatmap(training_loss_pivot, annot=True, fmt=".6f", cmap='viridis', ax=ax7)
ax7.set_title('Training Loss Heatmap')

# 8. Training Time Heatmap
ax8 = axes[2, 0]
training_time_pivot = df_results.pivot(index='data_size', columns='epochs', values='training_time')
sns.heatmap(training_time_pivot, annot=True, fmt=".2f", cmap='Blues', ax=ax8)
ax8.set_title('Training Time Heatmap')
ax8.set_xlabel('Epochs')
ax8.set_ylabel('Training Data Size')

# 9. Evaluation Time Heatmaps for each test set
fig_eval_time_heatmaps, axes_eval_time_heatmaps = plt.subplots(1, 3, figsize=(20, 6))
evaluation_time_test_types = {
    'evaluation_time_interpolation': 'Interpolation Evaluation Time',
    'evaluation_time_extrapolation_close': 'Close Extrapolation Evaluation Time',
    'evaluation_time_extrapolation_far': 'Far Extrapolation Evaluation Time'
}
for idx, (eval_time_type, label) in enumerate(evaluation_time_test_types.items()):
    pivot_table = df_results.pivot(index='data_size', columns='epochs', values=eval_time_type)
    sns.heatmap(pivot_table, annot=True, fmt=".2f", cmap='Greens', ax=axes_eval_time_heatmaps[idx])
    axes_eval_time_heatmaps[idx].set_title(f'{label}\nHeatmap')
    axes_eval_time_heatmaps[idx].set_xlabel('Epochs')
    axes_eval_time_heatmaps[idx].set_ylabel('Training Data Size')
fig_eval_time_heatmaps.tight_layout()

plt.tight_layout()
plt.show()

# Call the function to plot collected images
plot_collected_images(collected_images, images_per_row=4)



In [None]:
import os
import time
import random
import torch
import numpy as np
from datetime import datetime
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import pickle
from tqdm import tqdm
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import torch.optim as optim
from sklearn.metrics import mean_squared_error, r2_score
from skimage.metrics import structural_similarity as ssim

# Define data ranges and max epochs
data_ranges = [100,200,400,800,1600,3200,6400,12800]
epochs = [1,2,4,8,16,32,64,128]

# Initialize results list
results = []

# Device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Define test datasets and dataloaders for different scenarios
# All test sets will use the same size (1000 samples) but different point ranges
interpolation_generator = lambda: torch.rand(torch.randint(3, 50, (1,)).item(), 3) * torch.tensor([64, 64, 1])
extrapolation_close_generator = lambda: torch.rand(torch.randint(51, 70, (1,)).item(), 3) * torch.tensor([64, 64, 1])
extrapolation_far_generator = lambda: torch.rand(torch.randint(70, 90, (1,)).item(), 3) * torch.tensor([64, 64, 1])

# Create test datasets
test_generators = {
    'interpolation': interpolation_generator,
    'extrapolation_close': extrapolation_close_generator,
    'extrapolation_far': extrapolation_far_generator
}

test_datasets = {}
for name, gen in test_generators.items():
    generator_obj = SpatialGenerator((64, 64), gen, methods=['kriging'])
    test_dataset = SpatialDataset(1000, generator_obj, drilling_sampling, lambda x: x,
                                  dynamic_secondary_mask=True, x_channels=0,
                                  secondary_channels=0, primary_channels=0)
    test_datasets[name] = test_dataset

test_dataloaders = {name: DataLoader(dataset, batch_size=128, shuffle=False,
                                     num_workers=0, pin_memory=True) for name, dataset in test_datasets.items()}

# Prepare to collect images for plotting
collected_images = []
use_sparse=True
use_input_mask=True

for data_range in data_ranges:
    print(f"\nStarting training with data size: {data_range}")

    # Create training dataset using the same range as interpolation test (3-50)
    generator = SpatialGenerator((64, 64), interpolation_generator, methods=['kriging'])
    train_dataset = SpatialDataset(data_range, generator, drilling_sampling, lambda x: x,
                                   dynamic_secondary_mask=True, x_channels=0,
                                   secondary_channels=0, primary_channels=0)
    train_dataloader = DataLoader(train_dataset, batch_size=128, shuffle=True,
                                  num_workers=0, pin_memory=True)

    # Initialize the models
    convnext = ConvNeXt(depths=[3, 3, 27, 3], dims=[96, 192, 384, 768]).to(device)
    sparse_encoder = SparseEncoder(convnext, input_size=(64, 64)).to(device)
    subsurface_decoder = Decoder(up_sample_ratio=32, out_chans=1, width=768, sbn=False).to(device)

    # Progressive training loop
    current_epoch = 0

    for target_epoch in epochs:
        epochs_to_train = target_epoch - current_epoch
        if epochs_to_train <= 0:
            continue

        print(f"Training from epoch {current_epoch + 1} to {target_epoch}")

        # Start timing the training
        training_start_time = time.time()

        # Train the model
        sparse_encoder, _, subsurface_decoder, training_losses = train_model(
            sparse_encoder, None, subsurface_decoder, train_dataloader,
            use_sparse=use_sparse,
            model_save_path=f"model_data{data_range}_epochs{target_epoch}.pth",
            num_epochs=epochs_to_train,
            use_secondary=False,
            use_input_mask=use_input_mask,
            device=device,
            save_model=False,
            visualize=False
        )

        # End timing the training
        training_end_time = time.time()
        training_time = training_end_time - training_start_time  # Calculate training time

        # Collect final training loss
        final_training_loss = training_losses[-1]['primary_loss']

        # Evaluate the model on all three test sets
        test_metrics = {}
        evaluation_times = {}

        # Evaluate the model on all three test sets
        for name, dataloader in test_dataloaders.items():
            # Start timing the evaluation
            evaluation_start_time = time.time()

            loss, variance, total_variance = evaluate_model(
                sparse_encoder, subsurface_decoder,
                dataloader, device=device,
                use_sparse=use_sparse, use_input_mask=use_input_mask
            )

            # End timing the evaluation
            evaluation_end_time = time.time()
            evaluation_time = evaluation_end_time - evaluation_start_time  # Calculate evaluation time

            test_metrics[name] = {
                'loss': loss,
                'batch_variance': variance,
                'total_variance': total_variance
            }
            evaluation_times[name] = evaluation_time  # Store evaluation time

        # Append the results
        results.append({
            'data_size': data_range,
            'epochs': target_epoch,
            'training_loss': final_training_loss,
            'training_time': training_time,  # Add training time
            'evaluation_time_interpolation': evaluation_times['interpolation'],  # Add evaluation times
            'evaluation_time_extrapolation_close': evaluation_times['extrapolation_close'],
            'evaluation_time_extrapolation_far': evaluation_times['extrapolation_far'],
            'test_loss_interpolation': test_metrics['interpolation']['loss'],
            'test_variance_interpolation': test_metrics['interpolation']['total_variance'],
            'test_loss_extrapolation_close': test_metrics['extrapolation_close']['loss'],
            'test_variance_extrapolation_close': test_metrics['extrapolation_close']['total_variance'],
            'test_loss_extrapolation_far': test_metrics['extrapolation_far']['loss'],
            'test_variance_extrapolation_far': test_metrics['extrapolation_far']['total_variance'],
            'training_points_range': '3-50'
        })

        current_epoch = target_epoch

        # Process and collect images for training data
        train_iter = iter(train_dataloader)
        train_batch = next(train_iter)
        train_input, train_output, train_mask = process_batch(
            sparse_encoder, subsurface_decoder, train_batch, device=device, use_sparse=use_sparse, use_input_mask=use_input_mask
        )
        collected_images.append({
            'title': f'Train Size: {data_range}, Epochs: {target_epoch}',
            'input_grid': train_input,
            'output_grid': train_output,
            'input_mask': train_mask
        })

        # Process and collect images for test datasets
        for name, dataset in test_datasets.items():
            test_dataloader = DataLoader(dataset, batch_size=128, shuffle=False,
                                         num_workers=0, pin_memory=True)
            test_iter = iter(test_dataloader)
            test_batch = next(test_iter)
            test_input, test_output, test_mask = process_batch(
                sparse_encoder, subsurface_decoder, test_batch, device=device, use_sparse=use_sparse, use_input_mask=use_input_mask
            )
            collected_images.append({
                'title': f'Test: {name}, Epochs: {target_epoch}',
                'input_grid': test_input,
                'output_grid': test_output,
                'input_mask': test_mask
            })

# Convert results to DataFrame
df_results = pd.DataFrame(results)
print("\nResults Summary:")
print(df_results)

# Create visualization plots
plt.style.use('seaborn')
fig, axes = plt.subplots(3, 2, figsize=(20, 24))

# 1. Test Losses vs Data Size for Different Epochs
ax1 = axes[0, 0]
test_types = {
    'test_loss_interpolation': 'Interpolation (3-50)',
    'test_loss_extrapolation_close': 'Close Extrapolation (51-70)',
    'test_loss_extrapolation_far': 'Far Extrapolation (70-90)'
}

for test_type, label in test_types.items():
    for epoch in epochs:
        epoch_data = df_results[df_results['epochs'] == epoch]
        ax1.plot(epoch_data['data_size'], epoch_data[test_type],
                 label=f'{label} (Epochs {epoch})', marker='o', alpha=0.7)
ax1.set_xlabel('Training Data Size')
ax1.set_ylabel('Test Loss')
ax1.set_title('Test Losses vs Training Data Size\n(Training Range: 3-50 points)')
ax1.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
ax1.grid(True)

# 2. Test Losses vs Epochs for Different Data Sizes
ax2 = axes[0, 1]
for test_type, label in test_types.items():
    for data_size in data_ranges:
        size_data = df_results[df_results['data_size'] == data_size]
        ax2.plot(size_data['epochs'], size_data[test_type],
                 label=f'{label} (Size {data_size})', marker='o', alpha=0.7)
ax2.set_xlabel('Epochs')
ax2.set_ylabel('Test Loss')
ax2.set_title('Test Losses vs Epochs\n(Training Range: 3-50 points)')
ax2.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
ax2.grid(True)

# 3. Heatmaps for each test type including Training Loss
fig_heatmaps, axes_heatmaps = plt.subplots(1, 4, figsize=(25, 6))
test_types_with_training = {
    'training_loss': 'Training Loss',
    **test_types
}
for idx, (test_type, label) in enumerate(test_types_with_training.items()):
    pivot_table = df_results.pivot(index='data_size', columns='epochs', values=test_type)
    sns.heatmap(pivot_table, annot=True, fmt=".6f", cmap='viridis', ax=axes_heatmaps[idx])
    axes_heatmaps[idx].set_title(f'{label}\nHeatmap')
fig_heatmaps.tight_layout()
# Removed plt.show() here

# 4. Comparison of test types
ax4 = axes[1, 1]
box_data = [df_results[test_type] for test_type in test_types.keys()]
ax4.boxplot(box_data, labels=[label.replace('\n', ' ') for label in test_types.values()])
ax4.set_ylabel('Test Loss')
ax4.set_title('Distribution of Test Losses by Type\n(Training Range: 3-50 points)')
ax4.grid(True)

# 5. Loss Ratio Analysis
ax3 = axes[1, 0]
df_results['close_extrapolation_ratio'] = df_results['test_loss_extrapolation_close'] / df_results['test_loss_interpolation']
df_results['far_extrapolation_ratio'] = df_results['test_loss_extrapolation_far'] / df_results['test_loss_interpolation']

for ratio, label in [('close_extrapolation_ratio', 'Close Extrapolation / Interpolation'),
                     ('far_extrapolation_ratio', 'Far Extrapolation / Interpolation')]:
    for epoch in epochs:
        epoch_data = df_results[df_results['epochs'] == epoch]
        ax3.plot(epoch_data['data_size'], epoch_data[ratio],
                 label=f'{label} (Epochs {epoch})', marker='o', alpha=0.7)
ax3.axhline(y=1, color='r', linestyle='--', label='Baseline (Equal Performance)')
ax3.set_xlabel('Training Data Size')
ax3.set_ylabel('Loss Ratio')
ax3.set_title('Extrapolation Performance Relative to Interpolation\n(Training Range: 3-50 points)')
ax3.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
ax3.grid(True)

# Heatmaps for variances of each test type
fig_variance_heatmaps, axes_variance_heatmaps = plt.subplots(1, 3, figsize=(20, 6))
variance_test_types = {
    'test_variance_interpolation': 'Interpolation Variance',
    'test_variance_extrapolation_close': 'Close Extrapolation Variance',
    'test_variance_extrapolation_far': 'Far Extrapolation Variance'
}
for idx, (variance_type, label) in enumerate(variance_test_types.items()):
    pivot_table = df_results.pivot(index='data_size', columns='epochs', values=variance_type)
    sns.heatmap(pivot_table, annot=True, fmt=".6f", cmap='viridis', ax=axes_variance_heatmaps[idx])
    axes_variance_heatmaps[idx].set_title(f'{label}\nHeatmap')
fig_variance_heatmaps.tight_layout()

# 7. Training Loss Heatmap (This remains unchanged)
ax7 = axes[2, 1]
training_loss_pivot = df_results.pivot(index='data_size', columns='epochs', values='training_loss')
sns.heatmap(training_loss_pivot, annot=True, fmt=".6f", cmap='viridis', ax=ax7)
ax7.set_title('Training Loss Heatmap')

# 8. Training Time Heatmap
ax8 = axes[2, 0]
training_time_pivot = df_results.pivot(index='data_size', columns='epochs', values='training_time')
sns.heatmap(training_time_pivot, annot=True, fmt=".2f", cmap='Blues', ax=ax8)
ax8.set_title('Training Time Heatmap')
ax8.set_xlabel('Epochs')
ax8.set_ylabel('Training Data Size')

# 9. Evaluation Time Heatmaps for each test set
fig_eval_time_heatmaps, axes_eval_time_heatmaps = plt.subplots(1, 3, figsize=(20, 6))
evaluation_time_test_types = {
    'evaluation_time_interpolation': 'Interpolation Evaluation Time',
    'evaluation_time_extrapolation_close': 'Close Extrapolation Evaluation Time',
    'evaluation_time_extrapolation_far': 'Far Extrapolation Evaluation Time'
}
for idx, (eval_time_type, label) in enumerate(evaluation_time_test_types.items()):
    pivot_table = df_results.pivot(index='data_size', columns='epochs', values=eval_time_type)
    sns.heatmap(pivot_table, annot=True, fmt=".2f", cmap='Greens', ax=axes_eval_time_heatmaps[idx])
    axes_eval_time_heatmaps[idx].set_title(f'{label}\nHeatmap')
    axes_eval_time_heatmaps[idx].set_xlabel('Epochs')
    axes_eval_time_heatmaps[idx].set_ylabel('Training Data Size')
fig_eval_time_heatmaps.tight_layout()

plt.tight_layout()
plt.show()

# Call the function to plot collected images
plot_collected_images(collected_images, images_per_row=4)



# Fine Tunning Results

In [None]:
import os
import time
import random
import torch
import numpy as np
from datetime import datetime
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import pickle
from tqdm import tqdm
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import torch.optim as optim
from sklearn.metrics import mean_squared_error, r2_score
from skimage.metrics import structural_similarity as ssim

# Define data ranges and max epochs
data_ranges = [100,200,400,800,1600,3200,6400,12800]
epochs = [1,2,4,8,16,32,64,128]

# Initialize results list
results = []

# Device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Define test datasets and dataloaders for different scenarios
# All test sets will use the same size (1000 samples) but different point ranges
interpolation_generator = lambda: torch.rand(torch.randint(3, 50, (1,)).item(), 3) * torch.tensor([64, 64, 1])
extrapolation_close_generator = lambda: torch.rand(torch.randint(51, 70, (1,)).item(), 3) * torch.tensor([64, 64, 1])
extrapolation_far_generator = lambda: torch.rand(torch.randint(70, 90, (1,)).item(), 3) * torch.tensor([64, 64, 1])

# Create test datasets
test_generators = {
    'interpolation': interpolation_generator,
    'extrapolation_close': extrapolation_close_generator,
    'extrapolation_far': extrapolation_far_generator
}

test_datasets = {}
for name, gen in test_generators.items():
    generator_obj = SpatialGenerator((64, 64), gen, methods=['kriging'])
    test_dataset = SpatialDataset(1000, generator_obj, drilling_sampling, lambda x: x,
                                  dynamic_secondary_mask=True, x_channels=0,
                                  secondary_channels=0, primary_channels=0)
    test_datasets[name] = test_dataset

test_dataloaders = {name: DataLoader(dataset, batch_size=128, shuffle=False,
                                     num_workers=0, pin_memory=True) for name, dataset in test_datasets.items()}

# Prepare to collect images for plotting
collected_images = []
use_sparse=True
use_input_mask=True

for data_range in data_ranges:
    print(f"\nStarting training with data size: {data_range}")

    # Create training dataset using the same range as interpolation test (3-50)
    generator = SpatialGenerator((64, 64), interpolation_generator, methods=['kriging'])
    train_dataset = SpatialDataset(data_range, generator, drilling_sampling, lambda x: x,
                                   dynamic_secondary_mask=True, x_channels=0,
                                   secondary_channels=0, primary_channels=0)
    train_dataloader = DataLoader(train_dataset, batch_size=128, shuffle=True,
                                  num_workers=0, pin_memory=True)

    # Initialize the models
    convnext = ConvNeXt(depths=[3, 3, 9, 3], dims=[96, 192, 384, 768]).to(device)
    sparse_encoder = SparseEncoder(convnext, input_size=(64, 64)).to(device)
    subsurface_decoder = Decoder(up_sample_ratio=32, out_chans=1, width=768, sbn=False).to(device)
    last_model_path = "drive/MyDrive/Models/base_id_1000k.pth"

    # Load the model state
    sparse_encoder, _ , subsurface_decoder, use_secondary, use_sparse, training_losses = load_model(
        last_model_path,
        sparse_encoder,
        None,
        subsurface_decoder,
        device=device
    )

    # Progressive training loop
    current_epoch = 0

    for target_epoch in epochs:
        epochs_to_train = target_epoch - current_epoch
        if epochs_to_train <= 0:
            continue

        print(f"Training from epoch {current_epoch + 1} to {target_epoch}")

        # Start timing the training
        training_start_time = time.time()

        # Train the model
        sparse_encoder, _, subsurface_decoder, training_losses = train_model(
            sparse_encoder, None, subsurface_decoder, train_dataloader,
            use_sparse=use_sparse,
            model_save_path=f"model_data{data_range}_epochs{target_epoch}.pth",
            num_epochs=epochs_to_train,
            use_secondary=False,
            use_input_mask=use_input_mask,
            device=device,
            save_model=False,
            visualize=False
        )

        # End timing the training
        training_end_time = time.time()
        training_time = training_end_time - training_start_time  # Calculate training time

        # Collect final training loss
        final_training_loss = training_losses[-1]['primary_loss']

        # Evaluate the model on all three test sets
        test_metrics = {}
        evaluation_times = {}

        # Evaluate the model on all three test sets
        for name, dataloader in test_dataloaders.items():
            # Start timing the evaluation
            evaluation_start_time = time.time()

            loss, variance, total_variance = evaluate_model(
                sparse_encoder, subsurface_decoder,
                dataloader, device=device,
                use_sparse=use_sparse, use_input_mask=use_input_mask
            )

            # End timing the evaluation
            evaluation_end_time = time.time()
            evaluation_time = evaluation_end_time - evaluation_start_time  # Calculate evaluation time

            test_metrics[name] = {
                'loss': loss,
                'batch_variance': variance,
                'total_variance': total_variance
            }
            evaluation_times[name] = evaluation_time  # Store evaluation time

        # Append the results
        results.append({
            'data_size': data_range,
            'epochs': target_epoch,
            'training_loss': final_training_loss,
            'training_time': training_time,  # Add training time
            'evaluation_time_interpolation': evaluation_times['interpolation'],  # Add evaluation times
            'evaluation_time_extrapolation_close': evaluation_times['extrapolation_close'],
            'evaluation_time_extrapolation_far': evaluation_times['extrapolation_far'],
            'test_loss_interpolation': test_metrics['interpolation']['loss'],
            'test_variance_interpolation': test_metrics['interpolation']['total_variance'],
            'test_loss_extrapolation_close': test_metrics['extrapolation_close']['loss'],
            'test_variance_extrapolation_close': test_metrics['extrapolation_close']['total_variance'],
            'test_loss_extrapolation_far': test_metrics['extrapolation_far']['loss'],
            'test_variance_extrapolation_far': test_metrics['extrapolation_far']['total_variance'],
            'training_points_range': '3-50'
        })

        current_epoch = target_epoch

        # Process and collect images for training data
        train_iter = iter(train_dataloader)
        train_batch = next(train_iter)
        train_input, train_output, train_mask = process_batch(
            sparse_encoder, subsurface_decoder, train_batch, device=device, use_sparse=use_sparse, use_input_mask=use_input_mask
        )
        collected_images.append({
            'title': f'Train Size: {data_range}, Epochs: {target_epoch}',
            'input_grid': train_input,
            'output_grid': train_output,
            'input_mask': train_mask
        })

        # Process and collect images for test datasets
        for name, dataset in test_datasets.items():
            test_dataloader = DataLoader(dataset, batch_size=128, shuffle=False,
                                         num_workers=0, pin_memory=True)
            test_iter = iter(test_dataloader)
            test_batch = next(test_iter)
            test_input, test_output, test_mask = process_batch(
                sparse_encoder, subsurface_decoder, test_batch, device=device, use_sparse=use_sparse, use_input_mask=use_input_mask
            )
            collected_images.append({
                'title': f'Test: {name}, Epochs: {target_epoch}',
                'input_grid': test_input,
                'output_grid': test_output,
                'input_mask': test_mask
            })

# Convert results to DataFrame
df_results = pd.DataFrame(results)
print("\nResults Summary:")
print(df_results)

# Create visualization plots
plt.style.use('seaborn')
fig, axes = plt.subplots(3, 2, figsize=(20, 24))

# 1. Test Losses vs Data Size for Different Epochs
ax1 = axes[0, 0]
test_types = {
    'test_loss_interpolation': 'Interpolation (3-50)',
    'test_loss_extrapolation_close': 'Close Extrapolation (51-70)',
    'test_loss_extrapolation_far': 'Far Extrapolation (70-90)'
}

for test_type, label in test_types.items():
    for epoch in epochs:
        epoch_data = df_results[df_results['epochs'] == epoch]
        ax1.plot(epoch_data['data_size'], epoch_data[test_type],
                 label=f'{label} (Epochs {epoch})', marker='o', alpha=0.7)
ax1.set_xlabel('Training Data Size')
ax1.set_ylabel('Test Loss')
ax1.set_title('Test Losses vs Training Data Size\n(Training Range: 3-50 points)')
ax1.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
ax1.grid(True)

# 2. Test Losses vs Epochs for Different Data Sizes
ax2 = axes[0, 1]
for test_type, label in test_types.items():
    for data_size in data_ranges:
        size_data = df_results[df_results['data_size'] == data_size]
        ax2.plot(size_data['epochs'], size_data[test_type],
                 label=f'{label} (Size {data_size})', marker='o', alpha=0.7)
ax2.set_xlabel('Epochs')
ax2.set_ylabel('Test Loss')
ax2.set_title('Test Losses vs Epochs\n(Training Range: 3-50 points)')
ax2.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
ax2.grid(True)

# 3. Heatmaps for each test type including Training Loss
fig_heatmaps, axes_heatmaps = plt.subplots(1, 4, figsize=(25, 6))
test_types_with_training = {
    'training_loss': 'Training Loss',
    **test_types
}
for idx, (test_type, label) in enumerate(test_types_with_training.items()):
    pivot_table = df_results.pivot(index='data_size', columns='epochs', values=test_type)
    sns.heatmap(pivot_table, annot=True, fmt=".6f", cmap='viridis', ax=axes_heatmaps[idx])
    axes_heatmaps[idx].set_title(f'{label}\nHeatmap')
fig_heatmaps.tight_layout()
# Removed plt.show() here

# 4. Comparison of test types
ax4 = axes[1, 1]
box_data = [df_results[test_type] for test_type in test_types.keys()]
ax4.boxplot(box_data, labels=[label.replace('\n', ' ') for label in test_types.values()])
ax4.set_ylabel('Test Loss')
ax4.set_title('Distribution of Test Losses by Type\n(Training Range: 3-50 points)')
ax4.grid(True)

# 5. Loss Ratio Analysis
ax3 = axes[1, 0]
df_results['close_extrapolation_ratio'] = df_results['test_loss_extrapolation_close'] / df_results['test_loss_interpolation']
df_results['far_extrapolation_ratio'] = df_results['test_loss_extrapolation_far'] / df_results['test_loss_interpolation']

for ratio, label in [('close_extrapolation_ratio', 'Close Extrapolation / Interpolation'),
                     ('far_extrapolation_ratio', 'Far Extrapolation / Interpolation')]:
    for epoch in epochs:
        epoch_data = df_results[df_results['epochs'] == epoch]
        ax3.plot(epoch_data['data_size'], epoch_data[ratio],
                 label=f'{label} (Epochs {epoch})', marker='o', alpha=0.7)
ax3.axhline(y=1, color='r', linestyle='--', label='Baseline (Equal Performance)')
ax3.set_xlabel('Training Data Size')
ax3.set_ylabel('Loss Ratio')
ax3.set_title('Extrapolation Performance Relative to Interpolation\n(Training Range: 3-50 points)')
ax3.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
ax3.grid(True)

# Heatmaps for variances of each test type
fig_variance_heatmaps, axes_variance_heatmaps = plt.subplots(1, 3, figsize=(20, 6))
variance_test_types = {
    'test_variance_interpolation': 'Interpolation Variance',
    'test_variance_extrapolation_close': 'Close Extrapolation Variance',
    'test_variance_extrapolation_far': 'Far Extrapolation Variance'
}
for idx, (variance_type, label) in enumerate(variance_test_types.items()):
    pivot_table = df_results.pivot(index='data_size', columns='epochs', values=variance_type)
    sns.heatmap(pivot_table, annot=True, fmt=".6f", cmap='viridis', ax=axes_variance_heatmaps[idx])
    axes_variance_heatmaps[idx].set_title(f'{label}\nHeatmap')
fig_variance_heatmaps.tight_layout()

# 7. Training Loss Heatmap (This remains unchanged)
ax7 = axes[2, 1]
training_loss_pivot = df_results.pivot(index='data_size', columns='epochs', values='training_loss')
sns.heatmap(training_loss_pivot, annot=True, fmt=".6f", cmap='viridis', ax=ax7)
ax7.set_title('Training Loss Heatmap')

# 8. Training Time Heatmap
ax8 = axes[2, 0]
training_time_pivot = df_results.pivot(index='data_size', columns='epochs', values='training_time')
sns.heatmap(training_time_pivot, annot=True, fmt=".2f", cmap='Blues', ax=ax8)
ax8.set_title('Training Time Heatmap')
ax8.set_xlabel('Epochs')
ax8.set_ylabel('Training Data Size')

# 9. Evaluation Time Heatmaps for each test set
fig_eval_time_heatmaps, axes_eval_time_heatmaps = plt.subplots(1, 3, figsize=(20, 6))
evaluation_time_test_types = {
    'evaluation_time_interpolation': 'Interpolation Evaluation Time',
    'evaluation_time_extrapolation_close': 'Close Extrapolation Evaluation Time',
    'evaluation_time_extrapolation_far': 'Far Extrapolation Evaluation Time'
}
for idx, (eval_time_type, label) in enumerate(evaluation_time_test_types.items()):
    pivot_table = df_results.pivot(index='data_size', columns='epochs', values=eval_time_type)
    sns.heatmap(pivot_table, annot=True, fmt=".2f", cmap='Greens', ax=axes_eval_time_heatmaps[idx])
    axes_eval_time_heatmaps[idx].set_title(f'{label}\nHeatmap')
    axes_eval_time_heatmaps[idx].set_xlabel('Epochs')
    axes_eval_time_heatmaps[idx].set_ylabel('Training Data Size')
fig_eval_time_heatmaps.tight_layout()

plt.tight_layout()
plt.show()

# Call the function to plot collected images
plot_collected_images(collected_images, images_per_row=4)



In [None]:
import os
import time
import random
import torch
import numpy as np
from datetime import datetime
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import pickle
from tqdm import tqdm
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import torch.optim as optim
from sklearn.metrics import mean_squared_error, r2_score
from skimage.metrics import structural_similarity as ssim

# Define data ranges and max epochs
data_ranges = [100,200,400,800,1600,3200,6400,12800]
epochs = [1,2,4,8,16,32,64,128]

# Initialize results list
results = []

# Device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Define test datasets and dataloaders for different scenarios
# All test sets will use the same size (1000 samples) but different point ranges
interpolation_generator = lambda: torch.rand(torch.randint(3, 50, (1,)).item(), 3) * torch.tensor([64, 64, 1])
extrapolation_close_generator = lambda: torch.rand(torch.randint(51, 70, (1,)).item(), 3) * torch.tensor([64, 64, 1])
extrapolation_far_generator = lambda: torch.rand(torch.randint(70, 90, (1,)).item(), 3) * torch.tensor([64, 64, 1])

# Create test datasets
test_generators = {
    'interpolation': interpolation_generator,
    'extrapolation_close': extrapolation_close_generator,
    'extrapolation_far': extrapolation_far_generator
}

test_datasets = {}
for name, gen in test_generators.items():
    generator_obj = SpatialGenerator((64, 64), gen, methods=['id'])
    test_dataset = SpatialDataset(1000, generator_obj, drilling_sampling, lambda x: x,
                                  dynamic_secondary_mask=True, x_channels=0,
                                  secondary_channels=0, primary_channels=0)
    test_datasets[name] = test_dataset

test_dataloaders = {name: DataLoader(dataset, batch_size=128, shuffle=False,
                                     num_workers=0, pin_memory=True) for name, dataset in test_datasets.items()}

# Prepare to collect images for plotting
collected_images = []
use_sparse=True
use_input_mask=True

for data_range in data_ranges:
    print(f"\nStarting training with data size: {data_range}")

    # Create training dataset using the same range as interpolation test (3-50)
    generator = SpatialGenerator((64, 64), interpolation_generator, methods=['id'])
    train_dataset = SpatialDataset(data_range, generator, drilling_sampling, lambda x: x,
                                   dynamic_secondary_mask=True, x_channels=0,
                                   secondary_channels=0, primary_channels=0)
    train_dataloader = DataLoader(train_dataset, batch_size=128, shuffle=True,
                                  num_workers=0, pin_memory=True)

    # Initialize the models
    convnext = ConvNeXt(depths=[3, 3, 9, 3], dims=[96, 192, 384, 768]).to(device)
    sparse_encoder = SparseEncoder(convnext, input_size=(64, 64)).to(device)
    subsurface_decoder = Decoder(up_sample_ratio=32, out_chans=1, width=768, sbn=False).to(device)
    last_model_path = "drive/MyDrive/Models/base_id_1000k.pth"

    # Load the model state
    sparse_encoder, _ , subsurface_decoder, use_secondary, use_sparse, training_losses = load_model(
        last_model_path,
        sparse_encoder,
        None,
        subsurface_decoder,
        device=device
    )

    # Progressive training loop
    current_epoch = 0

    for target_epoch in epochs:
        epochs_to_train = target_epoch - current_epoch
        if epochs_to_train <= 0:
            continue

        print(f"Training from epoch {current_epoch + 1} to {target_epoch}")

        # Start timing the training
        training_start_time = time.time()

        # Train the model
        sparse_encoder, _, subsurface_decoder, training_losses = train_model(
            sparse_encoder, None, subsurface_decoder, train_dataloader,
            use_sparse=use_sparse,
            model_save_path=f"model_data{data_range}_epochs{target_epoch}.pth",
            num_epochs=epochs_to_train,
            use_secondary=False,
            use_input_mask=use_input_mask,
            device=device,
            save_model=False,
            visualize=False
        )

        # End timing the training
        training_end_time = time.time()
        training_time = training_end_time - training_start_time  # Calculate training time

        # Collect final training loss
        final_training_loss = training_losses[-1]['primary_loss']

        # Evaluate the model on all three test sets
        test_metrics = {}
        evaluation_times = {}

        # Evaluate the model on all three test sets
        for name, dataloader in test_dataloaders.items():
            # Start timing the evaluation
            evaluation_start_time = time.time()

            loss, variance, total_variance = evaluate_model(
                sparse_encoder, subsurface_decoder,
                dataloader, device=device,
                use_sparse=use_sparse, use_input_mask=use_input_mask
            )

            # End timing the evaluation
            evaluation_end_time = time.time()
            evaluation_time = evaluation_end_time - evaluation_start_time  # Calculate evaluation time

            test_metrics[name] = {
                'loss': loss,
                'batch_variance': variance,
                'total_variance': total_variance
            }
            evaluation_times[name] = evaluation_time  # Store evaluation time

        # Append the results
        results.append({
            'data_size': data_range,
            'epochs': target_epoch,
            'training_loss': final_training_loss,
            'training_time': training_time,  # Add training time
            'evaluation_time_interpolation': evaluation_times['interpolation'],  # Add evaluation times
            'evaluation_time_extrapolation_close': evaluation_times['extrapolation_close'],
            'evaluation_time_extrapolation_far': evaluation_times['extrapolation_far'],
            'test_loss_interpolation': test_metrics['interpolation']['loss'],
            'test_variance_interpolation': test_metrics['interpolation']['total_variance'],
            'test_loss_extrapolation_close': test_metrics['extrapolation_close']['loss'],
            'test_variance_extrapolation_close': test_metrics['extrapolation_close']['total_variance'],
            'test_loss_extrapolation_far': test_metrics['extrapolation_far']['loss'],
            'test_variance_extrapolation_far': test_metrics['extrapolation_far']['total_variance'],
            'training_points_range': '3-50'
        })

        current_epoch = target_epoch

        # Process and collect images for training data
        train_iter = iter(train_dataloader)
        train_batch = next(train_iter)
        train_input, train_output, train_mask = process_batch(
            sparse_encoder, subsurface_decoder, train_batch, device=device, use_sparse=use_sparse, use_input_mask=use_input_mask
        )
        collected_images.append({
            'title': f'Train Size: {data_range}, Epochs: {target_epoch}',
            'input_grid': train_input,
            'output_grid': train_output,
            'input_mask': train_mask
        })

        # Process and collect images for test datasets
        for name, dataset in test_datasets.items():
            test_dataloader = DataLoader(dataset, batch_size=128, shuffle=False,
                                         num_workers=0, pin_memory=True)
            test_iter = iter(test_dataloader)
            test_batch = next(test_iter)
            test_input, test_output, test_mask = process_batch(
                sparse_encoder, subsurface_decoder, test_batch, device=device, use_sparse=use_sparse, use_input_mask=use_input_mask
            )
            collected_images.append({
                'title': f'Test: {name}, Epochs: {target_epoch}',
                'input_grid': test_input,
                'output_grid': test_output,
                'input_mask': test_mask
            })

# Convert results to DataFrame
df_results = pd.DataFrame(results)
print("\nResults Summary:")
print(df_results)

# Create visualization plots
plt.style.use('seaborn')
fig, axes = plt.subplots(3, 2, figsize=(20, 24))

# 1. Test Losses vs Data Size for Different Epochs
ax1 = axes[0, 0]
test_types = {
    'test_loss_interpolation': 'Interpolation (3-50)',
    'test_loss_extrapolation_close': 'Close Extrapolation (51-70)',
    'test_loss_extrapolation_far': 'Far Extrapolation (70-90)'
}

for test_type, label in test_types.items():
    for epoch in epochs:
        epoch_data = df_results[df_results['epochs'] == epoch]
        ax1.plot(epoch_data['data_size'], epoch_data[test_type],
                 label=f'{label} (Epochs {epoch})', marker='o', alpha=0.7)
ax1.set_xlabel('Training Data Size')
ax1.set_ylabel('Test Loss')
ax1.set_title('Test Losses vs Training Data Size\n(Training Range: 3-50 points)')
ax1.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
ax1.grid(True)

# 2. Test Losses vs Epochs for Different Data Sizes
ax2 = axes[0, 1]
for test_type, label in test_types.items():
    for data_size in data_ranges:
        size_data = df_results[df_results['data_size'] == data_size]
        ax2.plot(size_data['epochs'], size_data[test_type],
                 label=f'{label} (Size {data_size})', marker='o', alpha=0.7)
ax2.set_xlabel('Epochs')
ax2.set_ylabel('Test Loss')
ax2.set_title('Test Losses vs Epochs\n(Training Range: 3-50 points)')
ax2.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
ax2.grid(True)

# 3. Heatmaps for each test type including Training Loss
fig_heatmaps, axes_heatmaps = plt.subplots(1, 4, figsize=(25, 6))
test_types_with_training = {
    'training_loss': 'Training Loss',
    **test_types
}
for idx, (test_type, label) in enumerate(test_types_with_training.items()):
    pivot_table = df_results.pivot(index='data_size', columns='epochs', values=test_type)
    sns.heatmap(pivot_table, annot=True, fmt=".6f", cmap='viridis', ax=axes_heatmaps[idx])
    axes_heatmaps[idx].set_title(f'{label}\nHeatmap')
fig_heatmaps.tight_layout()
# Removed plt.show() here

# 4. Comparison of test types
ax4 = axes[1, 1]
box_data = [df_results[test_type] for test_type in test_types.keys()]
ax4.boxplot(box_data, labels=[label.replace('\n', ' ') for label in test_types.values()])
ax4.set_ylabel('Test Loss')
ax4.set_title('Distribution of Test Losses by Type\n(Training Range: 3-50 points)')
ax4.grid(True)

# 5. Loss Ratio Analysis
ax3 = axes[1, 0]
df_results['close_extrapolation_ratio'] = df_results['test_loss_extrapolation_close'] / df_results['test_loss_interpolation']
df_results['far_extrapolation_ratio'] = df_results['test_loss_extrapolation_far'] / df_results['test_loss_interpolation']

for ratio, label in [('close_extrapolation_ratio', 'Close Extrapolation / Interpolation'),
                     ('far_extrapolation_ratio', 'Far Extrapolation / Interpolation')]:
    for epoch in epochs:
        epoch_data = df_results[df_results['epochs'] == epoch]
        ax3.plot(epoch_data['data_size'], epoch_data[ratio],
                 label=f'{label} (Epochs {epoch})', marker='o', alpha=0.7)
ax3.axhline(y=1, color='r', linestyle='--', label='Baseline (Equal Performance)')
ax3.set_xlabel('Training Data Size')
ax3.set_ylabel('Loss Ratio')
ax3.set_title('Extrapolation Performance Relative to Interpolation\n(Training Range: 3-50 points)')
ax3.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
ax3.grid(True)

# Heatmaps for variances of each test type
fig_variance_heatmaps, axes_variance_heatmaps = plt.subplots(1, 3, figsize=(20, 6))
variance_test_types = {
    'test_variance_interpolation': 'Interpolation Variance',
    'test_variance_extrapolation_close': 'Close Extrapolation Variance',
    'test_variance_extrapolation_far': 'Far Extrapolation Variance'
}
for idx, (variance_type, label) in enumerate(variance_test_types.items()):
    pivot_table = df_results.pivot(index='data_size', columns='epochs', values=variance_type)
    sns.heatmap(pivot_table, annot=True, fmt=".6f", cmap='viridis', ax=axes_variance_heatmaps[idx])
    axes_variance_heatmaps[idx].set_title(f'{label}\nHeatmap')
fig_variance_heatmaps.tight_layout()

# 7. Training Loss Heatmap (This remains unchanged)
ax7 = axes[2, 1]
training_loss_pivot = df_results.pivot(index='data_size', columns='epochs', values='training_loss')
sns.heatmap(training_loss_pivot, annot=True, fmt=".6f", cmap='viridis', ax=ax7)
ax7.set_title('Training Loss Heatmap')

# 8. Training Time Heatmap
ax8 = axes[2, 0]
training_time_pivot = df_results.pivot(index='data_size', columns='epochs', values='training_time')
sns.heatmap(training_time_pivot, annot=True, fmt=".2f", cmap='Blues', ax=ax8)
ax8.set_title('Training Time Heatmap')
ax8.set_xlabel('Epochs')
ax8.set_ylabel('Training Data Size')

# 9. Evaluation Time Heatmaps for each test set
fig_eval_time_heatmaps, axes_eval_time_heatmaps = plt.subplots(1, 3, figsize=(20, 6))
evaluation_time_test_types = {
    'evaluation_time_interpolation': 'Interpolation Evaluation Time',
    'evaluation_time_extrapolation_close': 'Close Extrapolation Evaluation Time',
    'evaluation_time_extrapolation_far': 'Far Extrapolation Evaluation Time'
}
for idx, (eval_time_type, label) in enumerate(evaluation_time_test_types.items()):
    pivot_table = df_results.pivot(index='data_size', columns='epochs', values=eval_time_type)
    sns.heatmap(pivot_table, annot=True, fmt=".2f", cmap='Greens', ax=axes_eval_time_heatmaps[idx])
    axes_eval_time_heatmaps[idx].set_title(f'{label}\nHeatmap')
    axes_eval_time_heatmaps[idx].set_xlabel('Epochs')
    axes_eval_time_heatmaps[idx].set_ylabel('Training Data Size')
fig_eval_time_heatmaps.tight_layout()

plt.tight_layout()
plt.show()

# Call the function to plot collected images
plot_collected_images(collected_images, images_per_row=4)



# Layered Data Sampling

In [None]:
import os
import time
import random
import torch
import numpy as np
from datetime import datetime
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import pickle
from tqdm import tqdm
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import torch.optim as optim
from sklearn.metrics import mean_squared_error, r2_score
from skimage.metrics import structural_similarity as ssim

# Define sampling percentages and max epochs
sampling_percentages = [1, 2, 3, 4, 5, 10, 15]  # Percentages of total points
epochs = [1, 2, 5, 10, 20, 100, 150]
fixed_dataset_size = 1000
size = (32,64)

def percentage_random_sampling(size, percentage):
    mask = torch.zeros(1, *size)
    height, width = size
    total_points = height * width
    target_points = int((percentage / 100.0) * total_points)

    if percentage >= 100:
        mask[:] = 1
        return mask

    if target_points == 0:
        return mask

    indices = torch.randperm(total_points)[:target_points]
    y_indices = indices // width
    x_indices = indices % width
    mask[0, y_indices, x_indices] = 1

    return mask

# Initialize results list and device
results = []
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Create test generators
interpolation_generator = lambda: torch.rand(torch.randint(3, 50, (1,)).item(), 3) * torch.tensor([64, 64, 1])
extrapolation_close_generator = lambda: torch.rand(torch.randint(51, 70, (1,)).item(), 3) * torch.tensor([64, 64, 1])
extrapolation_far_generator = lambda: torch.rand(torch.randint(70, 90, (1,)).item(), 3) * torch.tensor([64, 64, 1])

test_generators = {
    'interpolation': interpolation_generator,
    'extrapolation_close': extrapolation_close_generator,
    'extrapolation_far': extrapolation_far_generator
}

# Create consistent validation datasets for each sampling percentage
print("Creating consistent validation datasets...")
consistent_test_datasets = {}
validation_states = {}

for sampling_percentage in sampling_percentages:
    consistent_test_datasets[sampling_percentage] = {}
    validation_states[sampling_percentage] = {}

    for name, gen in test_generators.items():
        # Set seeds for reproducibility
        torch.manual_seed(42 + sampling_percentage)
        np.random.seed(42 + sampling_percentage)
        random.seed(42 + sampling_percentage)

        generator_obj = CategoricalSpatialGenerator(size, lambda: two_layer_generator(),
                                                  num_categories=10, methods=['layered', 'vid'])

        test_dataset = SpatialDataset(1000, generator_obj,
                                    lambda size: percentage_random_sampling(size, sampling_percentage),
                                    lambda x: x,
                                    dynamic_secondary_mask=True,
                                    x_channels=1,
                                    secondary_channels=1,
                                    primary_channels=1)

        consistent_test_datasets[sampling_percentage][name] = test_dataset

        # Store random states for this sampling percentage and test type
        validation_states[sampling_percentage][name] = {
            'torch_rng': torch.get_rng_state(),
            'numpy_rng': np.random.get_state(),
            'random_rng': random.getstate()
        }

consistent_test_dataloaders = {
    sampling_percentage: {
        name: DataLoader(dataset, batch_size=128, shuffle=False, num_workers=0, pin_memory=True)
        for name, dataset in test_datasets.items()
    }
    for sampling_percentage, test_datasets in consistent_test_datasets.items()
}

collected_images = []
use_sparse = True
use_input_mask = False

for sampling_percentage in sampling_percentages:
    print(f"\nStarting training with sampling percentage: {sampling_percentage}%")

    # Create training dataset with current sampling percentage
    generator = CategoricalSpatialGenerator(size, lambda: two_layer_generator(),
                                          num_categories=10, methods=['layered', 'vid'])
    train_dataset = SpatialDataset(fixed_dataset_size, generator,
                                  lambda size: percentage_random_sampling(size, sampling_percentage),
                                  lambda x: x,
                                  dynamic_secondary_mask=True,
                                  x_channels=1,
                                  secondary_channels=1,
                                  primary_channels=1)
    train_dataloader = DataLoader(train_dataset, batch_size=128, shuffle=True,
                                num_workers=0, pin_memory=True)

    # Initialize the models
    convnext = ConvNeXt(depths=[3, 3, 9, 3], dims=[96, 192, 384, 768]).to(device)
    sparse_encoder = SparseEncoder(convnext, input_size=(32, 64)).to(device)
    subsurface_decoder = Decoder(up_sample_ratio=32, out_chans=1, width=768, sbn=False).to(device)
    last_model_path = "drive/MyDrive/Models/base_id_1000k.pth"

    # Load the model state
    sparse_encoder, _ , subsurface_decoder, use_secondary, use_sparse, training_losses = load_model(
        last_model_path,
        sparse_encoder,
        None,
        subsurface_decoder,
        device=device
    )

    # Progressive training loop
    current_epoch = 0

    for target_epoch in epochs:
        epochs_to_train = target_epoch - current_epoch
        if epochs_to_train <= 0:
            continue

        print(f"Training from epoch {current_epoch + 1} to {target_epoch}")

        training_start_time = time.time()

        sparse_encoder, _, subsurface_decoder, training_losses = train_model(
            sparse_encoder, None, subsurface_decoder, train_dataloader,
            use_sparse=use_sparse,
            model_save_path=f"model_sampling{sampling_percentage}_epochs{target_epoch}.pth",
            num_epochs=epochs_to_train,
            use_secondary=False,
            use_input_mask=use_input_mask,
            device=device,
            save_model=False,
            visualize=False
        )

        training_end_time = time.time()
        training_time = training_end_time - training_start_time
        final_training_loss = training_losses[-1]['primary_loss']

        # Evaluate on consistent test sets for this sampling percentage
        test_metrics = {}
        evaluation_times = {}

        for name, dataloader in consistent_test_dataloaders[sampling_percentage].items():
            # Restore validation data state
            torch.set_rng_state(validation_states[sampling_percentage][name]['torch_rng'])
            np.random.set_state(validation_states[sampling_percentage][name]['numpy_rng'])
            random.setstate(validation_states[sampling_percentage][name]['random_rng'])

            evaluation_start_time = time.time()

            loss, variance, total_variance = evaluate_model(
                sparse_encoder, subsurface_decoder,
                dataloader, device=device,
                use_sparse=use_sparse, use_input_mask=use_input_mask
            )

            evaluation_end_time = time.time()
            evaluation_time = evaluation_end_time - evaluation_start_time

            test_metrics[name] = {
                'loss': loss,
                'batch_variance': variance,
                'total_variance': total_variance
            }
            evaluation_times[name] = evaluation_time

        # Record results
        results.append({
            'sampling_percentage': sampling_percentage,
            'epochs': target_epoch,
            'training_loss': final_training_loss,
            'training_time': training_time,
            'evaluation_time_interpolation': evaluation_times['interpolation'],
            'evaluation_time_extrapolation_close': evaluation_times['extrapolation_close'],
            'evaluation_time_extrapolation_far': evaluation_times['extrapolation_far'],
            'test_loss_interpolation': test_metrics['interpolation']['loss'],
            'test_variance_interpolation': test_metrics['interpolation']['total_variance'],
            'test_loss_extrapolation_close': test_metrics['extrapolation_close']['loss'],
            'test_variance_extrapolation_close': test_metrics['extrapolation_close']['total_variance'],
            'test_loss_extrapolation_far': test_metrics['extrapolation_far']['loss'],
            'test_variance_extrapolation_far': test_metrics['extrapolation_far']['total_variance']
        })

        current_epoch = target_epoch

        # Collect visualization samples
        # Training data visualization
        train_iter = iter(train_dataloader)
        train_batch = next(train_iter)
        train_input, train_output, train_mask, train_original = process_batch(
            sparse_encoder, subsurface_decoder, train_batch,
            device=device, use_sparse=use_sparse, use_input_mask=use_input_mask
        )
        collected_images.append({
            'title': f'Train Sampling: {sampling_percentage}%, Epochs: {target_epoch}',
            'input_grid': train_input,
            'output_grid': train_output,
            'input_mask': train_mask,
            'original_image': train_original
        })

        # Test data visualization using consistent validation data
        for name, dataloader in consistent_test_dataloaders[sampling_percentage].items():
            test_iter = iter(dataloader)
            test_batch = next(test_iter)
            test_input, test_output, test_mask, test_original = process_batch(
                sparse_encoder, subsurface_decoder, test_batch,
                device=device, use_sparse=use_sparse, use_input_mask=use_input_mask
            )
            collected_images.append({
                'title': f'Test: {name}, Sampling: {sampling_percentage}%, Epochs: {target_epoch}',
                'input_grid': test_input,
                'output_grid': test_output,
                'input_mask': test_mask,
                'original_image': test_original
            })

# Convert results to DataFrame and create visualizations
df_results = pd.DataFrame(results)
print("\nResults Summary:")
print(df_results)

# Create visualization plots (similar structure to previous version, but with sampling_percentage instead of data_size)
plt.style.use('seaborn')
fig, axes = plt.subplots(3, 2, figsize=(20, 24))

# Modified plotting code for sampling percentage analysis
test_types = {
    'test_loss_interpolation': 'Interpolation (3-50)',
    'test_loss_extrapolation_close': 'Close Extrapolation (51-70)',
    'test_loss_extrapolation_far': 'Far Extrapolation (70-90)'
}

# 1. Test Losses vs Sampling Percentage for Different Epochs
ax1 = axes[0, 0]
for test_type, label in test_types.items():
    for epoch in epochs:
        epoch_data = df_results[df_results['epochs'] == epoch]
        ax1.plot(epoch_data['sampling_percentage'], epoch_data[test_type],
                 label=f'{label} (Epochs {epoch})', marker='o', alpha=0.7)
ax1.set_xlabel('Sampling Percentage')
ax1.set_ylabel('Test Loss')
ax1.set_title('Test Losses vs Sampling Percentage\n(Fixed Dataset Size)')
ax1.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
ax1.grid(True)

# 2. Test Losses vs Epochs for Different Sampling Percentages
ax2 = axes[0, 1]
for test_type, label in test_types.items():
    for sampling_pct in sampling_percentages:
        size_data = df_results[df_results['sampling_percentage'] == sampling_pct]
        ax2.plot(size_data['epochs'], size_data[test_type],
                 label=f'{label} ({sampling_pct}%)', marker='o', alpha=0.7)
ax2.set_xlabel('Epochs')
ax2.set_ylabel('Test Loss')
ax2.set_title('Test Losses vs Epochs\n(By Sampling Percentage)')
ax2.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
ax2.grid(True)

# 3. Heatmaps for each test type including Training Loss
fig_heatmaps, axes_heatmaps = plt.subplots(1, 4, figsize=(25, 6))
test_types_with_training = {
    'training_loss': 'Training Loss',
    **test_types
}
for idx, (test_type, label) in enumerate(test_types_with_training.items()):
    pivot_table = df_results.pivot(index='sampling_percentage', columns='epochs', values=test_type)
    sns.heatmap(pivot_table, annot=True, fmt=".6f", cmap='viridis', ax=axes_heatmaps[idx])
    axes_heatmaps[idx].set_title(f'{label}\nHeatmap')
    axes_heatmaps[idx].set_ylabel('Sampling Percentage')
    axes_heatmaps[idx].set_xlabel('Epochs')
fig_heatmaps.tight_layout()

# 4. Comparison of test types
ax4 = axes[1, 1]
box_data = [df_results[test_type] for test_type in test_types.keys()]
ax4.boxplot(box_data, labels=[label.replace('\n', ' ') for label in test_types.values()])
ax4.set_ylabel('Test Loss')
ax4.set_title('Distribution of Test Losses by Type\n(Across All Sampling Percentages)')
ax4.grid(True)

# 5. Loss Ratio Analysis
ax3 = axes[1, 0]
df_results['close_extrapolation_ratio'] = df_results['test_loss_extrapolation_close'] / df_results['test_loss_interpolation']
df_results['far_extrapolation_ratio'] = df_results['test_loss_extrapolation_far'] / df_results['test_loss_interpolation']

for ratio, label in [('close_extrapolation_ratio', 'Close Extrapolation / Interpolation'),
                    ('far_extrapolation_ratio', 'Far Extrapolation / Interpolation')]:
    for epoch in epochs:
        epoch_data = df_results[df_results['epochs'] == epoch]
        ax3.plot(epoch_data['sampling_percentage'], epoch_data[ratio],
                 label=f'{label} (Epochs {epoch})', marker='o', alpha=0.7)
ax3.axhline(y=1, color='r', linestyle='--', label='Baseline (Equal Performance)')
ax3.set_xlabel('Sampling Percentage')
ax3.set_ylabel('Loss Ratio')
ax3.set_title('Extrapolation Performance Relative to Interpolation\n(By Sampling Percentage)')
ax3.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
ax3.grid(True)

# Heatmaps for variances of each test type
fig_variance_heatmaps, axes_variance_heatmaps = plt.subplots(1, 3, figsize=(20, 6))
variance_test_types = {
    'test_variance_interpolation': 'Interpolation Variance',
    'test_variance_extrapolation_close': 'Close Extrapolation Variance',
    'test_variance_extrapolation_far': 'Far Extrapolation Variance'
}
for idx, (variance_type, label) in enumerate(variance_test_types.items()):
    pivot_table = df_results.pivot(index='sampling_percentage', columns='epochs', values=variance_type)
    sns.heatmap(pivot_table, annot=True, fmt=".6f", cmap='viridis', ax=axes_variance_heatmaps[idx])
    axes_variance_heatmaps[idx].set_title(f'{label}\nHeatmap')
    axes_variance_heatmaps[idx].set_ylabel('Sampling Percentage')
    axes_variance_heatmaps[idx].set_xlabel('Epochs')
fig_variance_heatmaps.tight_layout()

# 7. Training Loss Heatmap
ax7 = axes[2, 1]
training_loss_pivot = df_results.pivot(index='sampling_percentage', columns='epochs', values='training_loss')
sns.heatmap(training_loss_pivot, annot=True, fmt=".6f", cmap='viridis', ax=ax7)
ax7.set_title('Training Loss Heatmap')
ax7.set_ylabel('Sampling Percentage')
ax7.set_xlabel('Epochs')

# 8. Training Time Heatmap
ax8 = axes[2, 0]
training_time_pivot = df_results.pivot(index='sampling_percentage', columns='epochs', values='training_time')
sns.heatmap(training_time_pivot, annot=True, fmt=".2f", cmap='Blues', ax=ax8)
ax8.set_title('Training Time Heatmap')
ax8.set_ylabel('Sampling Percentage')
ax8.set_xlabel('Epochs')

# 9. Evaluation Time Heatmaps for each test set
fig_eval_time_heatmaps, axes_eval_time_heatmaps = plt.subplots(1, 3, figsize=(20, 6))
evaluation_time_test_types = {
    'evaluation_time_interpolation': 'Interpolation Evaluation Time',
    'evaluation_time_extrapolation_close': 'Close Extrapolation Evaluation Time',
    'evaluation_time_extrapolation_far': 'Far Extrapolation Evaluation Time'
}
for idx, (eval_time_type, label) in enumerate(evaluation_time_test_types.items()):
    pivot_table = df_results.pivot(index='sampling_percentage', columns='epochs', values=eval_time_type)
    sns.heatmap(pivot_table, annot=True, fmt=".2f", cmap='Greens', ax=axes_eval_time_heatmaps[idx])
    axes_eval_time_heatmaps[idx].set_title(f'{label}\nHeatmap')
    axes_eval_time_heatmaps[idx].set_ylabel('Sampling Percentage')
    axes_eval_time_heatmaps[idx].set_xlabel('Epochs')
    fig_eval_time_heatmaps.tight_layout()

plt.tight_layout()
plt.show()

# Visualize collected images
plot_collected_images(collected_images, images_per_row=4)

In [None]:
plot_collected_images(collected_images, images_per_row=4)

In [None]:
import os
import time
import random
import torch
import numpy as np
from datetime import datetime
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import pickle
from tqdm import tqdm
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import torch.optim as optim
from sklearn.metrics import mean_squared_error, r2_score
from skimage.metrics import structural_similarity as ssim

# Define sampling percentages and max epochs
sampling_percentages = [1, 2, 3, 4, 5, 10, 15]  # Percentages of total points
epochs = [1, 2, 5, 10, 20, 100, 150]
fixed_dataset_size = 1000
size = (32,64)

def percentage_random_sampling(size, percentage):
    mask = torch.zeros(1, *size)
    height, width = size
    total_points = height * width
    target_points = int((percentage / 100.0) * total_points)

    if percentage >= 100:
        mask[:] = 1
        return mask

    if target_points == 0:
        return mask

    indices = torch.randperm(total_points)[:target_points]
    y_indices = indices // width
    x_indices = indices % width
    mask[0, y_indices, x_indices] = 1

    return mask

# Initialize results list and device
results = []
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Create test generators
interpolation_generator = lambda: torch.rand(torch.randint(3, 50, (1,)).item(), 3) * torch.tensor([64, 64, 1])
extrapolation_close_generator = lambda: torch.rand(torch.randint(51, 70, (1,)).item(), 3) * torch.tensor([64, 64, 1])
extrapolation_far_generator = lambda: torch.rand(torch.randint(70, 90, (1,)).item(), 3) * torch.tensor([64, 64, 1])

test_generators = {
    'interpolation': interpolation_generator,
    'extrapolation_close': extrapolation_close_generator,
    'extrapolation_far': extrapolation_far_generator
}

# Create consistent validation datasets for each sampling percentage
print("Creating consistent validation datasets...")
consistent_test_datasets = {}
validation_states = {}

for sampling_percentage in sampling_percentages:
    consistent_test_datasets[sampling_percentage] = {}
    validation_states[sampling_percentage] = {}

    for name, gen in test_generators.items():
        # Set seeds for reproducibility
        torch.manual_seed(42 + sampling_percentage)
        np.random.seed(42 + sampling_percentage)
        random.seed(42 + sampling_percentage)

        generator_obj = CategoricalSpatialGenerator(size, lambda: two_layer_generator(),
                                                  num_categories=10, methods=['layered', 'vid'])

        test_dataset = SpatialDataset(1000, generator_obj,
                                    lambda size: percentage_random_sampling(size, sampling_percentage),
                                    lambda x: x,
                                    dynamic_secondary_mask=False,
                                    x_channels=1,
                                    secondary_channels=1,
                                    primary_channels=1)

        consistent_test_datasets[sampling_percentage][name] = test_dataset

        # Store random states for this sampling percentage and test type
        validation_states[sampling_percentage][name] = {
            'torch_rng': torch.get_rng_state(),
            'numpy_rng': np.random.get_state(),
            'random_rng': random.getstate()
        }

consistent_test_dataloaders = {
    sampling_percentage: {
        name: DataLoader(dataset, batch_size=128, shuffle=False, num_workers=0, pin_memory=True)
        for name, dataset in test_datasets.items()
    }
    for sampling_percentage, test_datasets in consistent_test_datasets.items()
}

collected_images = []
use_sparse = True
use_input_mask = False

for sampling_percentage in sampling_percentages:
    print(f"\nStarting training with sampling percentage: {sampling_percentage}%")

    # Create training dataset with current sampling percentage
    generator = CategoricalSpatialGenerator(size, lambda: two_layer_generator(),
                                          num_categories=10, methods=['layered', 'vid'])
    train_dataset = SpatialDataset(fixed_dataset_size, generator,
                                  lambda size: percentage_random_sampling(size, sampling_percentage),
                                  lambda x: x,
                                  dynamic_secondary_mask=False,
                                  x_channels=1,
                                  secondary_channels=1,
                                  primary_channels=1)
    train_dataloader = DataLoader(train_dataset, batch_size=128, shuffle=True,
                                num_workers=0, pin_memory=True)

    # Initialize the models
    convnext = ConvNeXt(depths=[3, 3, 9, 3], dims=[96, 192, 384, 768]).to(device)
    sparse_encoder = SparseEncoder(convnext, input_size=(32, 64)).to(device)
    subsurface_decoder = Decoder(up_sample_ratio=32, out_chans=1, width=768, sbn=False).to(device)
    last_model_path = "drive/MyDrive/Models/base_id_1000k.pth"

    # Load the model state
    sparse_encoder, _ , subsurface_decoder, use_secondary, use_sparse, training_losses = load_model(
        last_model_path,
        sparse_encoder,
        None,
        subsurface_decoder,
        device=device
    )

    # Progressive training loop
    current_epoch = 0

    for target_epoch in epochs:
        epochs_to_train = target_epoch - current_epoch
        if epochs_to_train <= 0:
            continue

        print(f"Training from epoch {current_epoch + 1} to {target_epoch}")

        training_start_time = time.time()

        sparse_encoder, _, subsurface_decoder, training_losses = train_model(
            sparse_encoder, None, subsurface_decoder, train_dataloader,
            use_sparse=use_sparse,
            model_save_path=f"model_sampling{sampling_percentage}_epochs{target_epoch}.pth",
            num_epochs=epochs_to_train,
            use_secondary=False,
            use_input_mask=use_input_mask,
            device=device,
            save_model=False,
            visualize=False
        )

        training_end_time = time.time()
        training_time = training_end_time - training_start_time
        final_training_loss = training_losses[-1]['primary_loss']

        # Evaluate on consistent test sets for this sampling percentage
        test_metrics = {}
        evaluation_times = {}

        for name, dataloader in consistent_test_dataloaders[sampling_percentage].items():
            # Restore validation data state
            torch.set_rng_state(validation_states[sampling_percentage][name]['torch_rng'])
            np.random.set_state(validation_states[sampling_percentage][name]['numpy_rng'])
            random.setstate(validation_states[sampling_percentage][name]['random_rng'])

            evaluation_start_time = time.time()

            loss, variance, total_variance = evaluate_model(
                sparse_encoder, subsurface_decoder,
                dataloader, device=device,
                use_sparse=use_sparse, use_input_mask=use_input_mask
            )

            evaluation_end_time = time.time()
            evaluation_time = evaluation_end_time - evaluation_start_time

            test_metrics[name] = {
                'loss': loss,
                'batch_variance': variance,
                'total_variance': total_variance
            }
            evaluation_times[name] = evaluation_time

        # Record results
        results.append({
            'sampling_percentage': sampling_percentage,
            'epochs': target_epoch,
            'training_loss': final_training_loss,
            'training_time': training_time,
            'evaluation_time_interpolation': evaluation_times['interpolation'],
            'evaluation_time_extrapolation_close': evaluation_times['extrapolation_close'],
            'evaluation_time_extrapolation_far': evaluation_times['extrapolation_far'],
            'test_loss_interpolation': test_metrics['interpolation']['loss'],
            'test_variance_interpolation': test_metrics['interpolation']['total_variance'],
            'test_loss_extrapolation_close': test_metrics['extrapolation_close']['loss'],
            'test_variance_extrapolation_close': test_metrics['extrapolation_close']['total_variance'],
            'test_loss_extrapolation_far': test_metrics['extrapolation_far']['loss'],
            'test_variance_extrapolation_far': test_metrics['extrapolation_far']['total_variance']
        })

        current_epoch = target_epoch

        # Collect visualization samples
        # Training data visualization
        train_iter = iter(train_dataloader)
        train_batch = next(train_iter)
        train_input, train_output, train_mask, train_original = process_batch(
            sparse_encoder, subsurface_decoder, train_batch,
            device=device, use_sparse=use_sparse, use_input_mask=use_input_mask
        )
        collected_images.append({
            'title': f'Train Sampling: {sampling_percentage}%, Epochs: {target_epoch}',
            'input_grid': train_input,
            'output_grid': train_output,
            'input_mask': train_mask,
            'original_image': train_original
        })

        # Test data visualization using consistent validation data
        for name, dataloader in consistent_test_dataloaders[sampling_percentage].items():
            test_iter = iter(dataloader)
            test_batch = next(test_iter)
            test_input, test_output, test_mask, test_original = process_batch(
                sparse_encoder, subsurface_decoder, test_batch,
                device=device, use_sparse=use_sparse, use_input_mask=use_input_mask
            )
            collected_images.append({
                'title': f'Test: {name}, Sampling: {sampling_percentage}%, Epochs: {target_epoch}',
                'input_grid': test_input,
                'output_grid': test_output,
                'input_mask': test_mask,
                'original_image': test_original
            })

# Convert results to DataFrame and create visualizations
df_results = pd.DataFrame(results)
print("\nResults Summary:")
print(df_results)

# Create visualization plots (similar structure to previous version, but with sampling_percentage instead of data_size)
plt.style.use('seaborn')
fig, axes = plt.subplots(3, 2, figsize=(20, 24))

# Modified plotting code for sampling percentage analysis
test_types = {
    'test_loss_interpolation': 'Interpolation (3-50)',
    'test_loss_extrapolation_close': 'Close Extrapolation (51-70)',
    'test_loss_extrapolation_far': 'Far Extrapolation (70-90)'
}

# 1. Test Losses vs Sampling Percentage for Different Epochs
ax1 = axes[0, 0]
for test_type, label in test_types.items():
    for epoch in epochs:
        epoch_data = df_results[df_results['epochs'] == epoch]
        ax1.plot(epoch_data['sampling_percentage'], epoch_data[test_type],
                 label=f'{label} (Epochs {epoch})', marker='o', alpha=0.7)
ax1.set_xlabel('Sampling Percentage')
ax1.set_ylabel('Test Loss')
ax1.set_title('Test Losses vs Sampling Percentage\n(Fixed Dataset Size)')
ax1.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
ax1.grid(True)

# 2. Test Losses vs Epochs for Different Sampling Percentages
ax2 = axes[0, 1]
for test_type, label in test_types.items():
    for sampling_pct in sampling_percentages:
        size_data = df_results[df_results['sampling_percentage'] == sampling_pct]
        ax2.plot(size_data['epochs'], size_data[test_type],
                 label=f'{label} ({sampling_pct}%)', marker='o', alpha=0.7)
ax2.set_xlabel('Epochs')
ax2.set_ylabel('Test Loss')
ax2.set_title('Test Losses vs Epochs\n(By Sampling Percentage)')
ax2.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
ax2.grid(True)

# 3. Heatmaps for each test type including Training Loss
fig_heatmaps, axes_heatmaps = plt.subplots(1, 4, figsize=(25, 6))
test_types_with_training = {
    'training_loss': 'Training Loss',
    **test_types
}
for idx, (test_type, label) in enumerate(test_types_with_training.items()):
    pivot_table = df_results.pivot(index='sampling_percentage', columns='epochs', values=test_type)
    sns.heatmap(pivot_table, annot=True, fmt=".6f", cmap='viridis', ax=axes_heatmaps[idx])
    axes_heatmaps[idx].set_title(f'{label}\nHeatmap')
    axes_heatmaps[idx].set_ylabel('Sampling Percentage')
    axes_heatmaps[idx].set_xlabel('Epochs')
fig_heatmaps.tight_layout()

# 4. Comparison of test types
ax4 = axes[1, 1]
box_data = [df_results[test_type] for test_type in test_types.keys()]
ax4.boxplot(box_data, labels=[label.replace('\n', ' ') for label in test_types.values()])
ax4.set_ylabel('Test Loss')
ax4.set_title('Distribution of Test Losses by Type\n(Across All Sampling Percentages)')
ax4.grid(True)

# 5. Loss Ratio Analysis
ax3 = axes[1, 0]
df_results['close_extrapolation_ratio'] = df_results['test_loss_extrapolation_close'] / df_results['test_loss_interpolation']
df_results['far_extrapolation_ratio'] = df_results['test_loss_extrapolation_far'] / df_results['test_loss_interpolation']

for ratio, label in [('close_extrapolation_ratio', 'Close Extrapolation / Interpolation'),
                    ('far_extrapolation_ratio', 'Far Extrapolation / Interpolation')]:
    for epoch in epochs:
        epoch_data = df_results[df_results['epochs'] == epoch]
        ax3.plot(epoch_data['sampling_percentage'], epoch_data[ratio],
                 label=f'{label} (Epochs {epoch})', marker='o', alpha=0.7)
ax3.axhline(y=1, color='r', linestyle='--', label='Baseline (Equal Performance)')
ax3.set_xlabel('Sampling Percentage')
ax3.set_ylabel('Loss Ratio')
ax3.set_title('Extrapolation Performance Relative to Interpolation\n(By Sampling Percentage)')
ax3.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
ax3.grid(True)

# Heatmaps for variances of each test type
fig_variance_heatmaps, axes_variance_heatmaps = plt.subplots(1, 3, figsize=(20, 6))
variance_test_types = {
    'test_variance_interpolation': 'Interpolation Variance',
    'test_variance_extrapolation_close': 'Close Extrapolation Variance',
    'test_variance_extrapolation_far': 'Far Extrapolation Variance'
}
for idx, (variance_type, label) in enumerate(variance_test_types.items()):
    pivot_table = df_results.pivot(index='sampling_percentage', columns='epochs', values=variance_type)
    sns.heatmap(pivot_table, annot=True, fmt=".6f", cmap='viridis', ax=axes_variance_heatmaps[idx])
    axes_variance_heatmaps[idx].set_title(f'{label}\nHeatmap')
    axes_variance_heatmaps[idx].set_ylabel('Sampling Percentage')
    axes_variance_heatmaps[idx].set_xlabel('Epochs')
fig_variance_heatmaps.tight_layout()

# 7. Training Loss Heatmap
ax7 = axes[2, 1]
training_loss_pivot = df_results.pivot(index='sampling_percentage', columns='epochs', values='training_loss')
sns.heatmap(training_loss_pivot, annot=True, fmt=".6f", cmap='viridis', ax=ax7)
ax7.set_title('Training Loss Heatmap')
ax7.set_ylabel('Sampling Percentage')
ax7.set_xlabel('Epochs')

# 8. Training Time Heatmap
ax8 = axes[2, 0]
training_time_pivot = df_results.pivot(index='sampling_percentage', columns='epochs', values='training_time')
sns.heatmap(training_time_pivot, annot=True, fmt=".2f", cmap='Blues', ax=ax8)
ax8.set_title('Training Time Heatmap')
ax8.set_ylabel('Sampling Percentage')
ax8.set_xlabel('Epochs')

# 9. Evaluation Time Heatmaps for each test set
fig_eval_time_heatmaps, axes_eval_time_heatmaps = plt.subplots(1, 3, figsize=(20, 6))
evaluation_time_test_types = {
    'evaluation_time_interpolation': 'Interpolation Evaluation Time',
    'evaluation_time_extrapolation_close': 'Close Extrapolation Evaluation Time',
    'evaluation_time_extrapolation_far': 'Far Extrapolation Evaluation Time'
}
for idx, (eval_time_type, label) in enumerate(evaluation_time_test_types.items()):
    pivot_table = df_results.pivot(index='sampling_percentage', columns='epochs', values=eval_time_type)
    sns.heatmap(pivot_table, annot=True, fmt=".2f", cmap='Greens', ax=axes_eval_time_heatmaps[idx])
    axes_eval_time_heatmaps[idx].set_title(f'{label}\nHeatmap')
    axes_eval_time_heatmaps[idx].set_ylabel('Sampling Percentage')
    axes_eval_time_heatmaps[idx].set_xlabel('Epochs')
    fig_eval_time_heatmaps.tight_layout()

plt.tight_layout()
plt.show()

# Visualize collected images
plot_collected_images(collected_images, images_per_row=4)

In [None]:
plot_collected_images(collected_images, images_per_row=4)

In [None]:
import os
import time
import random
import torch
import numpy as np
from datetime import datetime
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import pickle
from tqdm import tqdm
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import torch.optim as optim
from sklearn.metrics import mean_squared_error, r2_score
from skimage.metrics import structural_similarity as ssim

# Define sampling percentages and max epochs
sampling_percentages = [1, 2, 3, 4, 5, 10, 15]  # Percentages of total points
epochs = [1, 5, 10, 20, 40, 80, 160]
fixed_dataset_size = 1000  # Keep dataset size constant
size = (32,64)

def percentage_random_sampling(size, percentage):
    mask = torch.zeros(1, *size)
    height, width = size
    total_points = height * width
    target_points = int((percentage / 100.0) * total_points)

    if percentage >= 100:
        mask[:] = 1
        return mask

    if target_points == 0:
        return mask

    indices = torch.randperm(total_points)[:target_points]
    y_indices = indices // width
    x_indices = indices % width
    mask[0, y_indices, x_indices] = 1

    return mask

# Initialize results list and device
results = []
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Create test generators
interpolation_generator = lambda: torch.rand(torch.randint(3, 50, (1,)).item(), 3) * torch.tensor([64, 64, 1])
extrapolation_close_generator = lambda: torch.rand(torch.randint(51, 70, (1,)).item(), 3) * torch.tensor([64, 64, 1])
extrapolation_far_generator = lambda: torch.rand(torch.randint(70, 90, (1,)).item(), 3) * torch.tensor([64, 64, 1])

test_generators = {
    'interpolation': interpolation_generator,
    'extrapolation_close': extrapolation_close_generator,
    'extrapolation_far': extrapolation_far_generator
}

# Create consistent validation datasets for each sampling percentage
print("Creating consistent validation datasets...")
consistent_test_datasets = {}
validation_states = {}

for sampling_percentage in sampling_percentages:
    consistent_test_datasets[sampling_percentage] = {}
    validation_states[sampling_percentage] = {}

    for name, gen in test_generators.items():
        # Set seeds for reproducibility
        torch.manual_seed(42 + sampling_percentage)
        np.random.seed(42 + sampling_percentage)
        random.seed(42 + sampling_percentage)

        generator_obj = CategoricalSpatialGenerator(size, lambda: two_layer_generator(),
                                                  num_categories=10, methods=['layered', 'vid'])

        test_dataset = SpatialDataset(1000, generator_obj,
                                    lambda size: percentage_random_sampling(size, sampling_percentage),
                                    lambda x: x,
                                    dynamic_secondary_mask=True,
                                    x_channels=1,
                                    secondary_channels=1,
                                    primary_channels=1)

        consistent_test_datasets[sampling_percentage][name] = test_dataset

        # Store random states for this sampling percentage and test type
        validation_states[sampling_percentage][name] = {
            'torch_rng': torch.get_rng_state(),
            'numpy_rng': np.random.get_state(),
            'random_rng': random.getstate()
        }

consistent_test_dataloaders = {
    sampling_percentage: {
        name: DataLoader(dataset, batch_size=128, shuffle=False, num_workers=0, pin_memory=True)
        for name, dataset in test_datasets.items()
    }
    for sampling_percentage, test_datasets in consistent_test_datasets.items()
}

collected_images = []
use_sparse = True
use_input_mask = False

for sampling_percentage in sampling_percentages:
    print(f"\nStarting training with sampling percentage: {sampling_percentage}%")

    # Create training dataset with current sampling percentage
    generator = CategoricalSpatialGenerator(size, lambda: two_layer_generator(),
                                          num_categories=10, methods=['layered', 'vid'])
    train_dataset = SpatialDataset(fixed_dataset_size, generator,
                                  lambda size: percentage_random_sampling(size, sampling_percentage),
                                  lambda x: x,
                                  dynamic_secondary_mask=True,
                                  x_channels=1,
                                  secondary_channels=1,
                                  primary_channels=1)
    train_dataloader = DataLoader(train_dataset, batch_size=128, shuffle=True,
                                num_workers=0, pin_memory=True)

    # Initialize the models
    convnext = ConvNeXt(depths=[3, 3, 9, 3], dims=[96, 192, 384, 768]).to(device)
    sparse_encoder = SparseEncoder(convnext, input_size=(32, 64)).to(device)
    subsurface_decoder = Decoder(up_sample_ratio=32, out_chans=1, width=768, sbn=False).to(device)

    # Progressive training loop
    current_epoch = 0

    for target_epoch in epochs:
        epochs_to_train = target_epoch - current_epoch
        if epochs_to_train <= 0:
            continue

        print(f"Training from epoch {current_epoch + 1} to {target_epoch}")

        training_start_time = time.time()

        sparse_encoder, _, subsurface_decoder, training_losses = train_model(
            sparse_encoder, None, subsurface_decoder, train_dataloader,
            use_sparse=use_sparse,
            model_save_path=f"model_sampling{sampling_percentage}_epochs{target_epoch}.pth",
            num_epochs=epochs_to_train,
            use_secondary=False,
            use_input_mask=use_input_mask,
            device=device,
            save_model=False,
            visualize=True
        )

        training_end_time = time.time()
        training_time = training_end_time - training_start_time
        final_training_loss = training_losses[-1]['primary_loss']

        # Evaluate on consistent test sets for this sampling percentage
        test_metrics = {}
        evaluation_times = {}

        for name, dataloader in consistent_test_dataloaders[sampling_percentage].items():
            # Restore validation data state
            torch.set_rng_state(validation_states[sampling_percentage][name]['torch_rng'])
            np.random.set_state(validation_states[sampling_percentage][name]['numpy_rng'])
            random.setstate(validation_states[sampling_percentage][name]['random_rng'])

            evaluation_start_time = time.time()

            loss, variance, total_variance = evaluate_model(
                sparse_encoder, subsurface_decoder,
                dataloader, device=device,
                use_sparse=use_sparse, use_input_mask=use_input_mask
            )

            evaluation_end_time = time.time()
            evaluation_time = evaluation_end_time - evaluation_start_time

            test_metrics[name] = {
                'loss': loss,
                'batch_variance': variance,
                'total_variance': total_variance
            }
            evaluation_times[name] = evaluation_time

        # Record results
        results.append({
            'sampling_percentage': sampling_percentage,
            'epochs': target_epoch,
            'training_loss': final_training_loss,
            'training_time': training_time,
            'evaluation_time_interpolation': evaluation_times['interpolation'],
            'evaluation_time_extrapolation_close': evaluation_times['extrapolation_close'],
            'evaluation_time_extrapolation_far': evaluation_times['extrapolation_far'],
            'test_loss_interpolation': test_metrics['interpolation']['loss'],
            'test_variance_interpolation': test_metrics['interpolation']['total_variance'],
            'test_loss_extrapolation_close': test_metrics['extrapolation_close']['loss'],
            'test_variance_extrapolation_close': test_metrics['extrapolation_close']['total_variance'],
            'test_loss_extrapolation_far': test_metrics['extrapolation_far']['loss'],
            'test_variance_extrapolation_far': test_metrics['extrapolation_far']['total_variance']
        })

        current_epoch = target_epoch

        # Collect visualization samples
        # Training data visualization
        train_iter = iter(train_dataloader)
        train_batch = next(train_iter)
        train_input, train_output, train_mask = process_batch(
            sparse_encoder, subsurface_decoder, train_batch,
            device=device, use_sparse=use_sparse, use_input_mask=use_input_mask
        )
        collected_images.append({
            'title': f'Train Sampling: {sampling_percentage}%, Epochs: {target_epoch}',
            'input_grid': train_input,
            'output_grid': train_output,
            'input_mask': train_mask
        })

        # Test data visualization using consistent validation data
        for name, dataloader in consistent_test_dataloaders[sampling_percentage].items():
            test_iter = iter(dataloader)
            test_batch = next(test_iter)
            test_input, test_output, test_mask = process_batch(
                sparse_encoder, subsurface_decoder, test_batch,
                device=device, use_sparse=use_sparse, use_input_mask=use_input_mask
            )
            collected_images.append({
                'title': f'Test: {name}, Sampling: {sampling_percentage}%, Epochs: {target_epoch}',
                'input_grid': test_input,
                'output_grid': test_output,
                'input_mask': test_mask
            })

# Convert results to DataFrame and create visualizations
df_results = pd.DataFrame(results)
print("\nResults Summary:")
print(df_results)

# Visualization code remains the same...

# Create visualization plots (similar structure to previous version, but with sampling_percentage instead of data_size)
plt.style.use('seaborn')
fig, axes = plt.subplots(3, 2, figsize=(20, 24))

# Modified plotting code for sampling percentage analysis
test_types = {
    'test_loss_interpolation': 'Interpolation (3-50)',
    'test_loss_extrapolation_close': 'Close Extrapolation (51-70)',
    'test_loss_extrapolation_far': 'Far Extrapolation (70-90)'
}

# 1. Test Losses vs Sampling Percentage for Different Epochs
ax1 = axes[0, 0]
for test_type, label in test_types.items():
    for epoch in epochs:
        epoch_data = df_results[df_results['epochs'] == epoch]
        ax1.plot(epoch_data['sampling_percentage'], epoch_data[test_type],
                 label=f'{label} (Epochs {epoch})', marker='o', alpha=0.7)
ax1.set_xlabel('Sampling Percentage')
ax1.set_ylabel('Test Loss')
ax1.set_title('Test Losses vs Sampling Percentage\n(Fixed Dataset Size)')
ax1.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
ax1.grid(True)

# 2. Test Losses vs Epochs for Different Sampling Percentages
ax2 = axes[0, 1]
for test_type, label in test_types.items():
    for sampling_pct in sampling_percentages:
        size_data = df_results[df_results['sampling_percentage'] == sampling_pct]
        ax2.plot(size_data['epochs'], size_data[test_type],
                 label=f'{label} ({sampling_pct}%)', marker='o', alpha=0.7)
ax2.set_xlabel('Epochs')
ax2.set_ylabel('Test Loss')
ax2.set_title('Test Losses vs Epochs\n(By Sampling Percentage)')
ax2.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
ax2.grid(True)

# 3. Heatmaps for each test type including Training Loss
fig_heatmaps, axes_heatmaps = plt.subplots(1, 4, figsize=(25, 6))
test_types_with_training = {
    'training_loss': 'Training Loss',
    **test_types
}
for idx, (test_type, label) in enumerate(test_types_with_training.items()):
    pivot_table = df_results.pivot(index='sampling_percentage', columns='epochs', values=test_type)
    sns.heatmap(pivot_table, annot=True, fmt=".6f", cmap='viridis', ax=axes_heatmaps[idx])
    axes_heatmaps[idx].set_title(f'{label}\nHeatmap')
    axes_heatmaps[idx].set_ylabel('Sampling Percentage')
    axes_heatmaps[idx].set_xlabel('Epochs')
fig_heatmaps.tight_layout()

# 4. Comparison of test types
ax4 = axes[1, 1]
box_data = [df_results[test_type] for test_type in test_types.keys()]
ax4.boxplot(box_data, labels=[label.replace('\n', ' ') for label in test_types.values()])
ax4.set_ylabel('Test Loss')
ax4.set_title('Distribution of Test Losses by Type\n(Across All Sampling Percentages)')
ax4.grid(True)

# 5. Loss Ratio Analysis
ax3 = axes[1, 0]
df_results['close_extrapolation_ratio'] = df_results['test_loss_extrapolation_close'] / df_results['test_loss_interpolation']
df_results['far_extrapolation_ratio'] = df_results['test_loss_extrapolation_far'] / df_results['test_loss_interpolation']

for ratio, label in [('close_extrapolation_ratio', 'Close Extrapolation / Interpolation'),
                    ('far_extrapolation_ratio', 'Far Extrapolation / Interpolation')]:
    for epoch in epochs:
        epoch_data = df_results[df_results['epochs'] == epoch]
        ax3.plot(epoch_data['sampling_percentage'], epoch_data[ratio],
                 label=f'{label} (Epochs {epoch})', marker='o', alpha=0.7)
ax3.axhline(y=1, color='r', linestyle='--', label='Baseline (Equal Performance)')
ax3.set_xlabel('Sampling Percentage')
ax3.set_ylabel('Loss Ratio')
ax3.set_title('Extrapolation Performance Relative to Interpolation\n(By Sampling Percentage)')
ax3.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
ax3.grid(True)

# Heatmaps for variances of each test type
fig_variance_heatmaps, axes_variance_heatmaps = plt.subplots(1, 3, figsize=(20, 6))
variance_test_types = {
    'test_variance_interpolation': 'Interpolation Variance',
    'test_variance_extrapolation_close': 'Close Extrapolation Variance',
    'test_variance_extrapolation_far': 'Far Extrapolation Variance'
}
for idx, (variance_type, label) in enumerate(variance_test_types.items()):
    pivot_table = df_results.pivot(index='sampling_percentage', columns='epochs', values=variance_type)
    sns.heatmap(pivot_table, annot=True, fmt=".6f", cmap='viridis', ax=axes_variance_heatmaps[idx])
    axes_variance_heatmaps[idx].set_title(f'{label}\nHeatmap')
    axes_variance_heatmaps[idx].set_ylabel('Sampling Percentage')
    axes_variance_heatmaps[idx].set_xlabel('Epochs')
fig_variance_heatmaps.tight_layout()

# 7. Training Loss Heatmap
ax7 = axes[2, 1]
training_loss_pivot = df_results.pivot(index='sampling_percentage', columns='epochs', values='training_loss')
sns.heatmap(training_loss_pivot, annot=True, fmt=".6f", cmap='viridis', ax=ax7)
ax7.set_title('Training Loss Heatmap')
ax7.set_ylabel('Sampling Percentage')
ax7.set_xlabel('Epochs')

# 8. Training Time Heatmap
ax8 = axes[2, 0]
training_time_pivot = df_results.pivot(index='sampling_percentage', columns='epochs', values='training_time')
sns.heatmap(training_time_pivot, annot=True, fmt=".2f", cmap='Blues', ax=ax8)
ax8.set_title('Training Time Heatmap')
ax8.set_ylabel('Sampling Percentage')
ax8.set_xlabel('Epochs')

# 9. Evaluation Time Heatmaps for each test set
fig_eval_time_heatmaps, axes_eval_time_heatmaps = plt.subplots(1, 3, figsize=(20, 6))
evaluation_time_test_types = {
    'evaluation_time_interpolation': 'Interpolation Evaluation Time',
    'evaluation_time_extrapolation_close': 'Close Extrapolation Evaluation Time',
    'evaluation_time_extrapolation_far': 'Far Extrapolation Evaluation Time'
}
for idx, (eval_time_type, label) in enumerate(evaluation_time_test_types.items()):
    pivot_table = df_results.pivot(index='sampling_percentage', columns='epochs', values=eval_time_type)
    sns.heatmap(pivot_table, annot=True, fmt=".2f", cmap='Greens', ax=axes_eval_time_heatmaps[idx])
    axes_eval_time_heatmaps[idx].set_title(f'{label}\nHeatmap')
    axes_eval_time_heatmaps[idx].set_ylabel('Sampling Percentage')
    axes_eval_time_heatmaps[idx].set_xlabel('Epochs')
    fig_eval_time_heatmaps.tight_layout()

plt.tight_layout()
plt.show()

# Visualize collected images
plot_collected_images(collected_images, images_per_row=4)

# Secodnary Variable

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
from tqdm import tqdm


def train_model(sparse_encoder, sample_decoder, parameter_decoder, dataloader, num_epochs=10, learning_rate=1e-4, device="cuda"):
    sparse_encoder.train()
    sample_decoder.train()
    parameter_decoder.train()

    optimizer = optim.AdamW(list(sparse_encoder.parameters()) +
                           list(sample_decoder.parameters()) +
                           list(parameter_decoder.parameters()),
                           lr=learning_rate)
    criterion = nn.MSELoss()
    training_losses = []

    for epoch in range(num_epochs):
        total_parameter_loss = 0
        total_sample_loss = 0
        num_batches = 0

        for batch in tqdm(dataloader, desc=f"Epoch {epoch+1}/{num_epochs}"):
            x, param_grid, param_mask, response_grid, response_mask = batch
            batch_x = x.to(device)
            batch_param_grid = param_grid.to(device)
            batch_param_mask = param_mask.to(device)
            batch_response_grid = response_grid.to(device).float()
            batch_response_mask = response_mask.to(device)

            optimizer.zero_grad()

            global _cur_active
            _cur_active = batch_response_mask

            features = sparse_encoder(batch_response_grid * batch_response_mask)
            sample_output = sample_decoder(features[::-1])
            parameter_output = parameter_decoder(features[::-1])

            sample_loss = criterion(sample_output, batch_response_grid)
            parameter_loss = criterion(parameter_output, batch_x)

            total_sample_loss += sample_loss.item()
            total_parameter_loss += parameter_loss.item()

            loss = sample_loss + parameter_loss
            loss.backward()
            optimizer.step()

            num_batches += 1

        avg_sample_loss = total_sample_loss / num_batches
        avg_parameter_loss = total_parameter_loss / num_batches

        training_losses.append({
            'epoch': epoch + 1,
            'sample_loss': avg_sample_loss,
            'parameter_loss': avg_parameter_loss,
            'total_loss': avg_sample_loss + avg_parameter_loss
        })

        print(f"Epoch {epoch+1}/{num_epochs}, Average Sample Loss: {avg_sample_loss:.10f}, Average Parameter Loss: {avg_parameter_loss:.10f}")

    return sparse_encoder, sample_decoder, parameter_decoder, training_losses

# Test the pipeline
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

def evaluate_model(sparse_encoder, subsurface_decoder, dataloader, device="cuda", use_sparse=False, use_secondary=False, use_input_mask=True, imputation_decoder=None):
    sparse_encoder.to(device)
    subsurface_decoder.to(device)
    sparse_encoder.eval()
    subsurface_decoder.eval()

    criterion = nn.MSELoss()
    total_loss = 0
    total_variance = 0
    num_batches = 0
    all_outputs = []

    if use_secondary and imputation_decoder is not None:
        imputation_decoder.to(device)
        imputation_decoder.eval()

    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Evaluating"):
            x, param_grid, param_mask, response_grid, response_mask = batch
            batch_x = x.to(device).float()  # Convert to float32
            batch_param_grid = param_grid.to(device).float()
            batch_param_mask = param_mask.to(device).float()
            batch_response_grid = response_grid.to(device).float()
            batch_response_mask = response_mask.to(device).float()

            global _cur_active

            if use_sparse:
                if use_input_mask:
                    _cur_active = batch_param_mask
                else:
                    _cur_active = batch_response_mask
            else:
                _cur_active = torch.ones_like(batch_param_mask).to(device)

            if use_secondary:
                model_input = batch_response_grid * _cur_active
            else:
                model_input = batch_x * _cur_active

            features = sparse_encoder(model_input)
            primary_output = subsurface_decoder(features[::-1])

            if use_secondary and imputation_decoder is not None:
                secondary_output = imputation_decoder(features[::-1])

            all_outputs.append(primary_output.cpu())

            loss = criterion(primary_output, batch_x)
            total_loss += loss.item()

            batch_variance = torch.var(primary_output).item()
            total_variance += batch_variance

            num_batches += 1

    avg_loss = total_loss / num_batches
    avg_variance = total_variance / num_batches

    all_outputs = torch.cat(all_outputs, dim=0)
    total_prediction_variance = torch.var(all_outputs).item()

    return avg_loss, avg_variance, total_prediction_variance

def process_batch(sparse_encoder, subsurface_decoder, batch, device="cuda", use_secondary=False, use_sparse=False, use_input_mask=True, imputation_decoder=None):
    sparse_encoder.eval()
    subsurface_decoder.eval()
    if use_secondary and imputation_decoder is not None:
        imputation_decoder.eval()

    with torch.no_grad():
        x, param_grid, param_mask, response_grid, response_mask = batch
        # Ensure all tensors are float32
        batch_x = x.to(device).float()
        batch_param_grid = param_grid.to(device).float()
        batch_param_mask = param_mask.to(device).float()
        batch_response_grid = response_grid.to(device).float()
        batch_response_mask = response_mask.to(device).float()

        global _cur_active

        if use_sparse:
            if use_input_mask:
                _cur_active = batch_param_mask
            else:
                _cur_active = batch_response_mask
        else:
            _cur_active = torch.ones_like(batch_param_mask).to(device).float()

        if use_secondary:
            model_input = batch_response_grid * _cur_active
        else:
            model_input = batch_x * _cur_active

        # Convert model parameters to float32
        sparse_encoder = sparse_encoder.float()
        subsurface_decoder = subsurface_decoder.float()
        if imputation_decoder is not None:
            imputation_decoder = imputation_decoder.float()

        features = sparse_encoder(model_input)
        primary_output = subsurface_decoder(features[::-1])

        if use_secondary and imputation_decoder is not None:
            secondary_output = imputation_decoder(features[::-1])
            return model_input.cpu(), primary_output.cpu(), secondary_output.cpu(), _cur_active.cpu(), batch_x.cpu()
        else:
            return model_input.cpu(), primary_output.cpu(), None, _cur_active.cpu(), batch_x.cpu()

def plot_collected_images(collected_images, images_per_row=2, max_images=10):
    """
    Plot collected images with a limit on the number of images shown.
    Includes original input visualization.
    """
    # Limit the number of images to prevent matplotlib size issues
    collected_images = collected_images[:max_images]
    num_images = len(collected_images)

    has_originals = any('original_image' in data for data in collected_images)
    has_imputation = any('imputation_output' in data for data in collected_images)
    cols_per_item = 2  # Start with 2 for input and output
    if has_originals:
        cols_per_item += 1
    if has_imputation:
        cols_per_item += 1

    images_per_row = min(images_per_row, num_images)
    num_rows = (num_images + images_per_row - 1) // images_per_row

    fig, axes = plt.subplots(num_rows, images_per_row * cols_per_item,
                            figsize=(5 * images_per_row * cols_per_item, 5 * num_rows))
    if num_rows == 1 and images_per_row == 1:
        axes = np.array([axes])
    axes = axes.reshape(num_rows, images_per_row * cols_per_item)

    for idx, data in enumerate(collected_images):
        row_idx = idx // images_per_row
        col_start = (idx % images_per_row) * cols_per_item

        try:
            input_grid = data['original_secondary'][0].squeeze().cpu().numpy()
            output_grid = data['output_grid'][0].squeeze().cpu().numpy()
            input_mask = data['input_mask'][0].squeeze().cpu().numpy()
            title = data['title']

            col_idx = col_start

            # Plot original input
            ax_input = axes[row_idx, col_idx]
            im_input = ax_input.imshow(input_grid, cmap='viridis', interpolation='nearest')
            ax_input.set_title(f"{title}\n(Original Input)")
            plt.colorbar(im_input, ax=ax_input)
            ax_input.axis('off')
            col_idx += 1

            # Plot original image if available
            if has_originals:
                ax_orig = axes[row_idx, col_idx]
                if 'original_image' in data:
                    original = data['original_image'][0].squeeze().cpu().numpy()
                    im_orig = ax_orig.imshow(original, cmap='viridis', interpolation='nearest')
                    ax_orig.set_title(f"{title}\n(Original)")
                    plt.colorbar(im_orig, ax=ax_orig)
                else:
                    ax_orig.text(0.5, 0.5, 'No original image',
                               ha='center', va='center', transform=ax_orig.transAxes)
                ax_orig.axis('off')
                col_idx += 1

            # Plot subsurface prediction
            ax = axes[row_idx, col_idx]
            im = ax.imshow(output_grid, cmap='viridis', interpolation='nearest')
            ax.set_title(f"{title}\n(Subsurface Prediction)")
            ax.axis('off')
            plt.colorbar(im, ax=ax)
            col_idx += 1

            # Plot imputation output with points on top
            if has_imputation and 'imputation_output' in data:
                imputation_output = data['imputation_output'][0].squeeze().cpu().numpy()
                ax_imp = axes[row_idx, col_idx]
                im_imp = ax_imp.imshow(imputation_output, cmap='viridis', interpolation='nearest')

                # Add points on top of imputation output
                mask_indices = np.argwhere(input_mask > 0)
                if len(mask_indices) > 0:
                    mask_values = input_grid[tuple(mask_indices.T)]
                    scatter = ax_imp.scatter(mask_indices[:, 1],
                                          mask_indices[:, 0],
                                          c=mask_values,
                                          cmap='viridis',
                                          edgecolors='black',
                                          s=50)

                ax_imp.set_title(f"{title}\n(Imputation Output)")
                ax_imp.axis('off')
                plt.colorbar(im_imp, ax=ax_imp)

        except Exception as e:
            current_ax = axes[row_idx, col_idx]
            current_ax.text(0.5, 0.5, f'Error processing image {idx}\n{str(e)}',
                          ha='center', va='center', transform=current_ax.transAxes, color='red')
            current_ax.axis('off')

    # Remove empty subplots
    for idx in range(len(collected_images) * cols_per_item, num_rows * images_per_row * cols_per_item):
        row_idx = idx // (images_per_row * cols_per_item)
        col_idx = idx % (images_per_row * cols_per_item)
        if row_idx < axes.shape[0] and col_idx < axes.shape[1]:
            fig.delaxes(axes[row_idx, col_idx])

    plt.tight_layout()
    plt.show()

# Create a fixed validation sample to use across all epochs
def create_fixed_validation_sample(dataset, num_samples=1):
    """
    Create fixed validation samples with consistent sampling points.
    """
    fixed_samples = []

    # Set fixed seed temporarily
    torch.manual_seed(42)
    np.random.seed(42)
    random.seed(42)

    for _ in range(num_samples):
        x, param_grid, param_mask, response_grid, response_mask = dataset[0]

        # Handle both tensor and numpy array cases
        fixed_samples.append({
            'x': torch.as_tensor(x) if isinstance(x, np.ndarray) else x.clone(),
            'param_grid': torch.as_tensor(param_grid) if isinstance(param_grid, np.ndarray) else param_grid.clone(),
            'param_mask': torch.as_tensor(param_mask) if isinstance(param_mask, np.ndarray) else param_mask.clone(),
            'response_grid': torch.as_tensor(response_grid) if isinstance(response_grid, np.ndarray) else response_grid.clone(),
            'response_mask': torch.as_tensor(response_mask) if isinstance(response_mask, np.ndarray) else response_mask.clone()
        })

    # Reset random seeds
    torch.manual_seed(int(time.time()))
    np.random.seed(None)
    random.seed(None)

    return fixed_samples

def process_fixed_validation_sample(sparse_encoder, subsurface_decoder, sample_data,
                                  device="cuda", use_secondary=False, use_sparse=False,
                                  use_input_mask=True, imputation_decoder=None):
    """
    Process a single fixed validation sample through the model.
    """
    sparse_encoder.eval()
    subsurface_decoder.eval()
    if use_secondary and imputation_decoder is not None:
        imputation_decoder.eval()

    with torch.no_grad():
        # Prepare batch (add batch dimension)
        batch_x = sample_data['x'].unsqueeze(0).to(device).float()
        batch_param_grid = sample_data['param_grid'].unsqueeze(0).to(device).float()
        batch_param_mask = sample_data['param_mask'].unsqueeze(0).to(device).float()
        batch_response_grid = sample_data['response_grid'].unsqueeze(0).to(device).float()
        batch_response_mask = sample_data['response_mask'].unsqueeze(0).to(device).float()

        global _cur_active

        if use_sparse:
            if use_input_mask:
                _cur_active = batch_param_mask
            else:
                _cur_active = batch_response_mask
        else:
            _cur_active = torch.ones_like(batch_param_mask).to(device).float()

        if use_secondary:
            model_input = batch_response_grid * _cur_active
        else:
            model_input = batch_x * _cur_active

        features = sparse_encoder(model_input)
        primary_output = subsurface_decoder(features[::-1])

        if use_secondary and imputation_decoder is not None:
            secondary_output = imputation_decoder(features[::-1])
            return (model_input.cpu(), primary_output.cpu(), secondary_output.cpu(),
                   _cur_active.cpu(), batch_x.cpu(), batch_response_grid.cpu())
        else:
            return (model_input.cpu(), primary_output.cpu(), None,
                   _cur_active.cpu(), batch_x.cpu(), batch_response_grid.cpu())


In [None]:
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

def plot_comprehensive_heatmaps(results_df):
    """
    Create comprehensive heatmap visualizations for primary and secondary variables.
    """
    # Create a figure with 2x2 subplots for the heatmaps
    fig, axes = plt.subplots(2, 2, figsize=(20, 20))

    # Primary Variable Training Loss Heatmap
    primary_train_pivot = results_df.pivot(
        index='sampling_percentage',
        columns='epochs',
        values='primary_train_loss'
    )
    sns.heatmap(
        primary_train_pivot,
        annot=True,
        fmt=".6f",
        cmap='viridis',
        ax=axes[0, 0],
        cbar_kws={'label': 'Loss Value'}
    )
    axes[0, 0].set_title('Primary Variable Training Loss')
    axes[0, 0].set_ylabel('Sampling Percentage')
    axes[0, 0].set_xlabel('Epochs')

    # Primary Variable Validation Loss Heatmap
    primary_val_pivot = results_df.pivot(
        index='sampling_percentage',
        columns='epochs',
        values='primary_val_loss'
    )
    sns.heatmap(
        primary_val_pivot,
        annot=True,
        fmt=".6f",
        cmap='viridis',
        ax=axes[0, 1],
        cbar_kws={'label': 'Loss Value'}
    )
    axes[0, 1].set_title('Primary Variable Validation Loss')
    axes[0, 1].set_ylabel('Sampling Percentage')
    axes[0, 1].set_xlabel('Epochs')

    # Secondary Variable Training Loss Heatmap
    secondary_train_pivot = results_df.pivot(
        index='sampling_percentage',
        columns='epochs',
        values='secondary_train_loss'
    )
    sns.heatmap(
        secondary_train_pivot,
        annot=True,
        fmt=".6f",
        cmap='viridis',
        ax=axes[1, 0],
        cbar_kws={'label': 'Loss Value'}
    )
    axes[1, 0].set_title('Secondary Variable Training Loss')
    axes[1, 0].set_ylabel('Sampling Percentage')
    axes[1, 0].set_xlabel('Epochs')

    # Secondary Variable Validation Loss Heatmap
    secondary_val_pivot = results_df.pivot(
        index='sampling_percentage',
        columns='epochs',
        values='secondary_val_loss'
    )
    sns.heatmap(
        secondary_val_pivot,
        annot=True,
        fmt=".6f",
        cmap='viridis',
        ax=axes[1, 1],
        cbar_kws={'label': 'Loss Value'}
    )
    axes[1, 1].set_title('Secondary Variable Validation Loss')
    axes[1, 1].set_ylabel('Sampling Percentage')
    axes[1, 1].set_xlabel('Epochs')

    plt.tight_layout()
    plt.show()

# Modify the training loop to collect separate losses
def evaluate_with_separate_losses(sparse_encoder, subsurface_decoder, dataloader,
                                device="cuda", use_sparse=False, use_input_mask=True,
                                imputation_decoder=None):
    sparse_encoder.eval()
    subsurface_decoder.eval()
    if imputation_decoder is not None:
        imputation_decoder.eval()

    primary_criterion = nn.MSELoss()
    secondary_criterion = nn.MSELoss()
    total_primary_loss = 0
    total_secondary_loss = 0
    num_batches = 0

    with torch.no_grad():
        for batch in dataloader:
            x, param_grid, param_mask, response_grid, response_mask = batch
            batch_x = x.to(device).float()
            batch_response_grid = response_grid.to(device).float()
            batch_response_mask = response_mask.to(device).float()

            global _cur_active
            _cur_active = batch_response_mask if use_sparse else torch.ones_like(batch_response_mask).to(device)

            features = sparse_encoder(batch_response_grid * _cur_active)
            primary_output = subsurface_decoder(features[::-1])
            secondary_output = imputation_decoder(features[::-1]) if imputation_decoder else None

            # Calculate primary loss
            primary_loss = primary_criterion(primary_output, batch_x)
            total_primary_loss += primary_loss.item()

            # Calculate secondary loss if applicable
            if secondary_output is not None:
                secondary_loss = secondary_criterion(secondary_output, batch_response_grid)
                total_secondary_loss += secondary_loss.item()

            num_batches += 1

    avg_primary_loss = total_primary_loss / num_batches
    avg_secondary_loss = total_secondary_loss / num_batches if imputation_decoder else 0

    return avg_primary_loss, avg_secondary_loss


In [None]:
import os
import time
import random
import torch
import numpy as np
from datetime import datetime
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import pickle
from tqdm import tqdm
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import torch.optim as optim
from sklearn.metrics import mean_squared_error, r2_score
from skimage.metrics import structural_similarity as ssim

# Ensure that 'size', 'generator', and 'response_grid_fn' are defined
size = (32, 64)

def response_grid_fn(x):
    return x

# Define your generator
generator = CategoricalSpatialGenerator(
    size,
    lambda: two_layer_generator(),
    num_categories=10,
    methods=['layered', 'vid']
)

# Adjusted drilling_sampling_custom to accept sampling_percentage
def get_drilling_sampling_custom(sampling_percentage):
    def drilling_sampling_custom(size, min_drillholes=5, max_drillholes=15, min_samples=3, max_samples=20):
        mask = torch.zeros(1, *size)
        height, width = size
        total_points = height * width
        target_samples = int((random.uniform(50, 90) / 100.0) * total_points)
        num_samples = max(min_samples, target_samples)

        num_drillholes = sampling_percentage
        samples_per_drillhole = num_samples // num_drillholes if num_drillholes > 0 else num_samples

        for _ in range(num_drillholes):
            x = random.randint(0, width - 1)
            num_samples_drillhole = random.randint(min_samples, max_samples)
            num_samples_drillhole = min(samples_per_drillhole, num_samples_drillhole)
            y_positions = random.sample(range(height), min(num_samples_drillhole, height))
            for y in y_positions:
                mask[0, y, x] = 1


        return mask
    return drilling_sampling_custom

# Sampling percentages and epochs
sampling_percentages = [1, 2, 3, 5, 10, 20, 40, 60]
epochs = [1, 5, 10, 20, 40, 80, 160, 320]
fixed_dataset_size = 1000

# Initialize results list and device
results = []
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Create the datasets once, and we'll update the sampling_fn dynamically
train_dataset = SpatialDataset(
    num_generations=0,
    generator=generator,
    sampling_fn=None,  # We'll set this dynamically
    secondary_grid_fn=response_grid_fn,
    data_folder="drive/MyDrive/data_res/study_case_1",
    dynamic_secondary_mask=True,
    x_channels=1,            # Input channels
    secondary_channels=0,    # Channels for the secondary grid
    primary_channels=1       # Channels for the primary grid
)

train_dataloader = DataLoader(
    train_dataset,
    batch_size=128,
    shuffle=True,
    num_workers=0,
    pin_memory=True
)

test_dataset = SpatialDataset(
    num_generations=0,
    generator=generator,
    sampling_fn=None,  # We'll set this dynamically
    secondary_grid_fn=response_grid_fn,
    data_folder="drive/MyDrive/data_res/study_case_1_test",
    dynamic_secondary_mask=True,
    x_channels=1,
    secondary_channels=0,
    primary_channels=1
)

test_dataloader = DataLoader(
    test_dataset,
    batch_size=16,
    shuffle=False,
    num_workers=0,
    pin_memory=True
)

collected_images = []
use_sparse = True
use_input_mask = False
results = []

for sampling_percentage in sampling_percentages:
    print(f"\nStarting training with sampling percentage: {sampling_percentage}%")

    # Set the sampling function for the current sampling_percentage
    sampling_fn = get_drilling_sampling_custom(sampling_percentage)

    # Update the sampling_fn in the datasets
    train_dataset.sampling_fn = sampling_fn
    test_dataset.sampling_fn = sampling_fn

    # Create fixed validation samples for this sampling percentage
    fixed_validation_samples = create_fixed_validation_sample(test_dataset, num_samples=5)

    # Initialize models
    convnext = ConvNeXt(depths=[3, 3, 9, 3], dims=[96, 192, 384, 768]).to(device)
    sparse_encoder = SparseEncoder(convnext, input_size=size).to(device)
    subsurface_decoder = Decoder(up_sample_ratio=32, out_chans=1, width=768, sbn=False).to(device)
    imputation_decoder = Decoder(up_sample_ratio=32, out_chans=1, width=768, sbn=False).to(device)
    last_model_path = "drive/MyDrive/Models/base_id_1000k.pth"

    # Load the model state
    sparse_encoder, _ , imputation_decoder, use_secondary, use_sparse, training_losses = load_model(
        last_model_path,
        sparse_encoder,
        None,
        imputation_decoder,
        device=device
    )
    # Progressive training loop
    current_epoch = 0

    for target_epoch in epochs:
        epochs_to_train = target_epoch - current_epoch
        if epochs_to_train <= 0:
            continue

        print(f"Training from epoch {current_epoch + 1} to {target_epoch}")

        training_start_time = time.time()

        # Train the model

        sparse_encoder, imputation_decoder, subsurface_decoder, training_losses = train_model(
            sparse_encoder,
            imputation_decoder,
            subsurface_decoder,
            train_dataloader,
            num_epochs=epochs_to_train,
            device=device
        )

        training_end_time = time.time()
        training_time = training_end_time - training_start_time
        final_training_loss = training_losses[-1]['total_loss']

        # Evaluate the model
        primary_train_loss, secondary_train_loss = evaluate_with_separate_losses(
            sparse_encoder, subsurface_decoder, train_dataloader,
            device=device, use_sparse=use_sparse, use_input_mask=use_input_mask,
            imputation_decoder=imputation_decoder
        )

        # Evaluate validation losses
        primary_val_loss, secondary_val_loss = evaluate_with_separate_losses(
            sparse_encoder, subsurface_decoder, test_dataloader,
            device=device, use_sparse=use_sparse, use_input_mask=use_input_mask,
            imputation_decoder=imputation_decoder
        )

        results.append({
            'sampling_percentage': sampling_percentage,
            'epochs': target_epoch,
            'primary_train_loss': primary_train_loss,
            'primary_val_loss': primary_val_loss,
            'secondary_train_loss': secondary_train_loss,
            'secondary_val_loss': secondary_val_loss,
            'training_time': training_time
        })


        # Process fixed validation samples
        for idx, fixed_sample in enumerate(fixed_validation_samples):
            val_input, val_output, val_imputation_output, val_mask, val_original, val_secondary = process_fixed_validation_sample(
                sparse_encoder, subsurface_decoder, fixed_sample,
                device=device, use_sparse=use_sparse, use_input_mask=use_input_mask,
                use_secondary=True, imputation_decoder=imputation_decoder
            )

            collected_images.append({
                'title': f'Validation {idx+1} - {sampling_percentage}%, Epoch {target_epoch}',
                'input_grid': val_input,
                'output_grid': val_output,
                'imputation_output': val_imputation_output,
                'input_mask': val_mask,
                'original_image': val_original,
                'original_secondary': val_secondary
            })

        current_epoch = target_epoch

        # Training data visualization
        train_iter = iter(train_dataloader)
        train_batch = next(train_iter)
        train_input, train_output, train_imputation_output, train_mask, train_original = process_batch(
            sparse_encoder, subsurface_decoder, train_batch,
            device=device, use_sparse=use_sparse, use_input_mask=use_input_mask,
            use_secondary=True, imputation_decoder=imputation_decoder
        )

        # Test data visualization
        test_iter = iter(test_dataloader)
        test_batch = next(test_iter)
        test_input, test_output, test_imputation_output, test_mask, test_original = process_batch(
            sparse_encoder, subsurface_decoder, test_batch,
            device=device, use_sparse=use_sparse, use_input_mask=use_input_mask,
            use_secondary=True, imputation_decoder=imputation_decoder
        )

    # Save the model after training with this sampling percentage
    torch.save({
        'sparse_encoder_state_dict': sparse_encoder.state_dict(),
        'imputation_decoder_state_dict': imputation_decoder.state_dict(),
        'subsurface_decoder_state_dict': subsurface_decoder.state_dict(),
        'training_losses': training_losses,
    }, f"model_sampling{sampling_percentage}.pth")

# Convert results to DataFrame and create visualizations
df_results = pd.DataFrame(results)
print("\nResults Summary:")
print(df_results)



In [None]:
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np

def create_loss_visualizations(df_results, save_path=None):
    """
    Create comprehensive visualizations for training results

    Parameters:
    df_results (pd.DataFrame): DataFrame containing training results with columns:
        - sampling_percentage
        - epochs
        - primary_train_loss
        - primary_val_loss
        - secondary_train_loss
        - secondary_val_loss
        - training_time
    save_path (str, optional): Path to save the figures
    """
    # Set the style
    plt.style.use('default')
    sns.set_theme(style="whitegrid")

    # Create main figure for loss plots
    fig, axes = plt.subplots(3, 2, figsize=(20, 24))

    # 1. Primary Loss vs Sampling Percentage for Different Epochs
    for epoch in df_results['epochs'].unique():
        epoch_data = df_results[df_results['epochs'] == epoch]
        axes[0, 0].plot(epoch_data['sampling_percentage'],
                       epoch_data['primary_val_loss'],
                       label=f'Epochs {epoch}',
                       marker='o',
                       alpha=0.7)

    axes[0, 0].set_xlabel('Sampling Percentage (%)')
    axes[0, 0].set_ylabel('Validation Loss')
    axes[0, 0].set_title('Primary Variable Validation Loss vs Sampling Percentage')
    axes[0, 0].legend(bbox_to_anchor=(1.05, 1), loc='upper left')
    axes[0, 0].set_xscale('log')

    # 2. Secondary Loss vs Epochs for Different Sampling Percentages
    for pct in df_results['sampling_percentage'].unique():
        pct_data = df_results[df_results['sampling_percentage'] == pct]
        axes[0, 1].plot(pct_data['epochs'],
                       pct_data['secondary_val_loss'],
                       label=f'{pct}%',
                       marker='o',
                       alpha=0.7)

    axes[0, 1].set_xlabel('Epochs')
    axes[0, 1].set_ylabel('Validation Loss')
    axes[0, 1].set_title('Secondary Variable Validation Loss vs Epochs')
    axes[0, 1].legend(bbox_to_anchor=(1.05, 1), loc='upper left')
    axes[0, 1].set_xscale('log')

    # 3. Training vs Validation Loss Comparison
    axes[1, 0].scatter(df_results['primary_train_loss'],
                      df_results['primary_val_loss'],
                      alpha=0.6,
                      c=df_results['sampling_percentage'],
                      cmap='viridis')

    lims = [
        np.min([axes[1, 0].get_xlim(), axes[1, 0].get_ylim()]),
        np.max([axes[1, 0].get_xlim(), axes[1, 0].get_ylim()])
    ]
    axes[1, 0].plot(lims, lims, 'k--', alpha=0.5)
    axes[1, 0].set_xlabel('Training Loss')
    axes[1, 0].set_ylabel('Validation Loss')
    axes[1, 0].set_title('Training vs Validation Loss')

    # 4. Training Time Analysis
    sns.boxplot(data=df_results,
                x='sampling_percentage',
                y='training_time',
                ax=axes[1, 1])
    axes[1, 1].set_xlabel('Sampling Percentage (%)')
    axes[1, 1].set_ylabel('Training Time (seconds)')
    axes[1, 1].set_title('Training Time Distribution')

    # 5. Loss Heatmaps
    loss_pivot = df_results.pivot(index='sampling_percentage',
                                columns='epochs',
                                values='primary_val_loss')
    sns.heatmap(loss_pivot,
                annot=True,
                fmt='.2e',
                cmap='viridis',
                ax=axes[2, 0],
                cbar_kws={'label': 'Validation Loss'})
    axes[2, 0].set_title('Primary Variable Validation Loss Heatmap')
    axes[2, 0].set_ylabel('Sampling Percentage (%)')
    axes[2, 0].set_xlabel('Epochs')

    # 6. Secondary Loss Heatmap
    secondary_pivot = df_results.pivot(index='sampling_percentage',
                                     columns='epochs',
                                     values='secondary_val_loss')
    sns.heatmap(secondary_pivot,
                annot=True,
                fmt='.2e',
                cmap='viridis',
                ax=axes[2, 1],
                cbar_kws={'label': 'Validation Loss'})
    axes[2, 1].set_title('Secondary Variable Validation Loss Heatmap')
    axes[2, 1].set_ylabel('Sampling Percentage (%)')
    axes[2, 1].set_xlabel('Epochs')

    plt.tight_layout()

    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')

    plt.show()

def plot_training_metrics_summary(df_results):
    """
    Create a summary dashboard of key training metrics
    """
    fig, ax = plt.subplots(figsize=(12, 8))

    # Calculate efficiency metrics
    df_results['loss_improvement'] = df_results.groupby('sampling_percentage')['primary_val_loss'].transform(
        lambda x: (x.iloc[0] - x) / x.iloc[0]
    )

    df_results['time_efficiency'] = df_results['loss_improvement'] / df_results['training_time']

    # Create summary plot
    sns.scatterplot(data=df_results,
                   x='sampling_percentage',
                   y='time_efficiency',
                   size='epochs',
                   hue='primary_val_loss',
                   palette='viridis')

    plt.xscale('log')
    plt.xlabel('Sampling Percentage (%)')
    plt.ylabel('Efficiency (Loss Improvement / Training Time)')
    plt.title('Training Efficiency Analysis')

    plt.tight_layout()
    plt.show()

def print_summary_statistics(df_results):
    """
    Print summary statistics of the training results
    """
    print("\nSummary Statistics:")
    print("-" * 50)

    # Best performing configuration
    best_idx = df_results['primary_val_loss'].idxmin()
    best_config = df_results.loc[best_idx]

    print(f"Best Configuration:")
    print(f"Sampling Percentage: {best_config['sampling_percentage']}%")
    print(f"Epochs: {best_config['epochs']}")
    print(f"Validation Loss: {best_config['primary_val_loss']:.6f}")
    print(f"Training Time: {best_config['training_time']:.2f} seconds")

    # Efficiency analysis
    print("\nEfficiency Analysis:")
    df_results['efficiency'] = df_results['primary_val_loss'] * df_results['training_time']
    most_efficient_idx = df_results['efficiency'].idxmin()
    efficient_config = df_results.loc[most_efficient_idx]

    print(f"Most Efficient Configuration:")
    print(f"Sampling Percentage: {efficient_config['sampling_percentage']}%")
    print(f"Epochs: {efficient_config['epochs']}")
    print(f"Validation Loss: {efficient_config['primary_val_loss']:.6f}")
    print(f"Training Time: {efficient_config['training_time']:.2f} seconds")

In [None]:
df_results = pd.DataFrame(results)
plot_comprehensive_heatmaps(df_results)

In [None]:
# Create the visualizations
create_loss_visualizations(df_results, save_path='training_results.png')

# Get efficiency analysis
plot_training_metrics_summary(df_results)

# Print summary statistics
print_summary_statistics(df_results)

In [None]:
def create_loss_data_arrays(df_results, epochs, sampling_percentages):
    """
    Create 2D arrays for heatmap visualization from DataFrame results

    Parameters:
    df_results (pd.DataFrame): DataFrame containing the results
    epochs (list): List of epoch values
    sampling_percentages (list): List of sampling percentage values

    Returns:
    tuple: Four 2D numpy arrays for primary/secondary train/test losses
    """
    # Initialize empty arrays
    primary_test_loss_data = np.zeros((len(sampling_percentages), len(epochs)))
    primary_train_loss_data = np.zeros((len(sampling_percentages), len(epochs)))
    secondary_test_loss_data = np.zeros((len(sampling_percentages), len(epochs)))
    secondary_train_loss_data = np.zeros((len(sampling_percentages), len(epochs)))

    # Fill the arrays
    for i, samp in enumerate(sampling_percentages):
        for j, ep in enumerate(epochs):
            # Get the row matching current sampling percentage and epoch
            mask = (df_results['sampling_percentage'] == samp) & (df_results['epochs'] == ep)
            if not df_results[mask].empty:
                row = df_results[mask].iloc[0]
                primary_test_loss_data[i, j] = row['primary_val_loss']
                primary_train_loss_data[i, j] = row['primary_train_loss']
                secondary_test_loss_data[i, j] = row['secondary_val_loss']
                secondary_train_loss_data[i, j] = row['secondary_train_loss']

    # Create formatted string representations
    def format_array(arr):
        return np.array2string(arr,
                             formatter={'float_kind': lambda x: "%.6f" % x},
                             separator=',\n ',
                             prefix='[',
                             suffix=']')

    print("primary_test_loss_data = ", format_array(primary_test_loss_data))
    print("\nprimary_train_loss_data = ", format_array(primary_train_loss_data))
    print("\nsecondary_test_loss_data = ", format_array(secondary_test_loss_data))
    print("\nsecondary_train_loss_data = ", format_array(secondary_train_loss_data))

    return (primary_test_loss_data, primary_train_loss_data,
            secondary_test_loss_data, secondary_train_loss_data)

# Example usage:
epochs = [1, 5, 10, 20, 40, 80, 160, 320]
sampling_percentages = [1, 2, 3, 5, 10, 20, 40, 60]

# Create the arrays
loss_arrays = create_loss_data_arrays(df_results, epochs, sampling_percentages)

# You can then access individual arrays:
primary_test_loss_data = loss_arrays[0]
primary_train_loss_data = loss_arrays[1]
secondary_test_loss_data = loss_arrays[2]
secondary_train_loss_data = loss_arrays[3]

In [None]:
plot_collected_images(collected_images, images_per_row=5)

In [None]:
def plot_collected_images(collected_images, images_per_row=2, max_images=None):
    """
    Plot collected images with configurable limits and better progress tracking.

    Args:
        collected_images: List of dictionaries containing image data
        images_per_row: Number of images to display per row
        max_images: Maximum number of images to display (None for all images)
    """
    # Print diagnostic information
    print(f"Total images in collection: {len(collected_images)}")
    print("\nAvailable sampling percentages:")
    percentages = set(int(data['title'].split('%')[0].split()[-1])
                     for data in collected_images
                     if '%' in data['title'])
    print(sorted(percentages))

    # Don't limit if max_images is None
    if max_images is not None:
        collected_images = collected_images[:max_images]

    num_images = len(collected_images)
    print(f"\nProcessing {num_images} images")

    has_originals = any('original_image' in data for data in collected_images)
    has_imputation = any('imputation_output' in data for data in collected_images)
    cols_per_item = 2  # Start with 2 for input and output
    if has_originals:
        cols_per_item += 1
    if has_imputation:
        cols_per_item += 1

    images_per_row = min(images_per_row, num_images)
    num_rows = (num_images + images_per_row - 1) // images_per_row

    fig, axes = plt.subplots(num_rows, images_per_row * cols_per_item,
                            figsize=(5 * images_per_row * cols_per_item, 5 * num_rows))

    # Handle single row/column cases
    if num_rows == 1 and images_per_row == 1:
        axes = np.array([axes])
    if len(axes.shape) == 1:
        axes = axes.reshape(1, -1)

    print("\nProcessing individual images:")
    for idx, data in enumerate(collected_images):
        print(f"Processing image {idx + 1}/{num_images}: {data['title']}")

        row_idx = idx // images_per_row
        col_start = (idx % images_per_row) * cols_per_item

        try:
            if 'original_secondary' not in data:
                print(f"Warning: Missing original_secondary for image {idx}")
                continue

            input_grid = data['original_secondary'][0].squeeze().cpu().numpy()
            output_grid = data['output_grid'][0].squeeze().cpu().numpy()
            input_mask = data['input_mask'][0].squeeze().cpu().numpy()
            title = data['title']

            col_idx = col_start

            # Plot original input
            ax_input = axes[row_idx, col_idx]
            im_input = ax_input.imshow(input_grid, cmap='viridis', interpolation='nearest')
            ax_input.set_title(f"{title}\n(Original Input)")
            plt.colorbar(im_input, ax=ax_input)
            ax_input.axis('off')
            col_idx += 1

            # Plot original image if available
            if has_originals:
                ax_orig = axes[row_idx, col_idx]
                if 'original_image' in data:
                    original = data['original_image'][0].squeeze().cpu().numpy()
                    im_orig = ax_orig.imshow(original, cmap='viridis', interpolation='nearest')
                    ax_orig.set_title(f"{title}\n(Original)")
                    plt.colorbar(im_orig, ax=ax_orig)
                else:
                    ax_orig.text(0.5, 0.5, 'No original image',
                               ha='center', va='center', transform=ax_orig.transAxes)
                ax_orig.axis('off')
                col_idx += 1

            # Plot subsurface prediction
            ax = axes[row_idx, col_idx]
            im = ax.imshow(output_grid, cmap='viridis', interpolation='nearest')
            ax.set_title(f"{title}\n(Subsurface Prediction)")
            ax.axis('off')
            plt.colorbar(im, ax=ax)
            col_idx += 1

            # Plot imputation output with points on top
            if has_imputation and 'imputation_output' in data:
                imputation_output = data['imputation_output'][0].squeeze().cpu().numpy()
                ax_imp = axes[row_idx, col_idx]
                im_imp = ax_imp.imshow(imputation_output, cmap='viridis', interpolation='nearest')

                # Add points on top of imputation output
                mask_indices = np.argwhere(input_mask > 0)
                if len(mask_indices) > 0:
                    mask_values = input_grid[tuple(mask_indices.T)]
                    scatter = ax_imp.scatter(mask_indices[:, 1],
                                          mask_indices[:, 0],
                                          c=mask_values,
                                          cmap='viridis',
                                          edgecolors='black',
                                          s=50)

                ax_imp.set_title(f"{title}\n(Imputation Output)")
                ax_imp.axis('off')
                plt.colorbar(im_imp, ax=ax_imp)

        except Exception as e:
            print(f"Error processing image {idx}: {str(e)}")
            current_ax = axes[row_idx, col_idx]
            current_ax.text(0.5, 0.5, f'Error processing image {idx}\n{str(e)}',
                          ha='center', va='center', transform=current_ax.transAxes, color='red')
            current_ax.axis('off')

    # Remove empty subplots
    for idx in range(len(collected_images) * cols_per_item, num_rows * images_per_row * cols_per_item):
        row_idx = idx // (images_per_row * cols_per_item)
        col_idx = idx % (images_per_row * cols_per_item)
        if row_idx < axes.shape[0] and col_idx < axes.shape[1]:
            fig.delaxes(axes[row_idx, col_idx])

    plt.tight_layout()
    plt.show()

In [None]:
plot_collected_images(collected_images, images_per_row=5)

In [None]:
plot_collected_images(collected_images, images_per_row=5)
