In [1]:
import torch
from torch.cuda.amp import autocast
import torch_scatter

import MGBlurr.blurring as blur
from MGTomo import gridop
from MGTomo.gridop import RBox as R, PBox as P
from skimage import data
from skimage.transform import resize

In [2]:
def orthant_bounds(xh, xH, P_inf, lh, P_nonzero = None):
    if P_nonzero is None:
        coarse_dim = xH.shape[0]
        P_nonzero = gridop.compute_nonzero_elements_of_P(coarse_dim)
    
    lH = torch.zeros_like(xH)

    for col_coord, indices in P_nonzero.items():
        rows, cols = zip(*indices)
        
        rows = torch.tensor(rows)
        cols = torch.tensor(cols)
        
        diffs = xh[rows, cols]
        lmax = torch.max(lh[rows, cols] - diffs)
        
        lH[col_coord] = xH[col_coord] + lmax / P_inf
    return lH

In [11]:
def orthant_bounds_optimized(xh, xH, P_inf, lh, P_nonzero=None):
    coarse_dim = xH.shape[0]

    if P_nonzero is None:
        P_nonzero = gridop.compute_nonzero_elements_of_P(coarse_dim)
    
    lH = torch.zeros_like(xH)

    all_rows = []
    all_cols = []
    col_coords_flat = []
    col_coords = []

    for (x,y), indices in P_nonzero.items():
        rows, cols = zip(*indices)
        all_rows.extend(rows)
        all_cols.extend(cols)
        col_coords_flat.extend([x*coarse_dim + y]*len(rows))
        col_coords.append((x,y))

    all_rows_tensor = torch.tensor(all_rows)
    all_cols_tensor = torch.tensor(all_cols)
    all_col_coords = torch.tensor(col_coords_flat)

    rowsH_tensor, colsH_tensor = torch.tensor(col_coords).unbind(dim=1)

    diffs = xh[all_rows_tensor, all_cols_tensor]
  
    lmax = torch_scatter.scatter_max(lh[all_rows_tensor, all_cols_tensor] - diffs, all_col_coords, dim = 0)[0]
    lH[rowsH_tensor, colsH_tensor] = xH[rowsH_tensor, colsH_tensor] + lmax / P_inf

    return lH


In [4]:
def box_bounds_optimized(xh, xH, P_inf, lh, uh, P_nonzero=None):
    if P_nonzero is None:
        coarse_dim = xH.shape[0]
        P_nonzero = gridop.compute_nonzero_elements_of_P(coarse_dim)
    
    lH = torch.zeros_like(xH)
    uH = torch.zeros_like(xH)
    coarse_dim = xH.shape[0]

    # Collect all the rows and columns indices in one go
    all_rows = []
    all_cols = []
    col_coords_flat = []
    col_coords = []

    for (x,y), indices in P_nonzero.items():
        rows, cols = zip(*indices)
        all_rows.extend(rows)
        all_cols.extend(cols)
        col_coords_flat.extend([x*coarse_dim + y]*len(rows))
        col_coords.append((x,y))

    # Convert to tensors once
    all_rows_tensor = torch.tensor(all_rows)
    all_cols_tensor = torch.tensor(all_cols)
    all_col_coords = torch.tensor(col_coords_flat)

    rowsH, colsH = zip(*col_coords)
    rowsH_tensor = torch.tensor(rowsH)
    colsH_tensor = torch.tensor(colsH)

    # Calculate diffs in one go
    diffs = xh[all_rows_tensor, all_cols_tensor]
  
    lmax = torch_scatter.scatter_max(lh[all_rows_tensor, all_cols_tensor] - diffs, all_col_coords, dim = 0)[0]
    umin = torch_scatter.scatter_min(uh[all_rows_tensor, all_cols_tensor] - diffs, all_col_coords, dim = 0)[0]
    lH[rowsH_tensor, colsH_tensor] = xH[rowsH_tensor, colsH_tensor] + lmax / P_inf
    uH[rowsH_tensor, colsH_tensor] = xH[rowsH_tensor, colsH_tensor] + umin / P_inf

    return lH, uH

In [5]:
N = 1023
max_levels = 1
maxIter = [1,2,16,32,64,128]
kernel_size = 33
sigma = 10

# load image
x_orig = data.camera()
x_orig = resize(x_orig, (N,N), anti_aliasing = False)

x_torch = torch.tensor(x_orig, requires_grad = True)

A = [blur.GaussianBlurOperator(N, kernel_size, sigma)]
b = [torch.poisson(A[0](x_torch)*50)/50]
P_nonzero = []

