**This file is the implementation of other methods that we compare with our methods in the paper. Most of the code (excluding OPE part) in this file credits to the authors of those papers**

In [None]:
import json
#!! do not import matplotlib until you check input arguments
import numpy as np
import os
import seeding
import sys
import torch
from tqdm import tqdm
import logging
import shutil
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import copy
import pandas as pd
import os
import pickle
from collections import defaultdict
import argparse
import glob
import json
import numpy as np
import os
import random
import torch

In [3]:
def get_parser():
    return argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)

def load_experiment(tag, coefs=None):
    logfiles = sorted(glob.glob(os.path.join('results/logs', tag + '*', 'train-*.txt')))
    seeds = [f.split('-')[-1].split('.')[0] for f in logfiles]
    logs = [open(f, 'r').read().splitlines() for f in logfiles]

    def read_log(log, coefs=coefs):
        results = [json.loads(item) for item in log]
        fields = results[0].keys()
        data = dict([(f, np.asarray([item[f] for item in results])) for f in fields])
        if coefs is None:
            coefs = {
                'L_inv': 1.0,
                'L_fwd': 0.1,
                'L_cpc': 1.0,
                'L_fac': 0.1,
            }
        if 'L' not in fields:
            data['L'] = sum([
                coefs[f] * data[f] if f != 'L_fac' else coefs[f] * (data[f] - 1)
                for f in coefs.keys()
            ])
        return data

    results = [read_log(log) for log in logs]
    data = dict(zip(seeds, results))
    return data

In [5]:
class Network(torch.nn.Module):
    """Module that, when printed, shows its total number of parameters
    """
    def __init__(self):
        super().__init__()
        self.frozen = False

    def __str__(self):
        s = super().__str__() + '\n'
        n_params = 0
        for p in self.parameters():
            n_params += np.prod(p.size())
        s += 'Total params: {}'.format(n_params)
        return s

    def print_summary(self):
        s = str(self)
        print(s)

    def save(self, name, model_dir, is_best=False):
        os.makedirs(model_dir, exist_ok=True)
        model_file = os.path.join(model_dir, '{}_latest.pytorch'.format(name))
        torch.save(self.state_dict(), model_file)
        logging.info('Model saved to {}'.format(model_file))
        if is_best:
            best_file = os.path.join(model_dir, '{}_best.pytorch'.format(name))
            shutil.copyfile(model_file, best_file)
            logging.info('New best model! Model copied to {}'.format(best_file))

    def load(self, model_file, force_cpu=False):
        logging.info('Loading model from {}...'.format(model_file))
        map_loc = 'cpu' if force_cpu else None
        state_dict = torch.load(model_file, map_location=map_loc)
        self.load_state_dict(state_dict)

    def freeze(self):
        if not self.frozen:
            for param in self.parameters():
                param.requires_grad = False
            self.frozen = True

    def unfreeze(self):
        if self.frozen:
            for param in self.parameters():
                param.requires_grad = True
            self.frozen = False

In [6]:
class AutoEncoder(Network):
    def __init__(self,
                 n_actions,
                 input_shape=2,
                 n_latent_dims=4,
                 n_hidden_layers=1,
                 n_units_per_layer=32,
                 lr=0.001,
                 coefs=None):
        super().__init__()
        self.n_actions = n_actions
        self.n_latent_dims = n_latent_dims
        self.lr = lr
        self.coefs = defaultdict(lambda: 1.0)
        self.phi = PhiNet(input_shape=input_shape,
                          n_latent_dims=n_latent_dims,
                          n_units_per_layer=n_units_per_layer,
                          n_hidden_layers=n_hidden_layers)
        self.reverse_phi = PhiNet(input_shape=input_shape,
                                  n_latent_dims=n_latent_dims,
                                  n_units_per_layer=n_units_per_layer,
                                  n_hidden_layers=n_hidden_layers)
        self.reverse_phi.phi = nn.Sequential(
            *reversed([Reshape(-1, *input_shape), nn.Tanh()] + [
                nn.Linear(l.out_features, l.in_features) if isinstance(l, nn.Linear) else l
                for l in self.reverse_phi.layers[1:-1]
            ]))
        self.mse = torch.nn.MSELoss()
        self.optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)

    def forward(self, *args, **kwargs):
        raise NotImplementedError

    def encode(self, x0):
        return self.phi(x0)

    def decode(self, z0):
        return self.reverse_phi(z0)

    def compute_loss(self, x0):
        loss = self.mse(x0, self.decode(self.encode(x0)))
        return loss

    def train_batch(self, x0, *args, **kwargs):
        self.train()
        self.optimizer.zero_grad()
        loss = self.compute_loss(x0)
        loss.backward()
        self.optimizer.step()
        return loss
