In [None]:
from fastai.vision.all import *
seed = 42
set_seed(seed, reproducible=True)

In [None]:
path = Path('../input/plant-pathology-2021-fgvc8')
data_path = Path('../input/resized-plant2021')

In [None]:
df = pd.read_csv(path/'train.csv')
df.head()

In [None]:
df.shape

In [None]:
df['labels'].value_counts()

In [None]:
item_tfms = [RandomResizedCrop(224)]
batch_tfms = [Dihedral(p=0.5),
              Rotate(max_deg=180, p=0.5, pad_mode='reflection'),
              Zoom(min_zoom=1.0, max_zoom=1.2, p=0.5, pad_mode='reflection'),
              Warp(magnitude=0.2, p=0.5),
              Brightness(max_lighting=0.15, p=0.75), 
              Contrast(max_lighting=0.15, p=0.75)]

dls = ImageDataLoaders.from_df(
    df = df,
    folder = '../input/resized-plant2021/img_sz_256',
    item_tfms = item_tfms,
    batch_tfms = batch_tfms,
    splitter = RandomSplitter(valid_pct=0.1),
    label_delim = ' ',
    bs=128)

In [None]:
dls.show_batch()

In [None]:
!mkdir -p /root/.cache/torch/hub/checkpoints/
!cp ../input/resnet50/resnet50.pth  /root/.cache/torch/hub/checkpoints/resnet50-19c8e357.pth

In [None]:
learn = cnn_learner(
    dls,
    resnet50, 
    opt_func=Adam, 
    loss_func=BCEWithLogitsLossFlat(), 
    metrics=[accuracy_multi, F1ScoreMulti()]).to_fp16()

In [None]:
learn.lr_find()

In [None]:
learn.fine_tune(5, 0.025118863582611083,
    cbs = [
        EarlyStoppingCallback(patience=2),
        SaveModelCallback(),
    ],
    freeze_epochs=1
)

In [None]:
learn.show_results()

In [None]:
interp = ClassificationInterpretation.from_learner(learn)

In [None]:
interp.plot_top_losses(9)

In [None]:
submission_df = pd.read_csv(path/"sample_submission.csv")
submission_df.head()

In [None]:
test_image_path_series = submission_df["image"].apply(lambda x: f"../input/plant-pathology-2021-fgvc8/test_images/{x}")
test_image_path_series.head()

In [None]:
test_dl = learn.dls.test_dl(test_image_path_series)
preds, _ = learn.get_preds(dl=test_dl)

In [None]:
thresh = 0.5
labelled_preds = [' '.join([learn.dls.vocab[i] for i,p in enumerate(pred) if p > thresh]) for pred in preds]
labelled_preds

In [None]:
submission_df["labels"] = labelled_preds
submission_df.head()

In [None]:
submission_df.to_csv("submission.csv", index=False)