In [2]:
# reload modules
%load_ext autoreload
%autoreload 2

In [3]:
import os
import torch
from pytorch_lightning import LightningDataModule
from torch.utils.data import DataLoader
from torch.utils.data.dataset import random_split
from torch.utils.data import TensorDataset, Dataset
import h5py
import tables as pytables
import random
import numpy as np
import time
import matplotlib.pyplot as plt
import os.path as osp
import glob
import time

from lit_dataset_clean import HumanDataModule
from acid_dataset import AcidDataModule
from acid_corrnet3d_clean import LitCorrNet3D

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
def plot_point_clouds(points, colors, sizes, extra_data=None, opacity = 0.8, transform=None, show=False):

    """
    points: list of Nx3 array
    colors: list of colors (strings)
    extra_data: list of plotly data
    """

    if not isinstance(points, list):
        points = [points]
    if not isinstance(colors, list):
        colors = [colors]

    import plotly.graph_objects as go

    data = []
    for point, color, size in zip(points, colors, sizes):
        data.append(go.Scatter3d(
            x=point[:, 0], y=point[:, 1], z=point[:, 2],
            mode='markers',
            marker=dict(
                size=size,
                color=color,                
                opacity=0.8
            )
        ))

    if extra_data is not None:
        data += extra_data

    if transform is not None:
        local_frame =  PlotlyVisualizer().plotly_create_local_frame(transform)
        data += local_frame
        
    fig = go.Figure(
        data=data,
        layout=dict(
            scene=dict(
                xaxis=dict(visible=False),
                yaxis=dict(visible=False),
                zaxis=dict(visible=False)
            )
        )
    )

    if show:
        fig.show()

    return fig

In [7]:
import pickle

with open('args.pkl', 'rb') as f:
    args = pickle.load(f)

args.input_pts = 1024
model = LitCorrNet3D(**vars(args))
  
dm_human = AcidDataModule
dm = dm_human.from_argparse_args(args) 

In [8]:
batch = next(iter(dm.train_dataloader()))

In [9]:
device = torch.device("cuda:2" if torch.cuda.is_available() else "cpu")

In [10]:
ckpt1 = "tb_logs/supervised_2_mse/version_0/checkpoints/epoch=2-step=112499.ckpt"
model_test1 = model.load_from_checkpoint(ckpt1)
model_test1.to(device);

In [11]:
true_correspendence, pinput1, input2, index_ = batch
pinput1, input2, true_correspendence = pinput1.to(device), input2.to(device), true_correspendence.to(device)
p, out_a, out_b =model_test1._run_step(pinput1,input2)

# inputs
true_labels = torch.argmax(true_correspendence, axis = -1)

# outputs
p = p.transpose(1,2)
predicted_labels = torch.argmax(p, axis = -1)
predicted_correspondence = torch.zeros_like(true_correspendence)
for i in range(predicted_correspondence.shape[0]):
    for j in range(predicted_correspondence.shape[1]):
        predicted_correspondence[i, j, predicted_labels[i,j]] = 1
        
pinput2 = torch.bmm(predicted_correspondence, input2)


In [35]:
index = 5
pts1 = pinput1[index].cpu().numpy()
plot_point_clouds([pts1[:1, :], pts1[1:, :]], ["green", "red"], [5, 1])

In [36]:
tinput2 = torch.bmm(true_correspendence, input2)
index = 5
pts3 = tinput2[index].cpu().numpy()
plot_point_clouds([pts3[:1, :], pts3[1:, :]], ["green", "red"], [5, 1])

In [37]:
index = 5
pinput2 = torch.bmm(predicted_correspondence, input2)
pts2 = pinput2[index].cpu().numpy()
plot_point_clouds([pts2[:1, :], pts2[1:, :]], ["green", "red"], [5, 1])

In [58]:
# calculate accuracy using true labels and predicted labels

def accuracy(true_labels, predicted_labels):
    return [torch.sum(true_labels[i] == predicted_labels[i]).item() / true_labels.shape[1]*100 for i in range(true_labels.shape[0])]

print("Accuracy: ", accuracy(true_labels, predicted_labels))

Accuracy:  [0.09765625, 0.09765625, 0.1953125, 0.1953125, 0.0, 0.0, 0.09765625, 0.0]


In [59]:
ckpt2 = "tb_logs/supervised_1_ce/version_0/checkpoints/epoch=3-step=149999.ckpt"
model_test2 = model.load_from_checkpoint(ckpt1)
model_test2.to(device);

In [65]:
p, out_a, out_b =model_test2._run_step(pinput1,input2)

# inputs
true_labels = torch.argmax(true_correspendence, axis = -1)

# outputs
p = p.transpose(1,2)
predicted_labels = torch.argmax(p, axis = -1)
predicted_correspondence = torch.zeros_like(true_correspendence)
for i in range(predicted_correspondence.shape[0]):
    for j in range(predicted_correspondence.shape[1]):
        predicted_correspondence[i, j, predicted_labels[i,j]] = 1
        
pinput2 = torch.bmm(predicted_correspondence, input2)

In [66]:
index = 5
pts1 = pinput1[index].cpu().numpy()
plot_point_clouds([pts1[:1, :], pts1[1:, :]], ["green", "red"], [5, 1])

In [67]:
tinput2 = torch.bmm(true_correspendence, input2)
index = 5
pts3 = tinput2[index].cpu().numpy()
plot_point_clouds([pts3[:1, :], pts3[1:, :]], ["green", "red"], [5, 1])

In [68]:
index = 5
pinput2 = torch.bmm(predicted_correspondence, input2)
pts2 = pinput2[index].cpu().numpy()
plot_point_clouds([pts2[:1, :], pts2[1:, :]], ["green", "red"], [5, 1])

In [70]:
# calculate accuracy using true labels and predicted labels

def accuracy(true_labels, predicted_labels):
    return [torch.sum(true_labels[i] == predicted_labels[i]).item() for i in range(true_labels.shape[0])]

print("Accuracy: ", accuracy(true_labels, predicted_labels))

Accuracy:  [1, 1, 2, 2, 0, 0, 1, 0]


In [77]:
state_dict = torch.load(ckpt1)["state_dict"]

In [79]:
model.load_state_dict(state_dict)

<All keys matched successfully>