In [1]:
import cvxpy as cp
import numpy as np
from scipy.special import rel_entr
import pickle


In [2]:
def solve_Q_new(P: np.ndarray):
  '''
  Compute optimal Q given 3d array P 
  with dimensions coressponding to x1, x2, and y respectively
  '''
  Py = P.sum(axis=0).sum(axis=0)
  Px1 = P.sum(axis=1).sum(axis=1)
  Px2 = P.sum(axis=0).sum(axis=1)
  Px2y = P.sum(axis=0)
  Px1y = P.sum(axis=1)
  Px1y_given_x2 = P/P.sum(axis=(0,2),keepdims=True)
 
  Q = [cp.Variable((P.shape[0], P.shape[1]), nonneg=True) for i in range(P.shape[2])]
  Q_x1x2 = [cp.Variable((P.shape[0], P.shape[1]), nonneg=True) for i in range(P.shape[2])]

  # Constraints that conditional distributions sum to 1
  sum_to_one_Q = cp.sum([cp.sum(q) for q in Q]) == 1

  # Brute force constraints # 
  # [A]: p(x1, y) == q(x1, y) 
  # [B]: p(x2, y) == q(x2, y)

  # Adding [A] constraints
  A_cstrs = []
  for x1 in range(P.shape[0]):
      for y in range(P.shape[2]):
        vars = []
        for x2 in range(P.shape[1]):
          vars.append(Q[y][x1, x2])
        A_cstrs.append(cp.sum(vars) == Px1y[x1,y])
  
  # Adding [B] constraints
  B_cstrs = []
  for x2 in range(P.shape[1]):
      for y in range(P.shape[2]):
        vars = []
        for x1 in range(P.shape[0]):
          vars.append(Q[y][x1, x2])
        B_cstrs.append(cp.sum(vars) == Px2y[x2,y])

  # KL divergence
  Q_pdt_dist_cstrs = [cp.sum(Q) / P.shape[2] == Q_x1x2[i] for i in range(P.shape[2])]


  # objective
  obj = cp.sum([cp.sum(cp.rel_entr(Q[i], Q_x1x2[i])) for i in range(P.shape[2])])
  # print(obj.shape)
  all_constrs = [sum_to_one_Q] + A_cstrs + B_cstrs + Q_pdt_dist_cstrs
  prob = cp.Problem(cp.Minimize(obj), all_constrs)
  prob.solve(verbose=False, max_iters=50000)

  # print(prob.status)
  # print(prob.value)
  # for j in range(P.shape[1]):
  #  print(Q[j].value)

  return np.stack([q.value for q in Q],axis=2)

In [3]:
def gen_binary_data(num_data):
  # 00  0
  # 01  0
  # 10  0
  # 11  1

  x1 = np.random.randint(0, 2, (num_data, 1))
  x2 = np.random.randint(0, 2, (num_data, 1))
  data = {
      'and': (x1, x2, 1 * np.logical_and(x1, x2)),
      'or': (x1, x2, 1 * np.logical_or(x1, x2)),
      'xor': (x1, x2, 1 * np.logical_xor(x1, x2)),
      'unique1': (x1, x2, x1),
      'redundant': (x1, x1, x1),
      'redundant_and_unique1': (np.concatenate([x1, x2], axis=1), x2, 1 * np.logical_and(x1, x2)),
      'redundant_or_unique1': (np.concatenate([x1, x2], axis=1), x2, 1 * np.logical_or(x1, x2)),
      'redundant_xor_unique1': (np.concatenate([x1, x2], axis=1), x2, 1 * np.logical_xor(x1, x2)),
  }
  return data

