In [1]:
import os
import torch
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision.transforms import ToTensor, Resize, Normalize, Compose
from PIL import Image
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
from torch.optim import AdamW

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
# 데이터셋 클래스 정의
class KoreanTextDataset(Dataset):
    def __init__(self, image_dir, txt_dir, max_samples, transform=None, target_size=(384, 384)):
        self.image_dir = image_dir
        self.txt_dir = txt_dir
        self.transform = transform
        self.target_size = target_size  # 고정 크기 설정
        self.max_samples = max_samples
        self.samples, self.skipped_files = self._load_data()

    def _load_data(self):
        data = []
        skipped_files = []
        txt_files = []
        image_files = []

        # 하위 폴더 탐색
        for root, _, files in os.walk(self.txt_dir):
            txt_files.extend([os.path.join(root, f) for f in files if f.endswith(".txt")])
        for root, _, files in os.walk(self.image_dir):
            image_files.extend([os.path.join(root, f) for f in files if f.endswith(".png")])

        # 최대 샘플 제한
        image_files = sorted(image_files)[:self.max_samples]

        for image_path in image_files:
            txt_file = os.path.basename(image_path).replace(".png", ".txt")
            txt_path = next((f for f in txt_files if os.path.basename(f) == txt_file), None)

            if txt_path and os.path.exists(txt_path):
                with open(txt_path, "r", encoding="utf-8-sig") as f:
                    lines = f.readlines()

                # 텍스트 파일 파싱
                bboxes = []
                question_text = ""
                bbox_section = False
                question_section = False
                for line in lines:
                    line = line.strip()  # 양쪽 공백 제거
                    if not line:  # 빈 줄 건너뜀
                        continue
                    if "[bboxs]" in line:
                        bbox_section = True
                        question_section = False
                        continue
                    elif "[question_text]" in line:
                        bbox_section = False
                        question_section = True
                        continue

                    if bbox_section:
                        parts = line.split()
                        if len(parts) == 5 and parts[0] == '0':  # 올바른 데이터인지 확인
                            x_center, y_center, width, height = map(float, parts[1:])
                            x = x_center - (width / 2)
                            y = y_center - (height / 2)
                            w = width
                            h = height
                            bboxes.append((x, y, w, h))
                        else:
                            print(f"Invalid bbox line in {txt_path}: {line}")
                            continue

                    if question_section:
                        question_text += line.strip()

                data.append((image_path, bboxes, question_text))
            else:
                skipped_files.append(image_path)
        return data, skipped_files

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        image_path, bboxes, label = self.samples[idx]
        image = Image.open(image_path).convert("RGB")
        cropped_images = []

        for bbox in bboxes:
            x, y, w, h = bbox
            left = int(x * image.width)
            top = int(y * image.height)
            right = int((x + w) * image.width)
            bottom = int((y + h) * image.height)

            cropped_image = image.crop((left, top, right, bottom))
            cropped_image = cropped_image.resize(self.target_size)
            if self.transform:
                cropped_image = self.transform(cropped_image)
            cropped_images.append(cropped_image)

        image = image.resize(self.target_size)  # 전체 이미지를 384x384로 조정
        if self.transform:
            image = self.transform(image)

        return image, torch.stack(cropped_images), label

# 데이터셋 경로 설정
train_image_dir = "C:/Users/boar2/Desktop/final_project/github/Data/images/training/x"
train_txt_dir = "C:/Users/boar2/Desktop/final_project/github/Data/images/training/formatted_y"
val_image_dir = "C:/Users/boar2/Desktop/final_project/github/Data/images/validation/x"
val_txt_dir = "C:/Users/boar2/Desktop/final_project/github/Data/images/validation/formatted_y"

# 데이터셋 로드
train_dataset = KoreanTextDataset(
    image_dir=train_image_dir,
    txt_dir=train_txt_dir,
    max_samples=50,
    transform=Compose([
        ToTensor(),
        Normalize((0.5,), (0.5,))
    ])
)

