In this notebook, we run tests for the function "compute_linear_velocity_batch_time_arb_var", which is important for many flow-matching applications hereafter.

# Imports

In [1]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
import time

import torch
import torch.nn.functional as F
from torch import nn, Tensor
# We won't import the new function directly, but instead copy a version here for use.

# Function To Test

In [2]:
def compute_linear_velocity_batch_time_arb_var(
    current_points: torch.Tensor,  # Shape [M, *dims]
    data: torch.Tensor,            # Shape [N, *dims]
    t: torch.Tensor,               # Shape [M]
    sigma_i: float,
    sigma_f: torch.Tensor,         # Shape [N]
    coefficients: torch.Tensor,    # Shape [N]
    return_intermediates: bool = False  # Optional debugging
) -> torch.Tensor | tuple[torch.Tensor, dict]:
    """
    Computes velocity for batched inputs with time as tensor.

    Args:
        current_points: [M, *dims] positions
        data: [N, *dims] target points
        t: [M] batch of time values
        sigma_i: float
        sigma_f: [N] final std devs
        coefficients: [N] mixture weights
        return_intermediates: if True, also return dictionary of intermediates

    Returns:
        velocities: [M, *dims]
        intermediates: dict of named tensors (if return_intermediates=True)
    """
    intermediates = {}

    t_reshaped = t.view(-1, *([1]*(data.dim())))  # [M, 1, *dims]
    t_reshaped_2 = t.unsqueeze(-1)                # [M, 1]
    sigma_f_reshaped = sigma_f.unsqueeze(0)       # [1, N]
    sigma_f_reshaped_2 = sigma_f.view(1, -1, *[1]*(data.dim() - 1))  # [1, N, *dims]
    coefficients_reshaped = coefficients.unsqueeze(0)  # [1, N]

    data_exp = data.unsqueeze(0)                    # [1, N, *dims]
    data_scaled = t_reshaped * data_exp             # [M, N, *dims]
    current_expanded = current_points.unsqueeze(1)  # [M, 1, *dims]
    
    diff = (current_expanded - data_scaled)         # [M, N, *dims]
    squared_dist = torch.sum(diff**2, dim=tuple(range(2, diff.dim())))  # [M, N]

    dims = current_points.shape[1:]  # This gets *dims
    net_dim_coeff = torch.prod(torch.tensor(dims)).item()*0.5
    
    denominator = (1 - t_reshaped_2)**2 * sigma_i + t_reshaped_2**2 * sigma_f_reshaped  # [M, N]
    logits = -0.5 * squared_dist / (denominator)
    logit_interm = -0.5 * squared_dist / (denominator)
    coefficient_term = torch.log((1 - t_reshaped_2) * coefficients_reshaped) - net_dim_coeff * torch.log(denominator)
    logits += torch.log((1 - t_reshaped_2) * coefficients_reshaped) - net_dim_coeff * torch.log(denominator)

    weights = torch.softmax(logits, dim=1)  # [M, N]

    denominator_2 = (1 - t_reshaped)**2 * sigma_i + t_reshaped**2 * sigma_f_reshaped_2  # [M, N, *dims]
    x_num = t_reshaped * sigma_f_reshaped_2 - (1 - t_reshaped) * sigma_i                # [M, N, *dims]
    data_num = (1 - t_reshaped) * sigma_i                                               # [M, 1, *dims]
    net_weight_vec = (current_expanded * x_num + data_num * data_exp) / denominator_2  # [M, N, *dims]

    velocities = torch.sum(weights.unsqueeze(-1) * net_weight_vec, dim=1)  # [M, *dims]

    if return_intermediates:
        intermediates = {
            "t_reshaped": t_reshaped,
            "t_reshaped_2": t_reshaped_2,
            "sigma_f_reshaped": sigma_f_reshaped,
            "sigma_f_reshaped_2": sigma_f_reshaped_2,
            "coefficients_reshaped": coefficients_reshaped,
            "data_exp": data_exp,
            "data_scaled": data_scaled,
            "current_expanded": current_expanded,
            "diff": diff,
            "squared_dist": squared_dist,
            "denominator": denominator,
            "logits": logits,
            "weights": weights,
            "denominator_2": denominator_2,
            "x_num": x_num,
            "data_num": data_num,
            "net_weight_vec": net_weight_vec,
            "velocities": velocities,
            "logit_interm": logit_interm,
            "coefficient_term": coefficient_term
        }
        return velocities, intermediates

    return velocities


