# Vectorized Broadway HMCMC

In [None]:
try:
    from google.colab import drive
    import os
    drive.mount('/content/drive')
    os.chdir('drive/MyDrive/School/DS-GA 1006/code')
    print(os.getcwd())
except:
  pass

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
/content/drive/MyDrive/School/DS-GA 1006/code


In [None]:
# ! pip install -r requirements.txt

### imports

In [None]:
import tensorflow as tf
import pandas as pd
import tensorflow_probability as tfp
from tqdm import tqdm
from functools import lru_cache
import numpy as np
from collections import defaultdict
import copy
import gc

In [None]:
tf.config.run_functions_eagerly(True)
SEED = 4
tf.random.set_seed(SEED)
parameter_initializer = tf.keras.initializers.RandomNormal(
    seed=SEED
)

In [None]:
# Read the data files
node_data = pd.read_csv('data/nyse_node_sp1.csv', header=None,
                        names=['name', 'ever_committee', 'node_id', 'ethnicity', 'ever_sponsor'])
edge_data = pd.read_csv('data/nyse_edge_buy_sp_sp1.csv', header=None,
                        names=['buyer_id', 'sponsor1_id', 'sponsor2_id', 'f1', 'f2', 'f3', 'f4', 'blackballs', 'whiteballs', 'year'])
committee_data = pd.read_csv('data/nyse_edge_buy_com1.csv', header=None,
                             names=['buyer_id', 'committee_id', 'f1', 'f2', 'f3', 'f4', 'blackballs', 'whiteballs', 'year'])

In [None]:
def process_data(node_data, edge_data, committee_data):
    node_data = pd.get_dummies(data=node_data, columns=['ethnicity'], dummy_na=True, prefix='ethnicity', drop_first=True, dtype=int)
    node_data[['ever_committee', 'ever_sponsor']] = node_data[['ever_committee', 'ever_sponsor']].fillna(0)
    node_attrs = node_data.set_index('node_id').drop(columns=['name']).T.to_dict('list')

    # Initialize network statistics
    network_stats = {node_id: {'degree': 0, 'sponsor_count': 0} for node_id in node_attrs}
    edges = defaultdict(set)

    transactions = []
    for _, row in edge_data.iterrows():
        buyer_id = row['buyer_id']
        sponsor1_id = row['sponsor1_id']
        sponsor2_id = row['sponsor2_id']
        year = row['year']

        # Update network statistics
        network_stats[buyer_id]['degree'] += 2
        network_stats[sponsor1_id]['degree'] += 1
        network_stats[sponsor2_id]['degree'] += 1
        network_stats[sponsor1_id]['sponsor_count'] += 1
        network_stats[sponsor2_id]['sponsor_count'] += 1
        edges[buyer_id].add(sponsor1_id)
        edges[buyer_id].add(sponsor2_id)

        committee_members = committee_data[(committee_data['buyer_id'] == buyer_id) &
                                           (committee_data['year'] == year)]['committee_id'].tolist()

        transactions.append({
            'buyer_id': buyer_id,
            'sponsor1_id': sponsor1_id,
            'sponsor2_id': sponsor2_id,
            'committee_members': committee_members,
            'year': year,
            'whiteballs': row['whiteballs'],
            'blackballs': row['blackballs']
        })

    return node_attrs, transactions, network_stats, edges

node_attrs, transactions, network_stats, edges = process_data(node_data, edge_data, committee_data)
node_stats = {k: v['degree'] for k, v in network_stats.items()}

### Model


