In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from scipy.spatial.transform import Rotation as R
import wandb

In [None]:
user = "tkaminsky"
project = "Policy Training with Symmetries"
display_name = "Experiment 1"

wandb.init(entity=user, project=project, name=display_name)

In [None]:
class MatrixNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(18, 1024),
            nn.ReLU(),
            nn.Linear(1024, 1024),
            nn.ReLU(),
            nn.Linear(1024, 1024),
            nn.ReLU(),
            nn.Linear(1024, 9),
        )
    
    def gram_schmidt(self, A):
        v1 = A[:,0]
        v2 = A[:,1]
        v3 = A[:,2]
        u1 = v1 / torch.norm(v1)
        u2 = v2 - torch.dot(u1,v2) * u1
        u2 = u2 / torch.norm(u2)
        u3 = v3 - torch.dot(u1,v3) * u1 - torch.dot(u2,v3) * u2
        u3 = u3 / torch.norm(u3)
        return torch.stack((u1,u2,u3), dim=1)

    def forward(self, x):
        # Get the quaternion and rotation magnitude
        output_unref = self.linear_relu_stack(x)
        # Apply gram-schmidt to the matrix
        output_unref = output_unref.reshape(-1,3,3)
        output = torch.zeros_like(output_unref)
        for i in range(output_unref.shape[0]):
            output[i] = self.gram_schmidt(output_unref[i])
        #output = self.gram_schmidt(output)
        output = output.reshape(-1,9)
        return output
    
model = MatrixNetwork()

In [None]:
v = .1
def get_dataset(batch_size, N):
    ds = []
    targets = []
    while len(ds) < N:
        if len(ds) % 1000 == 0:
            print(len(ds))
        # Randomly sample a unit quaternion for the initial state
        world_q_init = np.random.randn(batch_size, 4)
        norms = np.linalg.norm(world_q_init, axis=1)
        world_q_init /= norms[:, None]
        # Ensure that the first element of the quaternion is positive
        signs = np.sign(world_q_init[:, 0])
        world_q_init *= signs[:, None]

        # Randomly sample a unit quaternion for the final state
        world_q_goal = np.random.randn(batch_size, 4)
        norms = np.linalg.norm(world_q_goal, axis=1)
        world_q_goal /= norms[:, None]
        # Ensure that the first element of each quaternion is positive
        signs = np.sign(world_q_goal[:, 0])
        world_q_goal *= signs[:, None]

        batch_targets = np.zeros((batch_size, 5))

        for i in range(batch_size):
            # Get the target as a rotation vector
            rotvec = (R.from_quat(world_q_init[i]).inv() * R.from_quat(world_q_goal[i])).as_rotvec()
            # Let the target be a unit rotation vector
            target = R.from_rotvec(rotvec / np.linalg.norm(rotvec))

            # Move either at speed v or the distance to the target, whichever is smaller
            delta = np.linalg.norm(rotvec)
            if delta > v:
                delta = v

            # Format q correctly
            q0, q1, q2, q3 = target.as_quat()
            if q0 < 0:
                q0 *= -1
            batch_targets[i] = np.array([q0, q1, q2, q3, delta])
        
        batch_ds = np.concatenate((world_q_init, world_q_goal), axis=1)
        ds.append(batch_ds)
        targets.append(batch_targets)
    return {"Data": ds, "Targets": targets}
            

ds = get_dataset(128, 100000)    

# Training Loop

In [None]:
T = 10
N = 50000
B = 256

# Initialize the optimizer and loss function
optimizer = optim.SGD(model.parameters(), lr=1e-4)
loss_fn = nn.MSELoss()

for i in range(N):

    curr = ds["Data"][i]
    target = ds["Targets"][i]
    
    optimizer.zero_grad()
    
    output = model(torch.tensor(curr).float())
    loss = loss_fn(output, torch.tensor(target).float())
    loss.backward()
    optimizer.step()
    avg_loss = loss.item() / B
    wandb.log({"loss": avg_loss})


In [None]:
# Randomly sample a unit quaternion for the initial state
world_q_init = np.random.randn(1,4)
norms = np.linalg.norm(world_q_init, axis=1)
world_q_init /= norms[:, None]
# Ensure that the first element of the quaternion is positive
signs = np.sign(world_q_init[:, 0])
world_q_init *= signs[:, None]