def convert_data_to_distribution(x1: np.ndarray, x2: np.ndarray, y: np.ndarray):
  assert x1.size == x2.size
  assert x1.size == y.size

  numel = x1.size
  
  x1_discrete, x1_raw_to_discrete = extract_categorical_from_data(x1.squeeze())
  x2_discrete, x2_raw_to_discrete = extract_categorical_from_data(x2.squeeze())
  y_discrete, y_raw_to_discrete = extract_categorical_from_data(y.squeeze())

  joint_distribution = np.zeros((len(x1_raw_to_discrete), len(x2_raw_to_discrete), len(y_raw_to_discrete)))
  for i in range(numel):
    joint_distribution[x1_discrete[i], x2_discrete[i], y_discrete[i]] += 1
  joint_distribution /= np.sum(joint_distribution)

  return joint_distribution, (x1_raw_to_discrete, x2_raw_to_discrete, y_raw_to_discrete)

def extract_categorical_from_data(x):
  supp = set(x)
  raw_to_discrete = dict()
  for i in supp:
    raw_to_discrete[i] = len(raw_to_discrete)
  discrete_data = [raw_to_discrete[x_] for x_ in x]

  return discrete_data, raw_to_discrete 

def MI(P: np.ndarray):
  ''' P has 2 dimensions '''
  margin_1 = P.sum(axis=1)
  margin_2 = P.sum(axis=0)
  outer = np.outer(margin_1, margin_2)

  return np.sum(rel_entr(P, outer))
  # return np.sum(P * np.log(P/outer))

def CoI(P:np.ndarray):
  ''' P has 3 dimensions, in order X1, X2, Y '''
  # MI(Y; X1)
  A = P.sum(axis=1)

  # MI(Y; X2)
  B = P.sum(axis=0)

  # MI(Y; (X1, X2))
  C = P.transpose([2, 0, 1]).reshape((P.shape[2], P.shape[0]*P.shape[1]))

  return MI(A) + MI(B) - MI(C)

def CI(P, Q):
  assert P.shape == Q.shape
  P_ = P.transpose([2, 0, 1]).reshape((P.shape[2], P.shape[0]*P.shape[1]))
  Q_ = Q.transpose([2, 0, 1]).reshape((Q.shape[2], Q.shape[0]*Q.shape[1]))
  return MI(P_) - MI(Q_)

def UI(P, cond_id=0):
  ''' P has 3 dimensions, in order X1, X2, Y 
  We condition on X1 if cond_id = 0, if 1, then X2.
  '''
  P_ = np.copy(P)
  sum = 0.

  if cond_id == 0:
    J= P.sum(axis=(1,2)) # marginal of x1
    for i in range(P.shape[0]):
      sum += MI(P[i,:,:]/P[i,:,:].sum()) * J[i]
  elif cond_id == 1:
    J= P.sum(axis=(0,2)) # marginal of x1
    for i in range(P.shape[1]):
      sum += MI(P[:,i,:]/P[:,i,:].sum()) * J[i]
  else:
    assert False

  return sum

def test(P):
  Q = solve_Q_new(P)
  redundancy = CoI(Q)
  print('Redundancy', redundancy)
  unique_1 = UI(Q, cond_id=1)
  print('Unique', unique_1)
  unique_2 = UI(Q, cond_id=0)
  print('Unique', unique_2)
  synergy = CI(P, Q)
  print('Synergy', synergy)
  return {'redundancy':redundancy, 'unique1':unique_1, 'unique2':unique_2, 'synergy':synergy}

In [15]:
P = np.zeros((2,2,2))
P[:,:,0] = np.eye(2) * 0.25
P[:,:,1] = np.array([[0., 0.25], [0.25, 0.]])
test(P)

Redundancy 7.254531412136951e-10
Synergy 0.6931471798344919
Unique -8.264737925823815e-26
Unique -8.264737925823815e-26


In [16]:
data = gen_binary_data(100000)
P, maps = convert_data_to_distribution(*data['xor'])
test(P)

Redundancy 2.0675621478641933e-07
Synergy 0.6931443092606695
Unique 2.0027052065158908e-07
Unique 7.05164076805474e-11


In [17]:
data = gen_binary_data(1000000)
P, maps = convert_data_to_distribution(*data['and'])
test(P)


Redundancy 0.21606748068594048
Synergy 0.34645409997043997
Unique 9.021767949435756e-06
Unique 1.0704557540107058e-08




