<a href="https://colab.research.google.com/github/tobiaskatsch/AirborneOpticalSectioning/blob/master/Airborne_Optical_Sectioning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Airborne Optical Sectioning

## Clone from GitHub and mount Google Drive

In [1]:
from google.colab import drive
drive.mount('/content/drive')
import os
import shutil
!pip install wandb
import wandb

# Define paths, names, and tokens
repo_name = 'AirborneOpticalSectioning'
base_path = "drive/MyDrive/ML/AirborneOpticalSectioning"
token_path = os.path.join(base_path, "token.txt")

# Read the token
with open(token_path, 'r') as file:
    token = file.read().strip()

# Prepare the repository URL
repo_url = 'https://github.com/tobiaskatsch/AirborneOpticalSectioning.git'.replace('https://', f'https://{token}:x-oauth-basic@')

Mounted at /content/drive
Collecting wandb
  Downloading wandb-0.16.1-py3-none-any.whl (2.1 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.1/2.1 MB[0m [31m27.8 MB/s[0m eta [36m0:00:00[0m
Collecting GitPython!=3.1.29,>=1.0.0 (from wandb)
  Downloading GitPython-3.1.40-py3-none-any.whl (190 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m190.6/190.6 kB[0m [31m25.9 MB/s[0m eta [36m0:00:00[0m
Collecting sentry-sdk>=1.0.0 (from wandb)
  Downloading sentry_sdk-1.38.0-py2.py3-none-any.whl (252 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m252.8/252.8 kB[0m [31m31.9 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting docker-pycreds>=0.4.0 (from wandb)
  Downloading docker_pycreds-0.4.0-py2.py3-none-any.whl (9.0 kB)
Collecting setproctitle (from wandb)
  Downloading setproctitle-1.3.3-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (30 kB)
Collecting gitdb<5,>=4.0.1 (from Gi

## Load Data

In [2]:
import os
import numpy as np
from tqdm import tqdm

data_path = "drive/MyDrive/ML/AirborneOpticalSectioning/data"

input_batches = []
target_batches = []

NUM_BATCHES = 333

for batch_idx in tqdm(range(NUM_BATCHES), desc="Loading Data"):
    input_file_name = f"{batch_idx}.npy"
    target_file_name = f"{batch_idx}_y.npy"

    input_batch = np.load(os.path.join(data_path, input_file_name))
    target_batch = np.load(os.path.join(data_path, target_file_name))

    input_batches.append(input_batch)
    target_batches.append(target_batch)

inputs = np.concatenate(input_batches, axis=0)
targets = np.concatenate(target_batches, axis=0)

import sys
size_in_bytes = sys.getsizeof(inputs) + sys.getsizeof(targets)
size_in_gigabytes = size_in_bytes / (1024 ** 3)
print(f"Size of dataset: {size_in_gigabytes} GB")

Loading Data: 100%|██████████| 333/333 [08:17<00:00,  1.50s/it]


Size of dataset: 10.396484658122063 GB


## Create Dataset

In [3]:
import torch
from torch.utils.data import Dataset

class AOSDataset(Dataset):
    def __init__(self, inputs, targets):
        self.inputs = torch.from_numpy(inputs)
        self.targets = torch.from_numpy(targets)

    def __len__(self):
        return len(self.inputs)

    def __getitem__(self, idx):
        input_sample = self.inputs[idx]
        target_sample = self.targets[idx]
        return input_sample, target_sample

dataset = AOSDataset(inputs, targets)
print("dataset size: ", len(dataset))

dataset size:  10646


## DataLoaders

In [4]:
from torch.utils.data import DataLoader, random_split

batch_size = 16

train_frac = 0.85
val_frac = 0.05
test_frac = 0.1

total_size = len(dataset)
train_size = int(train_frac * total_size)
val_size = int(val_frac * total_size)
test_size = total_size - train_size - val_size

train_dataset, val_dataset, test_dataset = random_split(dataset, [train_size, val_size, test_size])

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size)
test_loader = DataLoader(test_dataset, batch_size=batch_size)
print("train_size: ", total_size)
print("val_size: ", val_size)
print("test_size: ", test_size)

train_size:  10646
val_size:  532
test_size:  1065


## U-Net Model

In [5]:
import torch
import torch.nn as nn

class UNet(nn.Module):
    def __init__(self):
        super(UNet, self).__init__()

        # Contracting Path (Encoder)
        self.enc_conv1 = self.conv_block(3, 64)
        self.enc_conv2 = self.conv_block(64, 128)
        self.enc_conv3 = self.conv_block(128, 256)
        self.enc_conv4 = self.conv_block(256, 512)

        self.pool = nn.MaxPool2d(2)

        # Bottleneck
        self.bottleneck = self.conv_block(512, 1024)

        # Expansive Path (Decoder)
        self.upconv4 = self.upconv(1024, 512)
        self.dec_conv4 = self.conv_block(1024, 512)
        self.upconv3 = self.upconv(512, 256)
        self.dec_conv3 = self.conv_block(512, 256)
        self.upconv2 = self.upconv(256, 128)
        self.dec_conv2 = self.conv_block(256, 128)
        self.upconv1 = self.upconv(128, 64)
        self.dec_conv1 = self.conv_block(128, 64)

        # Output Convolution
        self.out_conv = nn.Conv2d(64, 1, kernel_size=1)

    def conv_block(self, in_channels, out_channels):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True)
        )

    def upconv(self, in_channels, out_channels):
        return nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2)

    def forward(self, x):
        # Contracting Path
        enc1 = self.enc_conv1(x)
        x = self.pool(enc1)
        enc2 = self.enc_conv2(x)
        x = self.pool(enc2)
        enc3 = self.enc_conv3(x)
        x = self.pool(enc3)
        enc4 = self.enc_conv4(x)
        x = self.pool(enc4)

        # Bottleneck
        x = self.bottleneck(x)

        # Expansive Path
        x = self.upconv4(x)
        x = torch.cat((x, enc4), dim=1)
        x = self.dec_conv4(x)
        x = self.upconv3(x)
        x = torch.cat((x, enc3), dim=1)
        x = self.dec_conv3(x)
        x = self.upconv2(x)
        x = torch.cat((x, enc2), dim=1)
        x = self.dec_conv2(x)
        x = self.upconv1(x)
        x = torch.cat((x, enc1), dim=1)
        x = self.dec_conv1(x)

        # Output Convolution
        x = self.out_conv(x)

        return x

## Training

In [6]:
import torch
import wandb

def train_epoch(train_loader, model, loss_fn, optimizer, device, epoch_id, log_every, val_loader=None, val_every=None):
    model.train()
    total_loss = 0

    for sample_id, (inputs, targets) in enumerate(train_loader):
        global_step = (epoch_id * len(train_loader)) + sample_id  # Updated to global_step

        # Move data to the device
        inputs, targets = inputs.float().to(device), targets.float().to(device)
        # Forward pass
        outputs = model(inputs)
        loss = loss_fn(outputs, targets)

        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

        # Logging
        if log_every and global_step % log_every == 0:
            wandb.log({"Train Loss": loss.item(), "Global Step": global_step})

        # Validation
        if val_loader and val_every and global_step % val_every == 0:
            validate(val_loader, model, loss_fn, device, global_step)  # Pass global_step

    average_loss = total_loss / len(train_loader)
    print(f"Epoch [{epoch_id+1}], Average Loss: {average_loss}")
    wandb.log({"Average Train Loss": average_loss, "Epoch": epoch_id+1})

def validate(val_loader, model, loss_fn, device, global_step):
    model.eval()
    val_loss = 0
    with torch.no_grad():
        for inputs, targets in val_loader:
            inputs, targets = inputs.float().to(device), targets.float().to(device)
            outputs = model(inputs)
            loss = loss_fn(outputs, targets)
            val_loss += loss.item()
    average_val_loss = val_loss / len(val_loader)
    print(f"Validation Loss: {average_val_loss}")
    wandb.log({"Validation Loss": average_val_loss, "Global Step": global_step})

def train(train_loader, model, loss_fn, optimizer, device, log_every, val_loader, val_every, epochs):
    for epoch_id in range(epochs):
        train_epoch(train_loader, model, loss_fn, optimizer, device, epoch_id, log_every, val_loader=val_loader, val_every=val_every)

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import wandb

epochs = 100
log_every = 100
val_every = 1000

wandb.login()

wandb.init(project='AirborneOpticalSectioning')

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = UNet()
model.to(device)

loss_fn = nn.MSELoss()
optimizer = optim.Adam(model.parameters())

train(train_loader, model, loss_fn, optimizer, device, log_every, val_loader, epochs, val_every)

<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
wandb: Paste an API key from your profile and hit enter, or press ctrl+c to quit:

 ··········


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mtobias-katsch42[0m ([33mthdeepresearch[0m). Use [1m`wandb login --relogin`[0m to force relogin


  return F.mse_loss(input, target, reduction=self.reduction)
  return F.mse_loss(input, target, reduction=self.reduction)


Validation Loss: 10699.027142693014
Validation Loss: 6315.890021829044
Validation Loss: 5979.178531422334
