
## IGMC CODE IMPLEMENTATION


In [None]:
!pip install torch-sparse
!pip install torch-scatter
!pip install torch-geometric

In [None]:
import numpy as np
import pandas as pd
import multiprocessing as mp
import argparse
import scipy.io as sio
import random
import pdb
import scipy.sparse as sp
import torch
from torch_geometric.data import Data, Dataset, InMemoryDataset
import warnings
from urllib.request import urlopen
from zipfile import ZipFile
import shutil
import os.path
from tqdm import tqdm
import networkx as nx
import pickle as pkl
import os, sys, pdb, math, time
from copy import deepcopy

warnings.simplefilter('ignore', sp.SparseEfficiencyWarning)
import torch.multiprocessing
torch.multiprocessing.set_sharing_strategy('file_system')

In [None]:
def map_data(data):
    unique_elements = list(set(data))
    id_dictionary = {old: new for new, old in enumerate(sorted(unique_elements))}
    data = np.array([id_dictionary[d] for d in data])
    n = len(unique_elements)

    return data, id_dictionary, n

In [None]:
def load_data(fname, seed=1234, verbose=True):
    u_features = None
    v_features = None
    sep = '\t'
    filename = './u.data'
    dtypes = {
        'u_nodes': np.int32, 'v_nodes': np.int32,
        'ratings': np.float32, 'timestamp': np.float64}

    data = pd.read_csv(
        filename, sep=sep, header=None,
        names=['u_nodes', 'v_nodes', 'ratings', 'timestamp'], dtype=dtypes)
    data_array = data.values.tolist()
    random.seed(seed)
    random.shuffle(data_array)
    data_array = np.array(data_array)

    u_nodes_ratings = data_array[:, 0].astype(dtypes['u_nodes'])
    v_nodes_ratings = data_array[:, 1].astype(dtypes['v_nodes'])
    ratings = data_array[:, 2].astype(dtypes['ratings'])

    u_nodes_ratings, u_dict, num_users = map_data(u_nodes_ratings)
    v_nodes_ratings, v_dict, num_movies = map_data(v_nodes_ratings)

    u_nodes_ratings, v_nodes_ratings = u_nodes_ratings.astype(np.int64), v_nodes_ratings.astype(np.int32)
    ratings = ratings.astype(np.float64)

    sep = r'|'
    movie_file = './u.item'
    movie_headers = ['movie id', 'movie title', 'release date', 'video release date',
                      'IMDb URL', 'unknown', 'Action', 'Adventure', 'Animation',
                      'Childrens', 'Comedy', 'Crime', 'Documentary', 'Drama', 'Fantasy',
                      'Film-Noir', 'Horror', 'Musical', 'Mystery', 'Romance', 'Sci-Fi',
                      'Thriller', 'War', 'Western']
    movie_df = pd.read_csv(movie_file, sep=sep, header=None,
                            names=movie_headers, engine='python')

    genre_headers = movie_df.columns.values[6:]
    num_genres = genre_headers.shape[0]

    v_features = np.zeros((num_movies, num_genres), dtype=np.float32)
    for movie_id, g_vec in zip(movie_df['movie id'].values.tolist(), movie_df[genre_headers].values.tolist()):
        # Check if movie_id was listed in ratings file and therefore in mapping dictionary
        if movie_id in v_dict.keys():
            v_features[v_dict[movie_id], :] = g_vec

    # User features

    sep = r'|'
    users_file = './u.user'
    users_headers = ['user id', 'age', 'gender', 'occupation', 'zip code']
    users_df = pd.read_csv(users_file, sep=sep, header=None,
                            names=users_headers, engine='python')

    occupation = set(users_df['occupation'].values.tolist())

    gender_dict = {'M': 0., 'F': 1.}
    occupation_dict = {f: i for i, f in enumerate(occupation, start=2)}

    num_feats = 2 + len(occupation_dict)

    u_features = np.zeros((num_users, num_feats), dtype=np.float32)
    for _, row in users_df.iterrows():
        u_id = row['user id']
        if u_id in u_dict.keys():
            u_features[u_dict[u_id], 0] = row['age']
            u_features[u_dict[u_id], 1] = gender_dict[row['gender']]
            u_features[u_dict[u_id], occupation_dict[row['occupation']]] = 1.

    u_features = sp.csr_matrix(u_features)
    v_features = sp.csr_matrix(v_features)

    if verbose:
        print('Number of users = %d' % num_users)
        print('Number of items = %d' % num_movies)
        print('Number of links = %d' % ratings.shape[0])
        print('Fraction of positive links = %.4f' % (float(ratings.shape[0]) / (num_users * num_movies),))

    return num_users, num_movies, u_nodes_ratings, v_nodes_ratings, ratings, u_features, v_features

