In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
from pycocotools.coco import COCO
import cv2
import numpy as np
import os
import matplotlib.pyplot as plt

import warnings
warnings.filterwarnings("ignore")

In [None]:
class CocoKeypoints(Dataset):
    def __init__(self, root, annFile, target_size=(256, 192), flip_prob=0.5):
        self.root = root
        self.coco = COCO(annFile)
        self.target_size = target_size
        self.flip_prob = flip_prob
        self.image_ids = [img_id for img_id in self.coco.imgs.keys() 
                        if len(self.coco.getAnnIds(imgIds=img_id, catIds=1)) > 0]

    def __len__(self):
        return len(self.image_ids)

    def __getitem__(self, idx):
        img_id = self.image_ids[idx]
        ann_ids = self.coco.getAnnIds(imgIds=img_id, catIds=1)
        anns = self.coco.loadAnns(ann_ids)
        ann = max(anns, key=lambda x: x['num_keypoints'])

        img_info = self.coco.loadImgs(img_id)[0]
        img_path = os.path.join(self.root, img_info['file_name'])
        img = cv2.imread(img_path)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        orig_h, orig_w = img.shape[:2]

        img = cv2.resize(img, (self.target_size[1], self.target_size[0]))

        keypoints = np.array(ann['keypoints'], dtype=np.float32).reshape(-1, 3)
        kp = keypoints[:, :2]
        visibility = keypoints[:, 2]

        kp[:, 0] = kp[:, 0] * (self.target_size[1] / orig_w)
        kp[:, 1] = kp[:, 1] * (self.target_size[0] / orig_h)

        if np.random.rand() < self.flip_prob:
            img = img[:, ::-1, :].copy()
            kp[:, 0] = self.target_size[1] - kp[:, 0]
            left = [1, 3, 5, 7, 9, 11, 13, 15]
            right = [2, 4, 6, 8, 10, 12, 14, 16]
            kp[left + right] = kp[right + left]

        img = np.ascontiguousarray(img)
        img = transforms.functional.to_tensor(img)
        img = transforms.functional.normalize(img, 
                                            mean=[0.485, 0.456, 0.406], 
                                            std=[0.229, 0.224, 0.225])

        heatmap_h, heatmap_w = self.target_size[0]//4, self.target_size[1]//4
        heatmaps = np.zeros((17, heatmap_h, heatmap_w), dtype=np.float32)
        
        for i in range(17):
            if visibility[i] > 0:
                x = (kp[i, 0] / self.target_size[1]) * heatmap_w
                y = (kp[i, 1] / self.target_size[0]) * heatmap_h
                heatmaps[i] = self._gaussian_kernel(heatmap_h, heatmap_w, x, y, 2)
        
        return img, torch.tensor(heatmaps, dtype=torch.float32)

    def _gaussian_kernel(self, height, width, x, y, sigma):
        xv, yv = np.meshgrid(np.arange(width), np.arange(height))
        d2 = (xv - x)**2 + (yv - y)**2
        return np.exp(-d2 / (2 * sigma**2))

In [None]:
coco_root = '/kaggle/input/coco-2017-dataset/coco2017'
train_img_dir = os.path.join(coco_root, 'train2017')
train_ann_file = os.path.join(coco_root, 'annotations/person_keypoints_train2017.json')

dataset = CocoKeypoints(train_img_dir, train_ann_file)
train_loader = DataLoader(dataset, batch_size=32, shuffle=True)

In [None]:
for image, heatmap in train_loader:
    print(image.shape)
    print(heatmap.shape)
    break

In [None]:
class PoseNet(nn.Module):
    def __init__(self, num_keypoints=17):
        super(PoseNet, self).__init__()
        resnet = models.resnet50(pretrained=True)
        self.backbone = nn.Sequential(*list(resnet.children())[:-2])

        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(2048, 1024, kernel_size=4, stride=2, padding=1),  # 8x6 -> 16x12
            nn.BatchNorm2d(1024),
            nn.ReLU(inplace=True),
            
            nn.ConvTranspose2d(1024, 512, kernel_size=4, stride=2, padding=1),   # 16x12 -> 32x24
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            
            nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1),    # 32x24 -> 64x48
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            
            nn.Conv2d(256, num_keypoints, kernel_size=3, padding=1)             # Final heatmap
        )

        self._initialize_weights()

    def _initialize_weights(self):
        for m in self.decoder.modules():
            if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
    def forward(self, x):
        x = self.backbone(x)
        x = self.decoder(x)
        return x

