In [None]:
from types import SimpleNamespace
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import open_clip
import smplx
import torch

from awol.code.nnutils.object_net import ObjectNet

def get_surface_normals(p, i, j, k):
    vn = np.cross(p[j] - p[i], p[k] - p[j])
    vn = vn / np.linalg.norm(vn, axis=1)[:, None]
    vn = (vn + 1) / 2
    vn[:, :2] = 230 * (0.1 + vn[:, :2])
    vn[:, 2] = 255 - 128 * vn[:, 2]
    return np.rint(vn).astype(np.uint8)

class SMALLayer(smplx.SMPLLayer):
    NUM_JOINTS = 34
    NUM_BODY_JOINTS = 34
    SHAPE_SPACE_DIM = 145

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.vertex_joint_selector.extra_joints_idxs = torch.empty(0, dtype=torch.int32)

opts = SimpleNamespace()
opts.model_type = 'flow'
opts.flow_type = 'realnvp'
opts.object = 'animal'
opts.train_mask = True
opts.num_hidden = 1024
opts.num_blocks = 5
opts.animal_emb_dim = 145
opts.noise = True
opts.add_mask_cond = False
opts.no_compression = False

model = ObjectNet(opts)
model.load_state_dict(torch.load('awol/code/cachedir/snapshots/submission_animal_realnvp_mask_pred_net_6000.pth', weights_only=True))
model.eval()

clip_model, _, preprocess = open_clip.create_model_and_transforms('ViT-B-32', pretrained='laion2b_s34b_b79k')
tokenizer = open_clip.get_tokenizer('ViT-B-32')

smal = SMALLayer(model_path='awol/data/animal/smal_plus_nochumpy.pkl', num_betas=SMALLayer.SHAPE_SPACE_DIM)

In [None]:
animals = [
    "Bear",
    "Cat",
    "Cow",
    "Dog",
    "Giraffe",
]

with torch.inference_mode(), torch.cuda.amp.autocast():
    features = clip_model.encode_text(tokenizer(
        [
            "A photo of a " + animal
            for animal in animals
        ]
    ))

    features /= features.norm(dim=-1, keepdim=True)

    betas = model(features, predict=True, sigma=1)

    smal_output = smal(betas=betas)

In [None]:
for index in range(len(animals)):
    verts = smal_output.vertices[index].cpu().numpy()
    normals = get_surface_normals(verts, *smal.faces.T)

    fig, ax = plt.subplots(subplot_kw={"projection": "3d"})
    ax.plot_trisurf(*verts.T, triangles=smal.faces).set_facecolor(normals / 255)

    ax.set_box_aspect(np.ptp(verts, axis=0))
    ax.xaxis.set_major_locator(mpl.ticker.MultipleLocator(0.5))
    ax.yaxis.set_major_locator(mpl.ticker.MultipleLocator(0.1))
    ax.zaxis.set_major_locator(mpl.ticker.MultipleLocator(0.5))
    ax.set_title(animals[index])
    fig.show()