Please put URLs to the libraries you are tallking about:

-   
-

In [None]:
import sklearn
import operator
import numpy as np
import matplotlib.pyplot as plt
import pyspark
from pyspark.mllib.random import RandomRDDs
from pyspark.mllib.feature import Normalizer
from pyspark.ml.linalg import DenseVector

def star(f):
  return lambda args: f(*args)

In [None]:
## Constants
N = 20 # train size
M = 10 # test size
D = 2 # dimensionality
T = 50 # number of rays
one_vs_all = False

assert((not one_vs_all) or N == M)

In [None]:
train_data = sc.range(N).zip(RandomRDDs.normalVectorRDD(sc, N, D))
if one_vs_all:
  test_data = train_data
else:
  test_data = sc.range(M).zip(RandomRDDs.normalVectorRDD(sc, M, D))


In [None]:
def get_uni_sphere():
  u = RandomRDDs.normalVectorRDD(sc, T, D)
  u = Normalizer().transform(u)
  return sc.range(T).zip(u)
  
rays = get_uni_sphere()

In [None]:
def compute_dst_sq(): # (N, M)
  # dst[n, m] = |x_n - x'_m|^2
  dst = train_data.cartesian(test_data).map(lambda ((n, train_vec), (m, test_vec)): ((n, m), np.sum((train_vec - test_vec) ** 2)))
  return dst

def compute_pu(data): # (data.N, T)
  # pu[n, t] = <data_n, ray_t>
  data.cartesian(rays).map(lambda ((n, data_vec), (t, ray_vec)): ((n, t), np.dot(data_vec, ray_vec)))

dst = compute_dst_sq()
pu_train = compute_pu(train_data)
pu_test = compute_pu(test_data)

In [None]:
dst.map(lambda ((k1, k2), v): (k1, k2)).aggregateByKey(0 , operator.__add__, operator.__add__).collect()

  

>     Out[24]: [(0, 45),
>      (1, 45),
>      (2, 45),
>      (3, 45),
>      (4, 45),
>      (5, 45),
>      (6, 45),
>      (7, 45),
>      (8, 45),
>      (9, 45),
>      (10, 45),
>      (11, 45),
>      (12, 45),
>      (13, 45),
>      (14, 45),
>      (15, 45),
>      (16, 45),
>      (17, 45),
>      (18, 45),
>      (19, 45)]

In [None]:

min(np.inf, np.inf)

  

>     Out[28]: inf

In [None]:
def compute_ray_lengths(): # (M, T)
  # lengths[m, t, n] = dst[n, m] / (2 * (pu_train[n, t] - pu_test[m, t]))
  def compute_length(n, m, dst_val, pu_train_val, pu_test_val):
    if one_vs_all and n == m:
      res = np.inf
    else:
      res = dst_val / (2 * (pu_train_val - pu_test_val))
      if res < 0:
        res = np.inf
    return res
        
  lengths = dst.cartesian(sc.range(T)) \
    .map(lambda (((n, m), dst_val), t): ((n, t), (m, dst_val))) \
    .join(pu_train) \
    .map(lambda ((n, t), ((m, dst_val), pu_train_val)): ((m, t), (n, dst_val, pu_train_val))) \
    .join(pu_test) \
    .map(lambda ((m, t), ((n, dst_val, pu_train_val), pu_test_val)): ((n, m), compute_length(n, m, dst_val, pu_train_val, pu_test_val))) \
    .aggregateByKey(np.inf, min, min)  
  return lengths

lengths = compute_ray_lengths()

In [None]:
lengths.collect()

  

