In [23]:
#!pip install pil torch torchvision pandas numpy matplotlib
!pip install wandb -qU

In [24]:
from google.colab import drive
import numpy as np
from PIL import Image
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
import wandb

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("cuda:", torch.cuda.is_available())
wandb.login()

  | |_| | '_ \/ _` / _` |  _/ -_)
[34m[1mwandb[0m: (1) Create a W&B account
[34m[1mwandb[0m: (2) Use an existing W&B account
[34m[1mwandb[0m: (3) Don't visualize my results
[34m[1mwandb[0m: Enter your choice:

 2


[34m[1mwandb[0m: You chose 'Use an existing W&B account'
[34m[1mwandb[0m: Logging into https://api.wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: Find your API key here: https://wandb.ai/authorize
[34m[1mwandb[0m: Paste an API key from your profile and hit enter:

 ··········


[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mvincenteichhorn[0m ([33mvincenteichhorn-hasso-plattner-institut[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [7]:
DATA_DIR = "/content/data"
ZIP_FILE = "https://nx82872.your-storageshare.de/s/RSd8ee55qQsMSPb/download"
# download the zip, extract and place contents into DATA_DIR
!wget -v $ZIP_FILE -O data.zip
!unzip data.zip -d $DATA_DIR

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
  inflating: /content/data/data/replicator_data_parallel/rgb/8117.png  
  inflating: /content/data/data/replicator_data_parallel/rgb/2399.png  
  inflating: /content/data/data/replicator_data_parallel/rgb/4791.png  
  inflating: /content/data/data/replicator_data_parallel/rgb/8422.png  
  inflating: /content/data/data/replicator_data_parallel/rgb/0220.png  
  inflating: /content/data/data/replicator_data_parallel/rgb/8587.png  
  inflating: /content/data/data/replicator_data_parallel/rgb/2179.png  
  inflating: /content/data/data/replicator_data_parallel/rgb/5948.png  
  inflating: /content/data/data/replicator_data_parallel/rgb/3123.png  
  inflating: /content/data/data/replicator_data_parallel/rgb/5448.png  
  inflating: /content/data/data/replicator_data_parallel/rgb/3282.png  
  inflating: /content/data/data/replicator_data_parallel/rgb/9434.png  
  inflating: /content/data/data/replicator_data_parallel/rgb/8511.png  

In [67]:
DATA_DIR = "/content/data/"
DATA_DIR = f"{DATA_DIR}data/assessment/"
print(list(os.listdir(DATA_DIR)))

['spheres', 'cubes']


In [68]:
IMG_SIZE = 64
BATCH_SIZE = 32
VALID_BATCHES = 10
N = 9999
NUM_POSITIONS = 9

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
loss_func = nn.MSELoss()

In [69]:
img_transforms = transforms.Compose([
    transforms.Resize(IMG_SIZE),
    transforms.ToTensor(),  # Scales data into [0,1]
])

def get_torch_xyza(lidar_depth, azimuth, zenith):
    x = lidar_depth * torch.sin(-azimuth[:, None]) * torch.cos(-zenith[None, :])
    y = lidar_depth * torch.cos(-azimuth[:, None]) * torch.cos(-zenith[None, :])
    z = lidar_depth * torch.sin(-zenith[None, :])
    a = torch.where(lidar_depth < 50.0, torch.ones_like(lidar_depth), torch.zeros_like(lidar_depth))
    xyza = torch.stack((x, y, z, a))
    return xyza

class MyDataset(Dataset):
    def __init__(self, root_dir, start_idx, stop_idx):
        self.classes = ["cubes", "spheres"]
        self.root_dir = root_dir
        self.rgb = []
        self.lidar = []
        self.class_idxs = []

        for class_idx, class_name in enumerate(self.classes):
            for idx in range(start_idx, stop_idx):
                file_number = "{:04d}".format(idx)
                rbg_img = Image.open(self.root_dir + class_name + "/rgb/" + file_number + ".png")
                rbg_img = img_transforms(rbg_img).to(device)
                self.rgb.append(rbg_img)

                lidar_depth = np.load(self.root_dir + class_name + "/lidar/" + file_number + ".npy")
                lidar_depth = torch.from_numpy(lidar_depth[None, :, :]).to(torch.float32).to(device)
                self.lidar.append(lidar_depth)

                self.class_idxs.append(torch.tensor(class_idx, dtype=torch.float32)[None].to(device))

    def __len__(self):
        return len(self.class_idxs)

    def __getitem__(self, idx):
        rbg_img = self.rgb[idx]
        lidar_depth = self.lidar[idx]
        class_idx = self.class_idxs[idx]
        return rbg_img, lidar_depth, class_idx


In [70]:
train_data = MyDataset(DATA_DIR, 0, N-VALID_BATCHES*BATCH_SIZE)
train_dataloader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)
valid_data = MyDataset(DATA_DIR, N-VALID_BATCHES*BATCH_SIZE, N)
valid_dataloader = DataLoader(valid_data, batch_size=BATCH_SIZE, shuffle=False, drop_last=True)

In [71]:
class BaseNet(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        kernel_size = 3

        # Convolution
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, 50, kernel_size, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(50, 100, kernel_size, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(100, 200, kernel_size, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(200, 200, kernel_size, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Flatten()
        )

        self.dense_emb = nn.Sequential(
            nn.Linear(200 * 4 * 4, 100),
            nn.ReLU(),
            nn.Linear(100, out_ch)
        )

    def forward(self, x):
        conv = self.conv(x)
        emb = self.dense_emb(conv)
        return F.normalize(emb)

class Classifier(nn.Module):
    def __init__(self, in_ch, n_classes):
        super().__init__()
        n_classes = 1

        self.classifier = nn.Sequential(
            nn.Linear(in_ch),
            nn.ReLU(),
            nn.Linear(100, n_classes)
        )

    def forward(self, x):
        return self.classifier(x)

class EarlyNet(nn.Module):
    def __init__(self, in_chs, out_ch, n_classes):
        super().__init__()
        self.base_net = BaseNet(sum(in_chs), out_ch)
        self.classifier = Classifier(out_ch, n_classes)

    def forward(self, inputs):
        x = torch.cat(inputs, 1)
        x = self.base_net(x)
        x = self.classifier(x)
        return x

class LateNet(nn.Module):
    def __init__(self, in_chs, out_ch, n_classes):
        super().__init__()
        self.networks = nn.ModuleList()
        for in_ch in in_chs:
            self.networks.append(BaseNet(in_ch, out_ch))
        sum_out_chs = out_ch * len(in_chs)
        self.classifier = Classifier(sum_out_chs, n_classes)

    def forward(self, inputs):
        network_outputs = [F.relu(net(inp)) for net, inp in zip(self.networks, inputs)]
        x = torch.cat(network_outputs, 1)
        x = self.classifier(x)
        return x


class MatMulNet(nn.Module):
    def __init__(self, in_chs, out_ch, kernel_size=3):
        super().__init__()
        self.networks = nn.ModuleList()
        for in_ch in in_chs:
            self.networks.append(nn.Sequential(
                nn.Conv2d(in_ch, 25, kernel_size, padding=1),
                nn.ReLU(),
                nn.MaxPool2d(2),
                nn.Conv2d(25, 50, kernel_size, padding=1),
                nn.ReLU(),
                nn.MaxPool2d(2),
                nn.Conv2d(50, 100, kernel_size, padding=1),
                nn.ReLU(),
                nn.MaxPool2d(2)
            ))

        self.fc1 = nn.Linear(200 * 8 * 4, 1000)
        self.fc2 = nn.Linear(1000, 100)
        self.fc3 = nn.Linear(100, out_ch)

    def forward(self, inputs):

        network_inputs = [F.relu(net(inp)) for net, inp in zip(self.networks, inputs)]

        x = torch.matmul(*network_inputs)
        x = torch.flatten(x, 1) # flatten all dimensions except batch
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

In [59]:
def format_positions(positions):
    return ['{0: .3f}'.format(x) for x in positions]

def get_outputs(model, batch, inputs_idx):
    inputs = batch[inputs_idx].to(device)
    target = batch[-1].to(device)
    outputs = model(inputs)
    return outputs, target

def print_loss(epoch, loss, outputs, target, is_train=True, is_debug=False):
    loss_type = "train loss:" if is_train else "valid loss:"
    print("epoch", str(epoch), loss_type, str(loss))
    if is_debug:
        print("example pred:", format_positions(outputs[0].tolist()))
        print("example real:", format_positions(target[0].tolist()))


def train_model(model, optimizer, input_fn, epochs, train_dataloader, valid_dataloader, target_idx=-1):
    train_losses = []
    valid_losses = []
    for epoch in range(epochs):
        model.train()
        train_loss = 0
        for step, batch in enumerate(train_dataloader):
            optimizer.zero_grad()
            target = batch[target_idx].to(device)
            outputs = model(input_fn(batch))

            loss = loss_func(outputs, target)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()

        train_loss = train_loss / (step + 1)
        train_losses.append(train_loss)
        print_loss(epoch, train_loss, outputs, target, is_train=True)

        model.eval()
        valid_loss = 0
        for step, batch in enumerate(valid_dataloader):
            target = batch[target_idx].to(device)
            outputs = model(input_fn(batch))
            valid_loss += loss_func(outputs, target).item()
        valid_loss = valid_loss / (step + 1)
        valid_losses.append(valid_loss)
        print_loss(epoch, valid_loss, outputs, target, is_train=False)
        wandb.log({
          "train/loss": train_loss,
          "valid/loss": valid_loss,
        })

    return train_losses, valid_losses

In [60]:
early_net = EarlyNet([4, 4], NUM_POSITIONS).to(device)

rgb_net = BaseNet(4, NUM_POSITIONS).to(device)
xyz_net = BaseNet(4, NUM_POSITIONS).to(device)
late_net = LateNet([rgb_net, xyz_net], [NUM_POSITIONS, NUM_POSITIONS], NUM_POSITIONS).to(device)

cat_net = CatNet([4, 4], NUM_POSITIONS).to(device)

matmul_net = MatMulNet([4, 4], NUM_POSITIONS).to(device)

In [61]:
def experiment(model, learning_rate=0.001, epochs=10, train_dataloader=train_dataloader, valid_dataloader=valid_dataloader):
    wandb.init(
      project="cilp-extended-assessment-fusion_comparison",
      # We pass a run name (otherwise it’ll be randomly assigned, like sunshine-lollypop-10)
      name=f"experiment_{model.__class__.__name__}",
      # Track hyperparameters and run metadata
      config={
        "learning_rate": learning_rate,
        "architecture": "MM",
        "epochs": epochs,
      }
    )
    def get_inputs(batch):
        inputs_rgb = batch[0].to(device)
        inputs_xyz = batch[1].to(device)
        return (inputs_rgb, inputs_xyz)
    optimizer = Adam(model.parameters(), lr=learning_rate)
    train_losses, valid_losses = train_model(
        model,
        optimizer,
        get_inputs,
        epochs,
        train_dataloader,
        valid_dataloader,
    )
    wandb.finish()
    return train_losses, valid_losses

In [62]:
experiment(early_net, epochs=20, train_dataloader=train_dataloader, valid_dataloader=valid_dataloader)

epoch 0 train loss: 6.756419404452999
epoch 0 valid loss: 5.899211263656616
epoch 1 train loss: 5.880294605596176
epoch 1 valid loss: 5.60968279838562
epoch 2 train loss: 5.761123864066522
epoch 2 valid loss: 5.550690364837647
epoch 3 train loss: 5.708968487796405
epoch 3 valid loss: 5.582197904586792
epoch 4 train loss: 5.662723545996559
epoch 4 valid loss: 5.584937191009521
epoch 5 train loss: 5.140413963242082
epoch 5 valid loss: 4.0964220523834225
epoch 6 train loss: 3.739746113486637
epoch 6 valid loss: 3.49600203037262
epoch 7 train loss: 3.339962653766405
epoch 7 valid loss: 3.3674086332321167
epoch 8 train loss: 2.728398662923977
epoch 8 valid loss: 2.032873249053955
epoch 9 train loss: 1.3907620261441793
epoch 9 valid loss: 1.3522067070007324
epoch 10 train loss: 0.9421257899691727
epoch 10 valid loss: 1.2334296703338623
epoch 11 train loss: 0.7152684478965027
epoch 11 valid loss: 1.1048839867115021
epoch 12 train loss: 0.5968741086737209
epoch 12 valid loss: 1.013459187746048

0,1
train/loss,█▇▇▇▇▆▅▄▄▂▂▂▁▁▁▁▁▁▁▁
valid/loss,█████▅▅▄▃▂▁▁▁▁▁▁▁▁▁▁

0,1
train/loss,0.23534
valid/loss,0.91155


([6.756419404452999,
  5.880294605596176,
  5.761123864066522,
  5.708968487796405,
  5.662723545996559,
  5.140413963242082,
  3.739746113486637,
  3.339962653766405,
  2.728398662923977,
  1.3907620261441793,
  0.9421257899691727,
  0.7152684478965027,
  0.5968741086737209,
  0.498447349726759,
  0.4346500269032472,
  0.37115297441845696,
  0.3216400309805049,
  0.29183690501562015,
  0.2678002835504267,
  0.23533728032909482],
 [5.899211263656616,
  5.60968279838562,
  5.550690364837647,
  5.582197904586792,
  5.584937191009521,
  4.0964220523834225,
  3.49600203037262,
  3.3674086332321167,
  2.032873249053955,
  1.3522067070007324,
  1.2334296703338623,
  1.1048839867115021,
  1.013459187746048,
  0.9733812034130096,
  1.0699026882648468,
  0.965760612487793,
  0.9265827357769012,
  0.9135573327541351,
  0.8934089660644531,
  0.9115458965301514])

In [63]:
experiment(late_net, epochs=20, train_dataloader=train_dataloader, valid_dataloader=valid_dataloader)

epoch 0 train loss: 7.085805825050303
epoch 0 valid loss: 6.3363118171691895
epoch 1 train loss: 5.913138372219161
epoch 1 valid loss: 5.66188154220581
epoch 2 train loss: 5.180394315561712
epoch 2 valid loss: 4.671744060516358
epoch 3 train loss: 4.60520323775462
epoch 3 valid loss: 4.471091079711914
epoch 4 train loss: 4.393731571980659
epoch 4 valid loss: 4.354999375343323
epoch 5 train loss: 4.235220491491406
epoch 5 valid loss: 4.330100655555725
epoch 6 train loss: 4.008480162810016
epoch 6 valid loss: 3.9873774528503416
epoch 7 train loss: 3.771113938053712
epoch 7 valid loss: 3.813468170166016
epoch 8 train loss: 3.46314595866677
epoch 8 valid loss: 3.45658118724823
epoch 9 train loss: 3.1779726530542436
epoch 9 valid loss: 3.24690101146698
epoch 10 train loss: 2.9760330362825202
epoch 10 valid loss: 3.167403769493103
epoch 11 train loss: 2.813352865888583
epoch 11 valid loss: 3.0798446893692017
epoch 12 train loss: 2.6788837878119867
epoch 12 valid loss: 3.146285963058472
epoch

0,1
train/loss,█▆▅▅▄▄▄▃▃▃▂▂▂▂▂▂▁▁▁▁
valid/loss,█▇▅▄▄▄▃▃▂▂▂▁▂▁▁▁▁▁▁▁

0,1
train/loss,1.96051
valid/loss,2.90283


([7.085805825050303,
  5.913138372219161,
  5.180394315561712,
  4.60520323775462,
  4.393731571980659,
  4.235220491491406,
  4.008480162810016,
  3.771113938053712,
  3.46314595866677,
  3.1779726530542436,
  2.9760330362825202,
  2.813352865888583,
  2.6788837878119867,
  2.5732290326364784,
  2.4511940968747172,
  2.3310100913837255,
  2.229201841038584,
  2.1273893406059567,
  2.0475583727786084,
  1.9605131544024739],
 [6.3363118171691895,
  5.66188154220581,
  4.671744060516358,
  4.471091079711914,
  4.354999375343323,
  4.330100655555725,
  3.9873774528503416,
  3.813468170166016,
  3.45658118724823,
  3.24690101146698,
  3.167403769493103,
  3.0798446893692017,
  3.146285963058472,
  2.9572325229644774,
  2.9985490798950196,
  2.99368360042572,
  2.895545792579651,
  2.8322766542434694,
  2.916111612319946,
  2.9028254747390747])

In [64]:
experiment(cat_net, epochs=20, train_dataloader=train_dataloader, valid_dataloader=valid_dataloader)

epoch 0 train loss: 4.151572912734076
epoch 0 valid loss: 1.80889151096344
epoch 1 train loss: 1.3353784536289064
epoch 1 valid loss: 1.1224386155605317
epoch 2 train loss: 0.8895630712146001
epoch 2 valid loss: 0.9068935990333558
epoch 3 train loss: 0.6571516803163566
epoch 3 valid loss: 0.8069858074188232
epoch 4 train loss: 0.4949865300726417
epoch 4 valid loss: 0.7089505851268768
epoch 5 train loss: 0.372176484753754
epoch 5 valid loss: 0.6663749933242797
epoch 6 train loss: 0.28999307702313987
epoch 6 valid loss: 0.6807492434978485
epoch 7 train loss: 0.2238926049691952
epoch 7 valid loss: 0.6765113353729248
epoch 8 train loss: 0.17852465544414048
epoch 8 valid loss: 0.6013859510421753
epoch 9 train loss: 0.14971074933149167
epoch 9 valid loss: 0.6220628499984742
epoch 10 train loss: 0.12717481364576233
epoch 10 valid loss: 0.6005114287137985
epoch 11 train loss: 0.11180579018415204
epoch 11 valid loss: 0.5870946854352951
epoch 12 train loss: 0.10556566458664193
epoch 12 valid los

0,1
train/loss,█▃▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁
valid/loss,█▄▃▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁

0,1
train/loss,0.07381
valid/loss,0.54653


([4.151572912734076,
  1.3353784536289064,
  0.8895630712146001,
  0.6571516803163566,
  0.4949865300726417,
  0.372176484753754,
  0.28999307702313987,
  0.2238926049691952,
  0.17852465544414048,
  0.14971074933149167,
  0.12717481364576233,
  0.11180579018415204,
  0.10556566458664193,
  0.09686598113879857,
  0.09471641984206951,
  0.08727454828308118,
  0.08289253308332914,
  0.0766411427383786,
  0.07188354530032502,
  0.07380517553661438],
 [1.80889151096344,
  1.1224386155605317,
  0.9068935990333558,
  0.8069858074188232,
  0.7089505851268768,
  0.6663749933242797,
  0.6807492434978485,
  0.6765113353729248,
  0.6013859510421753,
  0.6220628499984742,
  0.6005114287137985,
  0.5870946854352951,
  0.5720940053462982,
  0.5922597706317901,
  0.5568033933639527,
  0.5823369771242142,
  0.5756460070610047,
  0.5651071816682816,
  0.5670351684093475,
  0.5465286165475846])

In [65]:
experiment(matmul_net, epochs=20, train_dataloader=train_dataloader, valid_dataloader=valid_dataloader)

epoch 0 train loss: 4.265565169568093
epoch 0 valid loss: 2.108427810668945
epoch 1 train loss: 1.6903292422263039
epoch 1 valid loss: 1.4387306094169616
epoch 2 train loss: 1.3208947649459966
epoch 2 valid loss: 1.2761252999305726
epoch 3 train loss: 1.1024977414813262
epoch 3 valid loss: 1.0401997208595275
epoch 4 train loss: 0.9526560720031625
epoch 4 valid loss: 0.9301730275154114
epoch 5 train loss: 0.8160212170209317
epoch 5 valid loss: 0.882434344291687
epoch 6 train loss: 0.7373313599864378
epoch 6 valid loss: 0.7938616871833801
epoch 7 train loss: 0.6401270994090087
epoch 7 valid loss: 0.8201888382434845
epoch 8 train loss: 0.5828588827951064
epoch 8 valid loss: 0.7277527630329133
epoch 9 train loss: 0.5114198679363491
epoch 9 valid loss: 0.788463830947876
epoch 10 train loss: 0.476552776745613
epoch 10 valid loss: 0.7596542298793793
epoch 11 train loss: 0.4252698106678906
epoch 11 valid loss: 0.7582119107246399
epoch 12 train loss: 0.37932258296683924
epoch 12 valid loss: 0.7

0,1
train/loss,█▄▃▃▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁
valid/loss,█▅▄▃▂▂▂▂▁▂▂▂▁▁▁▁▁▂▁▁

0,1
train/loss,0.21385
valid/loss,0.72827


([4.265565169568093,
  1.6903292422263039,
  1.3208947649459966,
  1.1024977414813262,
  0.9526560720031625,
  0.8160212170209317,
  0.7373313599864378,
  0.6401270994090087,
  0.5828588827951064,
  0.5114198679363491,
  0.476552776745613,
  0.4252698106678906,
  0.37932258296683924,
  0.3375476522260154,
  0.3198965232794648,
  0.2901511339359725,
  0.26579474642971496,
  0.25341797807556116,
  0.25235479769130414,
  0.21385002170769585],
 [2.108427810668945,
  1.4387306094169616,
  1.2761252999305726,
  1.0401997208595275,
  0.9301730275154114,
  0.882434344291687,
  0.7938616871833801,
  0.8201888382434845,
  0.7277527630329133,
  0.788463830947876,
  0.7596542298793793,
  0.7582119107246399,
  0.7273328244686127,
  0.6477871417999268,
  0.7067088901996612,
  0.7268461287021637,
  0.7060805231332778,
  0.7740357339382171,
  0.6363675922155381,
  0.7282679677009583])