In [None]:
from IPython.display import FileLink

def train(model, train_loader, epochs=10, lr=0.001, device='cuda'):
    model.to(device)
    criterion = nn.MSELoss()
    optimizer = optim.Adam(model.parameters(), lr=lr)
    
    for epoch in range(epochs):
        model.train()
        running_loss = 0.0

        for i, (images, heatmaps) in enumerate(train_loader):
            images = images.to(device)
            heatmaps = heatmaps.to(device)

            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, heatmaps)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            if i % 50 == 49:
                print(f'Epoch {epoch+1}, Batch {i+1}: Loss {running_loss/50:.4f}')
                running_loss = 0.0

        # Save the model after each epoch
        model_path = f'model_epoch_{epoch+1}.pth'
        torch.save(model.state_dict(), model_path)
        print(f'Model saved: {model_path}')
        
        # Generate a download link
        display(FileLink(model_path))

        print(f'Epoch {epoch+1} completed')
    
    return model

In [None]:
def detect_keypoints(model, image_path, device='cuda'):
    # Load and preprocess image
    img = cv2.imread(image_path)
    orig_img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    orig_h, orig_w = orig_img.shape[:2]
    
    img = cv2.resize(orig_img, (192, 256))
    img_tensor = transforms.functional.to_tensor(img)
    img_tensor = transforms.functional.normalize(img_tensor,
                                                mean=[0.485, 0.456, 0.406],
                                                std=[0.229, 0.224, 0.225]).unsqueeze(0)
 
    model.eval()
    with torch.no_grad():
        heatmaps = model(img_tensor.to(device)).cpu().numpy()[0]

    keypoints = []
    for i in range(17):
        hm = heatmaps[i]
        y, x = np.unravel_index(hm.argmax(), hm.shape)
        x = (x / 48 * 192) * (orig_w / 192)
        y = (y / 64 * 256) * (orig_h / 256)
        keypoints.append((int(x), int(y)))
 
    plt.figure(figsize=(10, 10))
    plt.imshow(orig_img)
    for i, (x, y) in enumerate(keypoints):
        plt.scatter(x, y, s=50, marker='.', c='red')
    plt.axis('off')
    plt.show()

In [16]:
model = PoseNet()

device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = train(model, train_loader, epochs=10, device=device)

Epoch 1, Batch 50: Loss 2.1402
Epoch 1, Batch 100: Loss 0.0225
Epoch 1, Batch 150: Loss 0.0114
Epoch 1, Batch 200: Loss 0.0071
Epoch 1, Batch 250: Loss 0.0053
Epoch 1, Batch 300: Loss 0.0047
Epoch 1, Batch 350: Loss 0.0048
Epoch 1, Batch 400: Loss 0.0045
Epoch 1, Batch 450: Loss 0.0038
Epoch 1, Batch 500: Loss 0.0036
Epoch 1, Batch 550: Loss 0.0034
Epoch 1, Batch 600: Loss 0.0032
Epoch 1, Batch 650: Loss 0.0036
Epoch 1, Batch 700: Loss 0.0034
Epoch 1, Batch 750: Loss 0.0032
Epoch 1, Batch 800: Loss 0.0032
Epoch 1, Batch 850: Loss 0.0030
Epoch 1, Batch 900: Loss 0.0031
Epoch 1, Batch 950: Loss 0.0030
Epoch 1, Batch 1000: Loss 0.0030
Epoch 1, Batch 1050: Loss 0.0032
Epoch 1, Batch 1100: Loss 0.0032
Epoch 1, Batch 1150: Loss 0.0029
Epoch 1, Batch 1200: Loss 0.0029
Epoch 1, Batch 1250: Loss 0.0027
Epoch 1, Batch 1300: Loss 0.0028
Epoch 1, Batch 1350: Loss 0.0028
Epoch 1, Batch 1400: Loss 0.0027
Epoch 1, Batch 1450: Loss 0.0027
Epoch 1, Batch 1500: Loss 0.0027
Epoch 1, Batch 1550: Loss 0.00

