If your training is finished or if you have loaded my training weights, time for the demonstration!

In [1]:
import os
import torch
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import torchvision.models as models
import torch.nn as nn

To start, let's redefine our backbone.

In [2]:
class KeypointDetector(nn.Module):
    def __init__(self, backbone='resnet18', pretrained=True):
        super().__init__()
        if backbone == 'resnet18':
            if pretrained:
                self.backbone = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
            else:
                self.backbone = models.resnet18(weights=None)
            self.backbone = nn.Sequential(*list(self.backbone.children())[:-2])
            backbone_features = 512
        elif backbone == 'resnet34':
            if pretrained:
                self.backbone = models.resnet34(weights=models.ResNet34_Weights.DEFAULT)
            else:
                self.backbone = models.resnet34(weights=None)
            self.backbone = nn.Sequential(*list(self.backbone.children())[:-2])
            backbone_features = 512
        elif backbone == 'resnet50':
            if pretrained:
                self.backbone = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
            else:
                self.backbone = models.resnet50(weights=None)
            self.backbone = nn.Sequential(*list(self.backbone.children())[:-2])
            backbone_features = 2048

        self.upsample = nn.Sequential(
            nn.ConvTranspose2d(backbone_features, 256, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Dropout2d(0.1),
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.Dropout2d(0.1),
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Dropout2d(0.1),
            nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(32, 16, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(16),
            nn.ReLU(inplace=True),
            nn.Conv2d(16, 2, kernel_size=3, padding=1),
            nn.Sigmoid()
        )

    def forward(self, x):
        features = self.backbone(x)
        heatmaps = self.upsample(features)
        return heatmaps


And to conclude, I'll let you set the paths for your image that contains a datamatrix code and your training weights by filling in "image_path" and "model_weights_path". You will thus save the predicted heatmaps of the black cells only and of all cells by our model of the datamatrix present in the image!

In [3]:
image_path = "image.jpg"
model_weights_path = "best_model_epoch_011.pth"


model = KeypointDetector(backbone='resnet18', pretrained=False)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
checkpoint = torch.load(model_weights_path, map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
model = model.to(device)
model.eval()


image = Image.open(image_path).convert('RGB')
image = image.resize((512, 512), Image.LANCZOS)
image_tensor = torch.from_numpy(np.array(image)).permute(2, 0, 1).float() / 255.0
image_tensor = image_tensor.unsqueeze(0).to(device)


with torch.no_grad():
    outputs = model(image_tensor)

heatmap_black = outputs[0, 0].cpu().numpy()
heatmap_white = outputs[0, 1].cpu().numpy()

base_name = os.path.splitext(os.path.basename(image_path))[0]


combined_heatmap = np.zeros((512, 512, 3))
combined_heatmap[:, :, 0] = heatmap_black
combined_heatmap[:, :, 2] = heatmap_white

plt.figure(figsize=(8, 8))
plt.imshow(combined_heatmap, vmin=0, vmax=1)
plt.axis('off')
plt.savefig(f"{base_name}_heatmap.png", dpi=150, bbox_inches='tight', pad_inches=0)
plt.close()


plt.figure(figsize=(8, 8))
plt.imshow(heatmap_black, cmap='Reds', vmin=0, vmax=1)
plt.axis('off')
plt.savefig(f"{base_name}_black_cells.png", dpi=150, bbox_inches='tight', pad_inches=0)
plt.close()