In [None]:
class SparseRowIndex:
    def __init__(self, csr_matrix):
        data = []
        indices = []
        indptr = []

        for row_start, row_end in zip(csr_matrix.indptr[:-1], csr_matrix.indptr[1:]):
            data.append(csr_matrix.data[row_start:row_end])
            indices.append(csr_matrix.indices[row_start:row_end])
            indptr.append(row_end - row_start)  # nnz of the row

        self.data = np.array(data, dtype=object)
        self.indices = np.array(indices, dtype=object)
        self.indptr = np.array(indptr, dtype=object)
        self.shape = csr_matrix.shape

    def __getitem__(self, row_selector):
        indices = np.concatenate(self.indices[row_selector])
        data = np.concatenate(self.data[row_selector])
        indptr = np.append(0, np.cumsum(self.indptr[row_selector]))
        shape = [indptr.shape[0] - 1, self.shape[1]]
        return sp.csr_matrix((data, indices, indptr), shape=shape)

In [None]:
class SparseColIndex:
    def __init__(self, csc_matrix):
        data = []
        indices = []
        indptr = []

        for col_start, col_end in zip(csc_matrix.indptr[:-1], csc_matrix.indptr[1:]):
            data.append(csc_matrix.data[col_start:col_end])
            indices.append(csc_matrix.indices[col_start:col_end])
            indptr.append(col_end - col_start)

        self.data = np.array(data, dtype=object)
        self.indices = np.array(indices, dtype=object)
        self.indptr = np.array(indptr, dtype=object)
        self.shape = csc_matrix.shape

    def __getitem__(self, col_selector):
        indices = np.concatenate(self.indices[col_selector])
        data = np.concatenate(self.data[col_selector])
        indptr = np.append(0, np.cumsum(self.indptr[col_selector]))

        shape = [self.shape[0], indptr.shape[0] - 1]
        return sp.csc_matrix((data, indices, indptr), shape=shape)


In [None]:
class GraphDataset(InMemoryDataset):
    def __init__(self, root, A, links, labels, h, sample_ratio, max_nodes_per_hop,
                 u_features, v_features, class_values, max_num=None, parallel=True):
        self.Arow = SparseRowIndex(A)
        self.Acol = SparseColIndex(A.tocsc())
        self.links = links
        self.labels = labels
        self.h = h
        self.sample_ratio = sample_ratio
        self.max_nodes_per_hop = max_nodes_per_hop
        self.u_features = u_features
        self.v_features = v_features
        self.class_values = class_values
        self.parallel = parallel
        self.max_num = max_num
        if max_num is not None:
            np.random.seed(123)
            num_links = len(links[0])
            perm = np.random.permutation(num_links)
            perm = perm[:max_num]
            self.links = (links[0][perm], links[1][perm])
            self.labels = labels[perm]
        super(GraphDataset, self).__init__(root)
        self.data, self.slices = torch.load(self.processed_paths[0])

    @property
    def processed_file_names(self):
        name = 'data.pt'
        if self.max_num is not None:
            name = 'data_{}.pt'.format(self.max_num)
        return [name]

    def process(self):
        data_list = links2subgraphs(self.Arow, self.Acol, self.links, self.labels, self.h,
                                    self.sample_ratio, self.max_nodes_per_hop,
                                    self.u_features, self.v_features,
                                    self.class_values, self.parallel)

        data, slices = self.collate(data_list)
        torch.save((data, slices), self.processed_paths[0])
        del data_list

