In [1]:
import torch

if torch.cuda.is_available():
    print("GPU is available.")
    device = torch.device("cuda")
else:
    print("GPU is not available.")
    device = torch.device("cpu")

GPU is available.


In [2]:
# 필요한 라이브러리 import 

import os
import sys
import cv2
import numpy as np
from typing import Tuple, Sequence, Callable, Dict

import torch
from torch import Tensor
from torch.utils.data import Dataset

from torch import nn
# from torchvision.models.detection import keypointrcnn_resnet50_fpn
from torchvision.models.detection.backbone_utils import resnet_fpn_backbone
from torchvision.ops import MultiScaleRoIAlign
from torchvision.models.detection import KeypointRCNN

import pandas as pd
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split

from typing import Tuple 
import albumentations as A
from albumentations.pytorch import ToTensorV2

import time
import torch.optim as optim


In [3]:
# 학습할 수 있도록 데이터 전처리 클래스

class KeypointDataset(Dataset):
    def __init__(
        self,
        image_dir: os.PathLike,
        label_df: pd.DataFrame,
        transforms: Sequence[Callable]=None
    ) -> None:
        self.image_dir = image_dir
        self.df = label_df
        self.transforms = transforms

    def __len__(self) -> int:
        return self.df.shape[0]
    
    def __getitem__(self, index: int) -> Tuple[Tensor, Dict]:
        image_id = self.df.iloc[index, 1]
        labels = np.array([1])

        keypoints = self.df.iloc[index, 2:].values.reshape(-1, 2).astype(np.int64)
    
        x1, y1 = min(keypoints[:, 0]), min(keypoints[:, 1])
        x2, y2 = max(keypoints[:, 0]), max(keypoints[:, 1])
        

        
        boxes = np.array([[x1, y1, x2, y2]], dtype=np.int64)
        image_path = os.path.join(self.image_dir, image_id)
        img_array = np.fromfile(image_path, np.uint8)
        image = cv2.imdecode(img_array, cv2.COLOR_BGR2RGB)
        targets ={
            'image': image,
            'bboxes': boxes,
            'labels': labels,
            'keypoints': keypoints
        }
#         print('bboxes : ' , boxes)
        
        if self.transforms is not None:
            targets = self.transforms(**targets)

        image = targets['image']
        image = image / 255.0
   

        targets = {
            'labels': torch.as_tensor(targets['labels'], dtype=torch.int64),
            'boxes': torch.as_tensor(targets['bboxes'], dtype=torch.float32),
            'keypoints': torch.as_tensor(
                np.concatenate([targets['keypoints'], np.ones((15, 1))], axis=1)[np.newaxis], dtype=torch.float32
            )
        }

        return image, targets

In [4]:
# 데이터를 저장한 파일을 불러와서 학습용/ 검즘용 데이터로 나눠서 전처리 함수에 전달

def collate_fn(batch: torch.Tensor)->Tuple:
    return tuple(zip(*batch))

# Data Transform & Train-Test-Split
def load_data(image_dir, train_key, valid_key):
    transforms = A.Compose([
        # A.Resize(500, 500, always_apply=True),
        A.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
        ToTensorV2()
    ],  bbox_params=A.BboxParams(format='pascal_voc', label_fields=['labels']),
        keypoint_params=A.KeypointParams(format='xy')
    )

    trainset = KeypointDataset(image_dir, train_key, transforms)
    validset = KeypointDataset(image_dir, valid_key, transforms)
    train_loader = DataLoader(trainset, batch_size=2, shuffle=True, collate_fn=collate_fn)
    valid_loader = DataLoader(validset, batch_size=2, shuffle=False, collate_fn=collate_fn)

    return train_loader, valid_loader

