new pomcp from scratch!

In [1]:
# new pomcp from scratch

In [6]:
# imports
from random import randint
import numpy as np
import random
import datetime # for limiting calculation to wall clock time
import math
import copy
import matplotlib.pyplot as plt


In [7]:
# # game class
class TigerProblem():
    def __init__(self,obs_truth = 0.8):
#         self.state = ('tiger-left' if random.random() < 0.5 else 'tiger-right')
        self.obs_truth = obs_truth # how much you can trust observations
        self.all_states = set(['tiger-left','tiger-right'])
        
    # this contains the transition function for the MDP
    def next_state(self, state, action):
        
        assert action in ['listen','open-left','open-right'], "invalid action"
        
        if action == 'listen':
            return state
        elif action == 'open-left':
            return ('tiger-left' if random.random() < 0.5 else 'tiger-right')
        elif action == 'open-right':            
            return ('tiger-left' if random.random() < 0.5 else 'tiger-right')
       
    # observation function for the POMDP
    def observation(self, state, action):  
        
        assert action in ['listen','open-left','open-right'], "invalid action"
        
        all_s = self.all_states
        
        if action == 'listen':
            if random.random() < self.obs_truth:
                return state
            else:
                other = [st for st in self.all_states if st != state]
                return other[0]
        else:
            return []
        
    # Take a sequence of game states representing the full game tree, and return the full list
    # of actions that are legal actions
    def legal_actions(self,state_hist): 
        return ['listen','open-left','open-right']
    
    # Should this be the length of state_tree (how long tama alive for)? or is it trial by trial rwd?
    def reward(self, state, action):
        if action == 'listen':
            return -1
        elif action == 'open-left':
            return -20 if state == 'tiger-left' else 10
        elif action == 'open-right':            
            return -20 if state == 'tiger-right' else 10
            
    # GENERATOR MODEL OF GAME
    # returns next state, observation, and reward given an action taken in given state
    # takes tuple state
    def G_model(self,state,action):
        s = self.next_state(state,action)
        obs = self.observation(state,action)
        rwd = self.reward(state,action) # note that this should be more like immediate reward of state, not long-term?
        done = False
        return s, obs, rwd, done
    
    # Initial state distribution
    # Initial state distribution
    def sample_prior(self):
        s = ('tiger-left' if random.random() < 0.5 else 'tiger-right')
        return s
    
    # when filtering particles, this is the rule to keep one given a real observation
    def keep_particle(self, part, real_obs):
        trash_prob = 0.8
        if real_obs == []:
            return True
        if part != real_obs and random.random() < trash_prob:
            return False
        return True

    #     if real_obs == []:
    #         return True
    #     else:
    #         if part == real_obs:
    #             return True
    #     return False

    # generate a new particle from one randomly sampled from current belief (e.g., just add a lil noise)
    def new_particle(self, part):
        s = ('tiger-left' if random.random() < 0.5 else 'tiger-right')
        return s

In [24]:
class SearchTree(object):
    def __init__(self,visits=1,value=0):
        self.visits = visits
        self.value = value
        self.children=[]
        
        
class ActionNode(SearchTree):
    def __init__(self,action=None,visits=1,value=0):
        super().__init__(visits,value)
        self.action = action
        
        
class ObservationNode(SearchTree):
    def __init__(self,observation=[],visits=1,value=0,belief=[]):
        super().__init__(visits,value)
        self.observation = observation
        self.belief = belief
        
    def expand(self, legal_actions):
        for a in legal_actions:
            self.children += [ActionNode(a)]
            
        # upper confidence bound value for given node "child"
    def ucb(self, child): #maybe use index of child not object
#         print("self.visit=",self.visit," len of self.children=",len(self.children))
        logval = math.log(self.visits) #, len(self.children))
        div = logval / child.visits
        return math.sqrt(div)
    