In [None]:
class VectorizedVStar(tf.Module):
    def __init__(self):
        super().__init__()

    @tf.function
    def U_star(self, xi, xj, si, sj, theta):
        """
        Vectorized utility function that handles tensors of inputs
        Args:
            xi: [unique_n, 1, x_dim]
            xj: [n, 1, x_dim]
            si: [unique_n, 1]
            sj: [n, 1]
            theta: [param_dim] parameters
        Return:
            U* [unique_n, n] tensor of utility values
        """
        unique_n = tf.shape(xi)[0]
        n = tf.shape(xj)[0]

        # Reshape theta to [param_dim, 1] for matmul
        theta_reshaped = tf.convert_to_tensor(theta, dtype=tf.float16)[:, tf.newaxis]

        # Concatenate inputs along the last dimension
        if len(xi.shape) == 3: #Input should be [batch_size, x_dim]'
            # Reshape si, sj to match broadcasting dimensions
            si_expanded = tf.tile(tf.expand_dims(si, -1), [1, n, 1])  # [unique_n, n, 1]
            sj_expanded = tf.tile(tf.expand_dims(sj, 0), [unique_n, 1, 1])  # [unique_n, n, 1]
            xi_expanded = tf.tile(xi, [1, n, 1])  # [unique_n, n, x_dim]
            xj = tf.squeeze(xj, axis=1)  # Remove the middle dimension first: [n, x_dim]
            xj = tf.expand_dims(xj, 0)   # Add dimension at start: [1, n, x_dim]
            xj_expanded = tf.tile(xj, [unique_n, 1, 1])  # [unique_n, n, x_dim]

            # Concatenate features
            inputs = tf.concat([
                xi_expanded,
                xj_expanded,
                si_expanded,
                sj_expanded
            ], axis=-1)  # [unique_n, n, param_dim]
            assert len(inputs.shape) == 3 and inputs.shape[-1] == theta.shape[0], f'{inputs.shape=} {theta.shape[0]=}'
            inputs = tf.reshape(inputs, [-1, inputs.shape[-1]])  # [unique_n*n, param_dim]

            # Perform dot product using matmul: [batchsize, param_dim] @ [param_dim, 1] -> [batchsize, 1]
            dot_products = tf.matmul(tf.cast(inputs, tf.float16), theta_reshaped)

            # Remove the last dimension and reshape to [unique_n, n]
            result = tf.reshape(dot_products, [unique_n, n])
        else:
            inputs = tf.concat([xi, xj, si[..., tf.newaxis], sj[..., tf.newaxis]], axis=-1)

            result = tf.matmul(tf.cast(inputs, tf.float16), theta_reshaped)

        return result

    @tf.function
    def V_star(self, xi, xj, si, sj, theta):
        """
        Vectorized pseudo-surplus function that handles tensors of inputs
        """
        # Compute both directions simultaneously
        forward = self.U_star(xi, xj, si, sj, theta)
        backward = tf.transpose(self.U_star(xj, xi, sj, si, theta))
        return forward + backward

    def __call__(self, xi, xj, si, sj, theta):
        """
        Main call method that handles both individual pairs and batches
        """
        # If inputs are already tensors, use them directly
        if isinstance(xi, tf.Tensor):
            return self.V_star(xi, xj, si, sj, theta)

        # Convert individual inputs to tensors if needed
        xi = tf.convert_to_tensor(xi, dtype=tf.float16)
        xj = tf.convert_to_tensor(xj, dtype=tf.float16)
        si = tf.convert_to_tensor(si, dtype=tf.float16)
        sj = tf.convert_to_tensor(sj, dtype=tf.float16)
        theta = tf.convert_to_tensor(theta, dtype=tf.float16)

        return self.V_star(xi, xj, si, sj, theta)

class VectorizedPsi(tf.Module):
    def __init__(self):
        super().__init__()

    def __call__(self, H, x_batch, s_batch, w_batch, V_batch, H_batch):
        """
        Vectorized PSI computation for batches of inputs
        Args:
            H: Function that takes a tensor of x values and returns H values with k dimensions
            x_batch: [unique_n, n, x_dim] tensor of x values
            s_batch: [unique_n, n] tensor of s values
            w_batch: [unique_n, n] tensor of weight values
            V_batch: [unique_n, n] tensor of V values
            h_values: [unique_n, n, k] tensor of H values
        Returns:
            [unique_n, n, k] tensor of PSI values
        """
        unique_n, n, x_dim = tf.shape(x_batch)[0], tf.shape(x_batch)[1], tf.shape(x_batch)[2]
        k_dim = tf.shape(H_batch)[2]

        # Flatten the inputs
        x_flat = tf.reshape(x_batch, [-1, x_dim])  # [unique_n*n, x_dim]
        s_flat = tf.reshape(s_batch, [-1])         # [unique_n*n]
        w_flat = tf.reshape(w_batch, [-1])         # [unique_n*n]
        V_flat = tf.reshape(V_batch, [-1])         # [unique_n*n]
        H_values = tf.reshape(H_batch, [-1, k_dim])# [unique_n*n, k]

        # Compute numerator
        num = tf.exp(V_flat)[..., tf.newaxis]  # [unique_n*n, 1]
        if tf.reduce_any(tf.math.is_nan(num)) or tf.reduce_any(tf.math.is_inf(num)):
            print(f'{V_flat=}')
            print(f'{num=}')
            raise ValueError('NAN when computing psi: numerator')

        # Compute denominator
        denom = 1 + H_values  # [unique_n*n, k]

        # Expand dimensions for broadcasting
        s_flat = s_flat[..., tf.newaxis]  # [unique_n*n, 1]
        w_flat = w_flat[..., tf.newaxis]  # [unique_n*n, 1]

        # Compute final result
        res_flat = w_flat * s_flat * num / denom # [unique_n*n, k]
        if tf.reduce_any(tf.math.is_nan(res_flat)) or tf.reduce_any(tf.math.is_inf(res_flat)):
            print(f'{w_flat=}')
            print(f'{s_flat=}')
            print(f'{num=}')
            print(f'{H_values=}')
            raise ValueError('NAN when computing psi: all results')

        # Reshape back to original dimensions
        res = tf.reshape(res_flat, [unique_n, n, -1])  # [unique_n, n, k]
        return res