val_dataset = KoreanTextDataset(
    image_dir=val_image_dir,
    txt_dir=val_txt_dir,
    max_samples=10,
    transform=Compose([
        ToTensor(),
        Normalize((0.5,), (0.5,))
    ])
)

# 건너뛴 파일 출력
if train_dataset.skipped_files:
    print("Skipped training files:", train_dataset.skipped_files)
if val_dataset.skipped_files:
    print("Skipped validation files:", val_dataset.skipped_files)

# 데이터 로더 생성
train_dataloader = DataLoader(train_dataset, batch_size=1, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=1, shuffle=False)

# 모델 초기화
processor = TrOCRProcessor.from_pretrained("team-lucid/trocr-small-korean")
model = VisionEncoderDecoderModel.from_pretrained("team-lucid/trocr-small-korean")

# 모델 설정
model.config.decoder_start_token_id = processor.tokenizer.cls_token_id
model.config.pad_token_id = processor.tokenizer.pad_token_id
model.config.vocab_size = model.config.decoder.vocab_size

# 학습 설정
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

optimizer = AdamW(model.parameters(), lr=5e-5)

# 학습 및 검증 루프
epochs = 10
for epoch in range(epochs):
    model.train()
    total_train_loss = 0

    for full_image, cropped_images, ground_truth_text in train_dataloader:
        full_image = full_image.to(device)
        cropped_images = cropped_images.to(device)
        labels = processor.tokenizer(ground_truth_text, padding=True, truncation=True, return_tensors="pt").input_ids.to(device)

        outputs = model(pixel_values=full_image, labels=labels)
        loss = outputs.loss
        total_train_loss += loss.item()

        # 모델이 예측한 텍스트 디코딩 및 출력
        predicted_ids = model.generate(pixel_values=full_image)
        predicted_text = processor.batch_decode(predicted_ids, skip_special_tokens=True)
        print(f"Ground Truth: {ground_truth_text[0]}")
        print(f"Predicted Text: {predicted_text}")
        print(f"Loss: {loss.item():.4f}")

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    avg_train_loss = total_train_loss / len(train_dataloader)

    model.eval()
    total_val_loss = 0
    with torch.no_grad():
        for full_image, cropped_images, ground_truth_text in val_dataloader:
            full_image = full_image.to(device)
            cropped_images = cropped_images.to(device)
            labels = processor.tokenizer(ground_truth_text, padding=True, truncation=True, return_tensors="pt").input_ids.to(device)

            outputs = model(pixel_values=full_image, labels=labels)
            loss = outputs.loss
            total_val_loss += loss.item()

            # 모델이 예측한 텍스트 디코딩 및 출력
            predicted_ids = model.generate(pixel_values=full_image)
            predicted_text = processor.batch_decode(predicted_ids, skip_special_tokens=True)
            print(f"Validation Ground Truth: {ground_truth_text[0]}")
            print(f"Validation Predicted Text: {predicted_text}")
            print(f"Validation Loss: {loss.item():.4f}")

    avg_val_loss = total_val_loss / len(val_dataloader)
    print(f"Epoch {epoch + 1}/{epochs}")
    print(f"  Train Loss: {avg_train_loss:.4f}")
    print(f"  Validation Loss: {avg_val_loss:.4f}")

Invalid bbox line in C:/Users/boar2/Desktop/final_project/github/Data/images/training/formatted_y\elementary3\P3_1_01_21181_50079.txt: 1 0.487458 0.532213 0.640468 0.661064
Invalid bbox line in C:/Users/boar2/Desktop/final_project/github/Data/images/training/formatted_y\elementary3\P3_1_01_21181_50080.txt: 1 0.509197 0.549020 0.707358 0.627451
Invalid bbox line in C:/Users/boar2/Desktop/final_project/github/Data/images/training/formatted_y\elementary3\P3_1_01_21181_50081.txt: 1 0.505853 0.537356 0.627090 0.643678
Invalid bbox line in C:/Users/boar2/Desktop/final_project/github/Data/images/training/formatted_y\elementary3\P3_1_01_21181_50082.txt: 1 0.502508 0.545312 0.673913 0.603125
Invalid bbox line in C:/Users/boar2/Desktop/final_project/github/Data/images/training/formatted_y\elementary3\P3_1_01_21181_50083.txt: 1 0.505853 0.536337 0.720736 0.642442
Invalid bbox line in C:/Users/boar2/Desktop/final_project/github/Data/images/training/formatted_y\elementary3\P3_1_01_21181_66404.txt: 

