# **Part 1: Preparation**

### **Imports**

In [None]:
import os
import json

import numpy as np
import pandas as pd

import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
import torchvision.transforms.functional as TF

from tqdm import tqdm, trange
from PIL import Image
from IPython.display import clear_output
from torch.utils.data import Dataset, DataLoader

### **Settings**
I usually don't work with hardcoded settings like this, but it was the easiest way to create a tutorial.

In [None]:
# Annotation
NUMBER_OF_PIXELS = 150
START_INDEX_LEFT = 112
START_INDEX_RIGHT = 497

# Training
IMAGE_WIDTH = 128
IMAGE_HEIGHT = 64
LABEL_PADDING = 2

# **Part 2: Data Annotation**

In [None]:
def load_data():
    with open("rail_locations.json") as f:
        return json.load(f)

In [None]:
def show_rails(image, start_indexes, number_of_pixels):
    fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(20, 5))

    for i in range(2):
        axes[i].imshow(image.crop([start_indexes[i], 0, start_indexes[i] + number_of_pixels, image.size[1]]), cmap="gray")
        axes[i].set_xticks([i for i in range(0, number_of_pixels, 5)])
        axes[i].set_yticks([])
    
    plt.tight_layout()
    plt.show()

In [None]:
def annotate_rail_images(start_indexes, number_of_pixels, split):
    training_data = load_data()
    np.random.shuffle(training_data)
    
    for index, t in enumerate(training_data):
        if "locations" in t.keys(): continue
        clear_output()
    
        image = Image.open(t["image"])
        show_rails(image, start_indexes, number_of_pixels)
        locations = input("Rail centers:").split()
        locations = [int(l) + start_indexes[i] for i, l in enumerate(locations)]
    
        training_data[index]["locations"] = locations
        training_data[index]["split"] = split

        with open("rail_locations.json", "w") as f:
            json.dump(training_data, f)

    clear_output()

In [None]:
split = "train" # or "valid"

annotate_rail_images([START_INDEX_L, START_INDEX_R], NUMBER_OF_PIXELS, split)

# **Part 3: Data Loading**

In [None]:
def create_label(x, size, padding):
    label = torch.zeros(size)
    start = max(0, int(round(size * x)) - padding)
    end = min(size - 1, int(round(size * x)) + padding + 1)
    label[start:end] = 1
    return label

def flip_image(crop, x):
    if np.random.rand() < 0.5:
        crop = crop.transpose(Image.FLIP_LEFT_RIGHT)
        x = 1 - x
        
    if np.random.rand() < 0.5:
        crop = crop.transpose(Image.FLIP_TOP_BOTTOM)

def create_crop(image, x, start_index, train, number_of_pixels, size, padding):
    x_start = x - (np.random.uniform(10, number_of_pixels - 10)) if train else start_index
    crop = image.crop([x_start, 0, x_start + number_of_pixels, image.size[1]])
    crop = crop.resize((size, 64), Image.LANCZOS)
    x = (x - x_start) / number_of_pixels

    if train:
        flip_image(crop, x)
        crop = TF.adjust_brightness(crop, np.random.uniform(0.5, 1.5))
        crop = TF.adjust_contrast(crop, np.random.uniform(0.8, 1.2))
    
    label = create_label(x, size, padding)

    return TF.to_tensor(crop), label

def get_crops(data_point, train=False):
    left_image, x_left = create_crop(
        image=data_point["image"], x=data_point["locations"][0], start_index=START_INDEX_LEFT, 
        train=train, number_of_pixels=NUMBER_OF_PIXELS, size=IMAGE_WIDTH, padding=LABEL_PADDING
    )
    
    right_image, x_right = create_crop(
        image=data_point["image"], x=data_point["locations"][1], start_index=START_INDEX_RIGHT,
        train=train, number_of_pixels=NUMBER_OF_PIXELS, size=IMAGE_WIDTH, padding=LABEL_PADDING
    )
    
    return torch.cat([left_image, right_image]), torch.cat([x_left, x_right])

In [None]:
class RailDataset(Dataset):
    def __init__(self, images, train):
        self.images = images
        self.train = train

    def __len__(self):
        return len(self.images)

    def __getitem__(self, index):
        images, labels = get_crops(self.images[index], self.train)
        return images, labels

In [None]:
with open("rail_locations.json") as f:
    data = json.load(f)
    data = [d for d in data if "locations" in d.keys()]

for i in range(len(data)):
    data[i]["image"] = Image.open(data[i]["image"])

In [None]:
data_train = [d for d in data if d["split"] == "train"]
data_valid = [d for d in data if d["split"] == "valid"]

In [None]:
batch_size = 16

dataset_train = RailDataset(images=data_train, train=True)
dataloader_train = DataLoader(dataset_train, shuffle=True, batch_size=batch_size)

dataset_valid = RailDataset(images=data_valid, train=False)
dataloader_valid = DataLoader(dataset_valid, shuffle=False, batch_size=batch_size)

# **Part 4: Training**

