In [None]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')
print('Google Drive mounted successfully!')

In [None]:
# Download data directory from GitHub repository
import os
import subprocess

# Create target directory
os.makedirs('/content/drive/MyDrive/ILR-IR', exist_ok=True)

# Clone repository and copy data
try:
    subprocess.run(['git', 'clone', 'https://github.com/pooyaht/ILR-IR.git', '/tmp/ILR-IR'], check=True)
    print('Repository cloned successfully')
    
    # Copy data using shell command
    !cp -vr /tmp/ILR-IR/data /content/drive/MyDrive/ILR-IR/
    
except:
    print('Git clone failed. Please manually download the data from: https://github.com/pooyaht/ILR-IR/tree/main/data')
finally:
    # Change to the project directory
    os.chdir('/content/drive/MyDrive/ILR-IR/')
    print(f'Changed working directory to: {os.getcwd()}')
    print('Data downloaded and setup completed successfully!')

In [None]:
# File: pyHGT/data.py
# Source: ./pyHGT/data.py

import json, os
import math, copy, time
import numpy as np
from collections import defaultdict
import pandas as pd

import math
from tqdm import tqdm

import seaborn as sb
import matplotlib.pyplot as plt
import matplotlib.cm as cm

import dill
from functools import partial
import multiprocessing as mp
import torch.nn.functional as F

class Graph():
    def __init__(self):
        super(Graph, self).__init__()
        '''
        self.t_r_id_p_dict_train-->time:relaton:quadid:path list(2维，第一行为正确的triple之间的所有path)
        self.t_max_num-->每个时隙内所有triple之间的最大路径数
        '''
        '''

        self.t_r_id_p_dict_train = defaultdict(lambda: {})
        self.t_r_id_p_dict_valid = defaultdict(lambda: {})
        self.t_r_id_p_dict_test = defaultdict(lambda: {})

        self.t_paths_train = defaultdict(lambda: [])
        self.t_paths_valid = defaultdict(lambda: [])
        self.t_paths_test = defaultdict(lambda: [])

        self.t_paths_len_train = defaultdict(lambda: [])
        self.t_paths_len_valid = defaultdict(lambda: [])
        self.t_paths_len_test = defaultdict(lambda: [])

        self.t_max_num_train = {}
        self.t_max_num_valid = {}
        self.t_max_num_test = {}
        '''

        self.t_r_id_p_dict = defaultdict(lambda: {})
        self.t_r_id_target_dict = defaultdict(lambda: {})

        self.r_copy = defaultdict(lambda: {})
        #self.r_copy_t = defaultdict(lambda: {})


        self.t_paths = defaultdict(lambda: [])

        self.t_paths_len = defaultdict(lambda: [])

        self.t_paths_time = defaultdict(lambda: [])
        self.t_paths_m_time = defaultdict(lambda: [])
    
class RenameUnpickler(dill.Unpickler):
    def find_class(self, module, name):
        renamed_module = module
        if module == "GPT_GNN.data" or module == 'data':
            renamed_module = "pyHGT.data"
        return super(RenameUnpickler, self).find_class(renamed_module, name)


def renamed_load(file_obj):
    return RenameUnpickler(file_obj).load()


In [None]:
# File: pyHGT/model.py
# Source: ./pyHGT/model.py

import torch
import torch.nn as nn
from torch.nn.utils.rnn import pack_padded_sequence
import torch.nn.functional as F
import math


def get_device():
    if torch.backends.mps.is_available():
        device = torch.device('mps')
        return device, True
    elif torch.cuda.is_available():
        device = torch.device('cuda')
        return device, True
    else:
        device = torch.device('cpu')
        return device, False


DEVICE, HAS_ACCELERATION = get_device()
CUDA = HAS_ACCELERATION


class RelTemporalEncoding(nn.Module):

    def __init__(self, n_hid, max_len=4020, dropout=0.2):
        super(RelTemporalEncoding, self).__init__()
        position = torch.arange(0., max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, n_hid, 2) *
                             -(math.log(10000.0) / n_hid))
        emb = nn.Embedding(max_len, n_hid)
        emb.weight.data[:, 0::2] = torch.sin(
            position * div_term) / math.sqrt(n_hid)
        emb.weight.data[:, 1::2] = torch.cos(
            position * div_term) / math.sqrt(n_hid)
        emb.requires_grad_(False)
        self.emb = emb
        self.lin = nn.Linear(n_hid, n_hid)

    def forward(self, x, t):
        return x + self.lin(self.emb(t))