# Randomly sample a unit quaternion for the final state
world_q_goal = np.random.randn(1, 4)
norms = np.linalg.norm(world_q_goal, axis=1)
world_q_goal /= norms[:, None]
# Ensure that the first element of each quaternion is positive
signs = np.sign(world_q_goal[:, 0])
world_q_goal *= signs[:, None]

batch_targets = np.zeros((1, 5))

for i in range(1):
    # Get the target as a rotation vector
    rotvec = (R.from_quat(world_q_init[i]).inv() * R.from_quat(world_q_goal[i])).as_rotvec()
    # Let the target be a unit rotation vector
    target = R.from_rotvec(rotvec / np.linalg.norm(rotvec))
    delta = np.linalg.norm(rotvec)
    if delta > v:
        delta = v
    q0, q1, q2, q3 = target.as_quat()
    if q0 < 0:
        q0 *= -1
    batch_targets[i] = np.array([q0, q1, q2, q3, delta])

# Set model to evaluation mode
model.eval()

# Make predictions
output = model(torch.tensor(np.concatenate((world_q_init, world_q_goal), axis=1)).float())

# The desired output is the target
print("Target: ", batch_targets)
# The actual output is the prediction
print("Prediction: ", output)

# Print the MSE between the target and the prediction
print("MSE: ", loss_fn(output, torch.tensor(batch_targets).float()))

# Loop for Symmetry Net

In [None]:
import numpy as np
import torch
import matplotlib.pyplot as plt
from scipy.spatial.transform import Rotation
import os
from torch.utils.data import Dataset
from PIL import Image
import healpy as hp
import tensorflow_graphics.geometry.transformation as tfg
from torchvision import transforms
from torch.utils.data import DataLoader
# from train import DATA_DIR, DEVICE, PARAMS_F
from bingham_vis.bc_dataloader import BottleCapDataset
from ipdf.ipdf import IPDF
from get_sample import get_sample

PARAMS_F = "ipdf/bc_full_200k.pth"
DEVICE = "cuda:0"

DATA_DIR = "bingham_vis/bc_data_large"
NUMBER_QUERIES = 200000
inches_per_subplot = 4
subset = ['Arrow', 'Circle', 'Cross', 'Diamond', 'Hexagon', 'Key', 'Line', 'Pentagon', 'U']
subset =['Cross']
NEG_SAMPLES = 1

MIN = -np.pi/6
MAX = np.pi/6

import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from scipy.spatial.transform import Rotation as R
import wandb

THRESHOLD = .08

preprocess_transforms = transforms.ToTensor()

In [None]:
# Load the policy model from model_weights_rot_mat.pth
# policy_model = MatrixNetwork()
# policy_model.load_state_dict(torch.load("policy_model_sm.pt"))
# policy_model.eval()

# Load the IPDF model from ipdf/bc_full_200k.pth
ipdf_model = IPDF()
ipdf_model.load_state_dict(torch.load(PARAMS_F))
ipdf_model.eval()

# Put both models on the GPU
# policy_model.to(DEVICE)
ipdf_model.to(DEVICE)

In [None]:
def build_grid(add_cover = False):
    grid_sizes = 72 * 8 ** np.arange(7)
    size = grid_sizes[np.argmin(np.abs(np.log(NUMBER_QUERIES) - np.log(grid_sizes)))]
    size = 294912 / 100

    count = 4 * int(size ** (1/3))

    x_rots = np.linspace(MIN * 2, MAX * 2, num = count)
    y_rots = np.linspace(MIN * 2, MAX * 2, num = count)
    z_rots = np.linspace(0, np.pi * 2, num = 3 * count)

    # if add_cover:
    #     # create a sparser grid of the remaining possible rotations
    #     x_rots_start = np.linspace(0, 2 * MIN, num = np.sqrt(count))
    #     y_rots_start = np.linspace(0, 2 * MIN, num = np.sqrt(count))
    #     x_rots_end = np.linspace(2 * MAX, 2 * np.pi, num = np.sqrt(count))
    #     y_rots_end = np.linspace(2 * MAX, 2 * np.pi, num = np.sqrt(count))

    #     x_rots_cover = np.concatenate((x_rots_start, x_rots_end))
    #     y_rots_cover = np.concatenate((y_rots_start, y_rots_end))

    #     R_grid_cover = np.zeros((len(x_rots_cover), len(y_rots_cover), len(z_rots), 3, 3))
    #     for i, x_rot in enumerate(x_rots_cover):
    #         for j, y_rot in enumerate(y_rots_cover):
    #             for k, z_rot in enumerate(z_rots):
    #                 R_grid_cover[i, j, k] = Rotation.from_euler('XYZ', [x_rot, y_rot, z_rot], degrees=False).as_matrix()

    R_grid = np.zeros((len(x_rots), len(y_rots), len(z_rots), 3, 3))

    for i, x_rot in enumerate(x_rots):
        for j, y_rot in enumerate(y_rots):
            for k, z_rot in enumerate(z_rots):
                R_grid[i, j, k] = Rotation.from_euler('XYZ', [x_rot, y_rot, z_rot], degrees=False).as_matrix()

    # if add_cover:
    #     R_grid = np.concatenate((R_grid, R_grid_cover), axis=0)
    
    R_grid = R_grid.reshape(-1, 3, 3)
    R_grid = torch.from_numpy(R_grid).float().to(DEVICE)
    print("Shape of R_grid: ", R_grid.shape)
    return R_grid

