In [1]:
import sys
sys.path.append("/usr/local/lib/python3.8/dist-packages/")
sys.path.append("../")

from tqdm.notebook import tqdm
from torch.optim import AdamW
from torch.utils.data import DataLoader

import os
import snook.data as sd
import snook.model as sm
import torch

In [2]:
colors = sd.COLORS
balls  = [f"../resources/fbx/ball_{color}.fbx" for color in colors]
cue    = "../resources/fbx/cue.fbx"
pool   = "../resources/fbx/pool.fbx"
hdri   = "../resources/hdri"

scene = sd.Scene(
    sd.cFiles(balls, cue, pool, hdri),
    sd.cTable((2.07793, 1.03677), (0.25, 0.20), 1.70342),
    sd.cDistances(0.1, 1.5, (10.0, 20.0)),
)

os.makedirs("renders", exist_ok=True)
os.makedirs("data", exist_ok=True)
for i in tqdm(range(16), desc="Generating"):
    scene.sample()
    scene.render(f"renders/{i}.png")
    scene.register(f"data/{i}.txt")

FBX Import: start importing ../resources/fbx/ball_black.fbx
FBX version: 7400
	FBX import: Prepare...
		Done (0.000208 sec)

	FBX import: Templates...
		Done (0.000086 sec)

	FBX import: Nodes...
		Done (0.000071 sec)

	FBX import: Connections...
		Done (0.000071 sec)

	FBX import: Meshes...
		Done (0.014240 sec)

	FBX import: Materials & Textures...
		Done (0.007236 sec)

	FBX import: Cameras & Lamps...
		Done (0.000289 sec)

	FBX import: Objects & Armatures...
		Done (0.000748 sec)

	FBX import: ShapeKeys...
		Done (0.000294 sec)

	FBX import: Animations...
		Done (0.000523 sec)

	FBX import: Assign materials...
		Done (0.000098 sec)

	FBX import: Assign textures...
		Done (0.005284 sec)

	FBX import: Cycles z-offset workaround...
		Done (0.000940 sec)

	Done (0.034459 sec)

Import finished.
FBX Import: start importing ../resources/fbx/ball_black.fbx
FBX version: 7400
	FBX import: Prepare...
		Done (0.001244 sec)

	FBX import: Templates...
		Done (0.000528 sec)

	FBX import: Nodes...

		Done (0.015292 sec)

	FBX import: Materials & Textures...
		Done (0.000611 sec)

	FBX import: Cameras & Lamps...
		Done (0.000051 sec)

	FBX import: Objects & Armatures...
		Done (0.000413 sec)

	FBX import: ShapeKeys...
		Done (0.000100 sec)

	FBX import: Animations...
		Done (0.000043 sec)

	FBX import: Assign materials...
		Done (0.000132 sec)

	FBX import: Assign textures...
		Done (0.000108 sec)

	FBX import: Cycles z-offset workaround...
		Done (0.000065 sec)

	Done (0.022199 sec)

Import finished.
FBX Import: start importing ../resources/fbx/ball_yellow.fbx
FBX version: 7400
	FBX import: Prepare...
		Done (0.000435 sec)

	FBX import: Templates...
		Done (0.000083 sec)

	FBX import: Nodes...
		Done (0.000069 sec)

	FBX import: Connections...
		Done (0.000066 sec)

	FBX import: Meshes...
		Done (0.014823 sec)

	FBX import: Materials & Textures...
		Done (0.000495 sec)

	FBX import: Cameras & Lamps...
		Done (0.000562 sec)

	FBX import: Objects & Armatures...
		Done (0.000331 sec

FBX version: 7400
	FBX import: Prepare...
		Done (0.000758 sec)

	FBX import: Templates...
		Done (0.000135 sec)

	FBX import: Nodes...
		Done (0.000074 sec)

	FBX import: Connections...
		Done (0.000075 sec)

	FBX import: Meshes...
		Done (9.906386 sec)

	FBX import: Materials & Textures...
		Done (0.002939 sec)

	FBX import: Cameras & Lamps...
		Done (0.000223 sec)

	FBX import: Objects & Armatures...
		Done (0.000295 sec)

	FBX import: ShapeKeys...
		Done (0.000081 sec)

	FBX import: Animations...
		Done (0.000077 sec)

	FBX import: Assign materials...
		Done (0.000771 sec)

	FBX import: Assign textures...
		Done (0.001617 sec)

	FBX import: Cycles z-offset workaround...
		Done (0.000213 sec)

	Done (10.435700 sec)

Import finished.


HBox(children=(FloatProgress(value=0.0, description='Generating', max=16.0, style=ProgressStyle(description_wi…




In [3]:
dataset = sd.ReMaHeDataset("renders", "data", spread=4.0)
loader  = DataLoader(dataset, batch_size=4, shuffle=True, num_workers=4, pin_memory=True)
layers  = [sm.Layer(16, 24, 1), sm.Layer(24, 32, 6), sm.Layer(32, 64, 6)]
model   = sm.AutoEncoder(layers, 3, 1, scale=0.4).cuda()
loss    = sm.AdaptiveWingLoss().cuda()
optim   = AdamW(model.parameters())

In [4]:
for epoch in tqdm(range(10), desc="Epoch"):
    pbar = tqdm(loader, "Training")
    for render, mask, heatmap in pbar:
        render, mask, heatmap = render.cuda(), mask.cuda(), heatmap.cuda()
        optim.zero_grad()
        error = loss(model(render).squeeze(1), heatmap)
        error.backward()
        pbar.set_postfix(loss=error.item())

HBox(children=(FloatProgress(value=0.0, description='Epoch', max=10.0, style=ProgressStyle(description_width='…

HBox(children=(FloatProgress(value=0.0, description='Training', max=4.0, style=ProgressStyle(description_width…




HBox(children=(FloatProgress(value=0.0, description='Training', max=4.0, style=ProgressStyle(description_width…




HBox(children=(FloatProgress(value=0.0, description='Training', max=4.0, style=ProgressStyle(description_width…




HBox(children=(FloatProgress(value=0.0, description='Training', max=4.0, style=ProgressStyle(description_width…




HBox(children=(FloatProgress(value=0.0, description='Training', max=4.0, style=ProgressStyle(description_width…




HBox(children=(FloatProgress(value=0.0, description='Training', max=4.0, style=ProgressStyle(description_width…




HBox(children=(FloatProgress(value=0.0, description='Training', max=4.0, style=ProgressStyle(description_width…




HBox(children=(FloatProgress(value=0.0, description='Training', max=4.0, style=ProgressStyle(description_width…




HBox(children=(FloatProgress(value=0.0, description='Training', max=4.0, style=ProgressStyle(description_width…




HBox(children=(FloatProgress(value=0.0, description='Training', max=4.0, style=ProgressStyle(description_width…



