In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from scipy.spatial.transform import Rotation as R
import wandb

In [None]:
user = "tkaminsky"
project = "Policy Training with matrices"
display_name = "Experiment 6"

wandb.init(entity=user, project=project, name=display_name)

In [None]:
class NeuralNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(8, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 5),
        )

    def forward(self, x):
        # Get the quaternion and rotation magnitude
        output = self.linear_relu_stack(x)
        # Get the quaternion and scale it to unit length
        quaternion = self.linear_relu_stack(x)[::,0:4]
        quaternion = quaternion / torch.norm(quaternion)
        signs = torch.sign(quaternion[::,0])
        quaternion = quaternion * signs.view(-1,1)

        # Update the output
        output[::,0:4] = quaternion
        # Scale quaternion to unit length
        return output
    
model = NeuralNetwork()

In [None]:
class MatrixNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(18, 1024),
            nn.ReLU(),
            nn.Linear(1024, 1024),
            nn.ReLU(),
            nn.Linear(1024, 1024),
            nn.ReLU(),
            nn.Linear(1024, 9),
        )
    
    def gram_schmidt(self, A):
        v1 = A[:,0]
        v2 = A[:,1]
        v3 = A[:,2]
        u1 = v1 / torch.norm(v1)
        u2 = v2 - torch.dot(u1,v2) * u1
        u2 = u2 / torch.norm(u2)
        u3 = v3 - torch.dot(u1,v3) * u1 - torch.dot(u2,v3) * u2
        u3 = u3 / torch.norm(u3)
        return torch.stack((u1,u2,u3), dim=1)

    def forward(self, x):
        # Get the quaternion and rotation magnitude
        output_unref = self.linear_relu_stack(x)
        # Apply gram-schmidt to the matrix
        output_unref = output_unref.reshape(-1,3,3)
        output = torch.zeros_like(output_unref)
        for i in range(output_unref.shape[0]):
            output[i] = self.gram_schmidt(output_unref[i])
        #output = self.gram_schmidt(output)
        output = output.reshape(-1,9)
        return output
    
model = MatrixNetwork()

In [None]:
v = .1
def get_dataset(batch_size, N):
    ds = []
    targets = []
    while len(ds) < N:
        if len(ds) % 1000 == 0:
            print(len(ds))
        # Randomly sample a unit quaternion for the initial state
        world_q_init = np.random.randn(batch_size, 4)
        norms = np.linalg.norm(world_q_init, axis=1)
        world_q_init /= norms[:, None]
        # Ensure that the first element of the quaternion is positive
        signs = np.sign(world_q_init[:, 0])
        world_q_init *= signs[:, None]

        # Randomly sample a unit quaternion for the final state
        world_q_goal = np.random.randn(batch_size, 4)
        norms = np.linalg.norm(world_q_goal, axis=1)
        world_q_goal /= norms[:, None]
        # Ensure that the first element of each quaternion is positive
        signs = np.sign(world_q_goal[:, 0])
        world_q_goal *= signs[:, None]

        batch_targets = np.zeros((batch_size, 5))

        for i in range(batch_size):
            # Get the target as a rotation vector
            rotvec = (R.from_quat(world_q_init[i]).inv() * R.from_quat(world_q_goal[i])).as_rotvec()
            # Let the target be a unit rotation vector
            target = R.from_rotvec(rotvec / np.linalg.norm(rotvec))

            # Move either at speed v or the distance to the target, whichever is smaller
            delta = np.linalg.norm(rotvec)
            if delta > v:
                delta = v

            # Format q correctly
            q0, q1, q2, q3 = target.as_quat()
            if q0 < 0:
                q0 *= -1
            batch_targets[i] = np.array([q0, q1, q2, q3, delta])
        
        batch_ds = np.concatenate((world_q_init, world_q_goal), axis=1)
        ds.append(batch_ds)
        targets.append(batch_targets)
    return {"Data": ds, "Targets": targets}
            

ds = get_dataset(128, 100000)    

