In [59]:
#从已有模型finetuning
# config=[
#     "--logdir", "./runs/model_test/main_model",
#     "--model_suffix", "free",
# ]
import os
os.environ['NOTEBOOK']="1"

from params import parser

from model.PPO import PPO
from train.base import *


In [60]:

class PPOT(PPO):
    def __init__(self, config):
        super().__init__(config)


    def update(self, memory, freeze_feature_exact = False):
        '''
        :param memory: data used for PPO training
        :return: total_loss and critic_loss
        '''

        # 获取转置后的训练数据，用于策略更新
        t_data = memory.transpose_data()  # Tensor len 13  pre torch.Size([1000, 50, 10])
        # 计算广义优势估计（GAE）和目标价值  A_t, G_t
        t_advantage_seq, v_target_seq = memory.get_gae_advantages()

        full_batch_size = len(t_data[-1])  # 获取完整批次大小 # 1000
        num_batch = np.ceil(full_batch_size / self.minibatch_size)  # 计算小批次数 1.0

        loss_epochs = 0
        v_loss_epochs = 0
        if freeze_feature_exact:
            for name, param in self.policy.named_parameters():
                if name.startswith('feature_exact'):
                    param.requires_grad = False

        for _ in range(self.k_epochs):  # 4
            # 对每个迭代进行小批次的策略更新
            # Split into multiple batches of updates due to memory limitations
            
            for i in range(int(num_batch)):
                if i + 1 < num_batch:
                    start_idx = i * self.minibatch_size
                    end_idx = (i + 1) * self.minibatch_size
                else:
                    # the last batch  处理最后一个小批次
                    start_idx = i * self.minibatch_size
                    end_idx = full_batch_size

                # 通过策略网络获取动作分布和值函数估计
                pis, vals = self.policy(fea_j=t_data[0][start_idx:end_idx],
                                        op_mask=t_data[1][start_idx:end_idx],
                                        candidate=t_data[6][start_idx:end_idx],
                                        fea_m=t_data[2][start_idx:end_idx],
                                        mch_mask=t_data[3][start_idx:end_idx],
                                        comp_idx=t_data[5][start_idx:end_idx],
                                        dynamic_pair_mask=t_data[4][start_idx:end_idx],
                                        fea_pairs=t_data[7][start_idx:end_idx])

                action_batch = t_data[8][start_idx: end_idx]  # 获取动作序列
                logprobs, ent_loss = eval_actions(pis, action_batch)  # 计算动作的概率和熵损失
                ratios = torch.exp(logprobs - t_data[12][start_idx: end_idx].detach())  # 计算重要性采样比率

                advantages = t_advantage_seq[start_idx: end_idx]  # 获取优势估计
                surr1 = ratios * advantages  # 计算第一个损失项
                surr2 = torch.clamp(ratios, 1 - self.eps_clip, 1 + self.eps_clip) * advantages  # 计算第二个损失项

                v_loss = self.V_loss_2(vals.squeeze(1), v_target_seq[start_idx: end_idx])  # 计算价值损失
                p_loss = - torch.min(surr1, surr2)  # 计算策略损失   L^PPO-clip(pi_theta)
                ent_loss = - ent_loss.clone()  # 计算熵损失
                loss = self.vloss_coef * v_loss + self.ploss_coef * p_loss + self.entloss_coef * ent_loss  # 计算总损失
                # 梯度清零，进行反向传播和优化
                self.optimizer.zero_grad()  
                loss_epochs += loss.mean().detach()
                v_loss_epochs += v_loss.mean().detach()
                loss.mean().backward()
                # # 查看哪些参数受到loss的影响
                # for name, param in self.policy.named_parameters():
                #     if param.grad is not None and torch.sum(torch.abs(param.grad)) > 0:
                #         print(name, "受到了loss的影响")
                #     else:
                #         print(name, "没有受到loss的影响")
                self.optimizer.step()
        # soft update 进行软更新
        for policy_old_params, policy_params in zip(self.policy_old.parameters(), self.policy.parameters()):
            policy_old_params.data.copy_(self.tau * policy_old_params.data + (1 - self.tau) * policy_params.data)

        return loss_epochs.item() / self.k_epochs, v_loss_epochs.item() / self.k_epochs


In [61]:

def PPO_initialize():
    ppo = PPOT(configs)
    
    # writer = SummaryWriter(log_dir=configs.logdir, flush_secs=180)

    # writer.add_graph(dict(ppo.policy.named_parameters()))
    # writer.close()
    return ppo


In [64]:

class DANTrainer(Trainer):
    def __init__(self, config):

        super().__init__(config)
        self.env = FJSPEnvForSameOpNums(self.n_j, self.n_m)
        self.finetuning_model = f'../trained_network/SD2/10x5+mix.pth'
        self.ppo = PPO_initialize()
        self.ppo.policy.load_state_dict(torch.load(self.finetuning_model, map_location='cuda'))
        self.ppo.policy_old = deepcopy(self.ppo.policy)
        print(self.finetuning_model)


    def train(self):
        """
            train the model following the config
        """
        setup_seed(self.seed_train)
        self.log = []
        self.validation_log = []
        self.record = float('inf')
        print("-" * 25 + "Training Setting" + "-" * 25)
        print(f"source : {self.data_source}")
    
        print(f"model name :{self.finetuning_model}")
        print(f"vali data :{self.vali_data_path}")
        print("\n")

        self.train_st = time.time()

        for i_update in tqdm(range(self.max_updates), file=sys.stdout, desc="progress", colour='blue'):
            ep_st = time.time()

            # resampling the training data
            if i_update  == 0:
                dataset_job_length, dataset_op_pt = self.sample_training_instances()
                # print(dataset_op_pt[0])
                state = self.env.set_initial_data(dataset_job_length, dataset_op_pt)
                state.print_shape()
                print(f"EnvState(\n"
                f"  fea_j_tensor 形状: {state.fea_j_tensor[0]},\n"
                f"  op_mask_tensor 形状: {state.op_mask_tensor[0]},\n"
                f"  candidate_tensor 形状: {state.candidate_tensor[0]},\n"
                f"  fea_m_tensor 形状: {state.fea_m_tensor[0]},\n"
                f"  mch_mask_tensor 形状: {state.mch_mask_tensor[0]},\n"
                f"  comp_idx_tensor 形状: {state.comp_idx_tensor[0]},\n"
                f"  dynamic_pair_mask_tensor 形状: {state.dynamic_pair_mask_tensor[0]},\n"
                f"  fea_pairs_tensor 形状: {state.fea_pairs_tensor[0]}\n"
                f")")

            else:
                state = self.env.reset()
                

            ep_rewards = - deepcopy(self.env.init_quality)

            while True:

                # state store
                self.memory.push(state)
                with torch.no_grad():

                    pi_envs, vals_envs = self.ppo.policy_old(fea_j=state.fea_j_tensor,  # [sz_b, N, 8]
                                                            op_mask=state.op_mask_tensor,  # [sz_b, N, N]
                                                            candidate=state.candidate_tensor,  # [sz_b, J]
                                                            fea_m=state.fea_m_tensor,  # [sz_b, M, 6]
                                                            mch_mask=state.mch_mask_tensor,  # [sz_b, M, M]
                                                            comp_idx=state.comp_idx_tensor,  # [sz_b, M, M, J]
                                                            dynamic_pair_mask=state.dynamic_pair_mask_tensor,  # [sz_b, J, M]
                                                            fea_pairs=state.fea_pairs_tensor)  # [sz_b, J, M]

                # sample the action
                action_envs, action_logprob_envs = sample_action(pi_envs)

                # state transition
                state, reward, done = self.env.step(actions=action_envs.cpu().numpy())
                ep_rewards += reward
                reward = torch.from_numpy(reward).to(device)

                # collect the transition
                self.memory.done_seq.append(torch.from_numpy(done).to(device))
                self.memory.reward_seq.append(reward)
                self.memory.action_seq.append(action_envs)
                self.memory.log_probs.append(action_logprob_envs)
                self.memory.val_seq.append(vals_envs.squeeze(1))

                if done.all():
                    break

            loss, v_loss = self.ppo.update(self.memory)
            self.memory.clear_memory()

            mean_rewards_all_env = np.mean(ep_rewards)
            mean_makespan_all_env = np.mean(self.env.current_makespan)
            # print(self.env.current_makespan)
            if i_update < 2: vali_result = mean_makespan_all_env 

            # save the mean rewards of all instances in current training data
            self.log.append([i_update, mean_rewards_all_env])

            ep_et = time.time()
            # print the reward, makespan, loss and training time of the current episode
            tqdm.write(
                'Episode {}\t reward: {:.2f}\t makespan: {:.2f}\t Mean_loss: {:.8f},  training time: {:.2f}'.format(
                    i_update + 1, mean_rewards_all_env, mean_makespan_all_env, loss, ep_et - ep_st))
            scalars = {f"makespan_{i}":m  for i, m in zip(range(self.num_envs), self.env.current_makespan)}
            scalars.update({
                'Loss/train': loss
                ,'makespan_train':mean_makespan_all_env
                ,'makespan_validate':vali_result
            })
            
            self.iter_log(i_update, scalars)

        self.train_et = time.time()

        # log results
        self.save_training_log()