class ContrastiveNet(Network):
    def __init__(self, n_latent_dims=4, n_hidden_layers=1, n_units_per_layer=32):
        super().__init__()
        self.frozen = False

        self.layers = []
        if n_hidden_layers == 0:
            self.layers.extend([torch.nn.Linear(2 * n_latent_dims, 1)])
        else:
            self.layers.extend(
                [torch.nn.Linear(2 * n_latent_dims, n_units_per_layer),
                 torch.nn.Tanh()])
            self.layers.extend(
                [torch.nn.Linear(n_units_per_layer, n_units_per_layer),
                 torch.nn.Tanh()] * (n_hidden_layers - 1))
            self.layers.extend([torch.nn.Linear(n_units_per_layer, 1)])
        self.layers.extend([torch.nn.Sigmoid()])
        self.model = torch.nn.Sequential(*self.layers)

    def forward(self, z0, z1):
        context = torch.cat((z0, z1), -1)
        fakes = self.model(context).squeeze()
        return fakes
class CPCNet(Network):
    def __init__(self, n_latent_dims=4, n_hidden_layers=1, n_units_per_layer=32):
        super().__init__()
        self.frozen = False

        self.layers = []
        if n_hidden_layers == 0:
            self.layers.extend([torch.nn.Linear(2*n_latent_dims, 1)])
        else:
            self.layers.extend([torch.nn.Linear(2*n_latent_dims, n_units_per_layer), torch.nn.Tanh()])
            self.layers.extend([torch.nn.Linear(n_units_per_layer, n_units_per_layer), torch.nn.Tanh()] * (n_hidden_layers-1))
            self.layers.extend([torch.nn.Linear(n_units_per_layer, 1)])
        self.layers.extend([torch.nn.Sigmoid()])
        self.model = torch.nn.Sequential(*self.layers)

    def forward(self, c, z):
        context = torch.cat((c, z), -1)
        fakes = self.model(context).squeeze()
        return fakes