Epoch 1 completed
Epoch 2, Batch 50: Loss 0.0026
Epoch 2, Batch 100: Loss 0.0028
Epoch 2, Batch 150: Loss 0.0027
Epoch 2, Batch 200: Loss 0.0026
Epoch 2, Batch 250: Loss 0.0026
Epoch 2, Batch 300: Loss 0.0027
Epoch 2, Batch 350: Loss 0.0026
Epoch 2, Batch 400: Loss 0.0027
Epoch 2, Batch 450: Loss 0.0026
Epoch 2, Batch 500: Loss 0.0027
Epoch 2, Batch 550: Loss 0.0028
Epoch 2, Batch 600: Loss 0.0029
Epoch 2, Batch 650: Loss 0.0027
Epoch 2, Batch 700: Loss 0.0028
Epoch 2, Batch 750: Loss 0.0027
Epoch 2, Batch 800: Loss 0.0027
Epoch 2, Batch 850: Loss 0.0027
Epoch 2, Batch 900: Loss 0.0026
Epoch 2, Batch 950: Loss 0.0027
Epoch 2, Batch 1000: Loss 0.0026
Epoch 2, Batch 1050: Loss 0.0026
Epoch 2, Batch 1100: Loss 0.0026
Epoch 2, Batch 1150: Loss 0.0026
Epoch 2, Batch 1200: Loss 0.0027
Epoch 2, Batch 1250: Loss 0.0026
Epoch 2, Batch 1300: Loss 0.0026
Epoch 2, Batch 1350: Loss 0.0026
Epoch 2, Batch 1400: Loss 0.0031
Epoch 2, Batch 1450: Loss 0.0028
Epoch 2, Batch 1500: Loss 0.0028
Epoch 2, Bat

Epoch 2 completed
Epoch 3, Batch 50: Loss 0.0031
Epoch 3, Batch 100: Loss 0.0031
Epoch 3, Batch 150: Loss 0.0029
Epoch 3, Batch 200: Loss 0.0034
Epoch 3, Batch 250: Loss 0.0030
Epoch 3, Batch 300: Loss 0.0031
Epoch 3, Batch 350: Loss 0.0030
Epoch 3, Batch 400: Loss 0.0031
Epoch 3, Batch 450: Loss 0.0032
Epoch 3, Batch 500: Loss 0.0028
Epoch 3, Batch 550: Loss 0.0032
Epoch 3, Batch 600: Loss 0.0031
Epoch 3, Batch 650: Loss 0.0031
Epoch 3, Batch 700: Loss 0.0029
Epoch 3, Batch 750: Loss 0.0034
Epoch 3, Batch 800: Loss 0.0032
Epoch 3, Batch 850: Loss 0.0031
Epoch 3, Batch 900: Loss 0.0031
Epoch 3, Batch 950: Loss 0.0030
Epoch 3, Batch 1000: Loss 0.0031
Epoch 3, Batch 1050: Loss 0.0030
Epoch 3, Batch 1100: Loss 0.0030
Epoch 3, Batch 1150: Loss 0.0034
Epoch 3, Batch 1200: Loss 0.0030
Epoch 3, Batch 1250: Loss 0.0030
Epoch 3, Batch 1300: Loss 0.0028
Epoch 3, Batch 1350: Loss 0.0033
Epoch 3, Batch 1400: Loss 0.0029
Epoch 3, Batch 1450: Loss 0.0029
Epoch 3, Batch 1500: Loss 0.0032
Epoch 3, Bat

