In [1]:
%load_ext autoreload
%autoreload 1

import sys
sys.path.append('..')
import os
import skimage
import time
import kaolin as kal
import tntorch as tn
import torch
import trimesh
import tqdm
import os.path as osp
import matplotlib.pyplot as plt
from t4dt.metrics import compute_metrics, hausdorff, MSDM2
from t4dt.utils import sdf2mesh, get_3dtensors_patches
from t4dt.t4dt import reduce_tucker, get_qtt_frame, qtt2tensor3d

In [2]:
ckpt = torch.load('../logs/sweeps/qtt_00032_longshort_flying_eagle_high_ranks_high_thr.pt')

In [9]:
scene = ckpt[(-0.05,0.05)][400]['compressed_scene']

In [10]:
i = 14
max_rank = 400
lrs = scene.clone()
ranks_tt = scene.ranks_tt.clone()
ranks_tt[i + 1] = min(max_rank, ranks_tt[i + 1])
lrs.round_tt(rmax=ranks_tt[1:-1])

frame0_qtt = get_qtt_frame(lrs, 0)

In [11]:
data_dir = '/scratch2/data/cape_release/'
model = '00032'
scene = 'longshort_flying_eagle'
frames = []
for frame in sorted(os.listdir(osp.join(data_dir, 'meshes', model, scene, 'posed'))):
    if frame.startswith('sdf'):
        frames.append(frame)
        
folder = osp.join(data_dir, 'meshes', model, scene, 'posed')
files = [(osp.join(folder, frame), osp.join(folder, frame[4:-2] + 'obj'))
            for frame in frames]
sdf0 = torch.load(files[0][0])
coords = torch.tensor(
    torch.load('/scratch2/data/cape_release/meshes/00032/longshort_flying_eagle/coords.pt')['coords'])

In [12]:
# NOTE: ideas
# 1. standartization входов и выходов
# 2. обрезка после cnn
# 3. границы - проверить значения
# 4. мб выкинуть бн
# 5. radam
# 6. сколько артифактов попадает в receptive field?
# 7. сравнить с блюром
# 8. positional encoding input for cnn?
# 9. siren?

In [14]:
# NOTE: slow near boundary voxels search

# result = []

# offsets = torch.tensor([
#    [-1, -1, -1],
#    [-1, -1,  0],
#    [-1, -1,  1],
#    [-1,  0, -1],
#    [-1,  0,  0],
#    [-1,  0,  1],
#    [-1,  1, -1],
#    [-1,  1,  0],
#    [-1,  1,  1],
#    [ 0, -1, -1],
#    [ 0, -1,  0],
#    [ 0, -1,  1],
#    [ 0,  0, -1],
#    [ 0,  0,  0],
#    [ 0,  0,  1],
#    [ 0,  1, -1],
#    [ 0,  1,  0],
#    [ 0,  1,  1],
#    [ 1, -1, -1],
#    [ 1, -1,  0],
#    [ 1, -1,  1],
#    [ 1,  0, -1],
#    [ 1,  0,  0],
#    [ 1,  0,  1],
#    [ 1,  1, -1],
#    [ 1,  1,  0],
#    [ 1,  1,  1]])

# for i in range(512):
#     for j in range(512):
#         for k in range(512):
#             p = torch.tensor([i, j, k])
#             candidates = p + offsets
#             candidates.clamp_min_(0)
#             candidates.clamp_max_(511)
#             signs = torch.sign(sdf0['sdf'][candidates[:, 0], candidates[:, 1], candidates[:, 2]])
#             if torch.any(signs != torch.sign(sdf0['sdf'][p[0], p[1], p[2]])):
#                 result.append(p)

In [15]:
for core in frame0_qtt.cores:
    core.requires_grad_()
print('make the cores trainable')

