In [1]:
import copy
from dash import Dash, dcc, html, Input, Output
import grakel
from grakel.kernels import WeisfeilerLehman, VertexHistogram, ShortestPath, SubgraphMatching
import json
import os
import pandas as pd
import plotly.express as px
from rdflib import Graph, URIRef
from rdflib.namespace import RDF, SOSA
from rdflib.extras.external_graph_libs import rdflib_to_networkx_digraph
from sklearn.model_selection import KFold
from tqdm import tqdm
import time

from viscars.recommenders import Recommender
from viscars.evaluation.evaluators import Evaluator
from viscars.evaluation.metrics import MetricType
from viscars.evaluation.metrics.factory import MetricFactory
from viscars.namespace import DASHB

In [2]:
GRAPH_INPUT_DIR = 'D:\\Documents\\UGent\\PhD\\projects\\PhD\\VisCARS\\data\\protego'
GRAPH_SIMILARITY_DIR = '/home/pieter/OpenKE/benchmarks/DASHB'

graph = Graph()
graph.parse(os.path.join(GRAPH_INPUT_DIR, 'protego_zplus.ttl'))
graph.parse(os.path.join(GRAPH_INPUT_DIR, 'graph.ttl'))
items = graph.subjects(RDF.type, SOSA.ObservableProperty)

### Patients (Context)

In [3]:
cc_similarities = {}
with open(os.path.join(GRAPH_INPUT_DIR, 'RDF2VecUU.json')) as input_f:
    cc_similarities = json.load(input_f)

### ObservableProperties (Item)

In [4]:
qry = '''
    PREFIX dashb: <http://dynamicdashboard.ilabt.imec.be/broker/ontologies/dashboard#>
    PREFIX sosa: <http://www.w3.org/ns/sosa/>
    
    SELECT ?property ?metric WHERE {
        ?property a sosa:ObservableProperty ;
            dashb:produces ?metric .
    }
'''
properties = {property_: metric for property_, metric in graph.query(qry)}

ii_similarities = {}
# with open(os.path.join(GRAPH_SIMILARITY_DIR, 'results', 'RDF2VecII.json')) as input_f:
#     ii_similarities = json.load(input_f)

### Dashboard Users (User)

In [5]:
qry = '''
    PREFIX dashb: <http://dynamicdashboard.ilabt.imec.be/broker/ontologies/dashboard#>
    
    SELECT ?user ?label ?role WHERE {
        ?user a dashb:User ;
            rdfs:label ?label ;
            dashb:memberOf ?role .
    }
'''
users = [(str(user), str(label), str(role)) for user, label, role in graph.query(qry)]

uu_similarities = {}
for user, label, role in users:
    uu_similarities[user] = {}
    for user_, label_, role_ in users:
        if role == role_:
            uu_similarities[user][user_] = 1
        else:
            uu_similarities[user][user_] = 0

In [6]:
def kneighbours_cc(uid, cid, iid, k=20, strict=True):
    cc_similarities_ = copy.deepcopy(cc_similarities)
    
    if strict:
        for c, w in cc_similarities[cid].items():
            qry = f'''
                PREFIX dashb: <http://dynamicdashboard.ilabt.imec.be/broker/ontologies/dashboard#>
                PREFIX sosa: <http://www.w3.org/ns/sosa/>
                PREFIX ssn-ext: <http://dynamicdashboard.ilabt.imec.be/broker/ontologies/ssn-extension/>

                SELECT * WHERE {{
                    ?widget a dashb:Widget ;
                        dashb:createdBy <{uid}> ;
                        dashb:hasProperty ?property .
                        
                    ?sensor ssn-ext:subSystemOf <{c}> ;
                        sosa:observes ?property .
                }}
            '''
            if len(list(graph.query(qry))) == 0:
                cc_similarities_[cid].pop(c)
    
    if cid in cc_similarities_[cid].keys():
        cc_similarities_[cid].pop(cid)
    
    return sorted(cc_similarities_[cid].items(), key=lambda x: x[1])[:k]

