In [None]:
import random
import time
import copy
from collections import deque

import numpy as np

import torch as tc
import torch.optim as opt

from model import SnakeNet
from core import CUDA_AVAILABLE, DEVICE
from config import NROW,NCOL,EPISODE_MAXLEN
from per import PER
from env import Env

In [None]:
#Train Parameters

BATCH_SIZE = 16
DISC_RATIO=0.9
EPISODE_CNT=20000
EPS_EXPONENT=3.0
EPS_MAX=1.0
EPS_MIN=0.01
LEARN_FREQ=1
LEARNING_RATE=1e-4
REPLAY_MEM_SIZE = 500000
STAT_DISPLAY_FREQ = 10
SAVE_TEMP_FREQ = 100
TARGET_UPD_FREQ = LEARN_FREQ*1000

In [None]:
import time as timelib
from PIL import Image
from PIL import ImageDraw
from torchvision import transforms
import matplotlib.pyplot as plt
from IPython import display
def render(state,sleep_time=0,clear=True):
    #print(obs)
    txt = Image.new("RGBA", (NCOL*30,NROW*30), (10,50,100,100))
    draw = ImageDraw.Draw(txt)
    grid_size = 30
    for i in range(NROW):
        for j in range(NCOL):
            y1 = i*grid_size
            x1 = j*grid_size
            y2 = (i+1)*grid_size
            x2 = (j+1)*grid_size
            draw.rectangle(((x1, y1), (x2, y2)), outline='black', width=1)
    for y,x in state['foods']:
        draw.ellipse((x*grid_size, y*grid_size, (x+1)*grid_size, (y+1)*grid_size), fill = 'yellow', outline ='yellow')
    for y,x in state['snake']:
        if (y,x)==state['snake'][-1]:
            draw.polygon([((x+1/2)*grid_size,y*grid_size),(x*grid_size,(y+1/2)*grid_size),
                ((x+1/2)*grid_size,(y+1)*grid_size), ((x+1)*grid_size,(y+1/2)*grid_size)], fill = 'red')
        else:
            draw.rectangle(((x*grid_size, y*grid_size), ((x+1)*grid_size, (y+1)*grid_size)), fill='red', outline='red')
    txt = txt.resize((64, 64*NROW//NCOL)).convert('RGB')
    numpy_image = np.array(txt)
    plt.axis("off")
    trans1 = transforms.ToTensor()
    tensor_image = trans1(numpy_image)
    tf = transforms.ToPILImage()
    plt.imshow(tf(tensor_image))
    plt.show()
    plt.close()
    if sleep_time:
        timelib.sleep(sleep_time)
    if clear:
        display.clear_output(wait=True)
    return tensor_image

In [None]:
def stateTransform(state,flipy,flipx,deltay,deltax):
    def fy(y): return (NROW-1-y+deltay if flipy else y+deltay)%NROW
    def fx(x): return (NCOL-1-x+deltax if flipx else x+deltax)%NCOL
    ret=state#copy.deepcopy(state)
    ret['snake']=[(fy(y),fx(x)) for y,x in state['snake']]
    ret['foods']=[(fy(y),fx(x)) for y,x in state['foods']]
    return ret
def actTransform(act,flipy,flipx):
    if act%2: return (act+2)%4 if flipx else act #1,3
    else: return (act+2)%4 if flipy else act  #0,2
transformOptions=[(i,j,y,x) for i in range(2) for j in range(2) for y in range(NROW) for x in range(NCOL)]

def state2input(state):
    #4 = shape, head pos, tailpos, Food Pos
    ret=[[[0]*NCOL for _ in range(NROW)] for __ in range(4)]
    snake=state['snake']
    for i,(y,x) in enumerate(snake):
        ret[0][y][x]=(i+1)/len(snake)
    heady,headx = snake[-1]
    ret[1][heady][headx]=1.
    taily,tailx = snake[0]
    ret[2][taily][tailx]=1.
    for y,x in state['foods']:
        ret[3][y][x]=1.
    #ret[4]=[[1-state['time']/EPISODE_MAXLEN]*NCOL for _ in range(NROW)]
    return ret

print("CUDA: ",CUDA_AVAILABLE)
net=SnakeNet().cuda() if CUDA_AVAILABLE else SnakeNet()
net_target=copy.deepcopy(net)
net_target.load_state_dict(net.state_dict())
net.train()
net_target.eval()

opter=opt.Adam(net.parameters(),lr=LEARNING_RATE)
perm = PER(REPLAY_MEM_SIZE,alpha=0.8,beta=0.3)

def train():
    idxs,isws,bat=zip(*perm.sample(BATCH_SIZE,epi/EPISODE_CNT))
    
    s1bat=tc.tensor([state2input(s1) for (s1,a,r,s2) in bat]).to(DEVICE)
    abat=tc.tensor([a for (s1,a,r,s2) in bat]).to(DEVICE)
    rbat=tc.tensor([r for (s1,a,r,s2) in bat]).to(DEVICE)
    s2bat=tc.tensor([state2input(s2) for (s1,a,r,s2) in bat]).to(DEVICE)
    dbat=tc.tensor([int(s2['done']) for (s1,a,r,s2) in bat]).to(DEVICE)

    #tfoptbat=tc.tensor([ np.random.randint(0,len(transformOptions)) for _ in range(len(bat)) ]).to(DEVICE)

    q1=net(s1bat)
    with tc.no_grad():
        q2t=net_target(s2bat)
        q2=net(s2bat)
    x=q1.gather(1,abat.unsqueeze(dim=1)).squeeze()
    actidxs_from_net=tc.argmax(q2,dim=1)
    actidx_for_take=tc.tensor([i*q2t.shape[1]+actidxs_from_net[i] for i in range(BATCH_SIZE)]).to(DEVICE)
    y=rbat+DISC_RATIO*((1-dbat)*q2t.take(actidx_for_take))

    loss=(tc.tensor(isws).to(DEVICE)*(x-y)**2).mean()
    opter.zero_grad()
    loss.backward()
    opter.step()

eps=EPS_MAX
env = Env()
for epi in range(1,EPISODE_CNT+1):
    s1=env.reset()
    while not s1['done']:
        #render(s1,0.1,True)
        q1=net(tc.tensor(state2input(s1)).to(DEVICE))
        actidx=(np.random.randint(0,3) if random.random()<eps else np.argmax(q1.cpu().detach().numpy()))
        val1=q1[0][actidx]
        s2,rwd=env.step(actidx)
        with tc.no_grad():
            q2t=net_target(tc.tensor(state2input(s2)).to(DEVICE))
            q2=net(tc.tensor(state2input(s2)).to(DEVICE))
        val2=rwd+DISC_RATIO*((1-s2['done'])*q2t[0][tc.argmax(q2,dim=1)])
        td=val2-val1
        perm.push(td,(s1,actidx,rwd,s2))
        s1=s2
        
        if s1['time']%LEARN_FREQ==0 and perm.cnt>=BATCH_SIZE:
            train()
        if s1['time']%TARGET_UPD_FREQ==0:
            net_target.load_state_dict(net.state_dict())
            net_target.eval()
    eps=min(EPS_MAX,max(EPS_MIN, ((EPISODE_CNT-epi)/EPISODE_CNT)**EPS_EXPONENT ))
        
    if epi%SAVE_TEMP_FREQ==0:
        tc.save(net.state_dict(),'./netw.pt')
        from IPython.display import clear_output
        clear_output(wait=True)
    if epi%STAT_DISPLAY_FREQ==0:
        print("{}/{}({:.2f}%): LossAvg={:.4f} RwdAvg={:.4f}".format(
            epi,
            EPISODE_CNT,
            epi/EPISODE_CNT*100,
            0.,
            0.))
#saveL
tc.save(net.state_dict(),'./netw.pt')
print("DONE!!!")