# Setup

In [None]:
# Standard Python libraries
import sys
import random
import string
import os
import math

# Third party libraries
from pylibdmtx.pylibdmtx import encode
from PIL import Image, ImageFont, ImageDraw
import numpy as np
import cv2
import matplotlib.pyplot as plt
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import v2
import torch
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

from retinex import msrcr
import skimage

from ultralytics import YOLO

# Custom funcs
sys.path.append('../scripts')
from hourglass import *

## Parameters

In [None]:
input_size = 256 # Input image size for model (orig hourglasses 256x256)
output_size = 64 # Output heatmap size for model (orig hourglasses 64x64)
yolo_size = 640  # YOLOv11 input size (orig yolo 640x640)
n_stacks = 8     # Number of stacks in hourglass model

sigma = 1        # Gaussian distribution (spread) for heatmap generation

n_train = 800    # Number of training samples
n_val = 100      # Number of validation samples
n_test = 100     # Number of test samples

batch_size = 1   # Batch size for training

max_complex_epoch = 150 # Number of epochs where complex data is used

# Calculating weight for loss function
# pixels_per_gauss = (5 + 7 + (9*5) + 7 + 5) * 4 # Circles are built from vertical pixels of 5, 7, 9, 9, 9, 9, 9, 7, 5 (and there are 4 of them)
pixels_per_gauss = (3 + 5 + 5 + 5 + 3) * 4 # Circles are built from vertical pixels of 3, 5, 5, 5, 3
pixels_total = output_size*output_size
weight = (pixels_total - pixels_per_gauss) / pixels_per_gauss
print(weight) # There are this many times more background pixels than gaussian pixels

reset_synth = False # Set to True to regenerate synthetic data

In [None]:
# Example with 7x7 patch from https://medium.com/towards-data-science/human-pose-estimation-with-stacked-hourglass-network-and-tensorflow-c4e9f84fd3ce
# pixels_per_gauss = 7*7
# pixels_total = 64*64
# weight = (pixels_total - pixels_per_gauss) / pixels_per_gauss
# print(weight) # prints ~82.6

# Data Synthesis

In [None]:
def gen_string():
    '''
    Generates a serial number to encode
    
    Serial numbers are:
    - 11 characters long
    - Index 0, 2, 4, 5, 6, 7, 8, 9, 10 are random digits
    - Index 1 and 3 are uppercase letters
    - Index 11, 12, 13, 14 are an incremental number starting from 0001

    Example serial number: 4 L 4 N 0418028 0001
    '''

    to_encode = ''

    # first 11 indexes
    for j in range(11):
        # 1 and 3 are uppercase
        if j in [1, 3]:
            to_encode += random.choice(string.ascii_uppercase)
        else:
            to_encode += str(random.randrange(0, 10))

    # last 4 indexes
    end = str(random.randrange(1, 99))
    if len(end) == 1:
        end = '0' + end
    elif len(end) == 2:
        end = '00' + end
    else:
        end = '000' + end
    to_encode += end

    return to_encode

def encode_image(to_encode, dmc_size):
    '''Creates a PIL image containing DMC encoding of given string'''

    encoded = encode(to_encode.encode('utf8'))
    img = Image.frombytes('RGB', (encoded.width, encoded.height), encoded.pixels)

    # upscale image
    img = img.resize((dmc_size, dmc_size), Image.NEAREST)

    return img

def get_corner_coords(img, debug=False):
    '''Returns the coordinates of corners of DMC'''
    padding_perc = 0.1 # Percentage of image size to use as padding
    padding = int(padding_perc * img.width) # Padding around DMC info zone in pixels (we only want inner modules!)
    # padding = 26 # Padding around DMC info zone in pixels (we only want inner modules!)

    raw_coords = [] # Raw pixel coords of dmc corners
    label_info = [] # Line by line label info

    # Get coords of each corner of DMC
    top_left = (padding-1, padding-1)
    top_right = (img.width-padding, padding-1)
    bottom_left = (padding-1, img.height-padding)
    bottom_right = (img.width-padding, img.height-padding)

    # Paint corners red for viz/debug
    if debug:
        img = img.convert('RGB')
        img.putpixel(top_left, (255, 0, 0))
        img.putpixel(top_right, (255, 0, 0))
        img.putpixel(bottom_left, (255, 0, 0))
        img.putpixel(bottom_right, (255, 0, 0))

    # Normalize pixel coords to 0-1
    top_left_norm = [top_left[0] / img.width, top_left[1] / img.height]
    top_right_norm = [top_right[0] / img.width, top_right[1] / img.height]
    bottom_left_norm = [bottom_left[0] / img.width, bottom_left[1] / img.height]
    bottom_right_norm = [bottom_right[0] / img.width, bottom_right[1] / img.height]

    # Add raw coords
    raw_coords.extend([top_left, top_right, bottom_left, bottom_right])

    # Add normalized coords
    label_info.extend([top_left_norm, top_right_norm, bottom_left_norm, bottom_right_norm])

    return raw_coords, label_info, img

def get_heatmaps_basic(img, raw_coords, heatmap_size, debug=False):
    '''Returns "heatmaps" for each corner of DMC, except they are just single points'''
    
    # Create empty heatmaps
    heatmaps = np.zeros((4, heatmap_size, heatmap_size))

    # Create "heatmaps" for each corner
    for i in range(4):
        # Create "heatmap"
        heatmap = np.zeros((heatmap_size, heatmap_size))
        for y in range(img.height):
            for x in range(img.width):
                if x == raw_coords[i][0] and y == raw_coords[i][1]:
                    heatmap[y, x] = 1
        
        # Scale to heatmap_size x heatmap_size
        heatmap = cv2.resize(heatmap, (heatmap_size, heatmap_size))

        # Add heatmap to heatmaps
        heatmaps[i] = heatmap
    
    # Paint heatmaps on image for viz/debug
    if debug:
        # Scale heatmaps to image size
        debug_heatmaps = []
        for i in range(4):
            debug_heatmaps.append(cv2.resize(heatmaps[i], (img.width, img.height)))
        for i in range(4):
            for y in range(img.height):
                for x in range(img.width):
                    if debug_heatmaps[i][y, x] > 0:
                        img.putpixel((x, y), (int(debug_heatmaps[i][y, x] * 255), 0, 0))
    
    return heatmaps

def get_texture_crop(textures_path, size, debug=False):
    '''Gets a random texture image from the given path and returns a random crop'''

    texture = random.choice(os.listdir(textures_path))

    if debug:
        print(texture)

    texture = Image.open(os.path.join(textures_path, texture))

    # get random crop
    transform = v2.Compose([
        v2.RandomCrop((size, size))
    ])
    texture = transform(texture)

    return texture

In [None]:
# Testing texture crops (uncomment debug print)
get_texture_crop('../data/textures/', size=yolo_size, debug=True)

In [None]:
# Testing functions
test = gen_string()
img = encode_image(test, dmc_size=yolo_size)
raw_coords, label_info, img = get_corner_coords(img, debug=True)
print(raw_coords)
print(label_info)
heatmaps = get_heatmaps_basic(img, raw_coords, heatmap_size=yolo_size, debug=True)
print(heatmaps.shape)
print(heatmaps[0].shape)
print(heatmaps[0])

# Display image
display(img)

# Display heatmaps
plt.figure(figsize=(12, 12))
for i in range(4):
    plt.subplot(2, 2, i+1)
    heatmap = heatmaps[i]
    # scale up to input size
    heatmap = cv2.resize(heatmap, (yolo_size, yolo_size))
    plt.imshow(heatmap)
plt.show()