def kneighbours_ii(iid, k=20):
    ii_similarities_ = copy.deepcopy(ii_similarities)
    ii_similarities_[iid].pop(uid)
    return sorted(ii_similarities_[iid].items(), key=lambda x: x[1])[:k]

def kneighbours_uu(uid, cid, iid, k=20, strict=True):
    uu_similarities_ = copy.deepcopy(uu_similarities)
    
    if strict:
        for u, w in uu_similarities[uid].items():
            qry = f'''
                PREFIX dashb: <http://dynamicdashboard.ilabt.imec.be/broker/ontologies/dashboard#>

                SELECT * WHERE {{
                    ?widget a dashb:Widget ;
                        dashb:createdBy <{u}> ;
                        dashb:hasProperty <{iid}> .
                }}
            '''
            if len(list(graph.query(qry))) == 0:
                uu_similarities_[uid].pop(u)
                
    if uid in uu_similarities_[uid].keys():
        uu_similarities_[uid].pop(uid)
                
    return sorted(uu_similarities_[uid].items(), key=lambda x: x[1])[:k]

In [7]:
def predict_uucf(uid, cid, iid):
    # Calculate K-nearest neighbours of the user
    neighbours = kneighbours_uu(uid, cid, iid)
    
    total_rw = 0
    total_w = 0
    for n, w in neighbours:
        # Get ratings from user n for item i
        qry = f'''
            PREFIX dashb:<http://dynamicdashboard.ilabt.imec.be/broker/ontologies/dashboard#>
    
            SELECT * WHERE {{
                ?widget a dashb:Widget ;
                    dashb:createdBy <{n}> ;
                    dashb:hasProperty <{iid}> .
            }}
        '''
        r_aui = 1 if len(list(graph.query(qry))) > 0 else 0
        w_au = w  # Similarity score of user u and user n
        total_rw += r_aui * w_au
        total_w += w
    return total_rw / total_w if total_w > 0 else 0

In [8]:
cid = 'http://example.com/tx/patients/zplus_6'

qry = f'''
    PREFIX dashb: <http://dynamicdashboard.ilabt.imec.be/broker/ontologies/dashboard#>
    PREFIX sosa: <http://www.w3.org/ns/sosa/>
    PREFIX ssn-ext: <http://dynamicdashboard.ilabt.imec.be/broker/ontologies/ssn-extension/>

    SELECT ?property WHERE {{
        ?sensor ssn-ext:subSystemOf <{cid}> ;
            sosa:observes ?property .
    }}
'''
items = [str(property_[0]) for property_ in graph.query(qry)]
for iid in items:
    p = predict_uucf('https://dynamicdashboard.ilabt.imec.be/users/10', cid, iid)
    print(f'[{p}] {iid}')

[1.0] https://webthing.protego.dynamicdashboard.ilabt.imec.be/things/zplus_6.lifestyle/properties/enriched-call
[1.0] https://webthing.protego.dynamicdashboard.ilabt.imec.be/things/zplus_6.60%253A77%253A71%253A7D%253A93%253AD7%252Fservice0009/properties/org.dyamand.types.health.GlucoseLevel
[0] https://webthing.protego.dynamicdashboard.ilabt.imec.be/things/zplus_6.B0%253A91%253A22%253AFB%253AD0%253A78%252Fservice0009/properties/org.dyamand.types.health.GlucoseLevel
[0] https://webthing.protego.dynamicdashboard.ilabt.imec.be/things/zplus_6.00%253A9D%253A6B%253A89%253A43%253ACD%252Fservice0025/properties/org.dyamand.types.health.HeartRate
[0] https://webthing.protego.dynamicdashboard.ilabt.imec.be/things/zplus_6.00%253A9D%253A6B%253A89%253A43%253ACD%252Fservice0025/properties/org.dyamand.types.health.DiastolicBloodPressure
[0] https://webthing.protego.dynamicdashboard.ilabt.imec.be/things/zplus_6.00%253A9D%253A6B%253A89%253A43%253ACD%252Fservice0025/properties/org.dyamand.types.health.Sy