# Test Case

In [3]:
# Define test case
current_points = torch.tensor([[0.0, 0.0], [1.0, 0.0], [2.0, 0.0]])  # M = 3, D = 2
data = torch.tensor([[0.0, 1.0], [1.0, 1.0]])             # N = 2, D = 2
t = torch.tensor([0.1, 0.5, 0.9])
sigma_i = 1.0
sigma_f = torch.tensor([1.0, 2.0])
coefficients = torch.tensor([0.5, 1.0])

In [4]:
# Run function with return_intermediates=True
velocities, intermediates = compute_linear_velocity_batch_time_arb_var(
    current_points, data, t, sigma_i, sigma_f, coefficients, return_intermediates=True
)

## Squared Dist Check

For each pair (current_point, data_point), scale by t, compute diff, and manually compute squared diff.

    current_point = [0,0], t=0.1

        With data_point [0,1]: [0,0] - 0.1*[0,1] = [0,-0.1] → squared_dist = 0 + 0.01 = 0.01

        With data_point [1,1]: [0,0] - 0.1*[1,1] = [-0.1,-0.1] → squared_dist = 0.01 + 0.01 = 0.02

    current_point = [1,0], t=0.5

        With data_point [0,1]: [1,0] - 0.5*[0,1] = [1,-0.5] → squared_dist = 1 + 0.25 = 1.25

        With data_point [1,1]: [1,0] - 0.5*[1,1] = [0.5,-0.5] → squared_dist = 0.25 + 0.25 = 0.5

    current_point = [2,0], t=0.9

        With data_point [0,1]: [2,0] - 0.9*[0,1] = [2,-0.9] → squared_dist = 4 + 0.81 = 4.81

        With data_point [1,1]: [2,0] - 0.9*[1,1] = [1.1,-0.9] → squared_dist = 1.21 + 0.81 = 2.02


In [5]:
# Check squared distances
print("Computed squared_dist:")
print(intermediates["squared_dist"])
print("\nExpected squared_dist:")
print(torch.tensor([[0.01, 0.02],
                    [1.25, 0.50],
                    [4.81, 2.02]]))

Computed squared_dist:
tensor([[0.0100, 0.0200],
        [1.2500, 0.5000],
        [4.8100, 2.0200]])

Expected squared_dist:
tensor([[0.0100, 0.0200],
        [1.2500, 0.5000],
        [4.8100, 2.0200]])


## Denominator Check

For current_point [0,0] (t=0.1):

    denominator = (0.9²)*1 + (0.1²)*σ_f = [0.81 + 0.01*1, 0.81 + 0.01*2] = [0.82, 0.83]

For current_point [1,0] (t=0.5):

    denominator = (0.5²)*1 + (0.5²)*σ_f = [0.25 + 0.25*1, 0.25 + 0.25*2] = [0.5, 0.75]

For current_point [2,0] (t=0.9):

    denominator = (0.1²)*1 + (0.9²)*σ_f = [0.01 + 0.81*1, 0.01 + 0.81*2] = [0.82, 1.63]

In [6]:
# Print computed values
print("Computed denominator:")
print(intermediates["denominator"])
# Manual calculations
manual_denominator = torch.tensor([
    [(1-0.1)**2 * 1 + (0.1)**2 * 1, (1-0.1)**2 * 1 + (0.1)**2 * 2],
    [(1-0.5)**2 * 1 + (0.5)**2 * 1, (1-0.5)**2 * 1 + (0.5)**2 * 2],
    [(1-0.9)**2 * 1 + (0.9)**2 * 1, (1-0.9)**2 * 1 + (0.9)**2 * 2]
])

print("\nManual denominator:")
print(manual_denominator)

Computed denominator:
tensor([[0.8200, 0.8300],
        [0.5000, 0.7500],
        [0.8200, 1.6300]])

Manual denominator:
tensor([[0.8200, 0.8300],
        [0.5000, 0.7500],
        [0.8200, 1.6300]])


## Logit Check

There are [M, N] logits. For each current point, there are two squared distances with the 2 target points. We first focus on

    logit = -0.5 * squared_dist / (denominator)

For current_point [0,0] (t=0.1):

    logit = -0.5 * [0.0100, 0.0200]/[0.82, 0.83] = [-0.006098, -0.012048]