fine_dim = N
for i in range(1, max_levels+1):
    coarse_dim = blur.reduce_dim(fine_dim)
    A.append(blur.GaussianBlurOperator(coarse_dim, kernel_size, sigma))
    rhs = resize(b[-1].detach().numpy(), (coarse_dim, coarse_dim), anti_aliasing=False)
    b.append(torch.tensor(rhs, requires_grad=True)) #maybe use a different way to define bH
    P_nonzero.append(gridop.compute_nonzero_elements_of_P(coarse_dim))
    fine_dim = coarse_dim

In [6]:
xh = torch.rand(1023,1023)
xH = gridop.R(xh)


lh = torch.zeros_like(xh)
uh = torch.ones_like(xh)

P_inf = 1

In [7]:
orthant_bounds(xh, xH, 1, lh, P_nonzero[0])

tensor([[10.1860,  8.9615,  7.3578,  ..., 10.7755,  8.3521,  8.8985],
        [ 9.9855,  8.3821,  9.0785,  ...,  9.4574,  5.6935,  5.3180],
        [ 5.7296,  7.3632,  9.4823,  ...,  5.2250,  3.7747,  7.3206],
        ...,
        [ 8.4354,  8.1098,  9.9939,  ...,  7.0442,  6.2207,  9.9500],
        [ 8.3429,  9.8839, 10.8665,  ...,  9.8678, 10.8622,  9.0042],
        [ 7.1527, 10.2624,  7.0788,  ...,  7.8803,  8.1784,  5.1626]])

In [12]:
orthant_bounds_optimized(xh, xH, 1, lh, P_nonzero[0])

261121


tensor([[10.1860,  8.9615,  7.3578,  ..., 10.7755,  8.3521,  8.8985],
        [ 9.9855,  8.3821,  9.0785,  ...,  9.4574,  5.6935,  5.3180],
        [ 5.7296,  7.3632,  9.4823,  ...,  5.2250,  3.7747,  7.3206],
        ...,
        [ 8.4354,  8.1098,  9.9939,  ...,  7.0442,  6.2207,  9.9500],
        [ 8.3429,  9.8839, 10.8665,  ...,  9.8678, 10.8622,  9.0042],
        [ 7.1527, 10.2624,  7.0788,  ...,  7.8803,  8.1784,  5.1626]])

In [36]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [37]:
device

device(type='cuda')

In [38]:
xh = torch.rand(1023,1023, device=device)
xH = gridop.R(xh)


lh = torch.zeros_like(xh, device=device)
uh = torch.ones_like(xh, device=device)

P_inf = 1

RuntimeError: Input type (torch.cuda.FloatTensor) and weight type (torch.FloatTensor) should be the same

In [None]:
import torch
import torch_scatter  # Ensure you have torch_scatter installed

def orthant_bounds_optimized(xh, xH, P_inf, lh, P_nonzero=None):
    if P_nonzero is None:
        coarse_dim = xH.shape[0]
        P_nonzero = gridop.compute_nonzero_elements_of_P(coarse_dim)
    
    lH = torch.zeros_like(xH)

    # Flatten indices from P_nonzero
    all_rows = []
    all_cols = []
    col_coord_list = []

    for col_coord, indices in P_nonzero.items():
        rows, cols = zip(*indices)
        all_rows.append(torch.tensor(rows, device=xh.device))  # Ensure device consistency
        all_cols.append(torch.tensor(cols, device=xh.device))  # Ensure device consistency
        col_coord_list.append(col_coord)

    all_rows = torch.cat(all_rows)
    all_cols = torch.cat(all_cols)

    # Compute diffs in one go
    diffs = xh[all_rows, all_cols]
    lh_selected = lh[all_rows, all_cols]

    # Calculate lmax across all selected rows and columns
    lmax_per_col = torch_scatter.scatter_max(lh_selected - diffs, torch.tensor(col_coord_list, device=xh.device), dim=0)[0]

    # Initialize lmax_per_col with -inf to handle any cases where col_coord might be missing
    final_lmax_per_col = torch.full((xH.shape[0],), float('-inf'), device=xh.device)  # Change here

    # Place the lmax_per_col values back into final_lmax_per_col
    final_lmax_per_col.scatter_(0, torch.tensor(col_coord_list, device=xh.device), lmax_per_col)

    # Update lH in one go
    lH[torch.tensor(col_coord_list, device=xh.device)] = xH[torch.tensor(col_coord_list, device=xh.device)] + final_lmax_per_col / P_inf
    
    return lH


In [None]:
orthant_bounds(xh, xH, 1, lh, P_nonzero[0])