In [None]:
import sys
from pathlib import Path

import numpy as np
import matplotlib.pyplot as plt
from datasets import load_dataset

sys.path.insert(0, str(Path().resolve()))  # makes retrieval/ importable

In [None]:
# ── Config ──────────────────────────────────────────────────────────────────
# Either set a prompt (retrieval will resolve the index) or set DATASET_IDX directly.

PROMPT      = "a stone castle"   # set to None to skip retrieval
DATASET_IDX = None               # override with e.g. 20041 to skip retrieval

In [None]:
# ── Retrieval (skip if DATASET_IDX is set) ───────────────────────────────────
if DATASET_IDX is None:
    from retrieval.retrieve import load_index, retrieve
    load_index()
    DATASET_IDX = retrieve(PROMPT)
    print(f'Nearest neighbour for "{PROMPT}" → dataset index {DATASET_IDX}')
else:
    print(f'Using dataset index {DATASET_IDX} directly')

In [None]:
# ── Load sample ──────────────────────────────────────────────────────────────
print('Fetching sample from dataset (streaming)...')
ds = load_dataset('PeterAM4/blockgen-3d', split='train', streaming=True)
sample = next(iter(ds.skip(DATASET_IDX).take(1)))

colors = np.array(sample['voxels_colors'],    dtype=np.float32)  # [3,32,32,32]
occ    = np.array(sample['voxels_occupancy'], dtype=np.float32)  # [1,32,32,32]

n_occupied = int((occ[0] > 0.5).sum())
print(f'Occupied voxels: {n_occupied}')
print(f'Prompt label in dataset: {sample.get("prompt", "(no prompt field)")}') 

In [None]:
# ── Centre voxels ────────────────────────────────────────────────────────────
def center_voxels(colors, occ):
    occupied = occ[0] > 0.5
    coords = np.argwhere(occupied)
    if len(coords) == 0:
        return colors, occ
    centroid = coords.mean(axis=0).astype(int)
    shift = np.array([16, 16, 16]) - centroid
    for ax, s in enumerate(shift):
        colors = np.roll(colors, int(s), axis=ax + 1)
        occ    = np.roll(occ,    int(s), axis=ax + 1)
    return colors, occ

colors_c, occ_c = center_voxels(colors.copy(), occ.copy())

In [None]:
# ── Plot raw vs centred ───────────────────────────────────────────────────────
%matplotlib inline

fig = plt.figure(figsize=(14, 6))
fig.suptitle(f'Dataset index {DATASET_IDX}  |  prompt: "{PROMPT}"', fontsize=13)

for col, (c, o, label) in enumerate([
    (colors,   occ,   'raw'),
    (colors_c, occ_c, 'centred (what gets built)'),
]):
    mask = o[0] > 0.5
    xs, ys, zs = np.where(mask)
    rgb = np.clip(np.stack([c[0, xs, ys, zs],
                            c[1, xs, ys, zs],
                            c[2, xs, ys, zs]], axis=1), 0, 1)

    ax = fig.add_subplot(1, 2, col + 1, projection='3d')
    ax.scatter(xs, zs, ys, c=rgb, s=20, depthshade=True)
    ax.set_xlabel('X'); ax.set_ylabel('Z'); ax.set_zlabel('Y (height)')
    ax.set_title(f'{label}  ({len(xs)} voxels)')
    ax.set_xlim(0, 31); ax.set_ylim(0, 31); ax.set_zlim(0, 31)

plt.tight_layout()
out_path = f'voxel_{DATASET_IDX}.png'
plt.savefig(out_path, dpi=150)
print(f'Saved → {out_path}')
plt.show()

In [None]:
# ── Colour histogram: what's in this voxel? ──────────────────────────────────
mask = occ_c[0] > 0.5
xs, ys, zs = np.where(mask)
rgb = np.clip(np.stack([colors_c[0, xs, ys, zs],
                        colors_c[1, xs, ys, zs],
                        colors_c[2, xs, ys, zs]], axis=1), 0, 1)

fig, axes = plt.subplots(1, 3, figsize=(12, 3))
for i, (ch, name, color) in enumerate(zip(rgb.T, ['R', 'G', 'B'], ['red', 'green', 'blue'])):
    axes[i].hist(ch, bins=32, color=color, alpha=0.7)
    axes[i].set_title(f'{name} channel')
    axes[i].set_xlim(0, 1)
fig.suptitle('Voxel colour distribution (centred)')
plt.tight_layout()
plt.show()

mean_rgb = rgb.mean(axis=0)
print(f'Mean colour (R,G,B): {mean_rgb.round(3)}')

In [None]:
# ── Quick sanity: compare a few different prompts ────────────────────────────
from retrieval.retrieve import load_index, retrieve
load_index()

test_prompts = [
    'a stone castle',
    'a wooden cabin',
    'an oak tree',
    'a stone tower',
    'a small house',
]

for p in test_prompts:
    idx = retrieve(p)
    s = next(iter(load_dataset('PeterAM4/blockgen-3d', split='train', streaming=True).skip(idx).take(1)))
    c = np.array(s['voxels_colors'], dtype=np.float32)
    o = np.array(s['voxels_occupancy'], dtype=np.float32)
    mask = o[0] > 0.5
    xs, ys, zs = np.where(mask)
    mean_rgb = np.stack([c[0,xs,ys,zs], c[1,xs,ys,zs], c[2,xs,ys,zs]], axis=1).mean(axis=0)
    print(f'{p:<25} → idx {idx:>6}  |  mean RGB {mean_rgb.round(3)}  |  {len(xs)} voxels')