In [1]:
import fvdb
from fvdb.nn import VDBTensor
import torch
from UNet import SparseUNet
import tqdm.notebook as tqdm

fvdb.nn.SparseConv3d.allow_tf32 = False

In [2]:
from torch.utils.data import Dataset
import os

class ChunkDataset(Dataset):
    def __init__(self, chunksPath):
        self.paths = []

        for filename in os.listdir(chunksPath):
            self.paths.append(f"{chunksPath}/{filename}")

    def __len__(self):
        return len(self.paths)
    
    def __getitem__(self, idx):
        grid_batch, labels, names = fvdb.load(self.paths[idx])
        labels.to(torch.long)

        return VDBTensor(grid_batch, labels)

In [3]:
dataset = ChunkDataset("data/training_data/chunks")
dataloader = torch.utils.data.DataLoader(dataset, batch_size=8, collate_fn=fvdb.jcat, shuffle=True, prefetch_factor=24, num_workers=6)

In [4]:
len(dataloader)

6979

In [5]:
num_classes = 0
with open("minecraft-serialization/block_list.txt", 'r') as file:
    num_classes = sum(1 for line in file)

model = SparseUNet(num_classes).to('cuda')
loss = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

In [6]:
# Training Loop
epochs = 5
with tqdm.tqdm(total=epochs) as pbar:
    for epoch in range(epochs):
        for i, vdb_tensor in enumerate(dataloader):
            vdb_tensor = vdb_tensor.cuda()
            target = vdb_tensor.data.jdata.squeeze().to(torch.long)
            actives = vdb_tensor.grid.jagged_like(torch.ones(target.shape[0], device='cuda', dtype=torch.float32).unsqueeze(1))

            X = VDBTensor(grid=vdb_tensor.grid, data=actives)
            optimizer.zero_grad()
            y_hat = model(X)

            l = loss(y_hat.data.jdata, target)

            l.backward()
            optimizer.step()
            pbar.set_description(f"Epoch {epoch}, batch {i}, loss {l.item()}")
        pbar.update(1)

  0%|          | 0/5 [00:00<?, ?it/s]

KeyboardInterrupt: 