# Helper class to store H values for vectorized lookup


def compute_all_v_values(unique_x, x_tensor, unique_s, s_tensor, theta):
    """Compute V values for all unique pairs with all nodes at once"""

    x_unique = unique_x[:, tf.newaxis, :] # [unique_n, 1, x_dim]
    s_unique = unique_s[:, tf.newaxis]   # [unique_n, 1]
    x_all = x_tensor[:, tf.newaxis, :]    # [n, 1, x_dim]
    s_all = s_tensor[:, tf.newaxis]    # [n, 1]

    return VectorizedVStar()(x_unique, x_all, s_unique, s_all, theta)

def compute_psi_values(unique_x, x_tensor, w_tensor, s_tensor, H, V_values):
    """Compute PSI values for all pairs at once"""
    n_unique = tf.shape(unique_x)[0]
    n_nodes = tf.shape(x_tensor)[0]

    H_values = H(unique_x)

    # Reshape inputs for broadcasting
    x_all = tf.tile(x_tensor[tf.newaxis, ...], [n_unique, 1, 1])  # [unique_n, n, x_dim]
    s_all = tf.tile(s_tensor[tf.newaxis, ...], [n_unique, 1])     # [unique_n, n]
    w_all = tf.tile(w_tensor[tf.newaxis, ...], [n_unique, 1])     # [unique_n, n]
    h_all = tf.tile(H_values[tf.newaxis, ...], [1, n_nodes, H.k_dim])  # [unique_n, n, k]

    # Reshape V values to match
    V_reshaped = tf.reshape(V_values, [n_unique, n_nodes])  # [unique_n, n]

    # Compute PSI values for all pairs
    res = VectorizedPsi()(H, x_all, s_all, w_all, V_reshaped, h_all)  # [unique_n, n]
    return res

def H_star(node_attrs, node_stats, weights, theta, H, max_iter=20):

    # initialize tensors for vectorized processing
    H_current = H
    x_tensor = tf.convert_to_tensor(np.array([node_attrs[i] for i in node_attrs.keys()]), dtype=tf.float16)
    s_tensor = tf.convert_to_tensor(np.array([node_stats[i] for i in node_attrs.keys()]), dtype=tf.float16)
    w_tensor = tf.convert_to_tensor(np.array([weights[i] for i in node_attrs.keys()]), dtype=tf.float16)
    unique_node_pairs = set((tuple(node_attrs[node_id]), node_stats[node_id])
                          for node_id in node_attrs.keys())
    unique_x = tf.convert_to_tensor([x for x, s in unique_node_pairs], dtype=tf.float16)
    unique_s = tf.convert_to_tensor([s for x, s in unique_node_pairs], dtype=tf.float16)

    x_group = defaultdict(list)
    # Grouped indices for each unique x
    for idx, (x, _) in enumerate(unique_node_pairs):
        x_key = VectorizedH.generate_key(x)
        x_group[x_key].append(idx)

    V = tf.stop_gradient(compute_all_v_values(unique_x, x_tensor, unique_s, s_tensor, theta))
    print(f'Computed {V=} \n Using {theta=}')

    for i in tqdm(range(max_iter), desc='H_star', position=0, leave=True):
        H_prev = copy.deepcopy(H_current)

        psi_values = tf.stop_gradient(compute_psi_values(unique_x, x_tensor, w_tensor, s_tensor, H_prev, V))
        # Compute means for each unique x
        psi_mean = {}
        for x, indices in x_group.items():
            selected_psi_values = tf.gather(psi_values, indices, axis=0)
            psi = tf.stop_gradient(tf.reduce_mean(tf.cast(selected_psi_values, tf.float32), axis=[0, 1]))
            psi_mean[x] = tf.cast(psi, tf.float16)

        # Update H lookup
        H_current.update(psi_mean)

        # Check convergence - now comparing full tensors
        diff = tf.stop_gradient(tf.reduce_max(tf.abs(H_current(unique_x) - H_prev(unique_x))))
        if diff < 1e-4:
            print(f"Convergence achieved after {i} iterations.")
            break

    return H_current

