# インストール

In [None]:
%%capture
!pip install timm

In [None]:
%%capture
# for tpu
!pip install cloud-tpu-client==0.10 https://storage.googleapis.com/tpu-pytorch/wheels/colab/torch_xla-2.0-cp310-cp310-linux_x86_64.whl

# 結果準備

In [None]:
!wget https://github.com/huggingface/pytorch-image-models/raw/main/results/results-imagenet.csv

In [None]:
import pandas as pd

# 精度結果
results_imagenet = pd.read_csv('results-imagenet.csv')

# 推論時間結果
columns = results_imagenet.columns.to_list()
columns.insert(1, 'type')
columns.insert(2, 'time')

In [None]:
!ls -la ~/.cache

# 測定

In [None]:
import os
import shutil
import signal
import time
import psutil
from tqdm import tqdm
import numpy as np
from PIL import Image
import timm
from timm.data import resolve_data_config
from timm.data.transforms_factory import create_transform
import torch
import torch_xla.core.xla_model as xm
import csv
from google.colab import runtime

import warnings
warnings.filterwarnings("ignore")

width = 1920#@param {type: "number"}
height = 1080#@param {type: "number"}
repeat = 5#@param {type: "number"}
#device_type = "cpu"#@param ["cpu", "cuda", "tpu"]
output_path = "results.csv"#@param {type: "string"}
terminate = False

device_types = ["cuda", "cpu"]

# results
if os.path.isfile(output_path):
  df = pd.read_csv(output_path)
else:
  df = pd.DataFrame(columns=columns)
results = []

# device
#device = device_type
#if device_type == "tpu":
#  device = xm.xla_device()

# create images
images = [Image.fromarray((np.random.rand(width, height, 3)*255).astype('uint8')).convert('RGB') for i in range(repeat)]

# save
def save_results(df, results, output_path):
  df_append = pd.DataFrame(data=results, columns=columns)
  df = pd.concat([df, df_append], ignore_index=True, axis=0)
  df.to_csv(output_path, index=False)

# model list
model_names = timm.list_pretrained()

for model_name in tqdm(model_names):
  try:
    for device_type in device_types:
      device = device_type

      res_imagenet = results_imagenet.query(f"model == '{model_name}'")
      res = df.query(f"type == '{device_type}' & model == '{model_name}'")
      if len(res_imagenet) == 0 or 0 < len(res):
        continue

      # clear cache
      if device_type == "cuda":
        torch.cuda.empty_cache()
      dsk = psutil.disk_usage('/')
      if 80 < dsk.percent:
        shutil.rmtree(f"{os.environ['HOME']}/.cache/huggingface")
        save_results(df, results, output_path)

      # inference
      with torch.no_grad():
        model = timm.create_model(model_name, pretrained=True)
        model.to(device)
        model.eval()
        transform = create_transform(**resolve_data_config(model.pretrained_cfg, model=model))

        start = time.perf_counter()
        [model(transform(image).unsqueeze(0).to(device)) for image in images]
        end = time.perf_counter()
      del model

      res_dict = res_imagenet.iloc[0,:].T.to_dict()
      res_dict['type'] = device_type
      res_dict['time'] = (end - start)/repeat
      results.append(res_dict)
      print(f"{model_name},{(end - start)/repeat},{device}")
  except KeyboardInterrupt:
    break

# result save
save_results(df, results, output_path)