In [39]:
import torch
import torch.nn as nn
import torch.nn.functional as F


class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(4, 300)
        self.fc2 = nn.Linear(300, 400)
        self.fc3 = nn.Linear(400, 2)

    def forward(self, x):  # x are the states: [d_obs_x, d_obs_y, d_goal_x, d_goal_y]
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

net = Net()
print(net)

Net(
  (fc1): Linear(in_features=4, out_features=300, bias=True)
  (fc2): Linear(in_features=300, out_features=400, bias=True)
  (fc3): Linear(in_features=400, out_features=2, bias=True)
)


In [40]:
# Testing the network

random_data = torch.rand((1, 4))

result = net(random_data)
print (result)

tensor([[0.0121, 0.0248]], grad_fn=<AddmmBackward>)


In [41]:
# Preparing training and testing data
import numpy as np

data_root_path = f'/home/yigit/phd/yigit_phd_thesis/cnmp/data/sfm/small_env_changing_s_g/demonstrations/'

x_train_ = torch.cat((torch.from_numpy(np.load(f'{data_root_path}d_x.npy')).float(), torch.from_numpy(np.load(f'{data_root_path}d_gamma.npy')).float()), 2)
y_train_ = torch.from_numpy(np.load(f'{data_root_path}d_y.npy')).float()
x_test_ = torch.cat((torch.from_numpy(np.load(f'{data_root_path}v_d_x.npy')).float(), torch.from_numpy(np.load(f'{data_root_path}v_d_gamma.npy')).float()), 2)
y_test_ = torch.from_numpy(np.load(f'{data_root_path}v_d_y.npy')).float()

# flattening first 2 dims to get (7068*400, 4) training instances
num_trajs_train, num_steps, d_x_train = x_train_.shape
_, _, d_y_train = y_train_.shape
num_trajs_test, _, d_x_test = x_test_.shape
_, _, d_y_test = y_test_.shape

x_train = x_train_.view(num_trajs_train*num_steps, d_x_train)
y_train = y_train_.view(num_trajs_train*num_steps, d_y_train)
x_test = x_test_.view(num_trajs_test*num_steps, d_x_test)
y_test = y_test_.view(num_trajs_test*num_steps, d_y_test)

print(x_train.shape)
print(y_train.shape)
print(x_test.shape)
print(y_test.shape)

del x_train_, y_train_, x_test_, y_test_

torch.Size([2827200, 4])
torch.Size([2827200, 2])
torch.Size([57600, 4])
torch.Size([57600, 2])


In [42]:
batch_size = 100

num_instances_train, _ = x_train.shape

In [43]:
import torch.optim as optim

criterion = nn.MSELoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

epochs = 130

for epoch in range(epochs):  # loop over the dataset <epochs> times

    train_ids = torch.randperm(num_instances_train)  # shuffle instances every epoch
    num_iters = int(num_instances_train/batch_size)
    
    running_loss = 0.0
    for i in range(num_iters):
        batch_train_ids = train_ids[i*batch_size:(i+1)*batch_size]
        
        # get the inputs; data is a list of [inputs, labels]
        inputs = x_train[batch_train_ids, :]
        labels = y_train[batch_train_ids, :]

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        # print statistics
        running_loss += loss.item()
        if i % 2000 == 1999:    # print every 2000 mini-batches
            print('[%d, %5d] loss: %.3f' %
                  (epoch + 1, i + 1, running_loss / 2000))
            running_loss = 0.0

print('Finished Training')

[1,  2000] loss: 0.080
[1,  4000] loss: 0.045
[1,  6000] loss: 0.038
[1,  8000] loss: 0.033
[1, 10000] loss: 0.027
[1, 12000] loss: 0.023
[1, 14000] loss: 0.019
[1, 16000] loss: 0.016
[1, 18000] loss: 0.013
[1, 20000] loss: 0.012
[1, 22000] loss: 0.010
[1, 24000] loss: 0.009
[1, 26000] loss: 0.009
[1, 28000] loss: 0.008
[2,  2000] loss: 0.007
[2,  4000] loss: 0.007
[2,  6000] loss: 0.006
[2,  8000] loss: 0.006
[2, 10000] loss: 0.005
[2, 12000] loss: 0.005
[2, 14000] loss: 0.005
[2, 16000] loss: 0.004
[2, 18000] loss: 0.004
[2, 20000] loss: 0.004
[2, 22000] loss: 0.004
[2, 24000] loss: 0.004
[2, 26000] loss: 0.004
[2, 28000] loss: 0.003
[3,  2000] loss: 0.003
[3,  4000] loss: 0.003
[3,  6000] loss: 0.003
[3,  8000] loss: 0.003
[3, 10000] loss: 0.003
[3, 12000] loss: 0.003
[3, 14000] loss: 0.003
[3, 16000] loss: 0.002
[3, 18000] loss: 0.002
[3, 20000] loss: 0.002
[3, 22000] loss: 0.002
[3, 24000] loss: 0.002
[3, 26000] loss: 0.002
[3, 28000] loss: 0.002
[4,  2000] loss: 0.002
[4,  4000] 

