In [176]:
from torch_geometric.datasets import OGB_MAG

dataset = OGB_MAG(root='./data', preprocess='metapath2vec')
data = dataset[0]

In [177]:
data

HeteroData(
  [1mpaper[0m={
    x=[736389, 128],
    year=[736389],
    y=[736389],
    train_mask=[736389],
    val_mask=[736389],
    test_mask=[736389]
  },
  [1mauthor[0m={ x=[1134649, 128] },
  [1minstitution[0m={ x=[8740, 128] },
  [1mfield_of_study[0m={ x=[59965, 128] },
  [1m(author, affiliated_with, institution)[0m={ edge_index=[2, 1043998] },
  [1m(author, writes, paper)[0m={ edge_index=[2, 7145660] },
  [1m(paper, cites, paper)[0m={ edge_index=[2, 5416271] },
  [1m(paper, has_topic, field_of_study)[0m={ edge_index=[2, 7505078] }
)

In [178]:
import os
import networkx as nx
import json
from tqdm import tqdm
import torch
import torch_geometric

def nx2hetero(G):
	ids_by_type = {
		"playlist": {},
		"track": {},
		"artist": {},
		"album": {}
	}
	def node_id(type, id, exception=False):
		d = ids_by_type[type]
		if id not in d:
			if exception:
				raise Exception(f'node {id} not found for type {type}')
			d[id] = len(d)
		return d[id]

	
	node_features_by_type = {
		"playlist": [],
		"track": [],
		"artist": [],
		"album": []
	}
	for node in G.nodes(data=True):
		t = node[1]["node_type"]
		if t == "playlist":
			node_id("playlist", node[0])
			node_features_by_type["playlist"] += [node[1]["num_followers"]]
		elif t == "track":
			node_id("track", node[0])
			node_features_by_type["track"] += [node[1]["duration"]]
		elif t == "artist":
			node_id("artist", node[0])
			node_features_by_type["artist"] += [1]
		elif t == "album":
			node_id("album", node[0])
			node_features_by_type["album"] += [1]

	print(ids_by_type["playlist"])


	edge_index_by_type = {
		("playlist", "contains", "track"): [],
		("album", "includes", "track"): [],
		("artist", "authors", "track"): []
	}
	for edge in G.edges(data=True):
		if G[edge[0]][edge[1]]["edge_type"] == "track-playlist":
			s_id = node_id("track", edge[0], exception=True)
			t_id = node_id("playlist", edge[1], exception=True)
			
			edge_index_by_type[("playlist", "contains", "track")] += [(t_id, s_id)]
		elif G[edge[0]][edge[1]]["edge_type"] == "track-album":
			s_id = node_id("track", edge[0], exception=True)
			t_id = node_id("album", edge[1], exception=True)
			
			edge_index_by_type[("album", "includes", "track")] += [(t_id, s_id)]
		elif G[edge[0]][edge[1]]["edge_type"] == "track-artist":
			s_id = node_id("track", edge[0], exception=True)
			t_id = node_id("artist", edge[1], exception=True)
			
			edge_index_by_type[("artist", "authors", "track")] += [(t_id, s_id)]

	# construct HeteroData
	hetero = torch_geometric.data.HeteroData()

	# add initial node features
	hetero["playlist"].x = torch.FloatTensor(node_features_by_type["playlist"]).reshape(-1,1)
	hetero["track"].x = torch.FloatTensor(node_features_by_type["track"]).reshape(-1,1)
	hetero["artist"].x = torch.FloatTensor(node_features_by_type["artist"]).reshape(-1,1)
	hetero["album"].x = torch.FloatTensor(node_features_by_type["album"]).reshape(-1,1)

	# add edge indices
	hetero["playlist", "contains", "track"].edge_index = torch.tensor(edge_index_by_type[("playlist", "contains", "track")]).t()
	hetero["album", "includes", "track"].edge_index = torch.tensor(edge_index_by_type[("album", "includes", "track")]).t()
	hetero["artist", "authors", "track"].edge_index = torch.tensor(edge_index_by_type[("artist", "authors", "track")]).t()

	return hetero


In [179]:
import pickle
base = "spotify_million_playlist_dataset"
pickles = base + "/pickles"
graph_path = os.path.join(pickles, "G_example.pkl")

G = pickle.load(open(graph_path, "rb"))

our_data = nx2hetero(G)

{'spotify:playlist:42000': 0, 'spotify:playlist:42001': 1, 'spotify:playlist:42002': 2, 'spotify:playlist:42003': 3, 'spotify:playlist:42004': 4, 'spotify:playlist:42005': 5, 'spotify:playlist:42006': 6, 'spotify:playlist:42007': 7, 'spotify:playlist:42008': 8, 'spotify:playlist:42009': 9, 'spotify:playlist:42010': 10, 'spotify:playlist:42011': 11, 'spotify:playlist:42012': 12, 'spotify:playlist:42013': 13, 'spotify:playlist:42014': 14, 'spotify:playlist:42015': 15, 'spotify:playlist:42016': 16, 'spotify:playlist:42017': 17, 'spotify:playlist:42018': 18, 'spotify:playlist:42019': 19, 'spotify:playlist:42020': 20, 'spotify:playlist:42021': 21, 'spotify:playlist:42022': 22, 'spotify:playlist:42023': 23, 'spotify:playlist:42024': 24, 'spotify:playlist:42025': 25, 'spotify:playlist:42026': 26, 'spotify:playlist:42027': 27, 'spotify:playlist:42028': 28, 'spotify:playlist:42029': 29, 'spotify:playlist:42030': 30, 'spotify:playlist:42031': 31, 'spotify:playlist:42032': 32, 'spotify:playlist:4

In [180]:
# create training mask for playlist nodes
train_mask = torch.zeros(our_data["playlist"].x.shape[0], dtype=torch.bool)
train_mask[torch.randperm(train_mask.shape[0])[:int(train_mask.shape[0]*0.8)]] = True

our_data["playlist"].train_mask = train_mask

our_data["playlist"].y = torch.LongTensor([1]*our_data["playlist"].x.shape[0])

In [181]:
our_data["playlist", "contains", "track"].edge_index

tensor([[    0,   515,   664,  ...,   999,   999,   999],
        [    0,     0,     0,  ..., 35286, 35287, 35288]])

In [182]:
our_data.metadata()

(['playlist', 'track', 'artist', 'album'],
 [('playlist', 'contains', 'track'),
  ('album', 'includes', 'track'),
  ('artist', 'authors', 'track')])

In [183]:
data = our_data

In [184]:
data.is_undirected()

False

In [185]:
homogeneous_data = data.to_homogeneous()
homogeneous_data

Data(edge_index=[2, 136909], x=[66849, 1], train_mask=[66849], y=[66849], node_type=[66849], edge_type=[136909])

In [186]:
data

HeteroData(
  [1mplaylist[0m={
    x=[1000, 1],
    train_mask=[1000],
    y=[1000]
  },
  [1mtrack[0m={ x=[35289, 1] },
  [1martist[0m={ x=[10091, 1] },
  [1malbum[0m={ x=[20469, 1] },
  [1m(playlist, contains, track)[0m={ edge_index=[2, 66331] },
  [1m(album, includes, track)[0m={ edge_index=[2, 35289] },
  [1m(artist, authors, track)[0m={ edge_index=[2, 35289] }
)

In [187]:
import torch_geometric.transforms as T

if not data.is_undirected():
    data = T.ToUndirected()(data)
# data = T.NormalizeFeatures()(data)
if data.has_isolated_nodes():
    data = T.RemoveIsolatedNodes()(data)

In [188]:
data.is_undirected(), data.has_isolated_nodes()

(True, False)

In [189]:
data

HeteroData(
  [1mplaylist[0m={
    x=[1000, 1],
    train_mask=[1000],
    y=[1000]
  },
  [1mtrack[0m={ x=[35289, 1] },
  [1martist[0m={ x=[10091, 1] },
  [1malbum[0m={ x=[20469, 1] },
  [1m(playlist, contains, track)[0m={ edge_index=[2, 66331] },
  [1m(album, includes, track)[0m={ edge_index=[2, 35289] },
  [1m(artist, authors, track)[0m={ edge_index=[2, 35289] },
  [1m(track, rev_contains, playlist)[0m={ edge_index=[2, 66331] },
  [1m(track, rev_includes, album)[0m={ edge_index=[2, 35289] },
  [1m(track, rev_authors, artist)[0m={ edge_index=[2, 35289] }
)

In [190]:
import torch_geometric.transforms as T
from torch_geometric.datasets import OGB_MAG
from torch_geometric.nn import SAGEConv, to_hetero
import torch
import torch.nn.functional as F

class GNN(torch.nn.Module):
    def __init__(self, hidden_channels):
        super().__init__()
        self.conv1 = SAGEConv((-1, -1), hidden_channels)
        self.conv2 = SAGEConv((-1, -1), hidden_channels)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index).relu()
        x = self.conv2(x, edge_index)
        return x
    
class LinkPredictor(torch.nn.Module):
    def forward(self, x_playlist, x_track, playlist_track_edge):
        playlist_embedding = x_playlist[playlist_track_edge[0]]
        track_embedding = x_track[playlist_track_edge[1]]

        # Apply dot-product to get a prediction per supervision edge:
        return (playlist_embedding * track_embedding).sum(dim=-1)

class HeteroModel(torch.nn.Module):
    def __init__(self, hidden_channels, node_features, metadata):
        super().__init__()
        # Since the dataset does not come with rich features, we also learn two
        # embedding matrices for users and movies:
        
        self.node_lin = {
            k: torch.nn.Linear(v.shape[1], hidden_channels) for k, v in node_features.items()
        }
        
        # Instantiate homogeneous GNN:
        self.gnn = GNN(hidden_channels)
        # Convert GNN model into a heterogeneous variant:
        self.gnn = to_hetero(self.gnn, metadata=metadata)

        self.classifier = LinkPredictor()

    def forward(self, data):
        x_dict = {
            k: self.node_lin[k](v) for k, v in data.x_dict.items()
        }
        
        x_dict = self.gnn(x_dict, data.edge_index_dict)
        pred = self.classifier(
            x_dict["playlist"],
            x_dict["track"],
            data["playlist", "contains", "track"].edge_index,
        )
        return pred


model = HeteroModel(64, data.x_dict, data.metadata())
# model = model.to('cuda:0')

In [191]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

In [192]:
data.validate()

True

In [193]:
data

HeteroData(
  [1mplaylist[0m={
    x=[1000, 1],
    train_mask=[1000],
    y=[1000]
  },
  [1mtrack[0m={ x=[35289, 1] },
  [1martist[0m={ x=[10091, 1] },
  [1malbum[0m={ x=[20469, 1] },
  [1m(playlist, contains, track)[0m={ edge_index=[2, 66331] },
  [1m(album, includes, track)[0m={ edge_index=[2, 35289] },
  [1m(artist, authors, track)[0m={ edge_index=[2, 35289] },
  [1m(track, rev_contains, playlist)[0m={ edge_index=[2, 66331] },
  [1m(track, rev_includes, album)[0m={ edge_index=[2, 35289] },
  [1m(track, rev_authors, artist)[0m={ edge_index=[2, 35289] }
)

In [194]:
import torch_geometric.transforms as T

transform = T.RandomLinkSplit(
    num_val=0.1,
    num_test=0.1,
    disjoint_train_ratio=0.3,
    neg_sampling_ratio=2.0,
    add_negative_train_samples=False,
    edge_types=("playlist", "contains", "track"),
    rev_edge_types=("track", "rev_contains", "playlist"), 
)

train_data, val_data, test_data = transform(data)

In [195]:
train_data

HeteroData(
  [1mplaylist[0m={
    x=[1000, 1],
    train_mask=[1000],
    y=[1000]
  },
  [1mtrack[0m={ x=[35289, 1] },
  [1martist[0m={ x=[10091, 1] },
  [1malbum[0m={ x=[20469, 1] },
  [1m(playlist, contains, track)[0m={
    edge_index=[2, 37146],
    edge_label=[15919],
    edge_label_index=[2, 15919]
  },
  [1m(album, includes, track)[0m={ edge_index=[2, 35289] },
  [1m(artist, authors, track)[0m={ edge_index=[2, 35289] },
  [1m(track, rev_contains, playlist)[0m={ edge_index=[2, 37146] },
  [1m(track, rev_includes, album)[0m={ edge_index=[2, 35289] },
  [1m(track, rev_authors, artist)[0m={ edge_index=[2, 35289] }
)

In [196]:
from torch_geometric.loader import LinkNeighborLoader

edge_label_index = train_data["playlist", "contains", "track"].edge_label_index
edge_label = train_data["playlist", "contains", "track"].edge_label
train_loader = LinkNeighborLoader(
    data=train_data,
    num_neighbors=[20, 10],
    neg_sampling_ratio=2.0,
    edge_label_index=(("playlist", "contains", "track"), edge_label_index),
    edge_label=edge_label,
    batch_size=128,
    shuffle=True,
)

batch = next(iter(train_loader))

In [204]:
batch

HeteroData(
  [1mplaylist[0m={
    x=[913, 1],
    train_mask=[913],
    y=[913]
  },
  [1mtrack[0m={ x=[8496, 1] },
  [1martist[0m={ x=[2513, 1] },
  [1malbum[0m={ x=[3886, 1] },
  [1m(playlist, contains, track)[0m={
    edge_index=[2, 11873],
    edge_label=[384],
    edge_label_index=[2, 384],
    input_id=[128]
  },
  [1m(album, includes, track)[0m={ edge_index=[2, 5010] },
  [1m(artist, authors, track)[0m={ edge_index=[2, 5010] },
  [1m(track, rev_contains, playlist)[0m={ edge_index=[2, 7601] },
  [1m(track, rev_includes, album)[0m={ edge_index=[2, 1293] },
  [1m(track, rev_authors, artist)[0m={ edge_index=[2, 2205] }
)

In [200]:
def train():
    model.train()

    total_examples = total_loss = 0
    for batch in train_loader:
        optimizer.zero_grad()
        # batch = batch.to('cuda:0')
        batch_size = 100
        out = model(batch)
        loss = F.cross_entropy(out,
                               batch["playlist", "contains", "track"].edge_label)
        loss.backward()
        optimizer.step()

        total_examples += batch_size
        print(f'Loss: {loss:.4f}')
        total_loss += float(loss) * batch_size

    return total_loss / total_examples

In [201]:
train()

AttributeError: 'NodeStorage' object has no attribute 'batch_size'