In [None]:
def get_rotations(im_pil, R_grid):
    R_grid_new = R_grid.reshape(1, -1, 9)
    im_t = preprocess_transforms(im_pil)
    im_t = im_t.unsqueeze(0)
    im_t = im_t.to(DEVICE)

    query_rotations, probabilities = ipdf_model.output_pdf(im_t.to(DEVICE), R_grid_new.to(DEVICE))

    query_rotations = query_rotations.squeeze().cpu().numpy()
    probabilities = probabilities.squeeze().cpu().numpy()

    # Find all probabilities above a certain threshold THRESHOLD
    idx = np.where(probabilities > THRESHOLD)[0]
    print(idx)
    # Get the corresponding rotations
    rotations = query_rotations[idx]

    # Reshape to (num_rotations, 3, 3)
    rotations = rotations.reshape(-1, 3, 3)
    rotations = torch.from_numpy(rotations).float().to(DEVICE)

    return rotations

In [None]:
# Randomly choose an initial orientation between [-pi/6, pi/6] x [-pi/6, pi/6] x [0, 2pi]
# and a final orientation between [-pi/6, pi/6] x [-pi/6, pi/6] x [0, 2pi]
init_x = np.random.uniform(MIN, MAX)
init_y = np.random.uniform(MIN, MAX)
init_z = np.random.uniform(0, 2*np.pi)
init_rot = R.from_euler('XYZ', [init_x, init_y, init_z]).as_matrix()
goal_x = np.random.uniform(MIN, MAX)
goal_y = np.random.uniform(MIN, MAX)
goal_z = np.random.uniform(0, 2*np.pi)
goal_rot = R.from_euler('XYZ', [goal_x, goal_y, goal_z]).as_matrix()

# Make into tensors
init_rot = torch.from_numpy(init_rot).float().to(DEVICE)
goal_rot = torch.from_numpy(goal_rot).float().to(DEVICE)

In [None]:
H = 200

ims = []

R_grid = build_grid().to(DEVICE)

num_correct = 0
n_trials = 20

