In [1]:
import sys
import os
sys.path.append(os.path.dirname(os.getcwd()))
sys.path.append(os.path.dirname(os.path.dirname(os.getcwd())))
from utils import load_pickle, save_pickle, summarize_distribution
from lib.log import logger
import igraph

2022-09-26 15:12:28,610 Note: NumExpr detected 56 cores but "NUMEXPR_MAX_THREADS" not set, so enforcing safe limit of 8.
2022-09-26 15:12:28,611 NumExpr defaulting to 8 threads.


In [2]:
graph = load_pickle("/root/Lab_Related/data/Heter-GAT/Classic/graph/graph-directed.p")
diffusion_dict = load_pickle("/root/Lab_Related/data/Heter-GAT/Classic/ActionLog.p")

In [3]:
# hyper-param
ego_size = 49 # without center node, which means containing ego_size + 1 nodes per batch
Ntimestages = 8
Nslice = Ntimestages + 1
sample_ratio = 1
restart_prob = 0.2
walk_length = 1000
min_degree, max_degree, min_active_neighbor = 3, 7635, 3
min_degree, max_degree, min_active_neighbor = 3, 450, 3
min_influence, max_influence = 30, 200

In [5]:
from utils import SubGraphSample
from typing import List, Any
import random
import itertools
import numpy as np

subgraph_samples = [SubGraphSample() for _ in range(Ntimestages)]

def random_walk_with_restart(g: igraph.Graph, start: List[int], restart_prob: float):
    current = random.choice(seq=start)
    stop = False
    while not stop:
        stop = yield current
        neighbors = g.neighbors(current, mode="out")
        if random.random() < restart_prob or len(neighbors) == 0:
            current = random.choice(seq=start)
        else:
            current = random.choice(seq=neighbors)

def create_sample(center_user: int, label: int, hashtag: int, time_stage: int, users_affected_now: List[int], subgraph_sample: SubGraphSample):
    neighbors = graph.neighbors(center_user, mode="out")
    active_neighbors = set(users_affected_now) & set(neighbors)
    inactive_neighbors = set(neighbors) - active_neighbors
    active_neighbors, inactive_neighbors = list(active_neighbors), list(inactive_neighbors)
    if len(active_neighbors) < min_active_neighbor:
        return
    
    subnetwork_size = ego_size + 1
    subnetwork = []
    if len(active_neighbors) < ego_size:
        subnetwork = set(active_neighbors)
        for v in itertools.islice(
            random_walk_with_restart(g=graph, start=[center_user]+active_neighbors,restart_prob=restart_prob),
            walk_length
        ):
            if v != center_user and v not in subnetwork:
                subnetwork.add(v)
                if len(subnetwork) == ego_size:
                    break
        subnetwork = list(subnetwork)
        if len(subnetwork) < ego_size:
            return
    else:
        samples = np.random.choice(active_neighbors, size=ego_size, replace=False)
        subnetwork = samples.tolist()
    subnetwork.append(center_user)

    ranks = np.array(subnetwork).argsort().argsort()
    subgraph = graph.subgraph(subnetwork, implementation="create_from_scratch")
    adjacency = np.array(subgraph.get_adjacency().data, dtype=int)
    adjacency = adjacency[ranks][:,ranks]
    subgraph_sample.adj_matrices.append(adjacency)

    influence_feature = np.zeros((subnetwork_size,2))
    for idx, v in enumerate(subnetwork[:-1]):
        if v in users_affected_now:
            influence_feature[idx, 0] = 1
    influence_feature[subnetwork_size-1,1] = 1
    subgraph_sample.influence_features.append(influence_feature)

    subgraph_sample.vertex_ids.append(np.array(subnetwork, dtype=int))
    subgraph_sample.labels.append(label)
    subgraph_sample.tags.append(hashtag)
    subgraph_sample.time_stages.append(time_stage)

