In [1]:
import torch
from torch.utils.data import DataLoader, Dataset
from torch import nn, optim
from transformers import AutoImageProcessor, AutoModelForImageClassification
from torchvision import transforms
from PIL import Image
import pandas as pd
from sklearn.preprocessing import LabelEncoder
from pathlib import Path

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# ---------------------------------------------------
# 1️⃣ Transform 정의
# ---------------------------------------------------
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor()
])

# ---------------------------------------------------
# 2️⃣ CSV 기반 커스텀 Dataset 정의
# ---------------------------------------------------
class CustomImageDataset(Dataset):
    def __init__(self, csv_file, img_dir, transform=None, label_encoder=None):
        self.data = pd.read_csv(csv_file)
        self.img_dir = img_dir
        self.transform = transform
        self.label_encoder = label_encoder

        # 레이블이 존재하면 숫자로 변환
        if 'target' in self.data.columns and self.label_encoder is not None:
            self.data['target'] = self.label_encoder.transform(self.data['target'])

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img_path = f"{self.img_dir}/{self.data.iloc[idx, 0]}"
        image = Image.open(img_path).convert("RGB")

        if self.transform:
            image = self.transform(image)

        # 레이블이 없으면 -1 반환
        if 'target' in self.data.columns:
            label = torch.tensor(self.data.iloc[idx, 1], dtype=torch.long)
            return image, label
        else:
            return image, -1

# ---------------------------------------------------
# 3️⃣ LabelEncoder 준비 (문자열 레이블 → 숫자)
# ---------------------------------------------------
train_df = pd.read_csv("/root/CV_/datasets/data/train.csv")
le = LabelEncoder()
train_df['target'] = le.fit_transform(train_df['target'])

# Dataset 생성
train_dataset = CustomImageDataset(
    csv_file="/root/CV_/datasets/data/train.csv",
    img_dir="/root/CV_/datasets/data/train",
    transform=transform,
    label_encoder=le
)

test_dataset = CustomImageDataset(
    csv_file="/root/CV_/datasets/data/sample_submission.csv",
    img_dir="/root/CV_/datasets/data/test",
    transform=transform,
    label_encoder=le  # test에는 레이블 없어도 무방
)

train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=16)

In [3]:
# ---------------------------------------------------
# 4️⃣ 모델 준비
# ---------------------------------------------------
processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224")
model = AutoModelForImageClassification.from_pretrained("google/vit-base-patch16-224")

Fast image processor class <class 'transformers.models.vit.image_processing_vit_fast.ViTImageProcessorFast'> is available for this model. Using slow image processor class. To use the fast image processor class set `use_fast=True`.


In [4]:
# 출력 클래스 수를 train dataset에 맞춤
num_classes = len(le.classes_)
if model.config.num_labels != num_classes:
    # 마지막 classifier 레이어 이름 확인
    if hasattr(model, 'classifier'):
        in_features = model.classifier.in_features
        model.classifier = nn.Linear(in_features, num_classes)
    elif hasattr(model, 'score'):
        in_features = model.score.in_features
        model.score = nn.Linear(in_features, num_classes)
    else:
        print("모델 구조 확인 필요 - 마지막 레이어 이름 다를 수 있음")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=5e-5)

In [5]:
# ---------------------------------------------------
# 5️⃣ 학습 Loop
# ---------------------------------------------------
num_epochs = 30             

for epoch in range(num_epochs):
    model.train()
    total_loss = 0
    for images, labels in train_loader:
        images = images.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()
        outputs = model(images).logits
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    avg_loss = total_loss / len(train_loader)
    print(f"Epoch {epoch+1}/{num_epochs} - Loss: {avg_loss:.4f}")

Epoch 1/30 - Loss: 0.8920
Epoch 2/30 - Loss: 0.2196
Epoch 3/30 - Loss: 0.1274
Epoch 4/30 - Loss: 0.0753
Epoch 5/30 - Loss: 0.0602
Epoch 6/30 - Loss: 0.0550
Epoch 7/30 - Loss: 0.0153
Epoch 8/30 - Loss: 0.0357
Epoch 9/30 - Loss: 0.0074
Epoch 10/30 - Loss: 0.0029
Epoch 11/30 - Loss: 0.0014
Epoch 12/30 - Loss: 0.0012
Epoch 13/30 - Loss: 0.0010
Epoch 14/30 - Loss: 0.0009
Epoch 15/30 - Loss: 0.0008
Epoch 16/30 - Loss: 0.0007
Epoch 17/30 - Loss: 0.0006
Epoch 18/30 - Loss: 0.0005
Epoch 19/30 - Loss: 0.0005
Epoch 20/30 - Loss: 0.0004
Epoch 21/30 - Loss: 0.0004
Epoch 22/30 - Loss: 0.0004
Epoch 23/30 - Loss: 0.0003
Epoch 24/30 - Loss: 0.0003
Epoch 25/30 - Loss: 0.0003
Epoch 26/30 - Loss: 0.0002
Epoch 27/30 - Loss: 0.0002
Epoch 28/30 - Loss: 0.0002
Epoch 29/30 - Loss: 0.0002
Epoch 30/30 - Loss: 0.0002


In [6]:
# ---------------------------------------------------
# 6️⃣ 테스트 예측
# ---------------------------------------------------
model.eval()
all_preds = []

with torch.no_grad():
    for images, _ in test_loader:
        images = images.to(device)
        outputs = model(images).logits
        preds = outputs.argmax(dim=1)
        all_preds.extend(preds.cpu().numpy())

# 숫자를 클래스 이름으로 변환
pred_labels = le.inverse_transform(all_preds)

# 결과 확인
for i, label in enumerate(pred_labels[:10]):
    print(f"Image {i}: Predicted class = {label}")

Image 0: Predicted class = 2
Image 1: Predicted class = 6
Image 2: Predicted class = 5
Image 3: Predicted class = 13
Image 4: Predicted class = 2
Image 5: Predicted class = 15
Image 6: Predicted class = 0
Image 7: Predicted class = 8
Image 8: Predicted class = 15
Image 9: Predicted class = 11


In [7]:
result = pd.read_csv('/root/CV_/datasets/data/sample_submission.csv')

In [8]:
result['target'] = pred_labels

In [9]:
result.to_csv('vit_output_2.csv', index=False)