[tensor([[[ 6754.4312, -1004.0063],
          [ 6747.3096,  1005.0659]]], requires_grad=True),
 tensor([[[-7.0508e-01, -4.8293e-03, -5.1553e-02,  3.7291e-02],
          [-7.0334e-01,  4.9346e-03,  5.1715e-02, -3.7335e-02]],
 
         [[-3.8785e-04,  4.6748e-01, -3.2061e-01, -3.6658e-01],
          [ 4.8426e-03,  6.5004e-01,  1.9554e-01,  2.8894e-01]]],
        requires_grad=True),
 tensor([[[-6.9130e-01,  9.6349e-03,  2.8192e-02, -2.6585e-02,  4.4182e-02,
            6.1450e-04, -4.1783e-04,  1.9681e-04],
          [-7.1783e-01, -9.0506e-03, -2.7075e-02,  2.5624e-02, -4.2597e-02,
           -6.2309e-04,  4.0467e-04, -1.8155e-04]],
 
         [[-1.2038e-02, -9.4116e-01,  4.5220e-02, -4.2361e-02,  5.3004e-02,
            3.8202e-04, -1.7072e-04,  1.0858e-04],
          [ 4.5968e-03,  1.3772e-03,  6.3662e-03, -5.9651e-03,  1.3345e-02,
           -3.2660e-01,  4.2795e-03, -2.1541e-02]],
 
         [[-8.1834e-03, -1.0883e-02, -8.5266e-01, -2.9212e-01,  2.7911e-01,
            6.4131e-04, -

In [16]:
import torch.nn as nn
import torch.nn.functional as F


class CNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv3d(1, 6, 5)
        self.conv2 = nn.Conv3d(6, 16, 5)
        self.conv3 = nn.Conv3d(16, 1, 5)
        self.bn1 = torch.nn.BatchNorm3d(1)
        self.gelu = torch.nn.GELU()
        self.rec_field = 13

    def forward(self, x):
        x = self.bn1(x)
        x = self.gelu(self.conv1(x))
        x = self.gelu(self.conv2(x))
        x = self.conv3(x)
        return x


net = CNN()

In [17]:
res = net(torch.rand((1, 1, 13, 13, 13)))

In [18]:
idxs = torch.load('/scratch2/data/cape_release/meshes/00032/longshort_flying_eagle/near_surface_voxel_idxs_cross.pt')

In [19]:
MIN_TSDF = -0.05
MAX_TSDF = 0.05
NUM_EPOCHS = 100
NUM_SAMPLES = 5000
RES = 512

optimizer = torch.optim.SGD(frame0_qtt.cores + list(net.parameters()), lr=1e-3)
gt_sdf = sdf0['sdf']

In [34]:
net.cuda()
net.train()
for i in range(NUM_EPOCHS):
    epoch_loss = 0
    num_batches = 0
    for batch_start in range(0, idxs[frames[0]].shape[0], NUM_SAMPLES):
        optimizer.zero_grad()
        reconstruction = qtt2tensor3d(frame0_qtt.torch())
        local_idxs = idxs[frames[0]][batch_start:batch_start + NUM_SAMPLES]
        patches = get_3dtensors_patches(
            reconstruction.clamp_max(MAX_TSDF).clamp_min(MIN_TSDF),
            net.rec_field,
            local_idxs)[:, None].cuda() / (2 * 0.05)
        predicted_voxels = net(patches).reshape(-1)
        clamped_pred = predicted_voxels.clamp_max(MAX_TSDF).clamp_min(MIN_TSDF)
        clamped_gt = gt_sdf[local_idxs[:, 0], local_idxs[:, 1], local_idxs[:, 2]].cuda()
        clamped_gt = clamped_gt.clamp_max(MAX_TSDF).clamp_min(MIN_TSDF) / (2 * 0.05)
        loss = ((clamped_pred - clamped_gt)**2).mean()
        loss.backward()
        epoch_loss += loss
        num_batches += 1
        print(f'Batch: {batch_start}, loss: {torch.sqrt(loss).item()}')
    print(f'Epoch: {i}, loss: {torch.sqrt(epoch_loss / num_batches).item()}')
    optimizer.step()

Batch: 0, loss: 0.04489718750119209
Batch: 5000, loss: 0.050299305468797684
Batch: 10000, loss: 0.050056811422109604
Batch: 15000, loss: 0.04951167851686478


KeyboardInterrupt: 

In [36]:
torch.sqrt(((reconstruction - gt_sdf)**2).mean())

tensor(0.4947, grad_fn=<SqrtBackward0>)

In [46]:
torch.save(net.state_dict(), '../logs/CNN_qtt_00032_fling_eagle_learning_frame0.pt')

In [10]:
state_dict = torch.load('../logs/CNN_qtt_00032_fling_eagle_learning_frame0.pt')
net.load_state_dict(state_dict)

<All keys matched successfully>

In [50]:
patch = get_3dtensors_patches(qtt2tensor3d(frame0_qtt.torch()), 512 + 13, torch.tensor([[255, 255, 255]]))
net.cpu()
with torch.no_grad():
    predicted_voxels = net(patch[:, None])

In [55]:
coords = torch.tensor(
    torch.load('/scratch2/data/cape_release/meshes/00032/longshort_flying_eagle/coords.pt')['coords'])
mesh = sdf2mesh(predicted_voxels[0, 0, 10:-10, 10:-10, 10:-10] * (2 * 0.05), coords)
obj = trimesh.exchange.obj.export_obj(mesh, include_texture=False)
with open('./mesh_sdf.obj', 'w') as f:
    f.write(obj)

In [57]:
coords = torch.tensor(
    torch.load('/scratch2/data/cape_release/meshes/00032/longshort_flying_eagle/coords.pt')['coords'])
mesh = sdf2mesh(reconstruction, coords)
obj = trimesh.exchange.obj.export_obj(mesh, include_texture=False)
with open('./mesh_sdf.obj', 'w') as f:
    f.write(obj)