In [6]:
from torch_geometric.datasets import OGB_MAG

dataset = OGB_MAG(root='./data', preprocess='metapath2vec')
data = dataset[0]

In [7]:
data

HeteroData(
  [1mpaper[0m={
    x=[736389, 128],
    year=[736389],
    y=[736389],
    train_mask=[736389],
    val_mask=[736389],
    test_mask=[736389]
  },
  [1mauthor[0m={ x=[1134649, 128] },
  [1minstitution[0m={ x=[8740, 128] },
  [1mfield_of_study[0m={ x=[59965, 128] },
  [1m(author, affiliated_with, institution)[0m={ edge_index=[2, 1043998] },
  [1m(author, writes, paper)[0m={ edge_index=[2, 7145660] },
  [1m(paper, cites, paper)[0m={ edge_index=[2, 5416271] },
  [1m(paper, has_topic, field_of_study)[0m={ edge_index=[2, 7505078] }
)

In [8]:
import os
import networkx as nx
import json
from tqdm import tqdm
import torch
import torch_geometric

def nx2hetero(G):
	ids_by_type = {
		"playlist": {},
		"track": {},
		"artist": {},
		"album": {}
	}
	def node_id(type, id, exception=False):
		d = ids_by_type[type]
		if id not in d:
			if exception:
				raise Exception(f'node {id} not found for type {type}')
			d[id] = len(d)
		return d[id]

	
	node_features_by_type = {
		"playlist": [],
		"track": [],
		"artist": [],
		"album": []
	}
	for node in G.nodes(data=True):
		t = node[1]["node_type"]
		if t == "playlist":
			node_id("playlist", node[0])
			node_features_by_type["playlist"] += [[node[1]["num_followers"], 1, 1, 1, 1, 1]]
		elif t == "track":
			node_id("track", node[0])
			node_features_by_type["track"] += [[node[1]["duration"], 1, 1, 1, 1, 1]]
		elif t == "artist":
			node_id("artist", node[0])
			node_features_by_type["artist"] += [[1, 1, 1, 1, 1, 1]]
		elif t == "album":
			node_id("album", node[0])
			node_features_by_type["album"] += [[1, 1, 1, 1, 1, 1]]


	edge_index_by_type = {
		("playlist", "contains", "track"): [],
		("album", "includes", "track"): [],
		("artist", "authors", "track"): []
	}
	for edge in G.edges(data=True):
		if G[edge[0]][edge[1]]["edge_type"] == "track-playlist":
			s_id = node_id("track", edge[0], exception=True)
			t_id = node_id("playlist", edge[1], exception=True)
			
			edge_index_by_type[("playlist", "contains", "track")] += [(t_id, s_id)]
		elif G[edge[0]][edge[1]]["edge_type"] == "track-album":
			s_id = node_id("track", edge[0], exception=True)
			t_id = node_id("album", edge[1], exception=True)
			
			edge_index_by_type[("album", "includes", "track")] += [(t_id, s_id)]
		elif G[edge[0]][edge[1]]["edge_type"] == "track-artist":
			s_id = node_id("track", edge[0], exception=True)
			t_id = node_id("artist", edge[1], exception=True)
			
			edge_index_by_type[("artist", "authors", "track")] += [(t_id, s_id)]

	# construct HeteroData
	hetero = torch_geometric.data.HeteroData()

	# add initial node features
	hetero["playlist"].x = torch.FloatTensor(node_features_by_type["playlist"])
	hetero["track"].x = torch.FloatTensor(node_features_by_type["track"])
	hetero["artist"].x = torch.FloatTensor(node_features_by_type["artist"])
	hetero["album"].x = torch.FloatTensor(node_features_by_type["album"])	
	
	# add edge indices
	hetero["playlist", "contains", "track"].edge_index = torch.tensor(edge_index_by_type[("playlist", "contains", "track")]).t()
	hetero["album", "includes", "track"].edge_index = torch.tensor(edge_index_by_type[("album", "includes", "track")]).t()
	hetero["artist", "authors", "track"].edge_index = torch.tensor(edge_index_by_type[("artist", "authors", "track")]).t()

	return hetero


In [9]:
import pickle
base = "spotify_million_playlist_dataset"
pickles = base + "/pickles"
graph_path = os.path.join(pickles, "top-G-500.pkl")

G = pickle.load(open(graph_path, "rb"))

our_data = nx2hetero(G)

Exception: node spotify:album:0vlAYzvBDkRrRFpmR4v5MF not found for type track

In [None]:
len(list(nx.connected_components(G.to_undirected())))

7

In [None]:
# create training mask for playlist nodes
train_mask = torch.zeros(our_data["playlist"].x.shape[0], dtype=torch.bool)
train_mask[torch.randperm(train_mask.shape[0])[:int(train_mask.shape[0]*0.8)]] = True

our_data["playlist"].train_mask = train_mask

our_data["playlist"].y = torch.LongTensor([1]*our_data["playlist"].x.shape[0])

In [None]:
our_data["playlist", "contains", "track"].edge_index

tensor([[    0,   515,   664,  ...,   999,   999,   999],
        [    0,     0,     0,  ..., 35286, 35287, 35288]])

In [None]:
our_data.metadata()

(['playlist', 'track', 'artist', 'album'],
 [('playlist', 'contains', 'track'),
  ('album', 'includes', 'track'),
  ('artist', 'authors', 'track')])

In [None]:
data = our_data

In [None]:
data

HeteroData(
  [1mplaylist[0m={
    x=[1000, 6],
    train_mask=[1000],
    y=[1000]
  },
  [1mtrack[0m={ x=[35289, 6] },
  [1martist[0m={ x=[10091, 6] },
  [1malbum[0m={ x=[20469, 6] },
  [1m(playlist, contains, track)[0m={ edge_index=[2, 66331] },
  [1m(album, includes, track)[0m={ edge_index=[2, 35289] },
  [1m(artist, authors, track)[0m={ edge_index=[2, 35289] }
)

In [None]:
data.is_undirected()

False

In [None]:
homogeneous_data = data.to_homogeneous()
homogeneous_data

Data(edge_index=[2, 136909], x=[66849, 6], train_mask=[66849], y=[66849], node_type=[66849], edge_type=[136909])

In [None]:
data

HeteroData(
  [1mplaylist[0m={
    x=[1000, 6],
    train_mask=[1000],
    y=[1000]
  },
  [1mtrack[0m={ x=[35289, 6] },
  [1martist[0m={ x=[10091, 6] },
  [1malbum[0m={ x=[20469, 6] },
  [1m(playlist, contains, track)[0m={ edge_index=[2, 66331] },
  [1m(album, includes, track)[0m={ edge_index=[2, 35289] },
  [1m(artist, authors, track)[0m={ edge_index=[2, 35289] }
)

In [None]:
import torch_geometric.transforms as T

if not data.is_undirected():
    data = T.ToUndirected()(data)
# data = T.NormalizeFeatures()(data)
if data.has_isolated_nodes():
    data = T.RemoveIsolatedNodes()(data)

In [None]:
data.is_undirected(), data.has_isolated_nodes()

(True, False)

In [None]:
import torch_geometric.transforms as T
from torch_geometric.datasets import OGB_MAG
from torch_geometric.nn import SAGEConv, to_hetero
import torch
import torch.nn.functional as F

class GNN(torch.nn.Module):
    def __init__(self, hidden_channels):
        super().__init__()
        self.conv1 = SAGEConv((-1, -1), hidden_channels, normalize=True)
        self.conv2 = SAGEConv((-1, -1), hidden_channels, normalize=True)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index).relu()
        x = self.conv2(x, edge_index)
        return x
    
class LinkPredictor(torch.nn.Module):
    def forward(self, x_playlist, x_track, playlist_track_edge):
        playlist_embedding = x_playlist[playlist_track_edge[0]]
        track_embedding = x_track[playlist_track_edge[1]]

        # Apply dot-product to get a prediction per supervision edge:
        return (playlist_embedding * track_embedding).sum(dim=-1)

class HeteroModel(torch.nn.Module):
    def __init__(self, hidden_channels, node_features, metadata):
        super().__init__()
        # Since the dataset does not come with rich features, we also learn two
        # embedding matrices for users and movies:
        
        self.node_lin = {
            k: torch.nn.Linear(v.shape[1], hidden_channels, bias=True) for k, v in node_features.items()
        }
        
        for _, v in self.node_lin.items():
            torch.nn.init.xavier_uniform_(v.weight)

        # Instantiate homogeneous GNN:
        self.gnn = GNN(hidden_channels)
        # Convert GNN model into a heterogeneous variant:
        self.gnn = to_hetero(self.gnn, metadata=metadata)

        self.classifier = LinkPredictor()

    def forward(self, data):
        x_dict = {
            k: self.node_lin[k](v) for k, v in data.x_dict.items()
        }
        
        x_dict = self.gnn(x_dict, data.edge_index_dict)
        pred = self.classifier(
            x_dict["playlist"],
            x_dict["track"],
            data["playlist", "contains", "track"].edge_label_index,
        )
        return pred
    
    def reset_parameters(self):
        for _, v in self.node_lin.items():
            torch.nn.init.xavier_uniform_(v.weight)
        self.gnn.reset_parameters()


model = HeteroModel(256, data.x_dict, data.metadata())
# model = model.to('cuda:0')

In [None]:
data.validate()

True

In [None]:
import torch_geometric.transforms as T

transform = T.RandomLinkSplit(
    num_val=0.1,
    num_test=0.1,
    disjoint_train_ratio=0.3,
    neg_sampling_ratio=2.0,
    add_negative_train_samples=False,
    edge_types=("playlist", "contains", "track"),
    rev_edge_types=("track", "rev_contains", "playlist"), 
)

train_data, val_data, test_data = transform(data)

In [None]:
from torch_geometric.loader import LinkNeighborLoader

edge_label_index = train_data["playlist", "contains", "track"].edge_label_index
edge_label = train_data["playlist", "contains", "track"].edge_label
train_loader = LinkNeighborLoader(
    data=train_data,
    num_neighbors=[20, 10],
    neg_sampling_ratio=2.0,
    edge_label_index=(("playlist", "contains", "track"), edge_label_index),
    edge_label=edge_label,
    batch_size=128,
    shuffle=True,
)

batch = next(iter(train_loader))

In [None]:
def train():
    model.train()

    total_examples = total_loss = 0

    sample_outputs = []
    for batch in tqdm(train_loader, desc='Training'):
        optimizer.zero_grad()
        # batch = batch.to('cuda:0')
        out = model(batch)
        
        sample_outputs.append(round(out[0].item(), 2))
        loss = F.cross_entropy(out,
                               batch["playlist", "contains", "track"].edge_label)
        loss.backward()
        optimizer.step()

        total_examples += len(out)
        total_loss += float(loss) * len(out)

    return total_loss / total_examples, sample_outputs

In [None]:
import tqdm
import torch.nn.functional as F

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device = 'cpu'
print(f"Device: '{device}'")
model = model.to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=0.005)
for epoch in range(1, 100):
    total_loss = total_examples = 0
    for sampled_data in tqdm.tqdm(train_loader):
        optimizer.zero_grad()
        sampled_data.to(device)
        pred = model(sampled_data)
        ground_truth = sampled_data["playlist", "contains", "track"].edge_label
        loss = F.binary_cross_entropy_with_logits(pred, ground_truth)
        loss.backward()
        optimizer.step()
        total_loss += float(loss) * pred.numel()
        total_examples += pred.numel()
    print(f"Epoch: {epoch:03d}, Loss: {total_loss / total_examples:.4f}")

Device: 'cpu'


100%|███████████████████████████████████████████████| 125/125 [00:25<00:00,  4.98it/s]


Epoch: 001, Loss: 0.6432


100%|███████████████████████████████████████████████| 125/125 [00:25<00:00,  4.91it/s]


Epoch: 002, Loss: 0.6362


100%|███████████████████████████████████████████████| 125/125 [00:24<00:00,  5.03it/s]


Epoch: 003, Loss: 0.6344


100%|███████████████████████████████████████████████| 125/125 [00:25<00:00,  4.92it/s]


Epoch: 004, Loss: 0.6337


100%|███████████████████████████████████████████████| 125/125 [00:24<00:00,  5.01it/s]


Epoch: 005, Loss: 0.6366


100%|███████████████████████████████████████████████| 125/125 [00:24<00:00,  5.04it/s]


Epoch: 006, Loss: 0.6357


100%|███████████████████████████████████████████████| 125/125 [00:24<00:00,  5.03it/s]


Epoch: 007, Loss: 0.6390


100%|███████████████████████████████████████████████| 125/125 [00:24<00:00,  5.06it/s]


Epoch: 008, Loss: 0.6350


100%|███████████████████████████████████████████████| 125/125 [00:24<00:00,  5.03it/s]


Epoch: 009, Loss: 0.6383


100%|███████████████████████████████████████████████| 125/125 [00:25<00:00,  4.92it/s]


Epoch: 010, Loss: 0.6392


 54%|██████████████████████████                      | 68/125 [00:13<00:11,  4.94it/s]


KeyboardInterrupt: 

In [None]:
# Define the validation seed edges:
edge_label_index = val_data["playlist", "contains", "track"].edge_label_index
edge_label = val_data["playlist", "contains", "track"].edge_label
val_loader = LinkNeighborLoader(
    data=val_data,
    num_neighbors=[20, 10],
    edge_label_index=(("playlist", "contains", "track"), edge_label_index),
    edge_label=edge_label,
    batch_size=3 * 128,
    shuffle=False,
)
sampled_data = next(iter(val_loader))

In [None]:
from sklearn.metrics import roc_auc_score
preds = []
ground_truths = []
for sampled_data in tqdm.tqdm(val_loader):
    with torch.no_grad():
        sampled_data.to(device)
        preds.append(model(sampled_data))
        ground_truths.append(sampled_data["playlist", "contains", "track"].edge_label)
pred = torch.cat(preds, dim=0).cpu().numpy()
ground_truth = torch.cat(ground_truths, dim=0).cpu().numpy()
auc = roc_auc_score(ground_truth, pred)
print()
print(f"Validation AUC: {auc:.4f}")

100%|█████████████████████████████████████████████████| 52/52 [00:01<00:00, 28.42it/s]


Validation AUC: 0.6357



