In [1]:
%load_ext autoreload
%autoreload 2
import os
os.environ['CUDA_VISIBLE_DEVICES'] = "-1"
import numpy as np
from matplotlib import pyplot as plt
import torch
from sparse_causal_model_learner_rl.sacred_gin_tune.sacred_wrapper import load_config_files
from sparse_causal_model_learner_rl.learners.rl_learner import CausalModelLearnerRL
from sparse_causal_model_learner_rl.config import Config
from keychest.features_xy import arr_to_dict, dict_to_arr, obs_features_handcoded
from keychest.gofa_model import manual_model_features
from tqdm.auto import tqdm
from torch import nn
import gin

In [2]:
gin.enter_interactive_mode()
load_config_files(['../keychest/config/5x5_1f1c1k.gin','../sparse_causal_model_learner_rl/configs/rec_nonlin_gnn.gin'])

['5x5_1f1c1k', 'rec_nonlin_gnn']

In [3]:
learner = CausalModelLearnerRL(Config())

Make environment KeyChest-v0 None {}


In [4]:
learner.create_trainables()



In [5]:
learner.collect_steps()

In [6]:
obs_x = learner._context['obs_x']
obs_y = learner._context['obs_y']
action_x = learner._context['action_x']

In [7]:
torch.mean(torch.abs(obs_x - obs_y) ** 2)

tensor(0.7311)

In [8]:
f_sample = obs_features_handcoded(learner.env.engine)

In [9]:
f_x = [arr_to_dict(obs.numpy(), f_sample.keys()) for obs in obs_x]
f_y = [arr_to_dict(obs.numpy(), f_sample.keys()) for obs in obs_y]

In [10]:
f_x[0], f_y[0]

({'button__x': 9.0,
  'button__y': 2.0,
  'chest__00__x': 7.0,
  'chest__00__y': 3.0,
  'food__00__x': 7.0,
  'food__00__y': 2.0,
  'health': 8.0,
  'key__00__x': 5.0,
  'key__00__y': 4.0,
  'keys': 0.0,
  'lamp_off__x': 5.0,
  'lamp_off__y': 3.0,
  'lamp_on__x': -1.0,
  'lamp_on__y': -1.0,
  'lamp_status': 0.0,
  'player__x': 6.0,
  'player__y': 4.0},
 {'button__x': 9.0,
  'button__y': 2.0,
  'chest__00__x': 7.0,
  'chest__00__y': 3.0,
  'food__00__x': 7.0,
  'food__00__y': 2.0,
  'health': 7.0,
  'key__00__x': 5.0,
  'key__00__y': 4.0,
  'keys': 0.0,
  'lamp_off__x': 5.0,
  'lamp_off__y': 3.0,
  'lamp_on__x': -1.0,
  'lamp_on__y': -1.0,
  'lamp_status': 0.0,
  'player__x': 5.0,
  'player__y': 4.0})

In [11]:
action_x[0]

tensor([0., 1., 0., 0.])

In [12]:
learner.env.engine.food_rows + learner.env.engine.keys_rows + 1

5

In [13]:
f_t1 = [manual_model_features(f, a.numpy(), learner.env.engine) for f, a in zip(f_x, action_x)]

In [14]:
keys_differ = {}
for ft1_correct, ft1 in zip(f_y, f_t1):
    for key in f_sample.keys():
        if ft1_correct[key] != ft1[key]:
            if key not in keys_differ:
                keys_differ[key] = []
            if len(keys_differ[key]) < 10:
                keys_differ[key].append({'correct': ft1_correct[key],
                                         'given': ft1[key]})


In [15]:
keys_differ

{}