In [None]:
import random
import time
import copy
from collections import deque

import numpy as np

import torch as tc
import torch.optim as opt

from model import SnakeNet
from core import CUDA_AVAILABLE, DEVICE
from config import NROW,NCOL
from per import PER
from env import Env
from util import state2input

In [None]:
#Train Parameters

BATCH_SIZE = 16
DISC_RATIO=0.99
EPISODE_CNT=10000
EPS_EXPONENT=2.7
EPS_MAX=1.0
EPS_MIN=0.01
L2_DECAY=1e-6
LEARN_FREQ=4
LEARNING_RATE=2e-4
REPLAY_MEM_SIZE = 1000000
STAT_DISPLAY_FREQ = 10
SAVE_TEMP_FREQ = 100
TARGET_UPD_FREQ = LEARN_FREQ*500

In [None]:
import time as timelib
from PIL import Image
from PIL import ImageDraw
from torchvision import transforms
import matplotlib.pyplot as plt
from IPython import display
def render(state,sleep_time=0,clear=True):
    #print(obs)
    txt = Image.new("RGBA", (NCOL*30,NROW*30), (10,50,100,100))
    draw = ImageDraw.Draw(txt)
    grid_size = 30
    for i in range(NROW):
        for j in range(NCOL):
            y1 = i*grid_size
            x1 = j*grid_size
            y2 = (i+1)*grid_size
            x2 = (j+1)*grid_size
            draw.rectangle(((x1, y1), (x2, y2)), outline='black', width=1)
    for y,x in state['foods']:
        draw.ellipse((x*grid_size, y*grid_size, (x+1)*grid_size, (y+1)*grid_size), fill = 'yellow', outline ='yellow')
    for y,x in state['snake']:
        if (y,x)==state['snake'][-1]:
            draw.polygon([((x+1/2)*grid_size,y*grid_size),(x*grid_size,(y+1/2)*grid_size),
                ((x+1/2)*grid_size,(y+1)*grid_size), ((x+1)*grid_size,(y+1/2)*grid_size)], fill = 'red')
        else:
            draw.rectangle(((x*grid_size, y*grid_size), ((x+1)*grid_size, (y+1)*grid_size)), fill='red', outline='red')
    txt = txt.resize((64, 64*NROW//NCOL)).convert('RGB')
    numpy_image = np.array(txt)
    plt.axis("off")
    trans1 = transforms.ToTensor()
    tensor_image = trans1(numpy_image)
    tf = transforms.ToPILImage()
    plt.imshow(tf(tensor_image))
    plt.show()
    plt.close()
    if sleep_time:
        timelib.sleep(sleep_time)
    if clear:
        display.clear_output(wait=True)
    return tensor_image

In [4]:
print("CUDA: ",CUDA_AVAILABLE)
net=SnakeNet().cuda() if CUDA_AVAILABLE else SnakeNet()
net_target=copy.deepcopy(net)
net_target.load_state_dict(net.state_dict())
net.train()
net_target.eval()

opter=opt.Adam(net.parameters(),lr=LEARNING_RATE,weight_decay=L2_DECAY)
perm = PER(REPLAY_MEM_SIZE,alpha=0.75,beta=0.4)
losses = []

def train():
    idxs,isws,bat=zip(*perm.sample(BATCH_SIZE,epi/EPISODE_CNT))
    
    s1bat=tc.tensor([state2input(s1) for (s1,a,r,s2) in bat]).to(DEVICE)
    abat=tc.tensor([a for (s1,a,r,s2) in bat]).to(DEVICE)
    rbat=tc.tensor([r for (s1,a,r,s2) in bat]).to(DEVICE)
    s2bat=tc.tensor([state2input(s2) for (s1,a,r,s2) in bat]).to(DEVICE)
    dbat=tc.tensor([int(s2['done']) for (s1,a,r,s2) in bat]).to(DEVICE)

    q1=net(s1bat)
    with tc.no_grad():
        q2t=net_target(s2bat)
        q2=net(s2bat)
    x=q1.gather(1,abat.unsqueeze(dim=1)).squeeze()
    actidxs_from_net=tc.argmax(q2,dim=1)
    actidx_for_take=tc.tensor([i*q2t.shape[1]+actidxs_from_net[i] for i in range(BATCH_SIZE)]).to(DEVICE)
    y=rbat+DISC_RATIO*((1-dbat)*q2t.take(actidx_for_take))

    loss=(tc.tensor(isws).to(DEVICE)*(x-y)**2).mean()
    opter.zero_grad()
    loss.backward()
    opter.step()
    
    losses.append(float(loss))

eps=EPS_MAX
env = Env(True)
for epi in range(1,EPISODE_CNT+1):
    s1=env.reset()
    while not s1['done']:
        #render(s1,0.1,True)
        q1=net(tc.tensor(state2input(s1)).to(DEVICE))
        action=(np.random.randint(0,3) if random.random()<eps else np.argmax(q1.cpu().detach().numpy()))
        val1=net(tc.tensor(state2input(s1)).to(DEVICE))[0][action]
        s2,rwd=env.step(action)
        with tc.no_grad():
            q2t=net_target(tc.tensor(state2input(s2)).to(DEVICE))
            q2=net(tc.tensor(state2input(s2)).to(DEVICE))
        val2=rwd+DISC_RATIO*((1-s2['done'])*q2t[0][tc.argmax(q2,dim=1)])
        perm.push(val2-val1,(s1,action,rwd,s2))
        s1=s2
        
        if s1['time']%LEARN_FREQ==0 and perm.cnt>=BATCH_SIZE:
            train()
        if s1['time']%TARGET_UPD_FREQ==0:
            net_target.load_state_dict(net.state_dict())
            net_target.eval()
    eps=min(EPS_MAX,max(EPS_MIN, ((EPISODE_CNT-epi)/EPISODE_CNT)**EPS_EXPONENT ))
        
    if epi%SAVE_TEMP_FREQ==0:
        tc.save(net.state_dict(),'./netw.pt')
        from IPython.display import clear_output
        clear_output(wait=True)
    if epi%STAT_DISPLAY_FREQ==0:
        print("{}/{}({:.2f}%): LossAvg={:.4f} EPS={:.4f}".format(
            epi,
            EPISODE_CNT,
            epi/EPISODE_CNT*100,
            sum(losses)/len(losses),
            eps))
        losses=[]
#saveL
tc.save(net.state_dict(),'./netw.pt')
print("DONE!!!")

10000/10000(100.00%): LossAvg=0.0019 EPS=0.0100
DONE!!!
