In [7]:
import torch
import torch_geometric

import glob
import os

In [123]:
class NetworkDataset(torch_geometric.data.Dataset):
    def __init__(self, root, transform=None, pre_transform=None):
        super().__init__(root, transform, pre_transform)
    
    @property
    def file_names(self):
        return glob.glob('data/*')
    
    def len(self):
        return len(self.file_names)

    def get(self, idx):
        data = torch.load(os.path.join(self.root, f'data_{idx}.pt'))
        data.x = data.x.type(torch.float32)
        data.edge_attr = data.edge_attr.type(torch.float32)
        return data

In [124]:
normalize_features = torch_geometric.transforms.NormalizeFeatures(['edge_attr'])
# transforms = torch_geometric.transforms.Compose([normalize_features])
# dataset = NetworkDataset('data', transform=normalize_features)
dataset = NetworkDataset('data')
loader = torch_geometric.loader.DataLoader(dataset)

In [191]:
dir(dataset)

['__abstractmethods__',
 '__add__',
 '__class__',
 '__class_getitem__',
 '__delattr__',
 '__dict__',
 '__dir__',
 '__doc__',
 '__eq__',
 '__format__',
 '__ge__',
 '__getattribute__',
 '__getitem__',
 '__gt__',
 '__hash__',
 '__init__',
 '__init_subclass__',
 '__le__',
 '__len__',
 '__lt__',
 '__module__',
 '__ne__',
 '__new__',
 '__orig_bases__',
 '__parameters__',
 '__reduce__',
 '__reduce_ex__',
 '__repr__',
 '__setattr__',
 '__sizeof__',
 '__slotnames__',
 '__slots__',
 '__str__',
 '__subclasshook__',
 '__weakref__',
 '_abc_impl',
 '_download',
 '_indices',
 '_infer_num_classes',
 '_is_protocol',
 '_process',
 'download',
 'file_names',
 'get',
 'get_summary',
 'has_download',
 'has_process',
 'index_select',
 'indices',
 'len',
 'log',
 'num_classes',
 'num_edge_features',
 'num_features',
 'num_node_features',
 'pre_filter',
 'pre_transform',
 'print_summary',
 'process',
 'processed_dir',
 'processed_file_names',
 'processed_paths',
 'raw_dir',
 'raw_file_names',
 'raw_paths',
 '

In [188]:
class Stage1Model(torch.nn.Module):
    '''
        Propagates graph data to current node and neighbors
    '''
    def __init__(self, out_embedding_dim):
        super().__init__()
        
        nn1 = torch.nn.Sequential(
            torch.nn.Linear(dataset.num_edge_features, 32),
            torch.nn.ReLU(),
            torch.nn.Linear(32, dataset.num_node_features*64),
            torch.nn.ReLU()
        )
        # conv1 gets edge information into the nodes
        self.conv1 = torch_geometric.nn.conv.NNConv(dataset.num_node_features, 64, nn1)
        # conv2 and conv3 work on the nodes only
        self.conv2 = torch_geometric.nn.conv.GCNConv(64, 128)
        self.conv3 = torch_geometric.nn.conv.GCNConv(128, 64)
        self.attn = torch_geometric.nn.conv.GATv2Conv(64, out_embedding_dim, heads=3, concat=False)
        
        
        
    def forward(self, network_graph):
        x, edge_index, edge_attr = network_graph.x, network_graph.edge_index, network_graph.edge_attr
        
        x = self.conv1(x, edge_index, edge_attr)
        x = self.conv2(x, edge_index)
        x = self.conv3(x, edge_index)
        out = self.attn(x, edge_index)
        
        return out
    
test = Stage1Model(32)
test(dataset[0])

tensor([[ 0.0719, -0.0214, -0.0346,  ...,  0.0399, -0.0074, -0.0208],
        [ 0.0719, -0.0214, -0.0346,  ...,  0.0399, -0.0074, -0.0208],
        [ 0.0719, -0.0214, -0.0346,  ...,  0.0399, -0.0074, -0.0208],
        ...,
        [ 0.0719, -0.0214, -0.0346,  ...,  0.0399, -0.0074, -0.0208],
        [ 0.0719, -0.0214, -0.0346,  ...,  0.0399, -0.0074, -0.0208],
        [ 0.0719, -0.0214, -0.0346,  ...,  0.0399, -0.0074, -0.0208]],
       grad_fn=<AddBackward0>)

In [152]:
class PreferenceEmbedder(torch.nn.Module):
    '''
        Embeds a (n_preferences, 1) user preference vector to a (embedding_dim, 1) vector
    '''
    def __init__(self, n_preferences, embedding_dim):
        super().__init__()
        self.fclayer = torch.nn.Linear(n_preferences, embedding_dim)
        self.relu = torch.nn.ReLU()
    
    def forward(self, user_preferences):
        x = self.fclayer(user_preferences)
        x = self.relu(x)
        return x
    
data = dataset[0]
test = PreferenceEmbedder(data.y.shape[0], 10)
test(data.y)

tensor([ 0.0000,  0.0000, 51.3322,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
        26.0617,  0.0000], grad_fn=<ReluBackward0>)

In [233]:
class Stage2Model(torch.nn.Module):
    '''
        Does final computations on available node choices and returns finalized choice scores
    '''
    def __init__(self, out_embedding_dim):
        super().__init__()
        self.conv1 = torch_geometric.nn.conv.GCNConv(in_channels=-1, 
                                                out_channels=64)
        self.conv2 = torch_geometric.nn.conv.GCNConv(in_channels=64, 
                                                out_channels=64)
        self.attn = torch_geometric.nn.conv.GATv2Conv(64, out_embedding_dim, heads=3, concat=False)
    
        self.softmax = torch.nn.Softmax(dim=0)
        
    def forward(self, network_graph):
        x, edge_index = network_graph.x, network_graph.edge_index
        x = self.conv1(x, edge_index)
        x = self.conv2(x, edge_index)
        x = self.attn(x, edge_index)
        
        x = torch.sum(x, dim=1)
        scores = self.softmax(x)
        return scores
    
test = Stage2Model(32)
test(dataset[0]).shape

torch.Size([570])

In [None]:
class CustomGraphModel(torch.nn.Module):
    '''
        Stage 1 blocks -> mask/delete -> +preference embeddings -> stage 2 blocks -> out
    '''
    def __init__(self, Stage1Model, PreferenceEmbedder, Stage2Model):
        super().__init__()
        self.Stage1Model = Stage1Model
        self.PreferenceEmbedder = PreferenceEmbedder
        self.Stage2Model = Stage2Model
    
    # Masks all nodes and edges that are not connected to the current node.
    # The current node is excluded.
    def get_neighbor_graph(network_graph_propagated):
        pass
    
    def forward(self, network_graph, user_preferences):
        network_graph_propagated = self.Stage1Model(network_graph)
        
        neighbor_graph = self.get_neighbor_graph(network_graph_propagated)
        user_preferences_embedded = self.PreferenceEmbedder(user_preferences)
        neighbor_graph_with_task_information = torch.add(neighbor_graph.x, user_preferences_embedded)
        
        out = self.Stage2Model(neighbor_graph_with_task_information)
        return out

stage_1_node_embedding_dim = 32
stage_2_node_embedding_dim = 32

stage1model = Stage1Model(stage_1_node_embedding_dim)
preference_embedder = PreferenceEmbedder(data.y.shape[0], stage_1_node_embedding_dim)
stage2model = Stage2Model(stage_2_node_embedding_dim)
model = CustomGraphModel(stage1model, preference_embedder, stage2model)

In [None]:
class RLFramework():
    pass