In [None]:
def gen_save(type, yolo_size):
    '''Generates a random serial number, encodes it into a DMC image, and saves it to train/val/test folders'''

    to_encode = gen_string()
    img = encode_image(to_encode, dmc_size=yolo_size)

    # Get corner values
    raw_coords, label_info, img = get_corner_coords(img)

    # Debug print
    if (0, 0) in raw_coords:
        print(raw_coords)
        error

    # Get heatmaps (basic - single point as we will create the actual heatmaps during augmentation)
    heatmaps = get_heatmaps_basic(img, raw_coords, heatmap_size=yolo_size, debug=False)

    # Convert heatmaps to tensor
    heatmaps = torch.tensor(heatmaps).float()

    # Normalize image
    img = np.array(img)
    img = img / 255
    img = torch.tensor(img).float()
    img = img.permute(2, 0, 1) # Change to having channel dim first

    # Generate random texture crop (size should match input for model)
    texture = get_texture_crop('../data/textures/', size=yolo_size)

    # Normalize texture
    texture = np.array(texture)
    texture = texture / 255
    texture = torch.tensor(texture).float()
    texture = texture.permute(2, 0, 1) # Change to having channel dim first

    # Combine image, texture, and heatmaps into single tensor
    img = torch.cat((img, texture, heatmaps), dim=0)

    # Save image texture heatmaps wombo combo
    torch.save(img, f'../data/hourglass_localization_rectification/{type}/{to_encode}.pt')

    return

def delete_old():
    '''Deletes all images and labels in train/val/test folders'''

    for folder in ['train', 'val', 'test']:
        for file in os.listdir(f'../data/hourglass_localization_rectification/{folder}'):
            os.remove(f'../data/hourglass_localization_rectification/{folder}/{file}')

    return

In [None]:
# Delete old synth data
if reset_synth:
    delete_old()

In [None]:
# Generating train/val/test datasets
print(f'Generating {n_train} train images...')
for i in range(n_train - len(os.listdir('../data/hourglass_localization_rectification/train'))):
    gen_save('train', yolo_size=yolo_size)

print(f'Generating {n_val} val images...')
for i in range(n_val - len(os.listdir('../data/hourglass_localization_rectification/val'))):
    gen_save('val', yolo_size=yolo_size)

print(f'Generating {n_test} test images...')
for i in range(n_test - len(os.listdir('../data/hourglass_localization_rectification/test'))):
    gen_save('test', yolo_size=yolo_size)

In [None]:
# Test file
def load_test_file():
    '''Loads a test file for inspection'''

    test_file = '../data/hourglass_localization_rectification/train/' + os.listdir('../data/hourglass_localization_rectification/train')[0]
    test_tensor = torch.load(test_file)

    print(test_file)
    return test_tensor
load_test_file()

In [None]:
# Helper for presenting images / creating PIL images
def tensors2PIL(tensor, size, debug=False):
    # Move to CPU
    tensor = tensor.cpu()

    # Split into DMC and texture images
    dmc = tensor[:3].numpy()
    texture = tensor[3:6].numpy()
    heatmap = tensor[6:].numpy()

    # Multiply by 255 to convert to 0-255 range
    dmc = dmc * 255
    texture = texture * 255

    # Ensure no values are above 255
    dmc[dmc > 255] = 255
    texture[texture > 255] = 255

    # Convert to uint8
    dmc = dmc.astype(np.uint8)
    texture = texture.astype(np.uint8)

    # Reshape dmc to 3D
    dmc = dmc.reshape(3, size, size).transpose(1, 2, 0)

    # Reshape texture to 3D
    texture = texture.reshape(3, size, size).transpose(1, 2, 0)

    # Reshape heatmaps to 2D
    heatmaps = []
    for i in range(4):
        heatmaps.append(heatmap[i].reshape(size, size))

    # Convert to PIL images
    dmc = Image.fromarray(dmc).convert('L')
    texture = Image.fromarray(texture).convert('RGB')
    for idx, heatmap in enumerate(heatmaps):
        heatmap_colored = cv2.applyColorMap((heatmap * 255).astype(np.uint8), cv2.COLORMAP_JET)  # Apply color map
        dmc_tmp = cv2.cvtColor(np.array(dmc), cv2.COLOR_GRAY2RGB)
        overlay = cv2.addWeighted(dmc_tmp, 0.5, heatmap_colored, 0.5, 0) # Blend images
        heatmap_img = cv2.cvtColor(overlay, cv2.COLOR_BGR2RGB)
        heatmaps[idx] = Image.fromarray(heatmap_img)

    if debug:
        display(dmc)
        display(texture)
        for heatmap in heatmaps:
            display(heatmap)

    return dmc, texture, heatmaps

test_tensor = load_test_file()
dmc, texture, heatmaps = tensors2PIL(test_tensor, size=yolo_size, debug=True)

