In [287]:
import numpy as np
import seaborn as sn
from tqdm.notebook import tqdm
import random
import contextlib
import threading
import logging
import math
import time
import sys
import collections
import copy

In [288]:
logging.basicConfig(level=logging.DEBUG, format="%(name)s:%(threadName)s: %(message)s")
nodes_logger = logging.getLogger('nodes')
nodes_logger.info("ASD")

In [312]:
class Embedding:
    class Point:
        def __repr__(self):
            raise NotImplementedError
        
        def __str__(self):
            return repr(self)
        
    def __init__(self, n):
        self.n = n
    
    def distance(self, p1, p2):
        raise NotImplementedError
        
    def get_slots(self, n):
        raise NotImplementedError

In [313]:
class LineEmbedding(Embedding):
    class Point(Embedding.Point):
        def __init__(self, cord):
            self.cord = cord
        
        def __repr__(self):
            return str(self.cord)
    
    def distance(self, p1, p2):
        return abs(p1.cord - p2.cord)
    
    def get_slots(self):
        return [self.Point(x) for x in np.linspace(0, self.n, self.n)]
    

In [314]:
class CircleEmbedding(Embedding):
    class Point(Embedding.Point):
        def __init__(self, cord):
            self.cord = cord
        
        def __repr__(self):
            return str(self.cord)

    def distance(self, p1, p2):
        a, b = p1.cord, p2.cord
        if a > b:
            a, b = b, a
        return min(b - a, a + (self.size - b))
    
    def get_slots(self):
        return [self.Point(x) for x in np.linspace(0, self.n, self.n + 1)]
    

In [315]:
class PlaneEmbedding(Embedding):
    class Point(Embedding.Point):
        def __init__(self, x, y):
            self.x = x
            self.y = y
        
        def __repr__(self):
            return f'{self.x} + i{self.y}'
   
    def distance(self, p1, p2):
        return ((p1.x - p2.x) ** 2 + (p1.y - p2.y) ** 2)
    
    def get_slots(self):
        sq = math.floor(math.sqrt(self.n) + 1)
        return [
            self.Point(x, y)
            for x in np.linspace(0, sq, sq)
            for y in np.linspace(0, sq, sq)
        ]


In [316]:
class Node:
    __slots__ = [
        "neighbours_pos",
        "router",
        "locked_with",
        "pos_change_queue",
        "_position",
        "idx",
        "embedding",
        "_lock",
        "logger",
    ]
    
    @contextlib.contextmanager
    def locking_communications(self, other_idx):
        success = False
        try:
            if self.initialize(other_idx):
                success = self.router[other_idx].initialize(self.idx)
            yield success
        finally:
            self.router[other_idx].finalize(self.idx)
            self.finalize(other_idx)
    
    def initialize(self, other_idx):
        with self._lock:            
            if self.locked_with is not None and self.locked_with != other_idx:
                return False
#             self.logger.warning(f"A {[x.locked_with for x in self.router]}")
#             self.logger.debug(f"🔒 {self.idx} - {other_idx}")
            self.locked_with = other_idx
#             self.logger.warning(f"B {[x.locked_with for x in self.router]}")
            return True
        
    def finalize(self, other_idx):
        with self._lock:            
            if self.locked_with == other_idx:
#                 self.logger.debug(f"🔓 {self.idx}")
                self.locked_with = None
    
    def __init__(self, router, embedding, idx, neighbours_ids, logger):
        self.idx = idx
        self.neighbours_pos = {neigh: None for neigh in neighbours_ids}
        self.router = router
        self.locked_with = None
        self.pos_change_queue = []
        self._position = None
        self.embedding = embedding
        self._lock = threading.Lock()
        self.logger = logger
    
    def materialize(self):
        for neigh in self.neighbours_pos:
#             self.logger.warning(f"{repr(self)} broadcasting new position to {neigh}")
            self.router[neigh].update_pos(self.idx, self.position)
    
    def update_pos(self, other, pos):
        if self.locked_with is None or self.locked_with == other:
            self.neighbours_pos[other] = pos
#             self.logger.warning(f"{repr(self)} accepting new position {pos}")
        else:
            self.pos_change_queue.append((other, pos))
            
    def catch_up_with_changelogs(self):
        rows, self.pos_change_queue = self.pos_change_queue, []
        for row in rows:
            self.update_pos(*row)
    
    @property
    def position(self):
        return self._position
    
    @position.setter
    def position(self, value):
