In [None]:
import numpy as np
import matplotlib.pyplot as plt
from scipy import ndimage
import os
import sys
from tqdm.notebook import tqdm
import time
import copy
import argparse
import trimesh
import logging

import torch
import torchvision
import torch.nn as nn
import torch.optim.lr_scheduler as lr_scheduler

from modules import utils, volutils
from modules.models import INR

In [None]:
parser = argparse.ArgumentParser(description='INCODE')

# Shared Parameters
parser.add_argument('--input',type=str, default='./incode_data/Shape/occupancies/preprocessed_lucy.npy', help='Input image path')
parser.add_argument('--output',type=str, default='./output/', help='Output path')
parser.add_argument('--inr_model',type=str, default='incode', help='[gauss, mfn, relu, siren, wire, wire2d, ffn, incode]')
parser.add_argument('--lr',type=float, default=1e-4, help='Learning rate')
parser.add_argument('--using_schedular', type=bool, default=True, help='Whether to use schedular')
parser.add_argument('--scheduler_b', type=float, default=0.2, help='Learning rate scheduler')
parser.add_argument('--maxpoints', type=int, default=1e5, help='Batch size')
parser.add_argument('--niters', type=int, default=2, help='Number if iterations')
parser.add_argument('--steps_til_summary', type=int, default=1, help='Number of steps till summary visualization')
parser.add_argument('--res', type=int, default=512, help='resolution (N^3) of the mesh, same for xyz')
parser.add_argument('--mcubes_thres', type=float, default=0.5, help='Threshold for marching cubes')


# INCODE Parameters
parser.add_argument('--a_coef',type=float, default=0.1993, help='a coeficient')
parser.add_argument('--b_coef',type=float, default=0.0196, help='b coeficient')
parser.add_argument('--c_coef',type=float, default=0.0588, help='c coeficient')
parser.add_argument('--d_coef',type=float, default=0.0269, help='d coeficient')


args = parser.parse_args(args=[])

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## Loading Data

In [None]:
## Loading .npy file from the pre-processing code

data = np.load(args.input, allow_pickle=True)[()]

### Loading input im with the shape of [res, res, res]
scale = 1.0
im = np.unpackbits(data['im']).reshape(args.res, args.res, args.res)[..., None].astype(np.float32)[...,0]
im = ndimage.zoom(im/im.max(), [scale, scale, scale], order=0)

hidx, widx, tidx = np.where(im > 0.99)
im = im[hidx.min():hidx.max(),
        widx.min():widx.max(),
        tidx.min():tidx.max()]

H, W, T = im.shape


### Loading gt_im with the shape of [128, 128, 128] to input into the task-specific model
gt_im = np.unpackbits(data['gt_im']).reshape(128, 128, 128)[..., None].astype(np.float32)[...,0]
gt_im = ndimage.zoom(gt_im/gt_im.max(), [1.0, 1.0, 1.0], order=0)
hidx, widx, tidx = np.where(gt_im > 0.99)
gt_im = gt_im[hidx.min():hidx.max(),
                widx.min():widx.max(),
                tidx.min():tidx.max()]
gt_im = gt_im[None, None, ...]
gt_im = np.repeat(gt_im, 3, axis=1)
gt_im = torch.from_numpy(gt_im).to(device)


### loading the initial pose of mesh 
mesh_whl = data['mesh_whl']

## Defining Model

### Defining desired Positional Encoding

In [None]:
# Frequency Encoding
pos_encode_freq = {'type':'frequency', 'use_nyquist': True, 'mapping_input': int(max(H, W, T))}

# Gaussian Encoding
pos_encode_gaus = {'type':'gaussian', 'scale_B': 10, 'mapping_input': 256}

# No Encoding
pos_encode_no = {'type': None}

### Model Configureations

In [None]:
### Harmonizer Configurations
MLP_configs={'task': 'shape',
             'model': 'r3d_18',
             'truncated_layer':3,
             'in_channels': 128,             
             'hidden_channels': [64, 32, 4],
             'mlp_bias':0.3120,
             'activation_layer': nn.SiLU,
             'GT': gt_im
            }

### Model Configurations
model = INR(args.inr_model).run(in_features=3,
                                out_features=1, 
                                hidden_features=256,
                                hidden_layers=3,
                                first_omega_0=30.0,
                                hidden_omega_0=30.0,
                                pos_encode_configs=pos_encode_no, 
                                MLP_configs = MLP_configs
                               ).to(device)

## Training Code

