In [1]:
# 텐서보드 테스트

from torch.utils.tensorboard import SummaryWriter
import numpy as np

writer = SummaryWriter()

for n_iter in range(100):
    writer.add_scalar('Loss/train', scalar_value =np.random.random(), global_step=n_iter)
    writer.add_scalar('Loss/test',  scalar_value =np.random.random(), global_step=n_iter)
    writer.add_scalar('Accuracy/train',  scalar_value =np.random.random(), global_step=n_iter)
    writer.add_scalar('Accuracy/test',  scalar_value =np.random.random(), global_step=n_iter)

In [9]:
# coding=utf-8
from __future__ import absolute_import, division, print_function
from sklearn.metrics import confusion_matrix, accuracy_score, f1_score, classification_report, precision_score, recall_score

import logging
import argparse
import os
import cv2
import random
import numpy as np
import pandas as pd
import time
from PIL import Image
from datetime import timedelta
from models.modeling import VisionTransformer, CONFIGS
import torch
import torch.distributed as dist

from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms
from apex import amp
from apex.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader, RandomSampler, DistributedSampler, SequentialSampler

from models.modeling import VisionTransformer, CONFIGS
from utils.scheduler import WarmupLinearSchedule, WarmupCosineSchedule
from utils.data_utils import get_loader
from utils.dist_util import get_world_size

from utils.dataset import custom

In [7]:
test_transform=transforms.Compose([transforms.Resize((600, 600), Image.BILINEAR),
                            transforms.CenterCrop((448, 448)),
                            transforms.ToTensor(),
                            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])

testset = custom(root='/data/TransFG_experiment/datasets/custom', dtype=2, transform = test_transform) 

In [8]:
test_sampler = SequentialSampler(testset)#if args.local_rank == -1 else DistributedSampler(testset) #SequentialSampler : 항상 같은 순서
test_loader = DataLoader(testset,
                             sampler=test_sampler,
                             batch_size=40,
                             num_workers=4,
                             pin_memory=True) if testset is not None else None

In [10]:
img_size = 448
smoothing_value = 0.0
pretrained_model = "output/sample_run_checkpoint.bin"
config = CONFIGS["ViT-B_16"] 
#config = CONFIGS["ViT-B_32"]
config.split = 'overlap'
config.slide_step = 12
num_classes = pd.read_csv('train_x.csv')['label'].nunique()
model = VisionTransformer(config, img_size, num_classes, smoothing_value, zero_head=True)

In [11]:
if pretrained_model is not None:
    pretrained_model = torch.load(pretrained_model, map_location=torch.device('cpu'))['model']
    model.load_state_dict(pretrained_model)

In [12]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.eval()

VisionTransformer(
  (transformer): Transformer(
    (embeddings): Embeddings(
      (patch_embeddings): Conv2d(3, 768, kernel_size=(16, 16), stride=(12, 12))
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): Encoder(
      (layer): ModuleList(
        (0): Block(
          (attention_norm): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
          (ffn_norm): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
          (ffn): Mlp(
            (fc1): Linear(in_features=768, out_features=3072, bias=True)
            (fc2): Linear(in_features=3072, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (attn): Attention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (out): Linear(in_features=768, out_features=768, bias=True)
     

In [13]:
epoch_iterator = tqdm(test_loader)

  0%|          | 0/2121 [00:00<?, ?it/s]

In [14]:
all_preds, all_label = [], []

In [15]:
start = time.time()
with torch.no_grad():
    for step, batch in enumerate(epoch_iterator): 
        batch = tuple(t.to(device) for t in batch)
        x, y = batch
        if len(y)==1:
            print(y)
            break
        loss, logits = model(x, y)
        #loss = loss.mean()
        preds = torch.argmax(logits, dim=-1) 
        all_label.append(list(y.cpu().numpy()))
        all_preds.append(list(preds.cpu().numpy())) 
end = time.time() 
print("Time elapsed: ", timedelta(seconds=end-start))        

100%|██████████| 2121/2121 [55:10<00:00,  1.56s/it]

Time elapsed:  0:55:09.464454





In [16]:
all_preds = np.array(sum(all_preds,[]))
all_label = np.array(sum(all_label,[]))

In [17]:
all_preds.shape

(84808,)

In [18]:
all_preds[:10]#.reshape(-1)

array([389,  35,  57, 379,  98, 386,   5, 418,  86, 228])

In [19]:
all_label[:10]

array([389,  35,  57, 379,  98, 386,   7, 418,  86, 228])

In [20]:
print('Accuracy :',accuracy_score(all_label, all_preds)) 
print('Precision :',precision_score(all_label, all_preds, average='weighted')) 
print('Recall :',recall_score(all_label, all_preds, average='weighted')) 
print('F1 score :',f1_score(all_label, all_preds, average='weighted'))

Accuracy : 0.8821337609659466
Precision : 0.8765267034400398
Recall : 0.8821337609659466
F1 score : 0.8744986230971883


  _warn_prf(average, modifier, msg_start, len(result))


In [21]:
print(classification_report(all_label, all_preds))

  _warn_prf(average, modifier, msg_start, len(result))


              precision    recall  f1-score   support

           0       0.72      0.79      0.75       103
           1       0.75      0.42      0.54        36
           2       0.00      0.00      0.00        12
           3       0.67      0.39      0.49        36
           4       0.71      0.48      0.57        46
           5       0.55      0.66      0.60       253
           6       0.53      0.34      0.42       102
           7       0.46      0.28      0.35       163
           8       1.00      0.50      0.67        14
           9       0.81      0.62      0.70        42
          10       0.66      0.78      0.71       329
          11       0.65      0.83      0.73       521
          12       0.40      0.05      0.09        39
          13       0.00      0.00      0.00        10
          14       0.66      0.35      0.46        82
          15       0.62      0.44      0.51       157
          16       0.50      0.40      0.44        25
          17       0.50    

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