#         self.logger.warning(f"{repr(self)} setting position to {value}")
        self._position = value
        self.materialize()
    
    def swap(self, other_pos):
        old_pos = self.position
        self.position = other_pos
        return old_pos
    
    def act(self):
        other_idx = random.choice(list(self.neighbours_pos))
        with self.locking_communications(other_idx) as succcess:
            if succcess and self.neighbours_pos[other_idx] is not None:
                other_old, other_new = self.router[other_idx].score(self.idx)
                self_old, self_new = self.score(other_idx)
                
                score_diff = (self_old + other_old) - (self_new + other_new)
                
                if random.random() > 0.5:
#                     self.logger.info(f"{self} ⇌ {self.router[other_idx]}")
                    self.position = self.router[other_idx].swap(self.position)
    
    def loop(self, n=0):
        for i in range(n):
            self.act()
    
    def __repr__(self):
        return f"<Node {self.idx} @ {repr(self.position)}>"
    
    def score(self, other_idx):
        other_pos = self.neighbours_pos[other_idx]
        old_score, new_score = 0, 0
#         self.logger.debug(f"{repr(self)} computing score with {other_idx}; known positions: {self.neighbours_pos}")
        for idx, pos in self.neighbours_pos.items():
            if idx == other_idx:
                continue
            try:
                old_score += math.log(self.embedding.distance(self.position, pos))
            except ValueError:
                self.logger.critical(f"Zero distance <PseudoNode {self.idx} @ {self.position}> - <PseudoNode {idx} @ {pos}>")
            try:
                new_score += math.log(self.embedding.distance(other_pos, pos))
            except ValueError:
                self.logger.critical(f"Zero distance <PseudoNode {self.idx} @ {other_pos}> - <PseudoNode {idx} @ {pos}>")
        return old_score, new_score

In [317]:
def spawn_thread_swarm(router, begin=0, end=None):
    swarm = []
    if end is None:
        end = len(router)
    for node in router[begin:end]:
        t = threading.Thread(target=node.loop, name=f"Thread-{repr(node)}")
        t.daemon = True
        t.start()
        swarm.append((node, t))
    return swarm

def stop_nodes(swarm):
    for node, thread in swarm:
        thread.join()

In [318]:
class Graph:
    def __init__(self, connectivity_list, encoding=None):
        self.encoding = encoding
        if self.encoding is None:
            self.encoding = {}
        self.connectivity_list = copy.deepcopy(connectivity_list)
    
    @classmethod
    def encode(cls, token_connectivity):
        encoding = {actor: idx for idx, actor in enumerate(token_connectivity)}
        
        connectivity_list = []
        for neighbours in token_connectivity.values():
            connections = []
            for neighbour in neighbours:
                idx = encoding.get(neighbour)
                if idx is None:
                    idx = len(encoding)
                    encoding[neighbour] = idx
                connections.append(idx)
            connectivity_list.append(connections)
            
        return cls(connectivity_list, encoding=encoding)
    

class SymGraph(Graph):
    def __init__(self, connectivity_list, encoding=None):
        sym_connectivity_list = collections.defaultdict(set)
        for idx, neighbours in enumerate(connectivity_list):
            for neighbour in neighbours:
                sym_connectivity_list[idx].add(neighbour)
                sym_connectivity_list[neighbour].add(idx)
        
        super().__init__([list(x[1]) for x in sorted(sym_connectivity_list.items())], encoding)

In [319]:
class SyncModel:
    def __init__(self, embedding_cls, graph):
        self.embedding = embedding_cls(len(graph))
        self.graph = graph
        self.router = []
        self.epoch = 0
        
        for idx, conn_list in enumerate(self.graph):
            self.router.append(Node(self.router, self.embedding, idx, conn_list, nodes_logger))
        
        for node, pos in zip(self.router, self.embedding.get_slots()):
            node.position = pos
    
    def act(self, silent=False):
        rnge = self.router
        if not silent:
            rnge = tqdm(self.router)
        for node in rnge:
            node.act()
    
    def loop(self, n=0):
        for i in tqdm(range(n)):
            self.act(silent=True)
            self.epoch += 1

In [297]:
with open("deezer_clean_data/HU_edges.csv") as Romania:
    sym_connectivity_list = collections.defaultdict(set)
    for row in Romania.read().strip().split()[1:]:
        a, b = map(int, row.split(","))
        sym_connectivity_list[a].add(b)
        sym_connectivity_list[b].add(a)
    HU_deezer_graph = SymGraph([list(x[1]) for x in sorted(sym_connectivity_list.items())]).connectivity_list

In [304]:
test = SyncModel(LineEmbedding, HU_deezer_graph)

In [305]:
nodes_logger.level = logging.CRITICAL

In [309]:
test.loop(100)

  0%|          | 0/100 [00:00<?, ?it/s]