In [164]:
import os
import json
import collections
import heapq

In [2]:
def load_testing_data(file):
    with open(file) as f:
        data = json.loads(f.read())
    return data

In [137]:
K = 3

In [4]:
testing_data = load_testing_data("dataset/spotify_challenge/challenge_set.json")

In [123]:
task_to_playlists = {}
pid_to_playlist = {}
playlist_to_num_samples = {}
pid_to_set_tracks = {}
for p in testing_data["playlists"]:
    pid_to_playlist[p["pid"]] = p
    playlist_to_num_samples[p["pid"]] = p["num_samples"]
    pid_to_set_tracks[p["pid"]] = set([t["track_uri"] for t in p["tracks"]])
    for t in p["tracks"]:
        if not task_to_playlists.get(t["track_uri"]):
            task_to_playlists[t["track_uri"]] = []
        task_to_playlists[t["track_uri"]].append(p["pid"])

In [135]:
def process_mpd(path):
    filenames = os.listdir(path)
    count = 0
    for filename in sorted(filenames):
        if filename.startswith("mpd.slice.") and filename.endswith(".json"):
            fullpath = os.sep.join((path, filename))
            with open(fullpath) as f:
                js = f.read()
                mpd_slice = json.loads(js)
                for playlist in mpd_slice["playlists"]:
                    count += 1
                    if count % 1000 == 0:
                        print(count)
                    process_playlist(playlist)


In [83]:
def get_ids_as_set(p):
    return set([t["track_uri"] for t in p["tracks"]])

def calc_dist(ids1, ids2):
    count_common = len(ids1.intersection(ids2))
    return count_common / (len(ids1) + len(ids2) - count_common)

In [136]:
def flatten(t):
    return [item for sublist in t for item in sublist]

In [236]:
knn = {p["pid"]: [] for p in testing_data["playlists"]}
pot_recs_count = collections.Counter()

In [208]:
from dataclasses import dataclass, field

@dataclass(order=True)
class Neighbour(dict):
    dist: float
    pid: str=field(compare=False)
    unique: int=field(compare=False)
        
    def __init__(self, dist, pid, unique):
        self.dist = dist
        self.pid = pid
        self.unique = unique
        super().__init__(self, dist=dist, pid=pid, unique=unique)

In [237]:
def dist_nums(x, y, common):
    return common / (x + y - common)

def process_playlist(playlist):
    nn_playlists = flatten([task_to_playlists.get(t["track_uri"], []) for t in playlist["tracks"]])
    aggr = {}
    for test_p_id in nn_playlists:
        aggr[test_p_id] = aggr.get(test_p_id, 0) + 1
    
    p_num_tracks = playlist["num_tracks"]
    for pid, count in aggr.items():
        dist = dist_nums(playlist_to_num_samples[pid], p_num_tracks, count)
        ns = knn[pid]
        unique_count = p_num_tracks - count
        if pot_recs_count.get(pid, 0) < K * 500:
            heapq.heappush(ns, Neighbour(dist, playlist["pid"], unique_count))
            pot_recs_count[pid] += unique_count
        elif ns[0].dist < dist:
            pot_recs_count[pid] += unique_count - ns[0].unique
            heapq.heappop(ns)
            heapq.heappush(ns, Neighbour(dist, playlist["pid"], unique_count))

In [238]:
process_mpd("dataset/spotify/data")

1000
2000
3000
4000
5000
6000
7000
8000
9000
10000
11000
12000
13000
14000
15000
16000
17000
18000
19000
20000
21000
22000
23000
24000
25000
26000
27000
28000
29000
30000
31000
32000
33000
34000
35000
36000
37000
38000
39000
40000
41000
42000
43000
44000
45000
46000
47000
48000
49000
50000
51000
52000
53000
54000
55000
56000
57000
58000
59000
60000
61000
62000
63000
64000
65000
66000
67000
68000
69000
70000
71000
72000
73000
74000
75000
76000
77000
78000
79000
80000
81000
82000
83000
84000
85000
86000
87000
88000
89000
90000
91000
92000
93000
94000
95000
96000
97000
98000
99000
100000
101000
102000
103000
104000
105000
106000
107000
108000
109000
110000
111000
112000
113000
114000
115000
116000
117000
118000
119000
120000
121000
122000
123000
124000
125000
126000
127000
128000
129000
130000
131000
132000
133000
134000
135000
136000
137000
138000
139000
140000
141000
142000
143000
144000
145000
146000
147000
148000
149000
150000
151000
152000
153000
154000
155000
156000
157000
158000
15