class TypeGAT(nn.Module):
    def __init__(self, num_e, num_r, relation_embeddings, out_dim):
        super(TypeGAT, self).__init__()

        self.num_e = num_e
        self.num_r = num_r
        self.in_dim = relation_embeddings.shape[1]
        self.out_dim = out_dim

        self.pad = torch.zeros(1, self.out_dim)
        if HAS_ACCELERATION:
            self.pad = self.pad.to(DEVICE)

        self.relation_embeddings = nn.Parameter(relation_embeddings)
        self.emb = RelTemporalEncoding(self.out_dim)

        self.gru = nn.GRU(input_size=self.in_dim,
                          hidden_size=self.out_dim, batch_first=True)

        self.gru.reset_parameters()

    def forward2(self, path_index, batch_relation, paths, paths_time, lengths, path_r, path_neg_index, batch_his_r):
        r_inp = self.relation_embeddings

        # update relations r<-path
        pad_r = torch.cat((r_inp, self.pad), dim=0)
        emb = pad_r[paths]
        emb = self.emb(emb, paths_time)  # temporal information

        lengths_cpu = lengths.cpu()
        packed = pack_padded_sequence(
            emb, lengths_cpu, batch_first=True, enforce_sorted=False).to(paths.device)
        _, hidden = self.gru(packed)

        path_emb = torch.cat((self.pad, hidden.squeeze(0)), dim=0)
        del emb, packed, paths

        pad_r = torch.cat((F.normalize(r_inp, dim=1),
                          self.pad.to(r_inp.device)), dim=0)
        # pad_r = F.normalize(pad_r, dim=1)
        path_emb = F.normalize(path_emb, dim=1)

        # batch*num_paths
        scores = torch.mm(path_emb, pad_r[batch_relation].t()).t()
        mask = torch.zeros((scores.size(0), scores.size(1))).to(scores.device)
        m_index = min(path_index.size(1), mask.size(1))
        mask = mask.scatter(1, path_index[:, 0:m_index], 1)
        max_score, max_id = torch.max(scores * mask, 1)

        scores_r = torch.mm(pad_r, pad_r.t())[batch_relation]
        his_score = torch.mean(torch.diagonal(
            scores_r[:, batch_his_r], dim1=0, dim2=1).t(), 1)

        # scores = torch.mm(path_emb, pad_r[batch_relation[0]].unsqueeze(1)).squeeze(1)
        # max_score, max_id = torch.max(scores[path_index], 1)

        # scores_r = torch.mm(pad_r, pad_r[batch_relation[0]].unsqueeze(1)).squeeze(1)
        # his_score = torch.mean(scores_r[batch_his_r], 1)

        # score = max_score+his_score
        score = max_score

        return score, path_emb[path_neg_index], pad_r[path_r]

    def test(self, path_index, batch_relation, paths, lengths, paths_time, batch_his_r):
        r_inp = self.relation_embeddings

        pad = torch.zeros(1, self.out_dim)

        # update relations r<-path
        pad_r = torch.cat((r_inp, pad.to(r_inp.device)), dim=0)
        emb = pad_r[paths]
        emb = self.emb(emb, paths_time)  # temporal information
        lengths_cpu = lengths.cpu()
        packed = pack_padded_sequence(
            emb, lengths_cpu, batch_first=True, enforce_sorted=False)
        _, hidden = self.gru(packed)
        path_emb = torch.cat(
            (self.pad.to(r_inp.device), hidden.squeeze(0)), dim=0)

        del emb, packed, paths
        pad_r = torch.cat((F.normalize(r_inp, dim=1),
                          pad.to(r_inp.device)), dim=0)
        path_emb = F.normalize(path_emb, dim=1)

        scores = torch.mm(
            path_emb, pad_r[batch_relation[0]].unsqueeze(1)).squeeze(1)
        max_score, max_id = torch.max(scores[path_index], 1)

        scores_r = torch.mm(
            pad_r, pad_r[batch_relation[0]].unsqueeze(1)).squeeze(1)
        his_score = torch.mean(scores_r[batch_his_r], 1)

        score = max_score + his_score
        # score = his_score

        return score

    def __repr__(self):
        return self.__class__.__name__


In [None]:
# File: dataprocess.py
# Source: ./dataprocess.py

import os
import numpy as np
import networkx as nx
from collections import defaultdict
from more_itertools import flatten
from sklearn.utils import shuffle
from operator import itemgetter
import dill

import argparse


def parse_args():
    args = argparse.ArgumentParser()
    args.add_argument("-data", "--data",
                      default="./data/ICEWS14_forecasting", help="data directory")
    args.add_argument("-state", "--state",
                      default="train", help="train or test")
    args.add_argument("-neg_ratio", "--ratio",
                      default=1, type=int, help="training neg ratio")
    args.add_argument("-his_len", "--his_len",
                      default=50, type=int, help="1 hop historial relations",)

    # Notebook-compatible argument parsing
    import sys
    if 'ipykernel' in sys.modules:
        # Running in notebook - use defaults
        args = argparse.Namespace(
            data='./data/ICEWS14_forecasting',
            state='train',
            ratio=1,
            his_len=13
        )
    else:
        # Running as script - use command line args
        args = args.parse_args()
    return args


# Notebook-compatible argument setup
import sys
if 'ipykernel' in sys.modules:
    # Running in notebook - create args with defaults
    import argparse
    args = argparse.Namespace(
        data='./data/ICEWS14_forecasting',
        state='train',
        ratio=1,
        his_len=13
    )
else:
    # Running as script - use command line args
    args = parse_args()


def all_simple_edge_paths(G, source, target, cutoff=None):
    if source not in G:
        raise nx.NodeNotFound("source node %s not in graph" % source)
    if target in G:
        targets = {target}
    else:
        try:
            targets = set(target)
        except TypeError:
            raise nx.NodeNotFound("target node %s not in graph" % target)
    if source in targets:
        return []
    if cutoff is None:
        cutoff = len(G) - 1
    if cutoff < 1:
        return []
    if G.is_multigraph():
        for simp_path in _all_simple_edge_paths_multigraph(G, source, targets, cutoff):
            yield simp_path
    else:
        for simp_path in _all_simple_paths_graph(G, source, targets, cutoff):
            yield list(zip(simp_path[:-1], simp_path[1:]))


def _all_simple_edge_paths_multigraph(G, source, targets, cutoff):
    if not cutoff or cutoff < 1:
        return []
    visited = [source]
    stack = [iter(G.edges(source, keys=True))]

    while stack:
        children = stack[-1]
        child = next(children, None)
        if child is None:
            stack.pop()
            visited.pop()
        elif len(visited) < cutoff:
            if child[1] in targets:
                yield visited[1:] + [child]
            if child[1] not in [v[0] for v in visited[1:]]:
                visited.append(child)
                stack.append(iter(G.edges(child[1], keys=True)))
        else:  # len(visited) == cutoff:
            for (u, v, k) in [child] + list(children):
                if v in targets:
                    yield visited[1:] + [(u, v, k)]
            stack.pop()
            visited.pop()


def _all_simple_paths_graph(G, source, targets, cutoff):
    if cutoff < 1:
        return

    visited = [source]
    stack = [iter(G[source])]

    while stack:
        children = stack[-1]
        child = next(children, None)

        if child is None:
            stack.pop()
            visited.pop()
        elif len(visited) < cutoff:
            if child in targets:
                yield visited + [child]

            if child not in visited:
                visited.append(child)
                stack.append(iter(G[child]))
        else:
            for neighbor in [child] + list(children):
                if neighbor in targets and neighbor not in visited:
                    yield visited + [neighbor]
            stack.pop()
            visited.pop()


