In [77]:
from typing import Dict, Tuple, List
import numpy as np
from scipy.spatial.distance import cosine
from itertools import combinations
from functools import  reduce

In [181]:
@staticmethod
def similarity(embeddings: Dict[int, np.ndarray]) -> Dict[Tuple[int, int], float]:
    """Calculate pairwise similarities between each item
    in embedding.

    Args:
        embeddings (Dict[int, np.ndarray]): Items embeddings.

    Returns:
        Dict[Tuple[int, int], float]:
        Keys are in form of (i, j) - combinations pairs of item_ids
        with i < j.
        Round each value to 8 decimal places.
    """
    combs = combinations(embeddings, 2)
    pair_sims = {comb: round(cosine(embeddings[comb[0]], embeddings[comb[1]]), 8) for comb in combs}

    return pair_sims

@staticmethod
def knn(
    sim: Dict[Tuple[int, int], float], top: int
) -> Dict[int, List[Tuple[int, float]]]:
    """Return closest neighbors for each item.

    Args:
        sim (Dict[Tuple[int, int], float]): <similarity> method output.
        top (int): Number of top neighbors to consider.

    Returns:
        Dict[int, List[Tuple[int, float]]]: Dict with top closest neighbors
        for each item.
    """
    knn_dict = {item_id: [] for pair in sim for item_id in pair}

    for (item1, item2), similarity_score in sim.items():
        knn_dict[item1].append((item2, similarity_score))
        knn_dict[item2].append((item1, similarity_score))

    for item_id, neighbors in knn_dict.items():
        sorted(neighbors, key=lambda x: x[1], reverse=True)
        knn_dict[item_id] = neighbors[:top]

    return knn_dict

@staticmethod
def knn_price(
    knn_dict: Dict[int, List[Tuple[int, float]]],
    prices: Dict[int, float],
) -> Dict[int, float]:
    """Calculate weighted average prices for each item.
    Weights should be positive numbers in [0, 2] interval.

    Args:
        knn_dict (Dict[int, List[Tuple[int, float]]]): <knn> method output.
        prices (Dict[int, float]): Price dict for each item.

    Returns:
        Dict[int, float]: New prices dict, rounded to 2 decimal places.
    """
    norm_knn_dict = {}
    for item_id, neighbors in knn_dict.items():
        # min_sim = min(neighbors, key=lambda x: x[1])[1] 
        # max_sim = max(neighbors, key=lambda x: x[1])[1] 
        sum_sim = reduce(lambda x, y: x+y[1] + 1, neighbors, 0)
        normed_neighbors = [(pair[0], (pair[1] + 1) / sum_sim) for pair in neighbors]
        norm_knn_dict[item_id] = normed_neighbors

    knn_price_dict = {}
    for item_id, neighbors in norm_knn_dict.items():
        weighed_price = round(sum([prices[item[0]] * item[1] for item in neighbors]), 2)
        knn_price_dict[item_id] = weighed_price

    return knn_price_dict

@staticmethod
def transform(
    embeddings: Dict[int, np.ndarray],
    prices: Dict[int, float],
    top: int,
) -> Dict[int, float]:
    """Transforming input embeddings into a dictionary
    with weighted average prices for each item.

    Args:
        embeddings (Dict[int, np.ndarray]): Items embeddings.
        prices (Dict[int, float]): Price dict for each item.
        top (int): Number of top neighbors to consider.

    Returns:
        Dict[int, float]: Dict with weighted average prices for each item.
    """
    pair_sims = similarity(embeddings)
    knn_dict = knn(pair_sims, top)
    knn_price_dict = knn_price(knn_dict, prices)
    
    return knn_price_dict

In [182]:
similarity(embeddings)

{(1, 2): 1.15456349,
 (1, 3): 1.27053417,
 (1, 4): 0.81818899,
 (1, 5): 1.03886083,
 (1, 6): 1.03886083,
 (1, 7): 1.03886083,
 (1, 8): 1.03886083,
 (1, 9): 1.03886083,
 (1, 10): 1.03886083,
 (2, 3): 0.68653905,
 (2, 4): 0.85417666,
 (2, 5): 1.45207678,
 (2, 6): 1.45207678,
 (2, 7): 1.45207678,
 (2, 8): 1.45207678,
 (2, 9): 1.45207678,
 (2, 10): 1.45207678,
 (3, 4): 0.89426132,
 (3, 5): 0.56303506,
 (3, 6): 0.56303506,
 (3, 7): 0.56303506,
 (3, 8): 0.56303506,
 (3, 9): 0.56303506,
 (3, 10): 0.56303506,
 (4, 5): 0.96480001,
 (4, 6): 0.96480001,
 (4, 7): 0.96480001,
 (4, 8): 0.96480001,
 (4, 9): 0.96480001,
 (4, 10): 0.96480001,
 (5, 6): 0,
 (5, 7): 0,
 (5, 8): 0,
 (5, 9): 0,
 (5, 10): 0,
 (6, 7): 0,
 (6, 8): 0,
 (6, 9): 0,
 (6, 10): 0,
 (7, 8): 0,
 (7, 9): 0,
 (7, 10): 0,
 (8, 9): 0,
 (8, 10): 0,
 (9, 10): 0}