In [239]:
print(len(pot_recs_count))
for x, count in pot_recs_count.most_common(10000):
    print(count)

9000
5672
4744
4100
3956
3834
3667
3590
3551
3407
3382
3332
3331
3321
3298
3278
3188
3150
3146
3125
3116
3106
3089
3078
3076
3069
3065
3037
3036
2998
2991
2977
2970
2963
2962
2948
2944
2941
2939
2930
2926
2921
2910
2910
2907
2882
2877
2876
2868
2868
2864
2853
2853
2851
2848
2847
2847
2844
2844
2844
2839
2832
2826
2825
2811
2810
2805
2803
2803
2801
2795
2788
2774
2772
2767
2766
2753
2752
2748
2748
2744
2733
2730
2722
2712
2708
2706
2705
2703
2703
2698
2698
2694
2687
2686
2685
2684
2673
2670
2667
2665
2663
2649
2647
2646
2641
2641
2640
2639
2638
2638
2638
2631
2628
2628
2626
2622
2620
2617
2614
2614
2613
2612
2610
2610
2609
2608
2605
2603
2601
2601
2598
2595
2593
2592
2587
2578
2574
2572
2571
2571
2569
2569
2564
2564
2563
2560
2559
2555
2555
2551
2550
2547
2547
2546
2538
2532
2531
2530
2529
2528
2523
2523
2523
2521
2521
2517
2511
2509
2509
2508
2507
2506
2504
2502
2502
2498
2497
2497
2496
2493
2493
2489
2486
2484
2483
2482
2481
2481
2480
2480
2479
2479
2476
2475
2475
2475
2473
2473
2472


1665
1665
1664
1664
1664
1664
1664
1664
1664
1664
1664
1663
1663
1663
1663
1663
1663
1663
1663
1663
1663
1662
1662
1662
1662
1662
1662
1662
1661
1661
1661
1661
1661
1661
1661
1661
1661
1661
1661
1660
1660
1660
1660
1660
1660
1660
1660
1660
1660
1660
1660
1659
1659
1659
1659
1659
1658
1658
1658
1658
1658
1658
1658
1658
1658
1658
1658
1658
1657
1657
1657
1657
1657
1657
1657
1657
1657
1657
1656
1656
1656
1656
1656
1656
1656
1656
1656
1656
1656
1655
1655
1655
1655
1655
1655
1655
1655
1655
1655
1655
1655
1655
1654
1654
1654
1654
1654
1654
1654
1654
1654
1654
1654
1654
1654
1653
1653
1653
1653
1653
1653
1653
1653
1653
1652
1652
1652
1652
1652
1652
1652
1652
1652
1652
1652
1652
1652
1652
1652
1652
1651
1651
1651
1651
1651
1651
1651
1651
1651
1651
1651
1651
1650
1650
1650
1650
1650
1650
1650
1650
1650
1649
1649
1649
1649
1649
1649
1649
1649
1648
1648
1648
1648
1648
1648
1648
1648
1648
1647
1647
1647
1647
1647
1647
1647
1647
1647
1647
1647
1647
1647
1647
1646
1646
1646
1646
1646
1646
1646
1646


