https://github.com/facebookresearch/deit/blob/main/README_deit.md

#bash

pip install torch==1.8.1+cu111 torchvision==0.9.1+cu111 torchaudio==0.8.1 -f https://download.pytorch.org/whl/lts/1.8/torch_lts.html
pip install torchtext==0.9.1
pip install timm==0.3.2

- ImportError: cannot import name 'container_abcs' from 'torch._six' 오류 해결 위해 다운그레이드 필요!

### import

In [1]:
import torch
import timm

print(torch.__version__)  # PyTorch 버전 확인
print(timm.__version__)   # timm 버전 확인


  from .autonotebook import tqdm as notebook_tqdm


1.8.1+cu111
0.3.2


In [2]:
import torch
import timm
assert timm.__version__ == "0.3.2"

### Data

In [12]:
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
import torchvision.transforms as transforms

# 데이터 전처리
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # DeiT 모델 입력 크기에 맞춤
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

test_dataset = ImageFolder(root="C:/Users/seonahryu/Desktop/brp1/ILSVRC2012_img_val", transform=transform)
test_loader = DataLoader(dataset=test_dataset, batch_size=32, shuffle=False)

### Evaluation func

In [18]:
import time
import torch.nn as nn
from thop import profile # FLOPs

# 모델 FLOPs 계산 함수
def compute_flops(model, input_size=(1, 3, 224, 224)):
    flops, params = profile(model, inputs=(torch.randn(input_size),))
    return flops

# 모델 평가 함수 정의
def evaluate_model(model, dataloader):
    model.eval()
    correct_top1 = 0
    correct_top5 = 0
    total = 0
    loss_fn = nn.CrossEntropyLoss()
    total_loss = 0.0

    start_time = time.time()  # 평가 시작 시간 기록

    with torch.no_grad():
        for images, labels in dataloader:
            images, labels = images.to('cuda' if torch.cuda.is_available() else 'cpu'), labels.to('cuda' if torch.cuda.is_available() else 'cpu')
            outputs = model(images)
            _, predicted_top1 = torch.max(outputs.data, 1)
            predicted_top5 = torch.topk(outputs.data, 5)[1]  # Top-5 예측
            
            total += labels.size(0)
            correct_top1 += (predicted_top1 == labels).sum().item()
            
            # Top-5 비교
            correct_top5 += (predicted_top5 == labels.view(-1, 1)).sum().item()
            
            loss = loss_fn(outputs, labels)
            total_loss += loss.item()

    end_time = time.time()  # 평가 종료 시간 기록
    elapsed_time = end_time - start_time  # 총 소요 시간 계산
    
    accuracy_top1 = 100 * correct_top1 / total
    accuracy_top5 = 100 * correct_top5 / total
    avg_loss = total_loss / len(dataloader)  # 평균 손실 계산
    
    # 결과 출력
    print(f'Acc@1: {accuracy_top1:.3f} Acc@5: {accuracy_top5:.3f} loss: {avg_loss:.3f}')
    
    # 이미지 처리 속도 계산
    im_per_sec = total / elapsed_time
    print(f'Processing speed: {im_per_sec:.2f} images/sec')
    
    # FLOPs 계산
    flops = compute_flops(model)
    print(f'FLOPs: {flops:.2f}')

## DeiT-tiny

In [14]:
model_tiny = torch.hub.load('facebookresearch/deit:main', 'deit_tiny_patch16_224', pretrained=True)

Using cache found in C:\Users\seonahryu/.cache\torch\hub\facebookresearch_deit_main


In [15]:
model_tiny.eval()

VisionTransformer(
  (patch_embed): PatchEmbed(
    (proj): Conv2d(3, 192, kernel_size=(16, 16), stride=(16, 16))
  )
  (pos_drop): Dropout(p=0.0, inplace=False)
  (blocks): ModuleList(
    (0): Block(
      (norm1): LayerNorm((192,), eps=1e-06, elementwise_affine=True)
      (attn): Attention(
        (qkv): Linear(in_features=192, out_features=576, bias=True)
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=192, out_features=192, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (drop_path): Identity()
      (norm2): LayerNorm((192,), eps=1e-06, elementwise_affine=True)
      (mlp): Mlp(
        (fc1): Linear(in_features=192, out_features=768, bias=True)
        (act): GELU()
        (fc2): Linear(in_features=768, out_features=192, bias=True)
        (drop): Dropout(p=0.0, inplace=False)
      )
    )
    (1): Block(
      (norm1): LayerNorm((192,), eps=1e-06, elementwise_affine=True)
      (attn): Attention(
        (qkv

In [19]:
evaluate_model(model_tiny, test_loader) #34min

Acc@1: 69.952 Acc@5: 90.014 loss: 1.310
Processing speed: 24.46 images/sec
[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv2d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.dropout.Dropout'>.
[INFO] Register count_normalization() for <class 'torch.nn.modules.normalization.LayerNorm'>.
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
FLOPs: 1078633728.00


which should give

* Acc@1 72.202 Acc@5 91.124 loss 1.219

제발...

## DeiT-small

In [22]:
model_small = torch.hub.load('facebookresearch/deit:main', 'deit_small_patch16_224', pretrained=True)

Using cache found in C:\Users\seonahryu/.cache\torch\hub\facebookresearch_deit_main


In [None]:
model_small.eval()
evaluate_model(model_small, test_loader) #69min

Acc@1: 78.846 Acc@5: 94.488 loss: 0.928
Processing speed: 11.94 images/sec
[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv2d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.dropout.Dropout'>.
[INFO] Register count_normalization() for <class 'torch.nn.modules.normalization.LayerNorm'>.
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
FLOPs: 4248783360.00


giving

* Acc@1 79.854 Acc@5 94.968 loss 0.881

## DeiT-base

In [24]:
model_base = torch.hub.load('facebookresearch/deit:main', 'deit_base_patch16_224', pretrained=True)

Using cache found in C:\Users\seonahryu/.cache\torch\hub\facebookresearch_deit_main
Downloading: "https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth" to C:\Users\seonahryu/.cache\torch\hub\checkpoints\deit_base_patch16_224-b5f2ef4d.pth
100%|██████████| 330M/330M [00:09<00:00, 36.1MB/s] 


In [25]:
model_base.eval()
evaluate_model(model_base, test_loader) #263min

Acc@1: 81.252 Acc@5: 95.322 loss: 0.854
Processing speed: 3.16 images/sec
[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv2d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.dropout.Dropout'>.
[INFO] Register count_normalization() for <class 'torch.nn.modules.normalization.LayerNorm'>.
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
FLOPs: 16863630336.00


This should give

* Acc@1 81.846 Acc@5 95.594 loss 0.820