In [9]:
def predict_cccf(uid, cid, iid):
    # Calculate K-nearest neighbours of the patient
    neighbours = kneighbours_cc(uid, cid, iid)
    
    total_rw = 0
    total_w = 0
    for n, w in neighbours:
        # Get ratings from user i for similar items (metric-based) of patient n
        qry = f'''
            PREFIX dashb: <http://dynamicdashboard.ilabt.imec.be/broker/ontologies/dashboard#>
            PREFIX sosa: <http://www.w3.org/ns/sosa/>
            PREFIX ssn-ext: <http://dynamicdashboard.ilabt.imec.be/broker/ontologies/ssn-extension/>
    
            SELECT * WHERE {{
                <{iid}> dashb:produces ?metric .
                
                ?sensor ssn-ext:subSystemOf <{n}> ;
                    sosa:observes ?property .
                ?property dashb:produces ?metric .
                
                ?widget a dashb:Widget ;
                    dashb:createdBy <{uid}> ;
                    dashb:hasProperty ?property .
            }}
        '''
        r_aui = 1 if len(list(graph.query(qry))) > 0 else 0
        w_au = w  # Similarity score of patient cid and patient n
        total_rw += r_aui * w_au
        total_w += w
    return total_rw / total_w if total_w > 0 else 0

In [10]:
cid = 'http://example.com/tx/patients/zplus_6'

qry = f'''
    PREFIX dashb: <http://dynamicdashboard.ilabt.imec.be/broker/ontologies/dashboard#>
    PREFIX sosa: <http://www.w3.org/ns/sosa/>
    PREFIX ssn-ext: <http://dynamicdashboard.ilabt.imec.be/broker/ontologies/ssn-extension/>

    SELECT ?property WHERE {{
        ?sensor ssn-ext:subSystemOf <{cid}> ;
            sosa:observes ?property .
    }}
'''
items = [str(property_[0]) for property_ in graph.query(qry)]
for iid in items:
    p = predict_cccf('https://dynamicdashboard.ilabt.imec.be/users/10', cid, iid)
    print(f'[{p}] {iid}')

[1.0] https://webthing.protego.dynamicdashboard.ilabt.imec.be/things/zplus_6.lifestyle/properties/enriched-call
[0.480242082135768] https://webthing.protego.dynamicdashboard.ilabt.imec.be/things/zplus_6.60%253A77%253A71%253A7D%253A93%253AD7%252Fservice0009/properties/org.dyamand.types.health.GlucoseLevel
[0.480242082135768] https://webthing.protego.dynamicdashboard.ilabt.imec.be/things/zplus_6.B0%253A91%253A22%253AFB%253AD0%253A78%252Fservice0009/properties/org.dyamand.types.health.GlucoseLevel
[0.0] https://webthing.protego.dynamicdashboard.ilabt.imec.be/things/zplus_6.00%253A9D%253A6B%253A89%253A43%253ACD%252Fservice0025/properties/org.dyamand.types.health.HeartRate
[0.24193654921762062] https://webthing.protego.dynamicdashboard.ilabt.imec.be/things/zplus_6.00%253A9D%253A6B%253A89%253A43%253ACD%252Fservice0025/properties/org.dyamand.types.health.DiastolicBloodPressure
[0.24193654921762062] https://webthing.protego.dynamicdashboard.ilabt.imec.be/things/zplus_6.00%253A9D%253A6B%253A89%

In [11]:
def predict(uid, cid):
    # Get all items from the patient
    qry = f'''
        PREFIX dashb: <http://dynamicdashboard.ilabt.imec.be/broker/ontologies/dashboard#>
        PREFIX sosa: <http://www.w3.org/ns/sosa/>
        PREFIX ssn-ext: <http://dynamicdashboard.ilabt.imec.be/broker/ontologies/ssn-extension/>

        SELECT ?property WHERE {{
            ?sensor ssn-ext:subSystemOf <{cid}> ;
                sosa:observes ?property .
        }}
    '''
    items = [str(property_[0]) for property_ in graph.query(qry)]
    
    scores = {}
    for iid in items:
        scores[iid] = 0.5 * predict_uucf(uid, cid, iid) + 0.5 * predict_cccf(uid, cid, iid)
        
    recommendations = [{'contextId': cid, 'itemId': item, 'score': score} for item, score in scores.items()]
    return sorted(recommendations, key=lambda n: n['score'], reverse=True)

