In [2]:
import torch
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
from transformers import DeiTForImageClassification, DeiTFeatureExtractor
import numpy as np
import os

In [7]:
base_dir = 'img_preprocessed/'  
train_dir = os.path.join(base_dir, 'train')
val_dir = os.path.join(base_dir, 'val')

In [4]:
#feature_extractor 설정

feature_extractor = DeiTFeatureExtractor.from_pretrained('facebook/deit-base-distilled-patch16-224')

To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to see activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development


In [5]:
#데이터 변환 
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std)
])

In [8]:
#데이터 로드
train_dataset = ImageFolder(root=train_dir, transform=transform)
val_dataset = ImageFolder(root=val_dir, transform=transform)

In [9]:
# 데이터 로더 설정
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

In [10]:
# 모델 설정
model = DeiTForImageClassification.from_pretrained('facebook/deit-base-distilled-patch16-224', num_labels=4)

# Optimizer 설정
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)

# 학습 설정
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

Some weights of DeiTForImageClassification were not initialized from the model checkpoint at facebook/deit-base-distilled-patch16-224 and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


DeiTForImageClassification(
  (deit): DeiTModel(
    (embeddings): DeiTEmbeddings(
      (patch_embeddings): DeiTPatchEmbeddings(
        (projection): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
      )
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): DeiTEncoder(
      (layer): ModuleList(
        (0-11): 12 x DeiTLayer(
          (attention): DeiTAttention(
            (attention): DeiTSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
            (output): DeiTSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
          )
          (intermediate): DeiTIntermediate(
            (dense): Linear(in

In [11]:
# 학습 함수
def train_epoch(model, data_loader, optimizer, device):
    model.train()
    total_loss = 0
    correct = 0
    total = 0
    for batch in data_loader:
        inputs, labels = batch
        inputs = inputs.to(device)
        labels = labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(inputs).logits
        loss = torch.nn.CrossEntropyLoss()(outputs, labels)
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        _, predicted = torch.max(outputs, 1)
        correct += (predicted == labels).sum().item()
        total += labels.size(0)
        
    accuracy = correct / total
    average_loss = total_loss / len(data_loader)
    return average_loss, accuracy

In [12]:
def evaluate(model, data_loader, device):
    model.eval()
    total_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for batch in data_loader:
            inputs, labels = batch
            inputs = inputs.to(device)
            labels = labels.to(device)
            
            outputs = model(inputs).logits
            loss = torch.nn.CrossEntropyLoss()(outputs, labels)
            
            total_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            correct += (predicted == labels).sum().item()
            total += labels.size(0)
            
    accuracy = correct / total
    average_loss = total_loss / len(data_loader)
    return average_loss, accuracy

In [13]:
# 학습 루프
num_epochs = 50
best_val_loss = np.inf
patience = 5
early_stopping_counter = 0

In [14]:
for epoch in range(num_epochs):
    train_loss, train_accuracy = train_epoch(model, train_loader, optimizer, device)
    val_loss, val_accuracy = evaluate(model, val_loader, device)
    
    print(f'Epoch {epoch+1}/{num_epochs}')
    print(f'Train Loss: {train_loss:.4f}, Train Accuracy: {train_accuracy:.4f}')
    print(f'Validation Loss: {val_loss:.4f}, Validation Accuracy: {val_accuracy:.4f}')
    
    # Early stopping check
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        early_stopping_counter = 0
        # 모델 저장
        model.save_pretrained('model/')
    else:
        early_stopping_counter += 1
    
    if early_stopping_counter >= patience:
        break

Epoch 1/50
Train Loss: 0.5471, Train Accuracy: 0.7771
Validation Loss: 0.3045, Validation Accuracy: 0.8876
Epoch 2/50
Train Loss: 0.1530, Train Accuracy: 0.9480
Validation Loss: 0.3815, Validation Accuracy: 0.8601
Epoch 3/50
Train Loss: 0.0597, Train Accuracy: 0.9796
Validation Loss: 0.2923, Validation Accuracy: 0.9098
Epoch 4/50
Train Loss: 0.0276, Train Accuracy: 0.9914
Validation Loss: 0.3661, Validation Accuracy: 0.8954
Epoch 5/50
Train Loss: 0.0040, Train Accuracy: 0.9988
Validation Loss: 0.3834, Validation Accuracy: 0.9046
Epoch 6/50
Train Loss: 0.0232, Train Accuracy: 0.9924
Validation Loss: 0.4305, Validation Accuracy: 0.8928
Epoch 7/50
Train Loss: 0.0322, Train Accuracy: 0.9901
Validation Loss: 1.1725, Validation Accuracy: 0.7529
Epoch 8/50
Train Loss: 0.0516, Train Accuracy: 0.9836
Validation Loss: 0.3529, Validation Accuracy: 0.8993
