In [7]:
import torch
import torch_geometric

import glob
import os

In [96]:
class NetworkDataset(torch_geometric.data.Dataset):
    def __init__(self, root, transform=None, pre_transform=None):
        super().__init__(root, transform, pre_transform)
    
    @property
    def file_names(self):
        return glob.glob('data/*')
    
    def len(self):
        return len(self.file_names)

    def get(self, idx):
        data = torch.load(os.path.join(self.root, f'data_{idx}.pt'))
        return data

In [99]:
normalize_features = torch_geometric.transforms.NormalizeFeatures(['edge_attr'])
# transforms = torch_geometric.transforms.Compose([normalize_features])
# dataset = NetworkDataset('data', transform=normalize_features)
dataset = NetworkDataset('data')
loader = torch_geometric.loader.DataLoader(dataset)

In [104]:
print(torch.round(dataset[12].edge_attr, decimals=4))

tensor([[1.7000e-02, 2.0000e+00, 3.8665e+02],
        [1.7550e-01, 2.0000e+00, 1.8692e+01],
        [1.9600e-02, 2.0000e+00, 1.6702e+02],
        ...,
        [1.7820e-01, 2.0000e+00, 9.2070e+01],
        [2.9200e-02, 2.0000e+00, 4.4939e+02],
        [4.3600e-01, 2.0000e+00, 1.6556e+02]], dtype=torch.float64)


In [None]:
class Stage1Model(torch.nn.Module):
    '''
        Propagates graph data to current node and neighbors
    '''
    def __init__(self, n_computation_blocks, out_embedding_dim):
        super().__init__()
        self.computation_blocks = []
        first = torch_geometric.nn.conv.GCNConv(in_channels=-1, 
                                                out_channels=out_embedding_dim, 
                                                add_self_loops=False)
        self.computation_blocks.append(first)
        for _ in range(n_computation_blocks-1):
            block = torch_geometric.nn.conv.GCNConv(in_channels=out_embedding_dim, 
                                            out_channels=out_embedding_dim, 
                                            add_self_loops=False)
            self.computation_blocks.append(block)
        
    def forward(self, network_graph):
        for block in computation_blocks:
            network_graph = block(network_graph)
            
        return network_graph

In [None]:
class PreferenceEmbedder(torch.nn.Module):
    '''
        Embeds a (n_preferences, 1) user preference vector to a (embedding_dim, 1) vector
    '''
    def __init__(self, user_preferences, embedding_dim):
        super().__init__()

In [None]:
class Stage2Model(torch.nn.Module):
    '''
        Does final computations on available node choices and returns finalized choice scores
    '''
    def __init__(self, n_computation_blocks, out_embedding_dim):
        super().__init__()
        self.computation_blocks = []
        first = torch_geometric.nn.conv.GCNConv(in_channels=-1, 
                                                out_channels=out_embedding_dim, 
                                                add_self_loops=False)
        self.computation_blocks.append(first)
        for _ in range(n_computation_blocks-1):
            block = torch_geometric.nn.conv.GCNConv(in_channels=out_embedding_dim, 
                                            out_channels=out_embedding_dim, 
                                            add_self_loops=False)
            self.computation_blocks.append(block)
                 
        self.aggr = torch_geometric.nn.aggr.AddAggregation()
        self.softmax = torch_geometric.utils.softmax
        
    def forward(self, network_graph):
        for block in computation_blocks:
            network_graph = block(network_graph)
            
        network_graph = self.aggr(network_graph, torch.zeros_like(network_graph.x))
        scores = self.softmax(network_graph)
        return scores

In [None]:
class CustomGraphModel(torch.nn.Module):
    '''
        Stage 1 blocks -> mask/delete -> +preference embeddings -> stage 2 blocks -> out
    '''
    def __init__(self, Stage1Model, PreferenceEmbedder, Stage2Model):
        super().__init__()
        self.Stage1Model = Stage1Model
        self.PreferenceEmbedder = PreferenceEmbedder
        self.Stage2Model = Stage2Model
    
    # Masks all nodes and edges that are not connected to the current node.
    # The current node is excluded.
    def get_neighbor_graph(network_graph_propagated):
        pass
    
    def forward(self, network_graph, user_preferences):
        network_graph_propagated = self.Stage1Model(network_graph)
        
        neighbor_graph = self.get_neighbor_graph(network_graph_propagated)
        user_preferences_embedded = self.PreferenceEmbedder(user_preferences)
        neighbor_graph_with_task_information = torch.add(neighbor_graph, user_preferences_embedded)
        
        out = self.Stage2Model(neighbor_graph_with_task_information)
        return out

In [None]:
class RLFramework():
    pass