In [None]:
class DynamicDataset(Dataset):
    def __init__(self, root, A, links, labels, h, sample_ratio, max_nodes_per_hop,
                 u_features, v_features, class_values, max_num=None):
        super(DynamicDataset, self).__init__(root)
        self.Arow = SparseRowIndex(A)
        self.Acol = SparseColIndex(A.tocsc())
        self.links = links
        self.labels = labels
        self.h = h
        self.sample_ratio = sample_ratio
        self.max_nodes_per_hop = max_nodes_per_hop
        self.u_features = u_features
        self.v_features = v_features
        self.class_values = class_values
        if max_num is not None:
            np.random.seed(123)
            num_links = len(links[0])
            perm = np.random.permutation(num_links)
            perm = perm[:max_num]
            self.links = (links[0][perm], links[1][perm])
            self.labels = labels[perm]

    def len(self):
        return len(self.links[0])

    def get(self, idx):
        i, j = self.links[0][idx], self.links[1][idx]
        g_label = self.labels[idx]
        tmp = subgraph_extraction_labeling(
            (i, j), self.Arow, self.Acol, self.h, self.sample_ratio, self.max_nodes_per_hop,
            self.u_features, self.v_features, self.class_values, g_label
        )
        return construct_pyg_graph(*tmp)


In [None]:
def links2subgraphs(Arow,
                    Acol,
                    links,
                    labels,
                    h=1,
                    sample_ratio=1.0,
                    max_nodes_per_hop=None,
                    u_features=None,
                    v_features=None,
                    class_values=None,
                    parallel=True):
    # extract enclosing subgraphs
    print('Enclosing subgraph extraction begins...')
    g_list = []
    if not parallel:
        with tqdm(total=len(links[0])) as pbar:
            for i, j, g_label in zip(links[0], links[1], labels):
                tmp = subgraph_extraction_labeling(
                    (i, j), Arow, Acol, h, sample_ratio, max_nodes_per_hop, u_features,
                    v_features, class_values, g_label
                )
                data = construct_pyg_graph(*tmp)
                g_list.append(data)
                pbar.update(1)
    else:
        start = time.time()
        pool = mp.Pool(mp.cpu_count())
        results = pool.starmap_async(
            subgraph_extraction_labeling,
            [
                ((i, j), Arow, Acol, h, sample_ratio, max_nodes_per_hop, u_features,
                v_features, class_values, g_label)
                for i, j, g_label in zip(links[0], links[1], labels)
            ]
        )
        remaining = results._number_left
        pbar = tqdm(total=remaining)
        while True:
            pbar.update(remaining - results._number_left)
            if results.ready(): break
            remaining = results._number_left
            time.sleep(1)
        results = results.get()
        pool.close()
        pbar.close()
        end = time.time()
        print("Time elapsed for subgraph extraction: {}s".format(end-start))
        print("Transforming to pytorch_geometric graphs...")
        g_list = []
        pbar = tqdm(total=len(results))
        while results:
            tmp = results.pop()
            g_list.append(construct_pyg_graph(*tmp))
            pbar.update(1)
        pbar.close()
        end2 = time.time()
        print("Time elapsed for transforming to pytorch_geometric graphs: {}s".format(end2-end))
    return g_list