class FeatureNet(Network):
    def __init__(self,
                 n_actions,
                 input_shape=2,
                 n_latent_dims=4,
                 n_hidden_layers=1,
                 n_units_per_layer=32,
                 lr=0.001,
                 coefs=None):
        super().__init__()
        self.n_actions = n_actions
        self.n_latent_dims = n_latent_dims
        self.lr = lr
        self.coefs = defaultdict(lambda: 1.0)
        if coefs is not None:
            for k, v in coefs.items():
                self.coefs[k] = v

        self.phi = PhiNet(input_shape=input_shape,
                          n_latent_dims=n_latent_dims,
                          n_units_per_layer=n_units_per_layer,
                          n_hidden_layers=n_hidden_layers)
        # self.fwd_model = FwdNet(n_actions=n_actions, n_latent_dims=n_latent_dims, n_hidden_layers=n_hidden_layers, n_units_per_layer=n_units_per_layer)
        self.inv_model = InvNet(n_actions=n_actions,
                                n_latent_dims=n_latent_dims,
                                n_units_per_layer=n_units_per_layer,
                                n_hidden_layers=n_hidden_layers)
        self.inv_discriminator = InvDiscriminator(n_actions=n_actions,
                                                  n_latent_dims=n_latent_dims,
                                                  n_units_per_layer=n_units_per_layer,
                                                  n_hidden_layers=n_hidden_layers)
        self.discriminator = ContrastiveNet(n_latent_dims=n_latent_dims,
                                            n_hidden_layers=1,
                                            n_units_per_layer=n_units_per_layer)

        self.cross_entropy = torch.nn.CrossEntropyLoss()
        self.bce_loss = torch.nn.BCELoss()
        self.mse = torch.nn.MSELoss()
        self.optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)

    def inverse_loss(self, z0, z1, a):
        if self.coefs['L_inv'] == 0.0:
            return torch.tensor(0.0)
        a_hat = self.inv_model(z0, z1)
        return self.cross_entropy(input=a_hat, target=a)

    def contrastive_inverse_loss(self, z0, z1, a):
        if self.coefs['L_coinv'] == 0.0:
            return torch.tensor(0.0)
        N = len(z0)
        # shuffle next states
        idx = torch.randperm(N)

        a_neg = torch.randint_like(a, low=0, high=self.n_actions)

        # concatenate positive and negative examples
        z0_extended = torch.cat([z0, z0], dim=0)
        z1_extended = torch.cat([z1, z1], dim=0)
        a_pos_neg = torch.cat([a, a_neg], dim=0)
        is_fake = torch.cat([torch.zeros(N), torch.ones(N)], dim=0)

        # Compute which ones are fakes
        fakes = self.inv_discriminator(z0_extended, z1_extended, a_pos_neg)
        return self.bce_loss(input=fakes, target=is_fake.float())

    def ratio_loss(self, z0, z1):
        if self.coefs['L_rat'] == 0.0:
            return torch.tensor(0.0)
        N = len(z0)
        # shuffle next states
        idx = torch.randperm(N)
        z1_neg = z1.view(N, -1)[idx].view(z1.size())

        # concatenate positive and negative examples
        z0_extended = torch.cat([z0, z0], dim=0)
        z1_pos_neg = torch.cat([z1, z1_neg], dim=0)
        is_fake = torch.cat([torch.zeros(N), torch.ones(N)], dim=0)

        # Compute which ones are fakes
        fakes = self.discriminator(z0_extended, z1_pos_neg)
        return self.bce_loss(input=fakes, target=is_fake.float())

    def distance_loss(self, z0, z1):
        if self.coefs['L_dis'] == 0.0:
            return torch.tensor(0.0)
        dz = torch.norm(z1 - z0, dim=-1, p=2)
        with torch.no_grad():
            max_dz = 0.1
        excess = torch.nn.functional.relu(dz - max_dz)
        return self.mse(excess, torch.zeros_like(excess))

    def oracle_loss(self, z0, z1, d):
        if self.coefs['L_ora'] == 0.0:
            return torch.tensor(0.0)

        dz = torch.cat(
            [torch.norm(z1 - z0, dim=-1, p=2),
             torch.norm(z1.flip(0) - z0, dim=-1, p=2)], dim=0)

        with torch.no_grad():
            counts = 1 + torch.histc(d, bins=36, min=0, max=35)
            inverse_counts = counts.sum() / counts
            weights = inverse_counts[d.long()]
            weights = weights / weights.sum()

        loss = self.mse(dz, d / 10.0)
        # loss += torch.sum(weights * (dz - d / 20.0)**2) # weighted MSE
        # loss = -torch.nn.functional.cosine_similarity(dz, d, 0)
        return loss

    def forward(self, *args, **kwargs):
        raise NotImplementedError

    def predict_a(self, z0, z1):
        raise NotImplementedError
        # a_logits = self.inv_model(z0, z1)
        # return torch.argmax(a_logits, dim=-1)

    def compute_loss(self, z0, z1, a):
        loss = 0
        loss += self.coefs['L_coinv'] * self.contrastive_inverse_loss(z0, z1, a)
        loss += self.coefs['L_inv'] * self.inverse_loss(z0, z1, a)
        # loss += self.coefs['L_fwd'] * self.compute_fwd_loss(z0, z1, z1_hat)
        loss += self.coefs['L_rat'] * self.ratio_loss(z0, z1)
        loss += self.coefs['L_dis'] * self.distance_loss(z0, z1)
        return loss

    def train_batch(self, x0, x1, a):
        self.train()
        self.optimizer.zero_grad()
        z0 = self.phi(x0)
        z1 = self.phi(x1)
        # z1_hat = self.fwd_model(z0, a)
        loss = self.compute_loss(z0, z1, a)
        loss.backward()
        self.optimizer.step()
        return loss
