In [None]:
import numpy as np
import torch
import gym
import torchvision.transforms
from collections import deque
from DQN_model import CnnDQN
from helper_functions import phi_transformer


class frame_saver:
    def __init__(self, save_path, 
                  img_dim, 
                  channels,
                  run_name,
                  n_frames_stack = 10**4, n_seqs):
        self.save_path = save_path
        self.frame_stack = deque([])
        self.n_stacks = 1
        self.n_frames_stack = n_frames_stack
        self.run_name = run_name
        
        if (channels == 1):
            self.process = torchvision.transforms.Compose(
                [torchvision.transforms.ToPILImage(),
                 torchvision.transforms.Grayscale(num_output_channels=channels),
                 torchvision.transforms.Resize(img_dim),
                 torchvision.transforms.ToTensor()])
        elif (channels == 3):
            self.process = torchvision.transforms.Compose(
                [torchvision.transforms.ToPILImage(),
                 torchvision.transforms.Resize(img_dim),
                 torchvision.transforms.ToTensor()])
        
    def add_frame(self, S):
        # Add frame
        S_processed = self.process(S).numpy()
        S_processed = np.transpose(S_processed, (1,2,0))
        S_processed = (255*S_processed).astype("uint8")
        self.frame_stack.append(S_processed)
        
    def save_seqs(self, seq_arr)
        # Saves array of sequences to disk
        name = self.save_path + "_" + self.run_name + "_" + str(self.n_seqs) +".npy"
        np.save(name, seq_arr)

        

def gather_frames(frame_save, model, env, n_frames_gather, n_seqs, n_phi=4, epsilon=0.1):
    episode_reward = 0
    n_seq_gather = 0
    seq_arr = np.empty((n_seqs,), dtype=np.object)
    
    done = False
    S = np.zeros((n_phi,) + (210,160,3), dtype="uint8")
    S[n_phi-1] = env.reset()
    
    while (n_seq_gather>n_seqs):
        # Take action
        if np.random.rand(1)[0]<epsilon: # Case ranom move selected
                a = np.random.randint(env.action_space.n)
        else:
            with torch.no_grad():# Case non-random move selected greedely
                S_model = phi_transformer(S, n_phi)
                a = model.act(S_model)

        # Take actions
        for j in range(n_phi):
            S[j], r_temp, done, info = env.step(a)
            episode_reward += r_temp
            # Save frame
            frame_save.add_frame(S[j])
            n_frames += 1
            
            if (done): # Check if game done
                    seq_arr[n_seq_gather] = S
                    n_seq_gather += 1
                
                    print(episode_reward)
                    S = np.zeros((n_phi,) + (210,160,3), dtype="uint8")
                    done = False
                    S[n_phi-1] = env.reset()
                    episode_reward = 0
                    continue
    frame_save.save_seqs(seq_arr)            
    # frame_save.end_gather()


# Hyper parameters
epsilon = 0.05 # Chance to do random action
add_frame_chance = 0.05
env_id = "Riverraid-v0"
n_phi = 4
n_frames_train_gather = 3*10**5 # Number of frames for training
n_frames_stack = 10**5 # Number of frames to store before writing



load_path = "models/DQN_"+env_id
save_path = "saved_frames/"+env_id

# Check if save directory exists
import os
if not os.path.exists('saved_frames'):
    os.makedirs('saved_frames')

env = gym.make(env_id)
Q_model = torch.load(load_path)
#Q_model = Q_model.eval()


frame_store_train = frame_saver(save_path, [64,64], 3, "train", n_frames_stack=n_frames_stack)
#frame_store_test = frame_saver(save_path, [64,64], 3, "test", n_frames_stack=50000)
#frame_store_val = frame_saver(save_path, [64,64], 3, "validation", n_frames_stack=50000)

## Begin train frame loop
gather_frames(frame_store_train, Q_model, env, n_frames_train_gather)
#gather_frames(frame_store_test, Q_model, env, 1000)
#gather_frames(frame_store_val, Q_model, env, 1000)
