In [1]:
import os
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from gqn_dataset import GQNDataset, Scene, transform_viewpoint, sample_batch, GQNDataset_pdisco

from model import GQN

In [2]:
train_data_dir = "/home/mprabhud/dataset/shapenet_renders/npys/split_allpt.txt"

In [3]:
train_dataset = GQNDataset_pdisco(root_dir=train_data_dir, target_transform=transform_viewpoint)

In [4]:
train_dataset[0][1][1]

tensor([ 5.3157, -2.7362, -5.3157,  0.9397, -0.3420,  1.0000,  0.0000])

In [5]:
data  = []

with open(train_data_dir) as f:
    lines = f.readlines()
    
    for line in lines:
        data.append(line.split()[0])
        
all_files = [os.path.join(os.path.dirname(train_data_dir),f) for f in data if f.endswith(".p")]

In [6]:
train_loader = DataLoader(train_dataset, batch_size=36, shuffle=True)

In [7]:
for b in train_loader:
    x_data, v_data, metadata = b
    break

In [8]:
x_data.shape, v_data.shape, metadata

(torch.Size([36, 24, 64, 64, 3]), torch.Size([36, 24, 7]), {})

In [9]:
x, v, x_q, v_q, context_idx, query_idx = sample_batch(x_data, v_data, "Shepard-Metzler")
x = x.permute(0,1,4,2,3)
x_q = x_q.permute(0,3,1,2)

In [10]:
model = GQN(representation="tower", L=12, shared_core=True)

In [11]:
model

GQN(
  (phi): Tower(
    (conv1): Conv2d(3, 256, kernel_size=(2, 2), stride=(2, 2))
    (conv2): Conv2d(256, 256, kernel_size=(2, 2), stride=(2, 2))
    (conv3): Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (conv4): Conv2d(128, 256, kernel_size=(2, 2), stride=(2, 2))
    (conv5): Conv2d(263, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (conv6): Conv2d(263, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (conv7): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (conv8): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1))
  )
  (inference_core): InferenceCore(
    (downsample_x): Conv2d(3, 3, kernel_size=(4, 4), stride=(4, 4), bias=False)
    (upsample_v): ConvTranspose2d(7, 7, kernel_size=(16, 16), stride=(16, 16), bias=False)
    (upsample_r): ConvTranspose2d(256, 256, kernel_size=(16, 16), stride=(16, 16), bias=False)
    (downsample_u): Conv2d(128, 128, kernel_size=(4, 4), stride=(4, 4), bias=False)
    (cor

In [12]:
sigma_i, sigma_f = 2.0, 0.7
sigma = sigma_i


In [13]:
elbo = model(x, v, v_q, x_q, sigma)

In [14]:
elbo.shape

torch.Size([36])

In [30]:
x.shape

torch.Size([36, 1, 3, 64, 64])

In [15]:
context_idx, query_idx

([20], 13)

In [16]:
elbo

tensor([-19841.6699, -19875.7949, -19853.3613, -19862.2988, -19834.0488,
        -19847.9004, -19877.7051, -19850.6250, -19833.5723, -19911.8301,
        -19848.1777, -19823.2070, -19874.4316, -19873.2949, -19850.0684,
        -19872.8223, -19871.9336, -19886.7441, -19923.9219, -19851.8008,
        -19902.4551, -19841.6816, -19900.2539, -19909.7773, -19835.6074,
        -19911.2500, -19886.7500, -19868.8730, -19925.2168, -19908.2480,
        -19852.7969, -19885.3555, -19854.0703, -19889.6113, -19867.1113,
        -19823.1289], grad_fn=<AddBackward0>)