In [1]:
from sklearn.linear_model import LogisticRegression
import scipy.io
import numpy as np
import h5py
from sklearn.datasets import make_classification
from sklearn.metrics import accuracy_score
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.neighbors import KNeighborsClassifier
import pandas as pd
from sklearn.ensemble import RandomForestClassifier
import torch.nn as nn
from scipy.stats.stats import pearsonr
import torch
from matplotlib.pyplot import figure

  from .autonotebook import tqdm as notebook_tqdm


In [2]:

class DataManager():
    def __init__(self, params, num_bands):
        self.rl_data = None
        self.dataset_type = params['dataset_type']
        self.data_file_path = params['data_file_path']
        self.sample_ratio = params['sample_ratio']
        
        self.num_bands = num_bands
        #load the data
        assert self.dataset_type in ('IndianPines', 'Botswana', 'SalientObjects'), f'{self.dataset_type} is not valid'
        #separating out in case any of the data requires unique pre-processig
        if self.dataset_type == 'IndianPines':
            self.load_indian_pine_data()
        elif self.dataset_type == 'Botswana':
            self.load_botswana_data()
        elif self.dataset_type == 'SalientObjects':
            self.load_salient_objects_data()
        #self.x_train = None
        #self.y_train = None
        #self.x_test = None
        #self.y_test = None
        
    def load_indian_pine_data(self):
        hyper_path = '/Users/romitbarua/Documents/Berkeley/Fall 2022/CS285-Deep Reinforcement Learning/HyperSpectralRL/data/data_indian_pines_drl.mat'
        hyper = scipy.io.loadmat(hyper_path)['x'][:, :self.num_bands]
        #hyper = np.load(hyper_path)
        # randomly sample for x% of the pixels
        indices = np.random.randint(0, hyper.shape[0], int(hyper.shape[0]*self.sample_ratio))
        self.rl_data = hyper[indices, :]
        print(self.rl_data.shape)
        
    def load_salient_objects_data(self):
        hyper_path = '../data/salient_objects/hyperspectral_imagery/0001.npy'
        hyper = np.load(hyper_path)
        print(hyper.shape)
        # randomly sample for x% of the pixels
        indices = np.random.randint(0, hyper.shape[0], int(hyper.shape[0]*self.sample_ratio))
        self.rl_data = hyper[indices, :]
        print(self.rl_data.shape)
        
    def load_botswana_data(self):
        self.rl_data = scipy.io.loadmat(self.data_file_path)
    #def load_salient_objects(self)


In [3]:
class ReplayBuffer():
    
    def __init__(self, size=100000):
        self.size = size
        self.paths = []
        
    def add_trajectories(self, paths):
        self.paths.extend(paths)
        self.paths = self.paths[-self.size:]
        
    def sample_buffer_random(self, num_trajectories):
        
        rand_idx = np.random.permutation(len(self.paths))[:num_trajectories]
        return [self.paths[i] for i in rand_idx]
        #return self.paths[rand_idx]
    
        

In [4]:
class Agent():
    
    def __init__(self, params):
        
        self.agent_params = params['agent']
        self.num_bands = self.agent_params['num_bands']
        self.n_iter = self.agent_params['n_iter']
        self.trajectory_sample_size = self.agent_params['trajectory_sample_size']
        self.batch_size = self.agent_params['batch_size']
        self.num_critic_updates = self.agent_params['num_critic_updates']
        
        valid_rewards = ['correlation', 'mutual_info']
        assert self.agent_params['reward_type'] in valid_rewards, 'rewards must be one of ' + valid_rewards.join(',') 
        
        if self.agent_params['reward_type'] == 'correlation':
            self.reward_func = self.calculate_correlations
        elif self.agent_param['reward_type'] == 'mutual_info':
            self.reward_func = self.calculate_mutual_infos
        
        self.data_params = params['data']
        self.DataManager = DataManager(self.data_params, self.num_bands)
        self.band_selection_num = self.data_params['band_selection_num']

        self.critic_params = params['critic']
        self.critic = QCritic(self.critic_params, self.num_bands)
        
        
        self.policy_params = params['policy']
        self.policy = ArgMaxPolicy(self.policy_params, self.critic)
        
        self.replay_buffer = ReplayBuffer()
        
        self.cache = {}
        
        self.logging_df = pd.DataFrame()
        
    
    def generateTrajectories(self):
        
        #we expect paths to be a list of trajectories
        #a trajectory is a list of Path objects
        paths = []
        for i in range(self.trajectory_sample_size):
            
            path = self.sampleTrajectory()
#             print(f'Iter {i}')
#             print([p['re'] for p in path])
            
            paths.append(path)
    
        return paths
    
    def sampleTrajectory(self, iter_num = 1):
            
        #select 30 actions
        state = np.zeros(self.num_bands)
        state_next = state.copy()
        
        #paths will be a list of dictionaries
        path = []
        for i in range(self.band_selection_num):
            
            action, action_type = self.policy.get_action(state)
            state_next[action] += 1

            reward, correlation_current_state, correlation_next_state = self.calculate_reward(state, state_next)

            terminal = 1 if i == self.band_selection_num - 1 else 0
            path.append(self.Path(state.copy(), action, state_next.copy(), reward, terminal))
            
            state = state_next.copy()
        
            if iter_num % 25 == 0:
                print("Iter : ", iter_num)
                q_values = self.critic.get_action(state)
                
                sampled_paths = self.replay_buffer.sample_buffer_random(1)
                
                flat_sampled_path = [path for trajectory in sampled_paths for path in trajectory]
                obs = np.array([path['ob'] for path in flat_sampled_path])
                acs = np.array([path['ac'] for path in flat_sampled_path])
                obs_next = np.array([path['ob_next'] for path in flat_sampled_path])
                res = np.array([path['re'] for path in flat_sampled_path])
                terminals = np.array([path['terminal'] for path in flat_sampled_path])
                
                loss_value = self.critic.update(obs, acs, obs_next, res, terminals)
                
                row = {
                    "iter_num": iter_num,
                    "Selected Band": i,
                    "Action Type": action_type,
                    "Mean": torch.mean(q_values).detach().numpy(),
                    "Min": torch.min(q_values).detach().numpy(),
                    "Max": torch.max(q_values).detach().numpy(),
                    "Correlation Current State" : correlation_current_state,
                    "Correlation Next State" : correlation_next_state,
                    "Reward" : reward,
                    "Loss" : loss_value
                }
                
                self.logging_df = self.logging_df.append(row, ignore_index=True)
                
                
#                 print(self.logging_df)
        
        #path returns state, action, state_next, reward, terminal
        return path
                   
        
    def runAgent(self):
        
        for iter_num in range(self.n_iter):
            
            print('Iteration ', iter_num, ':')
            
            paths = self.generateTrajectories()
            self.replay_buffer.add_trajectories(paths)
            
            for _ in range(self.num_critic_updates):
                sampled_paths = self.replay_buffer.sample_buffer_random(self.agent_params['batch_size'])
                
                flat_sampled_path = [path for trajectory in sampled_paths for path in trajectory]
                obs = np.array([path['ob'] for path in flat_sampled_path])
                acs = np.array([path['ac'] for path in flat_sampled_path])
                obs_next = np.array([path['ob_next'] for path in flat_sampled_path])
                res = np.array([path['re'] for path in flat_sampled_path])
                terminals = np.array([path['terminal'] for path in flat_sampled_path])
                
                critic_loss = self.critic.update(obs, acs, obs_next, res, terminals)
                
            self.critic.update_target_network()
            
            #sample a single trajectory
            print('------------------------------------EVAL Results------------------------------')
            eval_path = self.sampleTrajectory(iter_num)
#             print(self.cache)
            #print(eval_path[-1])
            print('Selected_Bands: ', np.argwhere(eval_path[-1]['ob_next']>0).flatten())
            print('Num_Selected_Bands: ', np.argwhere(eval_path[-1]['ob_next']>0).shape[0])
            print('Eval_Return: ', np.sum(eval_path[-1]['re']))
            print('Critic_Loss: ', critic_loss)
            print('Correlation: ', self.logging_df.loc[self.logging_df.shape[0]-1, 'Correlation Next State'])
                

    def calculate_reward(self, state, state_next):
        #for future, save down the previous state so that we can avoid a calc
        
