In [196]:
import torch
from torch_geometric.data import Data
from torch_geometric.transforms import ToUndirected
import os
import json

from dataset import Dataset

def load_graph():
    dataset = Dataset()
    dataset.load("data")

    # Add the embeddings to the nodes
    for node in dataset.nodes:
        node[1]["embedding"] = torch.empty((0, 1024))
        if node[0] in os.listdir(f"embeddings/"):
            embedding_paths = [f"embeddings/{node[0]}/{i}" for i in
                               os.listdir(f"embeddings/{node[0]}")]
            for path in embedding_paths:
                with open(path, "r") as f:
                    try:
                        embedding = torch.Tensor(json.load(f))
                        node[1]["embedding"] = torch.cat(
                            (node[1]["embedding"], embedding),
                            dim=0
                        )
                    except json.decoder.JSONDecodeError:
                        pass
        # Take the first 30 embeddings
        node[1]["embedding"] = node[1]["embedding"][:30, :]
        node[1]["embedding"] = torch.nn.ConstantPad2d(
            (0, 0, 30-node[1]["embedding"].shape[0], torch.nan), 0
        )(node[1]["embedding"])
        node[1]["embedding"] = node[1]["embedding"].flatten()

    # Remove edges with no ids in the nodes (the collection was stopped before expansion to these nodes)
    edge_ids = [i[0][0] for i in dataset.edges] + [i[0][1] for i in dataset.edges]
    edge_ids = list(set(edge_ids))
    node_ids = [i[0] for i in dataset.nodes]
    not_in = [i for i in edge_ids if i not in node_ids]
    print(len(dataset.edges))
    dataset.edges = [i for i in dataset.edges if i[0][1] not in not_in]
    print(len(dataset.edges))

    # Change the idx of the album nodes
    id_map = {}
    for idx, node in enumerate(dataset.nodes):
        id_map[node[0]] = idx
        node[0] = idx

    # And change their indices in the album_artist_edges as well
    for edge in dataset.edges:
        edge[0][0] = id_map[edge[0][0]]
        edge[0][1] = id_map[edge[0][1]]

    # TODO: remove nodes with no embeddings and reconnect their neighbours
    for node in dataset.nodes:
        pass

    x = torch.vstack([i[1]["embedding"] for i in dataset.nodes])
    edges = torch.Tensor([[i[0][0], i[0][1]] for i in dataset.edges]).t().contiguous().long()
    data = Data(x=x, edge_index=edges)
    data = ToUndirected()(data)
    return data, dataset

graph, dataset = load_graph()
 
print(">>> num_nodes:", graph.num_nodes)
print(">>> num_edges:", graph.num_edges)
graph.validate()


616
541
>>> num_nodes: 438
>>> num_edges: 541


True

In [186]:
count = 0
for node in dataset.nodes:
    n_missing = node[1]["embedding"].sum(dim=1).isnan().sum()
    if n_missing == 30: count += 1
print(count / len(dataset.nodes))

IndexError: Dimension out of range (expected to be in range of [-1, 0], but got 1)

In [202]:
from bokeh.io import output_notebook, show
from bokeh.models import Range1d, Circle, MultiLine
from bokeh.plotting import figure
from bokeh.plotting import from_networkx
import networkx
from torch_geometric.utils import to_networkx

G = to_networkx(graph)
plot = figure(tools="pan,wheel_zoom,save,reset", active_scroll='wheel_zoom',
              x_range=Range1d(-10.1, 10.1), y_range=Range1d(-10.1, 10.1), width=800)
network_graph = from_networkx(G, networkx.spring_layout, scale=10, center=(0, 0))
network_graph.node_renderer.glyph = Circle(size=15, fill_color='skyblue',
                                           fill_alpha=0.8, line_width=0.7,
                                           line_color="black")
network_graph.edge_renderer.glyph = MultiLine(line_alpha=0.3, line_width=1)
plot.renderers.append(network_graph)
output_notebook(hide_banner=True)
show(plot)

