In [1]:
%load_ext autoreload
%autoreload 2

In [38]:
import numpy as np
import matplotlib.pyplot as plt
from torch import nn 
from torchvision import models
import torch 
import gym , sys
import cv2
from tqdm import tqdm

from rlxai import A2C
env = gym.make("Breakout-v0")

In [25]:

def img_crop(img_arr) :
    return img_arr[55:-15,15:-15,:]
def rgb2gray(rgb):
    image_data = cv2.cvtColor(cv2.resize(rgb, (84, 84)), cv2.COLOR_BGR2GRAY)
    image_data[image_data > 0] = 255
    image_data = np.reshape(image_data,(84, 84, 1))
    return image_data
def totensor(img_arr) :
    return torch.FloatTensor(img_arr.transpose((2, 0, 1))).unsqueeze(dim=0)


In [26]:
def data_transform(x) :
    x = img_crop(x)
    x = rgb2gray(x)
    x = totensor(x)
    return x

In [27]:
s = env.reset()
ss = data_transform(s)
ss.shape

torch.Size([1, 1, 84, 84])

In [33]:
actor = models.resnet18(pretrained=True)
actor.conv1 =nn.Conv2d(1,64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
actor.fc = nn.Linear(512, env.action_space.n)

critic = models.resnet18(pretrained=True)
critic.conv1 =nn.Conv2d(1,64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
critic.fc = nn.Linear(512, 1)

In [34]:
def data_transform(x) :
    x = img_crop(x)
    x = rgb2gray(x)
    x = totensor(x)
    return x

In [35]:
A2C_MODEL = A2C(actor=actor,critic=critic,lr=1e-3,gamma=0.99,lam=0.9)

In [52]:
from IPython.display import clear_output
from array2gif import  write_gif 

In [76]:
s = env.reset()
img_collection = [s.transpose((2, 0, 1))]
for t in range(10000) :
    s,r,done,info = env.step(env.action_space.sample())
    if done :
        s = env.reset()
    if t % 10 == 0 :
        img_collection.append(s.transpose((2, 0, 1)))
write_gif(img_collection,"test.gif")


In [83]:
############################################################
rewards      = []  #
is_terminals = []  #
values       = []  #
logprobs     = []  #
############################################################
max_steps_per_episode = 100_000
batch_size = 500
s = env.reset()
A2C_MODEL.load_model("./")
img_collection = []
with tqdm(total=max_steps_per_episode, file=sys.stdout) as pbar:
    for t in range(1, max_steps_per_episode):
        s_tensor = data_transform(s)
        act , v , logprob = A2C_MODEL.choose_action(s_tensor,inference=False)
        s,r,done,info = env.step(int(act))
        rewards.append(r)
        is_terminals.append(done)
        values.append(v)
        logprobs.append(logprob)
        if (batch_size ==len(rewards)) | done :
            s_tensor = data_transform(s)
            _ , v , _ = A2C_MODEL.choose_action(s_tensor,inference=False)
            actor_loss,  critic_loss = A2C_MODEL.get_loss(rewards , is_terminals,values,logprobs,float(v) )
            print()
            A2C_MODEL.train(actor_loss=actor_loss,critic_loss=critic_loss)
            A2C_MODEL.save_model("./")
            ############################################################
            rewards      = []  #
            is_terminals = []  #
            values       = []  #
            logprobs     = []  #
            ############################################################
            a_l , c_l =actor_loss.detach().numpy() , critic_loss.detach().numpy()
            write_gif(img_collection[::10],f"./gif/eval_{t:05d}.gif")
            del img_collection
            img_collection = []
            if done :
                print("restart...")
                s = env.reset()
                img_collection.append(s.transpose((2, 0, 1)))
        else :
            img_collection.append(s.transpose((2, 0, 1)))
            a_l , c_l = 0,0
        #pbar.set_description(f"{done} {r} {a_l:3f}, {c_l:3f}")
        pbar.update(1)

  0%|          | 498/100000 [00:50<2:36:37, 10.59it/s]
  1%|          | 998/100000 [02:47<2:26:16, 11.28it/s]
  1%|▏         | 1499/100000 [04:48<2:29:15, 11.00it/s]
  2%|▏         | 1999/100000 [06:47<2:26:23, 11.16it/s]
  2%|▏         | 1999/100000 [07:06<2:26:23, 11.16it/s]