#         print(list(np.argwhere(np.array(state) != 0)), list(np.argwhere(np.array(state_next) != 0)))
        if list(np.argwhere(np.array(state) != 0)) == list(np.argwhere(np.array(state_next) != 0)):
#             print("same action selected")
            return -1, "Indef", "Indef"
        else:
        
            a = self.reward_func(state)
            b = self.reward_func(state_next)
            #FLIPPED THE SIGN FOR TESTING
            return np.exp(a-b), a, b
    
    
    def calculate_correlations(self, state):
        
#         if repr(state) in self.cache:
#             return self.cache[repr(state)]
        
        #deal with the first state
        ##### THIS LOGIC SEEMS WRONG - REGARDLESS OF THE FIRST PICK, YOU HAVE A REWARD OF 0#####
        if np.sum(state) <= 1:
            return 0
        
        selected_bands = []
        non_zero_bands = np.argwhere(np.array(state) != 0)
        for band in non_zero_bands:
#             print(band[0])
            selected_bands.extend([band[0]]*int(state[band[0]]))
        #print(selected_bands)
        #selected_bands = np.squeeze(np.argwhere(np.array(state)==1))
        corr_sum = 0
        for idx_i, i in enumerate(selected_bands):
            for idx_j, j in enumerate(selected_bands):
                if idx_i != idx_j:
                    
                    if repr((i,j)) in self.cache:
                        result = self.cache[repr((i,j))]
                    else:
                        result = abs(pearsonr(self.DataManager.rl_data[:, i], self.DataManager.rl_data[:, j])[0])
                        self.cache[repr((i,j))] = result
                    
                    corr_sum += result
                    
#                     corr_sum += abs(pearsonr(self.DataManager.rl_data[:, i], self.DataManager.rl_data[:, j])[0])
        
#         self.cache[repr(state)] = corr_sum/(len(selected_bands)**2)
        
#         return self.cache[repr(state)]
        return corr_sum/(len(selected_bands)**2)

    def calculate_mutual_infos(self, state):
    
        selected_bands = []
        non_zero_bands = np.argwhere(np.array(state) != 0)
        for band in non_zero_bands:
            selected_bands.extend([band[0]]*int(state[band[0]]))
    
        normalized_mutual_info_score_sum = 0
        for i in selected_bands:
            for j in selected_bands:
                
                if i != j:

                    normalized_mutual_info_score_sum += normalized_mutual_info_score(self.DataManager.rl_data[:, i],
                                                                                     self.DataManager.rl_data[:, j])

        return normalized_mutual_info_score_sum/(len(selected_bands)**2)

            
    def Path(self, ob, ac, ob_next, re, terminal):
        return {'ob':ob,
                'ac':ac,
                'ob_next':ob_next,
                're':re,
                'terminal':terminal
                }

In [5]:
class ArgMaxPolicy():
    
    def __init__(self, params, critic):
        self.epsilon = params['epsilon']
        self.epsilon_decay = params['epsilon_decay']
        self.critic = critic
        
    def get_action(self, obs):
        
        q_value_estimates = self.critic.get_action(obs)
        unselected_bands = np.squeeze(np.argwhere(obs == 0))
        #print(obs)
#         print('Predicted Q-Values:', q_value_estimates)
        
        rand = np.random.rand()
        if rand < self.epsilon:
            #select a random action
#             print('Selected Random')
            unselected_bands = np.squeeze(np.argwhere(obs == 0))
            selected_idx = np.random.choice(unselected_bands)
            action_type = "Random Action"

        else:
#           print('Selected Max')
            #q_value_estimates_idx = torch.argsort(q_value_estimates, dim=1)
            #q_value_estimates = q_value_estimates[unselected_bands, :]
        
            #q_filter = q_value_estimates[unselected_bands]
            q_value_estimates_idx = torch.argsort(q_value_estimates, descending=True)
            q_value_estimates_idx = q_value_estimates_idx[torch.isin(q_value_estimates_idx, torch.tensor(unselected_bands))]
            selected_idx = q_value_estimates_idx[0].item()
                
            action_type = "Max Action"

            
        self.decay_epsilon()
        return selected_idx, action_type
                
    def decay_epsilon(self):
        self.epsilon *= self.epsilon_decay 

In [6]:
class ActorPolicy():
    
    def __init__(self):
        pass
    
    
    
    

In [7]:
class QCritic():
    
    def __init__(self, params, num_bands):
        
        self.num_bands = num_bands

        
        self.critic = self.create_network()
        self.critic_optimizer = torch.optim.Adam(self.critic.parameters(),lr=0.005)

        self.critic_target = self.create_network()
        
        self.gamma = params['gamma']
        
        self.loss = nn.SmoothL1Loss()
    
    def create_network(self):
        
        q_net  = nn.Sequential(
        nn.Linear(self.num_bands, self.num_bands*2),
        nn.ReLU(),
        nn.Linear(self.num_bands*2, self.num_bands*2),
        nn.ReLU(),
        nn.Linear(self.num_bands*2, self.num_bands)
        )
        
        return q_net
    
        
    def forward(self, obs):
        # will take in one hot encoded states and output a list of qu values
        
        q_values = self.critic(obs)
        
        return q_values
    
    def get_action(self, obs):
        
        if isinstance(obs, np.ndarray):
            obs = from_numpy(obs)
            
        return self.critic(obs)
    
    def update(self, obs, ac_n, next_obs, reward_n, terminals):
        
        obs = self.check_tensor(obs) #comes in as shape 
        ac_n = self.check_tensor(ac_n)
        next_obs = self.check_tensor(next_obs)
        reward_n = self.check_tensor(reward_n)
        terminals = self.check_tensor(terminals)
        
        full_q_values = self.critic(obs)
        q_actions = full_q_values.argmax(dim=1)
        q_values = torch.gather(full_q_values, 1, q_actions.unsqueeze(1)).squeeze(1)
        
        
        #print('Obs ', obs.shape)
        #print('Full Q ', full_q_values.shape)
        #print('Q Actions ', q_actions.shape)
        #print('Q Val ', q_values.shape)
        
        full_q_next_target = self.critic_target(next_obs)
        q_actions_next = self.critic(next_obs).argmax(dim=1)
        #q_values_next = full_q_next.max(dim=1)
        #print('q_values_next', q_values_next)
        q_values_next = torch.gather(full_q_next_target, 1, q_actions_next.unsqueeze(1)).squeeze(1)
        
        #print('reward', type(reward_n.shape))
        #print('q_values_next', type(q_values_next))
        #print('terminals', type(terminals))
        #print('gamma', type(self.gamma))
        target = reward_n + self.gamma*q_values_next*(1-terminals)
        target = target.detach()
        
        #print(f'Target Dim: {target.shape}')
        #print(f'Q_Values Dim: {q_values.shape}')
        loss = self.loss(q_values, target)
        
        self.critic_optimizer.zero_grad()
        loss.backward()
        self.critic_optimizer.step()
        
        return loss.item()
    
    def check_tensor(self, ar):
        
        if isinstance(ar, np.ndarray):
            ar = from_numpy(ar)
            
        return ar
    
    def update_target_network(self):
        for target_param, param in zip(
            self.critic_target.parameters(), self.critic.parameters()):
            target_param.data.copy_(param.data)
        
        

In [8]:
#utility functions
## taken from Prof. Sergey Levine's CS285 HW

device = 'cpu'

def from_numpy(*args, **kwargs):
    return torch.from_numpy(*args, **kwargs).float().to(device)

def to_numpy(tensor):
    return tensor.to('cpu').detach().numpy()