# Defined the importance weight function (section 4.2.1 pg 36)
@lru_cache(maxsize=1000)
def calculate_weights(s, node_vals):
    t = tf.convert_to_tensor(list(node_vals))
    indicator_tensor = tf.cast(tf.greater(t, 0), tf.float16)
    denominator = tf.reduce_mean(indicator_tensor).numpy()
    return s / denominator

In [None]:
# Log likelihood contribution for node_id
def vectorized_log_likelihood_contribution(Lijt, x_i, x_j, s_i, s_j, theta, H ):
    """
    Args:
        Lijt: Tensor of shape [batch_size]
        x_i: Tensor of shape [batch_size, feature_dim]
        x_j: Tensor of shape [batch_size, feature_dim]
        s_i: Tensor of shape [batch_size]
        s_j: Tensor of shape [batch_size]
        theta: Model parameters
        H: H function
        v_star: V* function
    """
    # Compute V* for all pairs at once
    V = VectorizedVStar()(x_i, x_j, s_i, s_j, theta)

    # Compute H values for all nodes at once
    H_i = H(x_i)
    H_j = H(x_j)

    # Compute log likelihood contributions vectorized
    ll = 0.5 * Lijt * (V - tf.math.log1p(H_i) - tf.math.log1p(H_j))
    return ll


def log_likelihood_optimized(theta, node_attrs, node_stats, weights, H = None):
    # cast theta to less precision (memory issues and whatnot)
    theta = tf.cast(theta, tf.float16)

    # # initialize H function and attribute tensors
    H = H if H is not None else VectorizedH(k_dim=1, node_attrs=node_attrs, node_stats=node_stats)
    H = H_star(node_attrs, node_stats, weights, theta, H)
    x_i, x_j, s_i, s_j = [], [], [], []
    for i in node_attrs.keys():
        for j in edges[i]:
            if i == j:
                continue
            x_i.append(node_attrs[i])
            x_j.append(node_attrs[j])
            s_i.append(node_stats[i])
            s_j.append(node_stats[j])
    x_i = tf.convert_to_tensor(x_i, dtype=tf.float16)
    x_j = tf.convert_to_tensor(x_j, dtype=tf.float16)
    s_i = tf.convert_to_tensor(s_i, dtype=tf.float16)
    s_j = tf.convert_to_tensor(s_j, dtype=tf.float16)
    Lijt = tf.ones(s_i.shape, dtype=tf.float16)

    # compute likelihood for alledges at once
    edge_ll = vectorized_log_likelihood_contribution(
        Lijt, x_i, x_j, s_i, s_j, theta, H,
    )

    # Sum edge contributions
    # print(f"Is NA or infinity: {tf.reduce_any(tf.math.is_nan(edge_ll)) or tf.reduce_any(tf.math.is_inf(edge_ll))}")
    ll = tf.reduce_sum(tf.cast(edge_ll, tf.float32)) / tf.cast(tf.reduce_sum(Lijt), tf.float32)
    # print(f'Log likelihood: {ll}')


    # Add node-specific terms
    node_terms = tf.math.log(s_i) - tf.math.log1p(H(x_i))
    # Note node_terms should exist because we only consider s_i > 0 and 1 + H > 0
    ll += tf.reduce_sum(tf.cast(node_terms, tf.float32)) / tf.cast(tf.shape(x_i)[0], tf.float32)
    # return tf.math.reduce_sum(theta), H

    return ll, H


In [None]:
# Set up HMC
num_results = 1000
num_burnin_steps = 1000

# Initialize parameters
x = len(list(node_attrs.values())[0])
initial_state = [0]*x*2 + [0, 0]  #x_i, x_j , si, sj
initial_state = tf.Variable(parameter_initializer([len(initial_state)], dtype=tf.float16))
# initial_state = tf.Variable(initial_state)
print(f'{initial_state=}')

# Define the HMC transition kernel
step_size = tf.Variable(0.01)
adaptive_hmc = tfp.mcmc.SimpleStepSizeAdaptation(
    tfp.mcmc.HamiltonianMonteCarlo(
        target_log_prob_fn=log_likelihood_optimized,
        num_leapfrog_steps=10,
        step_size=step_size),
    num_adaptation_steps=int(num_burnin_steps * 0.8))

initial_state=<tf.Variable 'Variable:0' shape=(20,) dtype=float16, numpy=
array([-0.0291   ,  0.04324  , -0.0783   ,  0.00989  ,  0.03323  ,
       -0.04358  ,  0.01701  ,  0.00419  , -0.06586  , -0.0003023,
        0.00598  , -0.01663  ,  0.10504  , -0.0408   ,  0.0409   ,
       -0.03204  , -0.00766  , -0.0664   , -0.02983  , -0.008224 ],
      dtype=float16)>