[25, 24000] loss: 0.001
[25, 26000] loss: 0.001
[25, 28000] loss: 0.001
[26,  2000] loss: 0.001
[26,  4000] loss: 0.001
[26,  6000] loss: 0.001
[26,  8000] loss: 0.001
[26, 10000] loss: 0.001
[26, 12000] loss: 0.001
[26, 14000] loss: 0.001
[26, 16000] loss: 0.001
[26, 18000] loss: 0.001
[26, 20000] loss: 0.001
[26, 22000] loss: 0.001
[26, 24000] loss: 0.001
[26, 26000] loss: 0.001
[26, 28000] loss: 0.001
[27,  2000] loss: 0.001
[27,  4000] loss: 0.001
[27,  6000] loss: 0.001
[27,  8000] loss: 0.001
[27, 10000] loss: 0.001
[27, 12000] loss: 0.001
[27, 14000] loss: 0.001
[27, 16000] loss: 0.001
[27, 18000] loss: 0.001
[27, 20000] loss: 0.001
[27, 22000] loss: 0.001
[27, 24000] loss: 0.001
[27, 26000] loss: 0.001
[27, 28000] loss: 0.001
[28,  2000] loss: 0.001
[28,  4000] loss: 0.001
[28,  6000] loss: 0.001
[28,  8000] loss: 0.001
[28, 10000] loss: 0.001
[28, 12000] loss: 0.001
[28, 14000] loss: 0.001
[28, 16000] loss: 0.001
[28, 18000] loss: 0.001
[28, 20000] loss: 0.001
[28, 22000] loss

[50,  8000] loss: 0.001
[50, 10000] loss: 0.001
[50, 12000] loss: 0.001
[50, 14000] loss: 0.001
[50, 16000] loss: 0.001
[50, 18000] loss: 0.001
[50, 20000] loss: 0.001
[50, 22000] loss: 0.001
[50, 24000] loss: 0.001
[50, 26000] loss: 0.001
[50, 28000] loss: 0.001
[51,  2000] loss: 0.001
[51,  4000] loss: 0.001
[51,  6000] loss: 0.001
[51,  8000] loss: 0.001
[51, 10000] loss: 0.001
[51, 12000] loss: 0.001
[51, 14000] loss: 0.001
[51, 16000] loss: 0.001
[51, 18000] loss: 0.001
[51, 20000] loss: 0.001
[51, 22000] loss: 0.001
[51, 24000] loss: 0.001
[51, 26000] loss: 0.001
[51, 28000] loss: 0.001
[52,  2000] loss: 0.001
[52,  4000] loss: 0.001
[52,  6000] loss: 0.001
[52,  8000] loss: 0.001
[52, 10000] loss: 0.001
[52, 12000] loss: 0.001
[52, 14000] loss: 0.001
[52, 16000] loss: 0.001
[52, 18000] loss: 0.001
[52, 20000] loss: 0.001
[52, 22000] loss: 0.001
[52, 24000] loss: 0.001
[52, 26000] loss: 0.001
[52, 28000] loss: 0.001
[53,  2000] loss: 0.001
[53,  4000] loss: 0.001
[53,  6000] loss