def parse_line(line):
    line = line.strip().split()
    e1, relation, e2 = line[0].strip(), line[1].strip(), line[2].strip()
    return e1, relation, e2


def build_data(path, num_r):

    t_quads = {}
    t_quads_re = {}

    all_triples = set()
    quads_id = {}
    with open(os.path.join(path, 'data.txt'), 'r') as fr:
        times = set()
        for i, line in enumerate(fr):
            line_split = line.split()
            time = int(line_split[3])
            times.add(time)

            e1, relation, e2 = int(line_split[0]), int(
                line_split[1]), int(line_split[2])

            all_triples.add((e1, relation, e2))
            all_triples.add((e2, relation+num_r, e1))

            t_quads.setdefault(time, []).append((e1, relation, e2))
            t_quads_re.setdefault(time, []).append((e2, relation+num_r, e1))

        all_triples = list(all_triples)
        for i, triple in enumerate(all_triples):
            quads_id[triple] = i

    all_times = list(times)
    all_times.sort()

    t_quadid = {}
    t_quadid_re = {}
    for t in all_times:
        for i in t_quads[t]:
            t_quadid.setdefault(t, []).append(quads_id[i])
        for j in t_quads_re[t]:
            t_quadid_re.setdefault(t, []).append(quads_id[j])
    print("number of triples ->", len(all_triples))

    return t_quadid, t_quadid_re, all_triples, all_times


class Corpus:
    def __init__(self, args, all_triples, num_e, num_r):

        self.all_triples = all_triples
        self.num_e = num_e
        self.num_r = num_r

    def get_neg_triples(self, args, all_times, t_quads, test_idx):
        quads_select = {}  # quad_id, cur_id, shortest length
        quads_neg = {}
        G = nx.Graph()
        if args.state == 'train':
            # if flag == 0:
            times = range(test_idx)

        else:
            times = range(test_idx, len(all_times))

            keys_his = all_times[:test_idx]

            quads_his = list(itemgetter(*keys_his)(t_quads))
            triples_his = np.array(self.all_triples)[
                list(set(flatten(quads_his)))]
            G.add_edges_from(triples_his[:, [0, 2]])

        for idx in tqdm(times):

            quads = t_quads[all_times[idx]]
            quad_id = []
            neg_len = defaultdict(lambda: [])
            pre = 0
            valid_triples = np.array(self.all_triples)[quads].tolist()

            for i, quad in enumerate(quads):
                triple = self.all_triples[quad]

                try:
                    pred = nx.predecessor(G, triple[0], triple[2], 3)
                    if len(pred) > 0:
                        length = nx.shortest_path_length(
                            G, triple[0], triple[2])
                    else:
                        # current facts
                        # G.add_edge(triple[0], triple[2])
                        continue
                except:
                    pass
                else:
                    if length < 4:
                        quad_id.append([quad, pre, i, length])
                        pre = i

                        paths_len = nx.single_source_shortest_path_length(
                            G, triple[0], 3)
                        del paths_len[triple[0]]
                        if args.state == 'train':
                            ids = shuffle(list(paths_len.keys()))[
                                :min(len(paths_len), 3)]
                            for target in ids:
                                l = paths_len[target]
                                if [triple[0], triple[1], target] not in valid_triples:
                                    neg_len[quad].append([target, l])
                        else:
                            for target, l in paths_len.items():
                                if [triple[0], triple[1], target] not in valid_triples:
                                    neg_len[quad].append([target, l])

                # current facts
                # G.add_edge(triple[0], triple[2])
            if len(quad_id) != 0 and len(neg_len) != 0:
                quads_select[idx] = quad_id
                quads_neg[idx] = neg_len

            triples_his = np.array(self.all_triples)[t_quads[all_times[idx]]]
            G.add_edges_from(triples_his[:, [0, 2]])

        return quads_select, quads_neg

    def get_path_test(self, args, G, s, targets, lens, cur_time, num_r):
        target_pid = defaultdict(list)
        target_his_pid = defaultdict(list)  # s,o之间的历史交互关系

        try:
            paths = []
            for i in range(len(targets)):
                c_graph = G.subgraph([s, targets[i]])
                sG = G.subgraph(c_graph)
                paths.extend(list(all_simple_edge_paths(sG, s, targets[i], 3)))
                # paths.extend(list(all_simple_edge_paths(sG, s, targets[i], max(lens[i],2))))
                # paths.extend(list(all_simple_edge_paths(sG, s, targets[i], lens[i])))

            path_len = [len(path) for path in paths]
            p_id = np.argsort(path_len)

            tar_dict = {}
        except:  # 可省略错误类型
            print('sample error')
        else:  # 没有错误的话继续执行下面的程序
            if len(paths) != 0:
                # print(paths)
                for id in p_id:
                    path = paths[id]

                    t = np.array(path)[-1][1]
                    pa = np.array(path)[:, 2]
                    pa_t = [cur_time - G.edges[p]['time']
                            for p in path]  # 相对时间

                    if t not in tar_dict.keys():
                        tar_dict[t] = [len(pa), max(pa_t)]
                    elif max(pa_t) < tar_dict[t][1]:
                        tar_dict[t][1] = max(pa_t)
                    elif len(pa) > tar_dict[t][0] and max(pa_t) > tar_dict[t][1]:
                        continue

                    if len(pa) == 3:
                        pl = self.pathlen_3[(pa[0], pa_t[0])][(
                            pa[1], pa_t[1])][(pa[2], pa_t[2])]
                        if pl == 0:
                            self.paths.append(pa)
                            self.paths_time.append(pa_t)
                            self.lengths.append(len(pa))
                            self.paths_m_time.append(max(pa_t))
                            target_pid[t].append(len(self.paths))
                            self.pathlen_3[(pa[0], pa_t[0])][(pa[1], pa_t[1])][(
                                pa[2], pa_t[2])] = len(self.paths)
                        else:
                            if pl not in target_pid[t]:
                                target_pid[t].append(pl)
                    elif len(pa) == 2:
                        pl = self.pathlen_2[(pa[0], pa_t[0])][(pa[1], pa_t[1])]
                        if pl == 0:
                            self.paths.append(pa)
                            self.paths_time.append(pa_t)
                            self.lengths.append(len(pa))
                            self.paths_m_time.append(max(pa_t))
                            target_pid[t].append(len(self.paths))
                            self.pathlen_2[(pa[0], pa_t[0])][(
                                pa[1], pa_t[1])] = len(self.paths)
                        else:
                            if pl not in target_pid[t]:
                                target_pid[t].append(pl)
                    elif len(pa) == 1:
                        pl = self.pathlen_1[(pa[0], pa_t[0])]
                        if pa_t[0] <= args.his_len:
                            target_his_pid[t].append(pa[0])
                        if pl == 0:
                            self.paths.append(pa)
                            self.paths_time.append(pa_t)
                            self.lengths.append(len(pa))
                            self.paths_m_time.append(max(pa_t))
                            target_pid[t].append(len(self.paths))
                            self.pathlen_1[(pa[0], pa_t[0])] = len(self.paths)

                        else:
                            if pl not in target_pid[t]:
                                target_pid[t].append(pl)
        target_pid_sort = defaultdict(list)
        for t in target_pid.keys():
            if t not in target_his_pid.keys():
                target_pid_sort[t] = [num_r * 2]
            else:
                target_pid_sort[t] = target_his_pid[t]

        return target_pid, target_pid_sort

    def get_iteration_batch(self, args, G, batch_quads, negs, quads_cur, cur_time, num_r):

        self.paths = []
        self.lengths = []
        self.paths_time = []
        self.paths_m_time = []

        self.pathlen_1 = defaultdict(int)
        self.pathlen_2 = defaultdict(lambda: defaultdict(int))
        self.pathlen_3 = defaultdict(lambda:
                                     defaultdict(lambda:
                                                 defaultdict(int)))

        paths_dict = defaultdict(lambda: defaultdict(list))
        targets_dict = defaultdict(lambda: defaultdict(list))

        paths_dict_copy = defaultdict(lambda: defaultdict(list))

        for quad, pre, pid, length in tqdm(batch_quads):
            target = []
            lens = []
            s, r, o = self.all_triples[quad]
            # if pid >= 0:
            #    edges = np.array(self.all_triples)[quads_cur[pre:pid]][:, [0, 2, 1]]
            #    G.add_edges_from(edges, time=cur_time)

            target.append(o)
            lens.append(length)
            neg = np.array(negs[quad])

            if len(neg) > 0:
                if args.state == 'train':
                    neg = shuffle(negs[quad])
                    neg_num = min(len(neg), args.ratio)
                    t_l = np.array(neg[:neg_num])
                    t_neg = t_l[:, [0]]
                    target.extend(t_neg.reshape(-1).tolist())

                    l_neg = t_l[:, [1]]
                    lens.extend(l_neg.reshape(-1).tolist())
                elif args.state == 'test':
                    t_neg = neg[:, [0]]
                    target.extend(t_neg.reshape(-1).tolist())

                    l_neg = neg[:, [1]]
                    lens.extend(l_neg.reshape(-1).tolist())

            subnodes = []
            subnodes.append(s)
            subnodes.extend(target)
            graph1 = G.subgraph(subnodes)
            H = G.subgraph(list(graph1.nodes()))

            target_pid, target_his_pid = self.get_path_test(
                args, H, s, target, lens, cur_time, num_r)
            if o not in target_pid.keys():
                continue

            # pos triple

            paths_dict[r][quad].append(target_pid[o])
            targets_dict[r][quad].append(o)

            paths_dict_copy[r][quad].append(target_his_pid[o])
            del target_pid[o], target_his_pid[o]

            # neg_triples
            if len(target_pid.keys()) == 0:
                continue

            paths_dict[r][quad].extend(list(target_pid.values()))
            targets_dict[r][quad].extend(list(target_pid.keys()))

            paths_dict_copy[r][quad].extend(list(target_his_pid.values()))

        del self.pathlen_1, self.pathlen_2, self.pathlen_3

        return paths_dict, targets_dict, self.paths, self.lengths, self.paths_time, paths_dict_copy, self.paths_m_time


