In [None]:
import numpy as np
import scipy
import math
from scipy.special import rel_entr
from scipy.stats import entropy
import copy
from tqdm import tqdm
import pickle
import matplotlib.pyplot as plt
from bounds import *

### compute bounds for synthetic data

In [None]:
# generate synthetic data

allPs = []
trials = 100
s_max = 0
diff_max = 0
P_max = None
for _ in tqdm(range(trials)):
    k = np.random.exponential(scale=1.0, size=8)
    P = (k / sum(k)).reshape((2,2,2))
    all_quantities, all_bounds = get_bounds(P)
    allPs.append((P, all_quantities, all_bounds))

with open('/content/drive/My Drive/neurips_bounds_data/all_Ps_1000_new_upper.pkl', 'wb') as fp:
    pickle.dump(allPs, fp)

In [None]:
with open('/content/drive/My Drive/neurips_bounds_data/all_Ps_10000.pkl', 'rb') as fp:
  savedPs = pickle.load(fp)
print (len(savedPs))

# all_R = [all_quantities['R'] - all_quantities['I_x1x2'] + all_quantities['I_x1x2_given_y'] for (P, all_quantities, all_bounds) in savedPs]

# visualize all bounds

all_S = []
all_upper = []
all_lower_diff = []
for (P, all_quantities, all_bounds) in savedPs:
  lower_diff = (all_bounds['lower_diff'] + max(all_quantities['U1'],all_quantities['U2']))/4 - max(all_quantities['U1'],all_quantities['U2'])
  if lower_diff > 0:
    all_S.append(all_quantities['S'])
    all_lower_diff.append(lower_diff)

fig, ax = plt.subplots()
ax.scatter(all_lower_diff, all_S, color='blue', cmap=plt.cm.coolwarm, zorder=10)
# ax.scatter(all_upper, all_S, color='red', cmap=plt.cm.coolwarm, zorder=10)
# lims = [0.0, 1.0]
# ax.plot(lims, lims, 'k-', alpha=0.75, zorder=0)
# ax.set_aspect('equal')
# ax.set_xlim(lims)
# ax.set_ylim(lims)

plt.show()

In [None]:
allPs = []
trials = 100000
s_max = 0
diff_max = 0
P_max = None
for _ in tqdm(range(trials)):
    k = np.random.exponential(scale=1.0, size=8)
    P = (k / sum(k)).reshape((2,2,2))
    try:
        r, u1, u2, s, diff, Py_givenx1, Py_givenx2, Px1x2, Px1, Px2, I_x1x2 = get_rus_diff(P)
        allPs.append((P, r, u1, u2, s, diff, Py_givenx1, Py_givenx2, Px1x2, Px1, Px2, I_x1x2))
        diff = diff-max(u1,u2)
        if s >= 0.3 and diff >= 0.5:
            print ('s=', s, 'diff-U=', diff, 'P=', P)
        if diff >= diff_max:
            s_max = s
            diff_max = diff
            P_max = P
    except:
        pass
print ('s_max=', s_max, 'diff_max=', diff_max, 'P_max=', P_max)

In [None]:
print (len(allPs))
for entry in allPs:
    P = entry[0]
    r = entry[1]
    u1 = entry[2]
    u2 = entry[3]
    s = entry[4]
    diff = entry[5]
    if s >= 0.2 and diff/4.0-max(u1,u2) >= 0.1:
        print ('s=', s, 'diff-U=', diff/4.0-max(u1,u2), 'P=', P)

In [None]:
P = np.array([[[0.23148881,0.00941881],[0.07064852,0.00650789]],[[0.01758162,0.42417211],[0.22418906,0.01599317]]])
# P = np.array([[[0.00941881,0.23148881],[0.07064852,0.00650789]],[[0.42417211,0.01758162],[0.01599317,0.22418906]]])
# P = np.array([[[0.0,0.25],[0.15,0.1]],[[0.25,0.0],[0.1,0.15]]])
# P = np.array([[[0.03349886,0.02460521], [0.23188986,0.0167147 ]], [[0.04237604,0.64494174], [0.00478614,0.00118744]]])
# P = np.array([[[0.11564164,0.01350078], [0.03169016,0.00144381]], [[0.05276016,0.66327317], [0.10926603,0.01242425]]])
P.shape
all_quantities, all_bounds = get_bounds(P)
r = all_quantities['R']
u1 = all_quantities['U1']
u2 = all_quantities['U2']
s = all_quantities['S']
diff = all_bounds['lower_diff']
Py_givenx1 = all_quantities['Py_given_x1']
Py_givenx2 = all_quantities['Py_given_x2']
Px1x2 = all_quantities['Px1x2']
Px1 = all_quantities['Px1']
Px2 = all_quantities['Px2']
I_x1x2 = all_quantities['I_x1x2']
print ('r=', r, 'u1=', u1, 'u2=', u2, 's=', s, 'total=', r+u1+u2+s)
print ('y|x1=', Py_givenx1, 'y|x2=', Py_givenx2, 'diff=', diff+max(u1,u2), 'diff-U=', diff)
print ('Px1x2=', Px1x2, 'Px1=', Px1, 'Px2=', Px2, 'I_x1x2=', I_x1x2)

