**※ GPU環境で利用してください**

In [None]:
!pip install timm

In [None]:
import argparse
import operator
import os
import time
from collections import OrderedDict

import timm
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from PIL import Image
from timm.data import create_dataset, create_loader, resolve_data_config
from timm.optim import create_optimizer
from timm.utils import AverageMeter, accuracy
from timm.utils.summary import update_summary
from torch.autograd import Variable
from IPython.display import display

In [None]:
parser = argparse.ArgumentParser(description="Training Config", add_help=False)

parser.add_argument(
    "--opt",
    default="sgd",
    type=str,
    metavar="OPTIMIZER",
    help='Optimizer (default: "sgd"',
)
parser.add_argument(
    "--weight-decay", type=float, default=0.0001, help="weight decay (default: 0.0001)"
)
parser.add_argument(
    "--lr", type=float, default=0.01, metavar="LR", help="learning rate (default: 0.01)"
)
parser.add_argument(
    "--momentum",
    type=float,
    default=0.9,
    metavar="M",
    help="Optimizer momentum (default: 0.9)",
)
parser.add_argument(
    "--input-size",
    default=None,
    nargs=3,
    type=int,
    metavar="N N N",
    help="Input all image dimensions (d h w, e.g. --input-size 3 224 224), uses model default if empty",
)

args = parser.parse_args(["--input-size", "3", "224", "224"])

EPOCHS = 30
BATCH_SIZE = 32
NUM_WORKERS = 2

In [None]:
# 適宜GoogleColab上のデータセットディレクトリ(train, validation, testが含まれれるディレクトリ)のパスを指定してください
dataset_path = '/content/drive/MyDrive/VisionTransformer/'

In [None]:
# 対応モデルを確認
model_names = timm.list_models(pretrained=True)
model_names

In [None]:
NUM_FINETUNE_CLASSES = 2 # {'clear': 0, 'cloudy': 1} の2種類
model = timm.create_model('vit_base_patch16_224', pretrained=True, num_classes=NUM_FINETUNE_CLASSES)
model.cuda()

In [None]:
data_config = resolve_data_config(vars(args), model=model)

In [None]:
dataset_train = create_dataset('train', root=os.path.join(dataset_path, 'train'), is_training=True, batch_size=BATCH_SIZE)
dataset_eval = create_dataset('validation', root=os.path.join(dataset_path, 'validation'), is_training=False, batch_size=BATCH_SIZE)
dataset_test = create_dataset('test', root=os.path.join(dataset_path, 'test'), is_training=False, batch_size=BATCH_SIZE)

In [None]:
loader_train = create_loader(dataset_train, input_size=data_config['input_size'], batch_size=BATCH_SIZE, is_training=True, num_workers=NUM_WORKERS)
loader_eval = create_loader(dataset_eval, input_size=data_config['input_size'], batch_size=BATCH_SIZE, is_training=False, num_workers=NUM_WORKERS)
loader_test = create_loader(dataset_test, input_size=data_config['input_size'], batch_size=BATCH_SIZE, is_training=False, num_workers=NUM_WORKERS)

In [None]:
train_loss_fn = nn.CrossEntropyLoss().cuda()
validate_loss_fn = nn.CrossEntropyLoss().cuda()

In [None]:
optimizer = create_optimizer(args, model)

In [None]:
def train_one_epoch(epoch, model, loader, optimizer, loss_fn, args, output_dir=None):
    second_order = hasattr(optimizer, "is_second_order") and optimizer.is_second_order
    batch_time_m = AverageMeter()
    data_time_m = AverageMeter()
    losses_m = AverageMeter()

    model.train()

    end = time.time()
    num_updates = epoch * len(loader)
    for _, (input, target) in enumerate(loader):

        data_time_m.update(time.time() - end)

        output = model(input)
        loss = loss_fn(output, target)

        optimizer.zero_grad()

        loss.backward(create_graph=second_order)

        optimizer.step()

        torch.cuda.synchronize()
        num_updates += 1
        batch_time_m.update(time.time() - end)

        end = time.time()

    if hasattr(optimizer, "sync_lookahead"):
        optimizer.sync_lookahead()

    return OrderedDict([("loss", losses_m.avg)])