class InvDiscriminator(Network):
    def __init__(self, n_actions, n_latent_dims=4, n_hidden_layers=1, n_units_per_layer=32):
        super().__init__()
        self.n_actions = n_actions

        self.layers = []
        if n_hidden_layers == 0:
            self.layers.extend([torch.nn.Linear(2 * n_latent_dims + n_actions, 1)])
        else:
            self.layers.extend([
                torch.nn.Linear(2 * n_latent_dims + n_actions, n_units_per_layer),
                torch.nn.Tanh()
            ])
            self.layers.extend(
                [torch.nn.Linear(n_units_per_layer, n_units_per_layer),
                 torch.nn.Tanh()] * (n_hidden_layers - 1))
            self.layers.extend([torch.nn.Linear(n_units_per_layer, 1)])
        self.layers.extend([torch.nn.Sigmoid()])
        self.model = torch.nn.Sequential(*self.layers)

    def forward(self, z0, z1, a):
        context = torch.cat((z0, z1, one_hot(a, self.n_actions)), -1)
        fakes = self.model(context).squeeze()
        return fakes
class FwdNet(Network):
    def __init__(self, n_actions, n_latent_dims=4, n_hidden_layers=1, n_units_per_layer=32):
        super().__init__()
        self.n_actions = n_actions
        self.frozen = False

        self.fwd_layers = []
        if n_hidden_layers == 0:
            self.fwd_layers.extend([torch.nn.Linear(n_latent_dims+self.n_actions, n_latent_dims)])
        else:
            self.fwd_layers.extend([torch.nn.Linear(n_latent_dims + self.n_actions, n_units_per_layer), torch.nn.Tanh()])
            self.fwd_layers.extend([torch.nn.Linear(n_units_per_layer, n_units_per_layer), torch.nn.Tanh()] * (n_hidden_layers-1))
            self.fwd_layers.extend([torch.nn.Linear(n_units_per_layer, n_latent_dims)])
        # self.fwd_layers.extend([torch.nn.BatchNorm1d(n_latent_dims, affine=False)])
        self.fwd_model = torch.nn.Sequential(*self.fwd_layers)

    def forward(self, z, a):
        a_onehot = one_hot(a, depth=self.n_actions)
        context = torch.cat((z, a_onehot), -1)
        z_hat = self.fwd_model(context)
        return z_hat
class InvNet(Network):
    def __init__(self, n_actions, n_latent_dims=4, n_hidden_layers=1, n_units_per_layer=32):
        super().__init__()
        self.n_actions = n_actions

        self.layers = []
        if n_hidden_layers == 0:
            self.layers.extend([torch.nn.Linear(2 * n_latent_dims, n_actions)])
        else:
            self.layers.extend(
                [torch.nn.Linear(2 * n_latent_dims, n_units_per_layer),
                 torch.nn.Tanh()])
            self.layers.extend(
                [torch.nn.Linear(n_units_per_layer, n_units_per_layer),
                 torch.nn.Tanh()] * (n_hidden_layers - 1))
            self.layers.extend([torch.nn.Linear(n_units_per_layer, n_actions)])

        self.inv_model = torch.nn.Sequential(*self.layers)

    def forward(self, z0, z1):
        context = torch.cat((z0, z1), -1)
        a_logits = self.inv_model(context)
        return a_logits
class Reshape(torch.nn.Module):
    """Module that returns a view of the input which has a different size

    Parameters
    ----------
    args : int...
        The desired size
    """
    def __init__(self, *args):
        super().__init__()
        self.shape = args

    def __repr__(self):
        s = self.__class__.__name__
        s += '{}'.format(self.shape)
        return s

    def forward(self, input):
        return input.view(*self.shape)


class Sequential(torch.nn.Sequential, Network):
    pass

def one_hot(x, depth, dtype=torch.float32):
    """Convert a batch of indices to a batch of one-hot vectors

    Parameters
    ----------
    depth : int
        The length of each output vector
    """
    i = x.unsqueeze(-1).expand(-1, depth)
    return torch.zeros_like(i, dtype=dtype).scatter_(-1, i, 1)

