In [None]:
from albumentations.pytorch import ToTensor
from albumentations import (
    Compose, HorizontalFlip, CLAHE, HueSaturationValue,
    RandomBrightness, RandomContrast, RandomGamma, OneOf, Resize,
    ToFloat, ShiftScaleRotate, GridDistortion, RandomRotate90, Cutout,
    RGBShift, RandomBrightness, RandomContrast, Blur, MotionBlur, MedianBlur, GaussNoise, CoarseDropout,
    IAAAdditiveGaussianNoise, GaussNoise, OpticalDistortion, RandomSizedCrop, VerticalFlip, Normalize
)

import os
import pandas as pd
import numpy as np
import seaborn as sns
from torch.utils.data import DataLoader

import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as t
import matplotlib.pyplot as plt
import torch.nn.functional as F
from sklearn import metrics

!pip install --upgrade efficientnet_pytorch
from efficientnet_pytorch import EfficientNet
from tqdm.notebook import tqdm


In [None]:
import scipy

from numpy import pi
from numpy import sin
from numpy import zeros
from numpy import r_
from scipy import signal
from scipy import misc # pip install Pillow
import matplotlib.pylab as pylab

%matplotlib inline
pylab.rcParams['figure.figsize'] = (20.0, 7.0)

In [None]:
import random

seed = 42
print(f'setting everything to seed {seed}')
random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True

In [None]:
data_dir = '../input/alaska2-image-steganalysis'
folder_names = ['JMiPOD/', 'JUNIWARD/', 'UERD/']
class_names = ['Normal', 'JMiPOD_75', 'JMiPOD_90', 'JMiPOD_95', 
               'JUNIWARD_75', 'JUNIWARD_90', 'JUNIWARD_95',
                'UERD_75', 'UERD_90', 'UERD_95']
class_labels = { name: i for i, name in enumerate(class_names)}
num_classes = len(class_labels)

In [None]:
train_df = pd.read_csv('../input/alaska2trainvalsplit/alaska2_train_df.csv')
val_df = pd.read_csv('../input/alaska2trainvalsplit/alaska2_val_df.csv')

print(train_df.head(10))
train_df.Label.hist()
plt.title('Distribution of Classes')

In [None]:
#train_df = train_df.sample(1000)
#val_df = val_df.sample(500)
#train_df,val_df

In [None]:
import scipy
import os
import numpy as np
import pandas as pd
from numpy import pi
from numpy import sin
from numpy import zeros
from numpy import r_
from scipy import signal
from scipy import misc #pip install Pillow
from scipy import fftpack
import matplotlib.pylab as pylab

%matplotlib inline
pylab.rcParams['figure.figsize'] = (20.0, 7.0)


In [None]:
def dct2(a):
    return scipy.fftpack.dct( scipy.fftpack.dct( a, axis=0, norm='ortho' ), axis=1, norm='ortho' )

def idct2(a):
    return scipy.fftpack.idct( scipy.fftpack.idct( a, axis=0 , norm='ortho'), axis=1 , norm='ortho')

In [None]:
def dct_ext(img):
    imsize = img.shape
    dct = np.zeros(imsize)
    for i in r_[:imsize[0]:8]:
        for j in r_[:imsize[1]:8]:
            dct[i:(i+8),j:(j+8)] = dct2( img[i:(i+8),j:(j+8)] )

    thresh = 0.02
    dct_thresh = dct * (abs(dct) > (thresh*np.max(dct)))
    return dct_thresh

In [None]:
from torch.utils.data import Dataset
import cv2

class Alaska(Dataset):
    
    def __init__(self, dataframe, trans = None):
        self.data = dataframe
        self.transform = trans
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        fname, target = self.data.iloc[idx]
        img = cv2.imread(fname)[:, :, ::-1]
        #img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB).astype(np.float32)
        #img/= 255
        #img = dct_ext(img)
        
        if self.transform:
            img = self.transform(image = img)
        x = (img['image'], target)
        
        return x

In [None]:
augmentations_train = Compose([
    #Resize(512, 512, p=1), 
    VerticalFlip(p=0.5),
    HorizontalFlip(p=0.5),
    ToFloat(max_value=255),
    ToTensor()
],p=1)

augmentations_test = Compose([
    #Resize(512, 512, p=1),
    ToFloat(max_value=255),
    ToTensor()
])

In [None]:
train_ds = Alaska(train_df, trans = augmentations_train)
val_ds = Alaska(val_df, trans = augmentations_test)
train_ds

In [None]:
len(train_ds)

In [None]:
img, lab = train_ds[200]
plt.imshow(img.permute(1,2,0))

In [None]:
batch_size = 64
num_workers = 0

temp_dl = DataLoader(train_ds, batch_size = batch_size, num_workers = num_workers, shuffle=True)


In [None]:
import gc

images, labels = next(iter(temp_dl))
images = images.permute(0, 2, 3, 1)
max_images = 64
grid_width = 16
grid_height = int(max_images / grid_width)
fig, axs = plt.subplots(grid_height, grid_width,
                        figsize=(grid_width+1, grid_height+1))

