# UAV 3D Trajectory Prediction using GRU

## Load and Test the Trained Model

### Import Required Libraries

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader, TensorDataset

### Load the Trained Model

In [None]:
class TrajectoryPredictor(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, num_layers, dropout=0.5):
        super(TrajectoryPredictor, self).__init__()
        self.hidden_dim = hidden_dim
        self.gru1 = torch.nn.GRU(input_dim, hidden_dim, num_layers, batch_first=True, dropout=dropout)
        self.gru2 = torch.nn.GRU(hidden_dim, hidden_dim, num_layers, batch_first=True)
        self.fc = torch.nn.Linear(hidden_dim, output_dim)
    def forward(self, x):
        out, h_n = self.gru1(x)
        dec_input = torch.zeros(x.size(0), 10, self.hidden_dim).to(x.device)
        out, _ = self.gru2(dec_input, h_n)
        out = self.fc(out)
        return out

model = TrajectoryPredictor(input_dim=3, hidden_dim=64, output_dim=3, num_layers=2, dropout=0.5)
model.load_state_dict(torch.load('5000_pos_max_norm_64.pth'))
model.eval()

### Load Test Data

In [None]:
# Assuming the test data is loaded similarly to how the training data was loaded
test_data = np.load('test_data.npz')
test_inputs = test_data['inputs']
test_outputs = test_data['outputs']
test_dataset = TensorDataset(torch.tensor(test_inputs, dtype=torch.float32), torch.tensor(test_outputs, dtype=torch.float32))
test_loader = DataLoader(test_dataset, batch_size=64)

### Evaluate the Model on Test Data

In [None]:
def evaluate_model(model, loader):
    model.eval()
    total_mse, total_rmse, total_mae, total_r2, total_adj_r2 = 0.0, 0.0, 0.0, 0.0, 0.0
    total_samples = 0

    with torch.no_grad():
        for inputs, targets in loader:
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)

            # Calculate Mean Squared Error
            mse = ((targets - outputs) ** 2).mean()
            total_mse += mse * inputs.size(0)

            # Calculate Root Mean Squared Error
            rmse = torch.sqrt(mse)
            total_rmse += rmse * inputs.size(0)

            # Calculate Mean Absolute Error
            mae = torch.abs(targets - outputs).mean()
            total_mae += mae * inputs.size(0)

            # Calculate R2 Score
            ss_res = mse * targets.numel()
            ss_tot = ((targets - targets.mean()) ** 2).sum()
            r2 = 1 - ss_res / ss_tot
            total_r2 += r2 * inputs.size(0)

            # Calculate Adjusted R2 Score
            adjusted_r2 = 1 - (1 - r2) * (inputs.size(0) - 1) / (inputs.size(0) - inputs.size(1) - 1)
            total_adj_r2 += adjusted_r2 * inputs.size(0)

            total_samples += inputs.size(0)

    # Compile metrics
    evaluation_results = {
        'MSE': total_mse / total_samples,
        'RMSE': total_rmse / total_samples,
        'MAE': total_mae / total_samples,
        'R2': total_r2 / total_samples,
        'Adjusted R2': total_adj_r2 / total_samples
    }
    return evaluation_results

# Evaluate the model on the test dataset
test_metrics = evaluate_model(model, test_loader)
print("Test Metrics:", test_metrics)

### Visualizing Predictions

In [None]:
# Function to plot trajectories
def plot_trajectories(inputs, targets, predictions, filename):
        '''
    Plot the input, target, and predicted trajectories in 3D.
    Saves the plot to a specified filename.
    '''
    fig = plt.figure(figsize=(10, 8))
    ax = fig.add_subplot(111, projection='3d')
    
    # Plotting the trajectories
    ax.plot(inputs[:, 0], inputs[:, 1], inputs[:, 2], label='Input Sequence', color='blue')
    ax.plot(targets[:, 0], targets[:, 1], targets[:, 2], label='True Future Trajectory', color='green')
    ax.plot(predictions[:, 0], predictions[:, 1], predictions[:, 2], label='Predicted Trajectory', color='red', linestyle='--')
    
    # Setting labels and title
    ax.set_xlabel('X Axis')
    ax.set_ylabel('Y Axis')
    ax.set_zlabel('Z Axis')
    ax.set_title('3D Trajectory Prediction')
    ax.legend()
    
    # Saving the plot
    plt.savefig(filename, bbox_inches='tight')
    plt.close()

# Plotting a few test examples
for i, (inputs, targets) in enumerate(test_loader):
    if i >= 5:  # Limit to first 5 batches
        break
    inputs, targets = inputs.to('cpu'), targets.to('cpu')
    predictions = model(inputs)
    for j in range(inputs.shape[0]):
        plot_filename = f'test_trajectory_{i}_{j}.png'
        plot_trajectories(inputs[j].numpy(), targets[j].numpy(), predictions[j].detach().numpy(), plot_filename)