In [1]:
import torch
import matplotlib.pyplot as plt
import imageio

# Code for rendering rectangular primitive based on Soft Rasterization paper (CVPR 2019)
# Demo shows results for a rectangular mask, saving an animation of the updates as a GIF

def render_rectangle(image_shape, center, width, height, sigma=5.0):
    """
    Renders rectangle with center, width, and height params.
    Generates a probability map based on distance to rectangle edges.
    Uses sigmoid function to normalize 0,1 like in the Soft Rasterization paper, 
    however here we do rectangle instead of the triangle from the paper
    """
    mask = torch.zeros(image_shape)
    cx, cy = center

    # Grid pixel coords ij
    y_coords, x_coords = torch.meshgrid(torch.arange(image_shape[0]), torch.arange(image_shape[1]), indexing='ij')

    # Calculate distances to edge (of rectangle)
    x_dist = torch.clamp(torch.abs(x_coords - cx) - width / 2, min=0)
    y_dist = torch.clamp(torch.abs(y_coords - cy) - height / 2, min=0)

    # 2D distance calc
    dist = torch.sqrt(x_dist**2 + y_dist**2)

    # Normalize output to 0,1 with sigmoid
    mask = torch.sigmoid(-dist/sigma)  # Watch sign value of sigmoid input
    return mask

def loss_function(rendered_mask, target_mask):
    """
    Loss calc between rendered mask and target mask.
    Loss function is binary cross entropy but can be changed to others easily.
    """
    loss = torch.nn.functional.binary_cross_entropy(rendered_mask, target_mask)
    return loss

def optimize_rectangle(target_mask, init_center, init_width, init_height, lr=0.01, iterations=500, sigma=5.0):
    """
    Optimize rectangle params by minimizing loss wrt to target mask.
    """
    # Init GIF
    frames = []
    # Init params
    center = torch.tensor(init_center, dtype=torch.float32, requires_grad=True)
    width = torch.tensor(init_width, dtype=torch.float32, requires_grad=True)
    height = torch.tensor(init_height, dtype=torch.float32, requires_grad=True)

    optimizer = torch.optim.Adam([center, width, height], lr=lr)
    
    for iteration in range(iterations):
        optimizer.zero_grad()

        # Call rectangle render
        rendered_mask = render_rectangle(target_mask.shape, center, width, height, sigma)

        # Check loss
        loss = loss_function(rendered_mask, target_mask)

        loss.backward()
        optimizer.step()

        if iteration % 10 == 0:
            # Debugging Outputs
            print(f"Iteration {iteration}: Loss = {loss.item()}, Center = {center.detach()}, Width = {width.item()}, Height = {height.item()}")

        # Store frames for GIF (every 50)
        if iteration % 50 == 0:
            plt.figure(figsize=(8, 4))
            plt.subplot(1, 2, 1)
            plt.title("Target Mask")
            plt.imshow(target_mask.numpy(), cmap='gray')
            plt.axis('off')

            plt.subplot(1, 2, 2)
            plt.title("Rendered Rectangle")
            plt.imshow(rendered_mask.detach().numpy(), cmap='gray')
            plt.axis('off')

            # Save and append fig
            plt.savefig(f'frame_{iteration}.png')
            plt.close()
            frames.append(f'frame_{iteration}.png')

    # Create GIF
    with imageio.get_writer('rectangle_render.gif', mode='I', duration=0.5) as writer:
        for frame in frames:
            image = imageio.imread(frame)
            writer.append_data(image)

    return center.detach(), width.item(), height.item()

# Render Demo:
image_shape = (100, 100)
target_mask = torch.zeros(image_shape)

# Define rectangle mask for render test
target_mask[40:70, 20:50] = 1.0  
init_center = [30.0, 30.0]
init_width = 25.0
init_height = 15.0

# Optimize params to mask to render rectangle
optimized_center, optimized_width, optimized_height = optimize_rectangle(
    target_mask, init_center, init_width, init_height, lr=0.05, iterations=1000, sigma=5.0
)

print("True Center: (35, 55), Optimized Center:", optimized_center)
print("True Width: 30, Optimized Width:", optimized_width)
print("True Height: 30, Optimized Height:", optimized_height)


  from .autonotebook import tqdm as notebook_tqdm


Iteration 0: Loss = 0.3808533847332001, Center = tensor([30.0500, 30.0500]), Width = 24.950000762939453, Height = 15.050000190734863
Iteration 10: Loss = 0.36773422360420227, Center = tensor([30.5496, 30.5499]), Width = 24.448881149291992, Height = 15.550236701965332
Iteration 20: Loss = 0.35464048385620117, Center = tensor([31.0474, 31.0495]), Width = 23.943334579467773, Height = 16.051258087158203
Iteration 30: Loss = 0.34136053919792175, Center = tensor([31.5625, 31.5499]), Width = 23.43753433227539, Height = 16.55127716064453
Iteration 40: Loss = 0.3284522294998169, Center = tensor([32.0774, 32.0475]), Width = 22.930971145629883, Height = 17.047348022460938
Iteration 50: Loss = 0.31607818603515625, Center = tensor([32.5677, 32.5383]), Width = 22.42222023010254, Height = 17.5372314453125
Iteration 60: Loss = 0.30432185530662537, Center = tensor([33.0145, 33.0208]), Width = 21.91551399230957, Height = 18.01995086669922
Iteration 70: Loss = 0.293133407831192, Center = tensor([33.4218,