In [12]:
predict('https://dynamicdashboard.ilabt.imec.be/users/10', 'http://example.com/tx/patients/zplus_6')

[{'contextId': 'http://example.com/tx/patients/zplus_6',
  'itemId': 'https://webthing.protego.dynamicdashboard.ilabt.imec.be/things/zplus_6.lifestyle/properties/enriched-call',
  'score': 1.0},
 {'contextId': 'http://example.com/tx/patients/zplus_6',
  'itemId': 'https://webthing.protego.dynamicdashboard.ilabt.imec.be/things/zplus_6.60%253A77%253A71%253A7D%253A93%253AD7%252Fservice0009/properties/org.dyamand.types.health.GlucoseLevel',
  'score': 0.740121041067884},
 {'contextId': 'http://example.com/tx/patients/zplus_6',
  'itemId': 'https://webthing.protego.dynamicdashboard.ilabt.imec.be/things/zplus_6.B0%253A91%253A22%253AFB%253AD0%253A78%252Fservice0009/properties/org.dyamand.types.health.GlucoseLevel',
  'score': 0.240121041067884},
 {'contextId': 'http://example.com/tx/patients/zplus_6',
  'itemId': 'https://webthing.protego.dynamicdashboard.ilabt.imec.be/things/zplus_6.00%253A1C%253A05%253AFF%253AA9%253A4E%252Fservice0029/properties/org.dyamand.types.health.SpO2',
  'score': 0.

In [13]:
# for user, label, role in users:
#     for patient_ in tqdm(cc_similarities.keys()):
#         recommendations = predict(user, patient_)
#         for recommendation in recommendations:
#             if recommendation['score'] != 0:
#                 print(user, patient_)

In [14]:
qry = '''
    PREFIX dashb: <http://dynamicdashboard.ilabt.imec.be/broker/ontologies/dashboard#>
    PREFIX sosa: <http://www.w3.org/ns/sosa/>
    PREFIX ssn-ext: <http://dynamicdashboard.ilabt.imec.be/broker/ontologies/ssn-extension/>

    SELECT ?user ?context ?property WHERE {
        ?sensor ssn-ext:subSystemOf ?context ;
            sosa:observes ?property .
            
        ?widget dashb:hasProperty ?property ;
                dashb:createdBy ?user .                
    }
'''

ratings = {'user': [], 'item': [], 'rating': [], 'context': []}
for row in graph.query(qry):
    user_ = row[0]
    context_ = row[1]
    item_ = row[2]
    
    ratings.get('user').append(user_)
    ratings.get('item').append(item_)
    ratings.get('rating').append(5.0)
    ratings.get('context').append(context_)
ratings_df = pd.DataFrame.from_dict(ratings)

In [15]:
ratings_df.head()

Unnamed: 0,user,item,rating,context
0,https://dynamicdashboard.ilabt.imec.be/users/7,https://webthing.protego.dynamicdashboard.ilab...,5.0,http://example.com/tx/patients/zplus_2
1,https://dynamicdashboard.ilabt.imec.be/users/9,https://webthing.protego.dynamicdashboard.ilab...,5.0,http://example.com/tx/patients/zplus_6
2,https://dynamicdashboard.ilabt.imec.be/users/7,https://webthing.protego.dynamicdashboard.ilab...,5.0,http://example.com/tx/patients/zplus_6
3,https://dynamicdashboard.ilabt.imec.be/users/7,https://webthing.protego.dynamicdashboard.ilab...,5.0,http://example.com/tx/patients/zplus_9
4,https://dynamicdashboard.ilabt.imec.be/users/9,https://webthing.protego.dynamicdashboard.ilab...,5.0,http://example.com/tx/patients/zplus_20


In [16]:
# Context metadata
qry = '''
    PREFIX dashb: <http://dynamicdashboard.ilabt.imec.be/broker/ontologies/dashboard#>
    PREFIX sosa: <http://www.w3.org/ns/sosa/>
    PREFIX ssn-ext: <http://dynamicdashboard.ilabt.imec.be/broker/ontologies/ssn-extension/>

    SELECT ?context WHERE {
        ?sensor ssn-ext:subSystemOf ?context ;
            sosa:observes ?property .
    }
'''
context_metadata = {'id': []}

result = graph.query(qry)
for row in result:
    context_metadata.get('id').append(row[0])
context_metadata_df = pd.DataFrame.from_dict(context_metadata)

# User metadata
qry = '''
    PREFIX dashb: <http://dynamicdashboard.ilabt.imec.be/broker/ontologies/dashboard#>

    SELECT ?user ?username ?role WHERE {
        ?user dashb:memberOf ?role .
    }
'''
user_metadata = {'id': [], 'type': []}

result = graph.query(qry)
for row in result:
    user_metadata.get('id').append(row[0])
    user_metadata.get('type').append(row[1])
user_metadata_df = pd.DataFrame.from_dict(user_metadata)

# Item metadata
qry = '''
    PREFIX dashb: <http://dynamicdashboard.ilabt.imec.be/broker/ontologies/dashboard#>
    PREFIX sosa: <http://www.w3.org/ns/sosa/>

    SELECT ?property WHERE {
        ?property a sosa:ObservableProperty .
    }
'''
item_metadata = {'id': []}

result = graph.query(qry)
for row in result:
    item_metadata.get('id').append(row[0])
item_metadata_df = pd.DataFrame.from_dict(item_metadata)


def build_subgraph_from_ratings(ratings: pd.DataFrame) -> Graph:
    sub_graph = Graph()

    for idx, row in context_metadata_df.iterrows():
        cid = row['id']
        sub_graph += graph.triples((cid, None, None))
        sub_graph += graph.triples((None, None, cid))

    for idx, row in user_metadata_df.iterrows():
        uid = row['id']
        sub_graph += graph.triples((uid, DASHB.memberOf, None))

    for idx, row in item_metadata_df.iterrows():
        iid = row['id']
        sub_graph += graph.triples((iid, None, None))
        sub_graph += graph.triples((None, None, iid))

    for idx, row in ratings.iterrows():
        uid = row['user']
        iid = row['item']
        cid = row['context']

        sub_graph += graph.triples((uid, None, None))
        sub_graph += graph.triples((None, None, uid))
        sub_graph += graph.triples((iid, None, None))
        sub_graph += graph.triples((None, None, iid))

        sub_graph += graph.triples((cid, None, None))
        sub_graph += graph.triples((None, None, cid))

    return sub_graph

In [17]:
class CF(Recommender):
    def __init__(self, graph: Graph, verbose=False):
        super().__init__(graph, verbose)
    
    def kneighbours_cc(self, uid, cid, iid, k=20, strict=True):
        cc_similarities_ = copy.deepcopy(cc_similarities)

        if strict:
            for c, w in cc_similarities[cid].items():
                qry = f'''
                    PREFIX dashb: <http://dynamicdashboard.ilabt.imec.be/broker/ontologies/dashboard#>
                    PREFIX sosa: <http://www.w3.org/ns/sosa/>
                    PREFIX ssn-ext: <http://dynamicdashboard.ilabt.imec.be/broker/ontologies/ssn-extension/>

                    SELECT * WHERE {{
                        ?widget a dashb:Widget ;
                            dashb:createdBy <{uid}> ;
                            dashb:hasProperty ?property .

                        ?sensor ssn-ext:subSystemOf <{c}> ;
                            sosa:observes ?property .
                    }}
                '''
                if len(list(self.graph.query(qry))) == 0:
                    cc_similarities_[cid].pop(c)

        if cid in cc_similarities_[cid].keys():
            cc_similarities_[cid].pop(cid)

        return sorted(cc_similarities_[cid].items(), key=lambda x: x[1])[:k]

    def kneighbours_ii(self, iid, k=20):
        ii_similarities_ = copy.deepcopy(ii_similarities)
        ii_similarities_[iid].pop(uid)
        return sorted(ii_similarities_[iid].items(), key=lambda x: x[1])[:k]

    def kneighbours_uu(self, uid, cid, iid, k=20, strict=True):
        uu_similarities_ = copy.deepcopy(uu_similarities)

        if strict:
            for u, w in uu_similarities[uid].items():
                qry = f'''
                    PREFIX dashb: <http://dynamicdashboard.ilabt.imec.be/broker/ontologies/dashboard#>

                    SELECT * WHERE {{
                        ?widget a dashb:Widget ;
                            dashb:createdBy <{u}> ;
                            dashb:hasProperty <{iid}> .
                    }}
                '''
                if len(list(self.graph.query(qry))) == 0:
                    uu_similarities_[uid].pop(u)

        if uid in uu_similarities_[uid].keys():
            uu_similarities_[uid].pop(uid)

        return sorted(uu_similarities_[uid].items(), key=lambda x: x[1])[:k]

    def _build_model(self):
        pass
    
    def predict_uucf(self, uid, cid, iid):
        # Calculate K-nearest neighbours of the user
        neighbours = self.kneighbours_uu(uid, cid, iid)

        total_rw = 0
        total_w = 0
        for n, w in neighbours:
            # Get ratings from user n for item i
            qry = f'''
                PREFIX dashb:<http://dynamicdashboard.ilabt.imec.be/broker/ontologies/dashboard#>

                SELECT * WHERE {{
                    ?widget a dashb:Widget ;
                        dashb:createdBy <{n}> ;
                        dashb:hasProperty <{iid}> .
                }}
            '''
            r_aui = 1 if len(list(self.graph.query(qry))) > 0 else 0
            w_au = w  # Similarity score of user u and user n
            total_rw += r_aui * w_au
            total_w += w
        return total_rw / total_w if total_w > 0 else 0
    
    def predict_cccf(self, uid, cid, iid):
        # Calculate K-nearest neighbours of the patient
        neighbours = self.kneighbours_cc(uid, cid, iid)

        total_rw = 0
        total_w = 0
        for n, w in neighbours:
            # Get ratings from user i for similar items (metric-based) of patient n
            qry = f'''
                PREFIX dashb: <http://dynamicdashboard.ilabt.imec.be/broker/ontologies/dashboard#>
                PREFIX sosa: <http://www.w3.org/ns/sosa/>
                PREFIX ssn-ext: <http://dynamicdashboard.ilabt.imec.be/broker/ontologies/ssn-extension/>

                SELECT * WHERE {{
                    <{iid}> dashb:produces ?metric .

                    ?sensor ssn-ext:subSystemOf <{n}> ;
                        sosa:observes ?property .
                    ?property dashb:produces ?metric .

                    ?widget a dashb:Widget ;
                        dashb:createdBy <{uid}> ;
                        dashb:hasProperty ?property .
                }}
            '''
            r_aui = 1 if len(list(self.graph.query(qry))) > 0 else 0
            w_au = w  # Similarity score of patient cid and patient n
            total_rw += r_aui * w_au
            total_w += w
        return total_rw / total_w if total_w > 0 else 0

    def predict(self, uid, cid, *kwargs):
        # Get all items from the patient
        qry = f'''
            PREFIX dashb: <http://dynamicdashboard.ilabt.imec.be/broker/ontologies/dashboard#>
            PREFIX sosa: <http://www.w3.org/ns/sosa/>
            PREFIX ssn-ext: <http://dynamicdashboard.ilabt.imec.be/broker/ontologies/ssn-extension/>

            SELECT ?property WHERE {{
                ?sensor ssn-ext:subSystemOf <{cid}> ;
                    sosa:observes ?property .
            }}
        '''
        items = [str(property_[0]) for property_ in self.graph.query(qry)]

        scores = {}
        for iid in items:
            scores[iid] = 0.5 * self.predict_uucf(uid, cid, iid) + 0.5 * self.predict_cccf(uid, cid, iid)

        recommendations = [{'contextId': cid, 'itemId': item, 'score': score} for item, score in scores.items()]
        return sorted(recommendations, key=lambda n: n['score'], reverse=True)

    def top_n(self, uid: [], cid: [], n: int, **kwargs):
        pass

In [18]:
class KFoldCrossValidation():

    def __init__(self, recommender: Recommender, metrics: [], k=5):
        """
        :param project_id: ID of the project (to load the correct data).
        :param recommender: Recommender
        :param metrics: List of Metrics
        :param k: Number of folds
        """
        self.recommender = recommender
        self.metrics = metrics
        self.k = k

    def evaluate(self, ratings, **kwargs):
        kf = KFold(n_splits=self.k, shuffle=True)

        n_fold = 0

        result = {'folds': [], 'result': {}}
        for train_idx, test_idx in kf.split(ratings):
            train = ratings.iloc[train_idx]
            test = ratings.iloc[test_idx]

            graph = build_subgraph_from_ratings(train)
            self.recommender.set_graph(graph)

            fold_scores = {}

            for uid in tqdm(test['user'].unique()):
                df_user = test.loc[test['user'] == uid]

                for cid in df_user['context']:
                    predictions = self.recommender.predict(str(uid), str(cid), **kwargs)
                    recommendations = [r['itemId'] for r in predictions]

                    truth = []
                    t_user = test.loc[test['user'] == uid]
                    for idx, row in t_user.iterrows():
                        if row['context'] == cid:
                            truth.append(str(row['item']))
                    # truth = list(test.loc[test['user'] == uid].loc[test['context'] == cid]['item'])

                    for metric in self.metrics:
                        if str(metric) not in fold_scores.keys():
                            fold_scores[str(metric)] = []
                        score = metric.calculate(recommendations, truth)
                        fold_scores[str(metric)].append(score)

            result_for_fold = {}
            for metric, scores in fold_scores.items():
                avg = sum(scores) / len(scores)
                result_for_fold[metric] = avg

                if metric not in result['result'].keys():
                    result['result'][metric] = []
                result['result'][metric].append(avg)
            result['folds'].append(result_for_fold)

            n_fold += 1

        final_results = {}
        for metric_type, score in result['result'].items():
            final_results[metric_type] = sum(score) / len(score)

        result['result'] = final_results
        return result

In [19]:
metric_factory = MetricFactory()

metrics = ['f1@1', 'ndcg@1', 'ndcg@3']
parsed_metrics = []
for metric in metrics:
    m_split = metric.split('@')
    m_type = m_split[0]
    n = int(m_split[1]) if len(m_split) >= 2 else None

    metric_ = metric_factory.get(MetricType.reverse_lookup(m_type), n)
    parsed_metrics.append(metric_)

In [20]:
recommender = CF(graph)

evaluator = KFoldCrossValidation(recommender, metrics=parsed_metrics, k=5)
result = evaluator.evaluate(ratings_df)

for fold in result['folds']:
    print(fold)
print(result['result'])

100%|████████████████████████████████████████████████████████████████████████████████████| 3/3 [03:32<00:00, 70.96s/it]
100%|████████████████████████████████████████████████████████████████████████████████████| 3/3 [03:33<00:00, 71.04s/it]
100%|████████████████████████████████████████████████████████████████████████████████████| 4/4 [03:33<00:00, 53.40s/it]
100%|████████████████████████████████████████████████████████████████████████████████████| 4/4 [03:33<00:00, 53.47s/it]
100%|████████████████████████████████████████████████████████████████████████████████████| 4/4 [03:11<00:00, 47.87s/it]

{'f1@1': 0.13333333333333333, 'ndcg@1': 0.12262943855309168, 'ndcg@3': 0.3750013399816747}
{'f1@1': 0.8, 'ndcg@1': 0.7678883156592751, 'ndcg@3': 0.8452588771061833}
{'f1@1': 0.3, 'ndcg@1': 0.3, 'ndcg@3': 0.3}
{'f1@1': 0.5333333333333334, 'ndcg@1': 0.5226294385530916, 'ndcg@3': 0.5226294385530916}
{'f1@1': 0.3333333333333333, 'ndcg@1': 0.3333333333333333, 'ndcg@3': 0.6838498630952543}
{'f1@1': 0.42000000000000004, 'ndcg@1': 0.4092961052197584, 'ndcg@3': 0.5453479037472408}



