In [None]:
cd ../../

In [None]:
no_prune_base_path = "./logs/{ENV}/{TRIAL}/n_episodes_{N_EPISODE}/seed_{SEED}/{DATE}/log.txt"
is_prune_base_path = "./logs/{ENV}/{TRIAL}/n_episodes_{N_EPISODE}/remain_rate_30/seed_{SEED}/{DATE}/log.txt"

"""
ENV = "BreakoutNoFrameskip-v4"
DATE_NO_PRUNE_LIST = ["2024_02_29_16_50_14", "2024_02_29_18_11_28", "2024_02_29_23_25_57"]
DATE_IS_PRUNE_LIST = ["2024_02_29_16_48_22", "2024_02_29_20_09_42", "2024_03_01_06_45_25"]
EPISODES = [300, 600, 900]
"""

ENV = "PongNoFrameskip-v4"
DATE_NO_PRUNE_LIST = ["2024_02_24_21_21_45", "2024_02_26_02_14_48"]
DATE_IS_PRUNE_LIST = ["2024_02_25_00_52_58", "2024_02_26_17_56_51"]
EPISODES = [600, 900]

In [None]:
import re
from datetime import datetime
import os
import numpy as np

def process_log_files(ENV, DATE_LIST, EPISODES, base_path, trial):
    # ログファイルのセクション開始と終了を示すパターン
    start_pattern = re.compile(r'^.*save_dir.*$')
    end_pattern = re.compile(r'.*メールが正常に送信されました。.*')
    
    # エピソード情報を抽出するための正規表現パターン
    episode_pattern = re.compile(r'(\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2},\d{3}) - INFO - Episode: (\d+),\s+Step: (\d+),\s+Reward: (-?\d+\.\d+)')
    
    data_dict = {}

    for idx, episode_count in enumerate(EPISODES):
        for seed in range(5):
            # ファイルパスのフォーマット
            log_file_path = base_path.format(ENV=ENV, TRIAL=trial, N_EPISODE=episode_count, SEED=seed, DATE=DATE_LIST[idx])
            
            # ファイル存在チェック
            if not os.path.exists(log_file_path):
                print(f"File does not exist: {log_file_path}")
                continue
            
            # エピソードキーの作成
            episode_key = f"{ENV}_Episode{episode_count}"
            
            # シードキーの作成
            seed_key = f"Seed{seed}_{DATE_LIST[idx]}"
            if episode_key not in data_dict:
                data_dict[episode_key] = {}
            if seed_key not in data_dict[episode_key]:
                data_dict[episode_key][seed_key] = {"reward_list": [], "timestamp_list": [], "episode_list": []}

            extracting = False
                
            # ファイルの読み込みとデータ抽出
            with open(log_file_path, 'r', encoding='utf-8') as file:
                for line in file:
                    # セクション開始条件チェック
                    if start_pattern.match(line):
                        extracting = True
                        continue
                    # セクション終了条件チェック
                    if end_pattern.match(line) and extracting:
                        extracting = False
                        break
                    
                    # データ抽出
                    if extracting:
                        match = episode_pattern.match(line)
                        if match:
                            timestamp_str, episode, step, reward = match.groups()
                            timestamp = datetime.strptime(timestamp_str, '%Y-%m-%d %H:%M:%S,%f')
                            
                            data_dict[episode_key][seed_key]["reward_list"].append(float(reward))
                            data_dict[episode_key][seed_key]["timestamp_list"].append(timestamp)
                            data_dict[episode_key][seed_key]["episode_list"].append(int(episode))
    return data_dict

# この関数の使用例
no_prune_data_dict = process_log_files(ENV, DATE_NO_PRUNE_LIST, EPISODES, no_prune_base_path, 'no_prune')
is_prune_data_dict = process_log_files(ENV, DATE_IS_PRUNE_LIST, EPISODES, is_prune_base_path, 'is_prune')

       
       


In [None]:
import matplotlib.pyplot as plt

cutoff = 10
dic = is_prune_data_dict

for k in dic.keys():
    episode = int(k[-3:])
    per_epoch = dic[k]

    all_times = []
    all_rewards = []
    for per_epoch_k in per_epoch.keys():
        seed_name = per_epoch_k[:5]
        per_seed = per_epoch[per_epoch_k]
        total_time = 0
        total_reward = 0
        time_list = []
        reward_list = []
        
        for t in np.arange(0, episode, cutoff):
            start = t
            end = start + cutoff
            per_seed['timestamp_list'][start:end]
            sec = per_seed['timestamp_list'][start:end][-1] - per_seed['timestamp_list'][start:end][0]
            cummurative_reward = sum(per_seed['reward_list'][start:end])
            total_reward += cummurative_reward
            total_time += sec.seconds

            time_list.append(total_time)
            reward_list.append(total_reward)
        
        plt.plot(time_list, reward_list, label=seed_name)  # 凡例用のラベルを追加
        all_times.append(time_list)
        all_rewards.append(reward_list)


    plt.title("Cumulative Time vs. Cumulative Reward")
    plt.xlabel("Cumulative Time (seconds)")
    plt.ylabel("Cumulative Reward")
    #plt.legend()
    plt.show()

    mean_rewards = np.mean(np.array(all_rewards), axis=0)
    std_rewards = np.std(np.array(all_rewards), axis=0)

    mean_times = np.mean(np.array(all_times), axis=0)

    # Plotting both 'is_prune' and 'no_prune' rewards
    #episodes_range = range(1, len(mean_is_prune_rewards) + 1)
    plt.plot(mean_times, mean_rewards, )
    plt.fill_between(mean_times, mean_rewards - std_rewards, mean_rewards + std_rewards, alpha=0.5)

    #plt.plot(episodes_range, mean_no_prune_rewards, label=f'No Prune (Episode {e}) Mean Reward')
    #plt.fill_between(episodes_range, mean_no_prune_rewards - std_no_prune_rewards, mean_no_prune_rewards + std_no_prune_rewards, alpha=0.5)

    plt.title("Mean and std of Rewards over Episodes at {}".format(ENV))
    plt.xlabel("Episode")
    plt.ylabel("Reward")
    plt.legend()
    plt.show()

In [None]:
per_epoch_k[:5]

In [None]:
total_reward = 0
for t in np.arange(0, episode, 100):
    start = t
    end = start + 100
    per_seed['timestamp_list'][start:end]
    sec = per_seed['timestamp_list'][start:end][-1] - per_seed['timestamp_list'][start:end][0]
    cummurative_reward = sum(per_seed['reward_list'][start:end])
    total_reward += cummurative_reward

    time_list.append(sec)
    time_list.append(sec)


In [None]:
total_reward

In [None]:
per_seed['timestamp_list'][start:end][-1]