In [1]:
import numpy as np
import torch
import torchvision
import torchvision.transforms as transforms
from collections import namedtuple, deque
import torch.nn as nn
import torch.nn.functional as F
import random as rand

In [2]:
class world:
    def __init__(self, function_set, dim = 20):
        self.dim = dim
        self.function_set = function_set
        
    def batch_state(self, state_ix): 
    #state_ix is  list of integers, one for each state in batch giving index of non-zero entry
    #1d - for each, will create one-hot vector (length equal to dim) with non-zero entry at corresponding index in x.
        batch_size = len(state_ix)
        state = torch.zeros(batch_size, self.dim)
        state[list(range(batch_size)),state_ix] = 1 #set corresponding index to 1
        return state
    
    def batch_act(self, fun_ix, state_old): 
    #fun_ix is list of integers, one for each state giving index of function to be applied
    #state is batch tensor of state: (B,n)
        batch_size = len(fun_ix)
        assert batch_size ==state_old.size(0)
        state_new = torch.zeros_like(state_old)
        for i in range(batch_size):
            state_new[i] = self.function_set[fun_ix[i]](state_old[i])
        return state_new
    

In [3]:
class phinet(nn.Module):
    def __init__(self, input_dim):
        super(phinet,self).__init__()
        self.fc1 = nn.Linear(input_dim,10)
        self.fc2 = nn.Linear(10,1)
    def forward(self,x):
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        x = F.relu(x)
        return x

In [19]:
def train(model, device, out_world, in_world, optimizer, batch_size, print_interval, num_batches):
    #learn phi map from out_world to in_world.
    model.train()
#     lower_bd, upper_bd = out_world.dim//3, 2*out_world.dim//3 #ensure location is within bounded range so no edge effects.
    for step in range(num_batches):
        w_ix = rand.choices(population = list(range(1,out_world.dim-1)), k = batch_size)
#         w_ix = rand.choices(population = list(range(lower_bd,upper_bd)), k = batch_size)
        w = out_world.batch_state(w_ix) #generate (w)
        w = w.to(device)
        fun_ix = rand.choices([0,1],k=batch_size) #randomly select f
        psif_w = out_world.batch_act(fun_ix, w) #generate psi(f)(w)
        phpsf_w = model(psif_w)
        return 
        phi_w = model(w) #generate phi(w)
        fph_w = in_world.batch_act(fun_ix,phi_w) #generate f(phi(w))
        optimizer.zero_grad()
        loss = F.mse_loss(phpsf_w, fph_w)
        loss.backward()
        optimizer.step()
        if step % print_interval == 0:
            print('{} :\t Loss: {:.6f}'.format(
                step, loss.item()))


In [5]:
import torch.optim as optim

In [10]:
device = torch.device(1)

In [11]:
out_dim = 16
trans_left = torch.zeros_like(torch.eye(out_dim)).to(device)
trans_right = torch.zeros_like(torch.eye(out_dim)).to(device)
for i in range(out_dim-1):
    trans_right[i+1,i] = 1
    trans_left[i,i+1] = 1

In [12]:
phi = phinet(out_dim).to(device)
optimizer = optim.Adam(phi.parameters(), lr=0.001) #e-1

In [13]:
out_world = world(function_set = [lambda x: torch.mv(trans_left,x), lambda x: torch.mv(trans_right,x)], dim = out_dim)
in_world = world(function_set = [lambda x: x+1, lambda x: x-1], dim = 1)

In [14]:
batch_size = 50
print_interval = 100
num_batches = 2000

In [21]:
train(phi, device, out_world, in_world, optimizer, batch_size, print_interval, num_batches)

In [18]:
test_ix = rand.choices(population = list(range(4,12)), k = 20)
test_w = out_world.batch_state(test_ix) #generate (w)
test_w = test_w.to(device)


In [315]:
phi.eval()
phi(test_w)

tensor([[-3.2923],
        [-0.2947],
        [-3.2923],
        [-1.2939],
        [-3.2923],
        [ 1.7043],
        [-4.2918],
        [-2.2930],
        [ 0.7048],
        [ 1.7043],
        [-0.2947],
        [-2.2930],
        [-0.2947],
        [-4.2918],
        [-5.2913],
        [ 1.7043],
        [-0.2947],
        [-4.2918],
        [-1.2939],
        [-1.2939]], device='cuda:1', grad_fn=<AddmmBackward>)

In [316]:
test_ix

[9, 6, 9, 7, 9, 4, 10, 8, 5, 4, 6, 8, 6, 10, 11, 4, 6, 10, 7, 7]

In [None]:
#Train

#1. input w fed through phi-net to get phi(w)
w 
#2. randomly sample among F

#3. map using fixed psito F'

#4. generate w' = phi(psi(F)w)

#5. compute loss ||w' - w|| and compute