>     Out[31]: [((7, 2), 1.2629881606637856),
>      ((2, 5), 0.46159384717891644),
>      ((6, 5), 0.95842580740469974),
>      ((18, 6), 0.32080733712963638),
>      ((13, 8), 0.46353226691190746),
>      ((1, 0), 1.0322240727976018),
>      ((5, 8), 0.6221419411239566),
>      ((17, 3), 0.61821329225771737),
>      ((0, 3), 0.219725553792082),
>      ((16, 0), 0.69839369568667831),
>      ((10, 1), 0.80372052032936458),
>      ((1, 9), 0.6906042774481157),
>      ((14, 9), 0.30787078063744644),
>      ((5, 1), 0.56701491283297545),
>      ((9, 4), 2.6914369106251379),
>      ((4, 2), 0.83261777184691588),
>      ((8, 7), 0.92089719214043397),
>      ((15, 7), 0.6122929027277727),
>      ((3, 7), 1.6523862157448843),
>      ((19, 2), 1.4048399465198804),
>      ((14, 0), 0.91906793254495744),
>      ((2, 3), 0.35773447203831094),
>      ((18, 4), 1.2246623104426142),
>      ((13, 6), 0.58199265332822736),
>      ((1, 6), 0.95367841097105377),
>      ((17, 1), 0.44129466037979431),
>      ((12, 5), 1.2830897155516801),
>      ((0, 5), 0.44393953576163336),
>      ((7, 5), 0.53564880418765026),
>      ((16, 2), 0.95627453086415348),
>      ((11, 0), 1.2333468007146429),
>      ((15, 8), 0.88964327449567226),
>      ((6, 2), 1.6117287288279358),
>      ((16, 9), 0.095526457287020194),
>      ((11, 9), 0.70781973297086676),
>      ((8, 9), 0.72167944524395145),
>      ((15, 1), 0.29895184233403005),
>      ((3, 9), 1.490492902934093),
>      ((19, 4), 0.97324542536975511),
>      ((10, 6), 1.0529978961119231),
>      ((14, 6), 0.55529821022659243),
>      ((9, 3), 1.1085463528245132),
>      ((8, 0), 0.37028634746485956),
>      ((2, 1), 0.46205517947440172),
>      ((18, 2), 0.56645372306614805),
>      ((13, 4), 0.71681320831306405),
>      ((1, 4), 1.3800640688796841),
>      ((17, 7), 0.77177090103880552),
>      ((12, 7), 1.6212433035840421),
>      ((0, 7), 0.37103105931475194),
>      ((7, 7), 0.89609578741092),
>      ((16, 4), 1.3357323528593228),
>      ((11, 2), 0.78999404722987587),
>      ((2, 8), 1.6356977175805627),
>      ((6, 0), 1.0505239373155486),
>      ((17, 8), 0.80166528767792589),
>      ((9, 8), 2.3721365487288408),
>      ((15, 3), 0.43280006527594322),
>      ((19, 6), 0.58173434860086592),
>      ((10, 4), 2.1567903888321225),
>      ((14, 4), 1.5503383282996692),
>      ((5, 6), 0.25573979867893204),
>      ((9, 1), 1.1869170917911065),
>      ((4, 5), 0.15307280071624169),
>      ((8, 2), 0.83050957804747538),
>      ((3, 0), 0.93205737195976313),
>      ((18, 0), 0.57192356623770746),
>      ((13, 2), 1.0710728785481662),
>      ((17, 5), 0.79637115272487891),
>      ((12, 1), 1.473976129854998),
>      ((7, 1), 0.98318742314990359),
>      ((16, 6), 0.34684621253802156),
>      ((11, 4), 1.9138025643888774),
>      ((2, 6), 0.86028252896257962),
>      ((6, 6), 0.91516449710675229),
>      ((1, 3), 1.2197435489134678),
>      ((12, 8), 0.99771143076806668),
>      ((0, 0), 0.76971944922538238),
>      ((4, 8), 1.2419355901375471),
>      ((6, 9), 0.62875120124918571),
>      ((19, 8), 0.71700189156658778),
>      ((10, 2), 0.88523768813616943),
>      ((5, 4), 0.93777136684828089),
>      ((9, 7), 0.8873564854703283),
>      ((4, 7), 0.51551090927227683),
>      ((8, 4), 0.94692204277052683),
>      ((3, 2), 1.4166432409716621),
>      ((19, 1), 1.0411599133967446),
>      ((13, 0), 0.30704640355383234),
>      ((12, 3), 1.409316268266543),
>      ((7, 3), 0.78097032550863132),
>      ((11, 6), 0.81012216708222196),
>      ((2, 4), 1.9504574658050049),
>      ((6, 4), 1.4479420438996458),
>      ((18, 9), 0.60322830048684573),
>      ((13, 9), 0.85289983716955908),
>      ((10, 9), 0.95742735247451938),
>      ((1, 1), 1.3306212787883265),
>      ((5, 9), 0.54699081589252607),
>      ((0, 2), 0.43170301576064163),
>      ((7, 8), 1.5159764776409621),
>      ((10, 0), 1.4789712913830841),
>      ((14, 8), 1.2479481593136914),
>      ((5, 2), 0.8677224929221472),
>      ((9, 5), 1.1666264864242017),
>      ((0, 9), 0.58104164248643286),
>      ((4, 1), 0.52398631192961764),
>      ((8, 6), 0.41482314114008911),
>      ((15, 4), 1.2070687579725508),
>      ((3, 4), 0.92320092590328184),
>      ((19, 3), 0.99290161131331789),
>      ((14, 3), 0.52311500866533189),
>      ((2, 2), 0.43548602918493379),
>      ((18, 7), 0.61533909915557894),
>      ((13, 7), 1.1663107507667714),
>      ((1, 7), 1.4211421573691456),
>      ((17, 2), 0.68243775416795349),
>      ((12, 4), 1.1577835155114173),
>      ((0, 4), 1.4517699002604638),
>      ((16, 1), 0.60623433555525552),
>      ((11, 1), 0.62488896944257266),
>      ((15, 9), 0.51293535774093213),
>      ((1, 8), 1.1681587505698761),
>      ((5, 0), 0.27163785388663164),
>      ((4, 3), 0.34258200289763463),
>      ((16, 8), 1.0257869694126154),
>      ((8, 8), 0.6628136357908152),
>      ((15, 6), 0.23404413231021351),
>      ((3, 6), 1.1999606961707669),
>      ((19, 5), 0.92865090152918539),
>      ((14, 1), 0.69423573030208374),
>      ((2, 0), 1.2695570487326961),
>      ((18, 5), 0.65041335994421112),
>      ((13, 5), 1.1345429570993462),
>      ((1, 5), 1.0634242407422114),
>      ((17, 0), 0.48745179448738402),
>      ((12, 6), 1.0363281833528688),
>      ((0, 6), 0.40985309990503543),
>      ((7, 4), 1.8317418412096649),
>      ((16, 3), 0.49805933788983492),
>      ((11, 3), 0.40707072459369453),
>      ((6, 3), 1.1359668269478991),
>      ((17, 9), 0.67834037590152141),
>      ((9, 9), 1.6352097052386736),
>      ((11, 8), 1.5938733715898281),
>      ((15, 0), 0.53195044382126766),
>      ((3, 8), 0.87931929122143004),
>      ((19, 7), 1.2165659623867986),
>      ((10, 7), 0.50909899990080987),
>      ((14, 7), 0.69384219551105308),
>      ((5, 7), 0.87726908515642577),
>      ((9, 2), 0.97173713088887748),
>      ((4, 4), 1.5620074841367915),
>      ((8, 1), 0.59290980751784039),
>      ((3, 1), 1.3220991596382465),
>      ((18, 3), 0.45922635090637298),
>      ((13, 3), 0.99663184896681201),
>      ((17, 6), 0.36870821465687342),
>      ((12, 0), 0.97372959857745889),
>      ((7, 6), 0.87478979105029642),
>      ((16, 5), 0.39554682008869746),
>      ((11, 5), 0.24223044742057079),
>      ((2, 9), 0.91462884114403686),
>      ((6, 1), 1.266488450342635),
>      ((15, 2), 0.6188218455525013),
>      ((6, 8), 1.2125012977118395),
>      ((19, 9), 0.44077308824584893),
>      ((10, 5), 0.48684836097464856),
>      ((14, 5), 0.31910276829560952),
>      ((5, 5), 0.80930267812653067),
>      ((9, 0), 2.0080291991953323),
>      ((4, 6), 0.4913627407115953),
>      ((8, 3), 0.76912017088188789),
>      ((3, 3), 1.5156693420028249),
>      ((18, 1), 0.2875018838011616),
>      ((13, 1), 0.83962659409281137),
>      ((17, 4), 1.0977395853744727),
>      ((12, 2), 1.8377815902953802),
>      ((7, 0), 1.2158276369479522),
>      ((16, 7), 0.70834934624726187),
>      ((11, 7), 0.39872191545506208),
>      ((2, 7), 0.13802398789471446),
>      ((6, 7), 1.3253396091863827),
>      ((18, 8), 0.9190160048918159),
>      ((10, 8), 1.8359993905873138),
>      ((1, 2), 1.6871098563143547),
>      ((12, 9), 0.84729867310861384),
>      ((0, 1), 0.068016869080292031),
>      ((7, 9), 0.5985883086826872),
>      ((4, 9), 0.34203425952994843),
>      ((10, 3), 0.60945946588635047),
>      ((5, 3), 0.68216703107810439),
>      ((9, 6), 1.6123958327310088),
>      ((0, 8), 1.1315481029905607),
>      ((4, 0), 0.90116073516300088),
>      ((8, 5), 0.91351621419499029),
>      ((15, 5), 0.59232304443170058),
>      ((3, 5), 1.7031419713107707),
>      ((19, 0), 0.56443477104689566),
>      ((14, 2), 1.0125379240330314)]

