<a href="https://colab.research.google.com/github/tomasonjo/blogs/blob/master/pyg2neo/Movie_recommendations.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install sentence_transformers torch_geometric torch-scatter torch-sparse torch neo4j



In [43]:
import torch
import pandas as pd
import numpy as np
from torch.nn import Linear
import torch.nn.functional as F
from sentence_transformers import SentenceTransformer

import torch_geometric.transforms as T
from torch_geometric.nn import SAGEConv, to_hetero

from torch_geometric.data import HeteroData
from torch_geometric.transforms import ToUndirected, RandomLinkSplit

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
#device = 'cpu'

In [3]:
print(device)

cuda


In [4]:
from neo4j import GraphDatabase

url= 'bolt://3.86.43.255:7687'
user = 'neo4j'
password = 'company-science-journals'

driver = GraphDatabase.driver(url, auth=(user, password))

def fetch_data(query):
  with driver.session() as session:
    result = session.run(query)
    return pd.DataFrame([r.values() for r in result], columns=result.keys())


In [59]:
fetch_data("""
CALL gds.graph.create('movies', ['Movie', 'Person'], {ACTED_IN: {orientation:'UNDIRECTED'}, DIRECTED: {orientation:'UNDIRECTED'}})
""")

Unnamed: 0,nodeProjection,relationshipProjection,graphName,nodeCount,relationshipCount,createMillis
0,"{'Movie': {'properties': {}, 'label': 'Movie'}...","{'DIRECTED': {'orientation': 'UNDIRECTED', 'ag...",movies,28172,91834,226


In [60]:
fetch_data("""
CALL gds.fastRP.write('movies', {writeProperty:'fastrp', embeddingDimension:56})
""")

Unnamed: 0,nodeCount,nodePropertiesWritten,createMillis,computeMillis,writeMillis,configuration
0,28172,28172,0,489,4367,"{'writeConcurrency': 4, 'normalizationStrength..."


In [84]:
def load_node(cypher, index_col, encoders=None, **kwargs):
    df = fetch_data(cypher)
    df.set_index(index_col, inplace=True)
    mapping = {index: i for i, index in enumerate(df.index.unique())}
    x = None
    if encoders is not None:
        xs = [encoder(df[col]) for col, encoder in encoders.items()]
        x = torch.cat(xs, dim=-1)

    return x, mapping

In [85]:
def load_edge(cypher, src_index_col, src_mapping, dst_index_col, dst_mapping,
                  encoders=None, **kwargs):
    df = fetch_data(cypher)

    src = [src_mapping[index] for index in df[src_index_col]]
    dst = [dst_mapping[index] for index in df[dst_index_col]]
    edge_index = torch.tensor([src, dst])

    edge_attr = None
    if encoders is not None:
        edge_attrs = [encoder(df[col]) for col, encoder in encoders.items()]
        edge_attr = torch.cat(edge_attrs, dim=-1)

    return edge_index, edge_attr

In [86]:
class SequenceEncoder(object):
    # The 'SequenceEncoder' encodes raw column strings into embeddings.
    def __init__(self, model_name='all-MiniLM-L6-v2', device=None):
        self.device = device
        self.model = SentenceTransformer(model_name, device=device)

    @torch.no_grad()
    def __call__(self, df):
        x = self.model.encode(df.values, show_progress_bar=True,
                              convert_to_tensor=True, device=self.device)
        return x.cpu()

In [127]:
class GenresEncoder(object):
    # The 'GenreEncoder' splits the raw column strings by 'sep' and converts
    # individual elements to categorical labels.
    def __init__(self, sep='|'):
        self.sep = sep

    def __call__(self, df):
        genres = set(g for col in df.values for g in col.split(self.sep))
        mapping = {genre: i for i, genre in enumerate(genres)}

        x = torch.zeros(len(df), len(mapping))
        for i, col in enumerate(df.values):
            for genre in col.split(self.sep):
                x[i, mapping[genre]] = 1
        return x

In [128]:
class IdentityEncoder(object):
    # The 'IdentityEncoder' takes the raw column values and converts them to
    # PyTorch tensors.
    def __init__(self, dtype=None, is_list=False):
        self.dtype = dtype
        self.is_list = is_list

    def __call__(self, df):
        if self.is_list:
            return torch.stack([torch.tensor(el) for el in df.values])
        return torch.from_numpy(df.values).to(self.dtype)

In [129]:
user_x, user_mapping = load_node("MATCH (u:User) RETURN u.userId AS userId" , index_col='userId')

In [162]:
movie_query = """
MATCH (m:Movie)-[:IN_GENRE]->(genre:Genre)
WITH m, collect(genre.name) AS genres_list
RETURN m.movieId AS movieId, m.title AS title, apoc.text.join(genres_list, '|') AS genres, m.fastrp AS fastrp
"""

movie_x, movie_mapping = load_node(
    movie_query, 
    index_col='movieId', encoders={
        'title': SequenceEncoder(),
        'genres': GenresEncoder(),
        'fastrp': IdentityEncoder(is_list=True)
    })

Batches:   0%|          | 0/286 [00:00<?, ?it/s]

