In [None]:
import os
import sys
import numpy as np
import pandas as pd
from fastai.vision.all import *

sys.path.append('../input/pytorch-image-model-jan-2021/pytorch-image-models-master/')
import timm

seed = 21
set_seed(seed, reproducible=True)

In [None]:
path = Path('../input/plant-pathology-2021-fgvc8')
data = pd.read_csv(path/'train.csv')
data.head()

In [None]:
def get_x(r):
    return path/'train_images'/r['image']

def get_y(r):
    return r['labels'].split(' ')

dblock = DataBlock(blocks = (ImageBlock, MultiCategoryBlock),
                    splitter = RandomSplitter(seed=seed),
                    get_x = get_x,
                    get_y = get_y,
                    item_tfms = RandomResizedCrop(224, min_scale=0.35),
                    batch_tfms = [*aug_transforms(mult=2.0, flip_vert=True, size=224),
                                  Normalize.from_stats(*imagenet_stats)])

In [None]:
dls = dblock.dataloaders(data, bs=64)
dls.show_batch(max_n=9)

In [None]:
f1score_multi = F1ScoreMulti()
model = timm.create_model('vit_base_patch16_224', num_classes=dls.c)
learn = Learner(dls, model, metrics=f1score_multi)

In [None]:
torch.cuda.empty_cache()

In [None]:
learn.lr_find()

In [None]:
learn.fit_one_cycle(10, 3e-3)

In [None]:
learn.export(f'vitb16.pkl')