In [217]:
# some code from DeepSDF https://github.com/facebookresearch/DeepSDF/tree/main

from torchvision import datasets, transforms
import torch
import matplotlib.pyplot as plt
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset
import math

from torchinfo import summary
from easydict import EasyDict as ed
import time
import random

from model import SDFNet, Decoder, Simple
import workspace as ws
import os

import logging

### Unseen data

In [211]:
class CustomSDFDataset(Dataset):
	def __init__(self, dataset):
		self.indices = range(len(dataset))
		self.sdf = dataset
		self.sampling_percent = 1
		self.sdf_dim = 28
		self.xy = self._init_xy()

	def _init_xy(self):
		x = np.linspace(0,1,self.sdf_dim)
		y = x
		xy = np.meshgrid(x, y)
		xy = np.stack(xy).reshape(2,-1).T
		return torch.tensor(xy, dtype=torch.float)
	
	def __len__(self):
		return len(self.sdf)

	def __getitem__(self, idx):
		pos_data, pos_idx = self._get_pos_data()
		sdf, label = self.sdf[idx]
		sdf_data = self._get_s_data(sdf, pos_idx)[:, None]
		# print(pos_data.shape)
		# print(sdf_data.shape)
		X = torch.cat([pos_data, sdf_data], dim=-1)
		return X, idx, label
	
	def _get_pos_data(self):
		sample_size = math.ceil(self.sampling_percent * (self.sdf_dim**2))
		# get indices to index a position index.
		# pos_idx = random.sample(range(self.sdf_dim**2), sample_size)
		pos_idx = range(self.sdf_dim**2)
		xy = self.xy[pos_idx]
		return self.xy[pos_idx], pos_idx
	
	def _get_s_data(self, sdf, pos_idx):
		return -sdf.flatten()[pos_idx]
	
	def get_pos_from_idx(self, idx_x, idx_y):
		return self.xy[idx_y*self.sdf_dim + idx_x]

transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,)) #0.1307(0.3081,)
        ])

dataset1 = datasets.MNIST('./data', train=True,
                    transform=transform)
train_kwargs = {'batch_size': 6, 'shuffle':True}
train_dataset = CustomSDFDataset(dataset1)
sdf_loader = torch.utils.data.DataLoader(train_dataset,**train_kwargs)


## Reconstruct