Epoch 3 completed
Epoch 4, Batch 50: Loss 0.0030
Epoch 4, Batch 100: Loss 0.0030
Epoch 4, Batch 150: Loss 0.0029
Epoch 4, Batch 200: Loss 0.0029
Epoch 4, Batch 250: Loss 0.0031
Epoch 4, Batch 300: Loss 0.0033
Epoch 4, Batch 350: Loss 0.0030
Epoch 4, Batch 400: Loss 0.0028
Epoch 4, Batch 450: Loss 0.0027
Epoch 4, Batch 500: Loss 0.0033
Epoch 4, Batch 550: Loss 0.0027
Epoch 4, Batch 600: Loss 0.0029
Epoch 4, Batch 650: Loss 0.0028
Epoch 4, Batch 700: Loss 0.0030
Epoch 4, Batch 750: Loss 0.0029
Epoch 4, Batch 800: Loss 0.0030
Epoch 4, Batch 850: Loss 0.0029
Epoch 4, Batch 900: Loss 0.0027
Epoch 4, Batch 950: Loss 0.0031
Epoch 4, Batch 1000: Loss 0.0028
Epoch 4, Batch 1050: Loss 0.0029
Epoch 4, Batch 1100: Loss 0.0029
Epoch 4, Batch 1150: Loss 0.0028
Epoch 4, Batch 1200: Loss 0.0029
Epoch 4, Batch 1250: Loss 0.0029
Epoch 4, Batch 1300: Loss 0.0029
Epoch 4, Batch 1350: Loss 0.0028
Epoch 4, Batch 1400: Loss 0.0028
Epoch 4, Batch 1450: Loss 0.0029
Epoch 4, Batch 1500: Loss 0.0029
Epoch 4, Bat

Epoch 4 completed
Epoch 5, Batch 50: Loss 0.0028
Epoch 5, Batch 100: Loss 0.0027
Epoch 5, Batch 150: Loss 0.0029
Epoch 5, Batch 200: Loss 0.0029
Epoch 5, Batch 250: Loss 0.0027
Epoch 5, Batch 300: Loss 0.0028
Epoch 5, Batch 350: Loss 0.0027
Epoch 5, Batch 400: Loss 0.0028
Epoch 5, Batch 450: Loss 0.0027
Epoch 5, Batch 500: Loss 0.0026
Epoch 5, Batch 550: Loss 0.0028
Epoch 5, Batch 600: Loss 0.0029
Epoch 5, Batch 650: Loss 0.0028
Epoch 5, Batch 700: Loss 0.0028
Epoch 5, Batch 750: Loss 0.0028
Epoch 5, Batch 800: Loss 0.0027
Epoch 5, Batch 850: Loss 0.0029
Epoch 5, Batch 900: Loss 0.0027
Epoch 5, Batch 950: Loss 0.0028
Epoch 5, Batch 1000: Loss 0.0026
Epoch 5, Batch 1050: Loss 0.0026
Epoch 5, Batch 1100: Loss 0.0026
Epoch 5, Batch 1150: Loss 0.0028
Epoch 5, Batch 1200: Loss 0.0027
Epoch 5, Batch 1250: Loss 0.0026
Epoch 5, Batch 1300: Loss 0.0027
Epoch 5, Batch 1350: Loss 0.0026
Epoch 5, Batch 1400: Loss 0.0027
Epoch 5, Batch 1450: Loss 0.0027
Epoch 5, Batch 1500: Loss 0.0027
Epoch 5, Bat

