In [1]:
import tensorflow as tf
from tqdm import tqdm

2024-05-22 22:29:58.886811: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: SSE3 SSE4.1 SSE4.2 AVX AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [2]:
def get_mutual_distance(batch):
    """See: https://stackoverflow.com/a/37040451"""
    r = tf.reduce_sum(tf.square(batch), axis=1)
    # turn r into column vector
    r = tf.reshape(r, [-1, 1])
    return r - 2 * tf.matmul(batch, tf.transpose(batch)) + tf.transpose(r)

In [3]:
x = tf.random.uniform([1000, 3])
v = tf.random.uniform([1000, 3])

In [4]:
def get_neighbours(x, num_nb):
    dist = get_mutual_distance(x)
    return tf.math.top_k(-dist, num_nb+1)[1][:, 1:]

In [5]:
nb = get_neighbours(x, 5)

In [6]:
nb

<tf.Tensor: shape=(1000, 5), dtype=int32, numpy=
array([[409, 378, 548, 796, 987],
       [555, 275, 337, 645, 755],
       [927, 388, 402, 360, 380],
       ...,
       [568, 203,  90, 349, 391],
       [431, 882,  57, 371, 778],
       [463, 566, 835, 715, 545]], dtype=int32)>

In [7]:
# nb[8056, :]

In [8]:
nb_v = tf.gather(v, nb, axis=0)
nb_v.shape

TensorShape([1000, 5, 3])

In [9]:
tf.reduce_mean(nb_v - tf.expand_dims(v, axis=1), axis=1).shape

TensorShape([1000, 3])

In [23]:
def get_dynamics(num_nb, J, dt, T):
    @tf.function
    def dynamics(x, v):
        nb = get_neighbours(x, num_nb)
        nb_v = tf.gather(v, nb, axis=0)
        v += J * (tf.reduce_mean(nb_v, axis=1) - v) * dt
        v += tf.random.normal(tf.shape(v)) * tf.sqrt(2*T*dt)
        x += v * dt
        return x, v
    return dynamics

dynamics = get_dynamics(5, 1e+1, 1e-2, 1e-2)

In [24]:
x = tf.random.uniform([1000, 3])
v = tf.random.uniform([1000, 3])

for _ in tqdm(range(3000)):
    x, v = dynamics(x, v)

100%|██████████████████████████████████████████████████████| 3000/3000 [00:07<00:00, 424.19it/s]


In [25]:
v

<tf.Tensor: shape=(1000, 3), dtype=float32, numpy=
array([[0.26754615, 0.8959832 , 0.70251447],
       [0.73039967, 0.32357755, 0.40126997],
       [0.6559002 , 0.64199996, 0.3803612 ],
       ...,
       [0.3917897 , 0.57835096, 0.5278683 ],
       [0.25851813, 0.71204674, 0.42719024],
       [0.50672305, 0.42646226, 0.37950367]], dtype=float32)>