In [None]:
# ====================================================================================
# 0. 라이브러리 임포트  학습용 코드
# ====================================================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.models import mobilenet_v3_large
import numpy as np
import cv2
import math

# ====================================================================================
# Part 1: 텍스트 탐지 모듈 (DBNet) 및 손실 함수 (DBLoss)
# ====================================================================================

# 1-1. DBNet 아키텍처
class MobileNetV3_Backbone(nn.Module):
    def __init__(self, pretrained=False):
        super().__init__()
        features = mobilenet_v3_large(pretrained=pretrained).features
        self.out_layers = nn.ModuleList([
            features[0:4], features[4:7], features[7:13], features[13:17]
        ])
        self.out_channels = [40, 80, 160, 960]

    def forward(self, x):
        outputs = [layer(x) for layer in self.out_layers]
        return outputs

class DB_FPN(nn.Module):
    def __init__(self, in_channels, out_channels=256):
        super().__init__()
        self.in_convs = nn.ModuleList([nn.Conv2d(c, out_channels, 1, bias=False) for c in in_channels])
        self.out_convs = nn.ModuleList([nn.Conv2d(out_channels, out_channels // 4, 3, padding=1, bias=False) for _ in in_channels])

    def forward(self, features):
        inner_features = [conv(f) for conv, f in zip(self.in_convs, features)]
        p4 = inner_features[3]
        p3 = F.interpolate(p4, scale_factor=2) + inner_features[2]
        p2 = F.interpolate(p3, scale_factor=2) + inner_features[1]
        p1 = F.interpolate(p2, scale_factor=2) + inner_features[0]
        
        final_features = [
            self.out_convs[0](p1),
            self.out_convs[1](F.interpolate(p2, scale_factor=2)),
            self.out_convs[2](F.interpolate(p3, scale_factor=4)),
            self.out_convs[3](F.interpolate(p4, scale_factor=8)),
        ]
        return torch.cat(final_features, dim=1)

class DB_Head(nn.Module):
    def __init__(self, in_channels=256):
        super().__init__()
        self.prob_conv = nn.Sequential(nn.Conv2d(in_channels, in_channels // 4, 3, padding=1), nn.ReLU(), nn.ConvTranspose2d(in_channels // 4, 1, 2, 2))
        self.thresh_conv = nn.Sequential(nn.Conv2d(in_channels, in_channels // 4, 3, padding=1), nn.ReLU(), nn.ConvTranspose2d(in_channels // 4, 1, 2, 2))

    def forward(self, x):
        prob_map = torch.sigmoid(self.prob_conv(x))
        thresh_map = torch.sigmoid(self.thresh_conv(x))
        return prob_map, thresh_map

class DBNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.backbone = MobileNetV3_Backbone()
        self.fpn = DB_FPN(in_channels=self.backbone.out_channels)
        self.head = DB_Head(in_channels=256)

    def forward(self, x):
        return self.head(self.fpn(self.backbone(x)))

# 1-2. DBNet 손실 함수
class DiceLoss(nn.Module):
    def __init__(self, eps=1e-6):
        super().__init__()
        self.eps = eps

    def forward(self, pred, target, mask):
        pred, target = pred * mask, target * mask
        intersection = (pred * target).sum()
        union = pred.sum() + target.sum() + self.eps
        return 1 - (2 * intersection / union)

class DBLoss(nn.Module):
    def __init__(self, alpha=1.0, beta=10.0, k=50):
        super().__init__()
        self.alpha, self.beta, self.k = alpha, beta, k
        self.bce_loss = nn.BCELoss(reduction='none')
        self.dice_loss = DiceLoss()

    def forward(self, pred_maps, gt_maps):
        pred_prob, pred_thresh = pred_maps
        gt_prob, gt_prob_mask, gt_thresh, gt_thresh_mask = gt_maps

        # L_s (BCE + Dice)
        bce_prob_loss = (self.bce_loss(pred_prob, gt_prob) * gt_prob_mask).mean()
        dice_prob_loss = self.dice_loss(pred_prob, gt_prob, gt_prob_mask)
        loss_s = bce_prob_loss + dice_prob_loss

        # L_b (Dice on DB map)
        db_map = 1 / (1 + torch.exp(-self.k * (pred_prob - pred_thresh)))
        loss_b = self.dice_loss(db_map, gt_prob, gt_prob_mask)

        # L_t (L1 on Threshold map)
        loss_t = (torch.abs(pred_thresh - gt_thresh) * gt_thresh_mask).sum() / (gt_thresh_mask.sum() + 1e-6)

        return loss_s + self.alpha * loss_b + self.beta * loss_t

# ====================================================================================
# Part 2: 텍스트 인식 모듈 (CRNN)
# ====================================================================================
class CRNN(nn.Module):
    def __init__(self, num_chars, rnn_hidden_size=256):
        super().__init__()
        self.cnn = nn.Sequential(
            nn.Conv2d(3, 64, 3, 1, 1), nn.ReLU(True), nn.MaxPool2d(2, 2),
            nn.Conv2d(64, 128, 3, 1, 1), nn.ReLU(True), nn.MaxPool2d(2, 2),
            nn.Conv2d(128, 256, 3, 1, 1), nn.BatchNorm2d(256), nn.ReLU(True), nn.MaxPool2d((2, 1), (2, 1)),
            nn.Conv2d(256, 512, 3, 1, 1), nn.BatchNorm2d(512), nn.ReLU(True), nn.MaxPool2d((2, 1), (2, 1)),
            nn.Conv2d(512, 512, 2, 1, 0), nn.BatchNorm2d(512), nn.ReLU(True)
        )
        self.rnn = nn.LSTM(512, rnn_hidden_size, bidirectional=True, num_layers=2)
        self.classifier = nn.Linear(rnn_hidden_size * 2, num_chars)

    def forward(self, x):
        conv_features = self.cnn(x).squeeze(2).permute(2, 0, 1)
        rnn_output, _ = self.rnn(conv_features)
        return self.classifier(rnn_output)

# ====================================================================================
# Part 3: 데이터 시뮬레이션 (더미 데이터 생성)
# ====================================================================================
def generate_dummy_ocr_image(path="dummy_ocr_test.png"):
    """추론 테스트를 위한 더미 이미지 생성"""
    image = np.ones((600, 800, 3), np.uint8) * 255
    cv2.putText(image, 'hello ocr 123', (100, 200), cv2.FONT_HERSHEY_SIMPLEX, 2, (0, 0, 0), 5)
    cv2.putText(image, 'pytorch', (150, 400), cv2.FONT_HERSHEY_SIMPLEX, 2, (0, 0, 0), 5)
    cv2.imwrite(path, image)
    return path

def generate_dummy_training_batch(batch_size, img_size, character_list):
    """학습 과정을 시뮬레이션하기 위한 더미 배치 데이터 생성"""
    # 1. 탐지 모델용 데이터
    det_images = torch.rand(batch_size, 3, img_size, img_size)
    gt_prob = torch.rand(batch_size, 1, img_size, img_size) > 0.8
    gt_prob_mask = torch.ones(batch_size, 1, img_size, img_size)
    gt_thresh = torch.rand(batch_size, 1, img_size, img_size)
    gt_thresh_mask = gt_prob.clone()
    db_gts = (gt_prob.float(), gt_prob_mask.float(), gt_thresh.float(), gt_thresh_mask.float())
    
    # 2. 인식 모델용 데이터 (잘라낸 단어 이미지라고 가정)
    rec_images = torch.rand(batch_size, 3, 32, 100) # (B, C, H, W)
    
    # 더미 텍스트 생성
    gt_texts = []
    for _ in range(batch_size):
        length = np.random.randint(3, 10)
        text = "".join(np.random.choice(list(character_list), length))
        gt_texts.append(text)
        
    return det_images, db_gts, rec_images, gt_texts

# ====================================================================================
# Part 4: OCR 파이프라인 (추론용)
# ====================================================================================
class OCR_Pipeline:
    def __init__(self, detector, recognizer, char_map):
        self.detector = detector
        self.recognizer = recognizer
        self.idx_to_char = {i + 1: char for i, char in enumerate(char_map)}
        
        self.detector.eval()
        self.recognizer.eval()

    def _get_cropped_image(self, image, box):
        w = max(np.linalg.norm(box[0] - box[1]), np.linalg.norm(box[2] - box[3]))
        h = max(np.linalg.norm(box[0] - box[3]), np.linalg.norm(box[1] - box[2]))
        dst_pts = np.array([[0, 0], [w - 1, 0], [w - 1, h - 1], [0, h - 1]], dtype=np.float32)
        M = cv2.getPerspectiveTransform(box.astype(np.float32), dst_pts)
        return cv2.warpPerspective(image, M, (int(w), int(h)))

    def _ctc_greedy_decode(self, preds):
        preds_idx = preds.argmax(dim=2).transpose(1, 0).contiguous().view(-1)
        decoded_text, last_char_idx = [], 0
        for idx in preds_idx:
            if idx.item() == 0 or idx.item() == last_char_idx:
                last_char_idx = idx.item() if idx.item() != 0 else 0
                continue
            decoded_text.append(self.idx_to_char.get(idx.item(), ''))
            last_char_idx = idx.item()
        return "".join(decoded_text)

    def predict(self, image_path):
        image = cv2.imread(image_path)
        img_h, img_w = image.shape[:2]
        
        det_input = cv2.resize(image, (640, 640))
        det_input = torch.from_numpy(det_input).permute(2, 0, 1).float().unsqueeze(0) / 255.0

        with torch.no_grad():
            prob_map, _ = self.detector(det_input)
            
            # 후처리
            binary_map = (prob_map.squeeze().numpy() > 0.3).astype(np.uint8)
            contours, _ = cv2.findContours(binary_map, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
            
            scale_x, scale_y = img_w / 640, img_h / 640
            results = []
            for contour in contours:
                if cv2.contourArea(contour) < 100: continue
                box = np.intp(cv2.boxPoints(cv2.minAreaRect(contour)) * np.array([scale_x, scale_y]))
                
                cropped_img = self._get_cropped_image(image, box)
                if cropped_img.shape[0] < 10 or cropped_img.shape[1] < 10: continue

                rec_input = cv2.resize(cropped_img, (100, 32))
                rec_input = torch.from_numpy(rec_input).permute(2, 0, 1).float().unsqueeze(0) / 255.0
                
                preds = self.recognizer(rec_input)
                text = self._ctc_greedy_decode(preds)
                results.append({'box': box.tolist(), 'text': text})
        
        return results, image

# ====================================================================================
# Part 5: 메인 실행 블록 (학습 시뮬레이션 및 추론)
# ====================================================================================
if __name__ == '__main__':
    # --- 1. 모델 및 하이퍼파라미터 설정 ---
    print("Step 1: 모델 및 설정 초기화...")
    CHARACTER_SET = "0123456789abcdefghijklmnopqrstuvwxyz"
    NUM_CLASSES = len(CHARACTER_SET) + 1 # +1 for CTC blank token
    
    detector = DBNet()
    recognizer = CRNN(num_chars=NUM_CLASSES)
    
    db_loss_fn = DBLoss()
    ctc_loss_fn = nn.CTCLoss(blank=0, zero_infinity=True)
    
    optimizer_det = torch.optim.Adam(detector.parameters(), lr=1e-4)
    optimizer_rec = torch.optim.Adam(recognizer.parameters(), lr=1e-4)
    
    # --- 2. 더미 학습 과정 시뮬레이션 ---
    print("\nStep 2: 더미 학습 과정 시뮬레이션 시작...")
    num_steps = 5
    for i in range(num_steps):
        # 더미 데이터 생성
        det_imgs, db_gts, rec_imgs, rec_gts = generate_dummy_training_batch(4, 640, CHARACTER_SET)
        
        # --- 탐지 모델 학습 ---
        detector.train()
        optimizer_det.zero_grad()
        pred_maps = detector(det_imgs)
        loss_det = db_loss_fn(pred_maps, db_gts)
        loss_det.backward()
        optimizer_det.step()
        
        # --- 인식 모델 학습 ---
        # 실제로는 탐지된 영역을 잘라 학습하지만, 여기서는 독립적으로 시뮬레이션
        recognizer.train()
        optimizer_rec.zero_grad()
        preds_rec = recognizer(rec_imgs) # (SeqLen, Batch, NumClasses)
        
        # CTCLoss를 위한 데이터 준비
        log_probs = F.log_softmax(preds_rec, dim=2)
        input_lengths = torch.full(size=(4,), fill_value=preds_rec.size(0), dtype=torch.long)
        
        char_to_idx = {char: i + 1 for i, char in enumerate(CHARACTER_SET)}
        targets = torch.cat([torch.tensor([char_to_idx[c] for c in text], dtype=torch.long) for text in rec_gts])
        target_lengths = torch.tensor([len(text) for text in rec_gts], dtype=torch.long)
        
        loss_rec = ctc_loss_fn(log_probs, targets, input_lengths, target_lengths)
        loss_rec.backward()
        optimizer_rec.step()
        
        print(f"  Step {i+1}/{num_steps} - Detector Loss: {loss_det.item():.4f}, Recognizer Loss: {loss_rec.item():.4f}")

    print("더미 학습 완료.")
    
    # --- 3. 추론 파이프라인 실행 ---
    print("\nStep 3: 추론 파이프라인 실행...")
    # '학습된' 모델로 파이프라인 교체
    pipeline = OCR_Pipeline(detector, recognizer, CHARACTER_SET)
    
    # 테스트 이미지 생성 및 예측
    dummy_img_path = generate_dummy_ocr_image()
    ocr_results, result_image = pipeline.predict(dummy_img_path)
    
    print("\n--- 최종 OCR 결과 ---")
    if not ocr_results:
        print("탐지된 텍스트가 없습니다.")
    else:
        for result in ocr_results:
            # 결과 시각화
            box = np.array(result['box'], dtype=np.intp)
            cv2.drawContours(result_image, [box], 0, (0, 255, 0), 2)
            # 텍스트 쓰기
            cv2.putText(result_image, result['text'], (box[0][0], box[0][1] - 10), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 0, 0), 2)
            print(f"Box: {result['box']}, Text: '{result['text']}'")
            
    # 결과 이미지 저장
    result_img_path = "final_ocr_result.png"
    cv2.imwrite(result_img_path, result_image)
    print(f"\n결과 이미지가 '{result_img_path}'로 저장되었습니다.")

In [None]:
!pip uninstall torch torchvision torchaudio

In [None]:
import torch
print(torch.cuda.is_available())

In [None]:
#detection


import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.models import mobilenet_v3_large
from torch.utils.data import Dataset, DataLoader
import numpy as np
import cv2
import json
import os
from tqdm import tqdm

# 데이터 증강 및 기하학 라이브러리
import albumentations as A
from albumentations.pytorch import ToTensorV2
import pyclipper

# ====================================================================================
# Part 1: 데이터 파이프라인 (Dataset, Augmentation, Ground Truth 생성)
# ====================================================================================

# ====================================================================================
# Part 1: 데이터 파이프라인 (최종 수정된 OcrDataset 클래스)
# ====================================================================================

class OcrDataset(Dataset):
    def __init__(self, annotation_path, image_dir, transform=None, target_size=640):
        super().__init__()
        self.image_dir = image_dir
        self.transform = transform
        self.target_size = target_size
        
        with open(annotation_path, 'r', encoding='utf-8') as f:
            self.annotations = json.load(f)['images']

    def __len__(self):
        return len(self.annotations)

    def __getitem__(self, idx):
        annotation_info = self.annotations[idx]
        image_key = annotation_info['file_name'] 
        image_path = os.path.join(self.image_dir, image_key)
        
        try:
            # ========================[ 최종 수정 부분 ]========================
            # cv2.imread()가 한글 경로를 처리 못하는 문제를 해결하기 위해
            # np.fromfile과 cv2.imdecode를 사용하는 방식으로 교체합니다.

            img_array = np.fromfile(image_path, np.uint8)
            image = cv2.imdecode(img_array, cv2.IMREAD_COLOR)

            if image is None:
                # 파일을 읽었으나 디코딩에 실패한 경우
                raise IOError(f"Failed to decode image at {image_path}")
            
            # OpenCV는 BGR로 이미지를 로드하므로, RGB로 변환합니다.
            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
            # =================================================================

        except Exception as e:
            print(f"Error loading image {image_path}: {e}")
            # 에러 발생 시, 다른 샘플을 로드하도록 시도합니다.
            return self.__getitem__((idx + 1) % len(self))

        polygons_data = annotation_info.get('polygons', [])
        
        polygons_for_aug = []
        for p_data in polygons_data:
            points = p_data.get('points')
            if points:
                polygons_for_aug.extend(points)

        if self.transform:
            try:
                transformed = self.transform(image=image, keypoints=polygons_for_aug)
                image = transformed['image']
                
                transformed_polygons = []
                points_per_polygon = [len(p_data.get('points', [])) for p_data in polygons_data]
                start = 0
                for num_points in points_per_polygon:
                    if num_points == 0: continue
                    polygon = transformed['keypoints'][start : start + num_points]
                    transformed_polygons.append(np.array(polygon, dtype=np.float32))
                    start += num_points
                polygons = transformed_polygons
            except Exception as e:
                print(f"Augmentation failed for {image_key}, using original. Error: {e}")
                polygons = [np.array(p['points'], dtype=np.float32) for p in polygons_data if 'points' in p]


        gt_prob_map, gt_prob_mask, gt_thresh_map, gt_thresh_mask = self.make_db_ground_truth(polygons, image.shape[1], image.shape[2])

        return image, gt_prob_map, gt_prob_mask, gt_thresh_map, gt_thresh_mask

    def make_db_ground_truth(self, polygons, width, height):
        gt_prob_map = np.zeros((height, width), dtype=np.float32)
        gt_prob_mask = np.ones((height, width), dtype=np.float32)
        gt_thresh_map = np.zeros((height, width), dtype=np.float32)
        gt_thresh_mask = np.zeros((height, width), dtype=np.float32)
        
        for polygon in polygons:
            if len(polygon) < 3: continue
            polygon = np.clip(polygon, [0,0], [width-1, height-1])
            
            cv2.fillPoly(gt_prob_map, [polygon.astype(np.int32)], 1.0)
            
            pco = pyclipper.PyclipperOffset()
            pco.AddPath(polygon, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
            
            area = cv2.contourArea(polygon)
            perimeter = cv2.arcLength(polygon, True)
            if perimeter == 0: continue
            distance = area * (1 - 0.4**2) / (perimeter + 1e-6)
            
            shrunk_polygons = pco.Execute(-distance)
            if not shrunk_polygons: continue
            
            cv2.fillPoly(gt_thresh_mask, [np.array(p, dtype=np.int32) for p in shrunk_polygons], 1.0)
            
            dist_map = np.zeros((height, width), dtype=np.float32)
            cv2.fillPoly(dist_map, [polygon.astype(np.int32)], 1.0)
            dist_map = cv2.distanceTransform(dist_map.astype(np.uint8), cv2.DIST_L2, 5)
            
            min_dist, max_dist = np.min(dist_map), np.max(dist_map)
            if max_dist > min_dist:
                dist_map = (dist_map - min_dist) / (max_dist - min_dist)
            
            gt_thresh_map[dist_map > 0] = dist_map[dist_map > 0]

        return torch.from_numpy(gt_prob_map).unsqueeze(0), torch.from_numpy(gt_prob_mask).unsqueeze(0), torch.from_numpy(gt_thresh_map).unsqueeze(0), torch.from_numpy(gt_thresh_mask).unsqueeze(0)
def get_transforms(size):
    # keypoint_params format='xy'는 (x, y) 좌표를 의미
    return A.Compose([
        A.Resize(size, size, always_apply=True),
        A.Rotate(limit=10, p=0.5, border_mode=cv2.BORDER_CONSTANT),
        A.ColorJitter(p=0.5),
        A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ToTensorV2(),
    ], keypoint_params=A.KeypointParams(format='xy', remove_invisible=False))


# ====================================================================================
# Part 2: 모델 및 손실 함수 아키텍처
# ====================================================================================
class MobileNetV3_Backbone(nn.Module):
    def __init__(self, pretrained=True):
        super().__init__()
        # mobilenet_v3_large의 특징 추출기 부분을 가져옵니다.
        self.features = mobilenet_v3_large(weights='IMAGENET1K_V1' if pretrained else None).features
        
        # FPN으로 특징을 전달할 레이어의 인덱스를 지정합니다.
        # 이 인덱스는 MobileNetV3-Large의 구조에 따릅니다.
        self.output_indices = [3, 6, 12, 16]

    def forward(self, x):
        outputs = []
        # features 모듈을 순차적으로 통과시키며, 지정된 인덱스에서 출력을 저장합니다.
        for i, layer in enumerate(self.features):
            x = layer(x)
            if i in self.output_indices:
                outputs.append(x)
        return outputs

class DB_FPN(nn.Module):
    # MobileNetV3-Large의 실제 출력 채널에 맞게 기본값을 수정합니다.
    def __init__(self, in_channels=[24, 40, 112, 960], out_channels=256):
        super().__init__()
        self.in_convs = nn.ModuleList([nn.Conv2d(c, out_channels, 1, bias=False) for c in in_channels])
        self.out_convs = nn.ModuleList([nn.Conv2d(out_channels, out_channels // 4, 3, padding=1, bias=False) for _ in in_channels])

    def forward(self, features):
        # features의 순서는 [C2, C3, C4, C5] 입니다. (작은 것 -> 큰 것)
        c2, c3, c4, c5 = features

        # 입력 채널에 맞게 1x1 conv 적용
        in5 = self.in_convs[3](c5)
        in4 = self.in_convs[2](c4)
        in3 = self.in_convs[1](c3)
        in2 = self.in_convs[0](c2)

        # Top-down 경로 (FPN의 핵심)
        out4 = in4 + F.interpolate(in5, size=in4.shape[2:], mode='bilinear', align_corners=False) # P5 -> P4
        out3 = in3 + F.interpolate(out4, size=in3.shape[2:], mode='bilinear', align_corners=False) # P4 -> P3
        out2 = in2 + F.interpolate(out3, size=in2.shape[2:], mode='bilinear', align_corners=False) # P3 -> P2

        # 3x3 conv로 최종 특징 맵 생성 및 업샘플링 후 합치기
        p5 = F.interpolate(self.out_convs[3](in5), size=out2.shape[2:], mode='bilinear', align_corners=False)
        p4 = F.interpolate(self.out_convs[2](out4), size=out2.shape[2:], mode='bilinear', align_corners=False)
        p3 = F.interpolate(self.out_convs[1](out3), size=out2.shape[2:], mode='bilinear', align_corners=False)
        p2 = self.out_convs[0](out2)
        
        return torch.cat([p2, p3, p4, p5], dim=1)

# DB_Head, DBNet, DiceLoss, DBLoss 클래스는 수정할 필요가 없습니다.
# (기존 코드 그대로 사용)
class DB_Head(nn.Module):
    def __init__(self, in_channels=256):
        super().__init__()
        # [수정] 최종 출력이 640x640이 되도록 nn.Upsample 레이어를 추가합니다.
        self.prob_conv = nn.Sequential(
            nn.Conv2d(in_channels, in_channels // 4, 3, padding=1), 
            nn.BatchNorm2d(in_channels // 4), 
            nn.ReLU(), 
            nn.ConvTranspose2d(in_channels // 4, 1, 2, 2), # 여기까지 출력이 320x320
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) # 320x320 -> 640x640
        )
        self.thresh_conv = nn.Sequential(
            nn.Conv2d(in_channels, in_channels // 4, 3, padding=1), 
            nn.BatchNorm2d(in_channels // 4), 
            nn.ReLU(), 
            nn.ConvTranspose2d(in_channels // 4, 1, 2, 2), # 여기까지 출력이 320x320
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) # 320x320 -> 640x640
        )
    def forward(self, x): 
        return torch.sigmoid(self.prob_conv(x)), torch.sigmoid(self.thresh_conv(x))


class DBNet(nn.Module):
    def __init__(self, pretrained=True):
        super().__init__()
        self.backbone = MobileNetV3_Backbone(pretrained)
        self.fpn = DB_FPN()
        self.head = DB_Head()
    def forward(self, x): return self.head(self.fpn(self.backbone(x)))

class DiceLoss(nn.Module):
    def __init__(self, eps=1e-6):
        super().__init__(); self.eps = eps
    def forward(self, pred, target, mask):
        pred, target = pred * mask, target * mask; intersection = (pred * target).sum()
        union = pred.sum() + target.sum() + self.eps; return 1 - (2 * intersection / union)

class DBLoss(nn.Module):
    def __init__(self, alpha=1.0, beta=10.0, k=50):
        super().__init__(); self.alpha, self.beta, self.k = alpha, beta, k
        self.bce_loss = nn.BCELoss(reduction='none'); self.dice_loss = DiceLoss()
    def forward(self, pred_maps, gt_maps):
        pred_prob, pred_thresh = pred_maps; gt_prob, gt_prob_mask, gt_thresh, gt_thresh_mask = gt_maps
        loss_s = self.dice_loss(pred_prob, gt_prob, gt_prob_mask) + (self.bce_loss(pred_prob, gt_prob) * gt_prob_mask).mean()
        db_map = 1 / (1 + torch.exp(-self.k * (pred_prob - pred_thresh)))
        loss_b = self.dice_loss(db_map, gt_prob, gt_prob_mask)
        loss_t = (torch.abs(pred_thresh - gt_thresh) * gt_thresh_mask).sum() / (gt_thresh_mask.sum() + 1e-6)
        return loss_s + self.alpha * loss_b + self.beta * loss_t

# ====================================================================================
# Part 3: 실제 학습 스크립트
# ====================================================================================
def main():
    # --- 1. 설정 (사용자 환경에 맞게 수정) ---
    # 이 경로들을 자신의 데이터셋 경로로 반드시 수정해야 합니다.
    ANNOTATION_PATH = r"C:\Users\User\DBNet_OCR\[라벨]Training\2.책표지\04.사회과학\combined_labels_cleaned.json"
    IMAGE_DIR = r"C:\Users\User\DBNet_OCR\[원천]Training_책표지1\04.사회과학" 
    
    # 하이퍼파라미터
    IMG_SIZE = 640
    BATCH_SIZE = 4
    NUM_EPOCHS = 20 # 실제 학습을 위해 Epoch 수를 늘립니다.
    LEARNING_RATE = 1e-4
    DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
    
    print(f"사용 장치: {DEVICE}")
    print("사전 학습된 MobileNetV3 가중치를 사용합니다.")

    # --- 2. 데이터 파이프라인 준비 ---
    print("데이터 로더를 준비합니다...")
    try:
        dataset = OcrDataset(
            annotation_path=ANNOTATION_PATH,
            image_dir=IMAGE_DIR,
            transform=get_transforms(IMG_SIZE)
        )
        data_loader = DataLoader(
            dataset=dataset,
            batch_size=BATCH_SIZE,
            shuffle=True,
            num_workers = 0  # CPU 코어 수에 맞게 조절
        )
    except FileNotFoundError as e:
        print(f"오류: 데이터셋 경로를 찾을 수 없습니다. ANNOTATION_PATH와 IMAGE_DIR을 확인하세요.")
        print(e)
        return

    # --- 3. 모델, 손실함수, 옵티마이저 준비 ---
    print("모델과 옵티마이저를 준비합니다...")
    detector = DBNet(pretrained=True).to(DEVICE)
    db_loss_fn = DBLoss().to(DEVICE)
    optimizer = torch.optim.AdamW(detector.parameters(), lr=LEARNING_RATE) # AdamW가 종종 더 나은 성능을 보임
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=len(data_loader) * NUM_EPOCHS)

    # --- 4. 학습 루프 ---
    print("학습을 시작합니다...")
    for epoch in range(NUM_EPOCHS):
        detector.train()
        epoch_loss = 0
        
        progress_bar = tqdm(data_loader, desc=f"Epoch {epoch+1}/{NUM_EPOCHS}")
        for batch_idx, (images, gt_prob, gt_prob_mask, gt_thresh, gt_thresh_mask) in enumerate(progress_bar):
            images = images.to(DEVICE)
            gt_prob, gt_prob_mask = gt_prob.to(DEVICE), gt_prob_mask.to(DEVICE)
            gt_thresh, gt_thresh_mask = gt_thresh.to(DEVICE), gt_thresh_mask.to(DEVICE)
            
            pred_prob, pred_thresh = detector(images)
            loss = db_loss_fn((pred_prob, pred_thresh), (gt_prob, gt_prob_mask, gt_thresh, gt_thresh_mask))
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            scheduler.step() # 매 스텝마다 학습률 조절
            
            epoch_loss += loss.item()
            progress_bar.set_postfix(loss=f"{loss.item():.4f}", lr=f"{scheduler.get_last_lr()[0]:.1e}")

        avg_loss = epoch_loss / len(data_loader)
        print(f"Epoch {epoch+1} 완료, 평균 손실: {avg_loss:.4f}")

        # --- 5. 체크포인트 저장 ---
        if (epoch + 1) % 5 == 0: # 5 에포크마다 저장
            torch.save(detector.state_dict(), f"dbnet_detector_epoch_{epoch+1}.pth")
            print(f"Epoch {epoch+1} 모델 저장 완료.")

    print("학습이 모두 완료되었습니다.")

if __name__ == '__main__':
    main()

사용 장치: cuda
사전 학습된 MobileNetV3 가중치를 사용합니다.
데이터 로더를 준비합니다...


  A.Resize(size, size, always_apply=True),


모델과 옵티마이저를 준비합니다...
학습을 시작합니다...


Epoch 1/20: 100%|█████████████████████████████████████████| 2801/2801 [12:03<00:00,  3.87it/s, loss=2.0050, lr=9.9e-05]


Epoch 1 완료, 평균 손실: 2.0361


Epoch 2/20: 100%|█████████████████████████████████████████| 2801/2801 [12:16<00:00,  3.81it/s, loss=2.0006, lr=9.8e-05]


Epoch 2 완료, 평균 손실: 2.0017


Epoch 3/20: 100%|█████████████████████████████████████████| 2801/2801 [12:38<00:00,  3.69it/s, loss=2.0001, lr=9.5e-05]


Epoch 3 완료, 평균 손실: 2.0003


Epoch 4/20: 100%|█████████████████████████████████████████| 2801/2801 [12:08<00:00,  3.85it/s, loss=2.0000, lr=9.0e-05]


Epoch 4 완료, 평균 손실: 2.0001


Epoch 5/20: 100%|█████████████████████████████████████████| 2801/2801 [12:30<00:00,  3.73it/s, loss=2.0000, lr=8.5e-05]


Epoch 5 완료, 평균 손실: 2.0000
Epoch 5 모델 저장 완료.


Epoch 6/20: 100%|█████████████████████████████████████████| 2801/2801 [12:36<00:00,  3.70it/s, loss=2.0000, lr=7.9e-05]


Epoch 6 완료, 평균 손실: 2.0000


Epoch 7/20:   2%|▊                                          | 55/2801 [00:14<11:59,  3.81it/s, loss=2.0000, lr=7.9e-05]