In [1]:
from thinker.util import __project__
import os
import argparse
import yaml
import numpy as np
import torch
from thinker.actor_net import ActorNet
from thinker.main import Env
import thinker.util as util
from thinker.self_play import init_env_out, create_env_out

class DetectBuffer:
    def __init__(self, outdir, t, rec_t, logger, delay_n=5):
        """
        Store training data grouped in planning stages and output
        whenever the target output is also readydd
            Args:
                N (int): number of planning stage per training output
                delay_n (int): number of planning stage delayed in the output y
                rec_t (int): number of step in a planning stage
                K (int): number of block to merge into
        """
        self.outdir = outdir
        self.t = t # number of time step per file
        self.rec_t = rec_t
        self.logger = logger        
        self.delay_n = delay_n        

        self.processed_n, self.xs, self.y, self.done, self.step_status = 0, [], [], [], []
        self.file_idx = -1
    
    def insert(self, xs, y, done, step_status):
        """
        Args:
            xs (dict): dictionary of training input, with each elem having the
                shape of (B, *)            
            y (tensor): bool tensor of shape (B), being the target output delayed by
                delay_n planning stage            
            done (tensor): bool tensor of shape (B), being the indicator of episode end
            step_status (int): int indicating current step status
        Output:
            save train_xs in shape (N, rec_t, B, *) and train_y in shape (N, B)
        """
        #print("data received! ", y.shape, id, cur_t)
        last_step_real = (step_status == 0) | (step_status == 3)
        if len(self.step_status) == 0 and not last_step_real: return self.file_idx  # skip until real step
                
        self.xs.append(util.dict_map(xs, lambda x:x.cpu()))
        self.y.append(y.cpu())
        self.done.append(done.cpu())
        self.step_status.append(step_status)
        self.processed_n += int(last_step_real)

        if (self.processed_n >= self.t + self.delay_n + 1):               
            self.file_idx += 1                     
            out = self._extract_data(self.t)
            self.processed_n = sum([int(i == 0) + int(i == 3) for i in self.step_status])
            assert self.processed_n == self.delay_n+1, f"should only have {self.delay_n + 1} data left instead of {self.processed_n}"
            path = f'{self.outdir}/data_{self.file_idx}.pt'
            torch.save(out, path)
            out_shape = out[0]['env_state'].shape
            n = self.file_idx * out_shape[0] * out_shape[2]
            self.logger.info(f"{n}: File saved to {path}; env_state shape {out_shape}")

        return self.file_idx   

    def _extract_data(self, t):
        # obtain the first N planning stage and the corresponding target_y in data
        xs, y, done, step_status = self._collect_data(t)
        future_y, future_done = self._collect_data(self.delay_n, y_done_only=True)
        y = torch.concat([y, future_y], dim=0)        
        done = torch.concat([done, future_done], dim=0)                
        
        last_step_real = (step_status == 0) | (step_status == 3)
        assert last_step_real[0], "cur_t should start with 0"
        assert last_step_real.shape[0] == t*self.rec_t, \
            f" last_step_real.shape is {last_step_real.shape}, expected {t*self.rec_t} for the first dimension."        
        assert y.shape[0] == (t + self.delay_n)*self.rec_t, \
            f" y.shape is {y.shape}, expected {(t + self.delay_n)*self.rec_t} for the first dimension."        
        
        B = y.shape[1]
        y = y.view(t + self.delay_n, self.rec_t, B)[:, 0]
        done = done.view(t + self.delay_n, self.rec_t, B)[:, 0]
        step_status = step_status.view(t, self.rec_t)
        # compute target_y
        target_y = self._compute_target_y(y, done, self.delay_n)

        for k in xs.keys():
            xs[k] = xs[k].view((t, self.rec_t) + xs[k].shape[1:])
        
        xs["done"] = done[:t]
        xs["step_status"] = step_status
                
        return xs, target_y

    def _collect_data(self, t, y_done_only=False):
        # collect the first t stage from data
        step_status = torch.tensor(self.step_status, dtype=torch.long)
        next_step_real = (step_status == 2) | (step_status == 3)        
        idx = torch.nonzero(next_step_real, as_tuple=False).squeeze()    
        last_idx = idx[t-1] + 1
        y = torch.stack(self.y[:last_idx], dim=0)
        done = torch.stack(self.done[:last_idx], dim=0)
        if not y_done_only:
            xs = {}
            for k in self.xs[0].keys():
                xs[k] = torch.stack([v[k] for v in self.xs[:last_idx]], dim=0)                
            step_status = step_status[:last_idx]
            self.xs = self.xs[last_idx:]
            self.y = self.y[last_idx:]
            self.done = self.done[last_idx:]
            self.step_status = self.step_status[last_idx:]
            return xs, y, done, step_status
        else:
            return y, done
        
    def _compute_target_y(self, y, done, delay_n):        
        # target_y[i] = (y[i] | (~done[i+1] & y[i+1]) | (~done[i+1] & ~done[i+2] & y[i+2]) | ... | (~done[i+1] & ~done[i+2] & ... & ~done[i+M] & y[i+M]))
        t, b = y.shape
        t = t - delay_n
        not_done_cum = torch.ones(delay_n, t, b, dtype=bool)
        target_y = y.clone()[:-delay_n]
        not_done_cum[0] = ~done[1:1+t]
        target_y = target_y | (not_done_cum[0] & y[1:1+t])
        for m in range(1, delay_n):
            not_done_cum[m] = not_done_cum[m-1] & ~done[m+1:m+1+t]
            target_y = target_y | (not_done_cum[m] & y[m+1:m+1+t])
        return target_y