Epoch 5 completed
Epoch 6, Batch 50: Loss 0.0027
Epoch 6, Batch 100: Loss 0.0026
Epoch 6, Batch 150: Loss 0.0026
Epoch 6, Batch 200: Loss 0.0026
Epoch 6, Batch 250: Loss 0.0027
Epoch 6, Batch 300: Loss 0.0026
Epoch 6, Batch 350: Loss 0.0027
Epoch 6, Batch 400: Loss 0.0027
Epoch 6, Batch 450: Loss 0.0026
Epoch 6, Batch 500: Loss 0.0027
Epoch 6, Batch 550: Loss 0.0027
Epoch 6, Batch 600: Loss 0.0028
Epoch 6, Batch 650: Loss 0.0027
Epoch 6, Batch 700: Loss 0.0027
Epoch 6, Batch 750: Loss 0.0027
Epoch 6, Batch 800: Loss 0.0027
Epoch 6, Batch 850: Loss 0.0027
Epoch 6, Batch 900: Loss 0.0027
Epoch 6, Batch 950: Loss 0.0026
Epoch 6, Batch 1000: Loss 0.0027
Epoch 6, Batch 1050: Loss 0.0027
Epoch 6, Batch 1100: Loss 0.0026
Epoch 6, Batch 1150: Loss 0.0026
Epoch 6, Batch 1200: Loss 0.0026
Epoch 6, Batch 1250: Loss 0.0026
Epoch 6, Batch 1300: Loss 0.0026
Epoch 6, Batch 1350: Loss 0.0026
Epoch 6, Batch 1400: Loss 0.0026
Epoch 6, Batch 1450: Loss 0.0026
Epoch 6, Batch 1500: Loss 0.0026
Epoch 6, Bat

Epoch 6 completed
Epoch 7, Batch 50: Loss 0.0027
Epoch 7, Batch 100: Loss 0.0025
Epoch 7, Batch 150: Loss 0.0026
Epoch 7, Batch 200: Loss 0.0027
Epoch 7, Batch 250: Loss 0.0026
Epoch 7, Batch 300: Loss 0.0026
Epoch 7, Batch 350: Loss 0.0026
Epoch 7, Batch 400: Loss 0.0026
Epoch 7, Batch 450: Loss 0.0026
Epoch 7, Batch 500: Loss 0.0025
Epoch 7, Batch 550: Loss 0.0026
Epoch 7, Batch 600: Loss 0.0026
Epoch 7, Batch 650: Loss 0.0026
Epoch 7, Batch 700: Loss 0.0026
Epoch 7, Batch 750: Loss 0.0026
Epoch 7, Batch 800: Loss 0.0026
Epoch 7, Batch 850: Loss 0.0026
Epoch 7, Batch 900: Loss 0.0026
Epoch 7, Batch 950: Loss 0.0025
Epoch 7, Batch 1000: Loss 0.0026
Epoch 7, Batch 1050: Loss 0.0026
Epoch 7, Batch 1100: Loss 0.0026
Epoch 7, Batch 1150: Loss 0.0025
Epoch 7, Batch 1200: Loss 0.0026
Epoch 7, Batch 1250: Loss 0.0026
Epoch 7, Batch 1300: Loss 0.0025
Epoch 7, Batch 1350: Loss 0.0026
Epoch 7, Batch 1400: Loss 0.0025
Epoch 7, Batch 1450: Loss 0.0026
Epoch 7, Batch 1500: Loss 0.0026
Epoch 7, Bat

Epoch 7 completed
Epoch 8, Batch 50: Loss 0.0025
Epoch 8, Batch 100: Loss 0.0026
Epoch 8, Batch 150: Loss 0.0025
Epoch 8, Batch 200: Loss 0.0025
Epoch 8, Batch 250: Loss 0.0025
Epoch 8, Batch 300: Loss 0.0025
Epoch 8, Batch 350: Loss 0.0025
Epoch 8, Batch 400: Loss 0.0025
Epoch 8, Batch 450: Loss 0.0026
Epoch 8, Batch 500: Loss 0.0026
Epoch 8, Batch 550: Loss 0.0025
Epoch 8, Batch 600: Loss 0.0025
Epoch 8, Batch 650: Loss 0.0025
Epoch 8, Batch 700: Loss 0.0025
Epoch 8, Batch 750: Loss 0.0025
Epoch 8, Batch 800: Loss 0.0026
Epoch 8, Batch 850: Loss 0.0026
Epoch 8, Batch 900: Loss 0.0026
Epoch 8, Batch 950: Loss 0.0025
Epoch 8, Batch 1000: Loss 0.0026
Epoch 8, Batch 1050: Loss 0.0025
Epoch 8, Batch 1100: Loss 0.0025
Epoch 8, Batch 1150: Loss 0.0025
Epoch 8, Batch 1200: Loss 0.0026
Epoch 8, Batch 1250: Loss 0.0025
Epoch 8, Batch 1300: Loss 0.0025
Epoch 8, Batch 1350: Loss 0.0025
Epoch 8, Batch 1400: Loss 0.0025
Epoch 8, Batch 1450: Loss 0.0026
Epoch 8, Batch 1500: Loss 0.0025
Epoch 8, Bat