In [5]:
def get_model() -> nn.Module:
    backbone = resnet_fpn_backbone('resnet50', pretrained=True, trainable_layers = 2) # resnet101, resnet152 
    roi_pooler = MultiScaleRoIAlign(
        featmap_names=['0', '1', '2', '3'],
        output_size=7,
        sampling_ratio=2
    )

    keypoint_roi_pooler = MultiScaleRoIAlign(
        featmap_names=['0', '1', '2', '3'],
        output_size=14,
        sampling_ratio=2
    )

    model = KeypointRCNN(
        backbone, 
        num_classes=2,
        num_keypoints=15,
        box_roi_pool=roi_pooler,
        keypoint_roi_pool=keypoint_roi_pooler
    )

    return model

In [7]:
def train(model, train_loader, optimizer, epoch, device = 'cuda'):
    model.train()                                        
    total_loss = 0
    for batch_idx, (images, targets) in enumerate(train_loader):
        images = list(image.to(device) for image in images)
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]  

        optimizer.zero_grad()
        losses = model(images, targets)
        loss = losses['loss_keypoint']
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

        if (batch_idx+1) % 200 == 0:
            print(f'| epoch: {epoch} | batch: {batch_idx+1}/{len(train_loader)} | batch loss: {loss.item()}')

    return total_loss / len(train_loader)

def evaluate(model, test_loader, device = 'cuda'):
    model.train()      
    test_loss = 0      # test_loss 초기화
    
    with torch.no_grad(): 
        for images, targets in test_loader:
            # data, target 값 DEVICE에 할당
            images = list(image.to(device) for image in images)
            targets = [{k: v.to(device) for k, v in t.items()} for t in targets]  

            losses = model(images, targets)                       # validation loss
            test_loss += float(losses['loss_keypoint'])           # sum of all loss 
    
    test_loss /= len(test_loader.dataset)                         # 평균 loss
    return test_loss