In [None]:
def subgraph_extraction_labeling(ind, Arow, Acol, h=1, sample_ratio=1.0, max_nodes_per_hop=None,
                                 u_features=None, v_features=None, class_values=None,
                                 y=1):
    # extract the h-hop enclosing subgraph around link 'ind'
    u_nodes, v_nodes = [ind[0]], [ind[1]]
    u_dist, v_dist = [0], [0]
    u_visited, v_visited = set([ind[0]]), set([ind[1]])
    u_fringe, v_fringe = set([ind[0]]), set([ind[1]])
    for dist in range(1, h+1):
        v_fringe, u_fringe = neighbors(u_fringe, Arow), neighbors(v_fringe, Acol)
        u_fringe = u_fringe - u_visited
        v_fringe = v_fringe - v_visited
        u_visited = u_visited.union(u_fringe)
        v_visited = v_visited.union(v_fringe)
        if sample_ratio < 1.0:
            u_fringe = random.sample(u_fringe, int(sample_ratio*len(u_fringe)))
            v_fringe = random.sample(v_fringe, int(sample_ratio*len(v_fringe)))
        if max_nodes_per_hop is not None:
            if max_nodes_per_hop < len(u_fringe):
                u_fringe = random.sample(u_fringe, max_nodes_per_hop)
            if max_nodes_per_hop < len(v_fringe):
                v_fringe = random.sample(v_fringe, max_nodes_per_hop)
        if len(u_fringe) == 0 and len(v_fringe) == 0:
            break
        u_nodes = u_nodes + list(u_fringe)
        v_nodes = v_nodes + list(v_fringe)
        u_dist = u_dist + [dist] * len(u_fringe)
        v_dist = v_dist + [dist] * len(v_fringe)
    subgraph = Arow[u_nodes][:, v_nodes]
    subgraph[0, 0] = 0

    # prepare pyg graph constructor input
    u, v, r = sp.find(subgraph)  # r is 1, 2... (rating labels + 1)
    v += len(u_nodes)
    r = r - 1  # transform r back to rating label
    num_nodes = len(u_nodes) + len(v_nodes)
    node_labels = [x*2 for x in u_dist] + [x*2+1 for x in v_dist]
    max_node_label = 2*h + 1
    y = class_values[y]

    # get node features
    if u_features is not None:
        u_features = u_features[u_nodes]
    if v_features is not None:
        v_features = v_features[v_nodes]
    node_features = None
    if False:
        if u_features is not None and v_features is not None:
            u_extended = np.concatenate(
                [u_features, np.zeros([u_features.shape[0], v_features.shape[1]])], 1
            )
            v_extended = np.concatenate(
                [np.zeros([v_features.shape[0], u_features.shape[1]]), v_features], 1
            )
            node_features = np.concatenate([u_extended, v_extended], 0)
    if False:
        # use identity features (one-hot encodings of node idxes)
        u_ids = one_hot(u_nodes, Arow.shape[0] + Arow.shape[1])
        v_ids = one_hot([x+Arow.shape[0] for x in v_nodes], Arow.shape[0] + Arow.shape[1])
        node_ids = np.concatenate([u_ids, v_ids], 0)
        #node_features = np.concatenate([node_features, node_ids], 1)
        node_features = node_ids
    if True:
        # only output node features for the target user and item
        if u_features is not None and v_features is not None:
            node_features = [u_features[0], v_features[0]]

    return u, v, r, node_labels, max_node_label, y, node_features


In [None]:
def construct_pyg_graph(u, v, r, node_labels, max_node_label, y, node_features):
    u, v = torch.LongTensor(u), torch.LongTensor(v)
    r = torch.LongTensor(r)
    edge_index = torch.stack([torch.cat([u, v]), torch.cat([v, u])], 0)
    edge_type = torch.cat([r, r])
    x = torch.FloatTensor(one_hot(node_labels, max_node_label+1))
    y = torch.FloatTensor([y])
    data = Data(x, edge_index, edge_type=edge_type, y=y)

    if node_features is not None:
        if type(node_features) == list:  # a list of u_feature and v_feature
            u_feature, v_feature = node_features
            data.u_feature = torch.FloatTensor(u_feature).unsqueeze(0)
            data.v_feature = torch.FloatTensor(v_feature).unsqueeze(0)
        else:
            x2 = torch.FloatTensor(node_features)
            data.x = torch.cat([data.x, x2], 1)
    return data


In [None]:
def neighbors(fringe, A):
    # find all 1-hop neighbors of nodes in fringe from A
    if not fringe:
        return set([])
    return set(A[list(fringe)].indices)

In [None]:
def one_hot(idx, length):
    idx = np.array(idx)
    x = np.zeros([len(idx), length])
    x[np.arange(len(idx)), idx] = 1.0
    return x