def main():
    with open(os.path.join('{}'.format(args.data), 'stat.txt'), 'r') as fr:
        for line in fr:
            line_split = line.split()
            num_e, num_r = int(line_split[0]), int(line_split[1])

    t_quads, t_quads_re, all_triples, all_times = build_data(args.data, num_r)

    with open(os.path.join('{}'.format(args.data), 'split.txt'), 'r') as fr:
        for line in fr:
            line_split = line.split()

            valid_start = int(line_split[1].split(',')[0])

            test_start = int(line_split[2].split(',')[0])

    time_list = list(all_times)
    valid_idx = time_list.index(valid_start)
    test_idx = time_list.index(test_start)

    Corpus_ = Corpus(args, all_triples, num_e, num_r)

    graph_train = Graph()
    print('sample')

    if not os.path.exists(os.path.join(args.data + '/quads_select.pk')):
        quads_select, quads_neg = Corpus_.get_neg_triples(
            args, all_times, t_quads, test_idx)
        dill.dump(quads_select, open(args.data + '/quads_select.pk', 'wb'))
        dill.dump(quads_neg, open(args.data + '/quads_neg.pk', 'wb'))
    else:
        quads_select = renamed_load(
            open(os.path.join(args.data + '/quads_select.pk'), 'rb'))
        quads_neg = renamed_load(
            open(os.path.join(args.data + '/quads_neg.pk'), 'rb'))
    G = nx.MultiDiGraph()

    print('train')

    for idx in range(valid_idx):
        if idx < 1:
            continue
        triples_his = np.array(all_triples)[t_quads[all_times[idx - 1]]]
        G.add_edges_from(triples_his[:, [0, 2, 1]], time=idx-1)

        triples_his_re = np.array(all_triples)[t_quads_re[all_times[idx - 1]]]
        G.add_edges_from(triples_his_re[:, [0, 2, 1]], time=idx - 1)
        quads_cur = t_quads[all_times[idx]]
        if args.state == 'train':

            try:
                quads = quads_select[idx]
                negs = quads_neg[idx]
            except:
                continue

            paths_dict, targets_dict, paths, lengths, paths_time, paths_dict_copy, paths_m_time = Corpus_.get_iteration_batch(
                args, G, quads, negs, quads_cur, idx, num_r)
            graph_train.t_r_id_p_dict[idx] = paths_dict
            graph_train.t_r_id_target_dict[idx] = targets_dict
            graph_train.t_paths[idx] = paths
            graph_train.t_paths_len[idx] = lengths
            graph_train.t_paths_time[idx] = paths_time
            graph_train.t_paths_m_time[idx] = paths_m_time

            graph_train.r_copy[idx] = paths_dict_copy
            print(idx, len(lengths))
    if args.state == 'train':
        dill.dump(graph_train, open(
            args.data + '/graph_preprocess_train.pk', 'wb'))
        del graph_train

    print('valid')
    graph_valid = Graph()
    for idx in range(valid_idx, test_idx):
        triples_his = np.array(all_triples)[t_quads[all_times[idx - 1]]]
        G.add_edges_from(triples_his[:, [0, 2, 1]], time=idx - 1)

        triples_his_re = np.array(all_triples)[t_quads_re[all_times[idx - 1]]]
        G.add_edges_from(triples_his_re[:, [0, 2, 1]], time=idx - 1)
        if args.state == 'train':

            quads_cur = t_quads[all_times[idx]]

            try:
                quads = quads_select[idx]
                negs = quads_neg[idx]
            except:
                continue

            paths_dict, targets_dict, paths, lengths, paths_time, paths_dict_copy, paths_m_time = Corpus_.get_iteration_batch(
                args, G, quads, negs, quads_cur, idx, num_r)
            graph_valid.t_r_id_p_dict[idx] = paths_dict
            graph_valid.t_r_id_target_dict[idx] = targets_dict

            graph_valid.t_paths[idx] = paths
            graph_valid.t_paths_len[idx] = lengths
            graph_valid.t_paths_time[idx] = paths_time
            graph_valid.t_paths_m_time[idx] = paths_m_time

            graph_valid.r_copy[idx] = paths_dict_copy
            print(idx, len(lengths))

    if args.state == 'train':
        dill.dump(graph_valid, open(
            args.data + '/graph_preprocess_valid.pk', 'wb'))
        del graph_valid

    else:
        graph_test = Graph()
        if not os.path.exists(os.path.join(args.data + '/quads_select_test.pk')):
            quads_select, quads_neg = Corpus_.get_neg_triples(
                args, all_times, t_quads, test_idx)
            dill.dump(quads_select, open(
                args.data + '/quads_select_test.pk', 'wb'))
            dill.dump(quads_neg, open(args.data + '/quads_neg_test.pk', 'wb'))
        else:
            quads_select = renamed_load(
                open(os.path.join(args.data + '/quads_select_test.pk'), 'rb'))
            quads_neg = renamed_load(
                open(os.path.join(args.data + '/quads_neg_test.pk'), 'rb'))

        quads_num = 0
        quads_select_num = 0
        quads_select_neg = 0
        print('test')
        for idx in range(test_idx, len(all_times)):
            triples_his = np.array(all_triples)[t_quads[all_times[idx - 1]]]
            G.add_edges_from(triples_his[:, [0, 2, 1]], time=idx - 1)
            triples_his_re = np.array(all_triples)[
                t_quads_re[all_times[idx - 1]]]
            G.add_edges_from(triples_his_re[:, [0, 2, 1]], time=idx - 1)
            quads_cur = t_quads[all_times[idx]]
            quads_num = quads_num+len(quads_cur)

            try:
                quads = quads_select[idx]
                negs = quads_neg[idx]
                quads_select_num = quads_select_num + len(quads)
                quads_select_neg = quads_select_neg + len(negs)
            except:
                continue

            paths_dict, targets_dict, paths, lengths, paths_time, paths_dict_copy, paths_m_time = Corpus_.get_iteration_batch(
                args, G, quads, negs, quads_cur, idx, num_r)
            graph_test.t_r_id_p_dict[idx] = paths_dict
            graph_test.t_r_id_target_dict[idx] = targets_dict

            graph_test.t_paths[idx] = paths
            graph_test.t_paths_len[idx] = lengths
            graph_test.t_paths_time[idx] = paths_time
            graph_test.t_paths_m_time[idx] = paths_m_time

            graph_test.r_copy[idx] = paths_dict_copy
            print(idx, len(lengths))

        del quads_select, quads_neg
        print("select quads:", quads_select_num)
        print("select neg:", quads_select_neg)
        print("all quads:", quads_num)

        dill.dump(graph_test, open(
            args.data + '/graph_preprocess_test.pk', 'wb'))