In [None]:
total_n = 200000
env_n = 16
delay_n = 5
greedy = True
savedir = "../logs/thinker"
outdir = "../data/detect"
xpid = "v18_mcts"

_logger = util.logger()
_logger.info(f"Initializing {xpid} from {savedir}")
device = torch.device("cuda")

ckpdir = os.path.join(savedir, xpid)     
if os.path.islink(ckpdir): ckpdir = os.readlink(ckpdir)  
ckpdir =  os.path.abspath(os.path.expanduser(ckpdir))
outdir = os.path.abspath(os.path.expanduser(outdir))

config_path = os.path.join(ckpdir, 'config_c.yaml')
flags = util.create_flags(config_path, save_flags=False)
flags.shallow_enc = False

env = Env(
        name=flags.name,
        env_n=env_n,
        gpu=True,
        train_model=False,
        parallel=False,
        savedir=savedir,        
        xpid=xpid,
        ckp=True,
        return_x=True,
        return_h=True,
    )

disable_thinker = flags.wrapper_type == 1   
im_rollout = disable_thinker and env.has_model
mcts = flags.mcts

obs_space = env.observation_space
action_space = env.action_space 

actor_param = {
    "obs_space": obs_space,
    "action_space": action_space,
    "flags": flags,
    "tree_rep_meaning": env.get_tree_rep_meaning() if not disable_thinker else None,
    "record_state": True,
}
actor_net = ActorNet(**actor_param)

if not mcts:
    path = os.path.join(ckpdir, "ckp_actor.tar")
    checkpoint = torch.load(path, map_location=torch.device("cpu"))
    actor_net.set_weights(checkpoint["actor_net_state_dict"])
    actor_net.to(device)
    actor_net.train(False)

state = env.reset()
env_out = init_env_out(state, flags=flags, dim_actions=actor_net.dim_actions, tuple_action=actor_net.tuple_action)  
actor_state = actor_net.initial_state(batch_size=env_n, device=device)

file_idx = 0

# create dir

n = 0
while True:
    name = "%s-%d" % (xpid, n)
    outdir_ = os.path.join(outdir, name)
    if not os.path.exists(outdir_):
        os.makedirs(outdir_)
        print(f"Outputting to {outdir_}")
        break
    n += 1
outdir = outdir_

