In [None]:
!pip install efficientnet_pytorch

In [None]:
%autosave 20
import os
import gc
gc.enable()
import time
from glob import glob
import random
from datetime import datetime

import cv2
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
%matplotlib inline
%config InlineBackend.figure_format='retina'
from PIL import Image
from tqdm import tqdm 

import torch
import torchvision
# import pretrainedmodels
import efficientnet_pytorch
from torchvision import transforms, models
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
# from torch.cuda import amp
from torch.utils.data import DataLoader, Dataset
from torch.utils.data.sampler import BatchSampler, SequentialSampler
import albumentations as A
from albumentations.pytorch.transforms import ToTensorV2
from albumentations import (
    Compose, HorizontalFlip, CLAHE, HueSaturationValue,
    RandomBrightness, RandomContrast, RandomGamma, OneOf, Resize,
    ToFloat, ShiftScaleRotate, GridDistortion, RandomRotate90, Cutout,
    RGBShift, RandomBrightness, RandomContrast, Blur, MotionBlur, MedianBlur, GaussNoise, CoarseDropout,
    IAAAdditiveGaussianNoise, GaussNoise, OpticalDistortion, RandomSizedCrop, VerticalFlip
)
from catalyst.data.sampler import DistributedSamplerWrapper, BalanceClassSampler
# from apex import amp

import sklearn
from sklearn import metrics
from sklearn.model_selection import GroupKFold

import warnings
warnings.simplefilter('ignore')
warnings.filterwarnings("ignore")

print(torch.__version__)
print("\nCUDNN VERSION: {}\n".format(torch.backends.cudnn.version()))

In [None]:
def fix_seed(seed=2020):
    random.seed(seed)
    np.random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True

fix_seed()
print('Seeding Completed.')

In [None]:
class EfficientNetB2(nn.Module):
    def __init__(self, pretrained=None):
        super(EfficientNetB2, self).__init__()
        self.model = efficientnet_pytorch.EfficientNet.from_name('efficientnet-b2')
        if pretrained is not None:
            self.model.load_state_dict(
                torch.load('../input/efficientnet-pytorch/efficientnet-b2-27687264.pth')
            )
        self.dropout = nn.Dropout(p=0.1)
        self.classifier = nn.Linear(in_features=1408, out_features=4)
    
    def forward(self, images):
        batch_size, _, _, _ = images.shape
        features = self.model.extract_features(images)
        avg_pool = F.adaptive_avg_pool2d(features, 1).reshape(batch_size, -1)
        
        # Multi Sample Dropout
#         logits = torch.mean(
#             torch.stack(
#                 [self.classifier(self.dropout(avg_pool)) for _ in range(5)],
#                 dim=0,
#             ),
#             dim=0,
#         )
        logits = self.classifier(avg_pool)
        return logits

In [None]:
net = EfficientNetB2(pretrained=None)
net = nn.DataParallel(net)

In [None]:
checkpoint = torch.load('../input/alaska-2-checkpoints/best-checkpoint-001epoch.bin')
net.load_state_dict(checkpoint['model_state_dict']);
net.eval();

In [None]:
AUGMENTATIONS_TEST = Compose([
    ToFloat(max_value=255),
    ToTensorV2()
], p=1)

In [None]:
class Alaska2TestDataset(Dataset):

    def __init__(self, df, augmentations=None):

        self.data = df
        self.augment = augmentations

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        fn = self.data.loc[idx][0]
        im = cv2.imread(fn)[:, :, ::-1]

        if self.augment:
            # Apply transformations
            im = self.augment(image=im)

        return im


test_filenames = sorted(glob(f"../input/alaska2-image-steganalysis/Test/*.jpg"))
test_df = pd.DataFrame({'ImageFileName': list(
    test_filenames)}, columns=['ImageFileName'])

batch_size = 16
num_workers = 4
test_dataset = Alaska2TestDataset(test_df, augmentations=AUGMENTATIONS_TEST)
test_loader = torch.utils.data.DataLoader(test_dataset,
                                          batch_size=batch_size,
                                          num_workers=num_workers,
                                          shuffle=False,
                                          drop_last=False)

In [None]:
preds = []
tk0 = tqdm(test_loader)
with torch.no_grad():
    for i, im in enumerate(tk0):
        inputs = im["image"].cuda()
        # flip vertical
        im = inputs.flip(2)
        outputs = net(im)
        # fliplr
        im = inputs.flip(3)
        outputs = (0.25*outputs + 0.25*net(im))
        outputs = (outputs + 0.5*net(inputs))        
        preds.extend(F.softmax(outputs, 1).cpu().numpy())

In [None]:
preds = np.array(preds)
labels = preds.argmax(1)
new_preds = np.zeros((len(preds),))
new_preds[labels != 0] = preds[labels != 0, 1:].sum(1)
new_preds[labels == 0] = 1 - preds[labels == 0, 0]

test_df['Id'] = test_df['ImageFileName'].apply(lambda x: x.split(os.sep)[-1])
test_df['Label'] = new_preds

test_df = test_df.drop('ImageFileName', axis=1)
test_df.to_csv('submission_eb0.csv', index=False)
print(test_df.head())