In [244]:
def reconstruct(
    decoder,
    num_iterations,
    latent_size,
    sdf_data,
    stat,
    clamp_dist,
    num_samples=28*28,
    lr=5e-4,
    l2reg=False,
):
    def adjust_learning_rate(
        initial_lr, optimizer, num_iterations, decreased_by, adjust_lr_every
    ):
        lr = initial_lr * ((1 / decreased_by) ** (num_iterations // adjust_lr_every))
        for param_group in optimizer.param_groups:
            param_group["lr"] = lr

    decreased_by = 10
    adjust_lr_every = int(num_iterations / 2)

    if type(stat) == type(0.1):
        latent = torch.ones(1, latent_size).normal_(mean=0, std=stat) #.cuda()
    else:
        latent = torch.normal(stat[0].detach(), stat[1].detach()) #.cuda()

    latent.requires_grad = True

    optimizer = torch.optim.Adam([latent], lr=lr)

    loss_num = 0
    loss_l1 = torch.nn.L1Loss()

    for e in range(num_iterations):

        decoder.eval()
        xyz = sdf_data[:, 0:2]
        sdf_gt = sdf_data[:, 2].unsqueeze(1)

        # sdf_gt = torch.clamp(sdf_gt, -clamp_dist, clamp_dist)

        adjust_learning_rate(lr, optimizer, e, decreased_by, adjust_lr_every)

        optimizer.zero_grad()

        latent_inputs = latent.expand(num_samples, -1)

        inputs = torch.cat([latent_inputs, xyz], 1) #.cuda()

        pred_sdf = decoder(inputs)

        # TODO: why is this needed?
        if e == 0:
            pred_sdf = decoder(inputs)

        # pred_sdf = torch.clamp(pred_sdf, -clamp_dist, clamp_dist)

        loss = loss_l1(pred_sdf, sdf_gt)
        if l2reg:
            loss += 1e-4 * torch.mean(latent.pow(2))
        loss.backward()
        optimizer.step()

        if e % 50 == 0:
            logging.debug(loss.cpu().data.numpy())
            logging.debug(e)
            logging.debug(latent.norm())
        loss_num = loss.cpu().data.numpy()

    return loss_num, latent


In [274]:
def reconstruct_partial(
    decoder,
    num_iterations,
    latent_size,
    sdf_data,
    stat,
    clamp_dist,
    num_samples=28*14,
    lr=5e-4,
    l2reg=False,
):
	def adjust_learning_rate(
		initial_lr, optimizer, num_iterations, decreased_by, adjust_lr_every
	):
		lr = initial_lr * ((1 / decreased_by) ** (num_iterations // adjust_lr_every))
		for param_group in optimizer.param_groups:
			param_group["lr"] = lr

	decreased_by = 10
	adjust_lr_every = int(num_iterations / 2)

	if type(stat) == type(0.1):
		latent = torch.ones(1, latent_size).normal_(mean=0, std=stat) #.cuda()
	else:
		latent = torch.normal(stat[0].detach(), stat[1].detach()) #.cuda()

	latent.requires_grad = True

	optimizer = torch.optim.Adam([latent], lr=lr)

	loss_num = 0
	loss_l1 = torch.nn.L1Loss()

	sdf_data = sdf_data.view(28, 28, 3) # Shape to image dimensions

	sdf_data = sdf_data[:, :14, :].reshape(-1, 3) # keep only right half of image

	for e in range(num_iterations):

		decoder.eval()
		
		xyz = sdf_data[:, 0:2]
		sdf_gt = sdf_data[:, 2].unsqueeze(1)

		# sdf_gt = torch.clamp(sdf_gt, -clamp_dist, clamp_dist)

		adjust_learning_rate(lr, optimizer, e, decreased_by, adjust_lr_every)

		optimizer.zero_grad()

		latent_inputs = latent.expand(num_samples, -1)

		inputs = torch.cat([latent_inputs, xyz], 1) #.cuda()

		pred_sdf = decoder(inputs)

		# TODO: why is this needed?
		if e == 0:
			pred_sdf = decoder(inputs)

		# pred_sdf = torch.clamp(pred_sdf, -clamp_dist, clamp_dist)

		loss = loss_l1(pred_sdf, sdf_gt)
		if l2reg:
			loss += 1e-4 * torch.mean(latent.pow(2))
		loss.backward()
		optimizer.step()

		if e % 50 == 0:
			logging.debug(loss.cpu().data.numpy())
			logging.debug(e)
			logging.debug(latent.norm())
		loss_num = loss.cpu().data.numpy()

	return loss_num, latent

In [None]:
lrschedule = [
	{ # Lr schedule for decoder
		"Type": "Step",
		"Initial":1e-4,
		"Interval":300, #Step at every nth epoch
		"Factor":0.5
	},
	{ # lr schedule for embeddings
		"Type": "Step",
		"Initial":3e-4,
		"Interval":300,
		"Factor":0.5,
	}
]

config = {
	'latent_size': 3,
	'code_bound': 1,
	'code_regularization': True, #False, #True,
	"CodeRegularizationLambda": 1e-4,
	'CodeInitStdDev': 1.0,
	'LearningRateSchedule': lrschedule,
	'grad_clip': None, #2.0, # None, # 1.0, # None, # for decoder
	"SnapshotFrequency": 30, # checkpoints
	"NumEpochs": 2000,
	"LogFrequency": 5,
	"ClampingDistance": 0.3,
	"NetworkSpecs" : {
		"dims" : [ 512, 512, 512, 512, 512, 512, 512, 512 ],
		"dropout" : [0, 1, 2, 3, 4, 5, 6, 7],
		"dropout_prob" : 0.2,
		"norm_layers" : [0, 1, 2, 3, 4, 5, 6, 7],
		"latent_in" : [4],
		"xyz_in_all" : False,
		"use_tanh" : False,
		"latent_dropout" : False,
		"weight_norm" : True
		},
}

config = ed(config)
experiment_directory = os.path.join('./experiment')

args = ed({
	'experiment_directory': './experiment',
	'checkpoint': 'latest',
	'iterations': 800,
	'split_filename': 'the split to reconstruct',

})

def empirical_stat(latent_vecs, indices):
	lat_mat = torch.zeros(0) #.cuda()
	for ind in indices:
		lat_mat = torch.cat([lat_mat, latent_vecs[ind]], 0)
	mean = torch.mean(lat_mat, 0)
	var = torch.var(lat_mat, 0)
	return mean, var

latent_size = config["latent_size"]

decoder = SDFNet(latent_size)

saved_model_state = torch.load(
	os.path.join(
		args.experiment_directory, ws.model_params_subdir, args.checkpoint + ".pth"
	)
)
saved_model_epoch = saved_model_state["epoch"]

decoder.load_state_dict(saved_model_state["model_state_dict"])

logging.debug(decoder)

err_sum = 0.0
repeat = 1
save_latvec_only = False
rerun = 0

reconstruction_dir = os.path.join(
	args.experiment_directory, ws.reconstructions_subdir, str(saved_model_epoch)
)

if not os.path.isdir(reconstruction_dir):
	os.makedirs(reconstruction_dir)

reconstruction_meshes_dir = os.path.join(
	reconstruction_dir, ws.reconstruction_meshes_subdir
)
reconstruction_partial_dir = os.path.join(
	reconstruction_dir, "Partial"
)
if not os.path.isdir(reconstruction_meshes_dir):
	os.makedirs(reconstruction_meshes_dir)
if not os.path.isdir(reconstruction_partial_dir):
	os.makedirs(reconstruction_partial_dir)
reconstruction_codes_dir = os.path.join(
	reconstruction_dir, ws.reconstruction_codes_subdir
)
if not os.path.isdir(reconstruction_codes_dir):
	os.makedirs(reconstruction_codes_dir)


def create_image(decoder, wh, sdf_gt, lat_vec, i):
    
	# make grid
	x = y = np.linspace(0,1.0,wh)
	xy = np.meshgrid(x, y)
	xy = torch.tensor(np.stack(xy).reshape(2,-1).T, dtype=torch.float)

	input = torch.cat([lat_vec.repeat(len(xy), 1), xy],-1)
	S = decoder(input)
	S = S.view(wh,-1)

	fig1 = plt.figure()
	c = plt.gca().imshow(S.detach().numpy(), cmap='gray')
	filename = os.path.join(reconstruction_meshes_dir, "SDF-" + str(i) + "inference.png")
	plt.savefig(filename)

	x_ = sdf_gt[:,0]
	y_ = sdf_gt[:,1]
	sdf_ = sdf_gt[:,2]
	sdf_gt_blah = sdf_gt[:,2].view(28,-1)
	fig2 = plt.figure()
	c = plt.gca().imshow(sdf_gt_blah.detach().numpy(), cmap='gray')
	# c = plt.scatter(x_, y_, c=sdf_)
	filename = os.path.join(reconstruction_meshes_dir, "SDF-" + str(i) + "gt.png")
	plt.savefig(filename)

def create_image_partial(decoder, wh, sdf_gt, lat_vec, i):
    
	# make grid
	x = y = np.linspace(0,1.0,wh)
	xy = np.meshgrid(x, y)
	xy = torch.tensor(np.stack(xy).reshape(2,-1).T, dtype=torch.float)

	input = torch.cat([lat_vec.repeat(len(xy), 1), xy],-1)
	S = decoder(input)
	S = S.view(wh,-1)

	fig1 = plt.figure()
	c = plt.gca().imshow(S.detach().numpy(), cmap='gray')
	filename = os.path.join(reconstruction_partial_dir, "SDF-" + str(i) + "inference.png")
	plt.savefig(filename)

	x_ = sdf_gt[:,0]
	y_ = sdf_gt[:,1]
	sdf_ = sdf_gt[:,2]
	sdf_gt = sdf_gt[:,2].view(28,-1)
	fig2 = plt.figure()
	c = plt.gca().imshow(sdf_gt.detach().numpy(), cmap='gray')
	# c = plt.scatter(x_, y_, c=sdf_)
	filename = os.path.join(reconstruction_partial_dir, "SDF-" + str(i) + "gt.png")
	plt.savefig(filename)

	# sdf_gt = sdf_gt[:,2].view(28,-1)
	sdf_gt[:, :14] = 0
	fig2 = plt.figure()
	c = plt.gca().imshow(sdf_gt.detach().numpy(), cmap='gray')
	# c = plt.scatter(x_, y_, c=sdf_)
	filename = os.path.join(reconstruction_partial_dir, "SDF-" + str(i) + "gt_partial.png")
	plt.savefig(filename)



In [None]:
### Reconstruct 1. Low resolution unseen images and 2. Low resolution partial images
data_sdf, _, label = next(iter(sdf_loader))
start = time.time()

for i, sdf in enumerate(data_sdf):
	err, latent = reconstruct(
		decoder,
		int(args.iterations),
		latent_size,
		sdf,
		0.01,  # [emp_mean,emp_var],
		0.1,
		num_samples=28*28,
		lr=5e-3,
		l2reg=True,
	)
	logging.debug("reconstruct time: {}".format(time.time() - start))
	err_sum += err
	# logging.debug("current_error avg: {}".format((err_sum / (ii + 1))))
	# logging.debug(ii)

	# logging.debug("latent: {}".format(latent.detach().cpu().numpy()))

	decoder.eval()


	# sdf_gt = sdf[:, 2].unsqueeze(1)
	sdf_gt = sdf
	wh = 1000

	with torch.no_grad():
		create_image(decoder, wh, sdf_gt, latent, i)

for i, sdf in enumerate(data_sdf):
	err, latent = reconstruct_partial(
		decoder,
		int(args.iterations),
		latent_size,
		sdf,
		0.01,  # [emp_mean,emp_var],
		0.1,
		num_samples=28*14,
		lr=5e-3,
		l2reg=True,
	)
	logging.debug("reconstruct time: {}".format(time.time() - start))
	err_sum += err
	# logging.debug("current_error avg: {}".format((err_sum / (ii + 1))))
	# logging.debug(ii)

	# logging.debug("latent: {}".format(latent.detach().cpu().numpy()))

	decoder.eval()
	# sdf_gt = sdf[:, 2].unsqueeze(1)
	sdf_gt = sdf
	wh = 1000
	
	with torch.no_grad():
		print('create partial')
		create_image_partial(decoder, wh, sdf_gt, latent, i)
	
	