In [3]:
import networkx as nx
import numpy as np
import random
from tqdm.auto import trange, tqdm
from typing import List, Tuple
import pandas as pd

In [2]:
G = nx.read_edgelist('graph.edgelist')

In [4]:
class LINE:
    def __init__(self, G: nx.Graph, dim: int = 64, order: int = 2, lr: float = 0.025, num_epochs: int = 100):
        self.G = G
        self.dim = dim
        self.order = order
        self.lr = lr
        self.num_epochs = num_epochs
        self.node_list = list(G.nodes())
        self.node_index = {node: idx for idx, node in enumerate(self.node_list)}
        self.num_nodes = len(self.node_list)
        # Инициализация эмбеддингов узлов как двумерный массив NumPy
        self.node_embeddings = np.random.randn(self.num_nodes, dim) / np.sqrt(dim)
    
    def get_first_order_pairs(self) -> List[Tuple[int, int]]:
        return [(self.node_index[u], self.node_index[v]) for u, v in self.G.edges()]
    
    def get_second_order_pairs(self) -> List[Tuple[int, int]]:
        pairs = []
        for node in self.node_list:
            neighbors = list(self.G.neighbors(node))
            idx = self.node_index[node]
            neighbor_indices = [self.node_index[neighbor] for neighbor in neighbors]
            for u in neighbor_indices:
                for v in neighbor_indices:
                    if u != v:
                        pairs.append((u, v))
        return pairs
    
    def train(self):
        first_order_pairs = self.get_first_order_pairs()
        # second_order_pairs = self.get_second_order_pairs()  # Если понадобится использовать
        
        for epoch in trange(self.num_epochs, desc="Training"):
            # Перемешиваем пары первого порядка
            random.shuffle(first_order_pairs)
            # random.shuffle(second_order_pairs)  # Если понадобится использовать
            
            # Обработка пар первого порядка
            for u, v in tqdm(first_order_pairs, leave=False):
                emb_u = self.node_embeddings[u]
                emb_v = self.node_embeddings[v]
                
                dot_product = np.dot(emb_u, emb_v)
                exp_neg_dot = np.exp(-dot_product)
                
                sigmoid = 1 / (1 + exp_neg_dot)
                grad_common = sigmoid - 1  # Для максимизации вероятности связи
                
                # Обновление эмбеддингов
                self.node_embeddings[u] -= self.lr * grad_common * emb_v
                self.node_embeddings[v] -= self.lr * grad_common * emb_u
            
            if epoch % 5 == 0:
                loss = self.calculate_loss(first_order_pairs)
                print(f"Epoch {epoch + 1}/{self.num_epochs}, Loss: {loss:.4f}")
    
    def calculate_loss(self, first_order_pairs: List[Tuple[int, int]]) -> float:
        embeddings_u = self.node_embeddings[[u for u, v in first_order_pairs]]
        embeddings_v = self.node_embeddings[[v for u, v in first_order_pairs]]
        dot_products = np.einsum('ij,ij->i', embeddings_u, embeddings_v)
        
        # Числовая стабильность для log(1 + exp(x))
        loss = np.where(dot_products > 0,
                        np.log1p(np.exp(-dot_products)),
                        -dot_products + np.log1p(np.exp(dot_products)))
        return np.sum(loss)
    
    def get_embeddings(self) -> dict:
        return {node: self.node_embeddings[idx] for node, idx in self.node_index.items()}

In [5]:
line_model = LINE(G, dim=64, order=1, num_epochs=100)
line_model.train()

Training:   0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/10794057 [00:00<?, ?it/s]

Epoch 1/100, Loss: 7362303.6640


  0%|          | 0/10794057 [00:00<?, ?it/s]

  0%|          | 0/10794057 [00:00<?, ?it/s]

  0%|          | 0/10794057 [00:00<?, ?it/s]

  0%|          | 0/10794057 [00:00<?, ?it/s]

  0%|          | 0/10794057 [00:00<?, ?it/s]

Epoch 6/100, Loss: 5608819.5052


  0%|          | 0/10794057 [00:00<?, ?it/s]

  0%|          | 0/10794057 [00:00<?, ?it/s]

  0%|          | 0/10794057 [00:00<?, ?it/s]

  0%|          | 0/10794057 [00:00<?, ?it/s]

  0%|          | 0/10794057 [00:00<?, ?it/s]

Epoch 11/100, Loss: 2772057.8645


  0%|          | 0/10794057 [00:00<?, ?it/s]

  0%|          | 0/10794057 [00:00<?, ?it/s]

  0%|          | 0/10794057 [00:00<?, ?it/s]

  0%|          | 0/10794057 [00:00<?, ?it/s]

  0%|          | 0/10794057 [00:00<?, ?it/s]

Epoch 16/100, Loss: 1416931.6645


  0%|          | 0/10794057 [00:00<?, ?it/s]

  0%|          | 0/10794057 [00:00<?, ?it/s]

  0%|          | 0/10794057 [00:00<?, ?it/s]

  0%|          | 0/10794057 [00:00<?, ?it/s]

  0%|          | 0/10794057 [00:00<?, ?it/s]

Epoch 21/100, Loss: 827379.8958


  0%|          | 0/10794057 [00:00<?, ?it/s]

  0%|          | 0/10794057 [00:00<?, ?it/s]

  0%|          | 0/10794057 [00:00<?, ?it/s]

  0%|          | 0/10794057 [00:00<?, ?it/s]

  0%|          | 0/10794057 [00:00<?, ?it/s]

Epoch 26/100, Loss: 538567.5604


  0%|          | 0/10794057 [00:00<?, ?it/s]

  0%|          | 0/10794057 [00:00<?, ?it/s]

  0%|          | 0/10794057 [00:00<?, ?it/s]

  0%|          | 0/10794057 [00:00<?, ?it/s]

  0%|          | 0/10794057 [00:00<?, ?it/s]