def extract(input, idx, idx_dim, batch_dim=0):
    '''
Extracts slices of input tensor along idx_dim at positions
specified by idx.

Notes:
    idx must have the same size as input.shape[batch_dim].
    Output tensor has the shape of input with idx_dim removed.

Args:
    input (Tensor): the source tensor
    idx (LongTensor): the indices of slices to extract
    idx_dim (int): the dimension along which to extract slices
    batch_dim (int): the dimension to treat as the batch dimension

Example::

    >>> t = torch.arange(24, dtype=torch.float32).view(3,4,2)
    >>> i = torch.tensor([1, 3, 0], dtype=torch.int64)
    >>> extract(t, i, idx_dim=1, batch_dim=0)
        tensor([[ 2.,  3.],
                [14., 15.],
                [16., 17.]])
'''
    if idx_dim == batch_dim:
        raise RuntimeError('idx_dim cannot be the same as batch_dim')
    if len(idx) != input.shape[batch_dim]:
        raise RuntimeError(
            "idx length '{}' not compatible with batch_dim '{}' for input shape '{}'".format(
                len(idx), batch_dim, list(input.shape)))
    viewshape = [
        1,
    ] * input.ndimension()
    viewshape[batch_dim] = input.shape[batch_dim]
    idx = idx.view(*viewshape).expand_as(input)
    result = torch.gather(input, idx_dim, idx).mean(dim=idx_dim)
    return result
class PhiNet(Network):
    def __init__(self,
                 input_shape=2,
                 n_latent_dims=4,
                 n_hidden_layers=1,
                 n_units_per_layer=32,
                 final_activation=torch.nn.Tanh):
        super().__init__()
        self.input_shape = input_shape

        shape_flat = np.prod(self.input_shape)

        self.layers = []
        self.layers.extend([Reshape(-1, shape_flat)])
        if n_hidden_layers == 0:
            self.layers.extend([torch.nn.Linear(shape_flat, n_latent_dims)])
        else:
            self.layers.extend([torch.nn.Linear(shape_flat, n_units_per_layer), torch.nn.Tanh()])
            self.layers.extend(
                [torch.nn.Linear(n_units_per_layer, n_units_per_layer),
                 torch.nn.Tanh()] * (n_hidden_layers - 1))
            self.layers.extend([
                torch.nn.Linear(n_units_per_layer, n_latent_dims),
            ])
        if final_activation is not None:
            self.layers.extend([final_activation()])
        self.phi = torch.nn.Sequential(*self.layers)

    def forward(self, x):
        z = self.phi(x)
        return z
class QNet(Network):
    def __init__(self, n_features, n_actions, n_hidden_layers=1, n_units_per_layer=32):
        super().__init__()
        self.n_actions = n_actions

        self.layers = []
        if n_hidden_layers == 0:
            self.layers.extend([torch.nn.Linear(n_features, n_actions)])
        else:
            self.layers.extend([torch.nn.Linear(n_features, n_units_per_layer), torch.nn.ReLU()])
            self.layers.extend(
                [torch.nn.Linear(n_units_per_layer, n_units_per_layer),
                 torch.nn.ReLU()] * (n_hidden_layers - 1))
            self.layers.extend([torch.nn.Linear(n_units_per_layer, n_actions)])

        self.model = torch.nn.Sequential(*self.layers)

    def forward(self, z):
        return self.model(z)
class SimpleNet(Network):
    def __init__(self, n_inputs, n_outputs, n_hidden_layers=1, n_units_per_layer=32):
        super().__init__()
        self.n_outputs = n_outputs
        self.frozen = False

        self.layers = []
        if n_hidden_layers == 0:
            self.layers.extend([torch.nn.Linear(n_inputs, n_outputs)])
        else:
            self.layers.extend([torch.nn.Linear(n_inputs, n_units_per_layer), torch.nn.Tanh()])
            self.layers.extend(
                [torch.nn.Linear(n_units_per_layer, n_units_per_layer),
                 torch.nn.Tanh()] * (n_hidden_layers - 1))
            self.layers.extend([torch.nn.Linear(n_units_per_layer, n_outputs)])

        self.model = torch.nn.Sequential(*self.layers)

    def forward(self, z0):
        a_logits = self.model(z0)
        return a_logits

In [8]:
parser = get_parser()
# parser.add_argument('-d','--dims', help='Number of latent dimensions', type=int, default=2)
# yapf: disable
parser.add_argument('--type', type=str, default='markov', choices=['markov', 'autoencoder', 'pixel-predictor'],
                    help='Which type of representation learning method')
parser.add_argument('-n','--n_updates', type=int, default=3000,
                    help='Number of training updates')
parser.add_argument('-r','--rows', type=int, default=6,
                    help='Number of gridworld rows')
parser.add_argument('-c','--cols', type=int, default=6,
                    help='Number of gridworld columns')
parser.add_argument('-w', '--walls', type=str, default='empty', choices=['empty', 'maze', 'spiral', 'loop'],
                    help='The wall configuration mode of gridworld')