if __name__ == '__main__':
    main()


In [None]:
# File: main_pre.py
# Source: ./main_pre.py

import torch
from torch.autograd import Variable
import torch.nn as nn
import numpy as np

import argparse
import os
import time

from torch.nn.utils.rnn import pad_sequence

os.environ['CUDA_VISIBLE_DEVICES'] = '0'


def get_device():
    if torch.backends.mps.is_available():
        device = torch.device('mps')
        print("Using Apple MPS")
        return device, True
    elif torch.cuda.is_available():
        device = torch.device('cuda')
        print("Using CUDA")
        return device, True
    else:
        device = torch.device('cpu')
        print("Using CPU")
        return device, False


DEVICE, HAS_ACCELERATION = get_device()
CUDA = HAS_ACCELERATION


def parse_args():
    args = argparse.ArgumentParser()

    args.add_argument("-data", "--data",
                      default="./data/ICEWS14_forecasting", help="data directory")
    args.add_argument('--dataset', type=str, default='ICEWS14_forecasting')
    args.add_argument("-e_c", "--epochs_conv", type=int,
                      default=100, help="Number of epochs")
    args.add_argument("-w_conv", "--weight_decay_conv", type=float,
                      default=1e-6, help="L2 reglarization for conv")
    args.add_argument("-emb_size", "--embedding_size", type=int,
                      default=200, help="Size of embeddings (if pretrained not used)")
    args.add_argument("-l", "--lr", type=float, default=1e-4)

    # Notebook-compatible argument parsing
    import sys
    if 'ipykernel' in sys.modules:
        # Running in notebook - use defaults
        args = argparse.Namespace(
            data='./data/ICEWS14_forecasting',
            state='train',
            ratio=1,
            his_len=13
        )
    else:
        # Running as script - use command line args
        args = args.parse_args()
    return args


