In [27]:
#!pip install pil torch torchvision pandas numpy matplotlib

In [28]:
# from google.colab import drive
import numpy as np
from PIL import Image
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.cuda.is_available()

False

In [29]:
IMG_SIZE = 64
BATCH_SIZE = 32
VALID_BATCHES = 10
N = 9999   
NUM_POSITIONS = 9

DATA_DIR = "../../../NVIDIA Multimodal Models/code/data/replicator_data_cubes/"
# drive.mount('/content/drive')
# DATA_DIR = "/content/drive/MyDrive/data/replicator_data_cubes/"

print(list(os.listdir(DATA_DIR)))
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
loss_func = nn.MSELoss()

['azimuth.npy', 'positions.csv', 'lidar', 'colors.csv', 'rgb', 'pointcloud', 'distance', 'zenith.npy']


In [30]:
img_transforms = transforms.Compose([
    transforms.Resize(IMG_SIZE),
    transforms.ToTensor(),  # Scales data into [0,1]
])

def get_torch_xyza(lidar_depth, azimuth, zenith):
    x = lidar_depth * torch.sin(-azimuth[:, None]) * torch.cos(-zenith[None, :])
    y = lidar_depth * torch.cos(-azimuth[:, None]) * torch.cos(-zenith[None, :])
    z = lidar_depth * torch.sin(-zenith[None, :])
    a = torch.where(lidar_depth < 50.0, torch.ones_like(lidar_depth), torch.zeros_like(lidar_depth))
    xyza = torch.stack((x, y, z, a))
    return xyza

class ReplicatorDataset(Dataset):
    def __init__(self, root_dir, start_idx, stop_idx):
        self.root_dir = root_dir
        self.rgb_imgs = []
        self.lidar_depths = []
        self.positions = np.genfromtxt(
            root_dir + "positions.csv", delimiter=",", skip_header=1
        )[start_idx:stop_idx]
        

        azimuth = np.load(self.root_dir + "azimuth.npy")
        zenith = np.load(self.root_dir + "zenith.npy")
        self.azimuth = torch.from_numpy(azimuth).to(device)
        self.zenith = torch.from_numpy(zenith).to(device)

        for idx in range(start_idx, stop_idx):
            file_number = "{:04d}".format(idx)
            rbg_img = Image.open(self.root_dir + "rgb/" + file_number + ".png")
            rbg_img = img_transforms(rbg_img).to(device)
            self.rgb_imgs.append(rbg_img)

            lidar_depth = np.load(self.root_dir + "lidar/" + file_number + ".npy")
            lidar_depth = torch.from_numpy(lidar_depth).to(torch.float32).to(device)
            self.lidar_depths.append(lidar_depth)

    def __len__(self):
        return len(self.positions)

    def __getitem__(self, idx):
        rbg_img = self.rgb_imgs[idx]
        lidar_depth = self.lidar_depths[idx]
        lidar_xyza = get_torch_xyza(lidar_depth, self.azimuth, self.zenith)

        position = self.positions[idx]
        position = torch.from_numpy(position).to(torch.float32).to(device)

        return rbg_img, lidar_xyza, position
    

In [31]:
train_data = ReplicatorDataset(DATA_DIR, 0, N-VALID_BATCHES*BATCH_SIZE)
train_dataloader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)
valid_data = ReplicatorDataset(DATA_DIR, N-VALID_BATCHES*BATCH_SIZE, N)
valid_dataloader = DataLoader(valid_data, batch_size=BATCH_SIZE, shuffle=False, drop_last=True)

In [32]:
class BaseNet(nn.Module):
    def __init__(self, in_ch, out_ch, kernel_size=3):
        super().__init__()
        flattened_size = 200 * 8 * 8
        self.conv1 = nn.Conv2d(in_ch, 50, kernel_size, padding=1)
        self.conv2 = nn.Conv2d(50, 100, kernel_size, padding=1)
        self.conv3 = nn.Conv2d(100, 200, kernel_size, padding=1)
        self.pool = nn.MaxPool2d(2)
        self.fc1 = nn.Linear(flattened_size, 1000)
        self.fc2 = nn.Linear(1000, 100)
        self.fc3 = nn.Linear(100, out_ch)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = self.pool(F.relu(self.conv3(x)))
        x = torch.flatten(x, 1) # flatten all dimensions except batch
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

class EarlyNet(nn.Module):
    def __init__(self, in_chs, out_ch):
        super().__init__()
        self.base_net = BaseNet(sum(in_chs), out_ch)

    def forward(self, inputs):
        x = torch.cat(inputs, 1)
        x = self.base_net(x)
        return x
    
class LateNet(nn.Module):
    def __init__(self, networks, out_chs):
        super().__init__()
        self.networks = networks
        sum_out_chs = sum(out_chs)
        self.fc1 = nn.Linear(sum_out_chs, sum_out_chs * 10)
        self.fc2 = nn.Linear(sum_out_chs * 10, sum_out_chs)

    def forward(self, inputs):
        network_inputs = [net(inp) for net, inp in zip(self.networks, inputs)]
        x = torch.cat(network_inputs, 1)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