In [None]:
weights = {node_id: calculate_weights(s > 0, node_stats.values()) for node_id, s in node_stats.items() }

ll, H = log_likelihood_optimized(initial_state, node_attrs, node_stats, weights)
ll

VectorizedH(k_dim=1)

Lookup Values:
  x=0.000000_0.000000_0.000000_0.000000_0.000000_0.000000_0.000000_0.000000_0.000000: H=[1.18]
  x=0.000000_0.000000_0.000000_0.000000_0.000000_0.000000_0.000000_1.000000_0.000000: H=[1.919]
  x=0.000000_0.000000_0.000000_0.000000_0.000000_0.000000_1.000000_0.000000_0.000000: H=[1.984]
  x=0.000000_0.000000_0.000000_0.000000_0.000000_1.000000_0.000000_0.000000_0.000000: H=[1.533]
  x=0.000000_0.000000_0.000000_0.000000_1.000000_0.000000_0.000000_0.000000_0.000000: H=[1.814]
  x=0.000000_0.000000_0.000000_1.000000_0.000000_0.000000_0.000000_0.000000_0.000000: H=[1.681]
  x=0.000000_0.000000_1.000000_0.000000_0.000000_0.000000_0.000000_0.000000_0.000000: H=[1.912]
  x=0.000000_1.000000_0.000000_0.000000_0.000000_0.000000_0.000000_0.000000_0.000000: H=[2.664]
  x=0.000000_1.000000_0.000000_0.000000_0.000000_0.000000_0.000000_1.000000_0.000000: H=[6.52]
  x=0.000000_1.000000_0.000000_0.000000_0.000000_0.000000_1.000000_0.000000_0.000000: H=[6.188]
  x=0

H_star: 100%|██████████| 20/20 [01:21<00:00,  4.06s/it]


<tf.Tensor: shape=(), dtype=float32, numpy=-2858.459>

In [None]:
print(repr(H))
del ll, H

VectorizedH(k_dim=1)

Lookup Values:
  x=0.000000_0.000000_0.000000_0.000000_0.000000_0.000000_0.000000_0.000000_0.000000: H=[1.116]
  x=0.000000_0.000000_0.000000_0.000000_0.000000_0.000000_0.000000_1.000000_0.000000: H=[1.095]
  x=0.000000_0.000000_0.000000_0.000000_0.000000_0.000000_1.000000_0.000000_0.000000: H=[1.1045]
  x=0.000000_0.000000_0.000000_0.000000_0.000000_1.000000_0.000000_0.000000_0.000000: H=[1.204]
  x=0.000000_0.000000_0.000000_0.000000_1.000000_0.000000_0.000000_0.000000_0.000000: H=[1.074]
  x=0.000000_0.000000_0.000000_1.000000_0.000000_0.000000_0.000000_0.000000_0.000000: H=[1.252]
  x=0.000000_0.000000_1.000000_0.000000_0.000000_0.000000_0.000000_0.000000_0.000000: H=[0.876]
  x=0.000000_1.000000_0.000000_0.000000_0.000000_0.000000_0.000000_0.000000_0.000000: H=[1.044]
  x=0.000000_1.000000_0.000000_0.000000_0.000000_0.000000_0.000000_1.000000_0.000000: H=[0.9844]
  x=0.000000_1.000000_0.000000_0.000000_0.000000_0.000000_1.000000_0.000000_0.000000: H=[1.109]
 