In [9]:
params = {'agent':{
            'n_iter':5000,
            'trajectory_sample_size': 10,
            'batch_size':10,
            'num_critic_updates':10,
            'num_bands':200,
            'reward_type':'correlation'
            },
          'data':{
            'band_selection_num':30,
            'dataset_type':'IndianPines',
            'data_file_path':r'/Users/romitbarua/Documents/Berkeley/Fall 2022/CS285-Deep Reinforcement Learning/HyperSpectralRL/data/data_indian_pines_drl.mat',
            'sample_ratio':0.1
            },
          'critic':{
            'gamma':0.99
            },
          'policy':{
            'epsilon':0.99,
            'epsilon_decay':0.9999
            }
         }

In [10]:
agent = Agent(params)
agent.runAgent()

(2102, 200)
Iteration  0 :
------------------------------------EVAL Results------------------------------
Iter :  0
Iter :  0
Iter :  0
Iter :  0
Iter :  0
Iter :  0
Iter :  0
Iter :  0
Iter :  0
Iter :  0
Iter :  0
Iter :  0
Iter :  0
Iter :  0
Iter :  0
Iter :  0
Iter :  0
Iter :  0
Iter :  0
Iter :  0
Iter :  0
Iter :  0
Iter :  0
Iter :  0
Iter :  0
Iter :  0
Iter :  0
Iter :  0
Iter :  0
Iter :  0
Selected_Bands:  [ 10  24  25  33  34  39  47  48  62  69  70  76  82  83  86  91  96  98
  99 123 141 144 145 146 169 178 179 180 185 195]
Num_Selected_Bands:  30
Eval_Return:  0.9933934521666574
Critic_Loss:  0.008239541202783585
Correlation:  0.5441496535195535
Iteration  1 :
------------------------------------EVAL Results------------------------------
Selected_Bands:  [  7  22  29  30  33  46  49  56  60  64  65  67  74  75  80  83  89 106
 108 111 114 130 133 136 146 151 160 175 187 188]
Num_Selected_Bands:  30
Eval_Return:  1.022642571516488
Critic_Loss:  0.16687403619289398
Corre

------------------------------------EVAL Results------------------------------
Selected_Bands:  [ 19  22  38  39  45  48  59  65  73  78  81  86  91 103 111 119 127 129
 140 141 150 153 159 164 167 169 183 187 190 191]
Num_Selected_Bands:  30
Eval_Return:  0.9912129934490119
Critic_Loss:  0.6800194978713989
Correlation:  0.5441496535195535
Iteration  23 :
------------------------------------EVAL Results------------------------------
Selected_Bands:  [  0   5  38  53  59  65  73  81  91  94 106 108 111 115 118 119 120 127
 137 153 158 159 160 161 163 169 183 184 190 191]
Num_Selected_Bands:  30
Eval_Return:  0.9921617298887097
Critic_Loss:  0.593532383441925
Correlation:  0.5441496535195535
Iteration  24 :
------------------------------------EVAL Results------------------------------
Selected_Bands:  [  5  14  22  35  38  42  45  49  57  59  64  65  73  77  81  91  98 111
 116 119 127 134 140 159 169 170 183 190 191 198]
Num_Selected_Bands:  30
Eval_Return:  0.9940788292843065
Critic_Lo

------------------------------------EVAL Results------------------------------
Selected_Bands:  [  5   6  23  30  38  40  45  48  54  59  65  73  81  91 107 108 111 117
 119 127 140 159 160 167 169 172 183 187 190 191]
Num_Selected_Bands:  30
Eval_Return:  0.9927416248094941
Critic_Loss:  0.8309006690979004
Correlation:  0.6321512357528142
Iteration  45 :
------------------------------------EVAL Results------------------------------
Selected_Bands:  [ 38  41  45  48  56  59  64  65  73  78  81  91 101 111 119 120 127 129
 140 154 156 159 164 169 183 187 190 191 192 195]
Num_Selected_Bands:  30
Eval_Return:  0.995002261811601
Critic_Loss:  0.9441883563995361
Correlation:  0.6321512357528142
Iteration  46 :
------------------------------------EVAL Results------------------------------
Selected_Bands:  [ 34  38  45  48  59  64  65  73  76  78  81  91  98 111 119 127 129 131
 140 154 159 164 169 170 183 185 187 190 191 195]
Num_Selected_Bands:  30
Eval_Return:  1.0060133474349575
Critic_Lo

------------------------------------EVAL Results------------------------------
Selected_Bands:  [ 10  34  35  38  44  45  48  59  65  73  78  81  89  91  98 111 119 127
 129 139 140 154 159 164 169 183 187 190 191 195]
Num_Selected_Bands:  30
Eval_Return:  1.036072589167126
Critic_Loss:  1.0424110889434814
Correlation:  0.6188338831172048
Iteration  67 :
------------------------------------EVAL Results------------------------------
Selected_Bands:  [ 24  34  38  45  48  56  59  65  73  75  78  81  89  91  98 111 119 127
 129 139 140 154 159 164 169 183 187 190 191 195]
Num_Selected_Bands:  30
Eval_Return:  0.9918857686244501
Critic_Loss:  1.0895600318908691
Correlation:  0.6188338831172048
Iteration  68 :
------------------------------------EVAL Results------------------------------
Selected_Bands:  [  1  34  38  45  48  59  65  73  75  78  81  89  91  93  98 111 119 127
 129 140 154 159 164 169 183 187 190 191 192 195]
Num_Selected_Bands:  30
Eval_Return:  1.032569509183627
Critic_Los

------------------------------------EVAL Results------------------------------
Selected_Bands:  [  1   6  34  38  45  48  59  65  66  73  78  81  89  91  98 111 119 127
 129 139 140 154 159 164 169 183 187 190 191 195]
Num_Selected_Bands:  30
Eval_Return:  1.0051296412624968
Critic_Loss:  1.4808528423309326
Correlation:  0.5755578983976853
Iteration  89 :
------------------------------------EVAL Results------------------------------
Selected_Bands:  [  6  34  38  45  48  59  65  73  78  81  89  91  98 111 119 122 127 129
 139 140 144 154 159 164 169 183 187 190 191 195]
Num_Selected_Bands:  30
Eval_Return:  1.0379393387978784
Critic_Loss:  1.4833528995513916
Correlation:  0.5755578983976853
Iteration  90 :
------------------------------------EVAL Results------------------------------
Selected_Bands:  [ 34  38  45  48  59  64  65  73  78  81  89  91  98 101 111 119 127 129
 139 140 154 159 164 169 183 187 190 191 195 198]
Num_Selected_Bands:  30
Eval_Return:  1.0361323705855474
Critic_L

------------------------------------EVAL Results------------------------------
Selected_Bands:  [ 34  36  38  45  48  59  65  66  73  78  81  89  91  98 105 111 119 127
 129 139 140 154 159 164 169 183 187 190 191 195]
Num_Selected_Bands:  30
Eval_Return:  0.9932920353346101
Critic_Loss:  1.4900213479995728
Correlation:  0.5991755868517956
Iteration  111 :
------------------------------------EVAL Results------------------------------
Selected_Bands:  [ 25  34  38  45  48  59  65  66  73  78  81  89  91  98 105 111 119 127
 129 139 140 154 159 164 169 183 187 190 191 195]
Num_Selected_Bands:  30
Eval_Return:  0.9951609643485331
Critic_Loss:  1.50579035282135
Correlation:  0.5991755868517956
Iteration  112 :
------------------------------------EVAL Results------------------------------
Selected_Bands:  [  8  34  38  45  48  54  59  65  73  78  81  89  91  98 111 119 127 129
 139 140 148 154 159 164 169 183 187 190 191 195]
Num_Selected_Bands:  30
Eval_Return:  1.0375529048763719
Critic_L

------------------------------------EVAL Results------------------------------
Selected_Bands:  [ 34  38  45  48  59  65  66  73  78  81  89  91  98 105 111 119 127 129
 139 140 145 154 159 164 169 183 187 190 191 195]
Num_Selected_Bands:  30
Eval_Return:  0.9940475867444101
Critic_Loss:  1.658300518989563
Correlation:  0.6028345737393231
Iteration  133 :
------------------------------------EVAL Results------------------------------
Selected_Bands:  [ 34  38  45  48  59  65  66  73  78  81  89  91  98 105 111 119 120 127
 129 139 140 154 159 164 169 183 187 190 191 195]