for i, (im, label) in enumerate(zip(images, labels)):
    ax = axs[int(i / grid_width), i % grid_width]
    ax.imshow(im.squeeze())
    ax.set_title(str(label.item()))
    ax.axis('off')

plt.suptitle("0: No Hidden Message, 1: JMiPOD, 2: JUNIWARD, 3:UERD")
plt.show()
del images, temp_dl
gc.collect()

In [None]:
def get_default_device():
    if torch.cuda.is_available():
        return torch.device('cuda')
    else:
        return torch.device('cpu')

def to_device(data, device):
    if isinstance(data, (list,tuple)):
        return [to_device(x, device) for x in data]
    return data.to(device, non_blocking = True)

class DeviceDataLoader():
    def __init__(self, dl, device):
        self.dl = dl
        self.device = device
        
    def __iter__(self):
        for b in self.dl:
            yield to_device(b, self.device)
            
    def __len__(self):
        return len(self.dl)
        

In [None]:
device = get_default_device()

In [None]:
class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = EfficientNet.from_pretrained('efficientnet-b0')
        # b0 => 1280
        self.dense_output = nn.Linear(1280, num_classes)

    def forward(self, x):
        feat = self.model.extract_features(x)
        feat = F.avg_pool2d(feat, feat.size()[2:]).reshape(-1, 1280)
        return self.dense_output(feat)
        

In [None]:
model = Net()

In [None]:
# https://www.kaggle.com/anokas/weighted-auc-metric-updated

def alaska_weighted_auc(y_true, y_valid):
    tpr_thresholds = [0.0, 0.4, 1.0]
    weights = [2,   1]

    fpr, tpr, thresholds = metrics.roc_curve(y_true, y_valid, pos_label=1)

    # size of subsets
    areas = np.array(tpr_thresholds[1:]) - np.array(tpr_thresholds[:-1])

    # The total area is normalized by the sum of weights such that the final weighted AUC is between 0 and 1.
    normalization = np.dot(areas, weights)

    competition_metric = 0
    for idx, weight in enumerate(weights):
        y_min = tpr_thresholds[idx]
        y_max = tpr_thresholds[idx + 1]
        mask = (y_min < tpr) & (tpr < y_max)
        # pdb.set_trace()

        x_padding = np.linspace(fpr[mask][-1], 1, 100)

        x = np.concatenate([fpr[mask], x_padding])
        y = np.concatenate([tpr[mask], [y_max] * len(x_padding)])
        y = y - y_min  # normalize such that curve starts at y=0
        score = metrics.auc(x, y)
        submetric = score * weight
        best_subscore = (y_max - y_min) * weight
        competition_metric += submetric

    return competition_metric / normalization

In [None]:
def loss_batch(model, loss_func, xb, yb, opt = None, metric = None):
    preds = model(xb)
    
    loss = loss_func(preds, yb)
    if opt is not None:
        loss.backward()
        opt.step()
        opt.zero_grad()
        
    metric_result = None
    if metric is not None:
        metric_result = metric(preds, yb)
    return loss.item(), len(xb), metric_result

In [None]:
def evaluate(model, loss_fn, valid_dl, metric = None):
    labs, predictions = [], []
    with torch.no_grad():
        for imgs, labels in valid_dl:
            imgs = imgs.to(device, dtype=torch.float)
            labels = labels.to(device, dtype=torch.long)
            preds = model(imgs)
            loss = loss_fn(preds, labels)
            
            labs.extend(labels.cpu().numpy().astype(int))
            predictions.extend(F.softmax(preds, 1).cpu().numpy())
                
        results = [loss_batch(model, loss_fn, xb, yb, metric = metric) 
                          for xb, yb in valid_dl]
        losses, nums, metrics = zip(*results)

        predictions = np.array(predictions)
        pred_labels = predictions.argmax(1)
        
        eval_accuracy = (pred_labels == labs).mean()
        
        new_preds = np.zeros(len(predictions))
        temp = predictions[pred_labels != 0, 1:]

        new_preds[pred_labels != 0] = temp.sum(1)
        new_preds[pred_labels == 0] = 1 - predictions[pred_labels == 0, 0]
        labs = np.array(labs)
        labs[labs != 0] = 1
        
        auc_score = alaska_weighted_auc(labs, new_preds)
        
        total = np.sum(nums)
    
        avg_loss = np.sum(np.multiply(losses, nums))/total
    
        avg_metric = None
        if metric is not None:
            avg_metric = np.sum(np.multiply(metrics, nums))/total
    return avg_loss, total, avg_metric, auc_score, eval_accuracy