def train_model(train_loader, val_loader, model_path=None, num_epochs=30, device='cuda'):
    if not os.path.exists(model_path):
        model = get_model()
    else:
        model = torch.load(model_path)
    model.to(device)
    
    best_loss = 999999  # initialize best loss
    optimizer = optim.SGD(model.parameters(), lr=1e-4, momentum=0.9, weight_decay=5e-4)

    for epoch in range(1, num_epochs+1):
        since = time.time()
        train(model, train_loader, optimizer, epoch, device)
        train_loss = train(model, train_loader, optimizer, epoch, device)
        val_loss = evaluate(model, val_loader)
        print('Train Keypoint Loss (avg): {:.4f}'.format(train_loss))

        if val_loss <= best_loss:   # update best loss
            best_loss = val_loss
            torch.save(model, './models/RCNN_ep'+str(epoch)+'_'+str(best_loss)+'.pt')
            print('Best Model Saved, Loss: ', val_loss)
        
        time_elapsed = time.time()-since
        print()
        print('---------------------- epoch {} ------------------------'.format(epoch))
        print('Train Keypoint Loss: {:.4f}, Val Keypoint Loss: {:.4f}'.format(train_loss, val_loss))   
        print('Completed in {:.0f}m {:.0f}s'.format(time_elapsed//60, time_elapsed%60))
        print()

def main():
    current_folder = globals()['_dh'][0]
    path = os.path.dirname(os.path.join(current_folder,''))
    os.chdir(path)
    
    DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

    train_img_path = './images'
    train_key_path = './filename.csv'

    # Load the entire DataFrame and split it into parts.
    total_df = pd.read_csv(train_key_path)
    num_parts = 1
    data_parts = np.array_split(total_df, num_parts)

    model_path = None

    # Train on each part sequentially.
    for i in range(num_parts):
        print(f"Training on part {i + 1}/{num_parts}")
        train_key, valid_key = train_test_split(data_parts[i], test_size=0.3, random_state=42)
        train_loader, valid_loader = load_data(train_img_path, train_key, valid_key)

        model_path = f"./models/RCNN_part{i + 1}.pt"
        train_model(train_loader, valid_loader, model_path=model_path, num_epochs=30, device=DEVICE)
    '''
    default: epoch - 30, 
             device - cuda
    '''


In [8]:
main()

Train Keypoint Loss (avg): 2.7222

---------------------- epoch 29 ------------------------
Train Keypoint Loss: 2.7222, Val Keypoint Loss: 1.7162
Completed in 745m 53s

| epoch: 30 | batch: 200/74328 | batch loss: 3.0117721557617188
| epoch: 30 | batch: 400/74328 | batch loss: 2.0546228885650635
| epoch: 30 | batch: 600/74328 | batch loss: 2.38338565826416
| epoch: 30 | batch: 800/74328 | batch loss: 2.2934749126434326
| epoch: 30 | batch: 1000/74328 | batch loss: 2.365713357925415
| epoch: 30 | batch: 1200/74328 | batch loss: 2.466822385787964
| epoch: 30 | batch: 1400/74328 | batch loss: 2.49348521232605
| epoch: 30 | batch: 1600/74328 | batch loss: 2.356402635574341
| epoch: 30 | batch: 1800/74328 | batch loss: 3.455949068069458
| epoch: 30 | batch: 2000/74328 | batch loss: 2.2623848915100098
| epoch: 30 | batch: 2200/74328 | batch loss: 2.1707465648651123
| epoch: 30 | batch: 2400/74328 | batch loss: 3.1435694694519043
| epoch: 30 | batch: 2600/74328 | batch loss: 1.97807407379150

| epoch: 30 | batch: 25000/74328 | batch loss: 2.418539047241211
| epoch: 30 | batch: 25200/74328 | batch loss: 2.69354510307312
| epoch: 30 | batch: 25400/74328 | batch loss: 2.960627794265747
| epoch: 30 | batch: 25600/74328 | batch loss: 2.3154726028442383
| epoch: 30 | batch: 25800/74328 | batch loss: 2.810124397277832
| epoch: 30 | batch: 26000/74328 | batch loss: 4.4683732986450195
| epoch: 30 | batch: 26200/74328 | batch loss: 2.283747673034668
| epoch: 30 | batch: 26400/74328 | batch loss: 2.493077039718628
| epoch: 30 | batch: 26600/74328 | batch loss: 2.306143283843994
| epoch: 30 | batch: 26800/74328 | batch loss: 2.80708384513855
| epoch: 30 | batch: 27000/74328 | batch loss: 2.955204486846924
| epoch: 30 | batch: 27200/74328 | batch loss: 1.8457856178283691
| epoch: 30 | batch: 27400/74328 | batch loss: 2.534271478652954
| epoch: 30 | batch: 27600/74328 | batch loss: 3.5484910011291504
| epoch: 30 | batch: 27800/74328 | batch loss: 1.9277796745300293
| epoch: 30 | batch: 2

| epoch: 30 | batch: 50200/74328 | batch loss: 3.3694891929626465
| epoch: 30 | batch: 50400/74328 | batch loss: 2.9172677993774414
| epoch: 30 | batch: 50600/74328 | batch loss: 3.7928881645202637
| epoch: 30 | batch: 50800/74328 | batch loss: 2.5667877197265625
| epoch: 30 | batch: 51000/74328 | batch loss: 3.6441967487335205
| epoch: 30 | batch: 51200/74328 | batch loss: 2.204216241836548
| epoch: 30 | batch: 51400/74328 | batch loss: 2.278938055038452
| epoch: 30 | batch: 51600/74328 | batch loss: 1.7713336944580078
| epoch: 30 | batch: 51800/74328 | batch loss: 1.9769757986068726
| epoch: 30 | batch: 52000/74328 | batch loss: 1.968371033668518
| epoch: 30 | batch: 52200/74328 | batch loss: 2.9566643238067627
| epoch: 30 | batch: 52400/74328 | batch loss: 4.360531330108643
| epoch: 30 | batch: 52600/74328 | batch loss: 2.651684522628784
| epoch: 30 | batch: 52800/74328 | batch loss: 2.54831862449646
| epoch: 30 | batch: 53000/74328 | batch loss: 2.555910348892212
| epoch: 30 | batc

| epoch: 30 | batch: 1200/74328 | batch loss: 2.936589002609253
| epoch: 30 | batch: 1400/74328 | batch loss: 2.567556381225586
| epoch: 30 | batch: 1600/74328 | batch loss: 2.390641450881958
| epoch: 30 | batch: 1800/74328 | batch loss: 2.333359956741333
| epoch: 30 | batch: 2000/74328 | batch loss: 3.331517457962036
| epoch: 30 | batch: 2200/74328 | batch loss: 1.829755187034607
| epoch: 30 | batch: 2400/74328 | batch loss: 2.6339423656463623
| epoch: 30 | batch: 2600/74328 | batch loss: 2.5312702655792236
| epoch: 30 | batch: 2800/74328 | batch loss: 2.803135395050049
| epoch: 30 | batch: 3000/74328 | batch loss: 3.899508476257324
| epoch: 30 | batch: 3200/74328 | batch loss: 2.845576763153076
| epoch: 30 | batch: 3400/74328 | batch loss: 2.3201980590820312
| epoch: 30 | batch: 3600/74328 | batch loss: 2.7210001945495605
| epoch: 30 | batch: 3800/74328 | batch loss: 2.6502623558044434
| epoch: 30 | batch: 4000/74328 | batch loss: 2.235126495361328
| epoch: 30 | batch: 4200/74328 | b

| epoch: 30 | batch: 26400/74328 | batch loss: 4.419778823852539
| epoch: 30 | batch: 26600/74328 | batch loss: 2.473423480987549
| epoch: 30 | batch: 26800/74328 | batch loss: 3.0484859943389893
| epoch: 30 | batch: 27000/74328 | batch loss: 3.202897310256958
| epoch: 30 | batch: 27200/74328 | batch loss: 2.5265324115753174
| epoch: 30 | batch: 27400/74328 | batch loss: 2.839341878890991
| epoch: 30 | batch: 27600/74328 | batch loss: 3.122657060623169
| epoch: 30 | batch: 27800/74328 | batch loss: 3.7506020069122314
| epoch: 30 | batch: 28000/74328 | batch loss: 1.9508042335510254
| epoch: 30 | batch: 28200/74328 | batch loss: 1.88569974899292
| epoch: 30 | batch: 28400/74328 | batch loss: 2.1049468517303467
| epoch: 30 | batch: 28600/74328 | batch loss: 2.4604616165161133
| epoch: 30 | batch: 28800/74328 | batch loss: 3.3811206817626953
| epoch: 30 | batch: 29000/74328 | batch loss: 2.7187821865081787
| epoch: 30 | batch: 29200/74328 | batch loss: 2.5065298080444336
| epoch: 30 | bat

| epoch: 30 | batch: 51400/74328 | batch loss: 3.0220561027526855
| epoch: 30 | batch: 51600/74328 | batch loss: 2.1246721744537354
| epoch: 30 | batch: 51800/74328 | batch loss: 2.569762945175171
| epoch: 30 | batch: 52000/74328 | batch loss: 3.3981142044067383
| epoch: 30 | batch: 52200/74328 | batch loss: 2.06313419342041
| epoch: 30 | batch: 52400/74328 | batch loss: 2.6149075031280518
| epoch: 30 | batch: 52600/74328 | batch loss: 3.081749200820923
| epoch: 30 | batch: 52800/74328 | batch loss: 4.1751203536987305
| epoch: 30 | batch: 53000/74328 | batch loss: 2.1405935287475586
| epoch: 30 | batch: 53200/74328 | batch loss: 4.416858196258545
| epoch: 30 | batch: 53400/74328 | batch loss: 3.541029453277588
| epoch: 30 | batch: 53600/74328 | batch loss: 2.3038017749786377
| epoch: 30 | batch: 53800/74328 | batch loss: 2.3277804851531982
| epoch: 30 | batch: 54000/74328 | batch loss: 1.4124348163604736
| epoch: 30 | batch: 54200/74328 | batch loss: 3.3595423698425293
| epoch: 30 | ba