1552
1552
1552
1552
1552
1552
1552
1552
1552
1552
1552
1552
1552
1552
1552
1552
1552
1552
1552
1552
1552
1552
1552
1552
1552
1552
1552
1551
1551
1551
1551
1551
1551
1551
1551
1551
1551
1551
1551
1551
1551
1551
1551
1551
1551
1551
1551
1551
1551
1551
1551
1551
1551
1551
1551
1551
1551
1551
1551
1551
1551
1551
1551
1550
1550
1550
1550
1550
1550
1550
1550
1550
1550
1550
1550
1550
1550
1550
1550
1550
1550
1550
1550
1550
1550
1550
1550
1550
1550
1550
1550
1550
1550
1550
1550
1550
1550
1550
1550
1550
1550
1550
1550
1550
1550
1550
1550
1550
1550
1550
1550
1550
1550
1550
1550
1550
1550
1550
1550
1550
1549
1549
1549
1549
1549
1549
1549
1549
1549
1549
1549
1549
1549
1549
1549
1549
1549
1549
1549
1549
1549
1549
1549
1549
1549
1549
1549
1549
1549
1549
1549
1549
1549
1549
1549
1549
1549
1549
1549
1548
1548
1548
1548
1548
1548
1548
1548
1548
1548
1548
1548
1548
1548
1548
1548
1548
1548
1548
1548
1548
1548
1548
1548
1548
1548
1548
1548
1548
1548
1548
1548
1548
1548
1547
1547
1547
1547
1547
1547
1547


1515
1515
1515
1515
1515
1515
1515
1515
1515
1515
1515
1515
1515
1515
1515
1515
1515
1515
1515
1515
1515
1515
1515
1515
1515
1515
1515
1515
1515
1515
1515
1515
1515
1515
1515
1515
1515
1515
1515
1515
1515
1515
1515
1515
1515
1515
1515
1515
1515
1515
1515
1515
1515
1515
1515
1515
1515
1515
1515
1515
1515
1515
1515
1515
1515
1515
1515
1515
1515
1515
1515
1515
1515
1515
1515
1515
1515
1515
1515
1515
1515
1515
1515
1515
1515
1515
1515
1515
1515
1515
1515
1515
1515
1515
1515
1515
1515
1515
1515
1515
1515
1515
1514
1514
1514
1514
1514
1514
1514
1514
1514
1514
1514
1514
1514
1514
1514
1514
1514
1514
1514
1514
1514
1514
1514
1514
1514
1514
1514
1514
1514
1514
1514
1514
1514
1514
1514
1514
1514
1514
1514
1514
1514
1514
1514
1514
1514
1514
1514
1514
1514
1514
1514
1514
1514
1514
1514
1514
1514
1514
1514
1514
1514
1514
1514
1514
1514
1514
1514
1514
1514
1514
1514
1514
1514
1514
1514
1514
1514
1514
1514
1514
1514
1514
1514
1514
1514
1514
1514
1514
1514
1514
1514
1514
1514
1514
1514
1514
1514
1514


1500
1500
1500
1500
1500
1500
1500
1500
1500
1500
1500
1500
1500
1500
1500
1500
1500
1500
1500
1500
1500
1500
1500
1500
1500
1500
1500
1500
1500
1500
1500
1500
1500
1500
1500
1500
1500
1500
1500
1500
1499
1498
1498
1498
1498
1498
1498
1497
1497
1497
1497
1496
1496
1495
1495
1494
1493
1493
1493
1492
1491
1491
1490
1488
1487
1486
1485
1485
1483
1481
1480
1480
1479
1479
1479
1479
1478
1477
1476
1475
1475
1474
1473
1471
1471
1470
1469
1469
1467
1464
1460
1450
1446
1443
1433
1433
1425
1408
1404
1389
1382
1376
1367
1361
1359
1265
1176
1103
1090
1052
1027
1025
1011
657
650
514
434
355
346
288
245
235
118
105
76
22


In [240]:
with open("knn.json", "w") as f:
    json.dump(knn, f)

In [241]:
neighbour_to_playlist = {}
for pid, neighbours in knn.items():
    for n in neighbours:
        if not neighbour_to_playlist.get(n.pid):
            neighbour_to_playlist[n.pid] = []
        neighbour_to_playlist[n.pid].append(pid)

In [242]:
len(neighbour_to_playlist)

270155