parser.add_argument('-l','--latent_dims', type=int, default=2,
                    help='Number of latent dimensions to use for representation')
parser.add_argument('--L_inv', type=float, default=1.0,
                    help='Coefficient for inverse-model-matching loss')
parser.add_argument('--L_coinv', type=float, default=0.0,
                    help='Coefficient for *contrastive* inverse-model-matching loss')
# parser.add_argument('--L_fwd', type=float, default=0.0,
#                     help='Coefficient for forward dynamics loss')
parser.add_argument('--L_rat', type=float, default=1.0,
                    help='Coefficient for ratio-matching loss')
# parser.add_argument('--L_fac', type=float, default=0.0,
#                     help='Coefficient for factorization loss')
parser.add_argument('--L_dis', type=float, default=0.0,
                    help='Coefficient for planning-distance loss')
parser.add_argument('--L_ora', type=float, default=0.0,
                    help='Coefficient for oracle distance loss')
parser.add_argument('-lr','--learning_rate', type=float, default=0.003,
                    help='Learning rate for Adam optimizer')
parser.add_argument('--batch_size', type=int, default=2048,
                    help='Mini batch size for training updates')
parser.add_argument('-s','--seed', type=int, default=0,
                    help='Random seed')
parser.add_argument('-t','--tag', type=str, required=True,
                    help='Tag for identifying experiment')
parser.add_argument('-v','--video', action='store_true',
                    help="Save training video")
parser.add_argument('--no_graphics', action='store_true',
                    help='Turn off graphics (e.g. for running on cluster)')
parser.add_argument('--save', action='store_true',
                    help='Save final network weights')
parser.add_argument('--cleanvis', action='store_true',
                    help='Switch to representation-only visualization')
parser.add_argument('--no_sigma', action='store_true',
                    help='Turn off sensors and just use true state; i.e. x=s')
parser.add_argument('--rearrange_xy', action='store_true',
                    help='Rearrange discrete x-y positions to break smoothness')
if 'ipykernel' in sys.argv[0]:
    arglist = [
        '--type', 'markov',
        '-w', 'spiral',
        '--tag', 'test-spiral',
        '-r', '6',
        '-c', '6',
        '--L_ora', '1.0',
        '--save'
    ]
    args = parser.parse_args(arglist)
else:
    args = parser.parse_args()
if args.no_graphics:
    import matplotlib
    # Force matplotlib to not use any Xwindows backend.
    matplotlib.use('Agg')
import matplotlib.pyplot as plt

log_dir = 'results/logs/' + str(args.tag)
vid_dir = 'results/videos/' + str(args.tag)
maze_dir = 'results/mazes/' + str(args.tag)
os.makedirs(log_dir, exist_ok=True)
if args.video:
    os.makedirs(vid_dir, exist_ok=True)
    os.makedirs(maze_dir, exist_ok=True)
    video_filename = vid_dir + '/video-{}.mp4'.format(args.seed)
    image_filename = vid_dir + '/final-{}.png'.format(args.seed)
    maze_file = maze_dir + '/maze-{}.png'.format(args.seed)

log = open(log_dir + '/train-{}.txt'.format(args.seed), 'w')
with open(log_dir + '/args-{}.txt'.format(args.seed), 'w') as arg_file:
    arg_file.write(repr(args))

seeding.seed(args.seed)

In [9]:
class FQE_eval(torch.nn.Module):
    def __init__(self, in_dim, action_size, n_layers=2, n_nodes=32, activation=nn.ReLU()):
        super().__init__()
        self.action_size = action_size

        self.net = []
        self.net.append(nn.Linear(in_dim, n_nodes))
        self.net.append(activation)

        for i in range(n_layers-1):
            self.net.append(nn.Linear(n_nodes, n_nodes))
            self.net.append(activation)

        self.net.append(nn.Linear(n_nodes, action_size))
        self.FQE_net = nn.Sequential(*self.net)

        self.train()
    def forward(self, x):
        x = self.FQE_net(x)
        return x
    
