In [1]:
%load_ext autoreload
%autoreload 1

import sys
sys.path.append('..')
import os
import skimage
import time
import kaolin as kal
import tntorch as tn
import torch
import trimesh
import tqdm
import os.path as osp
import matplotlib.pyplot as plt
from t4dt.metrics import compute_metrics, hausdorff, MSDM2
from t4dt.utils import sdf2mesh, get_3dtensors_patches
from t4dt.t4dt import reduce_tucker, get_qtt_frame, qtt2tensor3d

In [2]:
ckpt = torch.load('../logs/sweeps/qtt_00032_longshort_flying_eagle_high_ranks_high_thr.pt')

In [3]:
scene = ckpt[(-0.05,0.05)][400]['compressed_scene']

In [16]:
def invert(sdf, clamp_value):
    df = sdf.abs()
    sign = torch.sign(sdf)
    sign[sdf == 0] = 1
    return sign * torch.exp(-df.clamp_max(clamp_value))

def invert_back(inv):
    sign = torch.sign(inv)
    rec = -torch.log(torch.abs(inv))
    rec[~torch.isfinite(rec)] = 0
    return sign * rec

In [5]:
max_rank = 400
lrs = scene.clone()
rmax = 800
ranks_tt = scene.ranks_tt.clone()
ranks_tt[4::4] = torch.where(ranks_tt[4::4] < rmax, ranks_tt[4::4], rmax)
ranks_tt[3::4] = torch.where(ranks_tt[3::4] < rmax, ranks_tt[3::4], rmax)
lrs.round_tt(rmax=ranks_tt[1:-1])

frame0_qtt = get_qtt_frame(lrs, 0)

In [6]:
data_dir = '/scratch2/data/cape_release/'
model = '00032'
scene = 'longshort_flying_eagle'
frames = []
for frame in sorted(os.listdir(osp.join(data_dir, 'meshes', model, scene, 'posed'))):
    if frame.startswith('sdf'):
        frames.append(frame)
        
folder = osp.join(data_dir, 'meshes', model, scene, 'posed')
files = [(osp.join(folder, frame), osp.join(folder, frame[4:-2] + 'obj'))
            for frame in frames]
sdf0 = torch.load(files[0][0])
coords = torch.tensor(
    torch.load('/scratch2/data/cape_release/meshes/00032/longshort_flying_eagle/coords.pt')['coords'])

In [12]:
# NOTE: ideas
# 1. standartization входов и выходов
# 2. обрезка после cnn
# 3. границы - проверить значения
# 4. мб выкинуть бн
# 5. radam
# 6. сколько артифактов попадает в receptive field?
# 7. сравнить с блюром
# 8. positional encoding input for cnn?
# 9. siren?

In [14]:
# NOTE: slow near boundary voxels search

# result = []

# offsets = torch.tensor([
#    [-1, -1, -1],
#    [-1, -1,  0],
#    [-1, -1,  1],
#    [-1,  0, -1],
#    [-1,  0,  0],
#    [-1,  0,  1],
#    [-1,  1, -1],
#    [-1,  1,  0],
#    [-1,  1,  1],
#    [ 0, -1, -1],
#    [ 0, -1,  0],
#    [ 0, -1,  1],
#    [ 0,  0, -1],
#    [ 0,  0,  0],
#    [ 0,  0,  1],
#    [ 0,  1, -1],
#    [ 0,  1,  0],
#    [ 0,  1,  1],
#    [ 1, -1, -1],
#    [ 1, -1,  0],
#    [ 1, -1,  1],
#    [ 1,  0, -1],
#    [ 1,  0,  0],
#    [ 1,  0,  1],
#    [ 1,  1, -1],
#    [ 1,  1,  0],
#    [ 1,  1,  1]])

# for i in range(512):
#     for j in range(512):
#         for k in range(512):
#             p = torch.tensor([i, j, k])
#             candidates = p + offsets
#             candidates.clamp_min_(0)
#             candidates.clamp_max_(511)
#             signs = torch.sign(sdf0['sdf'][candidates[:, 0], candidates[:, 1], candidates[:, 2]])
#             if torch.any(signs != torch.sign(sdf0['sdf'][p[0], p[1], p[2]])):
#                 result.append(p)

In [7]:
for core in frame0_qtt.cores:
    core.requires_grad_()
print('make the cores trainable')

make the cores trainable


In [8]:
import torch.nn as nn
import torch.nn.functional as F


class CNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv3d(1, 6, 5)
        self.conv2 = nn.Conv3d(6, 16, 5)
        self.conv3 = nn.Conv3d(16, 1, 5)
        self.bn1 = torch.nn.BatchNorm3d(1)
        self.gelu = torch.nn.GELU()
        self.rec_field = 13

    def forward(self, x):
        x = self.bn1(x)
        x = self.gelu(self.conv1(x))
        x = self.gelu(self.conv2(x))
        x = self.conv3(x)
        return x


net = CNN()

In [9]:
res = net(torch.rand((1, 1, 13, 13, 13)))

In [10]:
idxs = torch.load(
    '/scratch2/data/cape_release/meshes/00032/longshort_flying_eagle/near_surface_voxel_idxs_cross.pt')