Num_Selected_Bands:  30
Eval_Return:  0.9928518147826921
Critic_Loss:  1.6971951723098755
Correlation:  0.6028345737393231
Iteration  134 :
------------------------------------EVAL Results------------------------------
Selected_Bands:  [ 34  38  45  48  59  65  66  73  78  81  89  91  98 105 111 119 127 129
 139 140 145 154 159 164 169 183 187 190 191 195]
Num_Selected_Bands:  30
Eval_Return:  0.9940475867444101
Critic_

------------------------------------EVAL Results------------------------------
Selected_Bands:  [ 34  38  45  48  59  65  66  73  78  81  89  91  98 105 111 119 127 129
 139 140 145 154 159 164 169 183 187 190 191 195]
Num_Selected_Bands:  30
Eval_Return:  0.9940475867444101
Critic_Loss:  1.823441982269287
Correlation:  0.6039539915296552
Iteration  155 :
------------------------------------EVAL Results------------------------------
Selected_Bands:  [ 34  38  45  48  59  65  66  73  78  81  89  91  98 105 111 119 127 129
 139 140 145 154 159 164 169 183 187 190 191 195]
Num_Selected_Bands:  30
Eval_Return:  0.9940475867444101
Critic_Loss:  1.8350380659103394
Correlation:  0.6039539915296552
Iteration  156 :
------------------------------------EVAL Results------------------------------
Selected_Bands:  [ 34  38  45  48  59  65  66  73  78  81  89  91  98 105 111 119 127 129
 139 140 145 154 159 164 169 183 187 190 191 195]
Num_Selected_Bands:  30
Eval_Return:  0.9940475867444101
Critic_

------------------------------------EVAL Results------------------------------
Selected_Bands:  [ 34  38  45  48  59  65  66  73  78  81  89  91  98 105 111 119 127 129
 139 140 145 154 159 164 169 183 187 190 191 195]
Num_Selected_Bands:  30
Eval_Return:  0.9940475867444101
Critic_Loss:  2.0238988399505615
Correlation:  0.6039539915296552
Iteration  177 :
------------------------------------EVAL Results------------------------------
Selected_Bands:  [ 34  38  45  48  59  65  66  73  78  81  89  91  98 105 111 119 127 129
 139 140 145 154 159 164 169 183 187 190 191 195]
Num_Selected_Bands:  30
Eval_Return:  0.9940475867444101
Critic_Loss:  2.060150384902954
Correlation:  0.6039539915296552
Iteration  178 :
------------------------------------EVAL Results------------------------------
Selected_Bands:  [ 34  38  45  48  59  65  66  73  78  81  89  91  98 105 111 119 127 129
 139 140 145 154 159 164 169 183 187 190 191 195]
Num_Selected_Bands:  30
Eval_Return:  0.9940475867444101
Critic_

------------------------------------EVAL Results------------------------------
Selected_Bands:  [ 34  38  45  48  59  65  66  73  78  81  89  91  98 105 111 119 127 129
 139 140 145 154 159 164 169 183 187 190 191 195]
Num_Selected_Bands:  30
Eval_Return:  0.9940475867444101
Critic_Loss:  2.2986631393432617
Correlation:  0.6039539915296552
Iteration  200 :
------------------------------------EVAL Results------------------------------
Iter :  200
Iter :  200
Iter :  200
Iter :  200
Iter :  200
Iter :  200
Iter :  200
Iter :  200
Iter :  200
Iter :  200
Iter :  200
Iter :  200
Iter :  200
Iter :  200
Iter :  200
Iter :  200
Iter :  200
Iter :  200
Iter :  200
Iter :  200
Iter :  200
Iter :  200
Iter :  200
Iter :  200
Iter :  200
Iter :  200
Iter :  200
Iter :  200
Iter :  200
Iter :  200
Selected_Bands:  [ 34  38  45  48  59  65  66  73  78  81  89  91  98 105 111 119 127 129
 139 140 145 154 159 164 169 183 187 190 191 195]
Num_Selected_Bands:  30
Eval_Return:  0.9940475867444101
Criti

------------------------------------EVAL Results------------------------------
Selected_Bands:  [ 34  38  45  48  59  65  66  73  78  81  89  91  98 105 111 119 127 129
 139 140 145 154 159 164 169 183 187 190 191 195]
Num_Selected_Bands:  30
Eval_Return:  0.9940475867444101
Critic_Loss:  2.5520007610321045
Correlation:  0.6039539915296552
Iteration  222 :
------------------------------------EVAL Results------------------------------
Selected_Bands:  [ 34  38  45  48  59  65  66  73  78  81  89  91  98 105 111 119 127 129
 139 140 145 154 159 164 169 183 187 190 191 195]
Num_Selected_Bands:  30
Eval_Return:  0.9940475867444101
Critic_Loss:  2.5503885746002197
Correlation:  0.6039539915296552
Iteration  223 :
------------------------------------EVAL Results------------------------------
Selected_Bands:  [ 34  38  45  48  59  65  66  73  78  81  89  91  98 105 111 119 127 129
 139 140 145 154 159 164 169 183 187 190 191 195]
Num_Selected_Bands:  30
Eval_Return:  0.9940475867444101
Critic

------------------------------------EVAL Results------------------------------
Selected_Bands:  [ 34  38  45  48  59  65  66  73  78  81  89  91  98 105 111 119 127 129
 139 140 145 154 159 164 169 183 187 190 191 195]
Num_Selected_Bands:  30
Eval_Return:  1.0057044600316978
Critic_Loss:  2.725217819213867
Correlation:  0.6039539915296552
Iteration  244 :
------------------------------------EVAL Results------------------------------
Selected_Bands:  [ 34  38  45  48  59  65  66  73  78  81  89  91  98 105 111 119 127 129
 139 140 145 154 159 164 169 183 187 190 191 195]
Num_Selected_Bands:  30
Eval_Return:  1.0057044600316978
Critic_Loss:  2.724937677383423
Correlation:  0.6039539915296552
Iteration  245 :
------------------------------------EVAL Results------------------------------
Selected_Bands:  [ 34  38  45  48  59  65  66  73  78  81  89  91  98 105 111 119 127 129
 139 140 145 154 159 164 169 183 187 190 191 195]
Num_Selected_Bands:  30
Eval_Return:  1.0057044600316978
Critic_L

------------------------------------EVAL Results------------------------------
Selected_Bands:  [ 34  38  45  48  59  65  66  73  78  81  89  91  98 105 111 119 127 129
 139 140 145 154 159 164 169 183 187 190 191 195]
Num_Selected_Bands:  30
Eval_Return:  1.0057044600316978
Critic_Loss:  2.8318557739257812
Correlation:  0.6039539915296552
Iteration  266 :
------------------------------------EVAL Results------------------------------
Selected_Bands:  [ 34  38  45  48  59  65  66  73  78  81  89  91  98 105 111 119 127 129
 139 140 145 154 159 164 169 183 187 190 191 195]
Num_Selected_Bands:  30
Eval_Return:  1.0057044600316978
Critic_Loss:  2.8358898162841797
Correlation:  0.6039539915296552
Iteration  267 :
------------------------------------EVAL Results------------------------------
Selected_Bands:  [ 34  38  45  48  59  65  73  78  81  89  91  98 105 111 119 127 129 139
 140 145 147 154 159 164 169 183 187 190 191 195]
Num_Selected_Bands:  30
Eval_Return:  0.9939078296433568
Critic

------------------------------------EVAL Results------------------------------
Selected_Bands:  [ 34  38  45  48  59  65  66  73  78  81  89  91  98 105 111 119 127 129
 139 140 145 154 159 164 169 183 187 190 191 195]
Num_Selected_Bands:  30
Eval_Return:  1.0057044600316978
Critic_Loss:  2.9024345874786377
Correlation:  0.6039539915296552
Iteration  288 :
------------------------------------EVAL Results------------------------------
Selected_Bands:  [ 34  38  45  48  59  65  66  73  78  81  89  91  98 105 111 119 127 129
 139 140 145 154 159 164 169 183 187 190 191 195]