#     def sample_belief(self):
#         return random.choice(self.belief)
    
    def next_hist(self,action,obs):
        act_child = next((c for c in self.children if c.action==action), None)
        assert act_child != None, "shouldn't you be expanded already?"
        assert isinstance(act_child, ActionNode), "action child should be an action node!"
        
        obs_child = next((c for c in act_child.children if c.observation == obs), None) 
        if obs_child is None:
            act_child.children += [ObservationNode(obs)]
#             print(act_child.children[0].observation)
            obs_child = next((c for c in act_child.children if c.observation == obs), None)     
        return obs_child
    
#     def next_hist_rollout(self,action,obs):
        

In [28]:
class POMCP:
    def __init__(self, 
                 game=TigerProblem(),
                 discount=0.8,
                 epsilon=1e-7,
                 explore=1,
                 n_particles=100,
                 reinvigoration=20, 
                 **kwargs):
        
#         self.context = {}
        self.game = game
        self.discount = discount
        self.epsilon = epsilon
        self.explore = explore
        self.n_particles = n_particles
        self.reinvigoration = reinvigoration
        self.G = game.G_model      
        self.tree = None
        self.history = []
        # list of all possible actions
#         self.actions = kwargs.get('actions') 
        
        seconds = kwargs.get('time',30)
        self.calculation_time = datetime.timedelta(seconds=seconds)
        self.maxdepth = kwargs.get('maxdepth',20)
        self.nsims = kwargs.get('nsims',1000)
        
    def search(self,obs):
        
        self.history += [obs]
        
        if self.tree is None:
            self.tree = ObservationNode(obs)                        
#             particle = self.game.sample_prior()
#             self.simulate(particle,self.tree,0)
        else:
            self.prune_tree(obs)
            
        for _ in range(self.nsims):
            particle = self.draw_sample()
            self.simulate(particle,self.tree,0)
        
        child = self.greedy_action_selection(self.tree,self.game.legal_actions(self.tree)) # will again need to handle legal actions differently for real
        self.tree = child # move forward to child action node (will move to obs node when real obs occurs)
        self.history += [child.action]
        
        return child.action
    
    def simulate(self,state,tree,depth):
        if depth >= self.maxdepth:
            return 0
        
#         legal = self.game.legal_actions(state,tree,depth)
        legal = self.game.legal_actions(tree) # would want it to be more elegant/complicated for real
    
        if len(tree.children) == 0:
            tree.expand(legal)
            return self.rollout(state,depth)
        
        if len(legal)==1:
            action = legal[0]
            child = tree.children[0]
        else:
            child = self.ucb_action_selection(tree,legal)
            action = child.action
            
        next_state, next_obs, r, done = self.G(state,action)
        next_tree = tree.next_hist(action,next_obs)
        reward = r + self.discount * self.simulate(next_state,next_tree,depth+1)
        
        tree.belief += [state] 
        tree.visits += 1
        
        child.visits += 1
        child.value += (reward - child.value)/child.visits
        
        return reward
    
    def rollout(self,state,depth):
        if depth >= self.maxdepth:
            return 0
        
        legal = self.game.legal_actions(["whatever but change this later"])
        a = random.choice(legal)
        
        next_state, next_obs, r, done = self.G(state,a)
#         next_tree = tree.next_hist(a,next_obs)
        
        if done:
            return r
        
        return r + self.discount * self.rollout(next_state,depth+1)
        
    def prune_tree(self,obs):
        #current tree is an action node. find child node with observation obs
        obs_child = next((c for c in self.tree.children if c.observation == obs), None) 
        self.tree = obs_child
        return
        
    def greedy_action_selection(self,tree,legal):
        children = [child for child in tree.children if child.action in legal] #filter(lambda child: child.action in legal_actions, tree.children)
        child_vals = np.array([child.value for child in children])
        favechildren = np.argwhere(child_vals == np.amax(child_vals))
        child = children[random.choice(favechildren.flatten().tolist())]
        return child
        
    def ucb_action_selection(self,tree,legal):
        children = [child for child in tree.children if child.action in legal] #filter(lambda child: child.action in legal_actions, tree.children)
        child_vals = np.array([child.value + self.explore * tree.ucb(child) for child in children])
        favechildren = np.argwhere(child_vals == np.amax(child_vals))
        child = children[random.choice(favechildren.flatten().tolist())]
        return child
    
    def draw_sample(self):
        if self.tree.belief == []:
            return self.game.sample_prior()
        else:
            return random.choice(self.tree.belief)