def save_model(model, name, folder_name, epoch):
    print("Saving Model")
    torch.save(model.state_dict(),
               (folder_name + "trained"+str(epoch)+".pth"))
    print("Done saving Model")


def mkdirs(path):
    if not os.path.exists(path):
        os.makedirs(path)


# Notebook-compatible argument setup
import sys
if 'ipykernel' in sys.modules:
    # Running in notebook - create args with defaults
    import argparse
    args = argparse.Namespace(
        data='./data/ICEWS14_forecasting',
        state='train',
        ratio=1,
        his_len=13
    )
else:
    # Running as script - use command line args
    args = parse_args()

mkdirs('./results/bestmodel/{}/conv'.format(args.dataset))
mkdirs('./results/bestmodel/{}/gat'.format(args.dataset))
model_state_file = './results/bestmodel/{}/'.format(args.dataset)


def load_data(args):
    with open(os.path.join('{}'.format(args.data), 'stat.txt'), 'r') as fr:
        for line in fr:
            line_split = line.split()
            num_e, num_r = int(line_split[0]), int(line_split[1])

    relation_embeddings = np.random.randn(num_r * 2, args.embedding_size)
    print("Initialised relations and entities randomly")
    return num_e, num_r, torch.FloatTensor(relation_embeddings)


num_e, num_r, relation_embeddings = load_data(args)

print("Initial relation dimensions {}".format(relation_embeddings.size()))


def list_to_array(x, pad):
    dff = pd.concat([pd.DataFrame({'{}'.format(index): labels})
                    for index, labels in enumerate(x)], axis=1)
    return dff.fillna(pad).values.T.astype(int)


def train_conv(args):
    print("Defining model")
    model = TypeGAT(num_e, num_r*2, relation_embeddings, args.embedding_size)
    if HAS_ACCELERATION:
        model.to(DEVICE)

    optimizer = torch.optim.Adam(
        model.parameters(), lr=args.lr, weight_decay=args.weight_decay_conv)

    scheduler = torch.optim.lr_scheduler.StepLR(
        optimizer, step_size=25, gamma=0.5, last_epoch=-1)

    margin_loss = torch.nn.SoftMarginLoss()
    cosine_loss = nn.CosineEmbeddingLoss(margin=0.0)

    epoch_losses = []
    print("Number of epochs {}".format(args.epochs_conv))

    graph_train = renamed_load(
        open(os.path.join(args.data + '/graph_preprocess_train.pk'), 'rb'))

    total_time_steps = len(graph_train.t_r_id_p_dict)
    print(f"Total time steps per epoch: {total_time_steps}")

    for epoch in range(args.epochs_conv):
        print(f"\n{'='*50}")
        print(f"EPOCH {epoch + 1}/{args.epochs_conv}")
        print(f"{'='*50}")

        model.train()
        start_time = time.time()
        epoch_loss = []

        time_steps = list(graph_train.t_r_id_p_dict.items())
        progress_bar = tqdm(
            time_steps, desc=f"Epoch {epoch+1}", unit="timestep")

        batch_count = 0
        loss_sum = 0.0

        for step_idx, (t, r_dict) in enumerate(progress_bar):
            step_start_time = time.time()

            batch_values = []
            batch_paths_id = []
            batch_relation = []
            batch_his_r = []

            # for neg paths
            path_r = []
            path_neg_index = []

            # Count total batches in this time step
            total_relations = len(r_dict)
            relations_processed = 0

            for r, id_p in r_dict.items():
                len_r = 0
                p_neg_temp = []
                for id, ps in id_p.items():
                    len_r = len_r + len(ps)

                    value = [-1] * len(ps)
                    value[0] = 1
                    batch_values.extend(value)
                    batch_paths_id.extend(ps)

                    batch_his_r.extend(graph_train.r_copy[t][r][id])
                    if len(ps) > 1:
                        p_neg_temp.extend(
                            list(eval('['+str(ps[1:]).replace("[", '').replace("]", '')+']')))

                batch_relation.extend([r]*len_r)
                path_r.extend([r]*len(p_neg_temp))
                path_neg_index.extend(p_neg_temp)

                relations_processed += 1

            if len(batch_values) > 0:
                path_values = [-1]*len(path_r)
                path_values = torch.FloatTensor(
                    np.expand_dims(np.array(path_values), axis=1))
                path_neg_index = torch.LongTensor(np.array(path_neg_index))
                path_r = torch.LongTensor(np.array(path_r))

                batch_paths_id = torch.LongTensor(
                    list_to_array(batch_paths_id, 0))
                batch_relation = torch.LongTensor(np.array(batch_relation))
                batch_values = torch.FloatTensor(
                    np.expand_dims(np.array(batch_values), axis=1))
                batch_his_r = torch.LongTensor(
                    list_to_array(batch_his_r, num_r*2))

                paths = graph_train.t_paths[t]
                paths_time = graph_train.t_paths_time[t]
                lengths = graph_train.t_paths_len[t]

                if len(paths) != 0:
                    paths = pad_sequence([torch.LongTensor(np.array(p)) for p in paths], batch_first=True,
                                         padding_value=num_r*2)
                    paths_time = pad_sequence([torch.LongTensor(np.array(p)) for p in paths_time], batch_first=True,
                                              padding_value=0)
                else:
                    paths = torch.LongTensor(np.array(paths))
                    paths_time = torch.LongTensor(np.array(paths_time))
                lengths = torch.LongTensor(np.array(lengths))

                if HAS_ACCELERATION:
                    batch_paths_id = Variable(batch_paths_id).to(DEVICE)
                    batch_relation = Variable(batch_relation).to(DEVICE)
                    batch_his_r = Variable(batch_his_r).to(DEVICE)
                    paths = Variable(paths).to(DEVICE)
                    paths_time = Variable(paths_time).to(DEVICE)
                    lengths = Variable(lengths).to(DEVICE)
                    path_r = Variable(path_r).to(DEVICE)
                    path_neg_index = Variable(path_neg_index).to(DEVICE)
                    batch_values = Variable(batch_values).to(DEVICE)
                    path_values = Variable(path_values).to(DEVICE)
                else:
                    batch_paths_id = Variable(batch_paths_id)
                    batch_relation = Variable(batch_relation)
                    batch_his_r = Variable(batch_his_r)
                    paths = Variable(paths)
                    paths_time = Variable(paths_time)
                    lengths = Variable(lengths)
                    path_r = Variable(path_r)
                    path_neg_index = Variable(path_neg_index)
                    batch_values = Variable(batch_values)
                    path_values = Variable(path_values)

                optimizer.zero_grad()

                try:
                    preds, p_emb, r_emb = model.forward2(batch_paths_id, batch_relation, paths, paths_time, lengths, path_r,
                                                         path_neg_index, batch_his_r)

                    del batch_paths_id, batch_relation, paths, paths_time, lengths, path_r, path_neg_index

                    loss_e = margin_loss(preds.view(-1), batch_values.view(-1))
                    loss_f = cosine_loss(p_emb, r_emb, path_values.view(-1))
                    del preds, p_emb, r_emb

                    loss = loss_e  # You can add loss_f back if needed

                    if torch.isnan(loss):
                        print(
                            f"Warning: NaN loss detected at timestep {t}, skipping...")
                        continue

                    loss.backward()
                    optimizer.step()

                    batch_count += 1
                    current_loss = loss.item()
                    epoch_loss.append(current_loss)
                    loss_sum += current_loss

                    avg_loss = loss_sum / batch_count
                    step_time = time.time() - step_start_time

                    progress_bar.set_postfix({
                        'Loss': f'{current_loss:.4f}',
                        'Avg_Loss': f'{avg_loss:.4f}',
                        'Step_Time': f'{step_time:.2f}s',
                        'Relations': f'{relations_processed}/{total_relations}'
                    })

                except Exception as e:
                    print(f"Error in forward pass at timestep {t}: {e}")
                    continue

            if (step_idx + 1) % 10 == 0:
                elapsed_time = time.time() - start_time
                avg_loss = loss_sum / max(batch_count, 1)
                print(f"\n  Step {step_idx + 1}/{total_time_steps} | "
                      f"Avg Loss: {avg_loss:.4f} | "
                      f"Batches: {batch_count} | "
                      f"Elapsed: {elapsed_time:.1f}s")

        progress_bar.close()

        scheduler.step()

        epoch_time = time.time() - start_time
        avg_epoch_loss = sum(epoch_loss) / len(epoch_loss) if epoch_loss else 0

        print(f"\n{'='*50}")
        print(f"EPOCH {epoch + 1} SUMMARY:")
        print(f"  Average Loss: {avg_epoch_loss:.6f}")
        print(f"  Total Batches: {batch_count}")
        print(f"  Epoch Time: {epoch_time:.1f}s")
        print(f"  Batches/sec: {batch_count/epoch_time:.2f}")
        print(f"  Current LR: {optimizer.param_groups[0]['lr']:.6f}")
        print(f"{'='*50}")

        epoch_losses.append(avg_epoch_loss)

        save_model(model, args.data, model_state_file, epoch)

        if len(epoch_losses) > 5:
            recent_losses = epoch_losses[-5:]
            if all(abs(recent_losses[i] - recent_losses[i-1]) < 1e-6 for i in range(1, len(recent_losses))):
                print("Early stopping: Loss has converged")
                break

    print("\nTraining completed!")
    print(f"Final loss: {epoch_losses[-1]:.6f}")
    return epoch_losses


