In [5]:
from src.causal_effect_estimation import causal_effect_estimation
import numpy as np
import torch

In [6]:
%load_ext autoreload
%autoreload 2

In [7]:
%load_ext line_profiler

The line_profiler extension is already loaded. To reload it, use:
  %reload_ext line_profiler


In [10]:
t = torch.tensor([[1.0, 10.0], [2.0, 20.0], [3.0, 30.0], [4.0, 40.0], [5.0, 50.0]])
X = torch.tensor([[2.5, 25.0], [3.5, 35.0]], requires_grad=True)

In [11]:
def interpolated_percentile_full_vectorized(t, X):
    # Ensure t and X have the same number of columns
    assert t.shape[1] == X.shape[1], "Both matrices must have the same number of columns."
    
    sorted_t, _ = torch.sort(t, dim=0)
    
    # Expand dimensions for broadcasting
    sorted_t_exp = sorted_t.unsqueeze(0)
    X_exp = X.unsqueeze(1)

    # Find the positions where elements of X would be inserted into t
    pos = (sorted_t_exp < X_exp).sum(dim=1)
    
    # Handle values outside the range
    outside_range = (pos == 0) | (pos == t.shape[0])
    percentiles_outside = (pos.float() / t.shape[0]) * 100
    
    # Handle values inside the range using linear interpolation
    indices_i = torch.clamp(pos - 1, 0)
    indices_i1 = torch.clamp(pos, 0, t.shape[0]-1)
    
    v_i = torch.gather(sorted_t, 0, indices_i)
    v_i1 = torch.gather(sorted_t, 0, indices_i1)
    
    interpolated_pos = pos - 1 + (X - v_i) / (v_i1 - v_i)
    percentiles_inside = (interpolated_pos.float() / t.shape[0]) * 100
    
    # Combine results
    final_percentiles = torch.where(outside_range, percentiles_outside, percentiles_inside)

    return final_percentiles

# Test the function
#X.grad.zero_()  # Clear previous gradients
percentile_matrix_fully_vectorized = interpolated_percentile_full_vectorized(t, X)
torch.sum(percentile_matrix_fully_vectorized).backward()

X_grad_matrix_fully_vectorized = X.grad
percentile_matrix_fully_vectorized, X_grad_matrix_fully_vectorized


(tensor([[30.0000, 30.0000],
         [50.0000, 50.0000]], grad_fn=<WhereBackward0>),
 tensor([[20.,  2.],
         [20.,  2.]]))

In [14]:
%lprun -f interpolated_percentile_full_vectorized interpolated_percentile_full_vectorized(t, X)

Timer unit: 1e-09 s

Total time: 0.000404856 s
File: /tmp/ipykernel_70640/2057498416.py
Function: interpolated_percentile_full_vectorized at line 1

Line #      Hits         Time  Per Hit   % Time  Line Contents
     1                                           def interpolated_percentile_full_vectorized(t, X):
     2                                               # Ensure t and X have the same number of columns
     3         1      20560.0  20560.0      5.1      assert t.shape[1] == X.shape[1], "Both matrices must have the same number of columns."
     4                                               
     5         1     132262.0 132262.0     32.7      sorted_t, _ = torch.sort(t, dim=0)
     6                                               
     7                                               # Expand dimensions for broadcasting
     8         1      17712.0  17712.0      4.4      sorted_t_exp = sorted_t.unsqueeze(0)
     9         1      13223.0  13223.0      3.3      X_exp = X.unsquee