Num_Selected_Bands:  30
Eval_Return:  1.0057044600316978
Critic_Loss:  2.902218818664551
Correlation:  0.6039539915296552
Iteration  289 :
------------------------------------EVAL Results------------------------------
Selected_Bands:  [ 34  38  45  48  59  65  66  73  78  81  89  91  98 105 111 119 127 129
 139 140 145 154 159 164 169 183 187 190 191 195]
Num_Selected_Bands:  30
Eval_Return:  1.0057044600316978
Critic_

------------------------------------EVAL Results------------------------------
Selected_Bands:  [ 34  38  45  48  59  65  66  73  78  81  89  91  98 105 111 119 127 129
 139 140 145 154 159 164 169 183 187 190 191 195]
Num_Selected_Bands:  30
Eval_Return:  1.0057044600316978
Critic_Loss:  2.9324140548706055
Correlation:  0.6039539915296552
Iteration  310 :
------------------------------------EVAL Results------------------------------
Selected_Bands:  [ 34  38  45  48  59  65  66  73  78  81  89  91  98 105 111 119 127 129
 139 140 145 154 159 164 169 183 187 190 191 195]
Num_Selected_Bands:  30
Eval_Return:  1.0057044600316978
Critic_Loss:  2.9364051818847656
Correlation:  0.6039539915296552
Iteration  311 :
------------------------------------EVAL Results------------------------------
Selected_Bands:  [ 34  38  45  48  59  65  66  73  78  81  89  91  98 105 111 119 127 129
 139 140 145 154 159 164 169 183 187 190 191 195]
Num_Selected_Bands:  30
Eval_Return:  1.0057044600316978
Critic

------------------------------------EVAL Results------------------------------
Selected_Bands:  [ 34  38  45  48  59  65  66  73  78  81  89  91  98 105 111 119 127 129
 139 140 145 154 159 164 169 183 187 190 191 195]
Num_Selected_Bands:  30
Eval_Return:  1.0057044600316978
Critic_Loss:  2.9530770778656006
Correlation:  0.6039539915296552
Iteration  332 :
------------------------------------EVAL Results------------------------------
Selected_Bands:  [ 34  38  45  48  59  65  66  73  78  81  89  91  98 105 111 119 127 129
 139 140 145 154 159 164 169 183 187 190 191 195]
Num_Selected_Bands:  30
Eval_Return:  1.0057044600316978
Critic_Loss:  2.953378915786743
Correlation:  0.6039539915296552
Iteration  333 :
------------------------------------EVAL Results------------------------------
Selected_Bands:  [ 34  38  45  48  59  65  66  73  78  81  89  91  98 105 111 119 127 129
 139 140 145 154 159 164 169 183 187 190 191 195]
Num_Selected_Bands:  30
Eval_Return:  1.0057044600316978
Critic_

------------------------------------EVAL Results------------------------------
Selected_Bands:  [ 34  38  45  48  59  65  66  73  78  81  89  91  98 105 111 119 127 129
 139 140 145 154 159 164 169 183 187 190 191 195]
Num_Selected_Bands:  30
Eval_Return:  1.0057044600316978
Critic_Loss:  2.972662925720215
Correlation:  0.6039539915296552
Iteration  354 :
------------------------------------EVAL Results------------------------------
Selected_Bands:  [ 34  38  45  48  59  65  66  73  78  81  89  91  98 105 111 119 127 129
 139 140 145 154 159 164 169 183 187 190 191 195]
Num_Selected_Bands:  30
Eval_Return:  1.0057044600316978
Critic_Loss:  2.9739701747894287
Correlation:  0.6039539915296552
Iteration  355 :
------------------------------------EVAL Results------------------------------
Selected_Bands:  [ 34  38  45  48  59  65  66  73  78  81  89  91  98 105 111 119 127 129
 139 140 145 154 159 164 169 183 187 190 191 195]
Num_Selected_Bands:  30
Eval_Return:  1.0057044600316978
Critic_

Iter :  375
Iter :  375
Iter :  375
Iter :  375
Iter :  375
Selected_Bands:  [ 34  38  45  48  59  65  66  73  78  81  89  91  98 105 111 119 127 129
 139 140 145 154 159 164 169 183 187 190 191 195]
Num_Selected_Bands:  30
Eval_Return:  1.0057044600316978
Critic_Loss:  3.0299246311187744
Correlation:  0.6039539915296552
Iteration  376 :
------------------------------------EVAL Results------------------------------
Selected_Bands:  [ 34  38  45  48  59  65  66  73  78  81  89  91  98 105 111 119 127 129
 139 140 145 154 159 164 169 183 187 190 191 195]
Num_Selected_Bands:  30
Eval_Return:  1.0057044600316978
Critic_Loss:  3.0233497619628906
Correlation:  0.6039539915296552
Iteration  377 :
------------------------------------EVAL Results------------------------------
Selected_Bands:  [ 34  38  45  48  59  65  66  73  78  81  89  91  98 105 111 119 127 129
 139 140 145 154 159 164 169 183 187 190 191 195]
Num_Selected_Bands:  30
Eval_Return:  1.0057044600316978
Critic_Loss:  3.025629758

------------------------------------EVAL Results------------------------------
Selected_Bands:  [ 34  38  45  48  59  65  66  73  78  81  89  91  98 105 111 119 127 129
 139 140 145 154 159 164 169 183 187 190 191 195]
Num_Selected_Bands:  30
Eval_Return:  1.0057044600316978
Critic_Loss:  3.055832624435425
Correlation:  0.6039539915296552
Iteration  399 :
------------------------------------EVAL Results------------------------------
Selected_Bands:  [ 34  38  45  48  59  65  66  73  78  81  89  91  98 105 111 119 127 129
 139 140 145 154 159 164 169 183 187 190 191 195]
Num_Selected_Bands:  30
Eval_Return:  1.0057044600316978
Critic_Loss:  3.056812047958374
Correlation:  0.6039539915296552
Iteration  400 :
------------------------------------EVAL Results------------------------------
Iter :  400
Iter :  400
Iter :  400
Iter :  400
Iter :  400
Iter :  400
Iter :  400
Iter :  400
Iter :  400
Iter :  400
Iter :  400
Iter :  400
Iter :  400
Iter :  400
Iter :  400
Iter :  400
Iter :  400
I

------------------------------------EVAL Results------------------------------
Selected_Bands:  [ 34  38  45  48  59  65  66  73  78  81  89  91  98 105 111 119 127 129
 139 140 145 154 159 164 169 183 187 190 191 195]
Num_Selected_Bands:  30
Eval_Return:  1.0057044600316978
Critic_Loss:  4.0373005867004395
Correlation:  0.6039539915296552
Iteration  421 :
------------------------------------EVAL Results------------------------------
Selected_Bands:  [ 34  38  45  48  59  65  66  73  78  81  89  91  98 105 111 119 127 129
 139 140 145 154 159 164 169 183 187 190 191 195]
Num_Selected_Bands:  30
Eval_Return:  1.0057044600316978
Critic_Loss:  3.0673017501831055
Correlation:  0.6039539915296552
Iteration  422 :
------------------------------------EVAL Results------------------------------
Selected_Bands:  [ 34  38  45  48  59  65  66  73  78  81  89  91  98 105 111 119 127 129
 139 140 145 154 159 164 169 183 187 190 191 195]
Num_Selected_Bands:  30
Eval_Return:  1.0057044600316978
Critic

------------------------------------EVAL Results------------------------------
Selected_Bands:  [ 34  38  45  48  59  65  66  73  78  81  89  91  98 105 111 119 127 129
 139 140 145 154 159 164 169 183 187 190 191 195]
Num_Selected_Bands:  30
Eval_Return:  1.0057044600316978
Critic_Loss:  3.0820810794830322
Correlation:  0.6039539915296552
Iteration  443 :
------------------------------------EVAL Results------------------------------
Selected_Bands:  [ 34  38  45  48  59  65  66  73  78  81  89  91  98 105 111 119 127 129
 139 140 145 154 159 164 169 183 187 190 191 195]
