In [21]:
%load_ext autoreload
%autoreload 2

from __future__ import absolute_import, division, print_function
import argparse
from datetime import datetime
import imp
import numpy as np
import torch
from utils.monitor import Monitor
from envs.mo_env import MultiObjectiveEnv
# from gym_env_moll.multiobjective import LunarLander
# import gym
import json


use_cuda =  torch.cuda.is_available()
FloatTensor = torch.cuda.FloatTensor if use_cuda else torch.FloatTensor
LongTensor = torch.cuda.LongTensor if use_cuda else torch.LongTensor
ByteTensor = torch.cuda.ByteTensor if use_cuda else torch.ByteTensor
Tensor = FloatTensor


def generate_next_preference(preference, alpha=10000):
    preference = np.array(preference)
    preference += 1e-6
    
    return FloatTensor(np.random.dirichlet(alpha*preference))

def init_log_file(log_file_str):
    with open(log_file_str, mode='w+') as log_file:
        log_file.write('[\n')

def write_log(log_file_str, data, is_json = False):
    with open(log_file_str, mode='a+') as log_file:
        if is_json:
            json.dump(data, log_file)
        else:
            log_file.write(data)

def train(env, agent, args):
    log_file_str = './logs/multihead3_Q_log_' + args.env_name + '_' + str(datetime.today().strftime("%Y_%m_%d")) + '.json'
    save_loc = './saved_models/'
    save_file_name = 'multihead3_Q_log_' + args.env_name + '_' + str(datetime.today().strftime("%Y_%m_%d"))

    init_log_file(log_file_str)
    fixed_probe = FloatTensor([0.8, 0.2, 0.0, 0.0, 0.0, 0.0])
    env.reset()
    alpha = args.alpha

    dirichet_param = 0.1
    dirichet_param_schedule = 0.9/(args.episode_num - 1000)

    max_steps_in_env = 100
    for num_eps in range(60):
        terminal = False
        env.reset()
        q_loss = 0
        exploration_loss = 0
        cnt = 0
        tot_reward = 0

        probe = np.random.randn(6)
        probe = FloatTensor(np.abs(probe)/np.linalg.norm(probe, ord=1))
    
        # probe = generate_next_preference(np.random.uniform(size=len(env.reward_spec)), alpha = 1)
        # probe = generate_next_preference(np.ones(shape=len(env.reward_spec))*dirichet_param, alpha = 1)
        
        # if dirichet_param < 0.99:
        #     dirichet_param += dirichet_param_schedule
        # else:
        #     dirichet_param = 0.99
        

        # if num_eps % 100 == 0:
        #     probe = FloatTensor([0.98, 0.02])
        #     probe = generate_next_preference(probe, 200)
        
        write_log(log_file_str, '[')

        while not terminal:
            state = env.observe()
            action = agent.act(state, probe)
            next_state, reward, terminal = env.step(action)
            next_preference = generate_next_preference(probe, alpha)
            
            agent.memorize(state, action, next_state, reward, terminal, probe, next_preference)
            loss = agent.learn()
            q_loss += loss[0]
            exploration_loss += loss[1]

            if cnt > max_steps_in_env:
                terminal = True
                agent.reset()
            
            tot_reward = tot_reward + (fixed_probe.cpu().numpy().dot(reward)) * np.power(args.gamma, cnt)
            probe = next_preference
            cnt = cnt + 1

            if reward[0] > 8:
                print(reward, state)

            if args.log and (num_eps % 50) == 0:
                _, Q, q = agent.predict(probe, state)

                log = {
                    'state':state.tolist(),
                    'action':action,
                    'reward':reward.tolist(),
                    'terminal':terminal,
                    'probe':probe.detach().numpy().tolist(),
                    'q_val': q.tolist(),
                    'cnt': cnt,
                    'num_eps': num_eps
                }

                print('probe', probe.detach().numpy().tolist())
                print('state', log['state'])
                print('action', log['action'])
                print('reward', log['reward'])
                print('q_val', log['q_val'])
                print('Q_val', Q.detach().numpy().tolist())
                print('tot_reward', tot_reward)
                print('cnt', log['cnt'])
                print('num_eps', log['num_eps'])
                print('eps', agent.epsilon)
                print('---------------------------------------')

                write_log(log_file_str, log, True)

                if not terminal:
                    write_log(log_file_str, ',\n')
                else:
                    write_log(log_file_str, '\n],\n')


        _, Q, q = agent.predict(fixed_probe)

        if args.env_name == "dst":
            act_1 = q[0, 3]
            act_2 = q[0, 1]
        elif args.env_name in ['ft', 'ft5', 'ft7']:
            act_1 = q[0, 1]
            act_2 = q[0, 0]

        if args.method == "crl-naive":
            act_1 = act_1.data.cpu()
            act_2 = act_2.data.cpu()
        elif args.method == "crl-envelope":
            act_1 = probe.dot(act_1.data)
            act_2 = probe.dot(act_2.data)
        elif args.method == "crl-energy":
            act_1 = probe.dot(act_1.data)
            act_2 = probe.dot(act_2.data)
        print("eps %d reward (1) %0.2f, the Q is %0.2f | %0.2f; the probe is %0.2f | %0.2f; dirichet: %0.3f; q_loss: %0.4f; exploration_loss: %0.4f" % (
            num_eps,
            tot_reward,
            act_1,
            act_2,
            probe[0],
            probe[1],
            dirichet_param,
            q_loss / cnt,
            exploration_loss/cnt))


        if (num_eps+1) % 500 == 0:
            agent.save(save_loc, save_file_name+"_eps_"+str(num_eps))

    
    agent.save(save_loc, save_file_name+"_eps_"+str(num_eps))
    return agent


