# Library

In [None]:
%cd ../input/efficientnet-pytorch/EfficientNet-PyTorch/EfficientNet-PyTorch-master
from efficientnet_pytorch import EfficientNet
%cd -

In [None]:
from fastai.vision.all import *
import albumentations

## Setting a seed

In [None]:
set_seed(42)

## Every function and class used to create the model from [the training notebook](https://www.kaggle.com/hubertwojewoda/cassava-efficientnet-b3) needs to be redefined here

In [None]:
class AlbumentationsTransform(RandTransform):
    "A transform handler for multiple `Albumentation` transforms"
    split_idx,order=None,2
    def __init__(self, train_aug, valid_aug): store_attr()
    
    def before_call(self, b, split_idx):
        self.idx = split_idx
    
    def encodes(self, img: PILImage):
        if self.idx == 0:
            aug_img = self.train_aug(image=np.array(img))['image']
        else:
            aug_img = self.valid_aug(image=np.array(img))['image']
        return PILImage.create(aug_img)

In [None]:
def get_x(row): return data_path/row['image_id']
def get_y(row): return row['label']

In [None]:
class CassavaModel(Module):
    def __init__(self, num_classes):

        self.effnet = EfficientNet.from_pretrained("efficientnet-b3")
        self.dropout = nn.Dropout(0.1)
        self.out = nn.Linear(1536, num_classes)

    def forward(self, image):
        batch_size, _, _, _ = image.shape

        x = self.effnet.extract_features(image)
        x = F.adaptive_avg_pool2d(x, 1).reshape(batch_size, -1)
        outputs = self.out(self.dropout(x))
        return outputs

# Load the model

In [None]:
Path('/kaggle/input').ls()

In [None]:
learn = load_learner(Path('/kaggle/input/effnet-inference/inference(1)'), cpu=False)

In [None]:
learn.to_fp16()

# Submission

In [None]:
path = Path("../input")
data_path = path/'cassava-leaf-disease-classification'

In [None]:
test_df = pd.read_csv(data_path/'sample_submission.csv')
test_df.head()

In [None]:
test_copy = test_df.copy()
test_copy['image_id'] = test_copy['image_id'].apply(lambda x: f'test_images/{x}')

In [None]:
test_dl = learn.dls.test_dl(test_copy)

In [None]:
preds, _ = learn.get_preds(dl=test_dl)

In [None]:
test_df['label'] = preds.argmax(dim=-1).numpy()

In [None]:
test_df.to_csv('submission.csv', index=False)

In [None]:
test_df.head()