In [1]:
%cd ..

/Users/pablomirallesgonzalez/Documents/masters-degree/Análisis de Redes Sociales/sna-ceb-assignment


In [2]:
import sys
sys.path.append('src')

In [3]:
import numpy as np
import pandas as pd
import networkx as nx
import matplotlib.pyplot as plt
import functools
import random

from networkx import community as nxcom
from genetic_algs import NSGA2, tournament_select, random_locus_crossover, mutate_locus, create_locus

## Carga de datos

In [4]:
GRAPH_FILE = "data/amazon_graph.graphml"
graph = nx.read_graphml(GRAPH_FILE)

In [5]:
graph.number_of_nodes(), graph.number_of_edges()

(475, 1184)

In [6]:
nodes = list(graph.nodes())
node_to_idx = {node: idx for idx, node in enumerate(nodes)}
idx_to_node = {idx: node for idx, node in enumerate(nodes)}

edges: list[list[int]] = [
    [ node_to_idx[neighbour] for neighbour in graph.neighbors(idx_to_node[idx]) ]
    for idx in range(graph.number_of_nodes())
]

In [7]:
edges

[[154, 462, 64],
 [301, 271, 273],
 [9, 364, 230, 357],
 [126, 436, 205, 64, 145],
 [261, 65, 84, 342, 246, 365, 170],
 [445, 57, 210, 121, 91, 86],
 [402],
 [304, 392, 153, 224, 29, 184, 64, 429, 379],
 [381, 92, 395],
 [2, 147],
 [117],
 [386, 14, 384, 398, 40],
 [464, 258],
 [70, 189, 321],
 [11, 386, 370, 15, 268, 398],
 [14, 386, 370, 268, 398, 40],
 [253, 418, 419, 83, 367, 369, 406, 408, 82, 42, 17],
 [16, 253, 418, 419, 83, 367, 369, 371, 408, 82],
 [473, 90, 59, 363, 458],
 [311, 195, 137],
 [418, 139],
 [216, 435, 430, 228],
 [104, 207, 215, 413],
 [226, 33],
 [104, 207, 215, 392],
 [455,
  75,
  60,
  185,
  441,
  218,
  112,
  359,
  227,
  279,
  469,
  438,
  243,
  259,
  28,
  111,
  426,
  120,
  175,
  355],
 [94, 311, 309],
 [471, 185, 191, 106, 175],
 [25, 286, 463, 441, 279, 469, 391, 243, 259, 426, 254, 232, 120],
 [7, 304, 184, 379, 325, 39],
 [387, 465, 122, 260],
 [335, 276, 35, 337, 110, 155],
 [357],
 [23, 402],
 [197],
 [31, 335, 75, 196, 276, 337, 438, 110

In [8]:
class DisjointSetUnion:
    def __init__(self, n: int):
        self.parent = [i for i in range(n)]
    
    def find(self, x: int) -> int:
        if self.parent[x] == x:
            return x
        self.parent[x] = self.find(self.parent[x])
        return self.parent[x]
    
    def join(self, x: int, y: int) -> None:
        x = self.find(x)
        y = self.find(y)
        if x == y:
            return
        if self.rank[x] < self.rank[y]:
            x, y = y, x
        self.parent[y] = x
        if self.rank[x] == self.rank[y]:
            self.rank[x] += 1
    
    def get_components(self) -> list[list[int]]:
        components = {}
        for idx in range(len(self.parent)):
            components.setdefault(self.find(idx), []).append(idx)
        return list(components.values())


In [9]:
population_size = 100
ngen = 100
pcross = 0.8

pmut = 0.2
mutate_ratio = 0.2
mutate_fn = functools.partial(mutate_locus, edges, mutate_ratio)

T = 8
select_fn = functools.partial(tournament_select, T)

def fitness_fn(individual: list[int]) -> tuple[float, float]:
    # random
    return random.random(), random.random()

population = create_locus(edges, population_size)

In [10]:
ga = NSGA2[list[int], tuple[float, float]](
    fitness_fn=fitness_fn,
    select_fn=select_fn,
    crossover_fn=random_locus_crossover,
    mutate_fn=mutate_fn,
    pcross=pcross,
    pmut=pmut,
)

In [11]:
ga.run(population, ngen=ngen)

[((0.9999282763736768, 0.9138591962606499),
  [154,
   273,
   357,
   64,
   170,
   57,
   402,
   304,
   395,
   2,
   117,
   40,
   258,
   321,
   370,
   398,
   367,
   82,
   363,
   311,
   418,
   216,
   104,
   226,
   207,
   359,
   311,
   471,
   243,
   39,
   260,
   335,
   357,
   23,
   197,
   245,
   357,
   437,
   301,
   121,
   268,
   376,
   104,
   281,
   79,
   293,
   341,
   190,
   287,
   124,
   397,
   164,
   431,
   108,
   66,
   329,
   70,
   5,
   75,
   18,
   359,
   65,
   376,
   258,
   126,
   168,
   54,
   283,
   149,
   212,
   321,
   474,
   297,
   301,
   301,
   441,
   301,
   180,
   150,
   113,
   222,
   146,
   406,
   16,
   95,
   240,
   210,
   90,
   473,
   474,
   87,
   210,
   8,
   178,
   261,
   179,
   432,
   201,
   173,
   190,
   315,
   470,
   338,
   258,
   413,
   260,
   448,
   290,
   53,
   185,
   245,
   441,
   455,
   150,
   176,
   456,
   64,
   10,
   85,
   454,
   355,
   91,
   387,