parser = argparse.ArgumentParser(description='MORL')
# CONFIG
parser.add_argument('--env-name', default='ft', metavar='ENVNAME',
                    help='environment to train on: dst | ft | ft5 | ft7')
parser.add_argument('--method', default='crl-naive', metavar='METHODS',
                    help='methods: crl-naive | crl-envelope | crl-energy')
parser.add_argument('--model', default='linear', metavar='MODELS',
                    help='linear | cnn | cnn + lstm')
parser.add_argument('--gamma', type=float, default=0.99, metavar='GAMMA',
                    help='gamma for infinite horizonal MDPs')
# TRAINING
parser.add_argument('--mem-size', type=int, default=4000, metavar='M',
                    help='max size of the replay memory')
parser.add_argument('--batch-size', type=int, default=256, metavar='B',
                    help='batch size')
parser.add_argument('--lr', type=float, default=1e-3, metavar='LR',
                    help='learning rate')
parser.add_argument('--epsilon', type=float, default=0.5, metavar='EPS',
                    help='epsilon greedy exploration')
parser.add_argument('--epsilon-decay', default=True, action='store_true',
                    help='linear epsilon decay to zero')
parser.add_argument('--weight-num', type=int, default=16, metavar='WN',
                    help='number of sampled weights per iteration')
parser.add_argument('--episode-num', type=int, default=10000, metavar='EN',
                    help='number of episodes for training')
parser.add_argument('--optimizer', default='Adam', metavar='OPT',
                    help='optimizer: Adam | RMSprop')
parser.add_argument('--update-freq', type=int, default=100, metavar='OPT',
                    help='optimizer: Adam | RMSprop')
parser.add_argument('--beta', type=float, default=0.01, metavar='BETA',
                    help='(initial) beta for evelope algorithm, default = 0.01')
parser.add_argument('--homotopy', default=False, action='store_true',
                    help='use homotopy optimization method')
# LOG & SAVING
parser.add_argument('--serialize', default=False, action='store_true',
                    help='serialize a model')
parser.add_argument('--save', default='crl/naive/saved/', metavar='SAVE',
                    help='path for saving trained models')
parser.add_argument('--name', default='', metavar='name',
                    help='specify a name for saving the model')
parser.add_argument('--log', default='crl/naive/logs/', metavar='LOG',
                    help='path for recording training informtion')

use_cuda =  torch.cuda.is_available()
device = torch.device('cuda' if use_cuda else 'cpu')

FloatTensor = torch.cuda.FloatTensor if use_cuda else torch.FloatTensor
LongTensor = torch.cuda.LongTensor if use_cuda else torch.LongTensor
ByteTensor = torch.cuda.ByteTensor if use_cuda else torch.ByteTensor
Tensor = FloatTensor

args = parser.parse_args(args=[])

 # setup the environment
    # args.env_name = 'Lunar'
# env = gym.make('gym.envs.multiobjective/LunarLander')
env = MultiObjectiveEnv(args.env_name)
# get state / action / reward sizes
state_size = len(env.state_spec)
action_size = env.action_spec[2][1] - env.action_spec[2][0]
reward_size = len(env.reward_spec)

