In [7]:
import torch
import torch_geometric

import glob
import os

In [123]:
class NetworkDataset(torch_geometric.data.Dataset):
    def __init__(self, root, transform=None, pre_transform=None):
        super().__init__(root, transform, pre_transform)
    
    @property
    def file_names(self):
        return glob.glob('data/*')
    
    def len(self):
        return len(self.file_names)

    def get(self, idx):
        data = torch.load(os.path.join(self.root, f'data_{idx}.pt'))
        data.x = data.x.type(torch.float32)
        data.edge_attr = data.edge_attr.type(torch.float32)
        return data

In [124]:
normalize_features = torch_geometric.transforms.NormalizeFeatures(['edge_attr'])
# transforms = torch_geometric.transforms.Compose([normalize_features])
# dataset = NetworkDataset('data', transform=normalize_features)
dataset = NetworkDataset('data')
loader = torch_geometric.loader.DataLoader(dataset)

In [157]:
dataset.num_edge_features

3

In [174]:
class Stage1Model(torch.nn.Module):
    '''
        Propagates graph data to current node and neighbors
    '''
    def __init__(self, n_computation_blocks, out_embedding_dim):
        super().__init__()
        
        nn1 = torch.nn.Sequential(
            torch.nn.Linear(dataset.num_edge_features, dataset.num_node_features*64),
            torch.nn.ReLU()
        )
        # conv1 gets edge information into the nodes
        self.conv1 = torch_geometric.nn.conv.NNConv(dataset.num_node_features, 64, nn1)
        # conv2 works on the nodes only
        self.conv2 = torch_geometric.nn.conv.GCNConv(64, out_embedding_dim)
        
        
    def forward(self, network_graph):
        x, edge_index, edge_attr = network_graph.x, network_graph.edge_index, network_graph.edge_attr
        
        x = self.conv1(x, edge_index, edge_attr)
        out = self.conv2(x, edge_index)
        return out
    
test = Stage1Model(3, 32)
test(dataset[0]).shape

torch.Size([570, 32])

In [152]:
class PreferenceEmbedder(torch.nn.Module):
    '''
        Embeds a (n_preferences, 1) user preference vector to a (embedding_dim, 1) vector
    '''
    def __init__(self, n_preferences, embedding_dim):
        super().__init__()
        self.fclayer = torch.nn.Linear(n_preferences, embedding_dim)
        self.relu = torch.nn.ReLU()
    
    def forward(self, user_preferences):
        x = self.fclayer(user_preferences)
        x = self.relu(x)
        return x
    
data = dataset[0]
test = PreferenceEmbedder(data.y.shape[0], 10)
test(data.y)

tensor([ 0.0000,  0.0000, 51.3322,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
        26.0617,  0.0000], grad_fn=<ReluBackward0>)

In [None]:
class Stage2Model(torch.nn.Module):
    '''
        Does final computations on available node choices and returns finalized choice scores
    '''
    def __init__(self, n_computation_blocks, out_embedding_dim):
        super().__init__()
        assert n_computation_blocks > 0
        self.computation_blocks = []
        first = torch_geometric.nn.conv.GCNConv(in_channels=-1, 
                                                out_channels=out_embedding_dim, 
                                                add_self_loops=False)
        self.computation_blocks.append(first)
        for _ in range(n_computation_blocks-1):
            block = torch_geometric.nn.conv.GCNConv(in_channels=out_embedding_dim, 
                                                    out_channels=out_embedding_dim, 
                                                    add_self_loops=False)
            self.computation_blocks.append(block)
                 
        self.aggr = torch_geometric.nn.aggr.SumAggregation()
        self.softmax = torch_geometric.utils.softmax
        
    def forward(self, network_graph):
        for block in computation_blocks:
            network_graph = block(network_graph)
            
        network_graph = self.aggr(network_graph, torch.zeros_like(network_graph.x))
        scores = self.softmax(network_graph)
        return scores

In [None]:
class CustomGraphModel(torch.nn.Module):
    '''
        Stage 1 blocks -> mask/delete -> +preference embeddings -> stage 2 blocks -> out
    '''
    def __init__(self, Stage1Model, PreferenceEmbedder, Stage2Model):
        super().__init__()
        self.Stage1Model = Stage1Model
        self.PreferenceEmbedder = PreferenceEmbedder
        self.Stage2Model = Stage2Model
    
    # Masks all nodes and edges that are not connected to the current node.
    # The current node is excluded.
    def get_neighbor_graph(network_graph_propagated):
        pass
    
    def forward(self, network_graph, user_preferences):
        network_graph_propagated = self.Stage1Model(network_graph)
        
        neighbor_graph = self.get_neighbor_graph(network_graph_propagated)
        user_preferences_embedded = self.PreferenceEmbedder(user_preferences)
        neighbor_graph_with_task_information = torch.add(neighbor_graph.x, user_preferences_embedded)
        
        out = self.Stage2Model(neighbor_graph_with_task_information)
        return out

In [None]:
class RLFramework():
    pass