In [1]:
!pip install stable-baselines3[extra]

Collecting stable-baselines3[extra]
  Downloading stable_baselines3-2.6.0-py3-none-any.whl.metadata (4.8 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch<3.0,>=2.3->stable-baselines3[extra])
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch<3.0,>=2.3->stable-baselines3[extra])
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch<3.0,>=2.3->stable-baselines3[extra])
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch<3.0,>=2.3->stable-baselines3[extra])
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch<3.0,>=2.3->stable-baselines3[extra])
  Downloading nvidia_cublas_cu12-12.4.5.8-py

In [2]:
import gymnasium as gym
import numpy as np
from stable_baselines3 import PPO
from stable_baselines3.common.env_checker import check_env
from stable_baselines3.common.monitor import Monitor
import datetime
from typing import List, Tuple, Dict, Any
import pprint
from stable_baselines3.common.vec_env import DummyVecEnv, VecNormalize

In [3]:
input_data_universal = {
    "Periods": 3,
    "Limit_Projects_Per_Period": 1,
    "Total_Overall_Budget": 407790913,
    "Limit_Sites_In_Region_Per_Period": 4,
    "Limit_Total_Sites_In_Region": 5,
    "Regions": {
        "Category_A": {
            "Site_Types_Available": {
                "Site_Type_Alpha": {
                    "Priority_Score": 3,
                    "Overall_Cost": 15882074,
                    "Regional_Cost_Impact": 3424817
                },
                "Site_Type_Beta": {
                    "Priority_Score": 5,
                    "Overall_Cost": 19892720,
                    "Regional_Cost_Impact": 3099341
                },
                "Site_Type_Gamma": {
                    "Priority_Score": 7,
                    "Overall_Cost": 15987521,
                    "Regional_Cost_Impact": 4192091
                }
            },
            "Number_Of_Needy": 3963,
            "Region_Rank": 4,
            "Initial_Regional_Budget": 5827884
        },
        "Category_B": {
            "Site_Types_Available": {
                "Site_Type_Alpha": {
                    "Priority_Score": 16,
                    "Overall_Cost": 14352038,
                    "Regional_Cost_Impact": 3075759
                },
                "Site_Type_Beta": {
                    "Priority_Score": 12,
                    "Overall_Cost": 12079944,
                    "Regional_Cost_Impact": 3738422
                },
                "Site_Type_Gamma": {
                    "Priority_Score": 13,
                    "Overall_Cost": 14485456,
                    "Regional_Cost_Impact": 5122573
                }
            },
            "Number_Of_Needy": 5523,
            "Region_Rank": 4,
            "Initial_Regional_Budget": 6296763
        }
    },
    "General_Site_Type_Info": {
        "Site_Type_Alpha": { "Capacity_Or_Feature": 94 },
        "Site_Type_Beta": { "Capacity_Or_Feature": 103 },
        "Site_Type_Gamma": { "Capacity_Or_Feature": 117 }
    }
}