# generate an agent for initial training
agent = None

args.alpha = 4000

from crl.envelope.meta_mod import MetaAgent
# from crl.envelope.models.multiheadoutput import EnvelopeLinearCQN
from crl.envelope.models.multihead3 import EnvelopeLinearCQN
from crl.envelope.exemplar import Exemplar

if args.serialize:
    model = torch.load("{}{}.pkl".format(args.save,
                                 "m.{}_e.{}_n.{}".format(args.model, args.env_name, args.name)))


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [317]:
import numpy as np


def generate_next_preference(preference, alpha=10000):
    cov = np.identity(preference.shape[0])*0.0001
    y = np.random.multivariate_normal(preference, cov, 1)[0]
    while np.any(y < 0):
        y = np.random.multivariate_normal(preference, cov, 1)[0]
    return y
x = np.random.uniform(size=3)
y = generate_next_preference(x)
y, x, y.shape, x.shape

(array([0.60557752, 0.26519936, 0.13173717]),
 array([0.59234412, 0.26514629, 0.11415272]),
 (3,),
 (3,))

In [169]:
model = EnvelopeLinearCQN(state_size, action_size, reward_size)
exemplar_model = Exemplar(reward_size, reward_size, 1e-3, -1, device, 3)
agent = MetaAgent(model, exemplar_model, args, is_train=True)  

agent = train(env, agent, args)