### generate binary data

In [None]:
def group(L, support):
    res = np.zeros(L[0].shape)
    index = 0
    for variable in L:
        res += variable*support**index
        index += 1
    return res

def gen_binary_data(num_data):
    x1dim = 3
    x2dim = 3
    ydim = 5

    x1 = [np.random.randint(0, 2, (num_data, 1)) for _ in range(x1dim)]
    x2 = [np.random.randint(0, 2, (num_data, 1)) for _ in range(x2dim)]
    x2 = copy.deepcopy(x1)
    # x1[0] = np.random.randint(0, 2, (num_data, 1))
    x2[0] = np.random.randint(0, 2, (num_data, 1))
    x2[1] = np.random.randint(0, 2, (num_data, 1))

    # s
    ydim1 = (x1[0] + x2[0] + x1[1] + x2[1] + x1[2] + x2[2]) % 2
    # ydim2 = (x1[0] + x2[0] + x1[1]) % 2
    # ydim3 = (x1[0] + x2[2]) % 2
    # ydim4 = (x1[1] + x2[1] + x1[2] + x2[2]) % 2
    # ydim5 = (x2[0] + x1[1] + x2[1] + x1[2]) % 2

    # rand_prob = 0.7
    # ydim3 = (np.random.rand(num_data,1) < rand_prob) * ydim3
    # rand_prob = 0.5
    # ydim5 = (np.random.rand(num_data,1) < rand_prob) * ydim5
    y_s = group([ydim1], support=2)

    # rand_prob = 0.7
    # x1[1] = (np.random.rand(num_data,1) < rand_prob) * x1[1]
    rand_prob = 0.5
    x2[2] = (np.random.rand(num_data,1) < rand_prob) * x2[2]

    rand_prob = 0.3
    x2[0] = (np.random.rand(num_data,1) < rand_prob) * x2[0]

    x1 = group(x1, support=2)
    x2 = group(x2, support=2)
    data = {
        's': (x1, x2, y_s),
    }
    return data

In [None]:
data = gen_binary_data(10000)
y = data['s'][2]
unique, counts = np.unique(y, return_counts=True)
print (unique, counts)

P, maps = convert_data_to_distribution(*data['s'])
r, u1, u2, s = test(P)
print (r, u1, u2, s, r+u1+u2+s)

In [None]:
# checking H(Y), H(Y|X_1,X_2)
P.shape
Py = np.sum(np.sum(P, axis=0), axis=0)
y_ent = scipy.special.entr(Py).sum(axis=0) / np.log(2)
print (y_ent)

Px1x2 = np.sum(P, axis=2, keepdims=True)
Py_givenx1x2 = P/Px1x2
print (Py_givenx1x2.shape)
Py_givenx1x2_ent = scipy.special.entr(Py_givenx1x2) / np.log(2)
Py_givenx1x2_ent = np.einsum('ijk,ij->k', Py_givenx1x2_ent, Px1x2.squeeze(axis=-1))
print (Py_givenx1x2_ent.sum(), Py_givenx1x2_ent.shape)

Px1y = np.sum(P, axis=1, keepdims=True)
Px2y = np.sum(P, axis=0, keepdims=True)
Px1 = np.sum(Px1y, axis=2, keepdims=True)
Px2 = np.sum(Px2y, axis=2, keepdims=True)
Py_givenx1 = Px1y/Px1
Py_givenx2 = Px2y/Px2

print (Py_givenx1.shape, Py_givenx2.shape)
diff = (Py_givenx1 - Py_givenx2)**2
print (diff.shape)
diff = np.einsum('ij,ij', Px1x2.squeeze(axis=-1), diff.sum(axis=-1))
print (diff)