def dump_data(dump_dirpath: str, adj_matrices, influence_features, vertex_ids, labels, tags, time_stages):
    adj_matrices = np.array(adj_matrices)
    influence_features = np.array(influence_features)
    vertex_ids = np.array(vertex_ids)
    labels = np.array(labels)
    tags = np.array(tags)
    time_stages = np.array(time_stages)

    os.makedirs(dump_dirpath, exist_ok=True)
    with open(os.path.join(dump_dirpath, "adjacency_matrix.npy"), "wb") as f:
        np.save(f, adj_matrices)
    with open(os.path.join(dump_dirpath, "influence_feature.npy"), "wb") as f:
        np.save(f, influence_features)
    with open(os.path.join(dump_dirpath, "vertex_id.npy"), "wb") as f:
        np.save(f, vertex_ids)
    with open(os.path.join(dump_dirpath, "label.npy"), "wb") as f:
        np.save(f, labels)
    with open(os.path.join(dump_dirpath, "tag.npy"), "wb") as f:
        np.save(f, tags)
    with open(os.path.join(dump_dirpath, "time_stage.npy"), "wb") as f:
        np.save(f, time_stages)

    logger.info("Dump %d instances in total" % (len(labels)))

def find_rightest_bound(timestamp: float, min_timestamp: float, time_span: float):
    idx = int((timestamp-min_timestamp)//time_span)
    return Ntimestages-1 if idx>=Ntimestages-1 else idx

def get_samples(hashtag: int, cascades: List[Any], degree: List[int]):
    """
    Function: 分别找到每个(cascades, time-stage)下的正负样本, 其中正样本指的是已激活的用户, 负样本指的是这些正样本的子节点
    """
    cascade_item_idx = 0
    users_affected_now = set()
    users_affected_all = set([item[0] for item in cascades])
    time_span = (cascades[-1][1]-cascades[0][1])/Ntimestages

    for user, timestamp in cascades[1:]:
        while cascade_item_idx < len(cascades) and cascades[cascade_item_idx][1] < timestamp:
            users_affected_now.add(cascades[cascade_item_idx][0])
            cascade_item_idx += 1
        if len(users_affected_now) == 0 or user in users_affected_now:
            continue
        tidx = find_rightest_bound(timestamp, min_timestamp=cascades[0][1], time_span=time_span)
        
        # Pos
        if degree[user] >= min_degree and degree[user] <= max_degree:
            create_sample(center_user=user, label=1, hashtag=hashtag, time_stage=tidx, users_affected_now=users_affected_now, subgraph_sample=subgraph_samples[tidx])
        
        # Neg
        neg_samples = list(set(graph.neighbors(user, mode="out")) - users_affected_all)
        neg_samples = np.random.choice(neg_samples, size=min(len(neg_samples), sample_ratio), replace=False)
        for neg_sample in neg_samples:
            if degree[neg_sample] >= min_degree and degree[neg_sample] <= max_degree:
                create_sample(center_user=neg_sample, label=0, hashtag=hashtag, time_stage=tidx, users_affected_now=users_affected_now, subgraph_sample=subgraph_samples[tidx])

degree = graph.degree(mode="out")
for aidx, (hashtag, cascades) in enumerate(diffusion_dict.items()):
    get_samples(hashtag=hashtag, cascades=cascades, degree=degree)

    logger.info(f"aidx={aidx:>4}, hashtag={hashtag:>4}, cascades length={len(cascades):>8}, subgraph samples length=" + \
        " ".join([f"{len(sample):>6}" for sample in subgraph_samples])
    )
    break

# for idx in range(Ntimestages):
#     dump_data(
#         dump_dirpath=f"stages1/{idx}",
#         adj_matrices=subgraph_samples[idx].adj_matrices,
#         influence_features=subgraph_samples[idx].influence_features,
#         vertex_ids=subgraph_samples[idx].vertex_ids,
#         labels=subgraph_samples[idx].labels,
#         tags=subgraph_samples[idx].tags,
#         time_stages=subgraph_samples[idx].time_stages,
#     )


2022-09-26 15:21:20,318 [282, 159, 176, 176, 259, 389, 543, 735]
2022-09-26 15:21:20,320 aidx=   0, hashtag= 186, cascades length=  120774, subgraph samples length=     0      0      0      6     11     21     56    124


In [57]:
# def func(a=[]):
#     a.append(2)
#     return a

func()
func.__defaults__

([2, 2],)

In [34]:
s.a

[2, 2]

In [None]:
# def get_heternetwork_nodetype(current:int, nb_users:int=20000)->str:
#     return "Tweet" if current >= nb_users else "User"

# def static_vars(**kwargs):
#     def decorate(func):
#         for k in kwargs:
#             setattr(func, k, kwargs[k])
#         return func
#     return decorate

# @static_vars(node_cnts={"User":0, "Tweet":0}, max_cnts = {"User": 50, "Tweet": 50*20})
# def control_node_cnts(current:int)->bool:
#     node_type = get_heternetwork_nodetype(current)
#     valid = control_node_cnts.node_cnts[node_type]+1<=control_node_cnts.max_cnts[node_type]
#     if valid:
#         control_node_cnts.node_cnts[node_type] += 1
#     # logger.info("node_cnts: {}".format(" ".join(f"{key}={value}" for key,value in control_node_cnts.node_cnts.items())))
#     return valid

# def heter_random_walk_with_restart(g, start:List[int], restart_prob:float, valid_fn:Callable[[int],bool]):
#     current = random.choice(start)
#     stop = False
#     valid_fn.node_cnts={"User":0, "Tweet":0}

#     while not stop:
#         stop = yield current
#         valid = False
        
#         while not valid:
#             node_type = get_heternetwork_nodetype(current)
#             # NOTE: We treat all edges as directed ones
#             # Bcz we create U-T Edges in the direction of user->tweet
#             neighbor_mode = "in" if node_type == "Tweet" else "out"
#             if random.random() < restart_prob or g.degree(current, mode=neighbor_mode)==0:
#                 current = random.choice(start)
#                 valid = valid_fn(current)
#             else:
#                 current = random.choice(g.neighbors(current, mode=neighbor_mode))
#                 valid = valid_fn(current)
#             # logger.info(f"current={current}, valid={valid}")

# for v in itertools.islice(heter_random_walk_with_restart(graph, [2603], 0.2, control_node_cnts), 10):
#     print(v)
# print(control_node_cnts.node_cnts)

# def get_samples(cascades: List[tuple[int,int]]):
#     """
#     Function: 分别找到每个(cascades, time-stage)下的正负样本, 其中正样本指的是已激活的用户, 负样本指的是这些正样本的子节点
#     """
#     pos_samples, neg_samples = [set() for _ in range(Nslice)], [set() for _ in range(Nslice)]

#     cascade_item_idx = 0
#     min_timestamp, max_timestamp = cascades[0][1], cascades[-1][1]
#     time_span = (max_timestamp - min_timestamp) / Nslice
#     users_affected_now = [set() for _ in range(Nslice)]
    
#     for tidx in range(Nslice)
#         lower_b, upper_b = min_timestamp+tidx*time_span, min_timestamp+(tidx+1)*time_span
#         cur_pos, cur_neg = set(), set()
#         while cascade_item_idx < len(cascades) and cascades[cascade_item_idx][1] <= upper_b:
#             cur_pos.add(cascades[cascade_item_idx][0])
#             cur_neg |= set(graph.neighbors(cascades[cascade_item_idx][0], mode="out"))
#             cascade_item_idx += 1
#         users_affected_now[tidx] = cur_pos
#         if tidx >= 1:
#             users_affected_now[tidx] |= users_affected_now[tidx-1]

#         pos_samples[tidx] |= cur_pos
#         neg_samples[tidx] |= cur_neg
#         # if tidx >= 1:
#         #     neg_samples[tidx-1] -= cur_pos
        
#     return pos_samples, neg_samples, users_affected_now

# pos_samples, neg_samples, users_affected_now = get_samples(cascades=diffusion_dict[186])
# pos_samples, neg_samples, users_affected_now = get_samples(cascades=cascades)
# for tidx in range(len(users_affected_now)):
#     for pos_sample in pos_samples[tidx]:
#         if degree[pos_sample] >= min_degree and degree[pos_sample] <= max_degree:
#             create_sample(center_user=pos_sample, label=1, users_affected_now=users_affected_now[tidx], subgraph_sample=subgraph_samples[tidx])
    
#     neg_users = np.random.choice(list(neg_samples[tidx]), size=min(len(neg_samples[tidx]), sample_ratio*len(pos_samples[tidx])), replace=False)
#     for neg_sample in neg_users:
#         if degree[neg_sample] >= min_degree and degree[neg_sample] <= max_degree:
#             create_sample(center_user=neg_sample, label=0, users_affected_now=users_affected_now[tidx], subgraph_sample=subgraph_samples[tidx])

