In [5]:
import random
import time
import copy
from collections import deque

import numpy as np

import torch as tc
import torch.nn as nn
import torch.optim as opt
import torch.distributions as tcdist
import torch.nn.functional as F

from model import SnakeNet
from core import CUDA_AVAILABLE, DEVICE
from config import NROW,NCOL
from env import Env
from util import state2input

In [6]:
#Train Parameters

BATCH_SIZE = 32
DISC_RATIO=0.99
GRADCLIP_NORM=1
L2_DECAY=1e-6
LEARNING_RATE=1e-5
STAT_DISPLAY_FREQ = 50
SAVE_TEMP_FREQ = 300
TRAIN_CNT=10000

In [7]:
import time as timelib
from PIL import Image
from PIL import ImageDraw
from torchvision import transforms
import matplotlib.pyplot as plt
from IPython import display
def render(state,sleep_time=0,clear=True):
    #print(obs)
    txt = Image.new("RGBA", (NCOL*30,NROW*30), (10,50,100,100))
    draw = ImageDraw.Draw(txt)
    grid_size = 30
    for i in range(NROW):
        for j in range(NCOL):
            y1 = i*grid_size
            x1 = j*grid_size
            y2 = (i+1)*grid_size
            x2 = (j+1)*grid_size
            draw.rectangle(((x1, y1), (x2, y2)), outline='black', width=1)
    for y,x in state['foods']:
        draw.ellipse((x*grid_size, y*grid_size, (x+1)*grid_size, (y+1)*grid_size), fill = 'yellow', outline ='yellow')
    for y,x in state['snake']:
        if (y,x)==state['snake'][-1]:
            draw.polygon([((x+1/2)*grid_size,y*grid_size),(x*grid_size,(y+1/2)*grid_size),
                ((x+1/2)*grid_size,(y+1)*grid_size), ((x+1)*grid_size,(y+1/2)*grid_size)], fill = 'red')
        else:
            draw.rectangle(((x*grid_size, y*grid_size), ((x+1)*grid_size, (y+1)*grid_size)), fill='red', outline='red')
    txt = txt.resize((64, 64*NROW//NCOL)).convert('RGB')
    numpy_image = np.array(txt)
    plt.axis("off")
    trans1 = transforms.ToTensor()
    tensor_image = trans1(numpy_image)
    tf = transforms.ToPILImage()
    plt.imshow(tf(tensor_image))
    plt.show()
    plt.close()
    if sleep_time:
        timelib.sleep(sleep_time)
    if clear:
        display.clear_output(wait=True)
    return tensor_image

In [8]:
print("CUDA: ",CUDA_AVAILABLE)
net=SnakeNet().cuda() if CUDA_AVAILABLE else SnakeNet()
net.train()

opter=opt.Adam(net.parameters(),lr=LEARNING_RATE,weight_decay=L2_DECAY)
losses = []
scores = []
probmaxs = []

def train(bat):
    s1bat=tc.tensor([state2input(s1) for (s1,a,r,s2) in bat]).to(DEVICE)
    abat=tc.tensor([a for (s1,a,r,s2) in bat]).to(DEVICE)
    rbat=tc.tensor([r for (s1,a,r,s2) in bat]).to(DEVICE)
    s2bat=tc.tensor([state2input(s2) for (s1,a,r,s2) in bat]).to(DEVICE)
    dbat=tc.tensor([int(s2['done']) for (s1,a,r,s2) in bat]).to(DEVICE)

    td_target=rbat+DISC_RATIO*((1-dbat)*net.calcval(s2bat).squeeze())
    delta=td_target-net.calcval(s1bat).squeeze()

    polraw = net.calcpol(s1bat)
    pol_a = F.softmax(polraw,dim=1).gather(1,abat.unsqueeze(dim=1)).squeeze()
    loss_val = ((net.calcval(s1bat).squeeze()-td_target.detach())**2).sum()
    loss_pol = (-tc.log(pol_a) * delta.detach()).sum()
    loss_entropy = (F.log_softmax(polraw,dim=1)*F.softmax(polraw,dim=1)).sum()
    loss = loss_pol + 0.5*loss_val + 0.01*loss_entropy

    opter.zero_grad()
    loss.backward()
    nn.utils.clip_grad_norm_(net.parameters(),GRADCLIP_NORM)
    opter.step()
    
    losses.append(float(loss.mean()))

envs = [Env(True) for _ in range(BATCH_SIZE)]
for traini in range(1,TRAIN_CNT+1):
    batch = []
    for env in envs:
        s1=env.state
        if not s1 or s1['done']:
            scores.append(env.score)
            s1=env.reset()
        # if env==envs[0]:
        #     print(env.score)
        #     render(s1,0,True)
        pol = F.softmax(net.calcpol(tc.tensor(state2input(s1)).to(DEVICE)),dim=1)
        probmaxs.append(float(tc.max(pol)))
        a = tcdist.Categorical(pol).sample().item()
        s2,rwd=env.step(a)
        batch.append((s1,a,rwd,s2))
    train(batch)
    if traini%SAVE_TEMP_FREQ==0:
        tc.save(net.state_dict(),'./netw.pt')
        display.clear_output(wait=True)
    if traini%STAT_DISPLAY_FREQ==0:
        print("#{} ({:.2f}%): LossAvg={:.3f}, MaxprobAvg={:.3f}, ScoreMax={:.2f}, ScoreAvg={:.2f}, ScoreMed={:.2f}, ScoreMin={:.2f}".format(
            traini,
            traini/TRAIN_CNT*100,
            sum(losses)/len(losses) if losses else -1,
            sum(probmaxs)/len(probmaxs) if probmaxs else -1,
            max(scores) if scores else -1,
            sum(scores)/len(scores) if scores else -1,
            sorted(scores)[len(scores)//2] if scores else -1,
            sorted(scores)[0] if scores else -1))
        losses=[]
        probmaxs=[]
        scores=[]
#save
tc.save(net.state_dict(),'./netw.pt')
print("DONE!!!")

#9900 (99.00%): LossAvg=-10.279, MaxprobAvg=0.947, ScoreMax=26.33, ScoreAvg=23.07, ScoreMed=22.01, ScoreMin=21.73
#9950 (99.50%): LossAvg=-14.578, MaxprobAvg=0.937, ScoreMax=29.33, ScoreAvg=22.15, ScoreMed=21.79, ScoreMin=15.20
#10000 (100.00%): LossAvg=-11.567, MaxprobAvg=0.935, ScoreMax=27.33, ScoreAvg=21.66, ScoreMed=22.87, ScoreMin=8.59
DONE!!!


In [9]:
%debug

> [0;32m/home/loboprix/.local/lib/python3.8/site-packages/torch/nn/functional.py[0m(4435)[0;36m_pad_circular[0;34m()[0m
[0;32m   4433 [0;31m        [0min_h1[0m [0;34m=[0m [0min_shape[0m[0;34m[[0m[0;36m3[0m[0;34m][0m [0;34m-[0m [0mmax[0m[0;34m([0m[0;34m-[0m[0mpadding[0m[0;34m[[0m[0;34m-[0m[0;36m3[0m[0;34m][0m[0;34m,[0m [0;36m0[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m   4434 [0;31m[0;34m[0m[0m
[0m[0;32m-> 4435 [0;31m        [0mout[0m[0;34m[[0m[0;34m...[0m[0;34m,[0m [0mout_d0[0m[0;34m:[0m[0mout_d1[0m[0;34m,[0m [0mout_h0[0m[0;34m:[0m[0mout_h1[0m[0;34m][0m [0;34m=[0m [0minput[0m[0;34m[[0m[0;34m...[0m[0;34m,[0m [0min_d0[0m[0;34m:[0m[0min_d1[0m[0;34m,[0m [0min_h0[0m[0;34m:[0m[0min_h1[0m[0;34m][0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m   4436 [0;31m    [0;32melif[0m [0mndim[0m [0;34m==[0m [0;36m3[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m   4437 [0;31m        [0mout

SyntaxError: invalid syntax (<ipython-input-10-70b6e365df8e>, line 1)