In [None]:
# Optimizer setup
optim = torch.optim.Adam(lr=args.lr, params=model.parameters())
scheduler = lr_scheduler.LambdaLR(optim, lambda x: args.scheduler_b ** min(x / args.niters, 1))

# Initialize lists for IOU and best loss value as positive infinity
iou_values = []
best_iou = torch.tensor(float('inf'))

# Generate coordinate grid
coords = utils.get_coords(H, W, T, dim=3)[None, ...]

# Convert input image to a tensor and reshape
gt = torch.tensor(im).reshape(H * W * T, 1)[None, ...].to(device)

# Initialize a tensor for reconstructed data
rec = torch.zeros_like(gt)

# Check the args.maxpoints value
args.maxpoints = int(min(H*W*T, args.maxpoints))

In [None]:
for step in tqdm(range(args.niters)):
    # Randomize the order of data points for each iteration
    indices = torch.randperm(H*W*T)

    loss_values = []
    # Process data points in batches
    for b_idx in range(0, H*W*T, args.maxpoints):
        b_indices = indices[b_idx:min(H*W*T, b_idx+args.maxpoints)]
        b_coords = coords[:, b_indices, ...].to(device)
        b_indices = b_indices.to(device)
        
        # Calculate model output
        if args.inr_model == 'incode':
            model_output, coef = model(b_coords)  
        else:
            model_output = model(b_coords) 

        # Update the reconstructed data
        with torch.no_grad():
            rec[:, b_indices, :] = model_output

        # Calculate the output loss
        output_loss = ((model_output - gt[:, b_indices, :])**2).mean()
        
        if args.inr_model == 'incode':
            # Calculate regularization loss for 'incode' model
            a_coef, b_coef, c_coef, d_coef = coef[0]  
            reg_loss = args.a_coef * torch.relu(-a_coef) + \
                       args.b_coef * torch.relu(-b_coef) + \
                       args.c_coef * torch.relu(-c_coef) + \
                       args.d_coef * torch.relu(-d_coef)

            # Total loss for 'incode' model
            loss = output_loss + reg_loss 
        else: 
            # Total loss for other models
            loss = output_loss
        loss_values.append(loss.item())

        # Perform backpropagation and update model parameters
        optim.zero_grad()
        loss.backward()
        optim.step()

    
    # Calculate IOU
    with torch.no_grad():
        iou = volutils.get_IoU(rec, gt, args.mcubes_thres)
        iou_values.append(iou.item())

    # Adjust learning rate using a scheduler if applicable
    if args.using_schedular:
        if args.inr_model == 'incode' and 30 < step:
            scheduler.step()
        else:
            scheduler.step()

    # Prepare reconstructed shape for visualization
    imrec = rec[0, ...].reshape(H, W, T).detach().cpu().numpy()

    # Check if the current iteration's loss is the best so far
    if (iou > best_iou) or (step == 0):
        best_iou = iou
        best_img = copy.deepcopy(imrec)

    # Display intermediate results at specified intervals
    if step % args.steps_til_summary == 0:
        print("Epoch: {} | Total Loss: {:.5f} | IoU: {:.4f}".format(step, 
                                                                     np.mean(loss_values),
                                                                     iou.item()))

        
# Print maximum PSNR achieved during training
print('--------------------')
print('Max PSNR:', max(iou_values))
print('--------------------')


# Marching and saving volumes
expname = os.path.splitext(os.path.basename(args.input))[0]
os.makedirs(args.output + args.inr_model, exist_ok=True)
savename = f'{args.output}/{args.inr_model}/{expname}.dae'
volutils.march_and_save(best_img, mesh_whl, args.mcubes_thres, savename, True)

In [None]:
# Marching and saving volumes
expname = os.path.splitext(os.path.basename(args.input))[0]
os.makedirs(args.output + args.inr_model, exist_ok=True)
savename = f'{args.output}/{args.inr_model}/{expname}.dae'
volutils.march_and_save(best_img, mesh_whl, args.mcubes_thres, savename, True)

# Convergance Rate

In [None]:
font = {'font': 'Times New Roman', 'size': 12}

plt.figure()
axfont = {'family' : 'Times New Roman', 'weight' : 'regular', 'size'   : 10}
plt.rc('font', **axfont)

plt.plot(np.arange(len(iou_values[:-1])), iou_values[:-1], label = f"{(args.inr_model).upper()}")
plt.xlabel('# Epochs', fontdict=font)
plt.ylabel('IoU', fontdict=font)
plt.title('Shape Representation', fontdict={'family': 'Times New Roman', 'size': 12, 'weight': 'bold'})
plt.legend()
plt.grid(True, color='lightgray')

plt.show()