Epoch 31/100, Loss: 379464.4280


  0%|          | 0/10794057 [00:00<?, ?it/s]

  0%|          | 0/10794057 [00:00<?, ?it/s]

  0%|          | 0/10794057 [00:00<?, ?it/s]

  0%|          | 0/10794057 [00:00<?, ?it/s]

  0%|          | 0/10794057 [00:00<?, ?it/s]

Epoch 36/100, Loss: 283278.7689


  0%|          | 0/10794057 [00:00<?, ?it/s]

  0%|          | 0/10794057 [00:00<?, ?it/s]

  0%|          | 0/10794057 [00:00<?, ?it/s]

  0%|          | 0/10794057 [00:00<?, ?it/s]

  0%|          | 0/10794057 [00:00<?, ?it/s]

Epoch 41/100, Loss: 220823.1888


  0%|          | 0/10794057 [00:00<?, ?it/s]

  0%|          | 0/10794057 [00:00<?, ?it/s]

  0%|          | 0/10794057 [00:00<?, ?it/s]

  0%|          | 0/10794057 [00:00<?, ?it/s]

  0%|          | 0/10794057 [00:00<?, ?it/s]

Epoch 46/100, Loss: 177954.0035


  0%|          | 0/10794057 [00:00<?, ?it/s]

  0%|          | 0/10794057 [00:00<?, ?it/s]

  0%|          | 0/10794057 [00:00<?, ?it/s]

  0%|          | 0/10794057 [00:00<?, ?it/s]

  0%|          | 0/10794057 [00:00<?, ?it/s]

Epoch 51/100, Loss: 147205.7184


  0%|          | 0/10794057 [00:00<?, ?it/s]

  0%|          | 0/10794057 [00:00<?, ?it/s]

  0%|          | 0/10794057 [00:00<?, ?it/s]

  0%|          | 0/10794057 [00:00<?, ?it/s]

  0%|          | 0/10794057 [00:00<?, ?it/s]

Epoch 56/100, Loss: 124354.0659


  0%|          | 0/10794057 [00:00<?, ?it/s]

  0%|          | 0/10794057 [00:00<?, ?it/s]

  0%|          | 0/10794057 [00:00<?, ?it/s]

  0%|          | 0/10794057 [00:00<?, ?it/s]

  0%|          | 0/10794057 [00:00<?, ?it/s]

Epoch 61/100, Loss: 106868.5031


  0%|          | 0/10794057 [00:00<?, ?it/s]

  0%|          | 0/10794057 [00:00<?, ?it/s]

  0%|          | 0/10794057 [00:00<?, ?it/s]

  0%|          | 0/10794057 [00:00<?, ?it/s]

  0%|          | 0/10794057 [00:00<?, ?it/s]

Epoch 66/100, Loss: 93159.6217


  0%|          | 0/10794057 [00:00<?, ?it/s]

  0%|          | 0/10794057 [00:00<?, ?it/s]

  0%|          | 0/10794057 [00:00<?, ?it/s]

  0%|          | 0/10794057 [00:00<?, ?it/s]

  0%|          | 0/10794057 [00:00<?, ?it/s]

Epoch 71/100, Loss: 82188.6418


  0%|          | 0/10794057 [00:00<?, ?it/s]

  0%|          | 0/10794057 [00:00<?, ?it/s]

  0%|          | 0/10794057 [00:00<?, ?it/s]

  0%|          | 0/10794057 [00:00<?, ?it/s]

  0%|          | 0/10794057 [00:00<?, ?it/s]

Epoch 76/100, Loss: 73253.2862


  0%|          | 0/10794057 [00:00<?, ?it/s]

  0%|          | 0/10794057 [00:00<?, ?it/s]

  0%|          | 0/10794057 [00:00<?, ?it/s]

  0%|          | 0/10794057 [00:00<?, ?it/s]

  0%|          | 0/10794057 [00:00<?, ?it/s]

Epoch 81/100, Loss: 65864.7488


  0%|          | 0/10794057 [00:00<?, ?it/s]

  0%|          | 0/10794057 [00:00<?, ?it/s]

  0%|          | 0/10794057 [00:00<?, ?it/s]

  0%|          | 0/10794057 [00:00<?, ?it/s]

  0%|          | 0/10794057 [00:00<?, ?it/s]

Epoch 86/100, Loss: 59674.1044


  0%|          | 0/10794057 [00:00<?, ?it/s]

  0%|          | 0/10794057 [00:00<?, ?it/s]

  0%|          | 0/10794057 [00:00<?, ?it/s]

  0%|          | 0/10794057 [00:00<?, ?it/s]

  0%|          | 0/10794057 [00:00<?, ?it/s]

Epoch 91/100, Loss: 54426.7055


  0%|          | 0/10794057 [00:00<?, ?it/s]

  0%|          | 0/10794057 [00:00<?, ?it/s]

  0%|          | 0/10794057 [00:00<?, ?it/s]

  0%|          | 0/10794057 [00:00<?, ?it/s]

  0%|          | 0/10794057 [00:00<?, ?it/s]

Epoch 96/100, Loss: 49933.0461


  0%|          | 0/10794057 [00:00<?, ?it/s]

  0%|          | 0/10794057 [00:00<?, ?it/s]

  0%|          | 0/10794057 [00:00<?, ?it/s]

  0%|          | 0/10794057 [00:00<?, ?it/s]

In [33]:
import pickle


embeddings = line_model.get_embeddings()

with open("embeddings.pkl", "wb") as f:
    pickle.dump(embeddings, f)