class CatNet(nn.Module):
    def __init__(self, in_chs, out_ch, kernel_size=3):
        super().__init__()
        self.networks = []
        for in_ch in in_chs:
            self.networks.append(nn.Sequential(
                nn.Conv2d(in_ch, 25, kernel_size, padding=1),
                nn.ReLU(),
                nn.MaxPool2d(2),
                nn.Conv2d(25, 50, kernel_size, padding=1),
                nn.ReLU(),
                nn.MaxPool2d(2),
                nn.Conv2d(50, 100, kernel_size, padding=1),
                nn.ReLU(),
                nn.MaxPool2d(2)
            ))
        
        self.fc1 = nn.Linear(200 * 8 * 8, 1000)
        self.fc2 = nn.Linear(1000, 100)
        self.fc3 = nn.Linear(100, out_ch)

    def forward(self, inputs):

        network_inputs = [net(inp) for net, inp in zip(self.networks, inputs)]
        
        x = torch.cat(network_inputs, 1)
        x = torch.flatten(x, 1) # flatten all dimensions except batch
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x


class MatMulNet(nn.Module):
    def __init__(self, in_chs, out_ch, kernel_size=3):
        super().__init__()
        self.networks = []
        for in_ch in in_chs:
            self.networks.append(nn.Sequential(
                nn.Conv2d(in_ch, 25, kernel_size, padding=1),
                nn.ReLU(),
                nn.MaxPool2d(2),
                nn.Conv2d(25, 50, kernel_size, padding=1),
                nn.ReLU(),
                nn.MaxPool2d(2),
                nn.Conv2d(50, 100, kernel_size, padding=1),
                nn.ReLU(),
                nn.MaxPool2d(2)
            ))
        
        self.fc1 = nn.Linear(200 * 8 * 8, 1000)
        self.fc2 = nn.Linear(1000, 100)
        self.fc3 = nn.Linear(100, out_ch)

    def forward(self, inputs):

        network_inputs = [net(inp) for net, inp in zip(self.networks, inputs)]
        
        x = torch.matmul(*network_inputs)
        x = torch.flatten(x, 1) # flatten all dimensions except batch
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

In [33]:
def format_positions(positions):
    return ['{0: .3f}'.format(x) for x in positions]

def get_outputs(model, batch, inputs_idx):
    inputs = batch[inputs_idx].to(device)
    target = batch[-1].to(device)
    outputs = model(inputs)
    return outputs, target

def print_loss(epoch, loss, outputs, target, is_train=True, is_debug=False):
    loss_type = "train loss:" if is_train else "valid loss:"
    print("epoch", str(epoch), loss_type, str(loss))
    if is_debug:
        print("example pred:", format_positions(outputs[0].tolist()))
        print("example real:", format_positions(target[0].tolist()))
        

def train_model(model, optimizer, input_fn, epochs, train_dataloader, valid_dataloader, target_idx=-1):
    train_losses = []
    valid_losses = []
    for epoch in range(epochs):
        model.train()
        train_loss = 0
        for step, batch in enumerate(train_dataloader):
            optimizer.zero_grad()
            target = batch[target_idx].to(device)
            outputs = model(input_fn(batch))
            
            loss = loss_func(outputs, target)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()

        train_loss = train_loss / (step + 1)
        train_losses.append(train_loss)
        print_loss(epoch, train_loss, outputs, target, is_train=True)
        
        model.eval()
        valid_loss = 0
        for step, batch in enumerate(valid_dataloader):
            target = batch[target_idx].to(device)
            outputs = model(input_fn(batch))
            valid_loss += loss_func(outputs, target).item()
        valid_loss = valid_loss / (step + 1)
        valid_losses.append(valid_loss)
        print_loss(epoch, valid_loss, outputs, target, is_train=False)
    return train_losses, valid_losses

In [34]:
early_net = EarlyNet([4, 4], NUM_POSITIONS).to(device)

rgb_net = BaseNet(4, NUM_POSITIONS).to(device)
xyz_net = BaseNet(4, NUM_POSITIONS).to(device)
late_net = LateNet([rgb_net, xyz_net], [NUM_POSITIONS, NUM_POSITIONS]).to(device)

cat_net = CatNet([4, 4], NUM_POSITIONS).to(device)

matmul_net = MatMulNet([4, 4], NUM_POSITIONS).to(device)

In [35]:
def experiment(model, learning_rate=0.001, epochs=10, train_dataloader=train_dataloader, valid_dataloader=valid_dataloader):
    def get_inputs(batch):
        inputs_rgb = batch[0].to(device)
        inputs_xyz = batch[1].to(device)
        return (inputs_rgb, inputs_xyz)
    optimizer = Adam(model.parameters(), lr=learning_rate)
    train_losses, valid_losses = train_model(
        model,
        optimizer,
        get_inputs,
        epochs,
        train_dataloader,
        valid_dataloader,
    )
    return train_losses, valid_losses

In [36]:
experiment(early_net, epochs=20, train_dataloader=train_dataloader, valid_dataloader=valid_dataloader)

KeyboardInterrupt: 