Epoch 8 completed
Epoch 9, Batch 50: Loss 0.0025
Epoch 9, Batch 100: Loss 0.0026
Epoch 9, Batch 150: Loss 0.0026
Epoch 9, Batch 200: Loss 0.0025
Epoch 9, Batch 250: Loss 0.0025
Epoch 9, Batch 300: Loss 0.0026
Epoch 9, Batch 350: Loss 0.0025
Epoch 9, Batch 400: Loss 0.0025
Epoch 9, Batch 450: Loss 0.0026
Epoch 9, Batch 500: Loss 0.0025
Epoch 9, Batch 550: Loss 0.0025
Epoch 9, Batch 600: Loss 0.0025
Epoch 9, Batch 650: Loss 0.0025
Epoch 9, Batch 700: Loss 0.0025
Epoch 9, Batch 750: Loss 0.0025
Epoch 9, Batch 800: Loss 0.0025
Epoch 9, Batch 850: Loss 0.0026
Epoch 9, Batch 900: Loss 0.0025
Epoch 9, Batch 950: Loss 0.0025
Epoch 9, Batch 1000: Loss 0.0025
Epoch 9, Batch 1050: Loss 0.0025
Epoch 9, Batch 1100: Loss 0.0025
Epoch 9, Batch 1150: Loss 0.0025
Epoch 9, Batch 1200: Loss 0.0025
Epoch 9, Batch 1250: Loss 0.0025
Epoch 9, Batch 1300: Loss 0.0026
Epoch 9, Batch 1350: Loss 0.0026
Epoch 9, Batch 1400: Loss 0.0026
Epoch 9, Batch 1450: Loss 0.0026
Epoch 9, Batch 1500: Loss 0.0025
Epoch 9, Bat

Epoch 9 completed
Epoch 10, Batch 50: Loss 0.0025
Epoch 10, Batch 100: Loss 0.0025
Epoch 10, Batch 150: Loss 0.0025
Epoch 10, Batch 200: Loss 0.0026
Epoch 10, Batch 250: Loss 0.0026
Epoch 10, Batch 300: Loss 0.0025
Epoch 10, Batch 350: Loss 0.0025
Epoch 10, Batch 400: Loss 0.0025
Epoch 10, Batch 450: Loss 0.0026
Epoch 10, Batch 500: Loss 0.0025
Epoch 10, Batch 550: Loss 0.0026
Epoch 10, Batch 600: Loss 0.0025
Epoch 10, Batch 650: Loss 0.0026
Epoch 10, Batch 700: Loss 0.0025
Epoch 10, Batch 900: Loss 0.0026
Epoch 10, Batch 950: Loss 0.0025
Epoch 10, Batch 1000: Loss 0.0025
Epoch 10, Batch 1050: Loss 0.0025
Epoch 10, Batch 1100: Loss 0.0025
Epoch 10, Batch 1150: Loss 0.0025
Epoch 10, Batch 1200: Loss 0.0026
Epoch 10, Batch 1250: Loss 0.0026
Epoch 10, Batch 1300: Loss 0.0026
Epoch 10, Batch 1350: Loss 0.0025
Epoch 10, Batch 1400: Loss 0.0025
Epoch 10, Batch 1450: Loss 0.0025
Epoch 10, Batch 1500: Loss 0.0025
Epoch 10, Batch 1550: Loss 0.0025
Epoch 10, Batch 1600: Loss 0.0025
Epoch 10, Bat

Epoch 10 completed