In [None]:
def add_text_n_symbols(tensor_image):
    '''Adds text and symbols randomly around the DMC image'''
    # Load DMC image
    dmc = tensor_image[:3]

    # Pad DMC to input size to allow for text/symbols to be added
    dmc = v2.functional.pad(dmc, (yolo_size, yolo_size, yolo_size, yolo_size), fill=1)

    # Convert to PIL image
    dmc = v2.functional.to_pil_image(dmc)
    dmc = dmc.convert('RGB')

    # Create 8 empty segments to add to DMC image
    segments = [Image.new('RGB', (yolo_size, yolo_size), (255, 255, 255)) for _ in range(8)]

    # For 8 segments of the image, add random text
    for i in range(8):
        # 50% chance to add anything
        if random.random() > 0.5:
            font = ImageFont.load_default(random.randint(yolo_size/8, yolo_size/2))
            text = ''

            # 50% chance to add text
            if random.random() > 0.5:
                # From 1 to 10 characters of random letters and numbers
                for _ in range(random.randint(1, 10)):
                    text += random.choice(string.ascii_letters + string.digits)

                # If len text divisible by 2 or 3, 50% chance to throw newline characters in
                if len(text) % 2 == 0:
                    if random.random() > 0.5:
                        text = text[:len(text)//2] + '\n' + text[len(text)//2:]
                    
                elif len(text) % 3 == 0:
                    if random.random() > 0.5:
                        if len(text) < 9:
                            text = text[:len(text)//3] + '\n' + text[len(text)//3:]
                        else:
                            text = text[:len(text)//3] + '\n' + text[len(text)//3:len(text)//3*2] + '\n' + text[len(text)//3*2:]

            # else add symbol
            else:
                # Generate random symbol
                text = random.choice(['!', '@', '#', '$', '%', '^', '&', '*', '(', ')', '-', '_', '+', '=', '[', ']', '{', '}', '|', ';', ':', '<', '>', ',', '.', '/', '?'])

            # Draw text/symbol on segment in random x and y position
            ImageDraw.Draw(segments[i]).text((random.randint(0, yolo_size-100), random.randint(0, yolo_size-50)), text, font=font, fill=(0, 0, 0))
            # ImageDraw.Draw(segments[i]).text((0, 0), text, font=font, fill=(0, 0, 0))
        
        # Else segment remains white
        else:
            continue
    
    # Draw segments on DMC image
    dmc.paste(segments[0], (0, 0))                     # top left
    dmc.paste(segments[1], (yolo_size, 0))             # top
    dmc.paste(segments[2], (yolo_size*2, 0))           # top right
    dmc.paste(segments[3], (0, yolo_size))             # left
    dmc.paste(segments[4], (yolo_size*2, yolo_size))   # right
    dmc.paste(segments[5], (0, yolo_size*2))           # bottom left
    dmc.paste(segments[6], (yolo_size, yolo_size*2))   # bottom
    dmc.paste(segments[7], (yolo_size*2, yolo_size*2)) # bottom right

    # Resize to input size
    dmc = dmc.resize((yolo_size, yolo_size), Image.NEAREST)

    # Convert to tensor
    transform = v2.Compose([
        v2.ToImage(),
        v2.ToDtype(torch.float32, scale=True)
    ])
    dmc = transform(dmc)

    # Load, pad, resize heatmaps too
    heatmaps = tensor_image[6:]
    heatmaps = v2.functional.pad(heatmaps, (yolo_size, yolo_size, yolo_size, yolo_size), fill=0)
    heatmaps = v2.functional.resize(heatmaps, (yolo_size, yolo_size))

    # For each heatmap keep the brightest pixel only and make it 1
    for i in range(4):
        heatmap = heatmaps[i]
        heatmap[heatmap != heatmap.max()] = 0
        heatmap[heatmap == heatmap.max()] = 1
        heatmaps[i] = heatmap

    return dmc, heatmaps

test_tensor = load_test_file()
test_tensor[:3], test_tensor[6:] = add_text_n_symbols(test_tensor)
dmc, texture, heatmaps = tensors2PIL(test_tensor, size=yolo_size, debug=True)

In [None]:
def flood_fill(tensor_image, x, y):
    '''Finds and fills all connected black pixels in image starting from x, y'''
    rows, cols = tensor_image.shape[-2], tensor_image.shape[-1]
    visited = set()
    stack = [(x, y)]

    connected_pixels = []

    while stack:
        cx, cy = stack.pop()
        if (cx, cy) in visited:
            continue
        visited.add((cx, cy))

        # Ensure it's within bounds and is black
        if 0 <= cx < rows and 0 <= cy < cols and tensor_image[cx, cy] == 0:
            connected_pixels.append((cx, cy))

            # Add neighbors (if not at the edge of the image)
            if cx > 0:
                stack.append((cx-1, cy))
            if cx < rows-1:
                stack.append((cx+1, cy))
            if cy > 0:
                stack.append((cx, cy-1))
            if cy < cols-1:
                stack.append((cx, cy+1))

    for idx in connected_pixels:
        tensor_image[idx[0], idx[1]] = 1

    return tensor_image

def shape_transform(tensor, training_mode=False, epoch=None, max_complex_epoch=max_complex_epoch, debug=False):
    '''Applies random shape transformations to image'''
    generation_error = False

    # Split into DMC and heatmaps
    dmc = tensor[:3]
    heatmaps = tensor[6:]

    # Temporarily downsize dmc to 1 channel to allow for next steps (IMPORTANT: THIS WILL SUCK IF SHAPE TRANSFORM SOMEHOW GETS USED AFTER COLOR TRANSFORM)
    dmc = dmc[0]

    # Stacking tensors for easier processing
    batched = torch.stack((dmc, heatmaps[0], heatmaps[1], heatmaps[2], heatmaps[3]), dim=0)

    # First pad batched tensor to ensure corners are not near the edge
    # batched = v2.functional.pad(batched, padding=100, fill=0, padding_mode='constant')

    # DEBUG
    heatmaps = batched[1:]
    for i in range(4):
        brightest = heatmaps[i].max()
        torch_mask = torch.where(heatmaps[i] == brightest, torch.tensor(1.0), torch.tensor(0.0))
        if generation_error == False and torch_mask.nonzero()[0].tolist() == [0, 0]:
            print('generation_error - before transforms')
            generation_error = True
        heatmaps[i] = torch_mask

    # Defining max scale and translations
    max_scale = (0.8, 1.8)
    max_translate = (0.2, 0.2)

    if training_mode:
        # Calculate scale and translate amount based on epoch
        if epoch < 10:
            scale = (1, 1)
            translate = (0, 0)
        # Increase scale and translate amount linearly from 10 to max_epoch
        elif epoch < max_complex_epoch:
            scale = 1 + (max_scale[1] - 1) * ((epoch - 10) / (max_complex_epoch - 10))
            scale = (1-scale + 1, scale)
            translate = max_translate[1] * ((epoch - 10) / (max_complex_epoch - 10))
            translate = (translate, translate)
        # After max_epoch, keep at max values
        else:
            scale = max_scale
            translate = max_translate
    else:
        # If not in training mode, just use max values
        scale = max_scale
        translate = max_translate

    if debug:
        print(scale)
        print(translate)

    # Applying transformations to all tensors
    transforms = v2.Compose([
        v2.RandomRotation(180, # Random rotation (-180 to 180 degrees, so full rotation)
                          interpolation=Image.BILINEAR,
                          expand=True),
        v2.RandomPerspective(distortion_scale=0.5,
                             p=0.5,
                             interpolation=Image.BILINEAR,
                             ),
        v2.RandomAffine(degrees=0, # No rotation
                        shear=(-20, 20, -20, 20), # Random shear on x and y axis (squish)
                        interpolation=Image.BILINEAR,
                        scale=scale, # Random scaling (zoom in/out)
                        translate=translate, # Random translation (move dmc around a bit)
                        ),
    ])
    batched = transforms(batched)

    # Resize all to original size
    dmc_width, dmc_height = dmc.shape[-1], dmc.shape[-2]
    batched = v2.functional.resize(batched, (dmc_height, dmc_width), interpolation=Image.BILINEAR)

    # Binarize dmc
    dmc = batched[0]
    dmc = torch.where(dmc > 0.5, torch.tensor(1.0), torch.tensor(0.0))

    # DEBUG
    heatmaps = batched[1:]
    for i in range(4):
        brightest = heatmaps[i].max()
        torch_mask = torch.where(heatmaps[i] == brightest, torch.tensor(1.0), torch.tensor(0.0))
        if generation_error == False and torch_mask.nonzero()[0].tolist() == [0, 0]:
            # print('generation_error - after transforms')
            generation_error = True
        heatmaps[i] = torch_mask

    # Fill outer black areas of dmc with white
    dmc = flood_fill(dmc, 0, 0)
    dmc = flood_fill(dmc, 0, dmc.shape[-1]-1)
    dmc = flood_fill(dmc, dmc.shape[-2]-1, 0)
    dmc = flood_fill(dmc, dmc.shape[-2]-1, dmc.shape[-1]-1)

    # Split back into DMC and heatmaps
    dmc = dmc.unsqueeze(0)
    heatmaps = batched[1:]

    # Binarize heatmaps so that brightest pixel is 1 and rest are 0
    for i in range(4):
        brightest = heatmaps[i].max()
        torch_mask = torch.where(heatmaps[i] == brightest, torch.tensor(1.0), torch.tensor(0.0))
        if generation_error == False and torch_mask.nonzero()[0].tolist() == [0, 0]:
            print('generation_error - after paintbucket')
        heatmaps[i] = torch_mask

    # Add back color channels to DMC
    dmc = torch.cat((dmc, dmc, dmc), dim=0)

    return dmc, heatmaps, generation_error

test_tensor = load_test_file()
test_tensor[:3], test_tensor[6:] = add_text_n_symbols(test_tensor)
test_tensor[:3], test_tensor[6:], generation_error = shape_transform(test_tensor, training_mode=True, epoch=150, max_complex_epoch=max_complex_epoch, debug=True)
dmc, texture, heatmaps = tensors2PIL(test_tensor, size=yolo_size, debug=True)

In [None]:
def combine_images(tensor):
    '''Combines DMC and texture tensors into one tensor'''
    dmc = tensor[:3].unsqueeze(0)
    texture = tensor[3:6].unsqueeze(0)

    # Increment black pixel intensity of DMC randomly
    increment = random.uniform(0.1, 0.5)
    dmc = torch.where(dmc < 0.5, dmc + increment, dmc)

    # Multiply DMC onto each texture channel
    texture = torch.mul(texture, dmc)

    # Remove batch dimension
    texture = texture.squeeze(0)

    return texture

test_tensor = load_test_file()
test_tensor[:3] = combine_images(test_tensor)
combined_dmc, texture, heatmaps = tensors2PIL(test_tensor, size=yolo_size, debug=True)

In [None]:
def color_transform(tensor):
    '''Applies random color transformations to image'''

    # Unsqueeze tensor to add batch dimension
    tensor = tensor.unsqueeze(0)

    # Random kernel size for gaussian blur
    kernel_size = random.choice([3, 5])
    sharpness = random.uniform(0.5, 1.5)

    transforms = v2.Compose([
        # v2.ColorJitter(brightness = (0.5, 1.5),
        #                contrast   = (0.5, 1.5),
        #                saturation = (0.5, 1.5),
        #                hue        = (-0.5, 0.5),
        #                ),
        # v2.RandomChannelPermutation(),
        v2.RandomPhotometricDistort(brightness = (0.5, 1.5),
                                    contrast   = (0.5, 1.5),
                                    saturation = (0.5, 1.5),
                                    hue        = (-0.5, 0.5),
                                    ),
        v2.GaussianBlur(kernel_size=kernel_size, sigma=(0.1, 25)), # chance to blur a lot or a little - mostly an ok amount
        # v2.GaussianNoise(), # not implemented for PIL images
        # v2.RandomInvert(0.2), # lower chance of inversion # apply to code not background
        # v2.RandomPosterize(8),
        # v2.RandomSolarize(0.5, 0.5),
        v2.RandomAdjustSharpness(sharpness, 0.5),
        # v2.RandomAutocontrast(),
        # v2.RandomEqualize(0.2), # lower chance of equalization
    ])

    # Dummy transform for testing
    # transforms = v2.Compose([
    #     v2.Pad(padding=0)
    # ])

    tensor = transforms(tensor)

    # Remove batch dimension
    tensor = tensor.squeeze(0)

    return tensor

# Testing color_transform
test_tensor = load_test_file()
test_tensor[:3] = color_transform(test_tensor[:3])
test_tensor[3:6] = color_transform(test_tensor[3:6])
dmc, texture, heatmaps = tensors2PIL(test_tensor, size=yolo_size, debug=True)

In [None]:
# Testing all transforms
test_tensor = load_test_file()

print('DMC TEXTURE PAIR')
# dmc, texture, heatmaps = tensors2PIL(test_tensor, size=yolo_size, debug=True)

print('TEXT AND SYMBOLS')
test_tensor[:3], test_tensor[6:] = add_text_n_symbols(test_tensor)
# dmc, texture, heatmaps = tensors2PIL(test_tensor, size=yolo_size, debug=True)

print('SHAPE TRANSFORM')
test_tensor[:3], test_tensor[6:], generation_error = shape_transform(test_tensor, training_mode=False, debug=True)
# dmc, texture, heatmaps = tensors2PIL(test_tensor, size=yolo_size, debug=True)

if generation_error:
    print('GENERATION ERROR')

print('COMBINE IMAGES')
dmc = combine_images(test_tensor)
# dmc_img, texture_img, heatmaps = tensors2PIL(test_tensor, size=yolo_size, debug=True)

print('COLOR TRANSFORM')
dmc = color_transform(dmc)

# Overwrite for the sake of seeing
test_tensor[:3] = dmc
print('FINAL SIZES')
print(dmc.shape)
print(test_tensor[6:].shape)
dmc_img, texture_img, heatmaps = tensors2PIL(test_tensor, size=yolo_size, debug=True)

In [None]:
# Helper function for resizing heatmaps
def resize_heatmaps(heatmaps, size):
    '''Resizes heatmaps to given size'''
    new_heatmaps = np.zeros((4, size, size))
    for i in range(4):
        # Find proportionally where the max value is in the heatmap
        y, x = np.unravel_index(np.argmax(heatmaps[i].cpu().numpy()), heatmaps[i].shape)
        y = int(y * size / heatmaps[i].shape[0])
        x = int(x * size / heatmaps[i].shape[1])

        # Create new heatmap
        heatmap = np.zeros((size, size))
        heatmap[y, x] = 1

        new_heatmaps[i] = heatmap

    # Convert to tensor
    new_heatmaps = torch.tensor(new_heatmaps).float()

    return new_heatmaps

# Actually creating the heatmaps from the pixel heatmaps
def create_heatmap(heatmaps, size, sigma):
    '''Creates heatmaps from pixel heatmaps'''
    heatmap = np.zeros((size, size))
    for pixel_heatmap in heatmaps:
        # Center of keypoint
        y, x = np.unravel_index(np.argmax(pixel_heatmap.cpu().numpy()), pixel_heatmap.shape)

        # Fill heatmap with gaussian distribution (keeping only 95% of the distribution)
        threshold = 0.05
        for y2 in range(size):
            for x2 in range(size):
                value = np.exp(-((x2 - x)**2 + (y2 - y)**2) / (2 * sigma**2))  # Sigma = std of the Gaussian distribution
                if value >= threshold:
                    heatmap[y2, x2] = value

    # Convert to tensor
    heatmap = torch.tensor(heatmap).float()

    return heatmap

In [None]:
# Find latest YOLO model from "trainXX" folder
model_to_train = [d.replace('train', '') for d in os.listdir('../yolo/runs/obb') if d.startswith('train')]
if model_to_train == ['']:
    model_to_train = ''
else:
    model_to_train.remove('') # remove empty string
    model_to_train = sorted(model_to_train, key=int)[-1] # get the latest one
print(f'Training YOLO from train{model_to_train}')

YOLO_model = YOLO(f'../yolo/runs/obb/train{model_to_train}/weights/best.pt')
YOLO_model.eval() # Set to eval mode

def save_yolo_failure(image, heatmap):
    '''Saves image and heatmap to yolo failures folder for further training'''
    # Convert image to PIL
    image = image.squeeze(0).permute(1, 2, 0).numpy()
    image = (image * 255).astype(np.uint8)
    image = Image.fromarray(image)

    # Filename is latest index filename + 1
    if len(os.listdir('../data/hourglass_localization_rectification/yolo_failures/train/images/')) == 0:
        idx = 0
    else:
        idx = int(max(os.listdir('../data/hourglass_localization_rectification/yolo_failures/train/images/'), key=lambda x: int(x.split('.')[0])).split('.')[0]) + 1

    image.save(f'../data/hourglass_localization_rectification/yolo_failures/train/images/{idx}.jpg')

    # Get brightest pixels in heatmaps in YOLO train format (class_index x1 y1 x2 y2 x3 y3 x4 y4)
    brightest_0_y, brightest_0_x = (heatmap[0]==torch.max(heatmap[0])).nonzero().tolist()[0]
    brightest_1_y, brightest_1_x = (heatmap[1]==torch.max(heatmap[1])).nonzero().tolist()[0]
    brightest_2_y, brightest_2_x = (heatmap[2]==torch.max(heatmap[2])).nonzero().tolist()[0]
    brightest_3_y, brightest_3_x = (heatmap[3]==torch.max(heatmap[3])).nonzero().tolist()[0]

    # Convert to proportional coordinates
    height, width = image.size
    x1 = brightest_0_x / width
    y1 = brightest_0_y / height
    x2 = brightest_1_x / width
    y2 = brightest_1_y / height
    x3 = brightest_2_x / width
    y3 = brightest_2_y / height
    x4 = brightest_3_x / width
    y4 = brightest_3_y / height

    # Save label in YOLO format
    with open(f'../data/hourglass_localization_rectification/yolo_failures/train/labels/{idx}.txt', 'w') as f:
        f.write(f'0 {x1} {y1} {x2} {y2} {x3} {y3} {x4} {y4}')
    
    return

def YOLO_crop(model, image, heatmap):
    '''Uses YOLO model to detect DMC in image tensor, and crops both image and heatmap tensors accordingly
    Returns cropped image and heatmap tensors and error flag for if no DMC was found'''
    # Combine image tensors into one RGB tensor
    image = torch.stack((image[0], image[1], image[2]), dim=0)

    # Add batch dimension
    image = image.unsqueeze(0)

    # Run YOLO model on image
    results = model.predict(image, verbose=False)

    # If no detections, raise error and save image and label for further training of YOLO
    if results[0].obb is None or len(results[0].obb.xywhr) == 0:
        print('No detections')
        save_yolo_failure(image, heatmap)
        return None, None, True

    # Crop image tensor to bounding box
    x, y, w, h, r = results[0].obb.xywhr[0]
    cos_r, sin_r = torch.cos(r), torch.sin(r)
    H, W = image.shape[2], image.shape[3]
    pixel_pad = W * 0.1 # 10% of image size
    h, w = h + pixel_pad, w + pixel_pad # Add padding to height and width
    tx = 2*x / W - 1
    ty = 2*y / H - 1
    sx = w / W
    sy = h / H
    theta = torch.tensor([
        [cos_r * sx, -sin_r * sy, tx],
        [sin_r * sx, cos_r * sy, ty]
        ], dtype=torch.float32)
    theta = theta.unsqueeze(0)
    h, w = int(h.item()), int(w.item())
    grid = F.affine_grid(theta, (1, image.shape[1], h, w), align_corners=True)
    image_crop = F.grid_sample(image, grid, align_corners=True)
    image_crop = image_crop.squeeze(0)

    # Crop heatmap tensors to bounding box
    # theta = theta.repeat(4, 1, 1) # Repeat theta for each heatmap
    grid = F.affine_grid(theta, (1, 4, h, w), align_corners=True)
    heatmap_crop = heatmap.unsqueeze(0)
    heatmap_crop = F.grid_sample(heatmap_crop, grid, align_corners=True)
    heatmap_crop = heatmap_crop.squeeze(0)

    # Keep only the brightest pixel in each heatmap
    for i in range(4):
        brightest = heatmap_crop[i].max()
        torch_mask = torch.where(heatmap_crop[i] == brightest, torch.tensor(1.0), torch.tensor(0.0))
        heatmap_crop[i] = torch_mask
    # for i in range(4):
    #     print(f'heatmap {i} nonzero: {torch.count_nonzero(heatmap[i])}')

    # # Display image (DEBUG)
    # image_crop = image_crop.squeeze(0).permute(1, 2, 0).numpy()
    # image_crop = (image_crop * 255).astype(np.uint8)
    # image_crop = Image.fromarray(image_crop)
    # display(image_crop)

    # # Display grayscale heatmaps (DEBUG)
    # plt.figure(figsize=(12, 12))
    # for i in range(4):
    #     plt.subplot(2, 2, i+1)
    #     heatmap_i = heatmap_crop[i].squeeze(0).numpy()
    #     heatmap_i = (heatmap_i * 255).astype(np.uint8)
    #     heatmap_i = Image.fromarray(heatmap_i)
    #     plt.imshow(heatmap_i)
    #     plt.axis('off')
    # plt.show()

    # Validate that each heatmap still has a single bright pixel
    for i in range(4):
        brightest = heatmap_crop[i].max()
        torch_mask = torch.where(heatmap_crop[i] == brightest, torch.tensor(1.0), torch.tensor(0.0))
        if torch_mask.nonzero().shape[0] != 1:
            print('generation_error - after YOLO crop')
            save_yolo_failure(image, heatmap)
            # Display image
            image_crop = image_crop.squeeze(0).permute(1, 2, 0).numpy()
            image_crop = (image_crop * 255).astype(np.uint8)
            image_crop = Image.fromarray(image_crop)
            display(image_crop)
            return None, None, True

    return image_crop, heatmap_crop, False

image, heatmaps, yolo_err = YOLO_crop(YOLO_model, test_tensor[:3], test_tensor[6:])

In [None]:
# Clearing YOLO failure folders
for folder in ['train', 'val']:
    for file in os.listdir(f'../data/hourglass_localization_rectification/yolo_failures/{folder}/images/'):
        os.remove(f'../data/hourglass_localization_rectification/yolo_failures/{folder}/images/{file}')
    for file in os.listdir(f'../data/hourglass_localization_rectification/yolo_failures/{folder}/labels/'):
        os.remove(f'../data/hourglass_localization_rectification/yolo_failures/{folder}/labels/{file}')

    # delete cache if exists
    if os.path.exists(f'../data/hourglass_localization_rectification/yolo_failures/{folder}/labels.cache'):
        os.remove(f'../data/hourglass_localization_rectification/yolo_failures/{folder}/labels.cache')

In [None]:
# Dataloader
class DMCDataset(Dataset):
    def __init__(self, image_dir, training_mode=False, epoch=None, max_complex_epoch=max_complex_epoch):
        self.image_dir = image_dir
        self.files = os.listdir(image_dir)
        self.training_mode = training_mode
        self.epoch = epoch
        self.max_complex_epoch = max_complex_epoch

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        yolo_err = True
        yolo_err_count = 0
        yolo_max_err = 100

        while yolo_err == True:

            # Shape transform dmc and heatmaps
            synth_err = True
            synth_err_count = 0
            while synth_err == True:
                # Load the tensor containing dmc, texture, and heatmaps
                tensor = torch.load(f'{self.image_dir}/{self.files[idx]}')
                tensor[:3], tensor[6:] = add_text_n_symbols(tensor) # Add text and symbols
                tensor[:3], tensor[6:], synth_err = shape_transform(tensor, training_mode=self.training_mode, epoch=self.epoch, max_complex_epoch=self.max_complex_epoch)
                synth_err_count += 1

            # Combine dmc and texture pair
            dmc = combine_images(tensor)

            # Color transform dmc
            dmc = color_transform(dmc)

            # Crop image and heatmaps using YOLO
            dmc, heatmaps, yolo_err = YOLO_crop(YOLO_model, dmc, tensor[6:])

            if yolo_err == True:
                yolo_err_count += 1
            if yolo_err_count >= yolo_max_err:
                print('YOLO max error reached')
                error

        # Resize image to input size
        dmc = v2.functional.resize(dmc, (input_size, input_size), interpolation=Image.BILINEAR)

        # Resize heatmaps to output size
        heatmaps = resize_heatmaps(heatmaps, output_size)

        # Create actual heatmaps from binarized heatmaps
        heatmap = create_heatmap(heatmaps, output_size, sigma)

        return dmc, heatmap, synth_err_count

def define_train_loader(epoch):
    train_dataset = DMCDataset(
        image_dir='../data/hourglass_localization_rectification/train',
        training_mode=True,
        epoch=epoch,
        max_complex_epoch=max_complex_epoch,
    )

    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
    )

    return train_loader

train_loader = define_train_loader(epoch=0)

# Testing dataloader
for dmc, heatmap, count in train_loader:
    print(dmc.shape)
    print(heatmap.shape)
    print(count)
    print(f'{count.sum().item()-batch_size} generation errors')

    # Display dmc and heatmap
    plt.figure(figsize=(5, 5))
    plt.axis('Off')
    plt.imshow(dmc[0].squeeze(0).permute(1, 2, 0))
    plt.show()

    plt.figure(figsize=(5, 5))
    plt.axis('Off')
    plt.imshow(heatmap[0].squeeze(0))
    plt.show()
    break

In [None]:
val_dataset = DMCDataset(
    image_dir='../data/hourglass_localization_rectification/val',
)

val_loader = torch.utils.data.DataLoader(
    val_dataset,
    batch_size=batch_size,
    shuffle=True,
)

test_dataset = DMCDataset(
    image_dir='../data/hourglass_localization_rectification/test',
)

test_loader = torch.utils.data.DataLoader(
    test_dataset,
    batch_size=batch_size,
    shuffle=False,
)

# Stacked Hourglass Setup

In [None]:
# Example usage with dataloader
model = StackedHourglassNetwork(
    num_stacks=n_stacks,
    num_features=input_size,
    num_output_points=1,
)

for images, heatmaps, count in train_loader:
    print(images.shape)
    print(heatmaps.shape)
    print(f'{count.sum().item()-batch_size} generation errors')
    outputs = model(images)
    for output in outputs:
        print(output.shape) # Expected: (8, 1, H, W) representing 1 heatmap per image
        break
    break
print(output)

In [None]:
# Convert tensor to numpy for visualization
heatmap = output[0, 0].detach().cpu().numpy()  # Shape: (1, H, W)
print(heatmap.shape)

plt.figure(figsize=(5, 5))
plt.imshow(heatmap, cmap='jet')
plt.title('Produced Heatmap')
plt.axis('off')

plt.show()

In [None]:
def extract_keypoints(heatmap, input_size, output_size):
    heatmap = heatmap.copy() # Don't modify original heatmap

    scale = input_size / output_size  # Scale factor to match original resolution
    keypoints = []

    # Find 4 brightest points in heatmap
    for _ in range(4):
        y, x = np.unravel_index(np.argmax(heatmap), heatmap.shape)
        keypoints.append((int(x * scale), int(y * scale))) # Scale keypoints
        heatmap[y, x] = 0 # Remove brightest point to find next brightest point

    return keypoints

def heatmap_viz(image, keypoints, heatmap, alpha=0.5):
    # Combined figure
    fig, axes = plt.subplots(1, 2, figsize=(10, 5))

    # Original image with keypoints
    image_kp = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)  # Convert for OpenCV
    colors = [(255, 0, 0), (0, 255, 0), (0, 0, 255), (0, 255, 255)] # Colors are blue, green, red, yellow (BGR)
    for i, (x, y) in enumerate(keypoints):
        cv2.circle(image_kp, (x, y), 5, colors[i], -1)  # Draw colored circle
    axes[0].imshow(cv2.cvtColor(image_kp, cv2.COLOR_BGR2RGB))
    axes[0].axis('off')
    axes[0].set_title('Original Image')

    # Heatmaps
    heatmap_resized = cv2.resize(heatmap, (image.shape[1], image.shape[0]))  # Resize to input image size
    heatmap_resized = (heatmap_resized - heatmap_resized.min()) / (heatmap_resized.max() - heatmap_resized.min())  # Normalize
    heatmap_colored = cv2.applyColorMap((heatmap_resized * 255).astype(np.uint8), cv2.COLORMAP_JET)  # Apply color map
    overlay = cv2.addWeighted(image, 1 - alpha, heatmap_colored, alpha, 0)  # Blend images
    axes[1].imshow(cv2.cvtColor(overlay, cv2.COLOR_BGR2RGB))
    axes[1].axis('off')
    axes[1].set_title('Heatmap')

    plt.show()

In [None]:
# Example vizualization usage with dataloader
model = StackedHourglassNetwork(
    num_stacks=n_stacks,
    num_features=input_size,
    num_output_points=1,
)

for images, labels, count in train_loader:
    outputs = model(images)
    break

heatmaps = outputs[0][0][0].detach().cpu().numpy() # Convert tensor to numpy
print(heatmaps.shape)
keypoints = extract_keypoints(heatmaps, input_size=input_size, output_size=output_size)
print('Predicted keypoints:', keypoints)

img = images[0].permute(1, 2, 0).numpy() * 255
img = img.astype(np.uint8) # Convert to ints only
heatmap_viz(img, keypoints, heatmaps)

In [None]:
# Example vizualization of true keypoints and heatmaps
for images, labels, count in train_loader:
    for i in range(images.shape[0]):
        image = images[i].permute(1, 2, 0).numpy() * 255
        image = image.astype(np.uint8) # Convert to ints only

        heatmaps = labels[i].detach().cpu().numpy()  # Convert tensor to numpy
        print(heatmaps.shape)
        keypoints = extract_keypoints(heatmaps, input_size=input_size, output_size=output_size)
        print('True keypoints:', keypoints)
        heatmap_viz(image, keypoints, heatmaps)
    break

# Model Training

In [None]:
class EarlyStopper:
    def __init__(self, patience=1, min_delta=0):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.min_validation_loss = float('inf')

    def early_stop(self, validation_loss):
        if validation_loss < self.min_validation_loss:
            self.min_validation_loss = validation_loss
            self.counter = 0
        elif validation_loss > (self.min_validation_loss + self.min_delta):
            self.counter += 1
            if self.counter >= self.patience:
                return True
        return False

In [None]:
def calculate_loss(outputs, heatmaps):
    '''Calculates the loss between predicted and target heatmaps'''

    # outputs dims: (batch_size, num_stacks, num_keypoints, H, W)
    # heatmaps dims: (batch_size, num_keypoints, H, W)

    # Remove batch dims
    outputs = outputs.squeeze(0) # Now dims: (num_stacks, num_keypoints, H, W)
    heatmap = heatmaps.squeeze(0) # Now dims: (num_keypoints, H, W)

    # IF I WANT TO TRY DOING BATCHES AGAIN JUST SUM OVER BATCH DIMENSION INSTEAD OF AVEREGING AT THE END

    # losses = torch.zeros(outputs.shape[0]).to(device) # Loss for samples in batch

    # # Iterate over batch
    # for i in range(outputs.shape[0]):
    #     heatmap = heatmaps[i]
    #     weights = (heatmap > 0).float() * weight + 1 # Weighted loss (1 for pixels not part of gaussian, gaussian pixel value * weight for pixels part of gaussian)

    #     # Iterate over stacks
    #     for j in range(outputs.shape[1]):
    #         output = outputs[i, j]
    #         losses[i] += torch.mean((heatmap - output)**2 * weights)
    
    # loss = torch.mean(losses) # Average loss over batch

    # Calculate loss for each stack
    weights = (heatmap > 0).float() * weight + 1 # Weighted loss (1 for pixels not part of gaussian, gaussian pixel value * weight for pixels part of gaussian)
    loss = 0
    for i in range(outputs.shape[0]):
        loss += torch.mean((heatmap - outputs[i])**2 * weights) # Accumulate loss over each stack and ground truth (weighted version)

    return loss

In [None]:
# Output and Heatmaps with circle in the middle
outputs = torch.zeros((batch_size, n_stacks, 1, 64, 64), device=device)
outputs[:, :, :, 32, 32] = 0
heatmaps = torch.zeros((batch_size, 1, 64, 64), device=device)
heatmaps[:, :, 32, 32] = 1

# Calculate loss
print(f'outputs from model: {outputs.shape}')
print(f'heatmaps pre loss: {heatmaps.shape}')
loss = calculate_loss(outputs, heatmaps)
print(loss)

In [None]:
# Training function
def train_model(model, early_stopper, train_loader, val_loader, epochs=10, lr=0.001, synthetic_training=False):
    model.to(device)
    # optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    optimizer = torch.optim.RMSprop(model.parameters(), lr=lr)

    losses = {'train': [], 'val': []} # Store losses

    lowest_val_loss = float('inf')
    best_epoch = 0
    generation_errors = 0

    for epoch in range(epochs):
        model.train()

        if synthetic_training:
            train_loader = define_train_loader(epoch)

        train_loss = 0
        for images, heatmaps, count in train_loader:
            images, heatmaps = images.to(device), heatmaps.to(device)

            # Add generation errros
            generation_errors += count.sum().item()-batch_size

            optimizer.zero_grad()

            # Forward pass
            outputs = model(images)

            # Custom loss function
            loss = calculate_loss(outputs, heatmaps)

            # Backprop
            loss.backward()
            optimizer.step()
            train_loss += loss.item()

        losses['train'].append(train_loss/len(train_loader))

        # Validation loop
        model.eval()
        val_loss = 0.0
        with torch.no_grad():
            for images, masks, _ in val_loader:
                images, masks = images.to(device), masks.to(device)

                # Forward pass
                outputs = model(images)

                # Calculate loss
                loss = calculate_loss(outputs, masks)

                val_loss += loss.item()

        losses['val'].append(val_loss/len(val_loader))

        # Keep best model based on validation loss
        if val_loss < lowest_val_loss:
            lowest_val_loss = val_loss
            best_epoch = epoch
            best_model_state = model.state_dict() # Save best model state

        # Early stopping if most complex examples are in and validation loss increases
        if epoch > max_complex_epoch+early_stopper.patience and early_stopper.early_stop(val_loss):
            print(f'Early stopping at epoch {epoch+1}')
            break

        print(f'Epoch {epoch+1}/{epochs}, Train Loss: {train_loss/len(train_loader):.5f}, Val Loss: {val_loss/len(val_loader):.5f}')

    print('Best epoch:', best_epoch+1)
    print(f'Total generation errors: {generation_errors}')

    # Reload the best model weights
    model.load_state_dict(best_model_state)

    return model, losses, lowest_val_loss

In [None]:
def fancy_train_model():
    model = StackedHourglassNetwork(
        num_stacks=n_stacks,
        num_features=input_size,
        num_output_points=1,
        )

    print('Initial training...')
    early_stopper = EarlyStopper(patience=20, min_delta=0)
    train_loader = define_train_loader(epoch=0)
    model, losses, lowest_val_loss = train_model(model,
                                                 early_stopper,
                                                 train_loader,
                                                 val_loader,
                                                 epochs=1000,
                                                 lr=2.5e-4, # As in paper
                                                 synthetic_training=True
                                                 )

    print('\nReduced LR training...')
    early_stopper = EarlyStopper(patience=20, min_delta=0)
    train_loader = define_train_loader(epoch=max_complex_epoch+1) # Ensuring most complex examples are being used
    model, losses_tmp, lowest_val_loss_tmp = train_model(model,
                                                         early_stopper,
                                                         train_loader,
                                                         val_loader,
                                                         epochs=1000,
                                                         lr=(2.5e-4) / 5, # Reduce by factor 5 (as in paper)
                                                         synthetic_training=True
                                                         )
    
    for loss in losses_tmp['train']:
        losses['train'].append(loss)
    for loss in losses_tmp['val']:
        losses['val'].append(loss)

    print('\nTraining complete!')

    return model, losses

In [None]:
# Most recent change: random text and symbols added to synthesis
model, losses = fancy_train_model()

In [None]:
# Show training and validation losses
plt.plot(losses['train'], label='Train Loss')
plt.plot(losses['val'], label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.show()

In [None]:
torch.save(model, '../models/hourglass_localization_rectification.pth')

# Model Evaluation

In [None]:
model = torch.load('../models/hourglass_localization_rectification.pth', weights_only=False, map_location=device)

In [None]:
def evaluate_model(model, loader):
    model.eval()

    total_loss = 0.0
    count = 0
    with torch.no_grad():
        for images, heatmaps, _ in loader:
            images, heatmaps = images.to(device), heatmaps.to(device)

            # Forward pass
            outputs = model(images)

            # Calculate loss
            loss = calculate_loss(outputs, heatmaps)

            total_loss += loss.item()
            count += 1

    print(f'Total loss: {total_loss:.4f}')
    print(f'Average batch loss: {total_loss/count:.4f}')
    print(f'Average single loss: {total_loss/(count*images.shape[0]):.4f}')

# Normal trained

In [None]:
# Train loss
evaluate_model(model, train_loader)

In [None]:
# Val loss
evaluate_model(model, val_loader)

In [None]:
# Test loss
evaluate_model(model, test_loader)

In [None]:
def compare_keypoints(orig_image, outputs, true_heatmaps):
    '''Compares predicted and true keypoints for a single image'''
    img = orig_image.permute(1, 2, 0).cpu().numpy() * 255
    img = img.astype(np.uint8) # Convert to ints only

    outputs_reshape = outputs[-1][0] # Grab last stack output heatmap 0 (64, 64)

    pred_keypoints = extract_keypoints(outputs_reshape, input_size=input_size, output_size=output_size)
    true_keypoints = extract_keypoints(true_heatmaps, input_size=input_size, output_size=output_size)

    print(f'Predicted keypoints: {pred_keypoints}')
    heatmap_viz(img, pred_keypoints, outputs_reshape)

    print(f'True keypoints: {true_keypoints}')
    heatmap_viz(img, true_keypoints, true_heatmaps)

    return

def compare_hourglass_outputs(orig_image, outputs, true_heatmaps):
    '''Compares heatmaps across all hourglass outputs for a single image'''
    # outputs is a tensor of heatmaps for each hourglass output, with dims (num_stacks, 1, 64, 64)
    # true_heatmaps is the true heatmaps for the image (64, 64)
    
    # Display combined heatmaps
    fig, axes = plt.subplots(1, len(outputs)+2, figsize=(20, 5))

    # Display original image
    dmc_image = orig_image.permute(1, 2, 0).cpu().numpy() * 255
    dmc_image = dmc_image.astype(np.uint8) # Convert to ints only
    axes[0].imshow(dmc_image)
    axes[0].set_title('Original Image')
    axes[0].axis('off')

    # Display each hourglass output
    for i in range(len(outputs)):
        axes[i+1].imshow(outputs[i][0], cmap='jet')
        axes[i+1].set_title(f'Hourglass {i+1}')
        axes[i+1].axis('off')

    # Display true heatmap
    axes[-1].imshow(true_heatmaps, cmap='jet')
    axes[-1].set_title('True Heatmap')
    axes[-1].axis('off')

    # Show plot
    plt.show()

def print_single_result(model, loader):
    model.eval()

    with torch.no_grad():
        for images, true_heatmaps, _ in loader:
            image, true_heatmaps = images.to(device)[0], true_heatmaps.numpy()[0]
            # image shape: (3, 256, 256)
            # true_heatmaps shape: (64, 64)
            break

        # Convert image
        dmc_image = image.permute(1, 2, 0).cpu().numpy() * 255
        dmc_image = dmc_image.astype(np.uint8) # Convert to ints only

        # Single image forward pass
        images = image.unsqueeze(0)
        outputs = model(images)
        outputs = outputs[0].detach().cpu().numpy() # Grab first batch (8, 1, 64, 64)

        compare_keypoints(image, outputs, true_heatmaps)
        compare_hourglass_outputs(image, outputs, true_heatmaps)

def print_multiple_results(model, loader, n_print):
    model.eval()

    with torch.no_grad():
        printed = 0
        for images, true_heatmaps, _ in loader:  
            images, true_heatmaps = images.to(device), true_heatmaps.numpy()

            for i in range(images.shape[0]):
                if printed >= n_print:
                    break
                image, true_heatmap = images[i], true_heatmaps[i]

                # Convert image
                dmc_image = image.permute(1, 2, 0).cpu().numpy() * 255
                dmc_image = dmc_image.astype(np.uint8)

                # Single image forward pass
                images = image.unsqueeze(0)
                outputs = model(images)
                outputs = outputs[0].detach().cpu().numpy() # Grab first batch (8, 1, 64, 64)

                compare_keypoints(image, outputs, true_heatmap)
                compare_hourglass_outputs(image, outputs, true_heatmap)

                printed += 1

## Train Print

In [None]:
print_single_result(model, train_loader)

## Validation Print

In [None]:
print_single_result(model, val_loader)

## Test Print

In [None]:
print_single_result(model, test_loader)

# Finetuning to MAN data

In [None]:
# Dataloader (train, with augmentations)
class MANDataset(Dataset):
    def __init__(self, image_dir, label_dir, augment=False, retinex=False):
        self.image_dir = image_dir
        self.image_files = sorted(os.listdir(image_dir))
        self.label_dir = label_dir
        self.label_files = sorted(os.listdir(label_dir))
        self.augment = augment
        self.retinex = retinex

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        yolo_err = True
        yolo_err_count = 0
        yolo_max_err = 10

        while yolo_err == True:
            # Load the png image
            image = Image.open(f'{self.image_dir}/{self.image_files[idx]}')

            # Convert to tensor
            transform = v2.Compose([
                v2.ToImage(),
                v2.ToDtype(torch.float32, scale=True)
            ])
            dmc = transform(image)

            # Resize to input size for model
            dmc = v2.functional.resize(dmc, (yolo_size, yolo_size))

            # Load the txt label
            label = np.loadtxt(f'{self.label_dir}/{self.label_files[idx]}')

            # Extract relative coords from label and convert to raw coords
            relative_coords = [(label[1], label[2]), (label[3], label[4]), (label[5], label[6]), (label[7], label[8])]
            raw_coords = []
            for coord in relative_coords:
                raw_coords.append((int(coord[0] * image.width), int(coord[1] * image.height)))

            # Create basic heatmap
            heatmap_basic = get_heatmaps_basic(image, raw_coords, yolo_size, debug=False)
            heatmap_basic = torch.tensor(heatmap_basic).float()

            if self.augment:
                # Combine dmc and heatmap pair into single tensor
                combined = torch.cat((dmc, heatmap_basic), dim=0)

                # Apply augmentation transforms
                transforms = v2.Compose([
                    v2.RandomHorizontalFlip(p=0.5),
                    v2.RandomVerticalFlip(p=0.5),
                ])
                combined = transforms(combined)

                # Split back into DMC and heatmaps
                dmc = combined[:3]
                heatmap_basic = combined[3:]

            # Avoid crop if not augmenting and already failed once
            if self.augment == False and yolo_err_count == 1:
                print('YOLO max error reached (with no augmentation), not cropping with YOLO')
                break
            # Avoid crop if augmenting and error count too high
            if yolo_err_count >= yolo_max_err:
                print('YOLO max error reached, not cropping with YOLO')
                break

            # Crop image and heatmaps using YOLO
            dmc, heatmap_basic, yolo_err = YOLO_crop(YOLO_model, dmc, heatmap_basic)

            if yolo_err == True:
                yolo_err_count += 1
        
        # Resize image to input size
        dmc = v2.functional.resize(dmc, (input_size, input_size), interpolation=Image.BILINEAR)

        # Resize heatmap to output size
        heatmap_basic = v2.functional.resize(heatmap_basic, (output_size, output_size))
        
        if self.retinex:
            img = dmc.permute(1, 2, 0).numpy()
            img = skimage.img_as_ubyte(img)
            img = msrcr(img, sigmas=(25., 50., 100.))
            img = img.astype(np.float32) / 255
            dmc = torch.tensor(img).permute(2, 0, 1)

        # Create gaussian heatmap from basic heatmap
        heatmap_gaussian = create_heatmap(heatmap_basic, output_size, sigma)

        return dmc, heatmap_gaussian, 1

MAN_train_dataset = MANDataset(
    image_dir='../data/MAN/roboflow_oriented_boxes/train/images',
    label_dir='../data/MAN/roboflow_oriented_boxes/train/labels',
    augment=True,
    retinex=False,
)

MAN_train_loader = torch.utils.data.DataLoader(
    MAN_train_dataset,
    batch_size=batch_size,
    shuffle=True,
)

# Testing dataloader
for dmc, heatmap, _ in MAN_train_loader:
    print(dmc.shape)
    print(heatmap.shape)

    # Display dmc and heatmap
    plt.figure(figsize=(5, 5))
    plt.axis('Off')
    plt.imshow(dmc.squeeze(0).permute(1, 2, 0))
    plt.show()

    plt.figure(figsize=(5, 5))
    plt.axis('Off')
    plt.imshow(heatmap.squeeze(0))
    plt.show()
    break

In [None]:
MAN_val_dataset = MANDataset(
    image_dir='../data/MAN/roboflow_oriented_boxes/valid/images',
    label_dir='../data/MAN/roboflow_oriented_boxes/valid/labels',
    retinex=False
)

MAN_val_loader = torch.utils.data.DataLoader(
    MAN_val_dataset,
    batch_size=batch_size,
    shuffle=False,
)

MAN_test_dataset = MANDataset(
    image_dir='../data/MAN/roboflow_oriented_boxes/test/images',
    label_dir='../data/MAN/roboflow_oriented_boxes/test/labels',
    retinex=False
)

MAN_test_loader = torch.utils.data.DataLoader(
    MAN_test_dataset,
    batch_size=batch_size,
    shuffle=False,
)

In [None]:
# Finetune training
def finetune_train_model(model):
    '''Finetuning the model to MAN dataset. Similar training process to original training'''

    # print('Initial training...')
    # early_stopper = EarlyStopper(patience=20, min_delta=0)
    # model, losses, lowest_val_loss = train_model(model,
    #                                              early_stopper,
    #                                              MAN_train_loader,
    #                                              MAN_val_loader,
    #                                              epochs=1000,
    #                                              lr=2.5e-4 # As in paper
    #                                              )

    print('\nReduced LR training...')
    early_stopper = EarlyStopper(patience=20, min_delta=0)
    model, losses, lowest_val_loss_tmp = train_model(model,
                                                     early_stopper,
                                                     MAN_train_loader,
                                                     MAN_val_loader,
                                                     epochs=1000,
                                                     lr=(2.5e-4) / 5 # Reduce by factor 5 (as in paper)
                                                     )
    
    # for loss in losses_tmp['train']:
    #     losses['train'].append(loss)
    # for loss in losses_tmp['val']:
    #     losses['val'].append(loss)

    print('\nTraining complete!')

    return model, losses

In [None]:
# Load synthetic trained model
model = torch.load('../models/hourglass_localization_rectification.pth', weights_only=False, map_location=device)

# Make max_complex_epoch 0 since we are finetuning (no synthesis)
max_complex_epoch = 0

# Finetune model
model, losses = finetune_train_model(model)

# Save it
torch.save(model, '../models/hourglass_localization_rectification_finetuned.pth')

# Evaluating Finetuned Model

In [None]:
model = torch.load('../models/hourglass_localization_rectification_finetuned.pth', weights_only=False, map_location=device)

In [None]:
# Train loss
evaluate_model(model, MAN_train_loader)

In [None]:
# Val loss
evaluate_model(model, MAN_val_loader)

In [None]:
# Test loss
evaluate_model(model, MAN_test_loader)

## Train Print

In [None]:
print_single_result(model, MAN_train_loader)

## Val Print

In [None]:
print_single_result(model, MAN_val_loader)

## Test Print

In [None]:
print_multiple_results(model, MAN_test_loader, 100)