In [11]:
MIN_TSDF = -0.05
MAX_TSDF = 0.05
NUM_EPOCHS = 100
NUM_SAMPLES = 5000
RES = 512

optimizer = torch.optim.SGD(frame0_qtt.cores + list(net.parameters()), lr=1e-3)
gt_sdf = sdf0['sdf']

In [29]:
net.cuda()
net.train()
for i in range(NUM_EPOCHS):
    epoch_loss = 0
    num_batches = 0
    for batch_start in range(0, idxs[frames[0]].shape[0], NUM_SAMPLES):
        optimizer.zero_grad()
        reconstruction = qtt2tensor3d(frame0_qtt.torch())
        local_idxs = idxs[frames[0]][batch_start:batch_start + NUM_SAMPLES]
        patches = get_3dtensors_patches(
            reconstruction.clamp_max(MAX_TSDF).clamp_min(MIN_TSDF),
            net.rec_field,
            local_idxs)[:, None].cuda()
        predicted_voxels = net(patches).reshape(-1)
        clamped_pred = predicted_voxels.clamp_max(MAX_TSDF).clamp_min(MIN_TSDF)
        clamped_gt = gt_sdf[local_idxs[:, 0], local_idxs[:, 1], local_idxs[:, 2]].cuda()
        clamped_gt = clamped_gt.clamp_max(MAX_TSDF).clamp_min(MIN_TSDF)
        loss = ((invert(clamped_pred, MAX_TSDF) - invert(clamped_gt, MAX_TSDF))**2).mean()
        loss.backward()
        epoch_loss += loss
        num_batches += 1
        print(f'Batch: {batch_start}, loss: {torch.sqrt(loss).item()}')
    print(f'Epoch: {i}, loss: {torch.sqrt(epoch_loss / num_batches).item()}')
    optimizer.step()

Batch: 0, loss: 1.3781309127807617
Batch: 5000, loss: 1.3931143283843994
Batch: 15000, loss: 1.3875826597213745
Batch: 20000, loss: 1.4073865413665771
Batch: 25000, loss: 1.32358717918396
Batch: 30000, loss: 1.285507082939148
Batch: 35000, loss: 1.423962116241455
Batch: 40000, loss: 1.3753794431686401
Batch: 45000, loss: 1.414846420288086
Batch: 50000, loss: 1.4188364744186401
Batch: 55000, loss: 1.3863054513931274
Batch: 60000, loss: 1.3700129985809326
Batch: 65000, loss: 1.4689780473709106
Batch: 70000, loss: 1.427836298942566
Batch: 75000, loss: 1.3767499923706055
Batch: 80000, loss: 1.4170403480529785
Batch: 85000, loss: 1.3882029056549072
Batch: 90000, loss: 1.4170283079147339
Batch: 95000, loss: 1.3594329357147217
Batch: 100000, loss: 1.4054861068725586
Batch: 105000, loss: 1.3923677206039429
Batch: 110000, loss: 1.4077253341674805
Batch: 115000, loss: 1.3883569240570068
Batch: 120000, loss: 1.3901290893554688
Batch: 125000, loss: 1.3951390981674194
Batch: 130000, loss: 1.3770514

KeyboardInterrupt: 

In [36]:
torch.sqrt(((reconstruction - gt_sdf)**2).mean())

tensor(0.4947, grad_fn=<SqrtBackward0>)

In [30]:
# torch.save(net.state_dict(), '../logs/CNN_qtt_00032_fling_eagle_learning_frame0_inv.pt')

In [10]:
state_dict = torch.load('../logs/CNN_qtt_00032_fling_eagle_learning_frame0.pt')
net.load_state_dict(state_dict)

<All keys matched successfully>

In [31]:
patch = get_3dtensors_patches(qtt2tensor3d(frame0_qtt.torch()), 512 + 13, torch.tensor([[255, 255, 255]]))
net.cpu()
with torch.no_grad():
    predicted_voxels = net(patch[:, None])

In [32]:
coords = torch.tensor(
    torch.load('/scratch2/data/cape_release/meshes/00032/longshort_flying_eagle/coords.pt')['coords'])
mesh = sdf2mesh(invert_back(predicted_voxels[0, 0, 10:-10, 10:-10, 10:-10]), coords)
obj = trimesh.exchange.obj.export_obj(mesh, include_texture=False)
with open('./mesh_sdf.obj', 'w') as f:
    f.write(obj)

In [55]:
coords = torch.tensor(
    torch.load('/scratch2/data/cape_release/meshes/00032/longshort_flying_eagle/coords.pt')['coords'])
mesh = sdf2mesh(predicted_voxels[0, 0, 10:-10, 10:-10, 10:-10] * (2 * 0.05), coords)
obj = trimesh.exchange.obj.export_obj(mesh, include_texture=False)
with open('./mesh_sdf.obj', 'w') as f:
    f.write(obj)

In [57]:
coords = torch.tensor(
    torch.load('/scratch2/data/cape_release/meshes/00032/longshort_flying_eagle/coords.pt')['coords'])
mesh = sdf2mesh(reconstruction, coords)
obj = trimesh.exchange.obj.export_obj(mesh, include_texture=False)
with open('./mesh_sdf.obj', 'w') as f:
    f.write(obj)