Config of the encoder: <class 'transformers.models.deit.modeling_deit.DeiTModel'> is overwritten by shared encoder config: DeiTConfig {
  "attention_probs_dropout_prob": 0.0,
  "encoder_stride": 16,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.0,
  "hidden_size": 384,
  "image_size": 384,
  "initializer_range": 0.02,
  "intermediate_size": 1536,
  "layer_norm_eps": 1e-12,
  "model_type": "deit",
  "num_attention_heads": 6,
  "num_channels": 3,
  "num_hidden_layers": 12,
  "patch_size": 16,
  "qkv_bias": true,
  "torch_dtype": "float32",
  "transformers_version": "4.46.3"
}

Config of the decoder: <class 'transformers.models.trocr.modeling_trocr.TrOCRForCausalLM'> is overwritten by shared decoder config: TrOCRConfig {
  "activation_dropout": 0.0,
  "activation_function": "relu",
  "add_cross_attention": true,
  "attention_dropout": 0.0,
  "bos_token_id": 0,
  "classifier_dropout": 0.0,
  "cross_attention_hidden_size": 384,
  "d_model": 256,
  "decoder_attention_heads": 8,
  "decod

Ground Truth: $186+213$에서 $186$과 $213$을 각각 가장 가까운 몇백으로 어림하여 계산해 보세요.$\\$
Predicted Text: ['여기에서 213을 각각 가까운 몇백으로 어림히']
Loss: 9.7253
Ground Truth: $628+231$에서 $628$과 $231$을 각각 가장 가까운 몇백 몇십으로 어림하여 계산해 보세요.
Predicted Text: ['어릴적#ff,#ff8ffA28.']
Loss: 7.7642
Ground Truth: $523+316$에서 $523$과 $316$을 각각 가장 가까운 몇백 몇십으로 어림하여 계산해 보세요.
Predicted Text: [' 기억이 없다.]']
Loss: 7.7950
Ground Truth: 다음이 나타내는 수보다 $232$만큼 더 큰 수를 구해 보세요. $100$이 $6$ 개, $10$이 $6$ 개, $1$이 $8$ 개인 수
Predicted Text: ['장은 분명한 느낌의�이라는 이름으로 번역되는 것은 곧 글을']
Loss: 7.6800
Ground Truth: 다음 계산에서 ㉠이 실제로 나타내는 수를 구해 보세요. $\begin{array}{r}  \overset{\boxed{㉠}} 2 \overset{1}57 \\ +865 \\ \hline 1122 \end{array}$
Predicted Text: ['#000#$+X5의 효과를 받은 효과가 되기 때문에는']
Loss: 10.5559
Ground Truth: 성재는 집에서 출발하여 서점에 들렀다가 학교에 가려고 합니다. 성재가 걸어야 하는 거리는 몇 $ m$ 인지 구해 보세요.[figure_text]null
Predicted Text: ['+++++++++++++++++++']
Loss: 8.5304
Ground Truth: $598$과 $387$을 각각 가장 가까운 몇백으로 어림하여 $598-387$을 계산해 보세요.
Predicted Text: ['$$ $+$ $+10 $+5$$$+5$$$']
Loss: 5.6

: 

In [None]:
# 모델 저장
model.save_pretrained("./trocr-korean-finetuned")
processor.save_pretrained("./trocr-korean-finetuned")

# pth 파일로 저장
torch.save(model.state_dict(), "./trocr-korean-finetuned.pth")

# 모델 저장 설명
# save_pretrained: 트랜스포머 모델 및 프로세서를 재사용 가능하도록 디렉터리에 저장
# state_dict: PyTorch 텐서로 학습된 가중치만 저장