In [None]:
class MyHMCMC:
    def __init__(self, num_dims_theta, num_dims_h=1, num_chains=2, verbosity=True, lr=1e-4, clip_val=5):
        self.num_dims_h = num_dims_h
        self.num_chains = num_chains
        self.verbosity = verbosity
        self.num_dims_theta = num_dims_theta
        self.param_state = None
        self.current_H = None
        self.current_ll = None
        self.learning_rate = lr
        self.clip_val = clip_val

    def log_likelihood_wrapper(self, node_attrs, node_stats, weights):
        @tf.function
        def log_prob(theta):

            with tf.GradientTape() as tape:
                tape.watch(theta)
                # theta = tf.clip_by_value(theta, -5.0, 5.0)
                ll, H = log_likelihood_optimized(theta, node_attrs, node_stats, weights, H=self.current_H)

            # Adding gradient clipping
            grads = tape.gradient(ll, theta)
            clipped_grads = tf.clip_by_value([grads], -self.clip_val, self.clip_val)[0]

            # Update H and likelihood
            self.current_H = H
            self.current_ll = ll
            if self.verbosity:
                print(f"Log Likelihood: {ll}")
                print(f"H: {repr(H)}")
                # print(f"Theta: {theta}")
            return ll / 1000 # lieklihood is still too high normalize to allow model to explore its parameter space
        return log_prob

    def run_chain(self, node_attrs, node_stats, weights, burn_in_steps=100, num_results = 20):
        assert self.param_state is not None, 'INitialize parameters first'

        tf.keras.backend.clear_session()

        # Define HMCMC kernel
        step_size = tf.fill([self.num_dims_theta], self.learning_rate) #[0.1, 0.1, ...]
        adaptive_hmc = tfp.mcmc.SimpleStepSizeAdaptation(
            tfp.mcmc.HamiltonianMonteCarlo(
                target_log_prob_fn=self.log_likelihood_wrapper(
                    node_attrs, node_stats, weights
                ),
                num_leapfrog_steps=5,
                step_size=step_size),
            num_adaptation_steps=int(burn_in_steps * 0.8))

        # Run the chain
        # samples, [final_kernel_results] = tfp.mcmc.sample_chain(
        samples = tfp.mcmc.sample_chain(
            num_results=num_results,
            num_burnin_steps=burn_in_steps,
            current_state=self.param_state,
            kernel=adaptive_hmc,
            trace_fn=None,#,lambda _, pkr: [pkr],
            return_final_kernel_results=False)

        return samples#, final_kernel_results

    def optimize(self, node_attrs, node_stats, weights, burn_in_steps=100, num_results = 20):
        # Initialize parameters
        # self.param_state = tf.abs(parameter_initializer([self.num_dims_theta], dtype=tf.float32))

        # Run the HMC Chain
        # samples, final_kernel_results = self.run_chain(node_attrs, node_stats, weights, burn_in_steps, num_results)
        samples= self.run_chain(node_attrs, node_stats, weights, burn_in_steps, num_results)
        # log probability for sampled params
        log_probs = []
        for theta_sample in samples:
            ll, _ = log_likelihood_optimized(
                theta_sample, node_attrs, node_stats, weights, H=self.current_H
            )
            log_probs.append(ll)

        if self.verbosity:
            print(f"Log Likelihood: {log_probs}")
            print(f"H: {repr(self.current_H)}")
            # print(f"Theta: {samples}")
        # # Find the best sample
        best_idx = tf.argmax(log_probs)
        best_sample = samples[best_idx]

        return best_sample, samples, log_probs
        # return samples, log_probs

    def optimize_w_mle(self, node_attrs, node_stats, weights, num_epochs):
        self.param_state = tf.Variable(parameter_initializer([self.num_dims_theta], dtype=tf.float32))
        optimizer = tf.keras.optimizers.Adam(learning_rate=self.learning_rate)
        ll_function = self.log_likelihood_wrapper(node_attrs, node_stats, weights)
        losses = []

        def train_step():
            with tf.GradientTape() as tape:
                # optimizer minimizes a loss function but we want to maximize the log likelihood
                ll = ll_function(self.param_state)

            # Get gradients and update parameters
            grads = tape.gradient(ll, [self.param_state])
            clipped_grads = [tf.clip_by_value(g, -self.clip_val, self.clip_val) for g in grads]
            print(f"{clipped_grads=}, {grads=}")
            optimizer.apply_gradients(zip(clipped_grads, [self.param_state]))

            return ll


        for epoch in range(num_epochs):
            ll = train_step()
            gc.collect()
            losses.append(ll)
            if self.verbosity:
              print(f"Epoch {epoch+1}, Log Likelihood: {-1*ll}")

        return losses

In [None]:
# Set up HMC
num_results = 20
num_burnin_steps = 100

# Initialize parameters
x = len(list(node_attrs.values())[0])
dim_theta = x*2 + 2


hmcmc_optimizer = MyHMCMC(num_dims_theta=dim_theta, )

In [None]:
theta = hmcmc_optimizer.optimize_w_mle(node_attrs, node_stats, weights, num_epochs=20)
theta, hmcmc_optimizer.param_state

VectorizedH(k_dim=1)

