In [None]:
%load_ext autoreload
%autoreload 2 

In [None]:
!git clone https://ymentha14:mysecretpassword@github.com/ymentha14/se3_project.git

In [None]:
%cd se3_project

In [None]:
!pip install -r requirements.txt

In [None]:
!wandb login myhash

# SE3 implementation

In [None]:
import os
import time
from copy import copy
from pathlib import Path
from pdb import set_trace

import matplotlib.pyplot as plt
import numpy as np
import psutil
import src.se3.visualization as viz
import torch
import torch.nn as nn
from se3_transformer_pytorch.irr_repr import rot
from src.se3.se3_expes import expes_kwargs, run
from src.ri_distances.pnt_cloud_generation import (SpiralGenerator, center,
                                                   get_custom_spiral,
                                                   get_spiral,
                                                   get_src_scaled_spirals,
                                                   get_src_shifted_spirals,
                                                   to_numpy_array,
                                                   to_torch_tensor)
from src.ri_distances.SGW.risgw import RISGW_loss
from src.se3.torch_funcs import (MachineScaleChecker, get_model,
                                 get_predictions, predict,
                                 start_training, train_one_epoch,
                                 visualize_prediction,get_batch)
from src.se3.visualization import plot_coordinates, viz_point_cloud,plot_coordinates
from tqdm import tqdm
import wandb

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

plt.style.use('ggplot')
model_path = Path("/content/drive/MyDrive/se3_transformer.pt")
torch.set_default_dtype(torch.float32) # works best in float64

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

plt.style.use('ggplot')
model_path = Path("/content/drive/MyDrive/se3_transformer.pt")
torch.set_default_dtype(torch.float32) # works best in float64

# W&B group
os.environ["WANDB_RUN_GROUP"] = "experiment A"

## Colab running of all experiments

In [None]:
for i, (src_kwargs, trgt_kwargs, asym_features) in enumerate(tqdm(expes_kwargs, desc="SE(3) overfits experiments", leave=False)):
    run(i, src_kwargs, trgt_kwargs, asym_features, use_wandb=False)
    sys.stdout.flush()

## Experiment 1

We notice that the model seems unable to learn something: this might be due to its translational equivariance

In [None]:
transformer = get_model()
epochs = 50
criterion = torch.nn.MSELoss()
batch_size = 4
lr = 0.01
optimizer = torch.optim.Adam(transformer.parameters(),lr=lr)
center_output = False
use_wandb = False
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                       patience=100,
                                                       factor=0.4,
                                                       threshold=0.001,
                                                       verbose=True)
src_gen = SpiralGenerator()
trgt_gen = SpiralGenerator(shift=0.3)
use_wandb = False
asym_features = True
points, target_points = src_gen.generate(),trgt_gen.generate()
viz_point_cloud([(points,'src'),(target_points,'trgt')])

In [None]:
start_training(transformer,lr,optimizer,epochs,criterion,batch_size,scheduler,device,src_gen,trgt_gen,center_output,asym_features=asym_features,use_wandb=use_wandb)

In [None]:
points, target_points, predicted_points = get_predictions(transformer, src_gen,trgt_gen, center_output)
fig = visualize_prediction(points, target_points, predicted_points)
if use_wandb:
    wandb.log({"chart": wandb.Image(fig)})
fig

## Experiment 2

We switch to a scaled setting

In [None]:
transformer = get_model()
epochs = 15
criterion = torch.nn.MSELoss()
batch_size = 4
lr = 0.01
optimizer = torch.optim.Adam(transformer.parameters(),lr=lr)
# optimizer = torch.optim.SGD(transformer.parameters(), lr=lr, momentum=0.9)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                       patience=100,
                                                       factor=0.4,
                                                       threshold=0.001,
                                                       verbose=True)
center_output = False
src_gen = SpiralGenerator(shift=2.0)
trgt_gen = SpiralGenerator(scaling=2.0,shift=2.0)
use_wandb = True
points, target_points = src_gen.generate(),trgt_gen.generate()
viz_point_cloud([(points,'src'),(target_points,'trgt')])

In [None]:
start_training(transformer,lr,optimizer,epochs,criterion,batch_size,scheduler,device,src_gen,trgt_gen,center_output,use_wandb)

In [None]:
points, target_points, predicted_points = get_predictions(transformer, src_gen,trgt_gen, center_output)
fig = visualize_prediction(points, target_points, predicted_points)
if use_wandb:
    wandb.log({"chart": wandb.Image(fig)})
fig

In [None]:
# plot the z coordinate ==> intuition: cannot learn non symmetrical data
fig = plot_coordinates(target_points,predicted_points)
if use_wandb:
    wandb.log({"z_coord": wandb.Image(fig)})

In [None]:
# we train for 70 more epochs
start_training(transformer,lr,optimizer,epochs=85,criterion,batch_size,scheduler,device,src_gen,trgt_gen,center_output,use_wandb)

In [None]:
MSE = lambda x,y: ((x-y)**2).mean()
d1 = MSE(target_points,predicted_points)
input_z_pos = points.mean(axis=0)[2]
d2 = MSE((center(target_points) + [0,0,input_z_pos]),target_points)
print(f"D(target-pred) = {d1:.2f} D(target-src_centered_trgt) = {d2:.2f}")