In [4]:
class ResourcePlanningEnv(gym.Env):
    metadata = {"render_modes": [], "render_fps": 4}

    def __init__(self, config: Dict[str, Any]):
        super().__init__()

        self.config = config

        self.periods = config["Periods"]
        self.total_overall_budget_initial = float(config["Total_Overall_Budget"])
        self.limit_projects_per_period = config["Limit_Projects_Per_Period"]
        self.limit_sites_in_category_per_period = config["Limit_Sites_In_Region_Per_Period"]
        self.limit_total_sites_in_category = config["Limit_Total_Sites_In_Region"]

        self.category_names = list(config["Regions"].keys())
        self.n_categories = len(self.category_names)

        all_site_types_set = set()
        for cat_name in self.category_names:
            available_types_dict = config["Regions"][cat_name].get("Site_Types_Available", {})
            for site_type_name in available_types_dict.keys():
                all_site_types_set.add(site_type_name)
        self.global_site_type_names = sorted(list(all_site_types_set))

        self.possible_sites: List[Tuple[str, str]] = []
        self.site_details: Dict[Tuple[str, str], Dict[str, Any]] = {}
        for cat_name in self.category_names:
            available_types_dict = config["Regions"][cat_name].get("Site_Types_Available", {})
            for site_type_name in available_types_dict:
                if site_type_name in self.global_site_type_names:
                    site_key = (cat_name, site_type_name)
                    self.possible_sites.append(site_key)
                    self.site_details[site_key] = available_types_dict[site_type_name]

        self.n_possible_sites = len(self.possible_sites)
        self.pass_action_index = self.n_possible_sites
        self.action_space = gym.spaces.Discrete(self.n_possible_sites + 1)

        obs_space_size = 1 + 1 + 1 + self.n_categories + self.n_categories + self.n_categories + self.n_possible_sites

        low_bounds = np.zeros(obs_space_size, dtype=np.float32)
        high_bounds = np.ones(obs_space_size, dtype=np.float32)
        self.observation_space = gym.spaces.Box(low=low_bounds, high=high_bounds, dtype=np.float32)

        self.current_period = 0
        self.remaining_overall_budget = 0.0
        self.projects_built_this_period = 0
        self.sites_built_in_category_this_period = {}
        self.total_sites_built_in_category = {}
        self.initial_regional_budgets = {}
        self.remaining_regional_budgets = {}
        self.site_built_mask = np.zeros(self.n_possible_sites, dtype=np.int8)


    def reset(self, seed=None, options=None):
        super().reset(seed=seed)
        self.current_period = 0
        self.remaining_overall_budget = self.total_overall_budget_initial
        self.projects_built_this_period = 0

        self.initial_regional_budgets = {
            cat_name: float(cat_details.get("Initial_Regional_Budget", 0.0))
            for cat_name, cat_details in self.config["Regions"].items()
        }
        self.remaining_regional_budgets = self.initial_regional_budgets.copy()

        self.sites_built_in_category_this_period = {cat: 0 for cat in self.category_names}
        self.total_sites_built_in_category = {cat: 0 for cat in self.category_names}
        self.site_built_mask = np.zeros(self.n_possible_sites, dtype=np.int8)
        return self._get_obs(), self._get_info()

    def step(self, action: int):
        if not (0 <= action < self.action_space.n):
             print(f"ERROR: Invalid action {action}")
             return self._get_obs(), -100.0, True, False, {"error": "Invalid action index"}

        terminated = False
        truncated = False
        reward = 0.0
        site_built_successfully = False
        action_desc_for_info = "N/A"
        built_site_key_for_info = None

        if action == self.pass_action_index:
            action_desc_for_info = "PASS"
            self.current_period += 1
            self.projects_built_this_period = 0
            self.sites_built_in_category_this_period = {cat: 0 for cat in self.category_names}
            reward = 0.0
            if self.current_period >= self.periods:
                terminated = True
        else:
            site_idx = action
            if site_idx < 0 or site_idx >= self.n_possible_sites:
                 print(f"WARNING: Invalid site_idx {site_idx}")
                 return self._get_obs(), -10.0, False, False, {}

            category_name, site_type_name = self.possible_sites[site_idx]
            action_desc_for_info = f"BUILD {category_name}-{site_type_name}"
            details = self.site_details[(category_name, site_type_name)]

            overall_cost = float(details["Overall_Cost"])
            regional_cost_impact = float(details.get("Regional_Cost_Impact", 0.0))
            priority_score = float(details["Priority_Score"])

            can_build = True
            if self.site_built_mask[site_idx] == 1: can_build = False
            elif overall_cost > self.remaining_overall_budget: can_build = False
            elif regional_cost_impact > self.remaining_regional_budgets.get(category_name, 0): can_build = False
            elif self.projects_built_this_period >= self.limit_projects_per_period: can_build = False
            elif self.sites_built_in_category_this_period.get(category_name,0) >= self.limit_sites_in_category_per_period: can_build = False
            elif self.total_sites_built_in_category.get(category_name,0) >= self.limit_total_sites_in_category: can_build = False

            if can_build:
                self.remaining_overall_budget -= overall_cost
                self.remaining_regional_budgets[category_name] -= regional_cost_impact
                self.projects_built_this_period += 1
                self.sites_built_in_category_this_period[category_name] += 1
                self.total_sites_built_in_category[category_name] += 1
                self.site_built_mask[site_idx] = 1
                reward = priority_score
                site_built_successfully = True
                built_site_key_for_info = (category_name, site_type_name)
            else:
                reward = -1.0

        observation = self._get_obs()
        current_info = self._get_info()
        current_info['action_description'] = action_desc_for_info
        current_info['action_valid'] = site_built_successfully or (action == self.pass_action_index)
        current_info['site_built_key'] = built_site_key_for_info

        return observation, reward, terminated, truncated, current_info

    def _get_obs(self) -> np.ndarray:
        obs_list = []

        obs_list.append(self.current_period / max(1.0, float(self.periods - 1)) if self.periods > 0 else 0.0)

        obs_list.append(self.remaining_overall_budget / self.total_overall_budget_initial if self.total_overall_budget_initial > 0 else 0.0)

        obs_list.append(self.projects_built_this_period / float(self.limit_projects_per_period) if self.limit_projects_per_period > 0 else 0.0)

        for cat_name in self.category_names:
            obs_list.append(self.sites_built_in_category_this_period.get(cat_name, 0) / float(self.limit_sites_in_category_per_period) if self.limit_sites_in_category_per_period > 0 else 0.0)

        for cat_name in self.category_names:
            obs_list.append(self.total_sites_built_in_category.get(cat_name, 0) / float(self.limit_total_sites_in_category) if self.limit_total_sites_in_category > 0 else 0.0)

        for cat_name in self.category_names:
            initial_reg_budget = self.initial_regional_budgets.get(cat_name, 0.0)
            remaining_reg_budget = self.remaining_regional_budgets.get(cat_name, 0.0)
            if initial_reg_budget <= 0:
                norm_reg_budget = 1.0 if remaining_reg_budget > 1e-6 else 0.0
            else:
                norm_reg_budget = remaining_reg_budget / initial_reg_budget
            obs_list.append(np.clip(norm_reg_budget, 0.0, 1.0))

        obs_list.extend(self.site_built_mask.astype(np.float32))
        return np.array(obs_list, dtype=np.float32)

    def _get_info(self) -> Dict[str, Any]:
        return {
            "current_period": self.current_period,
            "remaining_overall_budget": self.remaining_overall_budget,
            "remaining_regional_budgets": self.remaining_regional_budgets.copy(),
            "projects_built_this_period": self.projects_built_this_period,
            "sites_built_in_category_this_period": self.sites_built_in_category_this_period.copy(),
            "total_sites_built_in_category": self.total_sites_built_in_category.copy(),
            "site_built_mask_readable": {self.possible_sites[i]: int(self.site_built_mask[i]) for i in range(self.n_possible_sites) if i < len(self.possible_sites)},
            "pass_action_index": self.pass_action_index
        }
    def render(self): pass
    def close(self): pass