Lookup Values:
  x=0.000000_0.000000_0.000000_0.000000_0.000000_0.000000_0.000000_0.000000_0.000000: H=[1.18]
  x=0.000000_0.000000_0.000000_0.000000_0.000000_0.000000_0.000000_1.000000_0.000000: H=[1.919]
  x=0.000000_0.000000_0.000000_0.000000_0.000000_0.000000_1.000000_0.000000_0.000000: H=[1.984]
  x=0.000000_0.000000_0.000000_0.000000_0.000000_1.000000_0.000000_0.000000_0.000000: H=[1.533]
  x=0.000000_0.000000_0.000000_0.000000_1.000000_0.000000_0.000000_0.000000_0.000000: H=[1.814]
  x=0.000000_0.000000_0.000000_1.000000_0.000000_0.000000_0.000000_0.000000_0.000000: H=[1.681]
  x=0.000000_0.000000_1.000000_0.000000_0.000000_0.000000_0.000000_0.000000_0.000000: H=[1.912]
  x=0.000000_1.000000_0.000000_0.000000_0.000000_0.000000_0.000000_0.000000_0.000000: H=[2.664]
  x=0.000000_1.000000_0.000000_0.000000_0.000000_0.000000_0.000000_1.000000_0.000000: H=[6.52]
  x=0.000000_1.000000_0.000000_0.000000_0.000000_0.000000_1.000000_0.000000_0.000000: H=[6.188]
  x=0

H_star: 100%|██████████| 20/20 [01:18<00:00,  3.95s/it]


Log Likelihood: -2858.306640625
H: VectorizedH(k_dim=1)

Lookup Values:
  x=0.000000_0.000000_0.000000_0.000000_0.000000_0.000000_0.000000_0.000000_0.000000: H=[1.116]
  x=0.000000_0.000000_0.000000_0.000000_0.000000_0.000000_0.000000_1.000000_0.000000: H=[1.095]
  x=0.000000_0.000000_0.000000_0.000000_0.000000_0.000000_1.000000_0.000000_0.000000: H=[1.1045]
  x=0.000000_0.000000_0.000000_0.000000_0.000000_1.000000_0.000000_0.000000_0.000000: H=[1.204]
  x=0.000000_0.000000_0.000000_0.000000_1.000000_0.000000_0.000000_0.000000_0.000000: H=[1.074]
  x=0.000000_0.000000_0.000000_1.000000_0.000000_0.000000_0.000000_0.000000_0.000000: H=[1.253]
  x=0.000000_0.000000_1.000000_0.000000_0.000000_0.000000_0.000000_0.000000_0.000000: H=[0.876]
  x=0.000000_1.000000_0.000000_0.000000_0.000000_0.000000_0.000000_0.000000_0.000000: H=[1.044]
  x=0.000000_1.000000_0.000000_0.000000_0.000000_0.000000_0.000000_1.000000_0.000000: H=[0.9844]
  x=0.000000_1.000000_0.000000_0.000000_0.000000_0.000000_1.00

H_star:  30%|███       | 6/20 [00:26<01:02,  4.47s/it]

Convergence achieved after 6 iterations.





Log Likelihood: -2854.419921875
H: VectorizedH(k_dim=1)

Lookup Values:
  x=0.000000_0.000000_0.000000_0.000000_0.000000_0.000000_0.000000_0.000000_0.000000: H=[1.114]
  x=0.000000_0.000000_0.000000_0.000000_0.000000_0.000000_0.000000_1.000000_0.000000: H=[1.093]
  x=0.000000_0.000000_0.000000_0.000000_0.000000_0.000000_1.000000_0.000000_0.000000: H=[1.103]
  x=0.000000_0.000000_0.000000_0.000000_0.000000_1.000000_0.000000_0.000000_0.000000: H=[1.202]
  x=0.000000_0.000000_0.000000_0.000000_1.000000_0.000000_0.000000_0.000000_0.000000: H=[1.072]
  x=0.000000_0.000000_0.000000_1.000000_0.000000_0.000000_0.000000_0.000000_0.000000: H=[1.25]
  x=0.000000_0.000000_1.000000_0.000000_0.000000_0.000000_0.000000_0.000000_0.000000: H=[0.874]
  x=0.000000_1.000000_0.000000_0.000000_0.000000_0.000000_0.000000_0.000000_0.000000: H=[1.041]
  x=0.000000_1.000000_0.000000_0.000000_0.000000_0.000000_0.000000_1.000000_0.000000: H=[0.9814]
  x=0.000000_1.000000_0.000000_0.000000_0.000000_0.000000_1.0000

H_star:  30%|███       | 6/20 [00:27<01:03,  4.55s/it]

Convergence achieved after 6 iterations.





Log Likelihood: -2851.677734375
H: VectorizedH(k_dim=1)