In [None]:
v = .1
def get_mat_dataset(batch_size, N):
    ds = []
    targets = []
    while len(ds) < N:
        if len(ds) % 1000 == 0:
            print(len(ds))
        # Randomly sample a rotation using x y z euler angles
        x_init = np.random.uniform(0, 2 * np.pi, batch_size)
        y_init = np.random.uniform(0, 2 * np.pi, batch_size)
        z_init = np.random.uniform(0, 2 * np.pi, batch_size)
        # Turn the euler angles into a rotation matrix
        world_q_init = np.zeros((batch_size, 3, 3))
        for i in range(batch_size):
            world_q_init[i] = R.from_euler('XYZ', [x_init[i], y_init[i], z_init[i]]).as_matrix()
        
        # Randomly sample a rotation using x y z euler angles
        x_goal = np.random.uniform(0, 2 * np.pi, batch_size)
        y_goal = np.random.uniform(0, 2 * np.pi, batch_size)
        z_goal = np.random.uniform(0, 2 * np.pi, batch_size)
        # Turn the euler angles into a rotation matrix
        world_q_goal = np.zeros((batch_size, 3, 3))
        for i in range(batch_size):
            world_q_goal[i] = R.from_euler('XYZ', [x_goal[i], y_goal[i], z_goal[i]]).as_matrix()

        batch_targets = np.zeros((batch_size, 9))

        for i in range(batch_size):
            # Get the target as a rotation vector
            rotvec = (R.from_matrix(world_q_init[i]).inv() * R.from_matrix(world_q_goal[i])).as_rotvec()

            # Move either at speed v or the distance to the target, whichever is smaller
            delta = np.linalg.norm(rotvec)
            if delta > v:
                delta = v

            rotvec_scaled = (rotvec / np.linalg.norm(rotvec)) * delta
            target = R.from_rotvec(rotvec_scaled).as_matrix()

            batch_targets[i] = target.reshape(-1)
        
        world_q_init = world_q_init.reshape(-1, 9)
        world_q_goal = world_q_goal.reshape(-1, 9)
        batch_ds = np.concatenate((world_q_init, world_q_goal), axis=1)
        ds.append(batch_ds)
        targets.append(batch_targets)
    return {"Data": ds, "Targets": targets}
            

ds = get_mat_dataset(256, 10000)    

In [None]:
print(ds.keys())
print(len(ds['Data']))
print(ds['Data'][0].shape)
print(len(ds['Targets']))
print(ds['Targets'][0].shape)

# Training Loop

In [None]:
T = 10
N = 10000
B = 256

# Initialize the optimizer and loss function
optimizer = optim.Adam(model.parameters(), lr=1e-3)
# Choose an appropriate loss for the orientation matrices
loss_fn = nn.MSELoss()

for i in range(N):

    curr = ds["Data"][i]
    target = ds["Targets"][i]
    
    optimizer.zero_grad()
    
    output = model(torch.tensor(curr).float())

    output = output.reshape(-1,3,3)
    target = torch.tensor(target).float().reshape(-1,3,3)
    loss = loss_fn(output, torch.tensor(target).float())
    loss.backward()
    optimizer.step()
    avg_loss = loss.item()
    print(avg_loss)
    wandb.log({"loss": avg_loss})


In [None]:
# save the model
torch.save(model.state_dict(), "policy_model_sm.pt")

In [None]:
# Randomly sample a unit quaternion for the initial state
world_q_init = np.random.randn(1,4)
norms = np.linalg.norm(world_q_init, axis=1)
world_q_init /= norms[:, None]
# Ensure that the first element of the quaternion is positive
signs = np.sign(world_q_init[:, 0])
world_q_init *= signs[:, None]

# Randomly sample a unit quaternion for the final state
world_q_goal = np.random.randn(1, 4)
norms = np.linalg.norm(world_q_goal, axis=1)
world_q_goal /= norms[:, None]
# Ensure that the first element of each quaternion is positive
signs = np.sign(world_q_goal[:, 0])
world_q_goal *= signs[:, None]

batch_targets = np.zeros((1, 5))

for i in range(1):
    # Get the target as a rotation vector
    rotvec = (R.from_quat(world_q_init[i]).inv() * R.from_quat(world_q_goal[i])).as_rotvec()
    # Let the target be a unit rotation vector
    target = R.from_rotvec(rotvec / np.linalg.norm(rotvec))
    delta = np.linalg.norm(rotvec)
    if delta > v:
        delta = v
    q0, q1, q2, q3 = target.as_quat()
    if q0 < 0:
        q0 *= -1
    batch_targets[i] = np.array([q0, q1, q2, q3, delta])

# Set model to evaluation mode
model.eval()

# Make predictions
output = model(torch.tensor(np.concatenate((world_q_init, world_q_goal), axis=1)).float())

# The desired output is the target
print("Target: ", batch_targets)
# The actual output is the prediction
print("Prediction: ", output)

# Print the MSE between the target and the prediction
print("MSE: ", loss_fn(output, torch.tensor(batch_targets).float()))