For current_point [1,0] (t=0.5):

    logit = -0.5 * [1.25, 0.5]/[0.5, 0.75] = [-1.25, -0.333333]

For current_point [2,0] (t=0.9):

    logit = -0.5 * [4.8100, 2.0200]/[0.82, 1.63] = [-2.932927, -0.619632]

In [7]:
# Print computed values
print("Computed logit interm:")
print(intermediates["logit_interm"])
# Manual calculations
manual_logit_interm = torch.tensor([
    [-0.006098, -0.012048],
    [-1.25, -0.333333],
    [-2.932927, -0.619632]
])

print("\nManual logit interm:")
print(manual_logit_interm)

Computed logit interm:
tensor([[-0.0061, -0.0120],
        [-1.2500, -0.3333],
        [-2.9329, -0.6196]])

Manual logit interm:
tensor([[-0.0061, -0.0120],
        [-1.2500, -0.3333],
        [-2.9329, -0.6196]])


We have the correct distance based logit terms. Next the coefficient terms (for a 2d problem):

   `coefficient_term = torch.log((1 - t_reshaped_2) * coefficients_reshaped) - torch.log(denominator)`

### 1. For current_point [0,0] (t=0.1):

```
(1-t)*coefficients = 0.9 * [0.5, 1.0] = [0.45, 0.9]
log([0.45, 0.9]) ≈ [-0.798508, -0.105361]
-log(denominator) = -log([0.82, 0.83]) ≈ [0.198446, 0.186372]
```


### 2. For current_point [1,0] (t=0.5):

```
(1-t)*coefficients = 0.5 * [0.5, 1.0] = [0.25, 0.5]
log([0.25, 0.5]) ≈ [-1.386294, -0.693147]
-log(denominator) = -log([0.5, 0.75]) ≈ [0.693147, 0.287682]`  
```


### 3. For current_point [2,0] (t=0.9):

```
(1-t)*coefficients = 0.1 * [0.5, 1.0] = [0.05, 0.1]
log([0.05, 0.1]) ≈ [-2.995732, -2.302585]
-log(denominator) = -log([0.82, 1.63]) ≈ [0.198446, -0.48858001]` 
```


In [17]:
# Print computed values
print("Computed coefficient_term:")
print(intermediates["coefficient_term"])
# Manual calculations
manual_coefficient_term = torch.tensor([
    [-0.798508 + 0.198446, -0.105361 + 0.186372],
    [-1.386294 + 0.693147, -0.693147 + 0.287682],
    [-2.995732 + 0.198446, -2.302585 + -0.4885800148186709]
])
print("\nManual coefficient_term:")
print(manual_coefficient_term)

Computed coefficient_term:
tensor([[-0.6001,  0.0810],
        [-0.6931, -0.4055],
        [-2.7973, -2.7912]])

Manual coefficient_term:
tensor([[-0.6001,  0.0810],
        [-0.6931, -0.4055],
        [-2.7973, -2.7912]])


They match quite closely. Finally, we compare the logits

In [18]:
# Print computed values
print("Computed logits:")
print(intermediates["logits"])
# Manual calculations
manual_logits = manual_logit_interm + manual_coefficient_term

print("\nManual logit interm:")
print(manual_logits)

Computed logits:
tensor([[-0.6062,  0.0689],
        [-1.9431, -0.7388],
        [-5.7302, -3.4108]])

Manual logit interm:
tensor([[-0.6062,  0.0690],
        [-1.9431, -0.7388],
        [-5.7302, -3.4108]])


Now we compute the manual softmax weights from the logits and compare them with the computed weights.

## Softmax Check
Given the logits:
```
[
  [-0.6062,  0.0690],  # t=0.1
  [-1.9431, -0.7388],  # t=0.5
  [-5.7302, -3.4108]   # t=0.9
]
```

The softmax is calculated as:
```
weights = exp(logits) / sum(exp(logits), dim=1)
```

#### 1. For t=0.1:
```
- `exp(-0.6062) ≈ 0.5456`
- `exp(0.0690) ≈ 1.0714`
- `sum = 0.5456 + 1.0714 ≈ 1.6170`
- `weights = [0.5456/1.6170, 1.0714/1.6170] ≈ [0.3374, 0.6626]`
```

#### 2. For t=0.5:
```
- `exp(-1.9431) ≈ 0.1431`
- `exp(-0.7388) ≈ 0.4776`
- `sum = 0.1431 + 0.4776 ≈ 0.6207`
- `weights = [0.1431/0.6207, 0.4776/0.6207] ≈ [0.2305, 0.7695]`
```

