In [1]:
import mdp
import numpy as np
import matplotlib.pyplot as plt

### Approximate Q learning: Linear Q-function

In [4]:
# defining a linear Q function class
class LinearQ():

    def __init__(self, features):
        self.features = features
        # initialize weights to zero
        num_weights = self.features.get_num_actions() * self.features.get_num_features() 
        self.weights = np.zeros(shape=(num_weights)) 


    # update the weights
    def update(self, state, action, delta):
        # extract features from state
        feature_values = self.features.extract(state, action)
        # update weights
        self.weights += delta * feature_values


    # evaluate q function
    def evaluate(self, state, action):
        # extract features from state
        feature_values = self.features.extract(state, action)
        # compute Q value
        Q = np.dot(feature_values, self.weights)
        return Q



    


# defining a feature extractor class for gridworld problem (hand-engineered features)
class GridWorldFeatures:
    def __init__(self, mdp):
        self.mdp = mdp
        self.num_features = 3
        

    def get_num_features(self):
        return self.num_features    
 
 
    def get_num_actions(self):
        return len(self.mdp.get_actions())


    '''
        We will define three (normalized) features:
        1) x-distance from goal
        2) y-distance from goal
        3) manhattan distance from goal
    '''
    def extract_features(self, state, action):
        (xg, yg) = self.mdp.goal
        (x, y) = state
        e = 0.01  # small additive value for avoiding division by zero        

        feature_values = []
        for a in self.mdp.get_actions():
            if (a == action) and (state != self.mdp.exit):
                feature_values.append([(x+e)/(xg+e)])
                feature_values.append([(y+e)/(yg+e)])
                feature_values.append([(abs(xg-x)+abs(yg-y)+e)/(xg+yg+e)])
            else:
                feature_values += [0.0 for _ in range(self.num_features)]



In [3]:
# instantiate grid world mdp
gw = mdp.GridWorld(discount_factor=0.9, withQTable=False)

# instantiate feature extractor
features = GridWorldFeatures(gw)

# instantiate linear q function object
Qfunction = LinearQ(features)



In [6]:
d = {'a' : 8, 'b' : 4, 'c' : 15}

m = max(d.items(), key=lambda x: x[1])
print(m)

('c', 15)
