In [None]:
from fastai.vision.all import *

In [None]:
set_seed(999)

To properly use timm we need to get creative with the imports:

In [None]:
%cd ../input/timm030/pytorch-image-models-master/pytorch-image-models-master/
from timm import create_model
%cd ../../../../

In [None]:
%ls

We'll want to be able to recreate our model fully, so we'll bring in the `wwf` code:

In [None]:
from fastai.vision.learner import _update_first_layer

# Cell
def create_timm_body(arch:str, pretrained=True, cut=None, n_in=3):
    "Creates a body from any model in the `timm` library."
    model = create_model(arch, pretrained=pretrained, num_classes=0, global_pool='')
    _update_first_layer(model, n_in, pretrained)
    if cut is None:
        ll = list(enumerate(model.children()))
        cut = next(i for i,o in reversed(ll) if has_pool_type(o))
    if isinstance(cut, int): return nn.Sequential(*list(model.children())[:cut])
    elif callable(cut): return cut(model)
    else: raise NamedError("cut must be either integer or function")

# Cell
def create_timm_model(arch:str, n_out, cut=None, pretrained=True, n_in=3, init=nn.init.kaiming_normal_, custom_head=None,
                     concat_pool=True, **kwargs):
    "Create custom architecture using `arch`, `n_in` and `n_out` from the `timm` library"
    body = create_timm_body(arch, pretrained, None, n_in)
    if custom_head is None:
        nf = num_features_model(nn.Sequential(*body.children())) * (2 if concat_pool else 1)
        head = create_head(nf, n_out, concat_pool=concat_pool, **kwargs)
    else: head = custom_head
    model = nn.Sequential(body, head)
    if init is not None: apply_init(model[1], init)
    return model

# Cell
from fastai.vision.learner import _add_norm

# Cell
def timm_learner(dls, arch:str, loss_func=None, pretrained=True, cut=None, splitter=None,
                y_range=None, config=None, n_out=None, normalize=True, **kwargs):
    "Build a convnet style learner from `dls` and `arch` using the `timm` library"
    if config is None: config = {}
    if n_out is None: n_out = get_c(dls)
    assert n_out, "`n_out` is not defined, and could not be inferred from data, set `dls.c` or pass `n_out`"
    if y_range is None and 'y_range' in config: y_range = config.pop('y_range')
    model = create_timm_model(arch, n_out, default_split, pretrained, y_range=y_range, **config)
    learn = Learner(dls, model, loss_func=loss_func, splitter=default_split, **kwargs)
    if pretrained: learn.freeze()
    return learn

Recreate our data to get access to the `test_dl`:

In [None]:
path = Path("input")
data_path = path/'cassava-leaf-disease-classification'

In [None]:
data_path.ls()

In [None]:
df = pd.read_csv(data_path/'train.csv')

In [None]:
df['image_id'] = df['image_id'].apply(lambda x: f'train_images/{x}')

In [None]:
blocks = (ImageBlock, CategoryBlock)
splitter = RandomSplitter(valid_pct=0.2)
def get_x(row): return data_path/row['image_id']

def get_y(row): return row['label']
item_tfms = [Resize(512)]
batch_tfms = [RandomResizedCropGPU(448), *aug_transforms(), Normalize.from_stats(*imagenet_stats)]
block = DataBlock(blocks = blocks,
                 get_x = get_x,
                 get_y = get_y,
                 splitter = splitter,
                 item_tfms = item_tfms,
                 batch_tfms = batch_tfms)
dls = block.dataloaders(df, bs=32)


Build a `Learner`

In [None]:
learn = timm_learner(dls, 'efficientnet_b3', metrics=accuracy, pretrained=False)

And load in our weights

In [None]:
learn.model_dir = Path('input/b3_example_submission')

In [None]:
load_model(Path('input/b3-example-submission/b3.pth'), learn.model, learn.opt)

In [None]:
sample_df = pd.read_csv(data_path/'sample_submission.csv')
sample_df.head()

In [None]:
sample_copy = sample_df.copy()
sample_copy['image_id'] = sample_copy['image_id'].apply(lambda x: f'test_images/{x}')

Finally we can grab our predictions using TTA

In [None]:
test_dl = learn.dls.test_dl(sample_copy)

In [None]:
preds, _ = learn.tta(dl=test_dl)

In [None]:
sample_df['label'] = preds.argmax(dim=-1).numpy()

In [None]:
%cd working

In [None]:
sample_df.to_csv('submission.csv',index=False)