In [1]:
import os, sys
sys.path.append('..')
import json

import pandas as pd

from tqdm import tqdm
from arena_util import write_json
from arena_util import remove_seen
from gensim.models import Word2Vec
from gensim.models.keyedvectors import WordEmbeddingsKeyedVectors

In [2]:
class PlaylistEmbedding:
    def __init__(self, FILE_PATH):
        self.FILE_PATH = FILE_PATH
        self.min_count = 3
        self.size = 100
        self.window = 210
        self.sg = 5
        self.p2v_model = WordEmbeddingsKeyedVectors(self.size)
        with open(os.path.join(FILE_PATH, 'train.json'), encoding="utf-8") as f:
            self.train = json.load(f)
        with open(os.path.join(FILE_PATH, 'val.json'), encoding="utf-8") as f:
            self.val = json.load(f)
        with open(os.path.join('../arena_data/results/', 'genre_most_popular_results.json'), encoding="utf-8") as f:
            self.most_results = json.load(f)
            
    def get_dic(self, train, val):
        song_dic = {}
        tag_dic = {}
        data = train + val
        for q in tqdm(data):
            song_dic[str(q['id'])] = q['songs']
            tag_dic[str(q['id'])] = q['tags']
        self.song_dic = song_dic
        self.tag_dic = tag_dic
        total = list(map(lambda x: list(map(str, x['songs'])) + list(x['tags']), data))
        total = [x for x in total if len(x)>1]
        self.total = total
        
    def get_w2v(self, total, min_count, size, window, sg):
        w2v_model = Word2Vec(total, min_count = min_count, size = size, window = window, sg = sg)
        self.w2v_model = w2v_model
            
    def update_p2v(self, train, val, w2v_model):
        ID = []   
        vec = []
        for q in tqdm(train + val):
            tmp_vec = 0
            if len(q['songs'])>=1:
                for song in q['songs'] + q['tags']:
                    try: 
                        tmp_vec += w2v_model.wv.get_vector(str(song))
                    except KeyError:
                        pass
            if type(tmp_vec)!=int:
                ID.append(str(q['id']))    
                vec.append(tmp_vec)
        self.p2v_model.add(ID, vec)
    
    def get_result(self, p2v_model, song_dic, tag_dic, most_results, val):
        answers = []
        for n, q in tqdm(enumerate(val), total = len(val)):
            try:
                most_id = [x[0] for x in p2v_model.most_similar(str(q['id']), topn=200)]
                get_song = []
                get_tag = []
                for ID in most_id:
                    if ID in song_dic.keys():
                        get_song += song_dic[ID]
                    if ID in tag_dic.keys():
                        get_tag += tag_dic[ID]
                get_song = list(pd.value_counts(get_song)[:200].index)
                get_tag = list(pd.value_counts(get_tag)[:20].index)
                answers.append({
                    "id": q["id"],
                    "songs": remove_seen(q["songs"], get_song)[:100],
                    "tags": remove_seen(q["tags"], get_tag)[:10],
                })
            except:
                answers.append({
                  "id": most_results[n]["id"],
                  "songs": most_results[n]['songs'],
                  "tags": most_results[n]["tags"],
                }) 
        # check and update answer
        for n, q in enumerate(tqdm(answers)):
            if len(q['songs'])!=100:
                answers[n]['songs'] += remove_seen(q['songs'], self.most_results[n]['songs'])[:100-len(q['songs'])]
            if len(q['tags'])!=10:
                answers[n]['tags'] += remove_seen(q['tags'], self.most_results[n]['tags'])[:10-len(q['tags'])]  
        self.answers = answers
    
    def run(self):
        print('get_dic')
        self.get_dic(self.train, self.val)
        print('get_w2v')
        self.get_w2v(self.total, self.min_count, self.size, self.window, self.sg)
        print('update_p2v')
        self.update_p2v(self.train, self.val, self.w2v_model)
        print('get_result')
        self.get_result(self.p2v_model, self.song_dic, self.tag_dic, self.most_results, self.val)
        write_json(self.answers, 'results_word2vec.json')

In [3]:
path = '/Volumes/Seagate Backup Plus Drive/KIE/dataset/melon/'
U_space = PlaylistEmbedding(path)

In [4]:
U_space.run()

 32%|███▏      | 43691/138086 [00:00<00:00, 432238.98it/s]get_dic
100%|██████████| 138086/138086 [00:00<00:00, 493675.92it/s]
get_w2v
  0%|          | 383/138086 [00:00<00:36, 3822.66it/s]update_p2v
100%|██████████| 138086/138086 [00:24<00:00, 5582.61it/s] 
  0%|          | 1/23015 [00:00<57:19,  6.69it/s]get_result
100%|██████████| 23015/23015 [03:49<00:00, 100.41it/s]
100%|██████████| 23015/23015 [00:00<00:00, 772521.00it/s]