In [243]:
def extract_tracks_from_training_data(path):
    filenames = os.listdir(path)
    count = 0
    for filename in sorted(filenames):
        if filename.startswith("mpd.slice.") and filename.endswith(".json"):
            fullpath = os.sep.join((path, filename))
            with open(fullpath) as f:
                js = f.read()
                mpd_slice = json.loads(js)
                for playlist in mpd_slice["playlists"]:
                    count += 1
                    if count % 1000 == 0:
                        print(count)
                    extract_tracks(playlist)

In [244]:
recommendations = {}
for pid in knn:
    recommendations[pid] = []

In [245]:
def extract_tracks(playlist):
    track_uris = set([t["track_uri"] for t in playlist["tracks"]])
    for test_pid in neighbour_to_playlist.get(playlist["pid"], []):
        not_common = track_uris - pid_to_set_tracks[test_pid]
        recommendations[test_pid].extend(not_common)

In [246]:
extract_tracks_from_training_data("dataset/spotify/data")

1000
2000
3000
4000
5000
6000
7000
8000
9000
10000
11000
12000
13000
14000
15000
16000
17000
18000
19000
20000
21000
22000
23000
24000
25000
26000
27000
28000
29000
30000
31000
32000
33000
34000
35000
36000
37000
38000
39000
40000
41000
42000
43000
44000
45000
46000
47000
48000
49000
50000
51000
52000
53000
54000
55000
56000
57000
58000
59000
60000
61000
62000
63000
64000
65000
66000
67000
68000
69000
70000
71000
72000
73000
74000
75000
76000
77000
78000
79000
80000
81000
82000
83000
84000
85000
86000
87000
88000
89000
90000
91000
92000
93000
94000
95000
96000
97000
98000
99000
100000
101000
102000
103000
104000
105000
106000
107000
108000
109000
110000
111000
112000
113000
114000
115000
116000
117000
118000
119000
120000
121000
122000
123000
124000
125000
126000
127000
128000
129000
130000
131000
132000
133000
134000
135000
136000
137000
138000
139000
140000
141000
142000
143000
144000
145000
146000
147000
148000
149000
150000
151000
152000
153000
154000
155000
156000
157000
158000
15

In [247]:
with open("recommendations.json", "w") as f:
    json.dump(recommendations, f)

In [248]:
recs_compressed = {}

for pid, recs in recommendations.items():
    counts = collections.Counter()
    for r in recs:
        counts[r] += 1
    print(len(counts))
    recs_compressed[pid] = [t for t, _ in counts.most_common(500)]

0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0


903
845
924
1200
770
613
785
1017
1012
297
620
969
644
663
1176
1312
872
970
911
1069
834
977
1086
678
599
960
896
1209
1262
934
170
1040
605
653
1158
809
768
776
975
651
740
748
905
1162
803
921
703
694
1251
1109
720
977
434
898
1069
655
1010
858
823
569
382
709
670
665
703
876
642
1069
763
722
887
757
706
1116
328
624
692
964
803
993
655
924
941
638
858
1107
1111
1038
884
839
694
1248
969
698
889
953
1087
958
1129
1068
862
1113
690
1026
917
913
1075
817
705
1042
651
1192
785
1265
624
617
1210
1032
1421
879
771
957
1002
823
801
770
839
758
794
1216
1200
1046
993
870
646
990
1198
1163
1188
1215
1143
691
1169
608
1237
769
1073
885
929
909
1251
1004
1084
1383
881
904
801
790
754
527
1044
1203
847
591
861
1073
1094
1307
1072
1080
487
1069
765
1225
1012
911
776
939
989
1156
998
531
1179
768
999
1043
1186
703
860
424
1167
700
417
1362
697
1045
945
961
1097
1235
833
722
968
787
808
578
1106
730
738
722
1138
1168
801
896
542
1266
870
711
841
958
841
780
651
1237
979
1180
746
538
802
964
861
4