In [65]:
configs = parser.parse_args(args=[
    "--logdir", "./runs/model_test/main_model",
    "--model_suffix", "free",
    "--max_updates", "21",
    ])

trainer = DANTrainer(configs)

trainer.train()


vali_data = ./data/data_train_vali/SD2/10x5+mix
save model name:  10x5+mix+free
../trained_network/SD2/10x5+mix.pth
-------------------------Training Setting-------------------------
source : SD2
model name :../trained_network/SD2/10x5+mix.pth
vali data :./data/data_train_vali/SD2/10x5+mix


progress:   0%|[34m          [0m| 0/21 [00:00<?, ?it/s]torch.Size([20, 50, 10])
torch.Size([20, 50, 3])
torch.Size([20, 10])
torch.Size([20, 5, 8])
torch.Size([20, 5, 5])
torch.Size([20, 5, 5, 10])
torch.Size([20, 10, 5])
torch.Size([20, 10, 5, 8])
EnvState(
  fea_j_tensor 形状: tensor([[ 0.0000e+00, -1.3136e+00, -7.8598e-03,  2.9876e-01,  5.6845e-01,
          0.0000e+00,  0.0000e+00,  0.0000e+00, -7.0125e-02, -2.8880e-02],
        [ 0.0000e+00, -1.1668e+00, -8.8117e-01,  7.8588e-01, -2.5253e-01,
          0.0000e+00,  0.0000e+00,  0.0000e+00, -7.0125e-02,  1.4151e+00],
        [ 0.0000e+00, -6.2177e-01, -5.1526e-02, -4.8062e-01, -4.2045e-01,
          0.0000e+00,  0.0000e+00,  0.0000e+00, -7.0125

-------------------------Training Setting-------------------------
source : SD2
model name :../trained_network/SD2/10x5+mix.pth
vali data :./data/data_train_vali/SD2/10x5+mix


progress:   0%|[34m          [0m| 0/21 [00:00<?, ?it/s][[ 0  0 71 27 81]
 [ 7 15 76 55 72]
 [ 0 44  0 56 26]
 [ 0 85  0  0 13]
 [19 82  0 52 45]
 [ 0 57 92 50  0]
 [ 4  0 12 22 90]
 [ 0  0 12  0 95]
 [ 0  0 88 15  0]
 [ 0 13  0  0  0]
 [ 0 62  0  0 48]
 [ 0  0 87 47  0]
 [ 0  0 26  0  0]
 [ 1  0  0  0  0]
 [69  0  0  0 61]
 [66 77 68 45  0]
 [ 8 89 71 14 29]
 [64  0  0  0  0]
 [ 0  0 59  0  0]
 [79  7 66 95 71]
 [ 0  0 77 77  0]
 [ 0 14 45  2 42]
 [79  0 88 64  0]
 [ 4 18 89 88 70]
 [98  0  3 19  0]
 [46 58 36 19 92]
 [47  0  0  0  0]
 [ 0  0 27  0  0]
 [11  0  0 59  0]
 [ 0 26 28 58  4]
 [79 86 60 49 57]
 [88 88  5 70 57]
 [61 10 67 34 70]
 [ 0 80 29  0  0]
 [ 4 49  4  0 92]
 [56  0  0 50 88]
 [86  0  0  0 67]
 [46 11  9 97 26]
 [ 0  0  0 71  0]
 [40  0 93 39  6]
 [10 41  0  0 22]
 [ 0  0  0 45  0]
 [ 0 82 62