In [None]:
#icluded preprocessing steps as well
def load_official_trainvaltest_split(testing=False, rating_map=None, post_rating_map=None, ratio=1.0):
    sep = '\t'
    # Check if files exist and download otherwise
    files = ['/u1.base', '/u1.test', '/u.item', '/u.user']
    data_dir = '.'
    dtypes = {
        'u_nodes': np.int32, 'v_nodes': np.int32,
        'ratings': np.float32, 'timestamp': np.float64}

    filename_train = './u1.base'
    filename_test = './u1.test'

    data_train = pd.read_csv(
        filename_train, sep=sep, header=None,
        names=['u_nodes', 'v_nodes', 'ratings', 'timestamp'], dtype=dtypes)

    data_test = pd.read_csv(
        filename_test, sep=sep, header=None,
        names=['u_nodes', 'v_nodes', 'ratings', 'timestamp'], dtype=dtypes)

    data_array_train = data_train.values.tolist()
    data_array_train = np.array(data_array_train)
    data_array_test = data_test.values.tolist()
    data_array_test = np.array(data_array_test)

    if ratio < 1.0:
        data_array_train = data_array_train[data_array_train[:, -1].argsort()[:int(ratio*len(data_array_train))]]

    data_array = np.concatenate([data_array_train, data_array_test], axis=0)

    u_nodes_ratings = data_array[:, 0].astype(dtypes['u_nodes'])
    v_nodes_ratings = data_array[:, 1].astype(dtypes['v_nodes'])
    ratings = data_array[:, 2].astype(dtypes['ratings'])
    if rating_map is not None:
        for i, x in enumerate(ratings):
            ratings[i] = rating_map[x]

    u_nodes_ratings, u_dict, num_users = map_data(u_nodes_ratings)
    v_nodes_ratings, v_dict, num_movies = map_data(v_nodes_ratings)

    u_nodes_ratings, v_nodes_ratings = u_nodes_ratings.astype(np.int64), v_nodes_ratings.astype(np.int32)
    ratings = ratings.astype(np.float64)

    u_nodes = u_nodes_ratings
    v_nodes = v_nodes_ratings

    neutral_rating = -1  # int(np.ceil(np.float(num_classes)/2.)) - 1

    # assumes that ratings_train contains at least one example of every rating type
    rating_dict = {r: i for i, r in enumerate(np.sort(np.unique(ratings)).tolist())}

    labels = np.full((num_users, num_movies), neutral_rating, dtype=np.int32)
    labels[u_nodes, v_nodes] = np.array([rating_dict[r] for r in ratings])

    for i in range(len(u_nodes)):
        assert(labels[u_nodes[i], v_nodes[i]] == rating_dict[ratings[i]])

    labels = labels.reshape([-1])

    # number of test and validation edges, see cf-nade code

    num_train = data_array_train.shape[0]
    num_test = data_array_test.shape[0]
    num_val = int(np.ceil(num_train * 0.2))
    num_train = num_train - num_val

    pairs_nonzero = np.array([[u, v] for u, v in zip(u_nodes, v_nodes)])
    idx_nonzero = np.array([u * num_movies + v for u, v in pairs_nonzero])

    for i in range(len(ratings)):
        assert(labels[idx_nonzero[i]] == rating_dict[ratings[i]])

    idx_nonzero_train = idx_nonzero[0:num_train+num_val]
    idx_nonzero_test = idx_nonzero[num_train+num_val:]

    pairs_nonzero_train = pairs_nonzero[0:num_train+num_val]
    pairs_nonzero_test = pairs_nonzero[num_train+num_val:]

    # Internally shuffle training set (before splitting off validation set)
    rand_idx = list(range(len(idx_nonzero_train)))
    np.random.seed(42)
    np.random.shuffle(rand_idx)
    idx_nonzero_train = idx_nonzero_train[rand_idx]
    pairs_nonzero_train = pairs_nonzero_train[rand_idx]

    idx_nonzero = np.concatenate([idx_nonzero_train, idx_nonzero_test], axis=0)
    pairs_nonzero = np.concatenate([pairs_nonzero_train, pairs_nonzero_test], axis=0)

    val_idx = idx_nonzero[0:num_val]
    train_idx = idx_nonzero[num_val:num_train + num_val]
    test_idx = idx_nonzero[num_train + num_val:]

    assert(len(test_idx) == num_test)

    val_pairs_idx = pairs_nonzero[0:num_val]
    train_pairs_idx = pairs_nonzero[num_val:num_train + num_val]
    test_pairs_idx = pairs_nonzero[num_train + num_val:]

    u_test_idx, v_test_idx = test_pairs_idx.transpose()
    u_val_idx, v_val_idx = val_pairs_idx.transpose()
    u_train_idx, v_train_idx = train_pairs_idx.transpose()

    # create labels
    train_labels = labels[train_idx]
    val_labels = labels[val_idx]
    test_labels = labels[test_idx]

    if testing:
        u_train_idx = np.hstack([u_train_idx, u_val_idx])
        v_train_idx = np.hstack([v_train_idx, v_val_idx])
        train_labels = np.hstack([train_labels, val_labels])
        # for adjacency matrix construction
        train_idx = np.hstack([train_idx, val_idx])

    class_values = np.sort(np.unique(ratings))

    # make training adjacency matrix
    rating_mx_train = np.zeros(num_users * num_movies, dtype=np.float32)
    if post_rating_map is None:
        rating_mx_train[train_idx] = labels[train_idx].astype(np.float32) + 1.
    else:
        rating_mx_train[train_idx] = np.array([post_rating_map[r] for r in class_values[labels[train_idx]]]) + 1.
    rating_mx_train = sp.csr_matrix(rating_mx_train.reshape(num_users, num_movies))


    # movie features (genres)
    sep = r'|'
    movie_file = './u.item'
    movie_headers = ['movie id', 'movie title', 'release date', 'video release date',
                      'IMDb URL', 'unknown', 'Action', 'Adventure', 'Animation',
                      'Childrens', 'Comedy', 'Crime', 'Documentary', 'Drama', 'Fantasy',
                      'Film-Noir', 'Horror', 'Musical', 'Mystery', 'Romance', 'Sci-Fi',
                      'Thriller', 'War', 'Western']
    movie_df = pd.read_csv(movie_file, sep=sep, header=None,
                            names=movie_headers, engine='python',encoding="latin-1")

    genre_headers = movie_df.columns.values[6:]
    num_genres = genre_headers.shape[0]

    v_features = np.zeros((num_movies, num_genres), dtype=np.float32)
    for movie_id, g_vec in zip(movie_df['movie id'].values.tolist(), movie_df[genre_headers].values.tolist()):
        # check if movie_id was listed in ratings file and therefore in mapping dictionary
        if movie_id in v_dict.keys():
            v_features[v_dict[movie_id], :] = g_vec

    # user features

    sep = r'|'
    users_file = './u.user'
    users_headers = ['user id', 'age', 'gender', 'occupation', 'zip code']
    users_df = pd.read_csv(users_file, sep=sep, header=None,
                            names=users_headers, engine='python')

    occupation = set(users_df['occupation'].values.tolist())

    age = users_df['age'].values
    age_max = age.max()

    gender_dict = {'M': 0., 'F': 1.}
    occupation_dict = {f: i for i, f in enumerate(occupation, start=2)}

    num_feats = 2 + len(occupation_dict)

    u_features = np.zeros((num_users, num_feats), dtype=np.float32)
    for _, row in users_df.iterrows():
        u_id = row['user id']
        if u_id in u_dict.keys():
            # age
            u_features[u_dict[u_id], 0] = row['age'] / np.float(age_max)
            # gender
            u_features[u_dict[u_id], 1] = gender_dict[row['gender']]
            # occupation
            u_features[u_dict[u_id], occupation_dict[row['occupation']]] = 1.



    u_features = sp.csr_matrix(u_features)
    v_features = sp.csr_matrix(v_features)

    print("User features shape: "+str(u_features.shape))
    print("Item features shape: "+str(v_features.shape))

    return u_features, v_features, rating_mx_train, train_labels, u_train_idx, v_train_idx, \
        val_labels, u_val_idx, v_val_idx, test_labels, u_test_idx, v_test_idx, class_values