Num_Selected_Bands:  30
Eval_Return:  1.0057044600316978
Critic_Loss:  3.082416534423828
Correlation:  0.6039539915296552
Iteration  444 :
------------------------------------EVAL Results------------------------------
Selected_Bands:  [ 34  38  45  48  59  65  66  73  78  81  89  91  98 105 111 119 127 129
 139 140 145 154 159 164 169 183 187 190 191 195]
Num_Selected_Bands:  30
Eval_Return:  1.0057044600316978
Critic_

------------------------------------EVAL Results------------------------------
Selected_Bands:  [ 34  38  45  48  59  65  66  73  78  81  89  91  98 105 111 119 127 129
 139 140 145 154 159 164 169 183 187 190 191 195]
Num_Selected_Bands:  30
Eval_Return:  1.0057044600316978
Critic_Loss:  3.0957467555999756
Correlation:  0.6039539915296552
Iteration  465 :
------------------------------------EVAL Results------------------------------
Selected_Bands:  [ 34  38  45  48  59  65  66  73  78  81  89  91  98 105 111 119 127 129
 139 140 145 154 159 164 169 183 187 190 191 195]
Num_Selected_Bands:  30
Eval_Return:  1.0057044600316978
Critic_Loss:  3.096383571624756
Correlation:  0.6039539915296552
Iteration  466 :
------------------------------------EVAL Results------------------------------
Selected_Bands:  [ 34  38  45  48  59  65  66  73  78  81  89  91  98 105 111 119 127 129
 139 140 145 154 159 164 169 183 187 190 191 195]
Num_Selected_Bands:  30
Eval_Return:  1.0057044600316978
Critic_

------------------------------------EVAL Results------------------------------
Selected_Bands:  [ 34  38  45  48  59  65  66  73  78  81  89  91  98 105 111 119 127 129
 139 140 145 154 159 164 169 183 187 190 191 195]
Num_Selected_Bands:  30
Eval_Return:  1.0057044600316978
Critic_Loss:  3.096707820892334
Correlation:  0.6039539915296552
Iteration  487 :
------------------------------------EVAL Results------------------------------
Selected_Bands:  [ 34  38  45  48  59  65  66  73  78  81  89  91  98 105 111 119 127 129
 139 140 145 154 159 164 169 183 187 190 191 195]
Num_Selected_Bands:  30
Eval_Return:  1.0057044600316978
Critic_Loss:  3.096829414367676
Correlation:  0.6039539915296552
Iteration  488 :
------------------------------------EVAL Results------------------------------
Selected_Bands:  [ 34  38  45  48  59  65  66  73  78  81  89  91  98 105 111 119 127 129
 139 140 145 154 159 164 169 183 187 190 191 195]
Num_Selected_Bands:  30
Eval_Return:  1.0057044600316978
Critic_L

------------------------------------EVAL Results------------------------------
Selected_Bands:  [ 25  34  38  45  48  59  65  73  78  81  89  91  98 105 111 119 127 129
 139 140 145 154 159 164 169 183 187 190 191 195]
Num_Selected_Bands:  30
Eval_Return:  0.9947655311816366
Critic_Loss:  3.1014320850372314
Correlation:  0.6148904593880271
Iteration  509 :
------------------------------------EVAL Results------------------------------
Selected_Bands:  [ 25  34  38  45  48  59  65  73  78  81  89  91  98 105 111 119 127 129
 139 140 145 154 159 164 169 183 187 190 191 195]
Num_Selected_Bands:  30
Eval_Return:  0.9947655311816366
Critic_Loss:  3.101365804672241
Correlation:  0.6148904593880271
Iteration  510 :
------------------------------------EVAL Results------------------------------
Selected_Bands:  [ 25  34  38  45  48  59  65  73  78  81  89  91  98 105 111 119 127 129
 139 140 145 154 159 164 169 183 187 190 191 195]
Num_Selected_Bands:  30
Eval_Return:  0.9947655311816366
Critic_

------------------------------------EVAL Results------------------------------
Selected_Bands:  [ 25  34  38  45  48  59  65  73  78  81  89  91  98 105 111 119 127 129
 139 140 145 154 159 164 169 183 187 190 191 195]
Num_Selected_Bands:  30
Eval_Return:  0.9947655311816366
Critic_Loss:  3.109609842300415
Correlation:  0.6148904593880271
Iteration  531 :
------------------------------------EVAL Results------------------------------
Selected_Bands:  [ 25  34  38  45  48  59  65  73  78  81  89  91  98 105 111 119 127 129
 139 140 145 154 159 164 169 183 187 190 191 195]
Num_Selected_Bands:  30
Eval_Return:  0.9947655311816366
Critic_Loss:  3.1049256324768066
Correlation:  0.6148904593880271
Iteration  532 :
------------------------------------EVAL Results------------------------------
Selected_Bands:  [ 25  34  38  45  48  59  65  73  78  81  89  91  98 105 111 119 127 129
 139 140 145 154 159 164 169 183 187 190 191 195]
Num_Selected_Bands:  30
Eval_Return:  0.9947655311816366
Critic_

------------------------------------EVAL Results------------------------------
Selected_Bands:  [ 25  34  38  45  48  59  65  73  78  81  89  91  98 105 111 119 120 127
 129 139 140 154 159 164 169 183 187 190 191 195]
Num_Selected_Bands:  30
Eval_Return:  0.9916343041064233
Critic_Loss:  3.1074776649475098
Correlation:  0.61707363323206
Iteration  553 :
------------------------------------EVAL Results------------------------------
Selected_Bands:  [ 25  34  38  45  48  59  65  73  78  81  89  91  98 105 111 119 120 127
 129 139 140 154 159 164 169 183 187 190 191 195]
Num_Selected_Bands:  30
Eval_Return:  0.9916343041064233
Critic_Loss:  3.107506513595581
Correlation:  0.61707363323206
Iteration  554 :
------------------------------------EVAL Results------------------------------
Selected_Bands:  [ 25  34  38  45  48  59  65  73  78  81  89  91  98 105 111 119 120 127
 129 139 140 154 159 164 169 183 187 190 191 195]
Num_Selected_Bands:  30
Eval_Return:  0.9916343041064233
Critic_Loss

------------------------------------EVAL Results------------------------------
Iter :  575
Iter :  575
Iter :  575
Iter :  575
Iter :  575
Iter :  575
Iter :  575
Iter :  575
Iter :  575
Iter :  575
Iter :  575
Iter :  575
Iter :  575
Iter :  575
Iter :  575
Iter :  575
Iter :  575
Iter :  575
Iter :  575
Iter :  575
Iter :  575
Iter :  575
Iter :  575
Iter :  575
Iter :  575
Iter :  575
Iter :  575
Iter :  575
Iter :  575
Iter :  575
Selected_Bands:  [ 25  34  38  45  48  59  65  73  78  81  89  91  98 105 111 119 120 127
 129 139 140 154 159 164 169 183 187 190 191 195]
Num_Selected_Bands:  30
Eval_Return:  0.9916343041064233
Critic_Loss:  3.1067283153533936
Correlation:  0.61707363323206
Iteration  576 :
------------------------------------EVAL Results------------------------------
Selected_Bands:  [ 25  34  38  45  48  59  65  73  78  81  89  91  98 105 111 119 120 127
 129 139 140 154 159 164 169 183 187 190 191 195]
Num_Selected_Bands:  30
Eval_Return:  0.9916343041064233
Critic_

------------------------------------EVAL Results------------------------------
Selected_Bands:  [ 25  34  38  45  48  59  65  73  78  81  89  91  98 105 111 119 120 127
 129 139 140 154 159 164 169 183 187 190 191 195]
Num_Selected_Bands:  30
Eval_Return:  0.9916343041064233
Critic_Loss:  3.1063334941864014
Correlation:  0.61707363323206
Iteration  598 :
------------------------------------EVAL Results------------------------------
Selected_Bands:  [ 25  34  38  45  48  59  65  73  78  81  89  91  98 105 111 119 120 127
 129 139 140 154 159 164 169 183 187 190 191 195]
