## Setup

In [1]:
import torch
import torch.nn as nn
from torch.autograd import Variable

In [12]:
import copy
from datetime import datetime
from FireSimulator import *
from FireSimulatorUtilities import *
import glob
import itertools
import matplotlib.patches as patches
import matplotlib.patheffects as PathEffects
import matplotlib.pyplot as plt
import numpy as np
import os
import pickle
import sys
import time

%matplotlib inline
plt.rcParams['figure.figsize'] = (10.0, 8.0)
# plt.rcParams['image.interpolation'] = 'nearest'
# plt.rcParams['image.cmap'] = 'gray'

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


## Heuristic Solution

In [3]:
def heuristic(agent_id, pos, img_st, seen_fire, center, close_agent_id, close_pos):

    traj = []
    actions = []
    traj.append((pos[0],pos[1]))
    img_dim = img_st.shape[0]
    fire_neigh = [(-1,0),(0,-1),(1,0),(0,1)]
    move_neigh = [(-1,0),(1,0),(-1,1),(0,1),(1,1),(-1,-1),(0,-1),(1,-1)] #excluded (0,0)
    action_set = [4,1,2,3,5,8,7,6]
    
    dists = None
    x,y = pos

    r = img_dim//2
    c = img_dim//2

    if img_st[r,c] in [1,2] or seen_fire:
        seen_fire = True
        dists = []
        
        cen_vec = np.array([x-center,y-center])
        cen_vec = cen_vec/np.linalg.norm(cen_vec)
        for a in range(1,9):
        #for a in [2,5,7,4]:
            new_pos = actions_to_trajectory(traj[-1],[a])[1]

            rl = -new_pos[1] + y + img_dim//2
            cl = new_pos[0] -x + img_dim//2
            #if a in [1,3,8,6] and img_st[rl,cl] in [0]:
            #    continue

            move_vec = np.array([new_pos[0]-x,new_pos[1]-y])
            if a != 0:
                move_vec = move_vec/np.linalg.norm(move_vec)
            dists.append((np.cross(cen_vec,move_vec),new_pos,a))

        cir_pos = min(dists)[1]
        cir_act = min(dists)[2]

        ri = -cir_pos[1] + y + img_dim//2
        ci = cir_pos[0] -x + img_dim//2

        left_act = None
        if cir_act==1:
            left_act = [6,4]
            #left_act = [4]
            righ_act = [2]
        elif cir_act==2:
            left_act = [4,1]
            #left_act = [1]
            righ_act = [3]
        elif cir_act==3:
            left_act = [1,2]
            #left_act = [2]
            righ_act = [5]
        elif cir_act==5:
            left_act = [2,3]
            #left_act = [3]
            righ_act = [8]
        elif cir_act==8:
            left_act = [3,5]
            #left_act = [5]
            righ_act = [7]
        elif cir_act==7:
            left_act = [5,8]
            #left_act = [8]
            righ_act = [6]
        elif cir_act==6:
            left_act = [8,7]
            #left_act = [7]
            righ_act = [4]
        elif cir_act==4:
            left_act = [7,6]
            #left_act = [6]
            righ_act = [1]

        out = False
        for a in left_act:
            new_pos = actions_to_trajectory(traj[-1],[a])[1]
            ro = -new_pos[1] + y + img_dim//2
            co = new_pos[0] - x + img_dim//2
            if img_st[ro,co] in [1]:
                cir_pos = new_pos
                cir_act = a
                out = True
                break

        if not out:
            for a in left_act:
                new_pos = actions_to_trajectory(traj[-1],[a])[1]
                ro = -new_pos[1] + y + img_dim//2
                co = new_pos[0] - x + img_dim//2
                if img_st[ro,co] in [2]:
                    cir_pos = new_pos
                    cir_act = a
                    out = True
                    break

        counter = 0
        for (dr,dc) in move_neigh:
            rn = ri + dr
            cn = ci + dc
            if rn>=0 and rn<img_dim and cn>=0 and cn<img_dim and img_st[rn,cn] in [0]:
                counter += 1

        if not out and img_st[ri,ci] in [0] and counter>=6:
            for a in righ_act:
                new_pos = actions_to_trajectory(traj[-1],[a])[1]
                cir_pos = new_pos
                cir_act = a

        if np.linalg.norm(cir_pos-close_pos,2)<=1 and agent_id > close_agent_id:
            cir_pos = traj[-1]
            cir_act = 0

        traj.append(cir_pos)
        actions.append(cir_act)      

    if not seen_fire:
        dists = []
        #for a in range(9):
        for idx,a in enumerate([2,5,7,4,1,3,8,6]):
            new_pos = actions_to_trajectory(traj[-1],[a])[1]
            incntv = -(8-idx)*0.1
            dists.append((np.abs(center-new_pos[0])+np.abs(center-new_pos[1])+incntv,new_pos,a))

        #print(dists)
        #print()
        score, pos, act = min(dists)
        traj.append(pos)
        actions.append(act)
        
    return traj, actions, seen_fire

## DQN architecture