In [None]:
import torch
from torch.nn import Linear
import torch.nn.functional as F
from torch.optim import Adam
from torch_geometric.data import DataLoader
from torch_geometric.nn import RGCNConv
from torch_geometric.utils import dropout_adj

In [None]:
# Arguments
EPOCHS=30
BATCH_SIZE=128
LR=1e-3
LR_DECAY_STEP = 20
LR_DECAY_VALUE = 10

In [None]:
torch.manual_seed(1234)
device = torch.device('cpu')
if torch.cuda.is_available():
    torch.cuda.manual_seed(123)
    torch.cuda.synchronize()
    device = torch.device('cuda')
device

In [None]:
(u_features, v_features, adj_train, train_labels, train_u_indices, train_v_indices, val_labels,
val_u_indices, val_v_indices, test_labels, test_u_indices, test_v_indices, class_values
) = load_official_trainvaltest_split(testing=True)

In [None]:
train_dataset = eval('DynamicDataset')(root='data/ml_100k/testmode/train', A=adj_train,
    links=(train_u_indices, train_v_indices), labels=train_labels, h=1, sample_ratio=1.0,
    max_nodes_per_hop=200, u_features=None, v_features=None, class_values=class_values)
test_dataset = eval('GraphDataset')(root='data/ml_100k/testmode/test', A=adj_train,
    links=(test_u_indices, test_v_indices), labels=test_labels, h=1, sample_ratio=1.0,
    max_nodes_per_hop=200, u_features=None, v_features=None, class_values=class_values)