Lookup Values:
  x=0.000000_0.000000_0.000000_0.000000_0.000000_0.000000_0.000000_0.000000_0.000000: H=[1.112]
  x=0.000000_0.000000_0.000000_0.000000_0.000000_0.000000_0.000000_1.000000_0.000000: H=[1.091]
  x=0.000000_0.000000_0.000000_0.000000_0.000000_0.000000_1.000000_0.000000_0.000000: H=[1.101]
  x=0.000000_0.000000_0.000000_0.000000_0.000000_1.000000_0.000000_0.000000_0.000000: H=[1.2]
  x=0.000000_0.000000_0.000000_0.000000_1.000000_0.000000_0.000000_0.000000_0.000000: H=[1.069]
  x=0.000000_0.000000_0.000000_1.000000_0.000000_0.000000_0.000000_0.000000_0.000000: H=[1.247]
  x=0.000000_0.000000_1.000000_0.000000_0.000000_0.000000_0.000000_0.000000_0.000000: H=[0.8716]
  x=0.000000_1.000000_0.000000_0.000000_0.000000_0.000000_0.000000_0.000000_0.000000: H=[1.038]
  x=0.000000_1.000000_0.000000_0.000000_0.000000_0.000000_0.000000_1.000000_0.000000: H=[0.9785]
  x=0.000000_1.000000_0.000000_0.000000_0.000000_0.000000_1.0000

H_star:  30%|███       | 6/20 [00:26<01:02,  4.45s/it]

Convergence achieved after 6 iterations.





Log Likelihood: -2844.71044921875
H: VectorizedH(k_dim=1)

Lookup Values:
  x=0.000000_0.000000_0.000000_0.000000_0.000000_0.000000_0.000000_0.000000_0.000000: H=[1.11]
  x=0.000000_0.000000_0.000000_0.000000_0.000000_0.000000_0.000000_1.000000_0.000000: H=[1.088]
  x=0.000000_0.000000_0.000000_0.000000_0.000000_0.000000_1.000000_0.000000_0.000000: H=[1.099]
  x=0.000000_0.000000_0.000000_0.000000_0.000000_1.000000_0.000000_0.000000_0.000000: H=[1.198]
  x=0.000000_0.000000_0.000000_0.000000_1.000000_0.000000_0.000000_0.000000_0.000000: H=[1.067]
  x=0.000000_0.000000_0.000000_1.000000_0.000000_0.000000_0.000000_0.000000_0.000000: H=[1.245]
  x=0.000000_0.000000_1.000000_0.000000_0.000000_0.000000_0.000000_0.000000_0.000000: H=[0.869]
  x=0.000000_1.000000_0.000000_0.000000_0.000000_0.000000_0.000000_0.000000_0.000000: H=[1.036]
  x=0.000000_1.000000_0.000000_0.000000_0.000000_0.000000_0.000000_1.000000_0.000000: H=[0.9756]
  x=0.000000_1.000000_0.000000_0.000000_0.000000_0.000000_1.00

H_star:  50%|█████     | 10/20 [00:42<00:42,  4.21s/it]

Convergence achieved after 10 iterations.





Log Likelihood: -2837.7724609375
H: VectorizedH(k_dim=1)

Lookup Values:
  x=0.000000_0.000000_0.000000_0.000000_0.000000_0.000000_0.000000_0.000000_0.000000: H=[1.108]
  x=0.000000_0.000000_0.000000_0.000000_0.000000_0.000000_0.000000_1.000000_0.000000: H=[1.086]
  x=0.000000_0.000000_0.000000_0.000000_0.000000_0.000000_1.000000_0.000000_0.000000: H=[1.096]
  x=0.000000_0.000000_0.000000_0.000000_0.000000_1.000000_0.000000_0.000000_0.000000: H=[1.195]
  x=0.000000_0.000000_0.000000_0.000000_1.000000_0.000000_0.000000_0.000000_0.000000: H=[1.064]
  x=0.000000_0.000000_0.000000_1.000000_0.000000_0.000000_0.000000_0.000000_0.000000: H=[1.242]
  x=0.000000_0.000000_1.000000_0.000000_0.000000_0.000000_0.000000_0.000000_0.000000: H=[0.8667]
  x=0.000000_1.000000_0.000000_0.000000_0.000000_0.000000_0.000000_0.000000_0.000000: H=[1.033]
  x=0.000000_1.000000_0.000000_0.000000_0.000000_0.000000_0.000000_1.000000_0.000000: H=[0.9727]
  x=0.000000_1.000000_0.000000_0.000000_0.000000_0.000000_1.0

H_star: 100%|██████████| 20/20 [01:18<00:00,  3.91s/it]


In [None]:
all_params, log_likelihood = hmcmc_optimizer.optimize(
    node_attrs, node_stats, weights,
    num_results=num_results, burn_in_steps=num_burnin_steps
)

In [None]:
weights

In [None]:
import gc
gc.collect()