In [11]:
import cv2
import torch
import numpy as np
from torchvision import transforms as T
from torchvision.models.segmentation import deeplabv3_resnet101

In [13]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [14]:
model = deeplabv3_resnet101(pretrained=True).eval().to(device)

In [15]:
transform = T.Compose([
    T.ToPILImage(),
    T.Resize((360, 640)),
    T.ToTensor(),
    T.Normalize(mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225])
])

In [16]:
CLASS_COLORS = np.random.randint(0, 255, (21, 3), dtype=np.uint8)

In [17]:
def decode_segmap(segmentation):
    """Convert class indices to RGB mask."""
    seg_rgb = np.zeros((segmentation.shape[0], segmentation.shape[1], 3), dtype=np.uint8)
    for label in np.unique(segmentation):
        seg_rgb[segmentation == label] = CLASS_COLORS[label]
    return seg_rgb

In [18]:
cap = cv2.VideoCapture(0)

In [None]:
    
target = input_tensor.argmax(1).squeeze().cpu().numpy()
accuracy = (segmentation == target).mean()
print(f"Frame accuracy: {accuracy:.2%}", end='\\r')

In [19]:
if not cap.isOpened():
    print("Error: Could not open webcam.")
    exit()

print("Press 'q' to exit.")

Press 'q' to exit.


In [None]:
while True:
    ret, frame = cap.read()
    if not ret:
        print("Failed to grab frame.")
        break

    input_tensor = transform(frame).unsqueeze(0).to(device)

    with torch.no_grad():
        output = model(input_tensor)['out']
        segmentation = torch.argmax(output.squeeze(), dim=0).cpu().numpy()

    seg_rgb = decode_segmap(segmentation)
    seg_rgb = cv2.resize(seg_rgb, (frame.shape[1], frame.shape[0]))

    # Blend original frame with color mask
    overlay = cv2.addWeighted(frame, 0.6, seg_rgb, 0.4, 0)

    cv2.imshow("Multi-Object Colorization", overlay)

    if cv2.waitKey(1) & 0xFF == ord('q'):
        break

In [None]:
cap.release()
cv2.destroyAllWindows()

In [None]:

torch.save(model.state_dict(), 'deeplabv3_resnet101.pth')

print("\\nModel saved. Class color mapping:")
for i, color in enumerate(CLASS_COLORS):
    print(f"Class {i}: RGB {color.tolist()}")