len(train_dataset), len(test_dataset)

In [None]:
class IGMC(torch.nn.Module):
    def __init__(self):
        super(IGMC, self).__init__()
        self.rel_graph_convs = torch.nn.ModuleList()
        self.rel_graph_convs.append(RGCNConv(in_channels=4, out_channels=32, num_relations=5, num_bases=4))
        self.rel_graph_convs.append(RGCNConv(in_channels=32, out_channels=32, num_relations=5, num_bases=4))
        self.rel_graph_convs.append(RGCNConv(in_channels=32, out_channels=32, num_relations=5, num_bases=4))
        self.rel_graph_convs.append(RGCNConv(in_channels=32, out_channels=32, num_relations=5, num_bases=4))
        self.linear_layer1 = Linear(256, 128)
        self.linear_layer2 = Linear(128, 1)

    def reset_parameters(self):
        self.linear_layer1.reset_parameters()
        self.linear_layer2.reset_parameters()
        for i in self.rel_graph_convs:
            i.reset_parameters()

    def forward(self, data):
        num_nodes = len(data.x)
        edge_index_dr, edge_type_dr = dropout_adj(data.edge_index, data.edge_type, p=0.2, num_nodes=num_nodes, training=self.training)

        out = data.x
        h = []
        for conv in self.rel_graph_convs:
            out = conv(out, edge_index_dr, edge_type_dr)
            out = torch.tanh(out)
            h.append(out)
        h = torch.cat(h, 1)
        h = [h[data.x[:, 0] == True], h[data.x[:, 1] == True]]
        g = torch.cat(h, 1)
        out = self.linear_layer1(g)
        out = F.relu(out)
        out = F.dropout(out, p=0.2, training=self.training)
        out = self.linear_layer2(out)
        out = out[:,0]
        return out

model = IGMC()

In [None]:
train_dataset

In [None]:
train_loader = DataLoader(train_dataset, 200, shuffle=True)
test_loader = DataLoader(test_dataset, 200, shuffle=False)

In [None]:
model.to(device)
model.reset_parameters()
optimizer = Adam(model.parameters(), lr=LR, weight_decay=0)

In [None]:
train_losslist =[]
test_losslist = []
for epoch in range(1, EPOCHS+1):
    model.train()
    train_loss_all = 0

    for train_batch in train_loader:
        optimizer.zero_grad()
        train_batch = train_batch.to(device)
        y_pred = model(train_batch)
        y_true = train_batch.y
        train_loss = F.mse_loss(y_pred, y_true)
        train_loss.backward()
        train_loss_all += BATCH_SIZE * float(train_loss)
        optimizer.step()
        torch.cuda.empty_cache()
    train_loss_all = train_loss_all / len(train_loader.dataset)
    train_losslist.append(train_loss_all)
    test_loss = 0
    for test_batch in test_loader:
          test_batch = test_batch.to(device)
          with torch.no_grad():
              y_pred = model(test_batch)
          y_true = test_batch.y
          test_loss += F.mse_loss(y_pred, y_true, reduction='sum')
          # torch.cuda.empty_cache()
    mse_loss = float(test_loss) / len(test_loader.dataset)
    test_losslist.append(mse_loss)
    print('epoch', epoch,'; train loss', train_loss_all,'; test loss',mse_loss)

    if epoch % LR_DECAY_STEP == 0:
      for param_group in optimizer.param_groups:
          param_group['lr'] = param_group['lr'] / LR_DECAY_VALUE

In [None]:
model.eval()
test_loss = 0
for test_batch in test_loader:
    test_batch = test_batch.to(device)
    with torch.no_grad():
        y_pred = model(test_batch)
    y_true = test_batch.y
    test_loss += F.mse_loss(y_pred, y_true, reduction='sum')
    # torch.cuda.empty_cache()
mse_loss = float(test_loss) / len(test_loader.dataset)

print('test MSE loss', mse_loss)
print('test RMSE loss', math.sqrt(mse_loss))

In [None]:
import matplotlib.pyplot as plt
iters = [iter for iter in range(len(train_losslist))]
plt.plot(iters, train_losslist, label='train')
plt.plot(iters, test_losslist, label='test')
plt.xlabel('epochs')
plt.ylabel('loss')
plt.title('training and test loss curves')
plt.legend()
plt.show()