def train_FQE_step(model, optimizer, x, a, r, x_next, terminal, observed_s_next, target_policy, gamma=0.99):
    optimizer.zero_grad()
    model.train()
    criterion_FQE = nn.MSELoss()

    batch_size = x.shape[0]
    order = torch.arange(batch_size)
    pi_s_next = target_policy(observed_s_next)  #the policy is based on observed state space

    outputs_FQE = model(x)
    with torch.no_grad():
        FQE_next = model(x_next)

    FQE_targets = outputs_FQE.detach().clone()

    FQE_targets[order, a] = r + gamma * FQE_next[order, pi_s_next] * (torch.ones(batch_size) - terminal)

    loss_FQE = criterion_FQE(outputs_FQE, FQE_targets)

    loss_FQE.backward()
    optimizer.step()

    return loss_FQE.item()

def train_FQE(data, num_epochs, target_policy, n_layers=3, n_nodes=32, lr=0.001):  
    #data = [x,a,r,x',terminal,s']
    obs_size = data[0][0].shape[0]
    action_size =len(torch.unique(data[1]))
    model = FQE_eval(obs_size, action_size, n_layers, n_nodes).double()
    mse_loss = nn.MSELoss()
    optimizer = optim.Adam(model.parameters(), lr=lr)
    
    terminal_index = torch.nonzero(data[4]).squeeze().long() #the next index of True terminal is the initial state of the next episode
    terminal_index = terminal_index[:-1] #remove the last terminal state index
    init_index = torch.cat([torch.tensor([0]), terminal_index+1]) #the first state is always initial
    observed_init_index = torch.cat([torch.tensor([0]), terminal_index]) #use s_next to get initial observed states
    initial_x = data[0][init_index]
    observed_init = data[5][observed_init_index]
    target_init = target_policy(observed_init) 
    num_episode = initial_x.shape[0]
    
    batch_size = max((data[1].shape[0])//20,10)
    dataset = TensorDataset(*data)
    batch_data = DataLoader(dataset, batch_size=batch_size)
    
    for epoch in range(num_epochs):
        for x, a, r, x_next, terminal, observed_s_next in batch_data:
            batch_loss = train_FQE_step(model, optimizer, x, a, r, x_next, terminal, observed_s_next, target_policy)
            
        model.eval()
        with torch.no_grad():
            preds = model(initial_x) #Q-value estimation is based on abstracted space
        estimated_value = preds[np.arange(num_episode), target_init]
        estimated_value = estimated_value.mean()
    return estimated_value

def behavior_policy(state, epsilon=0):
    angle = state[2]
    if np.random.binomial(1, epsilon) == 1:
        return np.random.choice([0,1])
    else:
        if angle < 0:
            return 0
        else:
            return 1
        
def random_policy(state, action_size=2, batch=True):
    if batch:
        size = state.shape[0]
        return torch.randint(0, action_size, (size,)).long()
    
    return np.random.choice([0, 1])

def nondyna_policy(state, action=1, batch=True):
    if batch:
        size = state.shape[0]
        return torch.randint(action, action+1, (size,)).long()
    
    return action
def cartpole_policy(state, batch=True):
    if batch:
        pos = state[:,0]
        angle = state[:,2]
        prob_1 = 1 - 1/(1+torch.exp(angle-pos))
        return torch.bernoulli(prob_1).long()
    
    pos = state[0]
    angle = state[2]
    prob_0 = 1/(1+np.exp(angle-pos))
    prob_1 = 1 - prob_0
    return np.random.binomial(1, prob_1)

def angle_policy(state, batch=True):
    if batch:
        angle = state[:,2]
        return (angle>=0).long()
    
    return behavior_policy(state)

def plot_helper(df, title, xticks=None, xlabel="x", ylabel="y"):
    plt.figure(figsize=(6, 4))
    if xticks is None:
        xticks = np.arange(df.shape[0])
        set_xtick = False
    else:
        set_xtick = True
    for column in df.columns:
        plt.scatter(xticks, df[column], label=column[0])
        plt.plot(xticks, df[column], linestyle='-')
    if set_xtick:
        plt.xticks(xticks)
    #plt.axhline(y=hline, color='r', linestyle='-')
    plt.title(title)
    plt.xlabel(xlabel)
    plt.ylabel(ylabel)
    plt.legend()
    plt.show()

class DQN(nn.Module):
    def __init__(self, in_dim, action_size):
        super().__init__()
        self.fc1 = nn.Linear(in_dim, 64)
        self.fc2 = nn.Linear(64, 64)
        self.fc3 = nn.Linear(64, action_size)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x

def target_lunar(state, batch=True):
    policy_net.eval()
    with torch.no_grad():
        if batch:
            a = torch.argmax(policy_net(state[:,:8]), axis=1).long()
            return a
        else:
            state = torch.tensor(state).double().view(1,-1)
            a = np.argmax(policy_net(state).squeeze(1).numpy())
            return a  

In [11]:
policy_net = DQN(8, 4)
model_path = "dqn_lunar_lander.pt"
policy_net.load_state_dict(torch.load(model_path, map_location=torch.device('cpu'))) #loaidng trained model
policy_net = policy_net.double()

## Markov abstraction

In [12]:
coefs = {
    'L_inv': args.L_inv,
    'L_coinv': args.L_coinv,
    # 'L_fwd': args.L_fwd,
    'L_rat': args.L_rat,
    # 'L_fac': args.L_fac,
    'L_dis': args.L_dis,
    'L_ora': args.L_ora,
}
#target_policy = angle_policy
target_policy = target_lunar
def get_batch(x0, x1, a, batch_size=10):
    idx = np.random.choice(len(a), batch_size, replace=False)
    tx0 = torch.as_tensor(x0[idx]).float()
    tx1 = torch.as_tensor(x1[idx]).float()
    ta = torch.as_tensor(a[idx]).long()
    ti = torch.as_tensor(idx).long()
    return tx0, tx1, ta, idx

def encode(data, model):
    #Encode the original (s,a,r,s') tuple by forward abstraction
    #data = [s,a,r,s',terminal] or [x,a,r,x',terminal,s,s']
    data = copy.deepcopy(data)
    sample_s = torch.cat([data[0], data[3][-1].unsqueeze(0)]).float()
    model.eval()
    with torch.no_grad():
        all_states = model.phi(sample_s)
        phi_s = all_states[:-1]
        phi_s_next = all_states[1:]
    data[0] = phi_s.double()
    data[3] = phi_s_next.double()
    return data

def train_abstraction(data, action_size, n_frames, n_updates_per_frame, coefs, type):
    s = data[0]
    a = data[1]
    s_next= data[3]
    input_shape = s.shape[1:]
    batch_size = max((s.shape[0])//20, 10)
    if type == 'markov':
        fnet = FeatureNet(n_actions=action_size,
                      input_shape=input_shape,
                      n_latent_dims=args.latent_dims,
                      n_hidden_layers=1,
                      n_units_per_layer=32,
                      lr=args.learning_rate,
                      coefs=coefs)
    else:
        fnet = AutoEncoder(n_actions=action_size,
                       input_shape=input_shape,
                       n_latent_dims=args.latent_dims,
                       n_hidden_layers=1,
                       n_units_per_layer=32,
                       lr=args.learning_rate,
                       coefs=coefs)
    get_next_batch = (lambda: get_batch(s, s_next, a, batch_size))
    for frame_idx in tqdm(range(n_frames + 1)):
        for _ in range(n_updates_per_frame):
            tx0, tx1, ta, idx = get_next_batch()
            fnet.train_batch(tx0, tx1, ta)
    return fnet

In [14]:
with open("lunar_0.5_data.pickle", 'rb') as f:
    data05 = pickle.load(f)

In [None]:
n_updates_per_frame = 20
n_frames = 20
sample_sizes = [7,13,20]
total_bias = []
for sample_size in sample_sizes:
    samples = data05[sample_size] #samples contain 30 sample, each consists of data from #sample_size of episodes
    oracle = torch.tensor(61.7).repeat(2)
    biases = []
    #data_list = []
    for sample in samples:
        #sample=[s,a,r,s',terminal]
        s_next = sample[3]
        markov_model = train_abstraction(sample, 4, n_frames, n_updates_per_frame, coefs, "markov")
        autoencoder = train_abstraction(sample, 4, n_frames, n_updates_per_frame, coefs, "autoencoder")
        
        markov_sample = [*encode(sample, markov_model), s_next]
        auto_sample = [*encode(sample, autoencoder), s_next]

        markov_value = train_FQE(markov_sample, 20, target_policy)
        auto_value = train_FQE(auto_sample, 20, target_policy)
        
        bias = torch.tensor([markov_value, auto_value ]) - oracle
        biases.append(bias)
    total_bias.append(biases)

with open("lunar0.5_markov_auto_bias.pickle", "wb") as f:
    pickle.dump(total_bias,f)