In [None]:
def _old_compute_integrals(): # (M)
  integrals = np.zeros((M,))
  for m in range(M):
    current = 0.
    cur_coeff = 0.
    cur_T = np.sum(lengths[m, :] < np.inf)
    prev_len = 0.
    
    for t in range(cur_T):
      cur_len = lengths[m, t]
      current += (cur_len - prev_len) * (cur_T - t)
      if t > 0:  # can omit with large D
        current += cur_coeff * (-D + 1) * (np.power(cur_len, -D + 1) - np.power(prev_len, -D + 1))
      cur_coeff += np.power(cur_len, D)
      prev_len = cur_len
    
    integrals[m] = current
  return integrals

def compute_integrals(): # (M)
  integrals = np.zeros((M,))
  finite = np.sum(np.isfinite(lengths), axis=1)
  lengths[np.isinf(lengths)] = 0
  integrals = np.sum(lengths, axis=1) / finite
  return integrals

integrals = compute_integrals()

In [None]:
integrals

In [None]:
weights = 1. / integrals
weights = (50. / np.max(weights)) * weights

In [None]:
fig = plt.figure()
plt.scatter(test_data[:, 0], test_data[:, 1], weights)
display(fig)

In [None]:
from sklearn.neighbors.kde import KernelDensity
kde = KernelDensity(kernel='tophat', bandwidth=0.13).fit(train_data)
kde_weights = kde.score_samples(test_data)
kde_weights = np.exp(kde_weights)
kde_weights = (50. / np.max(kde_weights)) * kde_weights

In [None]:
fig = plt.figure()
plt.scatter(test_data[:, 0], test_data[:, 1], kde_weights)
display(fig)