rec_t=flags.rec_t if not im_rollout and not mcts else delay_n + 1
detect_buffer = DetectBuffer(outdir=outdir, t=3200//env_n, rec_t=rec_t, logger=_logger, delay_n=delay_n)
file_n = total_n // (env_n * detect_buffer.t) + 1
_logger.info(f"Data output directory: {outdir}")
_logger.info(f"Number of file to be generated: {file_n}")

with torch.set_grad_enabled(False):
    
    actor_out, actor_state = actor_net(env_out=env_out, core_state=actor_state, greedy=greedy)            
    if not disable_thinker:
        primary_action, reset_action = actor_out.action
    else:
        primary_action, reset_action = actor_out.action, None

    # save setting
    env_state_shape = env.observation_space["real_states"].shape[1:]
    #if rescale: env_state_shape = (3, 40, 40)
    tree_rep_shape = env.observation_space["tree_reps"].shape[1:] if not disable_thinker else None
    hidden_state_shape = actor_net.hidden_state.shape[1:] if disable_thinker else None

    flags_detect = {
        "dim_actions": actor_net.dim_actions,
        "num_actions": actor_net.num_actions,
        "tuple_actions": actor_net.tuple_action,
        "name": flags.name,
        "env_state_shape": list(env_state_shape),
        "tree_rep_shape": list(tree_rep_shape) if not disable_thinker else None,
        "hidden_state_shape": list(hidden_state_shape) if disable_thinker else None,
        "rec_t": flags.rec_t,
        "delay_n": delay_n,
        "ckpdir": ckpdir,
        "xpid": xpid,        
        "dxpid": name,
        "disable_thinker": disable_thinker,
        "im_rollout": im_rollout,      
        "mcts": mcts,
    }

    yaml_file_path = os.path.join(outdir, 'config_detect.yaml')
    with open(yaml_file_path, 'w') as file:
        yaml.dump(flags_detect, file)


    rets = []
    last_file_idx = None
    
    while(True):
        state, reward, done, info = env.step(
            primary_action=primary_action, 
            reset_action=reset_action, 
            action_prob=actor_out.action_prob[-1])    
        
        env_out = create_env_out(actor_out.action, state, reward, done, info, flags=flags)        
        if torch.any(done):
            rets.extend(info["episode_return"][done].cpu().tolist())

        step_status = info['step_status'][0].item() if not im_rollout else 0        

        if im_rollout or (mcts and step_status in [2, 3]):
            # generate imaginary rollout or most visited rollout (mcts)
            if im_rollout: 
                action = actor_out.action.unsqueeze(0)
            else:
                action = actor_out.action[0].unsqueeze(0)
            model_net_out = env.model_net(
                x=env_out.real_states[0],
                done=env_out.done[0],
                actions=action,
                state=None,
                one_hot=False,
                ret_xs=True
            )
            new_env_out = env_out
            if mcts:
                most_visited_actions = torch.tensor(env.most_visited_path(delay_n), dtype=torch.long, device=device)
            for m in range(delay_n):
                if not mcts:
                    actor_out, actor_state = actor_net(env_out=new_env_out, core_state=actor_state, greedy=greedy)   
                    action = actor_out.action
                else:
                    action = most_visited_actions[m]
                model_net_out = env.model_net.forward_single(
                    state=model_net_out.state,
                    action=action,
                    one_hot=False,
                    ret_xs=True,
                )
                new_state = {"real_states": (torch.clamp(model_net_out.xs,0,1)*255).to(torch.uint8)[0]}
                new_env_out = create_env_out(action, new_state, reward, done, info, flags=flags)
                xs = {
                    "env_state": model_net_out.xs[0].half(),
                    "pri_action": action,            
                    "cost": torch.zeros_like(info["cost"]),
                }
                if im_rollout: xs["hidden_state"] = actor_net.hidden_state
                if mcts: xs["reset_action"] = torch.zeros_like(actor_out.action[1])
                file_idx = detect_buffer.insert(
                    xs, 
                    torch.zeros_like(y), 
                    torch.zeros_like(done), 
                    1 if m < delay_n - 1 else 2
                )

        actor_out, actor_state = actor_net(env_out=env_out, core_state=actor_state, greedy=greedy)            
        if not disable_thinker:
            primary_action, reset_action = actor_out.action
        else:
            primary_action, reset_action = actor_out.action, None
        
        # write to detect buffer
        if not disable_thinker:
            env_state = env_out.xs[0] 
        else:
            env_state = env_out.real_states[0]
            env_state = env.normalize(env_state)
        
        env_state = env_state.half()
        xs = {
            "env_state": env_state,
            "pri_action": primary_action,            
            "cost": info["cost"],
        }
        if not disable_thinker:
            if not mcts: xs.update({"tree_rep": state["tree_reps"]})
            xs.update({"reset_action": actor_out.action[1]})
        else:
            if disable_thinker:
                xs.update({
                    "hidden_state": actor_net.hidden_state
                })       
        y = info['cost']
        done = done        

        if not (mcts and step_status != 0): # no recording for non-real mcts
            file_idx = detect_buffer.insert(xs, y, done, step_status)
        
        if file_idx >= file_n: 
            # last file is for validation
            os.rename(f'{outdir}/data_{file_idx}.pt', f'{outdir}/val.pt')
            break

        if last_file_idx is not None and file_idx != last_file_idx:
            print(f"Episode {len(rets)}; Return  {np.mean(np.array(rets))}")

        last_file_idx = file_idx
    
    env.close()

In [115]:
from detect_train import *
from PIL import Image
import torchvision.transforms as transforms
import torch

#datadir = "/home/scuk/RS/thinker/data/detect/v5_sok-32052928-0"
datadir = "/home/scuk/RS/thinker/data/detect/v5c_sp0-49956736-0"
dataset = CustomDataset(datadir=datadir, transform=None, chunk_n=1, data_n=10000)
sampler = ChunkSampler(dataset)
dataloader = DataLoader(dataset, batch_size=2048, sampler=ChunkSampler(dataset))

device = torch.device("cuda")

# load setting
yaml_file_path = os.path.join(datadir, 'config_detect.yaml')
with open(yaml_file_path, 'r') as file:
    flags_data = yaml.safe_load(file)
flags_data = argparse.Namespace(**flags_data)
num_actions = flags_data.num_actions
rec_t = flags_data.rec_t

# Path to your BMP file
image_path = '/home/scuk/RS/thinker/data/player_on_dan_small.bmp'
# Load the image
image = Image.open(image_path)
# Convert the image to a tensor
transform = transforms.Compose([
    transforms.ToTensor(),  # Converts to Tensor, scales to [0, 1] range
])
search_image = transform(image).to(device)



In [159]:
def find_max_similarity_single_function(x, search_image):
    B, C, H, W = x.shape  # Batch size, Channels, Height, Width
    block_size = 8
    num_blocks_h = H // block_size  # Number of horizontal blocks
    num_blocks_w = W // block_size  # Number of vertical blocks

    x_reshaped = x.view(B, C, num_blocks_h, block_size, num_blocks_w, block_size)
    # Permute to group blocks together while keeping the batch and channel dimensions intact
    x_permuted = x_reshaped.permute(0, 2, 4, 1, 3, 5)
    # Flatten the block grid dimensions to list all blocks sequentially
    x_blocks = x_permuted.reshape(B, num_blocks_h * num_blocks_w, C * block_size * block_size)
    # Normalize the blocks and the search_image
    x_blocks_norm = F.normalize(x_blocks+1e-6, p=2, dim=-1)  # Normalize over channel dimension
    search_image_norm = F.normalize(torch.flatten(search_image), p=2, dim=-1)

    similarity = torch.sum(x_blocks_norm * search_image_norm, dim=-1)

    # Find the maximum similarity for each image in the batch
    max_similarity, _ = similarity.view(B, -1).max(dim=1)
    return max_similarity

def mask_top_rank(x, rank):
    # args: x (tensor) of shape (B, N); rank (int)
    # return a mask that equals 1 if the element of each row is the rank largest element
    B, N = x.shape
    sorted_values, _ = x.sort(dim=1, descending=True)
    ties = (sorted_values[:, 1:] - sorted_values[:, :-1]) != 0
    cum_ties = torch.cumsum(ties, dim=-1)
    cum_ties = torch.concat([torch.zeros(B, 1, device=x.device), cum_ties], dim=-1)
    idx = torch.argmax((cum_ties == rank).float(), dim=1)
    not_found = torch.all(~(cum_ties == rank), dim=-1)
    rank_values = sorted_values[torch.arange(B, device=x.device), idx]
    mask = x == rank_values.unsqueeze(-1)
    mask[not_found] = False
    return mask

#B = 2048
#env_state = torch.stack([dataset[idx][0]["env_state"] for idx in range(B)]).to(device)
#tree_rep = torch.stack([dataset[idx][0]["tree_rep"] for idx in range(B)]).to(device)
#target_y = torch.stack([dataset[idx][1] for idx in range(B)]).to(device)

eval_results = {}
search_rank = 0

with torch.set_grad_enabled(False):

    for xs, target_y in dataloader:

        env_state = xs["env_state"].to(device)
        tree_rep = xs["tree_rep"].to(device)
        target_y = target_y.to(device)

        B, rec_t = env_state.shape[:2]

        # for sokoban
        # max_sim = find_max_similarity_single_function(torch.flatten(env_state, 0, 1), search_image)
        # max_sim = max_sim.view(B, rec_t)

        # for point goal
        mask = torch.zeros(240, dtype=torch.bool)
        for i in range(3, 4):
            mask[60*i+22:60*i+22+16] = 1
            #mask[60*i+41:60*i+41+16] = 1
        #max_sim = torch.sum((env_state[:, :, mask] > 0.95).float(), dim=-1) >= 0.5
        max_sim = torch.sum((env_state[:, :, mask] > 0.9).float(), dim=-1) >= 1

        # compute last rollout return

        idx_reset = num_actions * 4 + 6
        idx_rr = idx_reset + flags_data.rec_t + 1
        reset = tree_rep[:, :, idx_reset].bool()
        rollout_return = tree_rep[:, :, idx_rr]

        last_rollout_return = rollout_return.clone()
        r = last_rollout_return[:, -1].clone()
        for n in range(flags_data.rec_t-1, -1, -1):
            r[reset[:, n]] = last_rollout_return[reset[:, n], n]
            last_rollout_return[:, n] = r  

        search_mask = torch.zeros(B, rec_t, dtype=torch.bool, device=device)
        search_mask[:, 0] = 1

        for m in range(search_rank+1):
            search_mask = search_mask | mask_top_rank(last_rollout_return, m)

        max_sim[~search_mask] = 0
        max_sim = torch.max(max_sim, dim=-1)[0]
        pred_y = max_sim > 0.95
        
        result = evaluate_detect(target_y, pred_y)
        for k, v in result.items():
            if k not in eval_results: 
                eval_results[k] = [v]
            else:
                eval_results[k].append(v)


for k in eval_results:
    eval_results[k] = np.mean(np.array(eval_results[k]))

print(eval_results)    


{'acc': 0.870221757888794, 'rec': 0.6387614607810974, 'prec': 0.6094553828239441, 'f1': 0.6229297759159571, 'neg_p': 0.8314314842224121}


In [18]:
max_sim = find_max_similarity_single_function(torch.flatten(env_state, 0, 1), search_image)
env_state_ = env_state.flatten(0, 1)

midx = torch.nonzero(max_sim > 0.95).squeeze(-1)
idx = torch.randperm(midx.shape[0])
midx = midx[idx][:20]
d_env_state = env_state_[midx]

midx = torch.nonzero(max_sim < 0.95).squeeze(-1)
idx = torch.randperm(midx.shape[0])
midx = midx[idx][:20]
s_env_state = env_state_[midx]

In [37]:
from torchvision.utils import save_image

border_size = 0  # Size of the black border
scale_factor = 6

# Calculate the new size including the border
if border_size > 0:
    new_height = d_env_state.size(2) + 2*border_size
    new_width = d_env_state.size(3) + 2*border_size

    d_env_state_ = torch.zeros((d_env_state.size(0), 3, new_height, new_width), device=d_env_state.device)
    d_env_state_[:, :, border_size:-border_size, border_size:-border_size] = d_env_state

    s_env_state_ = torch.zeros((s_env_state.size(0), 3, new_height, new_width), device=s_env_state.device)
    s_env_state_[:, :, border_size:-border_size, border_size:-border_size] = s_env_state
else:
    d_env_state_ = d_env_state
    s_env_state_ = s_env_state

d_env_state_ = F.interpolate(d_env_state_, scale_factor=scale_factor, mode='nearest')
s_env_state_ = F.interpolate(s_env_state_, scale_factor=scale_factor, mode='nearest')

for i in range(d_env_state.shape[0]):
    image_filename = f"../data/sample/d_{i}.png"
    img = d_env_state_[i]
    save_image(img, image_filename)
for i in range(s_env_state.shape[0]):
    image_filename = f"../data/sample/s_{i}.png"
    img = s_env_state_[i]
    save_image(img, image_filename)    


In [14]:
# deprecated

import os 
from torch.utils.data import Dataset, DataLoader
import torch
from thinker import util

datadir = "../data/detect/v5_sok-5993808-1/"
datadir = os.path.abspath(os.path.expanduser(datadir))

class CustomDataset(Dataset):
    def __init__(self, datadir, transform=None):
        self.datadir = datadir
        self.file_list = [f for f in os.listdir(datadir) if f.endswith('.pt')]
        self.transform = transform
        xs, y = torch.load(os.path.join(datadir, self.file_list[0]))        
        self.t = xs['env_state'].shape[0]
        self.b = xs['env_state'].shape[2]
        self.samples_per_file = self.t * self.b

    def __len__(self):
        return len(self.file_list) * self.samples_per_file  # Adjust based on your data

    def __getitem__(self, idx):
        file_idx = idx // self.samples_per_file
        within_file_idx = idx % self.samples_per_file
        t_idx = within_file_idx // self.b
        b_idx = within_file_idx % self.b
        xs, y = torch.load(os.path.join(self.datadir, self.file_list[file_idx]))
        xs.pop('step_status')
        xs.pop('done')
        xs = util.dict_map(xs, lambda x: x[t_idx, :, b_idx])
        y = y[t_idx, b_idx]
        return xs, y

# To load data and train
dataset = CustomDataset(datadir)
# print(dataset[100])
train_dataloader = DataLoader(dataset, batch_size=64, shuffle=True)
train_features, train_labels = next(iter(train_dataloader))


0.2288818359375

In [25]:
from thinker.actor_net import DRCNet, ActorNetBase
from thinker.main import Env
from thinker.self_play import init_env_out, create_env_out
from thinker import util
import torch
import numpy as np

env_n = 16
flags = util.create_setting(args=[], drc=False, save_flags=False, see_h=True, legacy=True, wrapper_type=0, has_action_seq=False)
env = Env(
        name="Sokoban-v0",
        env_n=env_n,
        gpu=True,
        train_model=False,
        parallel=False,
        return_x=True,
        return_h=True,
        flags=flags,
    )

obs_space = env.observation_space
action_space = env.action_space 

device = torch.device("cuda")
# actor_net = DRCNet(obs_space=obs_space, action_space=action_space, flags=flags, tree_rep_meaning=None)
actor_net = ActorNetBase(obs_space=obs_space, action_space=action_space, flags=flags, tree_rep_meaning=None)


Initializing env 0 with device cuda
Model network size: 6637133


Symlinked log directory: /home/stephen/RS/thinker/notebook/logs/latest
Wrote config file to /home/stephen/RS/thinker/notebook/logs/detect-20240205-143653/config_c.yaml


In [26]:
obs_space["tree_reps"]

Box(-inf, inf, (16, 111), float32)

In [None]:
state_dict = torch.load("../logs/detect/v1a_base_dirloss/ckp_actor.tar")["actor_net_state_dict"]
new_state_dict = {}
for key, value in state_dict.items():
    key = key.replace("actor_encoder", "h_encoder")
    key = key.replace("core", "tree_rep_encoder.rnn")    
    key = key.replace("initial_enc", "tree_rep_encoder.rnn_in_fc")    
    key = key.replace("model_stat_fc", "tree_rep_encoder.rnn_out_fc")    
    new_state_dict[key] = value
actor_net.load_state_dict(new_state_dict)

In [None]:
for k, v in actor_net.state_dict().items():
    print(k, v.shape)

In [None]:
for k, v in new_state_dict.items():
    print(k, v.shape)

In [None]:
checkpoint = torch.load("../logs/detect/v1a_base_dirloss/ckp_actor.tar")["actor_net_state_dict"]
print(checkpoint)

In [None]:
actor_net = actor_net.to(device)
state = env.reset()
env_out = init_env_out(state, flags=flags, dim_actions=actor_net.dim_actions, tuple_action=actor_net.tuple_action)  
actor_state = actor_net.initial_state(batch_size=env_n, device=device)
rets = []

with torch.set_grad_enabled(False):
    
    while(True):
        actor_out, actor_state = actor_net(env_out=env_out, core_state=actor_state, greedy=False)
        primary_action, reset_action = actor_out.action, None
        state, reward, done, info = env.step(
            primary_action=primary_action, 
            reset_action=reset_action)    
        if torch.any(done):
            rets.extend(info["episode_return"][done].cpu().tolist())            
            print(f"Episode {len(rets)}; Return  {np.mean(np.array(rets))}")
        env_out = create_env_out(primary_action, state, reward, done, info, flags=flags)

In [8]:
cs = ['step', 'real_step', 'actor_net_optimizer_state_dict', 'actor_net_scheduler_state_dict', 'actor_net_state_dict']
for c in cs: checkpoint_[c] = checkpoint[c]
torch.save(checkpoint_, "../logs/detect/v5b_sok_drc/ckp_actor.tar")


In [None]:
import numpy as np

actor_net = actor_net.to(device)
state = env.reset()
#env_out = init_env_out(state, flags=flags, dim_actions=actor_net.dim_actions, tuple_action=actor_net.tuple_action)  
env_out = init_env_out(state, flags=flags, dim_actions=1, tuple_action=False)  
actor_state = actor_net.initial_state(batch_size=env_n, device=device)
rets = []

with torch.set_grad_enabled(False):
    
    while(True):
        #actor_out, actor_state = actor_net(env_out=env_out, core_state=actor_state, greedy=False)
        #primary_action, reset_action = actor_out.action, None
        primary_action, actor_state = actor_net(obs=env_out, core_state=actor_state, greedy=False)
        primary_action = primary_action[0]
        reset_action = None

        state, reward, done, info = env.step(
            primary_action=primary_action, 
            reset_action=reset_action)    
        if torch.any(done):
            rets.extend(info["episode_return"][done].cpu().tolist())            
            print(f"Episode {len(rets)}; Return  {np.mean(np.array(rets))}")
        env_out = create_env_out(primary_action, state, reward, done, info, flags=flags)

In [None]:
from thinker.wrapper import DMSuiteEnv
# Example usage
env = DMSuiteEnv(domain_name="acrobot", task_name="swingup", rgb=False)
obs = env.reset()
done = False
while not done:
    action = env.action_space.sample()
    obs, reward, done, _ = env.step(action)
    print(obs, reward, done)
    if done:
        obs = env.reset()
    