In [None]:
from hybridbrep import GeneralConvEncDec, BRepFaceAutoencoder, HybridPartDataset
import json
from zipfile import ZipFile
import torch
from tqdm import tqdm
import numpy as np
from matplotlib import pyplot as plt
from train_latent_space import BRepFaceAutoencoder as OldAutoencoder, BRepDS as OldDS
import altair as alt
import pandas as pd

In [None]:
old_model = BRepFaceAutoencoder(64, 1024, 4, True)
new_model = GeneralConvEncDec(64, 1024, 4)

old_ckpt_path = '/home/ben/Documents/research/repbrep/training_logs/reconstruction/old_with_loops/version_0/checkpoints/epoch=54-val_loss=0.002403.ckpt'
new_ckpt_path = '/home/ben/Documents/research/repbrep/training_logs/reconstruction/new_with_edges/version_1/checkpoints/epoch=183-val_loss=0.002646.ckpt'

old_ckpt = torch.load(old_ckpt_path)
new_ckpt = torch.load(new_ckpt_path)

old_model.load_state_dict(old_ckpt['state_dict'])
new_model.load_state_dict(new_ckpt['state_dict'])

In [None]:
with ZipFile('../../datasets/fusion360seg_hpart_fixed.zip','r') as zf:
    names = [n for n in zf.namelist() if n.endswith('.stp')]
    

In [None]:
ds_test = HybridPartDataset('../../datasets/fusion360seg.json', '../../datasets/fusion360seg_hpart_fixed.zip', mode='test')

In [None]:
av_xzy_dist_old = []
av_mask_abs_error_old = []

with torch.no_grad():
    for data in tqdm(ds_test):
        uvs = data.surface_coords.reshape((-1,2))
        uv_idx = torch.arange(data.surface_coords.shape[0]).repeat_interleave(data.surface_coords.shape[1])
        target = torch.cat([data.surface_samples[:,:,:3],data.surface_samples[:,:,-1].unsqueeze(-1)],dim=-1)
        pred = old_model(data, uvs, uv_idx).reshape_as(target)
        diffs = (pred - target)
        av_xyz_dist = (diffs[:,:,:3]**2).sum(dim=-1).sqrt().mean()
        av_mask_abs_error = diffs[:,:,-1].abs().mean()

        av_xzy_dist_old.append(av_xyz_dist)
        av_mask_abs_error_old.append(av_mask_abs_error)

In [None]:
av_xyz_dist_old = [x.item() for x in av_xzy_dist_old]
av_mask_abs_error_old = [x.item() for x in av_mask_abs_error_old]

In [None]:
av_xyz_dist_new = []
av_mask_abs_error_new = []

with torch.no_grad():
    for data in tqdm(ds_test):
        face_codes, face_preds = new_model.enc_dec(data, data.surface_coords)
        target = torch.cat([data.surface_samples[:,:,:3],data.surface_samples[:,:,-1].unsqueeze(-1)],dim=-1)
        pred = face_preds.reshape_as(target)
        diffs = (pred - target)
        av_xyz_dist = (diffs[:,:,:3]**2).sum(dim=-1).sqrt().mean().item()
        av_mask_abs_error = diffs[:,:,-1].abs().mean().item()

        av_xyz_dist_new.append(av_xyz_dist)
        av_mask_abs_error_new.append(av_mask_abs_error)


In [None]:
orig_model = OldAutoencoder(64, 1024, 4)
orig_ckpt_path = '../../models/BRepFaceAutoencoder_64_1024_4/BRepFaceAutoencoder_64_1024_4.ckpt'
orig_ckpt = torch.load(orig_ckpt_path)
orig_model.load_state_dict(orig_ckpt['state_dict'])
orig_index_path = '/media/ben/Data/fusion360segmentation/simple_train_test.json'
orig_data_path = '/media/ben/Data/fusion360segmentation/simple_preprocessed'
orig_ds_test = OldDS(orig_index_path, orig_data_path, mode='test')

In [None]:
av_xzy_dist_orig = []
av_mask_abs_error_orig = []

with torch.no_grad():
    for data in tqdm(orig_ds_test):
        uvs = data.surface_coords.reshape((-1,2))
        uv_idx = torch.arange(data.surface_coords.shape[0]).repeat_interleave(data.surface_coords.shape[1])
        target = torch.cat([data.surface_samples[:,:,:3],data.surface_samples[:,:,-1].unsqueeze(-1)],dim=-1)
        pred = orig_model(data, uvs, uv_idx).reshape_as(target)
        diffs = (pred - target)
        av_xyz_dist = (diffs[:,:,:3]**2).sum(dim=-1).sqrt().mean()
        av_mask_abs_error = diffs[:,:,-1].abs().mean()

        av_xzy_dist_orig.append(av_xyz_dist)
        av_mask_abs_error_orig.append(av_mask_abs_error)

