In [8]:
import torch
import torch.nn as nn

from stable_baselines3.common.torch_layers import BaseFeaturesExtractor

In [9]:
#Environment -> observation -> FeatureExtractor -> features -> Neural Network -> action -> Environment

In [10]:
from rough.code.gcn_gru_model import RecurrentGCNGRU


class CustomGCNGRUFeatureExtractor(BaseFeaturesExtractor):
    def __init__(self, observation_space, features_dim=512):
        super(CustomGCNGRUFeatureExtractor, self).__init__(observation_space, features_dim)
        
        gru_hidden_size = 256
        num_gru_layers = 1
        self.preprocessing_model = RecurrentGCNGRU(features_dim, gru_hidden_size, num_gru_layers)
    
        self.linear = nn.Sequential(nn.Linear(int(features_dim / 2), features_dim), nn.ReLU())

    def forward(self, observations):
        adjacency_matrix = observations["adjacency_matrix"]
        node_features = observations["node_features"]
        traffic_src = observations["traffic_src"]
        traffic_dst = observations["traffic_dst"]
        bit_rate = observations["bit_rate"]
        spectrum_details = observations["spectrum_details"]
        
        gcn_gru_output = self.preprocessing_model(node_features, adjacency_matrix)
        final_preprocessed_output = self.linear(gcn_gru_output)
        
        flattened_obs = torch.cat(
            (final_preprocessed_output.flatten(start_dim=1),
            traffic_src.flatten(start_dim=1),
            traffic_dst.flatten(start_dim=1),
            bit_rate.flatten(start_dim=1),
            spectrum_details.flatten(start_dim=1)), dim=1
        )
        
        return flattened_obs