#### 3. For t=0.9:
```
- `exp(-5.7302) ≈ 0.0033`
- `exp(-3.4108) ≈ 0.0330`
- `sum = 0.0033 + 0.0330 ≈ 0.0363`
- `weights = [0.0033/0.0363, 0.0330/0.0363] ≈ [0.0909, 0.9091]`
```

Yielding the following final logits
```
tensor([[0.3374, 0.6626],
        [0.2305, 0.7695],
        [0.0909, 0.9091]])
```


In [20]:
print ("Exponential of logits:")
print (torch.exp(manual_logits))

# Print computed values
print("Computed weights:")
print(intermediates["weights"])

manual_weights = torch.tensor([
    [0.3374, 0.6626],
    [0.2305, 0.7695],
    [0.0909, 0.9091]
])
print ("Manually computed weights:")
print (manual_weights)

Exponential of logits:
tensor([[0.5454, 1.0714],
        [0.1433, 0.4777],
        [0.0032, 0.0330]])
Computed weights:
tensor([[0.3374, 0.6626],
        [0.2307, 0.7693],
        [0.0895, 0.9105]])
Manually computed weights:
tensor([[0.3374, 0.6626],
        [0.2305, 0.7695],
        [0.0909, 0.9091]])


Which shows good agreement between the two.

## Weighted Vector Check
Next, we verify that the net weight vector is as expected.

In [11]:
#denominator_2 = (1 - t_reshaped)**2 * sigma_i + t_reshaped**2 * sigma_f_reshaped_2  # [M, N, *dims]
#x_num = t_reshaped * sigma_f_reshaped_2 - (1 - t_reshaped) * sigma_i                # [M, N, *dims]
#data_num = (1 - t_reshaped) * sigma_i                                               # [M, 1, *dims]
#net_weight_vec = (current_expanded * x_num + data_num * data_exp) / denominator_2  # [M, N, *dims]

print ("Denominator:")
print(intermediates["denominator_2"])

print ("X vector pre-factor:")
print(intermediates["x_num"])

print ("Data pre-factor:")
print(intermediates["data_num"])

Denominator:
tensor([[[0.8200],
         [0.8300]],

        [[0.5000],
         [0.7500]],

        [[0.8200],
         [1.6300]]])
X vector pre-factor:
tensor([[[-0.8000],
         [-0.7000]],

        [[ 0.0000],
         [ 0.5000]],

        [[ 0.8000],
         [ 1.7000]]])
Data pre-factor:
tensor([[[0.9000]],

        [[0.5000]],

        [[0.1000]]])


The denominator was previously calculated and shows good agreement. 

For the data prefactors, we obtain tensors of the shape $(1 - t) \sigma_i$ as expected.

For the x prefactors, we explicitly calculate. 

### 1. For t=0.1:
```
t * sigma_f - (1 - t) * sigma_i = 0.1 * [1, 2] - (0.9) * [1, 1] = [-0.8, -0.7]
```
### 2. For t=0.5:
```
t * sigma_f - (1 - t) * sigma_i = 0.5 * [1, 2] - (0.5) * [1, 1] = [0, 0.5]
```
### 3. For t=0.9:
```
t * sigma_f - (1 - t) * sigma_i = 0.9 * [1, 2] - (0.1) * [1, 1] = [0.8, 1.7]
```
which matches our results.

Next, we check net vector, given by net_weight_vec = (x_pre-factor * x_num + data_pre-factor * data) / denominator 

### For t=0.1 (first batch element):
Current point: [0,0]
Data points: [0,1] and [1,1]

#### For data point [0,1]:
```
numerator = [0,0] * [-0.8] + 0.9 * [0,1] = [0,0.9]
denominator = 0.82
result = [0/0.82, 0.9/0.82] ≈ [0.0000, 1.0976]
```

#### For data point [1,1]:
```
numerator = [0,0] * [-0.7] + 0.9 * [1,1] = [0.9,0.9]
denominator = 0.83
result = [0.9/0.83, 0.9/0.83] ≈ [1.0843, 1.0843]
```

### For t=0.5 (second batch element):
Current point: [1,0]

#### For data point [0,1]:
```
numerator = [1,0] * [0] + 0.5 * [0,1] = [0,0.5]
denominator = 0.50
result = [0/0.5, 0.5/0.5] = [0.0000, 1.0000]
```