In [4]:
torch.cuda.is_available()

True

In [5]:
dtype = torch.cuda.FloatTensor

In [6]:
class eelfff(nn.Module):
    
    def __init__(self, img_dim=8):
        super(eelfff, self).__init__()
        self.img_dim = img_dim
        
        # inputs: image + rot vec + id compare + pos-other_pos
        self.net = nn.Sequential(
                                nn.Linear(self.img_dim**2 + 2 + 1 + 2, 2048),
                                nn.ReLU(inplace=True),
                                nn.Linear(2048, 2048),
                                nn.ReLU(inplace=True),
                                nn.Linear(2048, 9)
                            )

    def forward(self, feat):
        return self.net(feat)


test network with random data

In [7]:
tic = time.clock()
N = 4
img_dim = 3

model = eelfff(img_dim).type(dtype)
feat = Variable(torch.randn(N,img_dim**2+2+1+2)).type(dtype)
Q = model(feat)
toc = time.clock()

print(Q.size())
print("%0.2fs = %0.2fm elapsed for this test" %(toc-tic,(toc-tic)/60))

torch.Size([4, 9])
342.12s = 5.70m elapsed for this test


load network from file

In [8]:
filename = 'simple_ext-26-Aug-2017-15:17.pth.tar'

checkpoint = torch.load(filename)
model.load_state_dict(checkpoint['state_dict'])

## Code to simulate with MADQN agent actions

In [57]:
seed = 100
plotting = True
# 50
# s = 100, a = 20

grid_size = 50
num_agents = 12
capacity = 10

newalpha = None
newbeta = None
# newalpha = 0.3
# newbeta = 0.92

base_station = np.array([5,5])
repeat_lim = 6
dp = 0.15/0.2763