Num_Selected_Bands:  30
Eval_Return:  0.9916343041064233
Critic_Loss:  3.1074602603912354
Correlation:  0.61707363323206
Iteration  599 :
------------------------------------EVAL Results------------------------------
Selected_Bands:  [ 25  34  38  45  48  59  65  73  78  81  89  91  98 105 111 119 120 127
 129 139 140 154 159 164 169 183 187 190 191 195]
Num_Selected_Bands:  30
Eval_Return:  0.9916343041064233
Critic_Los

------------------------------------EVAL Results------------------------------
Selected_Bands:  [ 34  38  45  48  59  65  73  78  81  89  91  98 105 111 119 120 127 129
 139 140 145 154 159 164 169 183 187 190 191 195]
Num_Selected_Bands:  30
Eval_Return:  0.9939145723224556
Critic_Loss:  3.098707675933838
Correlation:  0.61707363323206
Iteration  620 :
------------------------------------EVAL Results------------------------------
Selected_Bands:  [ 34  38  45  48  59  65  73  78  81  89  91  98 105 111 119 120 127 129
 139 140 145 154 159 164 169 183 187 190 191 195]
Num_Selected_Bands:  30
Eval_Return:  0.9939145723224556
Critic_Loss:  3.0987555980682373
Correlation:  0.61707363323206
Iteration  621 :
------------------------------------EVAL Results------------------------------
Selected_Bands:  [ 25  34  38  45  48  59  65  73  78  81  89  91  98 105 111 119 120 127
 129 139 140 154 159 164 169 183 187 190 191 195]
Num_Selected_Bands:  30
Eval_Return:  0.9948097863894603
Critic_Loss

------------------------------------EVAL Results------------------------------
Selected_Bands:  [ 25  34  38  45  48  59  65  73  78  81  89  91  98 105 111 119 127 129
 139 140 145 154 159 164 169 183 187 190 191 195]
Num_Selected_Bands:  30
Eval_Return:  0.9947655311816366
Critic_Loss:  3.1049513816833496
Correlation:  0.6148904593880271
Iteration  642 :
------------------------------------EVAL Results------------------------------
Selected_Bands:  [ 25  34  38  45  48  59  65  73  78  81  89  91  98 105 111 119 127 129
 139 140 145 154 159 164 169 183 187 190 191 195]
Num_Selected_Bands:  30
Eval_Return:  0.9947655311816366
Critic_Loss:  3.10615611076355
Correlation:  0.6148904593880271
Iteration  643 :
------------------------------------EVAL Results------------------------------
Selected_Bands:  [ 25  34  38  45  48  59  65  73  78  81  89  91  98 105 111 119 127 129
 139 140 145 154 159 164 169 183 187 190 191 195]
Num_Selected_Bands:  30
Eval_Return:  0.9947655311816366
Critic_L

------------------------------------EVAL Results------------------------------
Selected_Bands:  [ 25  34  38  45  48  59  65  73  78  81  89  91  98 105 111 119 127 129
 139 140 145 154 159 164 169 183 187 190 191 195]
Num_Selected_Bands:  30
Eval_Return:  0.9947655311816366
Critic_Loss:  3.1225759983062744
Correlation:  0.6148904593880271
Iteration  664 :
------------------------------------EVAL Results------------------------------
Selected_Bands:  [ 25  34  38  45  48  59  65  73  78  81  89  91  98 105 111 119 127 129
 139 140 145 154 159 164 169 183 187 190 191 195]
Num_Selected_Bands:  30
Eval_Return:  0.9947655311816366
Critic_Loss:  3.1221537590026855
Correlation:  0.6148904593880271
Iteration  665 :
------------------------------------EVAL Results------------------------------
Selected_Bands:  [ 25  34  38  45  48  59  65  73  78  81  89  91  98 105 111 119 127 129
 139 140 145 154 159 164 169 183 187 190 191 195]
Num_Selected_Bands:  30
Eval_Return:  0.9947655311816366
Critic

KeyboardInterrupt: 

In [12]:
agent.logging_df[agent.logging_df['iter_num']==0]

Unnamed: 0,iter_num,Selected Band,Action Type,Mean,Min,Max,Correlation Current State,Correlation Next State,Reward,Loss
0,0.0,0.0,Max Action,0.032197937,-0.114741586,0.8134908,0.0,0.0,1.0,0.473343
1,0.0,1.0,Max Action,0.03804788,-0.11941104,0.9804973,0.0,0.185507,0.830683,0.2468
2,0.0,2.0,Random Action,0.050463285,-0.15431102,1.3450689,0.185507,0.422493,0.789002,0.116353
3,0.0,3.0,Random Action,0.05984806,-0.15989354,1.6444672,0.422493,0.478256,0.945764,0.086307
4,0.0,4.0,Random Action,0.065917924,-0.16938102,1.7850955,0.478256,0.445233,1.033574,0.146224
5,0.0,5.0,Random Action,0.060496625,-0.1595816,1.6218024,0.445233,0.464352,0.981063,0.708086
6,0.0,6.0,Random Action,0.06333691,-0.1668063,1.6709975,0.464352,0.506604,0.958628,2.892471
7,0.0,7.0,Random Action,0.059510164,-0.17310499,1.5724148,0.506604,0.511924,0.994694,1.238159
8,0.0,8.0,Random Action,0.0593328,-0.2010959,1.5547637,0.511924,0.545951,0.966545,0.81709
9,0.0,9.0,Random Action,0.05901288,-0.23830225,1.5443425,0.545951,0.508474,1.038189,0.037396


In [None]:
filter_df = agent.logging_df[agent.logging_df['Selected Band']==29.0]
sns.lineplot(x='iter_num', y='Correlation Next State', data=filter_df)

In [None]:
for i in range(30):
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))
    plot_df = agent.logging_df[agent.logging_df["Selected Band"] == i]
    ax1.plot(plot_df["iter_num"], plot_df["Mean"], color='red', label="Mean")
    ax1.plot(plot_df["iter_num"], plot_df["Max"], color='blue', label="Max")
    ax1.plot(plot_df["iter_num"], plot_df["Min"], color='green', label="Min")
    ax1.axhline(plot_df['Reward'].mean(), color='red')
    ax1.set_title(f'Band Selection {i}')
    ax1.legend()
    ax2.plot(plot_df["iter_num"], plot_df["Loss"], color='red')
    ax2.set_title('Loss Function')
    plt.show()

In [None]:
agent.logging_df

In [None]:
band_1_logging_df = agent.logging_df[agent.logging_df["Selected Band"] == 1.0]

In [None]:
band_1_logging_df

In [None]:
plt.figure(num = 3, figsize=(8, 5))
plt.plot(band_1_logging_df["iter_num"], band_1_logging_df["Mean"], color='red', label="Mean")
plt.plot(band_1_logging_df["iter_num"], band_1_logging_df["Max"], color='blue', label="Max")
plt.plot(band_1_logging_df["iter_num"], band_1_logging_df["Min"], color='green', label="Min")

plt.axhline(band_1_logging_df['Reward'].mean( ))

plt.xlabel("iter num")
plt.ylabel("Q values")
plt.legend(loc="upper left")

plt.show()

In [None]:
plt.figure(num = 3, figsize=(8, 5))
plt.plot(band_1_logging_df["iter_num"], band_1_logging_df["Loss"], color='red')
plt.show()


In [None]:
path = '/Users/romitbarua/Documents/Berkeley/Fall 2022/CS285-Deep Reinforcement Learning/HyperSpectralRL/data/data_indian_pines_drl.mat'
data = scipy.io.loadmat(path)['x']
band_1 = data[:, 173]
band_2 = data[:, 2]

In [None]:
num = np.sum((band_1 - np.mean(band_1))*(band_2 - np.mean(band_2)))
den = np.sqrt(np.sum((band_1 - np.mean(band_1))**2)*np.sum((band_2 - np.mean(band_2))**2))
num/den

In [None]:
band_1

