In [1]:
import warnings

from docutils.nodes import reference
from torch.utils.data import DataLoader

warnings.filterwarnings('ignore')

In [2]:
from datasets import load_dataset
train_ds, test_ds = load_dataset('cifar10',split=['train[:5000]', 'test[:2000]'])
splits = train_ds.train_test_split(test_size=0.1)
train_ds = splits['train']
val_ds = splits['test']
train_ds[0].keys()

dict_keys(['img', 'label'])

In [3]:
from transformers import ViTFeatureExtractor

feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224-in21k')
feature_extractor #ViTFeatureExtractor 설정확인용
list(feature_extractor.size.values())

[224, 224]

In [4]:
from torchvision import transforms

normalize = transforms.Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std)

transforms_for_train = transforms.Compose(
    [
        transforms.Resize(list(feature_extractor.size.values())),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        normalize,
    ]
)

transforms_for_val = transforms.Compose(
    [
        transforms.Resize(list(feature_extractor.size.values())),
        transforms.CenterCrop(list(feature_extractor.size.values())),
        transforms.ToTensor(),
        normalize,
    ]
)

def train_transforms(imagedata):
    imagedata['pixel_values'] = [transforms_for_train(image.convert('RGB')) for image in imagedata['img']]
    return imagedata

def test_transforms(imagedata):
    imagedata['pixel_values'] = [transforms_for_val(image.convert('RGB')) for image in imagedata['img']]
    return imagedata

train_ds.set_transform(train_transforms)
val_ds.set_transform(test_transforms)
test_ds.set_transform(test_transforms)

print (train_ds[0].keys())
print (type(train_ds[0]['img']))
print (type(train_ds[0]['label']), train_ds[0]['label'])
print (type(train_ds[0]['pixel_values']), train_ds[0]['pixel_values'].shape)

dict_keys(['img', 'label', 'pixel_values'])
<class 'PIL.PngImagePlugin.PngImageFile'>
<class 'int'> 8
<class 'torch.Tensor'> torch.Size([3, 224, 224])


In [5]:
from torch.utils.data import DataLoader
import torch

def collate_fn(imagedata):
    pixel_values = torch.stack([example['pixel_values'] for example in imagedata])
    labels = torch.tensor([example['label'] for example in imagedata])
    return {'pixel_values': pixel_values, 'labels':labels}

In [6]:
id2label = {id:label for id, label in enumerate(train_ds.features['label'].names)}
label2id = {label:id for id, label in id2label.items()}
id2label

{0: 'airplane',
 1: 'automobile',
 2: 'bird',
 3: 'cat',
 4: 'deer',
 5: 'dog',
 6: 'frog',
 7: 'horse',
 8: 'ship',
 9: 'truck'}

In [7]:
from transformers import ViTForImageClassification
model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224-in21k',
                                                  num_labels=10,
                                                  id2label=id2label,
                                                  label2id=label2id) #출력부분 날리고 이걸로 교체한거임 즉 기존모델에서 수정한거

Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [8]:


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"현재 사용 중인 디바이스: {device}")

# 모델을 GPU로 이동
model.to(device)

현재 사용 중인 디바이스: cuda


ViTForImageClassification(
  (vit): ViTModel(
    (embeddings): ViTEmbeddings(
      (patch_embeddings): ViTPatchEmbeddings(
        (projection): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
      )
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): ViTEncoder(
      (layer): ModuleList(
        (0-11): 12 x ViTLayer(
          (attention): ViTAttention(
            (attention): ViTSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
            )
            (output): ViTSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
          )
          (intermediate): ViTIntermediate(
            (dense): Linear(in_features=768, out_features=3072, bias=True)
            (intermed

In [9]:
from transformers import TrainingArguments, Trainer

args = TrainingArguments(
    output_dir="test-cifar-10",
    save_strategy='epoch',
    eval_strategy = 'epoch',
    learning_rate=2e-5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=10,
    weight_decay=0.01,
    load_best_model_at_end=True,
    metric_for_best_model='accuracy',
    logging_dir = 'logs',
    remove_unused_columns=False,
    optim='adamw_torch',
    lr_scheduler_type='constant',
    save_total_limit=10,
    fp16=True,  # Mixed Precision 학습 (NVIDIA GPU에서 속도 대폭 향상)
    dataloader_num_workers=0,

)

In [10]:
import evaluate
import numpy as np
metric = evaluate.load('accuracy')

def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=1)
    return metric.compute(predictions=predictions, references=labels)

In [11]:

import torch
from transformers import Trainer, TrainingArguments
trainer = Trainer(
    model=model,
    args=args,
    train_dataset=train_ds,
    eval_dataset=val_ds,
    data_collator=collate_fn,
    compute_metrics=compute_metrics,
    tokenizer=feature_extractor,
)

In [12]:
trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy
1,No log,0.658613,0.948
2,0.933700,0.313869,0.956
3,0.933700,0.231727,0.95
4,0.146400,0.183186,0.956
5,0.146400,0.190162,0.96
6,0.052200,0.174274,0.958
7,0.052200,0.191822,0.956
8,0.027500,0.195622,0.954
9,0.016600,0.210815,0.954
10,0.016600,0.21683,0.954


TrainOutput(global_step=2820, training_loss=0.20989607530282745, metrics={'train_runtime': 314.7115, 'train_samples_per_second': 142.988, 'train_steps_per_second': 8.961, 'total_flos': 3.48738956568576e+18, 'train_loss': 0.20989607530282745, 'epoch': 10.0})