In [None]:
def validate(model, loader, loss_fn, args):
    batch_time_m = AverageMeter()
    losses_m = AverageMeter()
    accuracy_m = AverageMeter()

    model.eval()

    end = time.time()
    with torch.no_grad():
        for _, (input, target) in enumerate(loader):

            input = input.cuda()
            target = target.cuda()

            output = model(input)

            if isinstance(output, (tuple, list)):
                output = output[0]

            loss = loss_fn(output, target)
            acc1, _ = accuracy(output, target, topk=(1, 2))

            reduced_loss = loss.data

            torch.cuda.synchronize()

            losses_m.update(reduced_loss.item(), input.size(0))
            accuracy_m.update(acc1.item(), output.size(0))

            batch_time_m.update(time.time() - end)
            end = time.time()

    metrics = OrderedDict([("loss", losses_m.avg), ("accuracy", accuracy_m.avg)])

    return metrics

In [None]:
num_epochs = EPOCHS
eval_metric = "accuracy"
best_metric = None
best_epoch = None
compare = operator.gt

# 学習結果CSVファイルやファインチューニング後のモデルデータの出力先
output_dir = "/content/drive/MyDrive/VisionTransformer/output"

In [None]:
for epoch in range(0, num_epochs):
    train_metrics = train_one_epoch(
        epoch, model, loader_train, optimizer, train_loss_fn, args, output_dir=output_dir
    )

    eval_metrics = validate(model, loader_eval, validate_loss_fn, args)

    if output_dir is not None:
        update_summary(
            epoch,
            train_metrics,
            eval_metrics,
            os.path.join(output_dir, "summary.csv"),
            write_header=best_metric is None,
        )

    metric = eval_metrics[eval_metric]
    if best_metric is None or compare(metric, best_metric):
        best_metric = metric
        best_epoch = epoch
        torch.save(model.state_dict(), os.path.join(output_dir, "best_model.pth"))

    print(epoch)
    print(eval_metrics)
    print("Best metric: {0} (epoch {1})".format(best_metric, best_epoch))

In [None]:
model.load_state_dict(
    torch.load(
        os.path.join(output_dir, "best_model.pth"), map_location=torch.device("cuda")
    )
)

In [None]:
model.eval()

image_size = data_config["input_size"][-1]
loader = transforms.Compose([transforms.Resize(image_size), transforms.ToTensor()])


def image_loader(image_name):
    image = Image.open(image_name).convert("RGB")
    image = loader(image)
    image = Variable(image, requires_grad=True)
    image = image.unsqueeze(0)
    return image.cuda()


m = nn.Softmax(dim=1)

In [None]:
clear_image_path = os.path.join(dataset_path, 'test/clear/12_3542_1635.png')
predicted_clear_image = image_loader(clear_image_path)
display(Image.open(clear_image_path))
m(model(predicted_clear_image))

In [None]:
cloudy_image_path = os.path.join(dataset_path, 'test/cloudy/12_3503_1735.png')
predicted_cloudy_image = image_loader(cloudy_image_path)
display(Image.open(cloudy_image_path))
m(model(predicted_cloudy_image))

In [None]:
def test(model, loader, args):
    batch_time_m = AverageMeter()
    accuracy_m = AverageMeter()

    model.eval()

    end = time.time()
    with torch.no_grad():
        for _, (input, target) in enumerate(loader):
            input = input.cuda()
            target = target.cuda()
            
            output = model(input)

            if isinstance(output, (tuple, list)):
                output = output[0]

            acc1, _ = accuracy(output, target, topk=(1, 2))

            torch.cuda.synchronize()

            accuracy_m.update(acc1.item(), output.size(0))

            batch_time_m.update(time.time() - end)
            end = time.time()

    return {'accuracy': accuracy_m.avg}

In [None]:
test(model, loader_test, args)