with torch.no_grad():
    obj = "Cross"
    type = "Bottle"
    for rep in range(n_trials):
        ims = []
        init_x = np.random.uniform(MIN, MAX)
        init_y = np.random.uniform(MIN, MAX)
        init_z = np.random.uniform(0, 2*np.pi)
        init_rot = R.from_euler('XYZ', [init_x, init_y, init_z]).as_matrix()
        goal_x = np.random.uniform(MIN, MAX)
        goal_y = np.random.uniform(MIN, MAX)
        goal_z = np.random.uniform(0, 2*np.pi)
        goal_rot = R.from_euler('XYZ', [goal_x, goal_y, goal_z]).as_matrix()

        # Make into tensors
        init_rot = torch.from_numpy(init_rot).float().to(DEVICE)
        goal_rot = torch.from_numpy(goal_rot).float().to(DEVICE)

        gt_rot = init_rot

        goal_im = get_sample(obj, type=type, orientation=goal_rot.cpu().numpy())

        for i in range(H):
            # Get a sample image from the current orientation
            im_pil = get_sample(obj, type=type, orientation=gt_rot.cpu().numpy())
            ims.append(im_pil)

            # Check if the current rotation is close enough to the goal rotation
            if torch.norm(gt_rot - goal_rot) < .1:
                print("Reached goal")
                num_correct += 1
                break

            rotations = get_rotations(im_pil, R_grid)

            if len(rotations) == 0:
                print("No rotations found")
                # Apply a random rotation
                gt_rot = torch.matmul(torch.from_numpy(R.from_euler('XYZ', np.random.uniform(-np.pi/12, np.pi/12, size=3)).as_matrix()).float().to(DEVICE), gt_rot)
                continue

            # Get the rotation that is closest to the goal
            norms = torch.norm(rotations - goal_rot, dim=(1,2))
            rot_guess = rotations[torch.argmin(norms)]
            # Concatenate the current rotation and the goal rotation into an 18-dimensional vector
            # concat_guess = torch.cat((rot_guess.flatten(), goal_rot.flatten()))
            concat_guess = torch.cat((gt_rot.flatten(), goal_rot.flatten()))
            # Get the action that will take us to the next rotation
            action = policy_model(concat_guess).reshape(3,3)
            

            # Update the current rotation
            gt_rot = torch.matmul(action, gt_rot)
            # gt_rot = torch.matmul(gt_rot, action)


        # If the final guess is close to the goal, shade the goal image green
        if torch.norm(gt_rot - goal_rot) < .1:
            final_im = np.array(ims[-1])
            print(final_im.shape)
            # Turn from RGB to grayscale
            final_im = np.mean(final_im, axis=2)
            print(final_im.shape)
            # Stack the grayscale image 3 times to make it RGB
            final_im = np.stack((final_im, final_im, final_im), axis=2)
            print(final_im.shape)
            final_im[:,:,0] = 0
            final_im[:,:,2] = 0
            print(final_im.shape)
            # Turn intp uint8
            final_im = final_im.astype(np.uint8)
            final_im_pil = Image.fromarray(final_im)
            ims.append(final_im_pil)

        else:
            print("Did not reach goal")
            # Make the goal image red
            final_im = np.array(ims[-1])
            print(final_im.shape)
            # Turn from RGB to grayscale
            final_im = np.mean(final_im, axis=2)
            print(final_im.shape)
            # Stack the grayscale image 3 times to make it RGB
            final_im = np.stack((final_im, final_im, final_im), axis=2)
            print(final_im.shape)
            final_im[:,:,1] = 0
            final_im[:,:,2] = 0
            print(final_im.shape)
            # Turn intp uint8
            final_im = final_im.astype(np.uint8)
            final_im_pil = Image.fromarray(final_im)
            ims.append(final_im_pil)

        # copy the final image 20 times
        for i in range(20):
            ims.append(ims[-1])

        ims_to_save = []

        for im in ims:
            # Plot goal image alongside current image
            fig, (ax1, ax2) = plt.subplots(1, 2)
            ax1.imshow(goal_im)
            ax2.imshow(im)
            # Turn the plot into a pil image
            fig.canvas.draw()
            data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
            data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
            im_pil_big = Image.fromarray(data)
            ims_to_save.append(im_pil_big)
        # Turn the list of images into a gif, with the goal image on the left and the current image on the right
        ims_to_save[0].save(f'results_rev/perfect_{obj}_{type}_{rep}.gif', save_all=True, append_images=ims_to_save[1:], duration=100, loop=0)
    

# Policy-less Training Loop


In [None]:
H = 50

ims = []

v = .2

R_grid = build_grid(add_cover=False).to(DEVICE)

num_correct = 0
n_trials = 20