In [163]:
rating_query = """
MATCH (u:User)-[r:RATED]->(m:Movie) 
RETURN u.userId AS userId, m.movieId AS movieId, r.rating AS rating
"""

edge_index, edge_label = load_edge(
    rating_query,
    src_index_col='userId',
    src_mapping=user_mapping,
    dst_index_col='movieId',
    dst_mapping=movie_mapping,
    encoders={'rating': IdentityEncoder(dtype=torch.long)},
)

In [164]:
data = HeteroData()
data['user'].num_nodes = len(user_mapping)  # Users do not have any features.
# Add user node features for message passing:
data['user'].x = torch.eye(data['user'].num_nodes, device=device)
del data['user'].num_nodes

data['movie'].x = movie_x
data['user', 'rates', 'movie'].edge_index = edge_index
data['user', 'rates', 'movie'].edge_label = edge_label
data.to(device, non_blocking=True)

HeteroData(
  [1muser[0m={ x=[671, 671] },
  [1mmovie[0m={ x=[9125, 460] },
  [1m(user, rates, movie)[0m={
    edge_index=[2, 100004],
    edge_label=[100004]
  }
)

In [165]:
data = ToUndirected()(data)
del data['movie', 'rev_rates', 'user'].edge_label  # Remove "reverse" label.

# 2. Perform a link-level split into training, validation, and test edges.
transform = RandomLinkSplit(
    num_val=0.1,
    num_test=0.1,
    neg_sampling_ratio=0.0,
    edge_types=[('user', 'rates', 'movie')],
    rev_edge_types=[('movie', 'rev_rates', 'user')],
)
train_data, val_data, test_data = transform(data)

In [166]:
class GNNEncoder(torch.nn.Module):
    def __init__(self, hidden_channels, out_channels):
        super().__init__()
        self.conv1 = SAGEConv((-1, -1), hidden_channels)
        self.conv2 = SAGEConv((-1, -1), out_channels)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index).relu()
        x = self.conv2(x, edge_index)
        return x


class EdgeDecoder(torch.nn.Module):
    def __init__(self, hidden_channels):
        super().__init__()
        self.lin1 = Linear(2 * hidden_channels, hidden_channels)
        self.lin2 = Linear(hidden_channels, 1)

    def forward(self, z_dict, edge_label_index):
        row, col = edge_label_index
        z = torch.cat([z_dict['user'][row], z_dict['movie'][col]], dim=-1)

        z = self.lin1(z).relu()
        z = self.lin2(z)
        return z.view(-1)

In [167]:
class Model(torch.nn.Module):
    def __init__(self, hidden_channels):
        super().__init__()
        self.encoder = GNNEncoder(hidden_channels, hidden_channels)
        self.encoder = to_hetero(self.encoder, data.metadata(), aggr='sum')
        self.decoder = EdgeDecoder(hidden_channels)

    def forward(self, x_dict, edge_index_dict, edge_label_index):
        z_dict = self.encoder(x_dict, edge_index_dict)
        return self.decoder(z_dict, edge_label_index)

In [168]:
model = Model(hidden_channels=64).to(device)


In [169]:
weight = torch.bincount(train_data['user', 'movie'].edge_label)
weight = weight.max() / weight

def weighted_mse_loss(pred, target, weight=None):
    weight = 1. if weight is None else weight[target].to(pred.dtype)
    return (weight * (pred - target.to(pred.dtype)).pow(2)).mean()

In [170]:
# Due to lazy initialization, we need to run one model step so the number
# of parameters can be inferred:
with torch.no_grad():
    model.encoder(train_data.x_dict, train_data.edge_index_dict)

optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

In [171]:
def train():
    model.train()
    optimizer.zero_grad()
    pred = model(train_data.x_dict, train_data.edge_index_dict,
                 train_data['user', 'rates', 'movie'].edge_label_index)
    target = train_data['user', 'rates', 'movie'].edge_label
    loss = weighted_mse_loss(pred, target, weight)
    loss.backward()
    optimizer.step()
    return float(loss)

In [172]:
@torch.no_grad()
def test(data):
    model.eval()
    pred = model(data.x_dict, data.edge_index_dict,
                 data['user', 'rates', 'movie'].edge_label_index)
    pred = pred.clamp(min=0, max=5)
    target = data['user', 'rates', 'movie'].edge_label.float()
    rmse = F.mse_loss(pred, target).sqrt()
    return float(rmse)

In [173]:
for epoch in range(1, 300):
    loss = train()
    train_rmse = test(train_data)
    val_rmse = test(val_data)
    test_rmse = test(test_data)
    print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Train: {train_rmse:.4f}, '
          f'Val: {val_rmse:.4f}, Test: {test_rmse:.4f}')

Epoch: 001, Loss: 20.4778, Train: 3.2648, Val: 3.2541, Test: 3.2599
Epoch: 002, Loss: 16.5212, Train: 2.2975, Val: 2.2942, Test: 2.3017
Epoch: 003, Loss: 8.9984, Train: 1.2526, Val: 1.2495, Test: 1.2641
Epoch: 004, Loss: 11.1087, Train: 1.1689, Val: 1.1736, Test: 1.1874
Epoch: 005, Loss: 6.8246, Train: 1.7622, Val: 1.7627, Test: 1.7714
Epoch: 006, Loss: 6.7204, Train: 2.1124, Val: 2.1098, Test: 2.1169
Epoch: 007, Loss: 7.9894, Train: 2.1377, Val: 2.1348, Test: 2.1416
Epoch: 008, Loss: 8.0837, Train: 1.9314, Val: 1.9304, Test: 1.9373
Epoch: 009, Loss: 7.1595, Train: 1.5597, Val: 1.5622, Test: 1.5699
Epoch: 010, Loss: 6.0748, Train: 1.1838, Val: 1.1907, Test: 1.2001
Epoch: 011, Loss: 6.0290, Train: 1.0562, Val: 1.0647, Test: 1.0748
Epoch: 012, Loss: 6.8613, Train: 1.0655, Val: 1.0746, Test: 1.0837
Epoch: 013, Loss: 6.4660, Train: 1.2185, Val: 1.2257, Test: 1.2328
Epoch: 014, Loss: 5.6513, Train: 1.4770, Val: 1.4807, Test: 1.4861
Epoch: 015, Loss: 5.6393, Train: 1.6612, Val: 1.6626, Test:

Epoch: 131, Loss: 2.6561, Train: 1.0939, Val: 1.1589, Test: 1.1349
Epoch: 132, Loss: 2.6335, Train: 1.0437, Val: 1.1123, Test: 1.0901
Epoch: 133, Loss: 2.6119, Train: 1.0612, Val: 1.1293, Test: 1.1063
Epoch: 134, Loss: 2.5948, Train: 1.0668, Val: 1.1354, Test: 1.1120
Epoch: 135, Loss: 2.5845, Train: 1.0343, Val: 1.1059, Test: 1.0835
Epoch: 136, Loss: 2.5810, Train: 1.0957, Val: 1.1646, Test: 1.1396
Epoch: 137, Loss: 2.5875, Train: 1.0026, Val: 1.0769, Test: 1.0560
Epoch: 138, Loss: 2.6278, Train: 1.1503, Val: 1.2161, Test: 1.1901
Epoch: 139, Loss: 2.6634, Train: 0.9841, Val: 1.0567, Test: 1.0378
Epoch: 140, Loss: 2.7315, Train: 1.1025, Val: 1.1690, Test: 1.1454
Epoch: 141, Loss: 2.5738, Train: 1.0975, Val: 1.1661, Test: 1.1417
Epoch: 142, Loss: 2.5567, Train: 0.9885, Val: 1.0657, Test: 1.0442
Epoch: 143, Loss: 2.6642, Train: 1.1238, Val: 1.1971, Test: 1.1689
Epoch: 144, Loss: 2.5873, Train: 1.0262, Val: 1.1059, Test: 1.0804
Epoch: 145, Loss: 2.5059, Train: 1.0251, Val: 1.1049, Test: 1.

Epoch: 255, Loss: 2.2521, Train: 0.9947, Val: 1.1141, Test: 1.0861
Epoch: 256, Loss: 2.2566, Train: 1.0207, Val: 1.1394, Test: 1.1104
Epoch: 257, Loss: 2.2561, Train: 0.9996, Val: 1.1215, Test: 1.0927
Epoch: 258, Loss: 2.2498, Train: 0.9978, Val: 1.1211, Test: 1.0921
Epoch: 259, Loss: 2.2496, Train: 1.0164, Val: 1.1379, Test: 1.1085
Epoch: 260, Loss: 2.2516, Train: 0.9945, Val: 1.1157, Test: 1.0876
Epoch: 261, Loss: 2.2505, Train: 1.0094, Val: 1.1303, Test: 1.1016
Epoch: 262, Loss: 2.2460, Train: 1.0093, Val: 1.1321, Test: 1.1028
Epoch: 263, Loss: 2.2450, Train: 0.9916, Val: 1.1165, Test: 1.0876
Epoch: 264, Loss: 2.2481, Train: 1.0167, Val: 1.1400, Test: 1.1102
Epoch: 265, Loss: 2.2476, Train: 0.9959, Val: 1.1190, Test: 1.0904
Epoch: 266, Loss: 2.2443, Train: 1.0072, Val: 1.1301, Test: 1.1011
Epoch: 267, Loss: 2.2408, Train: 1.0080, Val: 1.1324, Test: 1.1029
Epoch: 268, Loss: 2.2401, Train: 0.9926, Val: 1.1187, Test: 1.0896
Epoch: 269, Loss: 2.2423, Train: 1.0167, Val: 1.1412, Test: 1.

In [149]:
num_movies = len(movie_mapping)
num_users = len(user_mapping)

for user_id in range(0,num_users): 

    row = torch.tensor([user_id] * num_movies)
    col = torch.arange(num_movies)
    edge_label_index = torch.stack([row, col], dim=0)

    pred = model(data.x_dict, data.edge_index_dict,
                 edge_label_index)
    pred = pred.clamp(min=0, max=5)
    mask = pred > 4