In [None]:
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision.transforms import LinearTransformation
from torch.nn.functional import affine_grid, grid_sample

from kornia.geometry.transform import warp_affine3d

import numpy as np
import matplotlib.pyplot as plt
from scipy.ndimage import affine_transform
from mpl_toolkits.axes_grid1 import ImageGrid
import os

import nn_architecture
import data_loader as dl


In [None]:
def get_device():
    if torch.cuda.is_available():
        device_name = "cuda:0"
    #elif torch.backends.mps.is_available():
    #    device_name = "mps"
    else:
        device_name = "cpu"
    return torch.device(device_name)

In [None]:
device = get_device()
device

In [None]:
def show_eval(res_matrix, x, y):
    res_matrix = res_matrix.cpu().detach().numpy()
    full_matrix = np.append(res_matrix, [0, 0, 0, 1]).reshape(4, 4)
    x_nmp = x.detach().numpy().transpose(1, 2, 0)
    y_nmp = y.detach().numpy().transpose(1, 2, 0)

    fig = plt.figure(figsize=(200, 200))
    grid = iter(ImageGrid(fig, 111, nrows_ncols=(1, 3), axes_pad=0.1))

    ax = next(grid)
    ax.imshow(np.max(x_nmp, axis=2), cmap='gray')
    ax = next(grid)
    ax.imshow(np.max(y_nmp, axis=2), cmap='gray')
    x_new = affine_transform(x_nmp, full_matrix)
    ax = next(grid)
    ax.imshow(np.max(x_new, axis=2), cmap='gray')

    plt.show()

In [None]:
def eval_model(model_state_file, data_path, iterations=0):
    model = nn_architecture.Siam_AirNet2()
    criterion = nn.MSELoss(reduction='sum')
    
    model.load_state_dict(torch.load(model_state_file, map_location=torch.device('cpu')))
    # model.load_state_dict(torch.load("best_model.pt", map_location=torch.device('cpu')))
    model.eval()

    min_val = -1000
    max_val = 1000
    batch_size = 1

    dataset = dl.Img3dDataSet(data_path, min_val, max_val, device)
    data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
    
    overall_loss = 0

    i = 0
    torch.set_printoptions(precision=3)
    for x, y, matrix in data_loader:
        if iterations and i >= iterations:
            break
        print(f"index: {i}")
        print(matrix.reshape(3, 4))
        res_matrix = model(x, y)
        loss = criterion(res_matrix, matrix.flatten(start_dim=1))
        overall_loss += loss
        print(res_matrix.reshape(3, 4).float())
        print(f"diff:\n{(matrix.reshape(3, 4) - res_matrix.reshape(3, 4))}")
        print(f"loss: {loss}")
        #show_eval(res_matrix, x[0], y[0])
        i += 1
        
    return overall_loss / i

In [None]:
eval_model("models/best_model.pt", "./data/val")

In [None]:
m1 = np.array([  0.9563,  -0.2924,   0.0000,  53.7707,   0.2924,   0.9563,   0.0000,
         -39.7882,   0.0000,   0.0000,   1.0000,   0.0000])
m2 = np.array([ 9.3715e-01, -2.6060e-01, -6.0337e-04,  5.5185e+01,  2.5515e-01,
          9.2936e-01, -1.4178e-02, -2.1137e+01, -4.2215e-03,  5.5613e-03,
          9.9444e-01, -3.5994e-03])

In [None]:
np.square(m1 - m2)

In [None]:
eval_model("models/continue_920_1.93.pt", "./data/val")

In [None]:
eval_model("best_model_731_1.93.pt", "./data/val")

In [None]:
eval_model("models/continueModel_1.19.pt", "./data/val")

In [None]:
eval_model("model_test1.pt", "./data/val", iterations=10)

In [80]:
np.random.uniform(-45, 45)

-39.27412412007226