with torch.no_grad():
    obj = "Cross"
    type = "Bottle"
    for rep in range(n_trials):
        ims = []
        init_x = np.random.uniform(MIN, MAX)
        init_y = np.random.uniform(MIN, MAX)
        init_z = np.random.uniform(0, 2*np.pi)
        init_rot = R.from_euler('XYZ', [init_x, init_y, init_z]).as_matrix()
        goal_x = np.random.uniform(MIN, MAX)
        goal_y = np.random.uniform(MIN, MAX)
        goal_z = np.random.uniform(0, 2*np.pi)
        goal_rot = R.from_euler('XYZ', [goal_x, goal_y, goal_z]).as_matrix()

        # Make into tensors
        init_rot = torch.from_numpy(init_rot).float().to(DEVICE)
        goal_rot = torch.from_numpy(goal_rot).float().to(DEVICE)

        gt_rot = init_rot

        goal_im = get_sample(obj, type=type, orientation=goal_rot.cpu().numpy())

        for i in range(H):
            # Get a sample image from the current orientation
            im_pil = get_sample(obj, type=type, orientation=gt_rot.cpu().numpy())
            ims.append(im_pil)

            # Check if the current rotation is close enough to the goal rotation
            if torch.norm(gt_rot - goal_rot) < .1:
                print("Reached goal")
                num_correct += 1
                break

            rotations = get_rotations(im_pil, R_grid)

            if len(rotations) == 0:
                print("No rotations found")
                # Apply a random rotation
                gt_rot = torch.matmul(torch.from_numpy(R.from_euler('XYZ', np.random.uniform(-np.pi/12, np.pi/12, size=3)).as_matrix()).float().to(DEVICE), gt_rot)
                continue

            # Get the rotation that is closest to the goal
            norms = torch.norm(rotations - goal_rot, dim=(1,2))
            rot_guess = rotations[torch.argmin(norms)]
            
            # Calculate the rotation between the current rotation and the goal rotation
            rot_diff = torch.matmul(goal_rot, rot_guess.inverse())

            # Turn into a rotvec
            rotvec = R.from_matrix(rot_diff.cpu().numpy()).as_rotvec()
            # Rescale the rotvec to have magnitude v
            if np.linalg.norm(rotvec) > v:
                rotvec = rotvec / np.linalg.norm(rotvec) * v
            # Turn back into a rotation matrix
            rot_diff = torch.from_numpy(R.from_rotvec(rotvec).as_matrix()).float().to(DEVICE)

            # Update the current rotation
            gt_rot = torch.matmul(rot_diff, gt_rot)
            
            

            # Update the current rotation
            # gt_rot = torch.matmul(action, gt_rot)
            # gt_rot = torch.matmul(gt_rot, action)


        # If the final guess is close to the goal, shade the goal image green
        if torch.norm(gt_rot - goal_rot) < .1:
            final_im = np.array(ims[-1])
            print(final_im.shape)
            # Turn from RGB to grayscale
            final_im = np.mean(final_im, axis=2)
            print(final_im.shape)
            # Stack the grayscale image 3 times to make it RGB
            final_im = np.stack((final_im, final_im, final_im), axis=2)
            print(final_im.shape)
            final_im[:,:,0] = 0
            final_im[:,:,2] = 0
            print(final_im.shape)
            # Turn intp uint8
            final_im = final_im.astype(np.uint8)
            final_im_pil = Image.fromarray(final_im)
            ims.append(final_im_pil)

        else:
            print("Did not reach goal")
            # Make the goal image red
            final_im = np.array(ims[-1])
            print(final_im.shape)
            # Turn from RGB to grayscale
            final_im = np.mean(final_im, axis=2)
            print(final_im.shape)
            # Stack the grayscale image 3 times to make it RGB
            final_im = np.stack((final_im, final_im, final_im), axis=2)
            print(final_im.shape)
            final_im[:,:,1] = 0
            final_im[:,:,2] = 0
            print(final_im.shape)
            # Turn intp uint8
            final_im = final_im.astype(np.uint8)
            final_im_pil = Image.fromarray(final_im)
            ims.append(final_im_pil)

        # copy the final image 20 times
        for i in range(20):
            ims.append(ims[-1])

        ims_to_save = []

        for im in ims:
            # Plot goal image alongside current image
            fig, (ax1, ax2) = plt.subplots(1, 2)
            ax1.imshow(goal_im, )
            ax2.imshow(im)
            # Turn the plot into a pil image
            fig.canvas.draw()
            data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
            data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
            im_pil_big = Image.fromarray(data)
            ims_to_save.append(im_pil_big)
            plt.close(fig)
        # Turn the list of images into a gif, with the goal image on the left and the current image on the right
        ims_to_save[0].save(f'results_rev/cover_guess_{obj}_{type}_{rep}.gif', save_all=True, append_images=ims_to_save[1:], duration=100, loop=0)

In [None]:
num_correct / n_trials

In [None]:
# Turn the list of images into a gif, with the goal image on the left and the current image on the right
ims[0].save('test_U_cap.gif', save_all=True, append_images=ims[1:], duration=100, loop=0)