[74, 20000] loss: 0.001
[74, 22000] loss: 0.001
[74, 24000] loss: 0.001
[74, 26000] loss: 0.001
[74, 28000] loss: 0.001
[75,  2000] loss: 0.001
[75,  4000] loss: 0.001
[75,  6000] loss: 0.001
[75,  8000] loss: 0.001
[75, 10000] loss: 0.001
[75, 12000] loss: 0.001
[75, 14000] loss: 0.001
[75, 16000] loss: 0.001
[75, 18000] loss: 0.001
[75, 20000] loss: 0.001
[75, 22000] loss: 0.001
[75, 24000] loss: 0.001
[75, 26000] loss: 0.001
[75, 28000] loss: 0.001
[76,  2000] loss: 0.001
[76,  4000] loss: 0.001
[76,  6000] loss: 0.001
[76,  8000] loss: 0.001
[76, 10000] loss: 0.001
[76, 12000] loss: 0.001
[76, 14000] loss: 0.001
[76, 16000] loss: 0.001
[76, 18000] loss: 0.001
[76, 20000] loss: 0.001
[76, 22000] loss: 0.001
[76, 24000] loss: 0.001
[76, 26000] loss: 0.001
[76, 28000] loss: 0.001
[77,  2000] loss: 0.001
[77,  4000] loss: 0.001
[77,  6000] loss: 0.001
[77,  8000] loss: 0.001
[77, 10000] loss: 0.001
[77, 12000] loss: 0.001
[77, 14000] loss: 0.001
[77, 16000] loss: 0.001
[77, 18000] loss

[99,  4000] loss: 0.001
[99,  6000] loss: 0.001
[99,  8000] loss: 0.001
[99, 10000] loss: 0.001
[99, 12000] loss: 0.001
[99, 14000] loss: 0.001
[99, 16000] loss: 0.001
[99, 18000] loss: 0.001
[99, 20000] loss: 0.001
[99, 22000] loss: 0.001
[99, 24000] loss: 0.001
[99, 26000] loss: 0.001
[99, 28000] loss: 0.001
[100,  2000] loss: 0.001
[100,  4000] loss: 0.001
[100,  6000] loss: 0.001
[100,  8000] loss: 0.001
[100, 10000] loss: 0.001
[100, 12000] loss: 0.001
[100, 14000] loss: 0.001
[100, 16000] loss: 0.001
[100, 18000] loss: 0.001
[100, 20000] loss: 0.001
[100, 22000] loss: 0.001
[100, 24000] loss: 0.001
[100, 26000] loss: 0.001
[100, 28000] loss: 0.001
[101,  2000] loss: 0.001
[101,  4000] loss: 0.001
[101,  6000] loss: 0.001
[101,  8000] loss: 0.001
[101, 10000] loss: 0.001
[101, 12000] loss: 0.001
[101, 14000] loss: 0.001
[101, 16000] loss: 0.001
[101, 18000] loss: 0.001
[101, 20000] loss: 0.001
[101, 22000] loss: 0.001
[101, 24000] loss: 0.001
[101, 26000] loss: 0.001
[101, 28000] 

[122, 18000] loss: 0.001
[122, 20000] loss: 0.001
[122, 22000] loss: 0.001
[122, 24000] loss: 0.001
[122, 26000] loss: 0.001
[122, 28000] loss: 0.001
[123,  2000] loss: 0.001
[123,  4000] loss: 0.001
[123,  6000] loss: 0.001
[123,  8000] loss: 0.001
[123, 10000] loss: 0.001
[123, 12000] loss: 0.001
[123, 14000] loss: 0.001
[123, 16000] loss: 0.001
[123, 18000] loss: 0.001
[123, 20000] loss: 0.001
[123, 22000] loss: 0.001
[123, 24000] loss: 0.001
[123, 26000] loss: 0.001
[123, 28000] loss: 0.001
[124,  2000] loss: 0.001
[124,  4000] loss: 0.001
[124,  6000] loss: 0.001
[124,  8000] loss: 0.001
[124, 10000] loss: 0.001
[124, 12000] loss: 0.001
[124, 14000] loss: 0.001
[124, 16000] loss: 0.001
[124, 18000] loss: 0.001
[124, 20000] loss: 0.001
[124, 22000] loss: 0.001
[124, 24000] loss: 0.001
[124, 26000] loss: 0.001
[124, 28000] loss: 0.001
[125,  2000] loss: 0.001
[125,  4000] loss: 0.001
[125,  6000] loss: 0.001
[125,  8000] loss: 0.001
[125, 10000] loss: 0.000
[125, 12000] loss: 0.001


In [51]:
# Saving the trained model
import time
model_root_path = f'/home/yigit/Documents/projects/irl_sfm/python_ws/nn_baseline_for_cnmp/model/'
# torch.save(net, f'{model_root_path}model_{str(int(time.time()))}.pt')
torch.save(net.state_dict(), f'{model_root_path}model_state_dict.pt')

In [50]:
print(torch.rand((1, 4)))
print(torch.Tensor([1,2,3,4]).float().view(1,4))

tensor([[0.4009, 0.1050, 0.6017, 0.6165]])
tensor([[1., 2., 3., 4.]])