probe [0.08870701491832733, 0.1992291361093521, 0.109388068318367, 0.23882588744163513, 0.22652246057987213, 0.13732744753360748]
state [0, 0]
action 0
reward [0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
q_val [[-0.020355885848402977, -0.06106964498758316]]
Q_val [[[-0.05712473765015602, -0.05365177243947983, 0.054477281868457794, -0.04658041149377823, 0.030275505036115646, 0.002545067109167576], [-0.048632439225912094, -0.0428861528635025, -0.04601268470287323, 0.03273971751332283, -0.049255602061748505, -0.009033589623868465]]]
tot_reward 0.0
cnt 1
num_eps 0
eps 0.5
---------------------------------------
probe [0.07985246181488037, 0.20434263348579407, 0.11487565189599991, 0.23212654888629913, 0.21867218613624573, 0.1501305252313614]
state [1, 0]
action 0
reward [0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
q_val [[-0.020011883229017258, -0.060004182159900665]]
Q_val [[[-0.05555865168571472, -0.05425175279378891, 0.04707087576389313, -0.050934769213199615, 0.03638920560479164, 0.001265830360352993], [-0.0531522

eps 42 reward (1) 0.86, the Q is -0.06 | -0.03; the probe is 0.12 | 0.05; dirichet: 0.100; q_loss: 0.8620; exploration_loss: 0.0506
eps 43 reward (1) 2.48, the Q is -0.06 | -0.03; the probe is 0.10 | 0.16; dirichet: 0.100; q_loss: 1.9171; exploration_loss: -5.4208
eps 44 reward (1) 2.02, the Q is -0.06 | -0.03; the probe is 0.45 | 0.09; dirichet: 0.100; q_loss: 1.5937; exploration_loss: -316.9390
[8.30192712 0.40973443 1.69099424 4.54961192 2.64473811 0.59753994] [ 5 22]
eps 45 reward (1) 6.39, the Q is -0.06 | -0.03; the probe is 0.24 | 0.27; dirichet: 0.100; q_loss: 1.5179; exploration_loss: -12961.9158
eps 46 reward (1) 3.52, the Q is -0.06 | -0.03; the probe is 0.05 | 0.21; dirichet: 0.100; q_loss: 1.4545; exploration_loss: -293900.3483
eps 47 reward (1) 4.00, the Q is -0.06 | -0.03; the probe is 0.26 | 0.03; dirichet: 0.100; q_loss: 1.3064; exploration_loss: -3930411.9375
eps 48 reward (1) 5.71, the Q is -0.06 | -0.03; the probe is 0.13 | 0.20; dirichet: 0.100; q_loss: 1.1665; exp

In [170]:
from torch.autograd import Variable

minibatch = agent.sample(agent.trans_mem, agent.priority_mem, agent.batch_size)
batchify = lambda x: list(x) * agent.weight_num
state_batch = batchify(map(lambda x: x.s.unsqueeze(0), minibatch))
action_batch = batchify(map(lambda x: LongTensor([x.a]), minibatch))
reward_batch = batchify(map(lambda x: x.r.unsqueeze(0), minibatch))
next_state_batch = batchify(map(lambda x: x.s_.unsqueeze(0), minibatch))
terminal_batch = batchify(map(lambda x: x.d, minibatch))

# w_batch = batchify(map(lambda x: x.w, minibatch))
# w_batch = Variable(torch.stack(w_batch), requires_grad=False).type(FloatTensor)

w_batch = list(map(lambda x: x.w, minibatch))
w_batch = Variable(torch.stack(w_batch), requires_grad=False).type(FloatTensor)
next_w_batch = list(map(lambda x: x.w_, minibatch))
next_w_batch = Variable(torch.stack(next_w_batch), requires_grad=False).type(FloatTensor)
w_batch, next_w_batch = agent.generate_neighbours(w_batch, next_w_batch, agent.weight_num)

exemplar_batch_size = 10
index_list = np.random.randint(0, w_batch.shape[0], size=exemplar_batch_size)

# sample1 = torch.cat((torch.cat(state_batch, dim=0)[index_list], w_batch[index_list]), dim=1)
sample1 = w_batch[index_list]
positive = sample1[0:int(sample1.shape[0]/2)]
negative = sample1[int(sample1.shape[0]/2):]

sample1 = torch.cat((positive, positive), axis=0)
sample2 = torch.cat((positive, negative), axis=0)

target = torch.cat((torch.ones((positive.shape[0], 1)), torch.zeros((negative.shape[0],1))))

exploration_loss = agent.exemplar_exploration.update(sample1, sample2, target)
exploration_loss

(array(-3.175422e+12, dtype=float32),
 array([2.3442350e+11, 2.2944825e+11, 2.3466462e+11, 2.3141741e+11,
        2.3584940e+11, 2.3442350e+11, 2.2944825e+11, 2.3466462e+11,
        2.3141741e+11, 2.3584940e+11], dtype=float32),
 array([2.9142472e+12, 2.8992476e+12, 2.9603833e+12, 2.9534338e+12,
        2.9773338e+12, 2.9434330e+12, 2.9354788e+12, 2.9251747e+12,
        2.9450431e+12, 2.9688379e+12], dtype=float32))

In [171]:
sample1

tensor([[0.2704, 0.0390, 0.1751, 0.4148, 0.0037, 0.0969],
        [0.1703, 0.3134, 0.0524, 0.0937, 0.0288, 0.3414],
        [0.2636, 0.2194, 0.2767, 0.0560, 0.0835, 0.1007],
        [0.2969, 0.2855, 0.0879, 0.0326, 0.1278, 0.1693],
        [0.4453, 0.0721, 0.2340, 0.1131, 0.1216, 0.0139],
        [0.2704, 0.0390, 0.1751, 0.4148, 0.0037, 0.0969],
        [0.1703, 0.3134, 0.0524, 0.0937, 0.0288, 0.3414],
        [0.2636, 0.2194, 0.2767, 0.0560, 0.0835, 0.1007],
        [0.2969, 0.2855, 0.0879, 0.0326, 0.1278, 0.1693],
        [0.4453, 0.0721, 0.2340, 0.1131, 0.1216, 0.0139]])

In [172]:
sample2

tensor([[2.7042e-01, 3.8990e-02, 1.7514e-01, 4.1480e-01, 3.7394e-03, 9.6915e-02],
        [1.7032e-01, 3.1336e-01, 5.2384e-02, 9.3736e-02, 2.8825e-02, 3.4138e-01],
        [2.6359e-01, 2.1941e-01, 2.7675e-01, 5.6042e-02, 8.3490e-02, 1.0073e-01],
        [2.9694e-01, 2.8548e-01, 8.7858e-02, 3.2611e-02, 1.2781e-01, 1.6930e-01],
        [4.4533e-01, 7.2120e-02, 2.3401e-01, 1.1310e-01, 1.2156e-01, 1.3882e-02],
        [1.0452e-01, 1.1155e-01, 1.4873e-01, 1.7704e-01, 2.3103e-01, 2.2714e-01],
        [2.5565e-02, 9.8072e-05, 2.1025e-01, 3.5493e-01, 2.2789e-01, 1.8127e-01],
        [2.8549e-01, 8.6374e-02, 1.9811e-01, 2.7463e-01, 2.7206e-02, 1.2820e-01],
        [2.2068e-01, 5.9762e-02, 2.0818e-01, 9.6294e-02, 1.8529e-01, 2.2979e-01],
        [7.5183e-02, 2.0209e-01, 3.5177e-01, 3.1961e-02, 1.8940e-01, 1.4959e-01]])

In [173]:
agent.exemplar_exploration.get_prob(torch.Tensor(sample2))

tensor([0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000,
        0.5000])

In [174]:
agent.exemplar_exploration.encoder1(sample1)

(tensor([[-411236.5000,  393350.4688,  384208.4062],
         [-406850.9375,  389155.6562,  380111.0938],
         [-411448.1250,  393552.9062,  384406.1250],
         [-408593.0938,  390822.0625,  381738.7500],
         [-412486.1250,  394545.7500,  385375.8750],
         [-411236.5000,  393350.4688,  384208.4062],
         [-406850.9375,  389155.6562,  380111.0938],
         [-411448.1250,  393552.9062,  384406.1250],
         [-408593.0938,  390822.0625,  381738.7500],
         [-412486.1250,  394545.7500,  385375.8750]], grad_fn=<AddmmBackward0>),
 tensor([ 0.1171,  0.1172, 49.1794], grad_fn=<ExpBackward0>))

In [175]:
agent.exemplar_exploration.encoder2(sample2)

(tensor([[-1337813.6250,  1471313.8750, -1381276.7500],
         [-1334363.7500,  1467519.6250, -1377714.8750],
         [-1348354.8750,  1482906.8750, -1392160.5000],
         [-1346769.8750,  1481163.8750, -1390524.1250],
         [-1352208.8750,  1487145.6250, -1396139.7500],
         [-1344490.5000,  1478657.0000, -1388170.6250],
         [-1342675.3750,  1476660.7500, -1386296.5000],
         [-1340317.6250,  1474067.7500, -1383862.1250],
         [-1344859.2500,  1479062.5000, -1388551.3750],
         [-1350278.1250,  1485022.2500, -1394146.2500]],
        grad_fn=<AddmmBackward0>),
 tensor([0.1171, 0.1170, 0.1172], grad_fn=<ExpBackward0>))

In [176]:
sample2

tensor([[2.7042e-01, 3.8990e-02, 1.7514e-01, 4.1480e-01, 3.7394e-03, 9.6915e-02],
        [1.7032e-01, 3.1336e-01, 5.2384e-02, 9.3736e-02, 2.8825e-02, 3.4138e-01],
        [2.6359e-01, 2.1941e-01, 2.7675e-01, 5.6042e-02, 8.3490e-02, 1.0073e-01],
        [2.9694e-01, 2.8548e-01, 8.7858e-02, 3.2611e-02, 1.2781e-01, 1.6930e-01],
        [4.4533e-01, 7.2120e-02, 2.3401e-01, 1.1310e-01, 1.2156e-01, 1.3882e-02],
        [1.0452e-01, 1.1155e-01, 1.4873e-01, 1.7704e-01, 2.3103e-01, 2.2714e-01],
        [2.5565e-02, 9.8072e-05, 2.1025e-01, 3.5493e-01, 2.2789e-01, 1.8127e-01],
        [2.8549e-01, 8.6374e-02, 1.9811e-01, 2.7463e-01, 2.7206e-02, 1.2820e-01],
        [2.2068e-01, 5.9762e-02, 2.0818e-01, 9.6294e-02, 1.8529e-01, 2.2979e-01],
        [7.5183e-02, 2.0209e-01, 3.5177e-01, 3.1961e-02, 1.8940e-01, 1.4959e-01]])

In [177]:
agent.exemplar_exploration.encoder1.input_layer.weight

Parameter containing:
tensor([[ 4.4945,  4.9176,  5.6169,  5.1803,  5.3713,  4.9508],
        [-0.5287, -0.4582, -0.3218, -0.1352, -0.4411,  0.2590],
        [ 5.5103,  5.1233,  5.4005,  4.8351,  4.6252,  5.1524],
        [ 4.8499,  3.8873,  3.9201,  4.3542,  3.6390,  3.7796],
        [-0.1354, -0.1777,  0.2604, -0.3452,  0.1909, -0.4116],
        [ 5.4439,  5.6386,  5.5182,  5.1969,  4.7592,  4.6876]],
       requires_grad=True)

In [49]:
agent.exemplar_exploration.encoder1.middle_layers[0].weight

Parameter containing:
tensor([[-0.4052, -0.2013,  0.1476,  0.0046, -0.1058, -0.1568],
        [ 0.6673,  0.0222, -0.4492,  0.0455, -0.2971,  0.2218],
        [-0.0672, -0.0889, -0.3498, -0.5678, -0.6732,  0.3505],
        [-0.3164,  0.2907, -0.4615, -0.3610,  0.0238,  0.6773],
        [ 0.5307,  0.3437, -0.3207, -0.5510,  0.5744,  0.1655],
        [-0.1069,  0.1500,  0.8121,  0.5394,  0.0563, -0.1241]],
       requires_grad=True)

In [50]:
agent.exemplar_exploration.encoder1.middle_layers[1].weight

Parameter containing:
tensor([[ 0.0508,  0.1165, -0.4575, -0.5614, -0.5016, -0.3066],
        [-0.6025,  0.4505,  0.3224,  0.1770, -0.4793,  0.5280],
        [ 0.4596,  0.3312, -0.0333,  0.6614,  0.4940,  0.1429],
        [ 0.5125,  0.3081, -0.7372, -0.3431, -0.1594,  0.3183],
        [-0.0357, -0.5719, -0.4612,  0.1695,  0.5028,  0.4007],
        [-0.6890,  0.4767, -0.2983, -0.1980, -0.4688,  0.4817]],
       requires_grad=True)

In [51]:
agent.exemplar_exploration.encoder1.output_layer.weight

Parameter containing:
tensor([[-0.2280, -0.0511,  0.0087, -0.1226,  0.1295, -0.4337],
        [ 0.7213, -0.6539, -0.5365,  0.8129, -0.0790,  0.1673],
        [ 0.7710,  0.1478, -0.3960,  0.1658, -0.3018, -0.0182]],
       requires_grad=True)

In [52]:
agent.exemplar_exploration.encoder2.input_layer.weight

Parameter containing:
tensor([[-0.4219,  0.0388,  0.6717,  0.5690,  0.0200,  0.3857],
        [-0.2483, -0.8850, -0.5738, -0.2108, -0.6203,  0.4135],
        [ 0.3906, -0.6704, -0.3467,  0.4073,  0.5902, -0.4055],
        [ 0.0849,  0.4339, -0.4858, -0.1623, -0.0183, -0.3314],
        [ 0.2226,  0.7325, -0.3647,  0.0252,  0.2816, -0.7248],
        [ 0.3811,  0.2219, -0.3775, -0.5197, -0.0470,  0.2148]],
       requires_grad=True)

In [55]:
agent.exemplar_exploration.encoder2.middle_layers[1].weight

Parameter containing:
tensor([[-0.0442, -0.6948,  0.0905, -0.5359,  0.3059, -0.6966],
        [ 0.4607, -0.4929, -0.6293,  0.3777, -0.0472, -0.6598],
        [ 0.5017, -0.2959,  0.6547, -0.4876,  0.4253, -0.1499],
        [-0.7312,  0.4837, -0.5116, -0.2307, -0.5350, -0.3990],
        [ 0.0530, -0.3585,  0.5933, -0.4743, -0.4582,  0.5733],
        [-0.1381,  0.5811,  0.1330,  0.5490,  0.2900, -0.3450]],
       requires_grad=True)

In [56]:
agent.exemplar_exploration.encoder2.middle_layers[0].weight

Parameter containing:
tensor([[ 0.2358,  0.7143, -0.0748,  0.3639,  0.4208, -0.1441],
        [ 0.6619, -0.5357, -0.3894,  0.5004,  0.4789,  0.0025],
        [-0.2524,  0.2116,  0.2926, -0.1054, -0.0636,  0.4401],
        [-0.7706,  0.2716, -0.7749, -0.4349, -0.9102,  0.6747],
        [-0.4367, -0.2666, -0.1968, -0.6260, -0.5411, -0.2329],
        [ 0.2322,  0.3201,  0.5454, -0.6749, -0.2609, -0.1297]],
       requires_grad=True)

In [57]:
agent.exemplar_exploration.encoder2.output_layer.weight

Parameter containing:
tensor([[ 0.4517, -0.5651,  0.3831, -0.7905, -0.5621,  0.5282],
        [ 0.5757,  0.6838, -0.4832,  0.7202,  0.5067, -0.0118],
        [ 0.1186, -0.5849,  0.3356,  0.5019, -0.4892, -0.7697]],
       requires_grad=True)