def evaluate_conv(args):
    model = TypeGAT(num_e, num_r*2, relation_embeddings, args.embedding_size)
    model.load_state_dict(torch.load(
        '{0}/trained99.pth'.format(model_state_file), map_location=DEVICE), strict=False)

    if HAS_ACCELERATION:
        model.to(DEVICE)

    model.eval()

    with torch.no_grad():
        mr, mrr, hits1, hits3, hits10, hits100 = 0, 0, 0, 0, 0, 0
        test_size = 0
        graph_test = renamed_load(
            open(os.path.join(args.data + '/graph_preprocess_test.pk'), 'rb'))
        for t, r_dict in graph_test.t_r_id_p_dict.items():
            size = 0
            ranks_tail = []
            reciprocal_ranks_tail = []
            hits_at_100_tail = 0
            hits_at_ten_tail = 0
            hits_at_three_tail = 0
            hits_at_one_tail = 0

            for r, id_p in r_dict.items():
                for id, ps in id_p.items():
                    len_r = 0
                    batch_paths_id = []
                    batch_relation = []

                    batch_his_r = []

                    size = size + 1
                    len_r = len_r + len(ps)
                    batch_paths_id.extend(ps)

                    batch_relation.extend([r] * len_r)

                    batch_his_r.extend(graph_test.r_copy[t][r][id])

                    batch_paths_id = torch.LongTensor(
                        list_to_array(batch_paths_id, 0))
                    batch_relation = torch.LongTensor(np.array(batch_relation))

                    batch_his_r = torch.LongTensor(
                        list_to_array(batch_his_r, num_r * 2))

                    paths = graph_test.t_paths[t]
                    lengths = graph_test.t_paths_len[t]
                    paths_time = graph_test.t_paths_time[t]

                    if len(paths) != 0:
                        paths = pad_sequence([torch.LongTensor(np.array(p)) for p in paths], batch_first=True,
                                             padding_value=num_r*2)
                        paths_time = pad_sequence([torch.LongTensor(np.array(p)) for p in paths_time], batch_first=True,
                                                  padding_value=0)
                    else:
                        paths = torch.LongTensor(np.array(paths))
                        paths_time = torch.LongTensor(np.array(paths_time))
                    lengths = torch.LongTensor(np.array(lengths))

                    if HAS_ACCELERATION:
                        batch_paths_id = Variable(batch_paths_id).to(DEVICE)
                        batch_relation = Variable(batch_relation).to(DEVICE)
                        paths = Variable(paths).to(DEVICE)
                        lengths = Variable(lengths).to(DEVICE)
                        paths_time = Variable(paths_time).to(DEVICE)
                        batch_his_r = Variable(batch_his_r).to(DEVICE)
                    else:
                        batch_paths_id = Variable(batch_paths_id)
                        batch_relation = Variable(batch_relation)
                        paths = Variable(paths)
                        lengths = Variable(lengths)
                        paths_time = Variable(paths_time)
                        batch_his_r = Variable(torch.LongTensor(batch_his_r))

                    scores_tail = model.test(
                        batch_paths_id, batch_relation, paths, lengths, paths_time, batch_his_r)

                    del batch_paths_id, batch_relation, paths, lengths

                    sorted_scores_tail, sorted_indices_tail = torch.sort(
                        scores_tail.view(-1), dim=-1, descending=True)
                    del scores_tail

                    # Just search for zeroth index in the sorted scores, we appended valid triple at top
                    ranks_tail.append(
                        np.where(sorted_indices_tail.cpu().numpy() == 0)[0][0] + 1)
                    reciprocal_ranks_tail.append(1.0 / ranks_tail[-1])

            for i in range(len(ranks_tail)):
                if ranks_tail[i] <= 100:
                    hits_at_100_tail = hits_at_100_tail + 1
                if ranks_tail[i] <= 10:
                    hits_at_ten_tail = hits_at_ten_tail + 1
                if ranks_tail[i] <= 3:
                    hits_at_three_tail = hits_at_three_tail + 1
                if ranks_tail[i] == 1:
                    hits_at_one_tail = hits_at_one_tail + 1

            assert len(ranks_tail) == len(reciprocal_ranks_tail)
            if len(ranks_tail) == 0:
                continue

            t_hits100 = hits_at_100_tail / len(ranks_tail)
            t_hits10 = hits_at_ten_tail / len(ranks_tail)
            t_hits3 = hits_at_three_tail / len(ranks_tail)
            t_hits1 = hits_at_one_tail / len(ranks_tail)
            t_mr = sum(ranks_tail) / len(ranks_tail)
            t_mrr = sum(reciprocal_ranks_tail) / len(reciprocal_ranks_tail)

            print("\nCumulative stats are -> ")
            print("Hits@100 are {}".format(t_hits100))
            print("Hits@10 are {}".format(t_hits10))
            print("Hits@3 are {}".format(t_hits3))
            print("Hits@1 are {}".format(t_hits1))
            print("Mean rank {}".format(t_mr))
            print("Mean Reciprocal Rank {}".format(t_mrr))

            test_size = test_size + size

            mrr += t_mrr * size
            mr += t_mr * size
            hits1 += t_hits1 * size
            hits3 += t_hits3 * size
            hits10 += t_hits10 * size
            hits100 += t_hits100 * size

        mrr = mrr / test_size
        mr = mr / test_size
        hits1 = hits1 / test_size
        hits3 = hits3 / test_size
        hits10 = hits10 / test_size
        hits100 = hits100 / test_size

        print("MR : {:.6f}".format(mr))
        print("MRR : {:.6f}".format(mrr))
        print("Hits @ 1: {:.6f}".format(hits1))
        print("Hits @ 3: {:.6f}".format(hits3))
        print("Hits @ 10: {:.6f}".format(hits10))
        print("Hits @ 100: {:.6f}".format(hits100))