In [None]:
# plot the z coordinate ==> intuition: cannot learn non symmetrical data
fig = plot_coordinates(target_points,predicted_points)
if use_wandb:
    wandb.log({"z_coord": wandb.Image(fig)})

## Experiment 3


The goal consists in learning a nonsymetrical spiral from a symmetrical one: we assume this is not possible

In [None]:
transformer = get_model()
epochs = 100
criterion = torch.nn.MSELoss()
batch_size = 4
lr = 0.01
optimizer = torch.optim.Adam(transformer.parameters(),lr=lr)
# optimizer = torch.optim.SGD(transformer.parameters(), lr=lr, momentum=0.9)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                       patience=100,
                                                       factor=0.4,
                                                       threshold=0.001,
                                                       verbose=True)
center_input = False
center_output = False
center_target = False

src_gen = SpiralGenerator(centering=True)
trgt_gen = SpiralGenerator(asym=True,centering=True,width_factor=1.5)
use_wandb = True
points, target_points = src_gen.generate(),trgt_gen.generate()
viz_point_cloud([(points,'src'),(target_points,'trgt')])

In [None]:
start_training(transformer,lr,optimizer,epochs,criterion,batch_size,scheduler,device,src_gen,trgt_gen,center_output,use_wandb)

In [None]:
points, target_points, predicted_points = get_predictions(transformer, src_gen,trgt_gen, center_output)
fig = visualize_prediction(points, target_points, predicted_points)
if use_wandb:
    wandb.log({"chart": wandb.Image(fig)})
fig

# Experiment 4

We now confirm that it is posible to learn the non symmetrical spiral from the asymetrical one

In [None]:
transformer = get_model()
epochs = 100
criterion = torch.nn.MSELoss()
batch_size = 4
lr = 0.01
optimizer = torch.optim.Adam(transformer.parameters(),lr=lr)
# optimizer = torch.optim.SGD(transformer.parameters(), lr=lr, momentum=0.9)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                       patience=100,
                                                       factor=0.4,
                                                       threshold=0.001,
                                                       verbose=True)
center_input = False
center_output = False
center_target = False

src_gen = SpiralGenerator(centering=True,asym=True)
trgt_gen = SpiralGenerator(asym=True,centering=True,scaling=3.0,width_factor=1.5)
use_wandb = True
points, target_points = src_gen.generate(),trgt_gen.generate()
viz_point_cloud([(points,'src'),(target_points,'trgt')])

In [None]:
start_training(transformer,lr,optimizer,epochs,criterion,batch_size,scheduler,device,src_gen,trgt_gen,center_output,use_wandb)

In [None]:
points, target_points, predicted_points = get_predictions(transformer, src_gen,trgt_gen, center_output)
fig = visualize_prediction(points, target_points, predicted_points)
if use_wandb:
    wandb.log({"chart": wandb.Image(fig)})
fig

## Experiment 5

In [None]:
transformer = get_model()
epochs = 100
criterion = torch.nn.MSELoss()
batch_size = 4
lr = 0.01
optimizer = torch.optim.Adam(transformer.parameters(),lr=lr)
# optimizer = torch.optim.SGD(transformer.parameters(), lr=lr, momentum=0.9)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                       patience=100,
                                                       factor=0.4,
                                                       threshold=0.001,
                                                       verbose=True)
center_input = False
center_output = False
center_target = False

src_gen = SpiralGenerator(asym=True,shift=0.3)
trgt_gen = SpiralGenerator(asym=True,shift=1.0)
use_wandb = True
points, target_points = src_gen.generate(),trgt_gen.generate()
viz_point_cloud([(points,'src'),(target_points,'trgt')])

You can visualize the type of point cloud you overfit in this cell

In [None]:
start_training(transformer,lr,optimizer,epochs,criterion,batch_size,scheduler,device,src_gen,trgt_gen,center_output,use_wandb)

In [None]:
points, target_points, predicted_points = get_predictions(transformer, src_gen,trgt_gen, center_output)
fig = visualize_prediction(points, target_points, predicted_points)
if use_wandb:
    wandb.log({"chart": wandb.Image(fig)})
fig

## Is the output systematically centered?

In [None]:
N = 25
scale = 1000
shift = 500
rand_points_tens = torch.rand(1,N,3) * scale # we scale the noise
rand_points_tens += torch.tensor([shift,shift,shift]) # and shift the point cloud
predicted_deltas_tens = predict(transformer,rand_points_tens)
print(f"Input barycenter:{rand_points_tens.mean(axis=1)}")
print(f"Output barycenter:{predicted_deltas_tens.mean(axis=1)}")

## Translational Equivariance

In [None]:
transformer = get_model()
shift = torch.tensor([15,15,15])
position = torch.tensor([0,0,10])
points_tens = torch.rand(1,10,3) + position

# we add the shift post-prediction
post_delta_tens = predict(transformer,points_tens)
post_prediction_tens = points_tens + post_delta_tens + shift

# we add the shift before the prediction
pre_delta_tens = predict(transformer,points_tens + shift)
pre_prediction_tens = points_tens + pre_delta_tens
pre_shift, post_shift = to_numpy_array(pre_prediction_tens),to_numpy_array(post_prediction_tens)
viz_point_cloud([(pre_shift,'pre'),(post_shift,'post')])