center = (grid_size+1)/2
spawn_loc = np.arange(grid_size//3//2,grid_size,grid_size//3)
perturbs = np.arange(-grid_size//3//2+1,grid_size//3//2+1,1)

In [58]:
st = datetime.today().strftime('%Y-%m-%d %H:%M:%S')
print('[%s] start' % st)

tic = time.clock()

s = seed
np.random.seed(1000+s)

# initialize simulator
sim = FireSimulator(grid_size, rng=s)
if newalpha is not None:
    sim.alpha = newalpha
if newbeta is not None:
    sim.beta = newbeta
sim.step([])
    
n = num_agents
agent_pos = np.random.choice(spawn_loc, (n,2)) + np.random.choice(perturbs, (n,2))
agent_pos = agent_pos.astype(np.int32)

control = []
repeat_ctr = 1
agent_data = {}
for k in range(n):
    agent_data[k] = {}
    agent_data[k]['sf'] = False
    if capacity is not None:
        agent_data[k]['cap'] = capacity
        
# create plotting objects, plot initial condition
if plotting:
    plt_ctr = 1
    fig = plt.figure()
    ax = fig.add_subplot(111, aspect='equal')
    plt.tick_params(axis='both', which='both',
                    labelbottom='off', labelleft='off',
                    bottom='off', left='off')
    ax.set_xlim([0,grid_size+1])
    ax.set_ylim([0,grid_size+1])

    base_rect = plt.Rectangle((base_station[0]-0.5,base_station[0]-0.5),1,1,edgecolor='none',facecolor='C9')

    plot_rects = {}
    for i in range(grid_size):
        for j in range(grid_size):
            x = col_to_x(j)
            y = row_to_y(grid_size,i)

            plot_rects[(i,j)] = plt.Rectangle((x-0.5,y-0.5),1,1,alpha=0.6,edgecolor='none')
            if sim.state[i,j] == 0:
                color = 'g'
            elif sim.state[i,j] == 1:
                color = 'r'
            elif sim.state[i,j] == 2:
                color = 'k'

            plot_rects[(i,j)].set_facecolor(color)
            ax.add_patch(plot_rects[(i,j)])

    ax.add_patch(base_rect)
    for k in range(num_agents):
        plt.plot(agent_pos[k,0],agent_pos[k,1],"bo",markersize=4)

    plt.savefig('img/sim%0.4d.png' %(plt_ctr), dpi=300, bbox_inches='tight')
    plt_ctr += 1
    plt.close()

# run simulation        
while not sim.end:
    new_agent_pos = np.zeros((n,2)).astype(np.int32)
    
    # calculate action for each agent
    for k in range(num_agents):
        img, img_st, _ = CreateImageBW(sim.state, agent_pos[k,:])
        if img_st[8//2,8//2] in [1,2]:
            agent_data[k]['sf'] = True                

        dists = [(np.linalg.norm(agent_pos[k,:]-p,2),j,p) for j,p in enumerate(agent_pos) if j!=k]
        min_dist, min_id, min_pos = min(dists)                

        if not agent_data[k]['sf']:

            dists = []
            for idx,a in enumerate([2,5,7,4,1,3,8,6]):
                new_pos = actions_to_trajectory(agent_pos[k,:],[a])[1]
                incntv = -(8-idx)*0.1
                dists.append((np.abs(center-new_pos[0])+np.abs(center-new_pos[1])+incntv,new_pos,a))

            score, pos, action = min(dists)
            traj = actions_to_trajectory(agent_pos[k,:], [action])

        else:
            rot_vec = agent_pos[k,:] - center
            rot_vec = rot_vec / np.linalg.norm(rot_vec,2)
            rot_vec = np.array([rot_vec[1],-rot_vec[0]])

            pos_vec = agent_pos[k,:] - min_pos
            if pos_vec[0]!=0 and pos_vec[1]!= 0:
                pos_vec = pos_vec / np.linalg.norm(pos_vec,2)

            state = np.concatenate((img[3:6,3:6].reshape((img_dim**2,)), rot_vec, 
                                    np.asarray(k>min_id)[np.newaxis], pos_vec))

            state = Variable(torch.from_numpy(state)).type(dtype)
            Q = model(state.unsqueeze(0))[0].data.cpu().numpy()
            action = np.argmax(Q)
            traj = actions_to_trajectory(agent_pos[k,:], [action])

        # generate control from trajectory
        # account for capacity constraint
        agent_control = FindGridIntersections(sim.state, traj)
        for el in agent_control:
            control.extend([el]) 
            if capacity is not None:
                agent_data[k]['cap'] -= 1
                if agent_data[k]['cap'] <= 0:
                    break

        control = list(set(control))

        # update agent location
        if capacity is not None and agent_data[k]['cap'] <= 0:
            agent_data[k]['sf'] = False
            agent_data[k]['cap'] = capacity
            new_agent_pos[k,:] = base_station
        else:
            new_agent_pos[k,:] = [traj[-1][0], traj[-1][1]]

    # update simulator periodically 
    if repeat_ctr % repeat_lim == 0:
        sim.step(control, dbeta=dp)
        control = []
    repeat_ctr += 1

    # update agent position
    agent_pos = new_agent_pos
    new_agent_pos = np.zeros((n,2)).astype(np.int32)
    
    # plot new information
    if plotting:
        fig = plt.figure()
        ax = fig.add_subplot(111, aspect='equal')
        plt.tick_params(axis='both', which='both',labelbottom='off', labelleft='off',bottom='off', left='off')
        ax.set_xlim([0,grid_size+1])
        ax.set_ylim([0,grid_size+1])

        for i in range(grid_size):
            for j in range(grid_size):
                x = col_to_x(j)
                y = row_to_y(grid_size,i)

                if sim.state[i,j] == 0:
                    color = 'g'
                elif sim.state[i,j] == 1:
                    color = 'r'
                elif sim.state[i,j] == 2:
                    color = 'k'

                plot_rects[(i,j)].set_facecolor(color)
                patch = copy.copy(plot_rects[(i,j)])
                patch.axes = None
                patch.figure = None
                patch.set_transform(ax.transData)
                ax.add_patch(patch)

        base_patch = copy.copy(base_rect)
        base_patch.axes = None
        base_patch.figure = None
        base_patch.set_transform(ax.transData)
        ax.add_patch(base_patch)
        for k in range(num_agents):
            plt.plot(agent_pos[k,0],agent_pos[k,1],"bo",markersize=4)

        plt.savefig('img/sim%04d.png' %(plt_ctr), dpi=300, bbox_inches='tight')
        plt_ctr += 1
        plt.close()

datetime.today().strftime('%Y-%m-%d %H:%M:%S')
print('[%s] finish' % st)

toc = time.clock()
print("%0.2fs = %0.2fm elapsed" % (toc-tic,(toc-tic)/60))

print('number of model updates: %d' %sim.iter)
print('remaining healthy trees: %0.4f' %(sim.stats[0]/sum(sim.stats)))

[2018-02-22 19:58:10] start
[2018-02-22 19:58:10] finish
1452.39s = 24.21m elapsed
number of model updates: 42
remaining healthy trees: 0.9824


In [59]:
for _ in range(repeat_lim):
    # plot new information
    fig = plt.figure()
    ax = fig.add_subplot(111, aspect='equal')
    plt.tick_params(axis='both', which='both',labelbottom='off', labelleft='off', bottom='off', left='off')
    ax.set_xlim([0,grid_size+1])
    ax.set_ylim([0,grid_size+1])

    for i in range(grid_size):
        for j in range(grid_size):
            x = col_to_x(j)
            y = row_to_y(grid_size,i)

            if sim.state[i,j] == 0:
                color = 'g'
            elif sim.state[i,j] == 1:
                color = 'r'
            elif sim.state[i,j] == 2:
                color = 'k'

            plot_rects[(i,j)].set_facecolor(color)
            patch = copy.copy(plot_rects[(i,j)])
            patch.axes = None
            patch.figure = None
            patch.set_transform(ax.transData)
            ax.add_patch(patch)
            
    for k in range(num_agents):
        plt.plot(agent_pos[k,0],agent_pos[k,1],"bo",markersize=4)

    plt.savefig('img/sim%04d.png' %(plt_ctr), dpi=300, bbox_inches='tight')
    plt_ctr += 1
    plt.close()