In [165]:
similarity(embeddings)

{(1, 2): 1.15456349,
 (1, 3): 1.27053417,
 (1, 4): 0.81818899,
 (1, 5): 1.03886083,
 (2, 3): 0.68653905,
 (2, 4): 0.85417666,
 (2, 5): 1.45207678,
 (3, 4): 0.89426132,
 (3, 5): 0.56303506,
 (4, 5): 0.96480001}

In [152]:
transform(embeddings, prices, 3)

{1: 90.89, 2: 59.4, 3: 82.6, 4: 57.31, 5: 57.57}

In [148]:
knn_price(a, prices)

{1: 90.89, 2: 59.4, 3: 82.6, 4: 57.31, 5: 57.57}

In [179]:
embeddings = {
    1: np.array([-26.57, -76.61, 81.61, -9.11, 74.8, 54.23, 32.56, -22.62, -72.44, -82.78]),
    2: np.array([-55.98, 82.87, 86.07, 18.71, -18.66, -46.74, -68.18, 60.29, 98.92, -78.95]),
    3: np.array([-27.97, 25.39, -96.85, 3.51, 95.57, -27.48, -80.27, 8.39, 89.96, -36.68]),
    4: np.array([-37.0, -49.39, 43.3, 73.36, 29.98, -56.44, -15.91, -56.46, 24.54, 12.43]),
    5: np.array([-22.71, 4.47, -65.42, 10.11, 98.34, 17.96, -10.77, 2.5, -26.55, 69.16]),
    6: np.array([-22.71, 4.47, -65.42, 10.11, 98.34, 17.96, -10.77, 2.5, -26.55, 69.16]),
    7: np.array([-22.71, 4.47, -65.42, 10.11, 98.34, 17.96, -10.77, 2.5, -26.55, 69.16]),
    8: np.array([-22.71, 4.47, -65.42, 10.11, 98.34, 17.96, -10.77, 2.5, -26.55, 69.16]),
    9: np.array([-22.71, 4.47, -65.42, 10.11, 98.34, 17.96, -10.77, 2.5, -26.55, 69.16]),
    10: np.array([-22.71, 4.47, -65.42, 10.11, 98.34, 17.96, -10.77, 2.5, -26.55, 69.16])
    
}


prices = {
    1: 100.5,
    2: 12.2,
    3: 60.0,
    4: 11.1,
    5: 245.2
    
}

In [123]:
sim = similarity(embeddings)

In [116]:
sim

{(1, 2): 1.1545634919962326,
 (1, 3): 1.270534168427317,
 (1, 4): 0.8181889878718,
 (1, 5): 1.038860825974679,
 (2, 3): 0.686539047652722,
 (2, 4): 0.8541766630741916,
 (2, 5): 1.4520767766565754,
 (3, 4): 0.8942613246389717,
 (3, 5): 0.563035061881952,
 (4, 5): 0.9648000066514621}

In [117]:
a = knn(sim, 3)

In [118]:
a

{1: [(4, 0.8181889878718), (5, 1.038860825974679), (2, 1.1545634919962326)],
 2: [(3, 0.686539047652722), (4, 0.8541766630741916), (1, 1.1545634919962326)],
 3: [(5, 0.563035061881952), (2, 0.686539047652722), (4, 0.8942613246389717)],
 4: [(1, 0.8181889878718), (2, 0.8541766630741916), (3, 0.8942613246389717)],
 5: [(3, 0.563035061881952), (4, 0.9648000066514621), (1, 1.038860825974679)]}

In [137]:
knn_price(a, prices)

6.011613305842712
5.695279202723146
5.143835434173646
5.566626975584963
5.5666958945080935


{1: [(4, 0.3024460981388631),
  (5, 0.3391536884105805),
  (2, 0.35840021345055634)],
 2: [(3, 0.2961293007103706),
  (4, 0.32556378661605107),
  (1, 0.37830691267357847)],
 3: [(5, 0.3038656819185454),
  (2, 0.3278757785383279),
  (4, 0.3682585395431267)],
 4: [(1, 0.32662310513104525),
  (2, 0.33308800305940156),
  (3, 0.34028889180955313)],
 5: [(3, 0.28078326739996473),
  (4, 0.3529562318268301),
  (1, 0.36626050077320504)]}