In [1]:
from sklearn.linear_model import LogisticRegression
import scipy.io
import numpy as np
import h5py
from sklearn.datasets import make_classification
from sklearn.metrics import accuracy_score
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.neighbors import KNeighborsClassifier
import pandas as pd
from sklearn.ensemble import RandomForestClassifier
import torch.nn as nn
from scipy.stats.stats import pearsonr
import torch

In [2]:
class DataManager():
    
    def __init__(self, params):
        
        self.rl_data = None
        self.dataset_type = params['dataset_type']
        self.data_file_path = params['data_file_path']
        
        #load the data
        assert self.dataset_type in ('IndianPines', 'Botswana'), f'{self.dataset_type} is not valid' 
        #separating out in case any of the data requires unique pre-processig
        if self.dataset_type == 'IndianPines':
            self.load_indian_pine_data()
        elif self.dataset_type == 'Botswana':
            self.load_botswana_data()
            
        #self.x_train = None
        #self.y_train = None
        #self.x_test = None
        #self.y_test = None

    def load_indian_pine_data(self):
        
        #self.rl_data = scipy.io.loadmat(self.data_file_path)
        self.rl_data = h5py.File(self.data_file_path, 'r')
        
        #self.x_train = np.array(data['x_tra']).transpose()
        #self.y_train = np.argmax(np.array(data['y_tra']).transpose(), axis=1)
        #self.x_test = np.array(data['x_test']).transpose()
        #self.y_test = np.argmax(np.array(data['y_test']).transpose(), axis=1)
        
    def load_botswana_data(self):
        
        self.rl_data = scipy.io.loadmat(self.data_file_path)
        

In [3]:
class ReplayBuffer():
    
    def __init__(self, size=10000):
        self.size = size
        self.paths = []
        
    def add_trajectories(self, paths):
        self.paths.extend(paths)
        self.paths = self.paths[-self.max_size:]
        
    def sample_buffer_random(self, num_trajectories):
        
        rand_idx = np.random.permutation(len(self.paths))[:num_trajectories]
        return self.paths[rand_idx]
    
        

In [4]:
class Agent():
    
    def __init__(self, params):
        
        self.agent_params = params['agent']
        self.n_iter = self.agent_params['n_iter']
        self.trajectory_sample_size = self.agent_params['trajectory_sample_size']
        self.batch_size = self.agent_params['batch_size']
        self.num_critic_updates = self.agent_params['num_critic_updates']
        
        self.data_params = params['data']
        self.DataManager = DataManager(self.data_params)
        self.band_selection_num = self.data_params['band_selection_num']

        self.critic_params = params['critic']
        self.critic = QCritic(self.critic_params)
        
        self.policy = ArgMaxPolicy(self.critic)
        
        self.replay_buffer = ReplayBuffer()
        
    
    def generateTrajectories(self):
        
        #we expect paths to be a list of trajectories
        #a trajectory is a list of Path objects
        paths = []
        for _ in range(self.trajectory_sample_size):
            paths.append(self.sampleTrajectory())
    
        return paths
    
    def sampleTrajectory(self):
            
        #select 30 actions
        state = np.zeros(200)
        state_next = state

        #paths will be a list of dictionary
        path = []
        for i in range(self.band_selection_num):
            
            action = self.policy.get_action(state)
            state_next[action] = 1

            reward = calculate_rewards(state, state_next)
            terminal = 1 if i == 29 else 0
            path.append(Path(state, action, state_next, reward, terminal))

        return path
                   
        
    def runAgent(self):
        
        for iter_num in range(self.n_iter):
            
            print('Iteration ', iter_num, ':')
            
            paths = self.generateTrajectories()
            self.replay_buffer.add_trajectories(paths)
            
            for _ in range(self.num_critic_updates):
                sampled_paths = self.replay_buffer.sample_buffer_random(self.agent_params['batch_size'])
                
                flat_sampled_path = [path for trajectory in sampled_paths for path in trajectory]
                obs = [path['ob'] for path in flat_sampled_path]
                acs = [path['ac'] for path in flat_sampled_path]
                obs_next = [path['ob_next'] for path in flat_sampled_path]
                res = [path['re'] for path in flat_sampled_path]
                terminals = [path['terminal'] for path in flat_sampled_path]
                
                self.critic.update(obs, acs, obs_next, res, terminals)
            
            
    def calculate_reward(state, state_next):
        #for future, save down the previous state so that we can avoid a calc
        return self.calculate_correation(state) - self.calculate_correation(state_next)
    
    
    def calculate_correlations(state):
        
        selected_bands = np.squeeze(np.argwhere(np.array(state)==1))
        corr_sum = 0
        for i in selected_bands:
            for j in selected_bands:
                corr_sum += pearsonr(self.DataManger.rl_data[:, i], self.DataManger.rl_data[:, i])
        return corr_sum/(len(selected_bands)**2)
            
            
    def Path(ob, ac, ob_next, re, terminal):
        return {'ob':ob,
                'ac':ac,
                'ob_next':ob_next,
                're':re,
                'terminal':terminal
                }
        
        

In [5]:
class ArgMaxPolicy():
    
    def __init__(self, critic):
        self.critic = critic
        
    def get_action(self, obs):
        
        
        q_value_estimates = self.critic.get_action(obs)
        
        #return index of best action
        return np.argmax(q_value_estimates)

In [6]:
class QCritic():
    
    def __init__(self, params):
        
        self.critic = self.create_network()
        self.critic_optimizer = torch.optim.Adam(self.critic.parameters(),lr=0.005)
        self.gamma = params['gamma']
        
        self.loss = nn.SmoothL1Loss()
    
    def create_network(self):
        
        q_net  = nn.Sequential(
        nn.Linear(200, 400),
        nn.ReLU(),
        nn.Linear(400, 400),
        nn.ReLU(),
        nn.Linear(400, 200)
        )
        
        return q_net
    
        
    def forward(self, obs):
        # will take in one hot encoded states and output a list of qu values
        
        q_values = self.critic(obs)
        
        return q_values
    
    def get_action(self, obs):
        
        
        return self.critic(obs)
    
    def update(self, obs, ac_n, next_obs, reward_n):
        
        q_values = self.critic(obs)
        
        q_values_next = self.critic(next_obs)
        
        target = reward_n + self.gamma*q_values_next
        target = target.detach()
        
        loss = self.loss(q_values, target)
        
        self.critic_optimizer.zero_grad()
        loss.backward()
        self.critic_optimizer.step()
        
        return loss.item()
        

In [None]:
#utility functions
## taken from Prof. Sergey Levine's CS285 HW
def from_numpy(*args, **kwargs):
    return torch.from_numpy(*args, **kwargs).float().to(device)

def to_numpy(tensor):
    return tensor.to('cpu').detach().numpy()

In [7]:
params = {'agent':{
            'n_iter':100,
            'trajectory_sample_size':100,
            'batch_size':50,
            'num_critic_updates':10
            },
          'data':{
            'band_selection_num':30,
            'dataset_type':'IndianPines',
            'data_file_path':'/Users/romitbarua/Documents/Berkeley/Fall 2022/CS285-Deep Reinforcement Learning/Final Project/DRL4BS/data4classification/indian_pines_randomSampling_0.1_run_1.mat' 
            },
          'critic':{
            'gamma':0.99
            }
         }

In [8]:
agent = Agent(params)
agent.runAgent()

Iteration  0 :


TypeError: linear(): argument 'input' (position 1) must be Tensor, not numpy.ndarray