In [None]:
band_2

In [None]:
corr = []
band_1 = data[:, 40]
for i in range(2, 200):

    band_2 = data[:, i]
    num = np.sum((band_1 - np.mean(band_1))*(band_2 - np.mean(band_2)))
    den = np.sqrt(np.sum((band_1 - np.mean(band_1))**2)*np.sum((band_2 - np.mean(band_2))**2))
    corr.append(num/den)
    
    
sns.histplot(corr)

In [None]:
path = r'/Users/romitbarua/Documents/Berkeley/Fall 2022/CS285-Deep Reinforcement Learning/HyperSpectralRL/data/data_indian_pines_drl.mat'
num_iter = 50
num_bands = 200
bands_to_select = 30
data = scipy.io.loadmat(path)['x'][:, :num_bands]

all_band_selection = []
all_corr = []
all_reward = []

for num_iter in range(num_iter):
    selected_bands = list(np.random.choice(np.arange(0, num_bands, 1), 1))
    corr = [0]
    rewards = [0]
    for i in range(bands_to_select):

        selected_bands.extend(list(np.random.choice(np.arange(0, num_bands, 1), 1)))
        corr_sum = 0
        
        for idx_i, i in enumerate(selected_bands):
            for idx_j, j in enumerate(selected_bands):
                if idx_i != idx_j:
                    corr_sum += abs(pearsonr(data[:, i], data[:, j])[0])
        corr.append(corr_sum/(len(selected_bands)**2))
        rewards.append(corr[-2] - corr[-1])
    
    if num_iter % 100 == 0:
        print(num_iter)
        
    all_band_selection.append(selected_bands)
    all_corr.append(corr)
    all_reward.append(rewards)
    
all_band_selection = np.array(all_band_selection)
all_corr = np.array(all_corr)
all_reward = np.array(all_reward)

In [None]:
num_iter = 1000
cum_reward = np.zeros((num_iter, bands_to_select))
for i in range(bands_to_select):
    cum_reward[:, i] = np.sum(all_reward[:, i+1:], axis=1)
    
# reasonbleness check to make sure this works
#print(all_reward[0, :])
#print(sum(all_reward[0, :]))
#print(cum_reward[0])

In [None]:
for i in range(bands_to_select):
    sns.histplot(cum_reward[:,i])
    plt.title(f'Distribution of Q-Values at band selection {i+1}')
    plt.show()

In [None]:


# reward functions

def calculate_correlations(data, num_bands_originally, num_bands_kept):
    
    selected_bands = np.random.randint(0,num_bands_originally,num_bands_kept)
    corr_sum = 0
    for i in selected_bands:
        for j in selected_bands:
            if i != j:
            
                corr_sum += np.abs(pearsonr(data[:, i], 
                                 data[:, j])[0])
            
    return corr_sum/(len(selected_bands)**2)


def calculate_mutual_infos(data, num_bands_originally, num_bands_kept):
    
    selected_bands = np.random.randint(0,num_bands_originally,num_bands_kept)
    normalized_mutual_info_score_sum = 0
    for i in selected_bands:
        for j in selected_bands:
            
            normalized_mutual_info_score_sum += normalized_mutual_info_score(data[:, i],
                                                                             data[:, j])
            
    return normalized_mutual_info_score_sum/(len(selected_bands)**2)



In [None]:

# rewards
    
path = r'/Users/romitbarua/Documents/Berkeley/Fall 2022/CS285-Deep Reinforcement Learning/HyperSpectralRL/data/data_indian_pines_drl.mat'
num_iter = 50
num_bands = 200
bands_to_select = 30
hyper_multiple = scipy.io.loadmat(path)['x'][:, :num_bands]

    
num_runs = 25


correlations = []
for i in range(num_runs):
    correlations.append(calculate_correlations(hyper_multiple, num_bands_originally=hyper_multiple.shape[-1], num_bands_kept=30))
print(f'\nCorrelation reward for random 10 bands, x{num_runs} runs:', np.mean(correlations))

# plot rewards
a_string = ['pearson correlation (cumulative avg)'] * len(correlations)    
#b_string = ['normalized mutual information (cumulative avg)'] * len(mis)
#strings = a_string + b_string
pd_df = pd.DataFrame([correlations, a_string]).T
pd_df[0] = pd_df[0].astype(float, copy=True)
sns.histplot(data=pd_df, bins=15, x=0, kde=True)
plt.title(f'Test', fontsize=17)
#plt.xlim([0,1])
plt.show()
plt.figure()


In [None]:
idx =  [  4,   6 , 31,  36,  41,  51,  64,  68,  71,  72,  81,  86,  95 ,105, 106, 107, 113, 124,
 138, 139, 145, 147, 160, 162, 167, 168, 175, 177, 185, 187]

idx = [ 19,  38,  46,  51,  59,  62,  63,  71,  72,  79,  81,  82,  83,  87,  89,  98, 101, 105,
 123, 129, 133, 137, 138, 142, 157, 173, 182, 187 , 189, 192]

def calculate_correlations(data):
    
    selected_bands = idx
    corr_sum = 0
    for i in selected_bands:
        for j in selected_bands:
            if i != j:
            
                corr_sum += np.abs(pearsonr(data[:, i], 
                                 data[:, j])[0])
            
    return corr_sum/(len(selected_bands)**2)

calculate_correlations(agent.DataManager.rl_data)

In [None]:
idx = list(scipy.io.loadmat('/Users/romitbarua/Documents/Berkeley/Fall 2022/CS285-Deep Reinforcement Learning/Final Project/DRL4BS/results/drl_30_bands_indian_pines.mat')['selected_bands'][0])

print(list(idx))

def calculate_correlations(data):
    
    selected_bands = idx
    corr_sum = 0
    for i in selected_bands:
        for j in selected_bands:
            if i != j:
            
                corr_sum += np.abs(pearsonr(data[:, int(i)], 
                                 data[:, int(j)])[0])
            
    return corr_sum/(len(selected_bands)**2)

calculate_correlations(agent.DataManager.rl_data)

In [None]:

class DataManager():
    def __init__(self, params, num_bands):
        self.rl_data = None
        self.dataset_type = params['dataset_type']
        self.data_file_path = params['data_file_path']
        self.sample_ratio = params['sample_ratio']
        
        self.num_bands = num_bands
        #load the data
        assert self.dataset_type in ('IndianPines', 'Botswana', 'SalientObjects'), f'{self.dataset_type} is not valid'
        #separating out in case any of the data requires unique pre-processig
        if self.dataset_type == 'IndianPines':
            self.load_indian_pine_data()
        elif self.dataset_type == 'Botswana':
            self.load_botswana_data()
        elif self.dataset_type == 'SalientObjects':
            self.load_salient_objects_data()
        #self.x_train = None
        #self.y_train = None
        #self.x_test = None
        #self.y_test = None
        
    def load_indian_pine_data(self):
        hyper_path = '/Users/romitbarua/Documents/Berkeley/Fall 2022/CS285-Deep Reinforcement Learning/HyperSpectralRL/data/data_indian_pines_drl.mat'
        hyper = np.load(hyper_path)
        print(hyper.shape)
        # randomly sample for x% of the pixels
        indices = np.random.randint(0, hyper.shape[0], int(hyper.shape[0]*self.sample_ratio))
        self.rl_data = hyper[indices, :]
        print(self.rl_data.shape)
        
    def load_salient_objects_data(self):
        hyper_path = '../data/salient_objects/hyperspectral_imagery/0001.npy'
        hyper = np.load(hyper_path)
        print(hyper.shape)
        # randomly sample for x% of the pixels
        indices = np.random.randint(0, hyper.shape[0], int(hyper.shape[0]*self.sample_ratio))
        self.rl_data = hyper[indices, :]
        print(self.rl_data.shape)
        
    def load_botswana_data(self):
        self.rl_data = scipy.io.loadmat(self.data_file_path)
    #def load_salient_objects(self)

In [None]:
researcher_selected_bands = scipy.io.loadmat('data/original_paper/results/drl_30_bands_indian_pines.mat')