-
Notifications
You must be signed in to change notification settings - Fork 0
/
backend.py
189 lines (156 loc) · 6.64 KB
/
backend.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
# some methods are adapted from https://github.com/plamere/spotipy/blob/master/examples/artist_discography.py
#Shows the list of all songs sung by the artist or the band
import logging
from typing import List, Dict
from neo4j import (
GraphDatabase,
basic_auth,
)
from spotipy.oauth2 import SpotifyClientCredentials
import spotipy
from tqdm import tqdm
logger = logging.getLogger('artist-connections')
logging.basicConfig(level='INFO')
# set to true for debugging
logger.disabled = True # set to False for debugging info to get printed to console
def get_artist(name:str) -> Dict:
"""gets dictionary object containing data of artist"""
results = sp.search(q='artist:' + name, type='artist')
items = results['artists']['items']
if len(items) > 0:
return items[0]
else:
return None
def get_artist_image_url(name:str) -> str:
"""get image url for artist"""
artist = get_artist(name)
if artist and artist['images']:
return artist['images'][0]["url"]
else:
return ""
def get_artist_tracks(artist:Dict) -> List[Dict]:
"""get list of artist tracks (tracks are stored as dictionaries)"""
albums = []
results = sp.artist_albums(artist['id'], album_type='album')
albums.extend(results['items'])
while results['next']:
results = sp.next(results)
albums.extend(results['items'])
logger.info('Total albums: %s', len(albums))
unique_albums = set() # skip duplicate albums
tracks = []
for album in albums:
name = album['name'].lower()
if name not in unique_albums:
logger.info('ALBUM: %s', name)
unique_albums.add(name)
results = sp.album_tracks(album['id'])
tracks.extend(results['items'])
while results['next']:
results = sp.next(results)
tracks.extend(results['items'])
for i, track in enumerate(tracks):
logger.info('%s. %s', i+1, track['name'])
return tracks
#TODO: Fix special characters in song (currently have an extra slash)
#TODO: artists are repeated in database, matching fake profiles with same names
def write_track_to_database(tx, track:Dict):
"""pushes track and associated artists to database"""
global driver
table = str.maketrans({
"-": r"\-",
"]": r"\]",
"\\": r"\\",
"^": r"\^",
"$": r"\$",
"*": r"\*",
".": r"\.",
"'": r"\'",
'"': r'\"',
"’": r'\’'
})
logger.info(f"Writing to database: {str(track['name'])}-{str([artist['name'] for artist in track['artists']])}")
command = ""
# create track node
command += f"""MERGE (t{track["id"]}:Track {{name:'{track["name"].translate(table)}',
spotify_id:'{track["id"]}', link:'{track["href"]}'}})"""
# create node for artist if it doesn't exist
for artist in track['artists']:
artist_id = "".join(artist["name"].split(" "))
new_id = ""
for char in artist_id:
if char.isnumeric():
new_id += chr(int(char)+65) # hack to only have alphabet in id (for artists with numbers in their names)
elif char.isalpha():
new_id += char # add regular characters to the new id, skip anything weird
command += f"""
MERGE ({new_id}:Artist {{name:'{artist["name"].replace("'", "")}', link:'{artist["external_urls"]["spotify"]}', id:'{artist['id']}'
}}) """
command += f"""MERGE ({new_id})-[:PERFORMED_IN]->(t{track["id"]})"""
logger.info(command)
tx.run(command)
def check_artist_exists(artist_name: str):
"""check if artist exists in database"""
global driver
def run_query(tx, artist:str):
command = f"MATCH (u:Artist {{name: '{artist}'}}) WITH COUNT(u) > 0 as node_exists RETURN node_exists"
result = tx.run(command)
return result.single()["node_exists"]
with driver.session() as session:
path = session.read_transaction(run_query, artist_name)
return path
def find_collaboration_path(artist1:str, artist2:str):
"""find connecting collaboration path between two artists
make sure to check if both artists exist first
returns None if no path is found
"""
global driver
def run_query(tx, artist1: str, artist2:str):
command = f'MATCH p=shortestPath((a1:Artist {{name:"{artist1}"}})-[*]-(a2:Artist {{name:"{artist2}"}}))RETURN p'
result = tx.run(command)
result = result.single()
if result == None:
return None
return (result["p"].nodes)
with driver.session() as session:
path = session.read_transaction(run_query, artist1, artist2)
return path
def find_direct_collaborators(artist:str):
"""find artists that are n jumps away from <artist>"""
global driver
session = driver.session()
result =session.run(f"MATCH (a:Artist {{name:'{artist}'}})-[*2]-(b:Artist) RETURN DISTINCT b")
find_direct_collaborators = [record["b"] for record in result]
session.close()
return find_direct_collaborators
def count_direct_collaborations(artist1:str, artist2:str) -> int:
"""count the number of direct collaborations between artist1 and artist2"""
global driver
session = driver.session()
query = f"""MATCH (a:Artist {{name:'{artist1}'}})-[:PERFORMED_IN]-(t:Track)-[:PERFORMED_IN]-(b:Artist{{name:'{artist2}'}}) RETURN DISTINCT count(t) as num_collabs"""
result = session.run(query)
num_collabs = result.single()["num_collabs"]
session.close()
return num_collabs
#TODO: once we have flask up and running, make driver and sp part of Flask global context instead of passing as arguments
def populate_database(list_of_artists: List[str]):
"""populates neo4j database with song data of artists"""
# global driver
global sp
global driver
for i in tqdm(range(0, len(list_of_artists)), unit=" artist", desc="Populating database"):
artist_name = list_of_artists[i]
artist = get_artist(artist_name)
for track in get_artist_tracks(artist):
# only write tracks with collaborators
if len(track['artists']) > 1:
with driver.session() as session:
session.write_transaction(write_track_to_database, track)
def save( driver_:GraphDatabase.driver, sp_:spotipy.Spotify):
"""temporary utility method to send over db driver and sp connection from whatever file we are running from
replace with Flask g once we get that set up
using this function to expose the driver and sp resources to all the methods in this file
"""
global driver, sp
driver = driver_
sp = sp_