train_conv(args)
evaluate_conv(args)


In [None]:
# File: model_summary.py
# Source: ./model_summary.py

import argparse
import torch
import torch.nn as nn
import numpy as np
import os


def parse_args():
    args = argparse.ArgumentParser()
    args.add_argument("-data", "--data",
                      default="./data/ICEWS14_forecasting", help="data directory")
    # Notebook-compatible argument parsing
    import sys
    if 'ipykernel' in sys.modules:
        # Running in notebook - use defaults
        args = argparse.Namespace(
            data='./data/ICEWS14_forecasting',
            state='train',
            ratio=1,
            his_len=13
        )
    else:
        # Running as script - use command line args
        args = args.parse_args()
    return args


# Notebook-compatible argument setup
import sys
if 'ipykernel' in sys.modules:
    # Running in notebook - create args with defaults
    import argparse
    args = argparse.Namespace(
        data='./data/ICEWS14_forecasting',
        state='train',
        ratio=1,
        his_len=13
    )
else:
    # Running as script - use command line args
    args = parse_args()


def load_data_info():
    data_path = args.data

    with open(os.path.join(data_path, 'stat.txt'), 'r') as fr:
        for line in fr:
            line_split = line.split()
            num_e, num_r = int(line_split[0]), int(line_split[1])

    return num_e, num_r


def print_model_summary():
    num_e, num_r = load_data_info()
    embedding_size = 200

    relation_embeddings = torch.FloatTensor(
        np.random.randn(num_r * 2, embedding_size))

    model = TypeGAT(num_e, num_r * 2, relation_embeddings, embedding_size)

    print("=" * 50)
    print("MODEL ARCHITECTURE")
    print("=" * 50)
    print(model)
    print()

    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel()
                           for p in model.parameters() if p.requires_grad)

    print("=" * 50)
    print("PARAMETER SUMMARY")
    print("=" * 50)
    print(f"Total parameters: {total_params:,}")
    print(f"Trainable parameters: {trainable_params:,}")
    print(f"Non-trainable parameters: {total_params - trainable_params:,}")
    print()

    print("=" * 50)
    print("DETAILED PARAMETER BREAKDOWN")
    print("=" * 50)
    for name, param in model.named_parameters():
        print(f"{name:30} | Shape: {str(param.shape):20} | Params: {param.numel():,}")

    print()
    print("=" * 50)
    print("DATA INFO")
    print("=" * 50)
    print(f"Number of entities: {num_e:,}")
    print(f"Number of relations: {num_r:,}")
    print(f"Total relations (with inverse): {num_r * 2:,}")
    print(f"Embedding dimension: {embedding_size}")


if __name__ == "__main__":
    print_model_summary()