#### For data point [1,1]:
```
numerator = [1,0] * [0.5] + 0.5 * [1,1] = [0.5+0.5, 0+0.5] = [1.0,0.5]
denominator = 0.75
result = [1.0/0.75, 0.5/0.75] ≈ [1.3333, 0.6667]
```

### For t=0.9 (third batch element):
Current point: [2,0]

#### For data point [0,1]:
```
numerator = [2,0] * [0.8] + 0.1 * [0,1] = [1.6,0.1]
denominator = 0.82
result = [1.6/0.82, 0.1/0.82] ≈ [1.9512, 0.1220]
```

#### For data point [1,1]:
```
numerator = [2,0] * [1.7] + 0.1 * [1,1] = [3.4+0.1, 0+0.1] = [3.5,0.1]
denominator = 1.63
result = [3.5/1.63, 0.1/1.63] ≈ [2.1472, 0.0613]
```

In [21]:
print ("Net Vector (To be used with Softmax:")
print(intermediates["net_weight_vec"])

manual_net_weight_vec = torch.tensor([
    [[0.0000, 1.0976], [1.0843, 1.0843]],   # t=0.1
    [[0.0000, 1.0000], [1.3333, 0.6667]],    # t=0.5
    [[1.9512, 0.1220], [2.1472, 0.0613]]     # t=0.9
])
print ("Manual calculation:")
print (manual_net_weight_vec)

Net Vector (To be used with Softmax:
tensor([[[0.0000, 1.0976],
         [1.0843, 1.0843]],

        [[0.0000, 1.0000],
         [1.3333, 0.6667]],

        [[1.9512, 0.1220],
         [2.1472, 0.0613]]])
Manual calculation:
tensor([[[0.0000, 1.0976],
         [1.0843, 1.0843]],

        [[0.0000, 1.0000],
         [1.3333, 0.6667]],

        [[1.9512, 0.1220],
         [2.1472, 0.0613]]])


They all match!  

Now we compare the final velocities

## Velocity Check

#### 1. For t=0.1:
```
weights = [0.3360, 0.6640]
net_weight_vec = [[0.0000, 1.0976], [1.0843, 1.0843]]

velocity = (0.3360 * [0.0000, 1.0976]) + (0.6640 * [1.0843, 1.0843])
         = [0.0000, 0.3688] + [0.7200, 0.7200]
         = [0.7200, 1.0888]
```

#### 2. For t=0.5:
```
weights = [0.1967, 0.8033]
net_weight_vec = [[0.0000, 1.0000], [1.3333, 0.6667]]

velocity = (0.1967 * [0.0000, 1.0000]) + (0.8033 * [1.3333, 0.6667])
         = [0.0000, 0.1967] + [1.0711, 0.5355]
         = [1.0711, 0.7322]
```

#### 3. For t=0.9:
```
weights = [0.0642, 0.9358]
net_weight_vec = [[1.9512, 0.1220], [2.1472, 0.0613]]

velocity = (0.0642 * [1.9512, 0.1220]) + (0.9358 * [2.1472, 0.0613])
         = [0.1253, 0.0078] + [2.0093, 0.0574]
         = [2.1346, 0.0652]
```

In [22]:
# Manual inputs
manual_net_weight_vec = torch.tensor([
    [[0.0000, 1.0976], [1.0843, 1.0843]],
    [[0.0000, 1.0000], [1.3333, 0.6667]],
    [[1.9512, 0.1220], [2.1472, 0.0613]]
])

# Manual velocity calculation
manual_velocities = torch.sum(manual_weights.unsqueeze(-1) * manual_net_weight_vec, dim=1)

print("Manual velocities:")
print(manual_velocities)

# Compare with function output
print("\nFunction's velocities:")
print(intermediates["velocities"])
print("\nDifference:")
print(intermediates["velocities"] - manual_velocities)

Manual velocities:
tensor([[0.7185, 1.0888],
        [1.0260, 0.7435],
        [2.1294, 0.0668]])

Function's velocities:
tensor([[0.7185, 1.0888],
        [1.0257, 0.7436],
        [2.1297, 0.0668]])

Difference:
tensor([[ 6.6757e-05,  1.1325e-05],
        [-2.4438e-04,  4.1902e-05],
        [ 3.0661e-04, -4.2379e-05]])


Which matches within the accuracy of our operations! This confirms our function does as we expect it to behave.