973
1033
721
637
799
714
590
1029
804
900
677
1226
673
1067
667
821
862
646
1022
954
1251
593
1117
600
992
772
967
728
672
1281
780
591
384
1047
1087
888
1319
728
1227
1110
973
913
710
718
585
1076
703
695
690
863
1058
651
667
1196
969
589
958
1050
825
1023
644
845
705
1207
1061
1383
1218
829
993
971
1100
670
696
990
407
946
657
1141
621
822
882
1061
607
791
715
664
546
1232
433
618
848
766
520
890
1238
1114
981
816
1186
575
662
678
650
771
908
1098
603
486
785
1154
952
524
885
939
810
782
1015
1123
948
1072
963
714
877
779
756
677
788
702
645
1333
858
785
563
1109
847
863
355
652
1213
590
694
1188
1140
840
657
496
739
694
615
892
757
554
872
940
834
860
720
1248
676
762
1228
982
667
1192
782
711
740
665
820
1115
975
930
818
978
747
1096
445
1251
727
709
878
581
949
682
1162
1020
876
491
893
604
1172
685
1084
540
623
890
952
822
1345
1043
916
765
684
986
934
1187
811
1366
1256
740
627
601
701
1217
797
937
707
857
872
1162
945
1038
913
1048
652
1150
571
1205
665
792
720
933
452
686
785


1375
649
797
873
608
597
816
1089
679
898
549
1123
871
467
966
1001
880
618
1274
555
792
1114
761
704
924
685
860
547
619
775
734
609
344
1089
940
1079
702
923
1517
778
759
613
703
727
1036
588
917
712
931
520
1228
1024
717
375
676
1054
795
595
917
1006
747
742
754
592
649
676
885
817
294
753
611
831
824
626
820
710
704
779
928
645
780
639
435
669
1286
1022
1216
1107
996
908
937
718
809
1062
878
1400
878
882
887
654
1082
1051
921
842
1108
563
1002
980
720
833
609
654
587
952
907
829
900
579
874
671
801
718
1004
689
836
755
789
730
801
645
558
686
878
1011
548
872
717
331
574
1135
1137
950
738
708
683
997
1192
547
655
770
784
861
835
802
600
818
657
487
718
684
607
574
656
1061
1094
1223
805
741
785
537
808
783
816
672
657
746
1132
556
609
496
528
773
1175
894
385
1172
630
855
284
976
755
769
809
770
806
960
1024
1025
534
999
885
848
831
444
1083
1002
596
775
832
789
793
879
849
757
834
870
818
581
254
1033
973
1119
883
1763
843
743
1019
1294
564
576
852
922
1158
1012
645
929
632
757
79

858
299
891
1185
1068
1464
1002
988
936
735
901
521
1046
853
874
1119
87
628
977
962
864
913
1054
1155
784
530
1076
665
729
1021
593
573
1082
1077
825
1070
339
881
1188
908
1302
837
1005
1003
783
972
902
719
1109
747
936
854
1120
696
807
815
840
648
738
864
643
1041
984
1129
680
888
590
967
885
983
1202
713
934
853
816
1575
819
1140
1308
950
743
888
778
694
779
897
880
896
721
850
818
936
1005
551
960
597
835
995
1268
1126
672
580
777
910
753
762
1115
769
1134
770
1324
688
693
968
908
991
710
946
1098
811
1041
912
634
789
620
699
1008
897
646
135
574
833
878
1115
1266
890
591
522
1305
877
259
924
1339
151
614
799
761
662
746
615
746
910
915
1074
729
1147
665
843
1140
593
968
1234
744
619
778
726
935
625
648
736
881
1348
853
946
773
1023
904
987
1241
1009
1063
670
1058
670
739
1060
1056
811
1109
726
1314
1092
1077
1035
957
765
605
1055
1947
876
873
915
962
988
1011
278
1593
1522
1015
1104
1076
1011
731
1199
655
972
950
869
761
967
1350
825
722
527
594
1068
1026
1063
972
757
1169
754
730

In [249]:
with open("knn_submission.csv", "w") as f:
    f.write("team_info,Pepilipep,p.angelov99@gmail.com\n")
    for x, y in recs_compressed.items():
        recs_string = ",".join(y)
        f.write(f"{x},{recs_string}\n")