In [None]:
print(np.mean([x.item() for x in av_xzy_dist_orig]))
print(np.mean([x.item() for x in av_mask_abs_error_orig]))

In [None]:
with open(orig_index_path, 'r') as f:
    orig_index = json.load(f)
with open('../../datasets/fusion360seg.json','r') as f:
    new_index = json.load(f)

In [None]:
av_xyz_dist_orig_new_sampling = []
av_mask_abs_error_orig_new_sampling = []

with torch.no_grad():
    for i in tqdm(range(len(orig_ds_test))):
        data = orig_ds_test[i]
        data_new = ds_test[i]
        uvs = data_new.surface_coords.reshape((-1,2))
        uv_idx = torch.arange(data_new.surface_coords.shape[0]).repeat_interleave(data_new.surface_coords.shape[1])
        target = torch.cat([data_new.surface_samples[:,:,:3],data_new.surface_samples[:,:,-1].unsqueeze(-1)],dim=-1)
        pred = orig_model(data, uvs, uv_idx).reshape_as(target)
        diffs = (pred - target)
        av_xyz_dist = (diffs[:,:,:3]**2).sum(dim=-1).sqrt().mean().item()
        av_mask_abs_error = diffs[:,:,-1].abs().mean().item()

        av_xyz_dist_orig_new_sampling.append(av_xyz_dist)
        av_mask_abs_error_orig_new_sampling.append(av_mask_abs_error)

In [None]:
print('Old Network, New Data')
print(np.mean(av_xyz_dist_old))
print(np.mean(av_mask_abs_error_old))

print('New Network, New Data')
print(np.mean(av_xyz_dist_new))
print(np.mean(av_mask_abs_error_new))

print('Old Network, Old Data (test old data)')
print(np.mean([x.item() for x in av_xzy_dist_orig]))
print(np.mean([x.item() for x in av_mask_abs_error_orig]))

print('Old Network, Old Data, (test new data)')
print(np.mean(av_xyz_dist_orig_new_sampling))
print(np.mean(av_mask_abs_error_orig_new_sampling))

In [None]:
records = []
for test_idx, xyz_dist in enumerate(av_xyz_dist_old):
    records.append({
        'test_idx':test_idx,
        'model':'Old Network, New Data',
        'metric':'xyz_dist',
        'value': xyz_dist
    })
for test_idx, xyz_dist in enumerate(av_mask_abs_error_old):
    records.append({
        'test_idx':test_idx,
        'model':'Old Network, New Data',
        'metric':'mask_error',
        'value': xyz_dist
    })

for test_idx, xyz_dist in enumerate(av_xyz_dist_new):
    records.append({
        'test_idx':test_idx,
        'model':'New Network, New Data',
        'metric':'xyz_dist',
        'value': xyz_dist
    })
for test_idx, mask_error in enumerate(av_mask_abs_error_new):
    records.append({
        'test_idx':test_idx,
        'model':'New Network, New Data',
        'metric':'mask_error',
        'value': mask_error
    })

for test_idx, xyz_dist in enumerate(av_xyz_dist_orig_new_sampling):
    records.append({
        'test_idx':test_idx,
        'model':'Old Network, Old Data',
        'metric':'xyz_dist',
        'value': xyz_dist
    })
for test_idx, mask_error in enumerate(av_mask_abs_error_orig_new_sampling):
    records.append({
        'test_idx':test_idx,
        'model':'Old Network, Old Data',
        'metric':'mask_error',
        'value': mask_error
    })

error_rec = pd.DataFrame.from_records(records)

In [None]:
error_rec.to_parquet('../../results/recon_ablations.parquet')

In [None]:
error_rec = pd.read_parquet('../../results/recon_ablations.parquet')

In [None]:
len(error_rec[error_rec.model == 'New Network, New Data'])

In [None]:
error_rec.groupby(['metric','model']).agg({'value':np.mean}).reset_index()

In [None]:
alt.Chart(error_rec.groupby(['metric','model']).agg({'value':np.mean}).reset_index()).mark_bar().encode(
    x='model',
    color='model',
    y='value',
    column='metric'
)

In [None]:
print( ((.025**2)/3)**(1/2) )
print( ((.015**2)/2)**(1/2) )

In [None]:
alt.data_transformers.disable_max_rows()
source = error_rec

bars = alt.Chart().mark_bar().encode(
    x='model',
    y=alt.Y('mean(value):Q', title='Mean Value'),
    color='model',
)

error_bars = alt.Chart().mark_errorbar(extent='ci').encode(
    x='model',
    y='value'
)

alt.layer(bars, error_bars, data=source).facet(
    column='metric'
)

In [None]:
len(error_rec)