In [5]:
print("--- Создание и проверка среды ---")

env_raw = ResourcePlanningEnv(config=input_data_universal)
env_for_ppo = Monitor(env_raw)

check_env(ResourcePlanningEnv(config=input_data_universal))
print("Среда успешно создана и проверена (базовый класс)!")
print(f"Пространство действий PPO env: {env_for_ppo.action_space}")
print(f"Пространство наблюдений PPO env (low): {env_for_ppo.observation_space.low}")
print(f"Пространство наблюдений PPO env (high): {env_for_ppo.observation_space.high}")

original_env = env_for_ppo.unwrapped
if isinstance(original_env, Monitor): original_env = original_env.unwrapped
print(f"Количество возможных сайтов/проектов: {original_env.n_possible_sites}")
print(f"Индекс действия Pass: {original_env.pass_action_index}")

--- Создание и проверка среды ---
Среда успешно создана и проверена (базовый класс)!
Пространство действий PPO env: Discrete(7)
Пространство наблюдений PPO env (low): [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
Пространство наблюдений PPO env (high): [1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]
Количество возможных сайтов/проектов: 6
Индекс действия Pass: 6


In [6]:
# --- Обучение модели PPO ---
print("\n--- Обучение модели PPO ---")
total_timesteps_ppo = 50000
policy_kwargs = dict(net_arch=dict(pi=[128, 128], vf=[128, 128]))

print(f"Начало обучения ({datetime.datetime.now()}) на {total_timesteps_ppo} шагах...")
model = PPO(
    "MlpPolicy", env_for_ppo, verbose=1,
    n_steps=256, batch_size=64, n_epochs=10,
    gamma=0.98, ent_coef=0.01, learning_rate=3e-4,
    policy_kwargs=policy_kwargs, seed=42
)
model.learn(total_timesteps=total_timesteps_ppo, progress_bar=True)
print(f"Обучение завершено ({datetime.datetime.now()}).")


--- Обучение модели PPO ---
Начало обучения (2025-05-20 04:40:24.621441) на 50000 шагах...
Using cuda device
Wrapping the env in a DummyVecEnv.


Output()



---------------------------------
| rollout/           |          |
|    ep_len_mean     | 19.4     |
|    ep_rew_mean     | 2.77     |
| time/              |          |
|    fps             | 193      |
|    iterations      | 1        |
|    time_elapsed    | 1        |
|    total_timesteps | 256      |
---------------------------------
-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 17.6        |
|    ep_rew_mean          | 4.21        |
| time/                   |             |
|    fps                  | 203         |
|    iterations           | 2           |
|    time_elapsed         | 2           |
|    total_timesteps      | 512         |
| train/                  |             |
|    approx_kl            | 0.018761158 |
|    clip_fraction        | 0.0965      |
|    clip_range           | 0.2         |
|    entropy_loss         | -1.94       |
|    explained_variance   | -0.0151     |
|    learning_rate        | 0.

Обучение завершено (2025-05-20 04:42:34.469644).


In [7]:
print("\n--- Тестирование обученной модели ---")
test_env_raw = ResourcePlanningEnv(config=input_data_universal)
test_env = Monitor(test_env_raw)
obs, info_reset = test_env.reset()

planned_site_details_test = []
total_reward_test = 0
step_count = 0

original_test_env = test_env.unwrapped
if isinstance(original_test_env, Monitor): original_test_env = original_test_env.unwrapped

max_steps_test = original_test_env.periods * (original_test_env.n_possible_sites + original_test_env.limit_projects_per_period + 1)

terminated = False
truncated = False
while not terminated and not truncated and step_count < max_steps_test:
    action, _states = model.predict(obs, deterministic=True)
    action_int = int(action)

    period_before_step = original_test_env.current_period

    obs, reward, terminated, truncated, info_step = test_env.step(action)
    total_reward_test += reward
    step_count += 1

    action_desc = info_step.get('action_description', "N/A")
    current_period_for_log = info_step.get('current_period', "N/A")

    print(f"Шаг {step_count}: Действие={action_desc}, Награда={reward:.1f}, Terminated={terminated}, Truncated={truncated}, Period now: {current_period_for_log}")

    if info_step.get('action_valid') and info_step.get('site_built_key') is not None:
        site_key = info_step['site_built_key']
        site_data = original_test_env.site_details.get(site_key)
        if site_data:
            planned_site_details_test.append({
                "period": period_before_step,
                "category": site_key[0],
                "type": site_key[1],
                "priority_score": site_data.get("Priority_Score"),
                "overall_cost": site_data.get("Overall_Cost"),
                "regional_cost_impact": site_data.get("Regional_Cost_Impact")
            })
        else:
            print(f"Warning: Could not find details for built site {site_key}")

    if terminated or truncated:
        print(f"\nЭпизод завершен на шаге {step_count}.")
        break

print("\n--- Результат планирования ---")
if planned_site_details_test:
    print("Освоенные ресурсы (Детали):")
    total_priority_calc = 0
    total_overall_cost_calc = 0
    total_regional_cost_by_category = {cat_name: 0.0 for cat_name in original_test_env.category_names}

    for proj in planned_site_details_test:
        print(f"- Период: {proj['period']}, Категория: {proj['category']}, Тип: {proj['type']} (Ценность: {proj['priority_score']}, Общ. стоимость: {proj['overall_cost']}, Регион. стоимость: {proj['regional_cost_impact']})")
        if proj['priority_score']: total_priority_calc += proj['priority_score']
        if proj['overall_cost']: total_overall_cost_calc += proj['overall_cost']
        if proj['category'] in total_regional_cost_by_category and proj['regional_cost_impact']:
            total_regional_cost_by_category[proj['category']] += proj['regional_cost_impact']

    print(f"\nВсего освоено инициатив: {len(planned_site_details_test)}")
    print(f"Суммарная оценочная ценность: {total_priority_calc}")
    print(f"Потраченный общий бюджет: {total_overall_cost_calc}")

    final_overall_budget_info = info_step.get('remaining_overall_budget', "N/A")
    print(f"Оставшийся общий бюджет (из среды): {final_overall_budget_info}")

    print("Затраты и остатки по региональным бюджетам:")
    final_regional_budgets_info = info_step.get('remaining_regional_budgets', {})
    for cat_name in original_test_env.category_names:
        spent = total_regional_cost_by_category[cat_name]
        remaining = final_regional_budgets_info.get(cat_name, "N/A")
        initial = original_test_env.initial_regional_budgets.get(cat_name, "N/A")
        print(f"  - {cat_name}: Потрачено={spent}, Остаток={remaining} (Начальный={initial})")

    print(f"Итоговая награда за тест: {total_reward_test:.2f}")
else:
    print("Ни одного ресурса не было освоено.")
    print(f"Итоговая награда за тест: {total_reward_test:.2f}")

if hasattr(env_for_ppo, 'close'): env_for_ppo.close()
if hasattr(test_env, 'close') and test_env != env_for_ppo : test_env.close()


--- Тестирование обученной модели ---
Шаг 1: Действие=BUILD Category_B-Site_Type_Alpha, Награда=16.0, Terminated=False, Truncated=False, Period now: 0
Шаг 2: Действие=PASS, Награда=0.0, Terminated=False, Truncated=False, Period now: 1
Шаг 3: Действие=BUILD Category_A-Site_Type_Gamma, Награда=7.0, Terminated=False, Truncated=False, Period now: 1
Шаг 4: Действие=PASS, Награда=0.0, Terminated=False, Truncated=False, Period now: 2
Шаг 5: Действие=PASS, Награда=0.0, Terminated=True, Truncated=False, Period now: 3

Эпизод завершен на шаге 5.

--- Результат планирования ---
Освоенные ресурсы (Детали):
- Период: 0, Категория: Category_B, Тип: Site_Type_Alpha (Ценность: 16, Общ. стоимость: 14352038, Регион. стоимость: 3075759)
- Период: 1, Категория: Category_A, Тип: Site_Type_Gamma (Ценность: 7, Общ. стоимость: 15987521, Регион. стоимость: 4192091)

Всего освоено инициатив: 2
Суммарная оценочная ценность: 23
Потраченный общий бюджет: 30339559
Оставшийся общий бюджет (из среды): 377451354.0
За