In [None]:
def fit(epochs, model, loss_fn, train_dl, valid_dl, opt_fn = None, lr = None, metric = None):
    train_losses, val_losses, val_metrics, auc_metrics, eval_metrics = [], [], [], [], []
    
    if opt_fn is None: opt_fn = torch.optim.SGD
    opt = opt_fn(model.parameters(), lr = lr)
    
    for epoch in range (epochs):
        model.train()
        for xb, yb in train_dl:
            train_loss, _, _ =loss_batch(model, loss_fn, xb, yb, opt)
            
        model.eval()
        result = evaluate(model, loss_fn, valid_dl, metric)
        val_loss, total, val_metric, auc_score, eval_accuracy = result
        
        train_losses.append(train_loss)
        val_losses.append(val_loss)
        val_metrics.append(val_metric)
        auc_metrics.append(auc_score)
        eval_metrics.append(eval_accuracy)
        
        if metric is None:
            print('Epoch [{}/{}], train_loss: {:4f}, val_loss: {:.4f}'
                  .format(epoch+1, epochs, train_loss, val_loss))
        else:
            print('Epoch [{}/{}], train_loss: {:4f}, val_loss: {:.4f}, val_{}: {:.4f},auc_score: {:.4f}, eval_accuracy: {:.4f}'
                 .format(epoch+1, epochs, train_loss, val_loss, metric.__name__, val_metric, auc_score, eval_accuracy))
    return train_losses, val_losses, val_metrics, auc_metrics, eval_metrics

In [None]:
def accuracy(outputs, labels):
    _,preds = torch.max(outputs, dim = 1)
    return torch.sum(preds == labels).item()/len(preds)

In [None]:
batch_size = 8
num_workers = 8

train_dl = DataLoader(train_ds, batch_size = batch_size, num_workers = num_workers, shuffle=True)
val_dl = DataLoader(val_ds, batch_size = batch_size, num_workers = num_workers, shuffle=False)


In [None]:
model = Net() 
train_dl = DeviceDataLoader(train_dl, device)
val_dl = DeviceDataLoader(val_dl, device)
to_device(model, device)


In [None]:
def load_cp(cp):
    model.load_state_dict(cp['model_state'])
    

In [None]:
for p in model.parameters():
    print(p)

In [None]:
cp = torch.load('../input/checkpoint/checkpoint26102020.pth')
load_cp(cp)

In [None]:
for p in model.parameters():
    print(p)

In [None]:
num_epochs = 2
opt_fn = torch.optim.AdamW
lr = 1e-4

In [None]:
x = fit(num_epochs, model, F.cross_entropy, train_dl, val_dl, opt_fn, lr, metric = accuracy)
train_losses, val_losses, val_metrics, auc_metrics, eval_metrics = x

In [None]:
import glob
class Alaska2TestDataset(Dataset):

    def __init__(self, df, augmentations=None):

        self.data = df
        self.augment = augmentations

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        fn = self.data.loc[idx][0]
        im = cv2.imread(fn)[:, :, ::-1]

        if self.augment:
            # Apply transformations
            im = self.augment(image=im)

        return im


test_filenames = sorted(glob.glob(f"{data_dir}/Test/*.jpg"))
test_df = pd.DataFrame({'ImageFileName': list(
    test_filenames)}, columns=['ImageFileName'])

batch_size = 16
num_workers = 4
test_dataset = Alaska2TestDataset(test_df, augmentations=augmentations_test)
test_loader = torch.utils.data.DataLoader(test_dataset,
                                          batch_size=batch_size,
                                          num_workers=num_workers,
                                          shuffle=False,
                                          drop_last=False)

In [None]:
model.eval()

preds = []
tk0 = tqdm(test_loader)
with torch.no_grad():
    for i, im in enumerate(tk0):
        inputs = im["image"].to(device)
        # flip vertical
        im = inputs.flip(2)
        outputs = model(im)
        # fliplr
        im = inputs.flip(3)
        outputs = (0.25*outputs + 0.25*model(im))
        outputs = (outputs + 0.5*model(inputs))
        labels = labels.to(device, dtype=torch.long)
        
        preds.extend(F.softmax(outputs, 1).cpu().numpy())

preds = np.array(preds)
labels = preds.argmax(1)
new_preds = np.zeros((len(preds),))
temp = preds[labels != 0, 1:]
new_preds[labels != 0] = [temp[i, val] for i, val in enumerate(temp.argmax(1))]
new_preds[labels == 0] = preds[labels == 0, 0]

test_df['Id'] = test_df['ImageFileName'].apply(lambda x: x.split(os.sep)[-1])
test_df['Label'] = new_preds

test_df = test_df.drop('ImageFileName', axis=1)
test_df.to_csv('submission.csv', index=False)
print(test_df.head())

In [None]:
torch.save(model.state_dict(), 'alaska2-effnetb0-4eps.pth')

In [None]:
checkpoint = {
    'epochs': 4,
    'model_state': model.state_dict(),
    'optim_state': opt_fn.state_dict()
}
torch.save(checkpoint, 'checkpoint07112020.pth')

In [None]:
#model = torch.load(PATH)
#model.eval()

for p in model.parameters():
    print(p)