In [203]:
graph.x.shape

torch.Size([438, 30720])

In [215]:
from torch_geometric.loader import NeighborLoader
sampler = NeighborLoader(graph, num_neighbors=[2]*2, batch_size=16)

samples = [i for i in iter(sampler)]
print(">>> n_samples:", len(samples))

random_idx = torch.randint(0, len(samples), (1,)).item()
print(">>> random_idx:", random_idx)
sample = samples[random_idx]

print(sample.keys)
print(sample.x.shape)
print(sample.edge_index.shape)


G = to_networkx(sample)
plot = figure(tools="pan,wheel_zoom,save,reset", active_scroll='wheel_zoom',
              x_range=Range1d(-10.1, 10.1), y_range=Range1d(-10.1, 10.1), width=800)
network_graph = from_networkx(G, networkx.spring_layout, scale=10, center=(0, 0))
network_graph.node_renderer.glyph = Circle(size=15, fill_color='skyblue',
                                           fill_alpha=0.8, line_width=0.7,
                                           line_color="black")
network_graph.edge_renderer.glyph = MultiLine(line_alpha=0.3, line_width=1)
plot.renderers.append(network_graph)
output_notebook(hide_banner=True)
show(plot)

>>> n_samples: 28
>>> random_idx: 17
['input_id', 'n_id', 'e_id', 'num_sampled_nodes', 'x', 'batch_size', 'edge_index', 'num_sampled_edges']
torch.Size([19, 30720])
torch.Size([2, 18])


In [82]:
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.data import Batch

class GCN(torch.nn.Module):
    def __init__(self, hidden_channels):
        super().__init__()
        torch.manual_seed(1234567)
        self.conv1 = GCNConv(5, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, 1)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = x.relu()
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.conv2(x, edge_index)
        x = F.sigmoid(x)
        return x

model = GCN(hidden_channels=10).cuda()

optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = torch.nn.BCELoss()

samples = [i.cuda() for i in iter(sampler)]
batches = [i.to_homogeneous() for i in iter(sampler)]

for idx, batch in enumerate(batches):
    optimizer.zero_grad()

    print(type(batch))
    print(dir(batch))

    print(batch[0])

    stop

    print(sample.x.shape)
    # x = sample.x[:5].unsqueeze(0)
    # y = sample.y[0]
    # edge_index = sample.edge_index

    # out = model(x, torch.empty(2, 0, dtype=torch.long).cuda())
    # loss = criterion(out.squeeze(), y.squeeze(0))

    # print(loss.item(), out.item(), y.item())



<class 'torch_geometric.data.data.Data'>
['__call__', '__cat_dim__', '__class__', '__contains__', '__copy__', '__deepcopy__', '__delattr__', '__delitem__', '__dict__', '__dir__', '__doc__', '__eq__', '__format__', '__ge__', '__getattr__', '__getattribute__', '__getitem__', '__getstate__', '__gt__', '__hash__', '__inc__', '__init__', '__init_subclass__', '__iter__', '__le__', '__len__', '__lt__', '__module__', '__ne__', '__new__', '__reduce__', '__reduce_ex__', '__repr__', '__setattr__', '__setitem__', '__setstate__', '__sizeof__', '__str__', '__subclasshook__', '__weakref__', '_edge_attr_cls', '_edge_to_layout', '_edges_to_layout', '_get_edge_index', '_get_tensor', '_get_tensor_size', '_multi_get_tensor', '_put_edge_index', '_put_tensor', '_remove_edge_index', '_remove_tensor', '_store', '_tensor_attr_cls', '_to_type', 'apply', 'apply_', 'batch', 'clone', 'coalesce', 'contains_isolated_nodes', 'contains_self_loops', 'contiguous', 'coo', 'cpu', 'csc', 'csr', 'cuda', 'debug', 'detach', '

KeyError: 0