In [1]:
import os
import torch
from torch.utils.data import Dataset, DataLoader
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Device: " + device)

Device: cuda


# Load Dataset as Chunks

In [2]:
class ChunkedDepthDataset(Dataset):
    def __init__(self, root_dir):
        self.files = sorted([
            os.path.join(root_dir, f)
            for f in os.listdir(root_dir)
            if f.endswith(".pt")
        ])

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        return torch.load(self.files[idx])

In [3]:
DATA_ROOT = "D:/diffdepth_data/processed/train"
dataset = ChunkedDepthDataset(DATA_ROOT)
print("Number of chunks: ", len(dataset))

Number of chunks:  507


In [4]:
def chunk_collate(batch):
    samples = []
    for chunk in batch:
        samples.extend(chunk)
    return samples

# DataLoader

In [5]:
loader = DataLoader(
    dataset,
    batch_size = 1,
    shuffle = True,
    collate_fn = chunk_collate
)

batch = next(iter(loader))
print(len(batch))
print(batch[0]["rgb"].shape, batch[0]["depth"].shape)

  return torch.load(self.files[idx])


100
torch.Size([3, 256, 256]) torch.Size([1, 256, 256])


# Noise Scheduling

In [6]:
import torch

T = 1000

def linear_beta_schedule(timestep):
    beta_start = 1e-4
    beta_end = 0.02
    return torch.linspace(beta_start, beta_end, timestep)

betas = linear_beta_schedule(T)
alphas = 1.0 - betas
alphas_cumprod = torch.cumprod(alphas, dim=0)

In [7]:
betas = betas.to(device)
alphas = alphas.to(device)
alphas_cumprod = alphas_cumprod.to(device)

print(betas.shape)
print(alphas_cumprod[0], alphas_cumprod[-1])

torch.Size([1000])
tensor(0.9999, device='cuda:0') tensor(4.0358e-05, device='cuda:0')


# Forward Diffusion

In [8]:
def get_index_from_list(vals, t, x_shape):
    batch_size = t.shape[0]
    out = vals.gather(-1,t)
    return out.reshape(batch_size, *((1,) * (len(x_shape) - 1)))

In [9]:
def forward_diffusion_sample(x0, t, device=device):
    noise = torch.randn_like(x0)
    sqrt_alpha_cumprod = torch.sqrt(alphas_cumprod[t]).view(-1,1,1,1)
    sqrt_one_minus = torch.sqrt(1 - alphas_cumprod[t]).view(-1,1,1,1)
    return sqrt_alpha_cumprod * x0 + sqrt_one_minus * noise, noise

In [10]:
def preprocess_batch(batch, device=device):
    rgbs = torch.stack([s["rgb"] for s in batch]).to(device)
    depths = torch.stack([s["depth"] for s in batch]).to(device)
    return rgbs, depths

In [11]:
batch = next(iter(loader))
rgb, depth = preprocess_batch(batch)

t = torch.randint(0,T,(depth.shape[0],), device=device)
noisy_depth, noise = forward_diffusion_sample(depth.to(device), t)

print(noisy_depth.shape, noise.shape)

  return torch.load(self.files[idx])


torch.Size([100, 1, 256, 256]) torch.Size([100, 1, 256, 256])
