In [None]:
import random
import time
import copy
from collections import deque

import numpy as np

import torch as tc
import torch.optim as opt
import torch.distributions as tcdist

from model import SnakeNet
from core import CUDA_AVAILABLE, DEVICE
from config import NROW,NCOL,EPISODE_MAXLEN
from env import Env
from util import state2input

import time as timelib
from PIL import Image
from PIL import ImageDraw
from torchvision import transforms
import matplotlib.pyplot as plt
from IPython import display

def render(state,sleep_time=0,clear=True,save_name=None):
    #print(obs)
    txt = Image.new("RGBA", (NCOL*30,NROW*30), (10,50,100,100))
    draw = ImageDraw.Draw(txt)
    grid_size = 30
    for i in range(NROW):
        for j in range(NCOL):
            y1 = i*grid_size
            x1 = j*grid_size
            y2 = (i+1)*grid_size
            x2 = (j+1)*grid_size
            draw.rectangle(((x1, y1), (x2, y2)), outline='black', width=1)
    for y,x in state['foods']:
        draw.ellipse((x*grid_size, y*grid_size, (x+1)*grid_size, (y+1)*grid_size), fill = 'yellow', outline ='yellow')
    for y,x in state['snake']:
        if (y,x)==state['snake'][-1]:
            draw.polygon([((x+1/2)*grid_size,y*grid_size),(x*grid_size,(y+1/2)*grid_size),
                ((x+1/2)*grid_size,(y+1)*grid_size), ((x+1)*grid_size,(y+1/2)*grid_size)], fill = 'red')
        else:
            draw.rectangle(((x*grid_size, y*grid_size), ((x+1)*grid_size, (y+1)*grid_size)), fill='red', outline='red')
    txt = txt.resize((64, 64*NROW//NCOL)).convert('RGB')
    if save_name:
        txt.save(save_name)
    numpy_image = np.array(txt)
    plt.axis("off")
    trans1 = transforms.ToTensor()
    tensor_image = trans1(numpy_image)
    tf = transforms.ToPILImage()
    plt.imshow(tf(tensor_image))
    plt.show()
    if sleep_time:
        timelib.sleep(sleep_time)
    if clear:
        plt.close()
        display.clear_output(wait=False)
    return tensor_image

def stateTransform(state,flipy,flipx,deltay,deltax):
    def fy(y): return (NROW-1-(y+deltay+NROW)%NROW if flipy else y+NROW+deltay)%NROW
    def fx(x): return (NCOL-1-(x+deltax+NCOL)%NCOL if flipx else x+NCOL+deltax)%NCOL
    ret=copy.deepcopy(state)
    ret['snake']=[(fy(y),fx(x)) for y,x in state['snake']]
    ret['foods']=[(fy(y),fx(x)) for y,x in state['foods']]
    return ret
def actTransform(act,flipy,flipx):
    if act%2: return (act+2)%4 if flipx else act #1,3
    else: return (act+2)%4 if flipy else act  #0,2

while True:
    net=SnakeNet().cuda() if CUDA_AVAILABLE else SnakeNet()
    net.load_state_dict(tc.load('./netw.pt'))
    net.eval()

    with tc.no_grad():
        env = Env(True)
        state=env.reset()
        # render(state,0.1,False)
        # tf = random.sample(transformOptions,1)[0]
        # print(tf)
        # render(stateTransform(state,*tf),0.1,False)
        rwdsum=0
        while not state['done']:
            fy,fx=0,0#np.random.randint(0,2),np.random.randint(0,2)
            pol = net.calcpol(tc.tensor(state2input(stateTransform(state,fy,fx,0,0))).to(DEVICE))
            a = tcdist.Categorical(pol).sample().item()
            state,rwd=env.step(actTransform(a,fy,fx))

            rwdsum+=rwd
            print(state['time'],rwdsum)
            print(pol)
            render(state,0.05,True)
            #render(state,0.3,True,'imgs/'+str(state['time'])+'.jpg')

In [None]:
import gym
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.distributions import Categorical

#Hyperparameters
learning_rate = 0.0002
gamma         = 0.98
n_rollout     = 10

losses=[]

class ActorCritic(nn.Module):
    def __init__(self):
        super(ActorCritic, self).__init__()
        self.data = []
        
        self.fc1 = nn.Linear(4,256)
        self.fc_pi = nn.Linear(256,2)
        self.fc_v = nn.Linear(256,1)
        self.optimizer = optim.Adam(self.parameters(), lr=learning_rate)
        
    def pi(self, x, softmax_dim = 0):
        x = F.relu(self.fc1(x))
        x = self.fc_pi(x)
        prob = F.softmax(x, dim=softmax_dim)
        return prob
    
    def v(self, x):
        x = F.relu(self.fc1(x))
        v = self.fc_v(x)
        return v
    
    def put_data(self, transition):
        self.data.append(transition)
        
    def make_batch(self):
        s_lst, a_lst, r_lst, s_prime_lst, done_lst = [], [], [], [], []
        for transition in self.data:
            s,a,r,s_prime,done = transition
            s_lst.append(s)
            a_lst.append([a])
            r_lst.append([r/100.0])
            s_prime_lst.append(s_prime)
            done_mask = 0.0 if done else 1.0
            done_lst.append([done_mask])
        
        s_batch, a_batch, r_batch, s_prime_batch, done_batch = torch.tensor(s_lst, dtype=torch.float), torch.tensor(a_lst), \
                                                               torch.tensor(r_lst, dtype=torch.float), torch.tensor(s_prime_lst, dtype=torch.float), \
                                                               torch.tensor(done_lst, dtype=torch.float)
        self.data = []
        return s_batch, a_batch, r_batch, s_prime_batch, done_batch
  
    def train_net(self):
        s, a, r, s_prime, done = self.make_batch()
        td_target = r + gamma * self.v(s_prime) * done
        delta = td_target - self.v(s)
        
        pi = self.pi(s, softmax_dim=1)
        pi_a = pi.gather(1,a)
        loss = -torch.log(pi_a) * delta.detach() + F.smooth_l1_loss(self.v(s), td_target.detach())

        self.optimizer.zero_grad()
        losses.append(float(loss.mean()))
        loss.mean().backward()
        self.optimizer.step()         
      
def main():  
    env = gym.make('CartPole-v1')
    model = ActorCritic()    
    print_interval = 20
    score = 0.0

    for n_epi in range(10000):
        done = False
        s = env.reset()
        while not done:
            for t in range(n_rollout):
                prob = model.pi(torch.from_numpy(s).float())
                m = Categorical(prob)
                a = m.sample().item()
                s_prime, r, done, info = env.step(a)
                model.put_data((s,a,r,s_prime,done))
                
                s = s_prime
                score += r
                
                if done:
                    break                     
            
            model.train_net()
            
        if n_epi%print_interval==0 and n_epi!=0:
            global losses
            print("# of episode :{}, avg loss: {:.5f}, avg score : {:.1f}".format(n_epi, sum(losses)/len(losses), score/print_interval))
            losses=[]
            score = 0.0
            if n_epi%50==0:
                print(model.state_dict())
    env.close()

if __name__ == '__main__':
    main()