In [29]:
s = ('tiger-left' if random.random() < 0.5 else 'tiger-right')
game = TigerProblem()
agent = POMCP(game, 0.9, maxdepth=20, nsims=1000)

In [30]:
# Check that the tree runs simulations to choose the next action using the choose_move() method

print("INITIAL STATE: ")
print(s) # initial tiger problem state

obs = []

action = agent.search(obs)
print("Taking action:", action)
state = game.next_state(s,action)
game.reward(s,action)

INITIAL STATE: 
tiger-right
Taking action: listen


-1

In [44]:
# if next obs is tiger-left,
observeright = 1 #1 if observes right on next observation, 0 if observes left
nparticles = len(agent.tree.children[observeright].belief)
nleftbelief = len([b for b in agent.tree.children[observeright].belief if b=='tiger-left'])
nrightbelief = len([b for b in agent.tree.children[observeright].belief if b=='tiger-right'])
print(['left belief:',nleftbelief])
print(['right belief:',nrightbelief])
print(['left belief percentage:',nleftbelief/nparticles])

['left belief:', 8340]
['right belief:', 9065]
['left belief percentage:', 0.4791726515369147]


In [None]:
# so clearly i still need to limit how many particles are added to each observation node but also like, 
# this doesn't seem right, the beliefs should be different given the real observation that happens next
# ALSO, when there are many many many more possible actions i might not really need to limit it that much


In [49]:
# Now, let the POMCP do its thing for several actions in a row

# Initialize the tiger problem
s = ('tiger-left' if random.random() < 0.5 else 'tiger-right')
game = TigerProblem() #obs_truth=0.95)
    
print("INITIAL STATE: ")
print(s) # initial tiger state

agent = POMCP(game, 0.9, maxdepth=20, nsims=1000)

action_seq = []
state = s
obs = []
R = 0


while len(action_seq) <= 20: # play for a certain amount of time (better rule?)

    action = agent.search(obs)
    action_seq.append(action)
    print('Action %i: True state is %s'% (len(action_seq), state))
    print("Taking action %s."% action)
    
    obs = game.observation(state,action)
    if obs!=[]:
        print("observed ",obs)

    r = game.reward(state,action)
    R = R + r
    print("Reward so far: ",R)    
    
    state = game.next_state(state,action)
    

    
print("game over!")

INITIAL STATE: 
tiger-left
Action 1: True state is tiger-left
Taking action listen.
observed  tiger-left
Reward so far:  -1
Action 2: True state is tiger-left
Taking action open-right.
Reward so far:  9
Action 3: True state is tiger-left
Taking action listen.
observed  tiger-left
Reward so far:  8
Action 4: True state is tiger-left
Taking action listen.
observed  tiger-left
Reward so far:  7
Action 5: True state is tiger-left
Taking action listen.
observed  tiger-left
Reward so far:  6
Action 6: True state is tiger-left
Taking action open-right.
Reward so far:  16
Action 7: True state is tiger-left
Taking action listen.
observed  tiger-left
Reward so far:  15
Action 8: True state is tiger-left
Taking action open-right.
Reward so far:  25
Action 9: True state is tiger-right
Taking action listen.
observed  tiger-right
Reward so far:  24
Action 10: True state is tiger-right
Taking action open-left.
Reward so far:  34
Action 11: True state is tiger-left
Taking action open-right.
Reward so 