In [None]:
class RailDetector(torch.nn.Module):
    def __init__(self, size=1):
        super(RailDetector, self).__init__()
        self.conv1 = torch.nn.Conv2d(1, size, kernel_size=3, stride=2, padding=1)
        self.conv2 = torch.nn.Conv2d(size, 2 * size, kernel_size=3, stride=2, padding=1)
        self.conv3 = torch.nn.Conv2d(2 * size, 4 * size, kernel_size=3, stride=2, padding=1)
        self.embedding = torch.nn.Linear(4 * size * 2, IMAGE_WIDTH)
    
    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), (2, 2)))
        x = F.relu(F.max_pool2d(self.conv2(x), (2, 2)))
        x = F.relu(F.max_pool2d(self.conv3(x), (2, 2)))
        return self.embedding(x.view(x.shape[0], -1))

In [None]:
def validate(model, dataloader, loss_f):
    model.eval()
    valid_losses = []
    for i, (images, labels) in enumerate(dataloader):
        images = images.view(-1, IMAGE_HEIGHT, IMAGE_WIDTH).unsqueeze(1)
        labels = labels.view(-1, IMAGE_WIDTH)
        output = model(images)
        loss = loss_f(output, labels)
        valid_losses.append(loss.item())
    return np.mean(valid_losses)

In [None]:
def get_average_position(model, data, train=False):
    differences = []
    
    for d in data:
        images, labels = get_crops(d, train=train)
        images = images.view(-1, IMAGE_HEIGHT, IMAGE_WIDTH).unsqueeze(1)
        labels = labels.view(-1, IMAGE_WIDTH)
        output = model(images)
    
        for i in range(2):
            x = pd.Series(labels[i].detach().numpy()).rolling(5, center=True, min_periods=1).mean().argmax()
            p = pd.Series(output[i].detach().numpy()).rolling(5, center=True, min_periods=1).mean().argmax()
            differences.append(abs(x-p))

    return np.mean(differences)

In [None]:
number_of_epochs = 5000

loss_function = torch.nn.BCEWithLogitsLoss()
weights = "model.pth"
model = RailDetector(size=8)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
best_score = get_average_position(model, data_valid)
progress_bar = trange(number_of_epochs)

for epoch in progress_bar:
    train_losses = []
    
    model.train()
    for i, (images, labels) in enumerate(dataloader_train):
        optimizer.zero_grad()
        images = images.view(-1, IMAGE_HEIGHT, IMAGE_WIDTH).unsqueeze(1)
        labels = labels.view(-1, IMAGE_WIDTH)
        output = model(images)
        loss = loss_function(output, labels)
        loss.backward()
        optimizer.step()
        
        train_losses.append(loss.item())

    valid_loss = validate(model, dataloader_valid, loss_function)
    average_position = get_average_position(model, data_valid)
    if average_position < best_score:
        best_score = average_position
        torch.save(model.state_dict(), weights)

    progress_bar.set_description("T: {:.3f}, V: {:.3f}, P: {:.2f}, B: {:.2f}".format(
        np.mean(train_losses), valid_loss, average_position, best_score
    ))

model.load_state_dict(torch.load(weights, weights_only=True))

In [None]:
def create_red_alpha(l_image):
    red_channel = l_image
    green_channel = Image.new("L", l_image.size, 0)
    blue_channel = Image.new("L", l_image.size, 0)
    alpha_channel = Image.new("L", l_image.size, 128)
    rgba_image = Image.merge("RGB", (red_channel, green_channel, blue_channel))
    return rgba_image, red_channel

In [None]:
def create_blue_alpha(l_image):
    blue_channel = l_image
    green_channel = Image.new("L", l_image.size, 0)
    red_channel = Image.new("L", l_image.size, 0)
    alpha_channel = Image.new("L", l_image.size, 128)
    rgba_image = Image.merge("RGB", (red_channel, green_channel, blue_channel))
    return rgba_image, blue_channel

In [None]:
for d in data_valid:
    images, labels = get_crops(d, train=False)
    images = images.view(-1, IMAGE_HEIGHT, IMAGE_WIDTH).unsqueeze(1)
    labels = labels.view(-1, IMAGE_WIDTH)
    output = model(images)

    for i in range(2):
        red_alpha = labels[i].unsqueeze(0).repeat(IMAGE_HEIGHT, 1)
        red_overlay, red_alpha = create_red_alpha(TF.to_pil_image(red_alpha))

        blue_alpha = torch.zeros(IMAGE_WIDTH)
        x = pd.Series(output[i].detach().numpy()).rolling(5, center=True, min_periods=1).mean().argmax()
        for j in range(-2, 3):
            if x+j >= 0 and x+j < IMAGE_WIDTH:
                blue_alpha[x+j] = 1
        
        blue_alpha = blue_alpha.unsqueeze(0).repeat(IMAGE_HEIGHT, 1)
        blue_overlay, blue_alpha = create_blue_alpha(TF.to_pil_image(blue_alpha))
        
        pil_image = TF.to_pil_image(images[i]).convert("RGBA")
        pil_image.paste(red_overlay, (0, 0), mask=Image.fromarray((np.array(red_alpha) / 2).astype(np.uint8)))
        pil_image.paste(blue_overlay, (0, 0), mask=Image.fromarray((np.array(blue_alpha) / 2).astype(np.uint8)))
        plt.axis("off")
        plt.imshow(pil_image)
        plt.show()