In [1]:
from auxiliary import loadData_staticTargetAddrMatch, prepare_data_loaders
from algorithms import RssPosAlgo_NeuralNet_MLP4layer
from algorithms import RssPosAlgo_NeuralNet_supervisedTrainingLoop
from algorithms import RssPosAlgo_NearestNeighbour_Interpolation
from algorithms import RssPosAlgo_NearestNeighbour_GetKmeansDb
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim

### prepare model, dataloaders and training parameters

In [2]:
datajsonpath = "../experiments/exp004_20241022_sna_kadirburakerdem/data-tshark/data.json"
inp_rss_vals, gt_locations = loadData_staticTargetAddrMatch(datajsonpath, second_hold = 5, shuffle=False, 
                                                            target_addresses=["d8:47:32:eb:6c:38",
"50:c7:bf:19:e6:4d",
"4c:77:6d:5f:dc:20"], snap250ms=False)

epochs           = 601
batch_size       = 8
train_test_split = 0.5

train_loader, test_loader, xtr, ytr, xts, yts = prepare_data_loaders(inp_rss_vals, gt_locations, 
                                                                     batch_size = 32, 
                                                                     train_test_split = train_test_split)

print("Subset sizes | train:", xtr.shape[0], ", test:",xts.shape[0])
MLP = RssPosAlgo_NeuralNet_MLP4layer(inch=3)
MLP.train()

MLP_criterion = nn.MSELoss(reduction='mean')
MLP_optimizer = optim.Adam(MLP.parameters(), lr=3e-4)

Subset sizes | train: 3086 , test: 3087


### train model

In [None]:
MLP = RssPosAlgo_NeuralNet_supervisedTrainingLoop(train_loader = train_loader, 
                                                         test_loader  = test_loader,
                                                         model        = MLP, 
                                                         criterion    = MLP_criterion, 
                                                         optimizer    = MLP_optimizer, 
                                                         epochs       = epochs,
                                                         testfreq     = 20) # testfreq is in epochs

Epoch [1/601] test loss: 2.432, training loss: -1.000


### save the model

In [None]:
torch.save(MLP.state_dict(), 'savedmodels/dev008_exp004_MLP4layer.pth')

### evaluate model

In [None]:
MLP = RssPosAlgo_NeuralNet_MLP4layer(inch=3)
MLP.load_state_dict(torch.load('savedmodels/dev008_exp004_MLP4layer.pth'));
MLP.eval();

In [None]:
db_kmeans = RssPosAlgo_NearestNeighbour_GetKmeansDb(xtr, ytr, num_clusters=3)
meanerror_nene_interp = 0;
meanerror_mlp         = 0;
for test_idx, x_test_sample in enumerate(xts): 
    loc_pred_mlp           = MLP(x_test_sample)
    loc_pred_nene_interp   = RssPosAlgo_NearestNeighbour_Interpolation(x_test_sample, db_kmeans)
    meanerror_mlp         += (yts[test_idx].numpy() - loc_pred_mlp.detach().numpy())**2
    meanerror_nene_interp += (yts[test_idx].numpy() - loc_pred_nene_interp)**2
print("MLP        :", np.linalg.norm(meanerror_mlp/(test_idx+1)))
print("NeNe+Interp:", np.linalg.norm(meanerror_nene_interp/(test_idx+1)))

### test a few samples

from test set

In [None]:
sampleid = 10
loc_pred_mlp           = MLP(xts[sampleid])
loc_pred_nene_interp   = RssPosAlgo_NearestNeighbour_Interpolation(xts[sampleid], db_kmeans)
meanerror_mlp         += (yts[sampleid].numpy() - loc_pred_mlp.detach().numpy())**2
meanerror_nene_interp += (yts[sampleid].numpy() - loc_pred_nene_interp)**2
print("Actual position  :", yts[sampleid].numpy())
print("MLP prediction   :", loc_pred_mlp.detach().numpy())
print("NeNe+Interp pred :", loc_pred_nene_interp)

---

from train set

In [None]:
sampleid = 20
loc_pred_mlp           = MLP(xtr[sampleid])
loc_pred_nene_interp   = RssPosAlgo_NearestNeighbour_Interpolation(xtr[sampleid], db_kmeans)
meanerror_mlp         += (ytr[sampleid].numpy() - loc_pred_mlp.detach().numpy())**2
meanerror_nene_interp += (ytr[sampleid].numpy() - loc_pred_nene_interp)**2
print("Actual position  :", ytr[sampleid].numpy())
print("MLP prediction   :", loc_pred_mlp.detach().numpy())
print("NeNe+Interp pred :", loc_pred_nene_interp)