In [18]:
P = np.random.uniform(size=(5,4,3))
P = P / np.sum(P)
test(P)

Redundancy 0.005079711220936332
Synergy 0.1224282223359574
Unique 0.022288101617060463
Unique 0.019486346111584556


#### synthetic dataset measures

In [3]:
import pickle

In [12]:
with open('synthetic/DATA_redundancy_cluster.pickle', 'rb') as f:
    dataset = pickle.load(f)
data = (dataset['test']['0'], dataset['test']['1'], dataset['test']['label'])
P, maps = convert_data_to_distribution(*data)
test(P)

Redundancy 0.16721953678967238
Unique 0.009443059215327298
Unique 0.00038723549898728936
Synergy 0.05407597605044967


In [13]:
with open('synthetic/DATA_uniqueness0_cluster.pickle', 'rb') as f:
    dataset = pickle.load(f)
data = (dataset['test']['0'], dataset['test']['1'], dataset['test']['label'])
P, maps = convert_data_to_distribution(*data)
test(P)

Redundancy 0.0024068776832209815
Unique 0.17168997263936803
Unique -3.9762788372355886e-17
Synergy 0.053745916375025876


In [14]:
with open('synthetic/DATA_uniqueness1_cluster.pickle', 'rb') as f:
    dataset = pickle.load(f)
data = (dataset['test']['0'], dataset['test']['1'], dataset['test']['label'])
P, maps = convert_data_to_distribution(*data)
test(P)

Redundancy 0.004109998201795223
Unique 4.739394431722425e-17
Unique 0.16024362610457085
Synergy 0.047320072834599824


In [15]:
with open('synthetic/DATA_synergy_cluster.pickle', 'rb') as f:
    dataset = pickle.load(f)
data = (dataset['test']['0'], dataset['test']['1'], dataset['test']['label'])
P, maps = convert_data_to_distribution(*data)
test(P)

Redundancy 0.0740888062450043
Unique 0.010071939432917378
Unique 0.0005577240092090323
Synergy 0.1381926821869899


In [4]:
results = dict()
for setting in ['redundancy', 'uniqueness0', 'uniqueness1', 'synergy', 'mix1', 'mix2', 'mix3', 'mix4', 'mix5', 'mix6']:
    with open('synthetic/DATA_{}_cluster.pickle'.format(setting), 'rb') as f:
        dataset = pickle.load(f)
    print(setting)
    data = (dataset['test']['0'], dataset['test']['1'], dataset['test']['label'])
    P, maps = convert_data_to_distribution(*data)
    result = test(P)
    results[setting] = result

with open('synthetic/experiments/datasets.pickle', 'wb') as f:
    pickle.dump(results, f)

redundancy
Redundancy 0.16721953678967238
Unique 0.009443059215327298
Unique 0.00038723549898728936
Synergy 0.05407597605044967
uniqueness0
Redundancy 0.0024068776832209815
Unique 0.17168997263936803
Unique -3.9762788372355886e-17
Synergy 0.053745916375025876
uniqueness1
Redundancy 0.004109998201795223
Unique 4.739394431722425e-17
Unique 0.16024362610457085
Synergy 0.047320072834599824
synergy
Redundancy 0.0740888062450043
Unique 0.010071939432917378
Unique 0.0005577240092090323
Synergy 0.1381926821869899
mix1
Redundancy 0.05076955525118054
Unique 0.005163602945388254
Unique 0.0012951574877414992
Synergy 0.07132778066395776
mix2
Redundancy 0.05399747675209331
Unique 0.05768927022274116
Unique 7.126959113402868e-16
Synergy 0.07139916901767802
mix3
Redundancy 0.07978805424537334
Unique 0.008983824324279941
Unique 0.0013647931713488403
Synergy 0.11584163577704519
mix4
Redundancy 0.09086883240778226
Unique 0.0027345520262268383
Unique 0.0056027411357744425
Synergy 0.07627241778313054
mix5
