In [1]:
"""
C10でFine-tuningしたViTモデルのフォワードパスを実行する
"""

import os, sys, math
sys.path.append("../src")
import numpy as np
import torch
from datasets import load_dataset, load_metric
from transformers import DefaultDataCollator, ViTForImageClassification, TrainingArguments, Trainer
from utils.helper import get_device
from utils.vit_util import processor, transforms, compute_metrics

2024-04-18 22:41:25.038648: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2024-04-18 22:41:26.119763: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory
2024-04-18 22:41:26.119918: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory
  met_acc = load_metric("accuracy")


In [2]:
# デバイス (cuda, or cpu) の取得
device = get_device()
# datasetをロード (初回の読み込みだけやや時間かかる)
cifar10 = load_dataset("cifar10")
# 読み込まれた時にリアルタイムで前処理を適用するようにする
cifar10_preprocessed = cifar10.with_transform(transforms)
# バッチごとの処理のためのdata_collator
data_collator = DefaultDataCollator()
# ラベルを示す文字列のlist
labels = cifar10_preprocessed["train"].features["label"].names
# pretrained modelのロード
pretrained_dir = "/src/src/out_vit_c10"
model = ViTForImageClassification.from_pretrained(pretrained_dir).to(device)
model.eval()
# 学習時の設定をロード
training_args = torch.load(os.path.join(pretrained_dir, "training_args.bin"))
# Trainerオブジェクトの作成
trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
    train_dataset=cifar10_preprocessed["train"],
    eval_dataset=cifar10_preprocessed["test"],
    tokenizer=processor,
)

Device: cuda


Found cached dataset parquet (/root/.cache/huggingface/datasets/parquet/plain_text-d4c080360fb556b0/0.0.0/14a00e99c0d15a23649d0db8944380ac81082d4b021f398733dd84f3a6c569a7)


  0%|          | 0/2 [00:00<?, ?it/s]

In [3]:
li = 11
key = f"vit.encoder.layer.{li}.intermediate.dense"

for name, param in model.named_parameters():
    if name.startswith(key):
        print(name, param.shape)

vit.encoder.layer.11.intermediate.dense.weight torch.Size([3072, 768])
vit.encoder.layer.11.intermediate.dense.bias torch.Size([3072])


In [14]:
cifar10_preprocessed["train"][:5]["pixel_values"].to(device).shape

torch.Size([5, 3, 224, 224])

In [13]:
mf = model.forward(cifar10_preprocessed["train"][:5]["pixel_values"].to(device), output_hidden_states=True, output_attentions=True, output_intermediate_states=True)
print(mf.keys())
# 後ろの3つのキーのタプルの長さを取得
print([len(mf[k]) for k in list(mf.keys())[-3:]])

odict_keys(['logits', 'hidden_states', 'attentions', 'intermediate_states'])
[13, 12, 12]


In [5]:
mf.logits.shape

torch.Size([5, 10])

In [6]:
print(len(mf.hidden_states))
for i, hs in enumerate(mf.hidden_states):
    print(f"mf.hidden_states[{i}].shape = {hs.shape}")
    print(f"mf.hidden_states[{i}][:, 0, :].shape = {hs[:, 0, :].shape}")

13
mf.hidden_states[0].shape = torch.Size([5, 197, 768])
mf.hidden_states[0][:, 0, :].shape = torch.Size([5, 768])
mf.hidden_states[1].shape = torch.Size([5, 197, 768])
mf.hidden_states[1][:, 0, :].shape = torch.Size([5, 768])
mf.hidden_states[2].shape = torch.Size([5, 197, 768])
mf.hidden_states[2][:, 0, :].shape = torch.Size([5, 768])
mf.hidden_states[3].shape = torch.Size([5, 197, 768])
mf.hidden_states[3][:, 0, :].shape = torch.Size([5, 768])
mf.hidden_states[4].shape = torch.Size([5, 197, 768])
mf.hidden_states[4][:, 0, :].shape = torch.Size([5, 768])
mf.hidden_states[5].shape = torch.Size([5, 197, 768])
mf.hidden_states[5][:, 0, :].shape = torch.Size([5, 768])
mf.hidden_states[6].shape = torch.Size([5, 197, 768])
mf.hidden_states[6][:, 0, :].shape = torch.Size([5, 768])
mf.hidden_states[7].shape = torch.Size([5, 197, 768])
mf.hidden_states[7][:, 0, :].shape = torch.Size([5, 768])
mf.hidden_states[8].shape = torch.Size([5, 197, 768])
mf.hidden_states[8][:, 0, :].shape = torch.Size

In [10]:
print(len(mf.attentions))
for i, attn in enumerate(mf.attentions):
    print(f"mf.attentions[{i}].shape = {attn.shape}")

12
mf.attentions[0].shape = torch.Size([5, 12, 197, 197])
mf.attentions[1].shape = torch.Size([5, 12, 197, 197])
mf.attentions[2].shape = torch.Size([5, 12, 197, 197])
mf.attentions[3].shape = torch.Size([5, 12, 197, 197])
mf.attentions[4].shape = torch.Size([5, 12, 197, 197])
mf.attentions[5].shape = torch.Size([5, 12, 197, 197])
mf.attentions[6].shape = torch.Size([5, 12, 197, 197])
mf.attentions[7].shape = torch.Size([5, 12, 197, 197])
mf.attentions[8].shape = torch.Size([5, 12, 197, 197])
mf.attentions[9].shape = torch.Size([5, 12, 197, 197])
mf.attentions[10].shape = torch.Size([5, 12, 197, 197])
mf.attentions[11].shape = torch.Size([5, 12, 197, 197])


In [11]:
print(len(mf.intermediate_states))
for i, med in enumerate(mf.intermediate_states):
    print(f"mf.intermediate_states[{i}].shape = {med.shape}")

12
mf.intermediate_states[0].shape = torch.Size([5, 197, 3072])
mf.intermediate_states[1].shape = torch.Size([5, 197, 3072])
mf.intermediate_states[2].shape = torch.Size([5, 197, 3072])
mf.intermediate_states[3].shape = torch.Size([5, 197, 3072])
mf.intermediate_states[4].shape = torch.Size([5, 197, 3072])
mf.intermediate_states[5].shape = torch.Size([5, 197, 3072])
mf.intermediate_states[6].shape = torch.Size([5, 197, 3072])
mf.intermediate_states[7].shape = torch.Size([5, 197, 3072])
mf.intermediate_states[8].shape = torch.Size([5, 197, 3072])
mf.intermediate_states[9].shape = torch.Size([5, 197, 3072])
mf.intermediate_states[10].shape = torch.Size([5, 197, 3072])
mf.intermediate_states[11].shape = torch.Size([5, 197, 3072])
