In [None]:
!pip install -q efficientnet_pytorch
!pip install albumentations==0.5.2
from efficientnet_pytorch import EfficientNet
from albumentations.pytorch import ToTensorV2
from albumentations import (
    Compose, HorizontalFlip, CLAHE, HueSaturationValue,
    RandomBrightness, RandomContrast, RandomGamma, OneOf, Resize,
    ToFloat, ShiftScaleRotate, GridDistortion, ElasticTransform, JpegCompression, HueSaturationValue,
    RGBShift, RandomBrightness, RandomContrast, Blur, MotionBlur, MedianBlur, GaussNoise, CenterCrop,
    IAAAdditiveGaussianNoise, GaussNoise, OpticalDistortion, RandomSizedCrop, VerticalFlip
)
import os
import torch
import pandas as pd
import numpy as np
import random
import torch.nn as nn
import matplotlib.pyplot as plt
from glob import glob
import torchvision
from torch.utils.data import Dataset
import time
from tqdm.notebook import tqdm
# from tqdm import tqdm
from sklearn import metrics
import cv2
import gc
import torch.nn.functional as F

In [None]:
# 随机数
seed = 42
print(f'setting everything to seed {seed}')
random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True

In [None]:
# 创建train 和 test
data_dir = '../input/alaska2-image-steganalysis'
sample_size = 75000
val_size = int(sample_size*0.25)
train_fn, val_fn = [], []
train_labels, val_labels = [], []

folder_names = ['Cover/','JMiPOD/', 'JUNIWARD/', 'UERD/'] # label 1 2 3
for label, folder in enumerate(folder_names):
    train_filenames = sorted(glob(f"{data_dir}/{folder}/*.jpg")[:sample_size])
    np.random.shuffle(train_filenames)
    train_fn.extend(train_filenames[val_size:])
    train_labels.extend(np.zeros(len(train_filenames[val_size:],))+label)
    val_fn.extend(train_filenames[:val_size])
    val_labels.extend(np.zeros(len(train_filenames[:val_size],))+label)

assert len(train_labels) == len(train_fn), "wrong labels"
assert len(val_labels) == len(val_fn), "wrong labels"

# 验证
train_df = pd.DataFrame({'ImageFileName': train_fn, 'Label': train_labels}, columns=['ImageFileName', 'Label'])
train_df['Label'] = train_df['Label'].astype(int)
val_df = pd.DataFrame({'ImageFileName': val_fn, 'Label': val_labels}, columns=['ImageFileName', 'Label'])
val_df['Label'] = val_df['Label'].astype(int)

print(train_df)
train_df.Label.hist()

In [None]:
# Pytorch 数据集构建
class Alaska2Dataset(Dataset):

    def __init__(self, df, augmentations=None):

        self.data = df
        self.augment = augmentations

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        fn, label = self.data.loc[idx]
        im = cv2.imread(fn)[:, :, ::-1]
        if self.augment:
            im = self.augment(image=im)
        return im, label


img_size = 512
AUGMENTATIONS_TRAIN = Compose([
    Resize(img_size, img_size, p=1),
    VerticalFlip(p=0.5),
    HorizontalFlip(p=0.5),
    JpegCompression(quality_lower=75, quality_upper=100,p=0.5),
    ToFloat(max_value=255),
    ToTensorV2()
], p=1)


AUGMENTATIONS_TEST = Compose([
    Resize(img_size, img_size, p=1),
    ToFloat(max_value=255),
    ToTensorV2()
], p=1)

In [None]:
# test 数据集 不做图像增强
temp_df = train_df.sample(64).reset_index(drop=True)
train_dataset = Alaska2Dataset(temp_df, augmentations=AUGMENTATIONS_TEST)
batch_size = 64
num_workers = 0

temp_loader = torch.utils.data.DataLoader(train_dataset,
                                          batch_size=batch_size,
                                          num_workers=num_workers, shuffle=False)


images, labels = next(iter(temp_loader))
images = images['image'].permute(0, 2, 3, 1)
max_images = 64
grid_width = 16
grid_height = int(max_images / grid_width)
fig, axs = plt.subplots(grid_height, grid_width,figsize=(grid_width+1, grid_height+1))

for i, (im, label) in enumerate(zip(images, labels)):
    ax = axs[int(i / grid_width), i % grid_width]
    ax.imshow(im.squeeze())
    ax.set_title(str(label.item()))
    ax.axis('off')

plt.suptitle("0: COVER, 1: JMiPOD, 2: JUNIWARD, 3:UERD")
plt.show()
del images
gc.collect()

In [None]:
# train 数据集 进行图像增强
train_dataset = Alaska2Dataset(temp_df, augmentations=AUGMENTATIONS_TRAIN)
batch_size = 64
num_workers = 0

temp_loader = torch.utils.data.DataLoader(train_dataset,
                                          batch_size=batch_size,
                                          num_workers=num_workers, shuffle=False)


images, labels = next(iter(temp_loader))
images = images['image'].permute(0, 2, 3, 1)
max_images = 64
grid_width = 16
grid_height = int(max_images / grid_width)
fig, axs = plt.subplots(grid_height, grid_width,
                        figsize=(grid_width+1, grid_height+1))

for i, (im, label) in enumerate(zip(images, labels)):
    ax = axs[int(i / grid_width), i % grid_width]
    ax.imshow(im.squeeze())
    ax.set_title(str(label.item()))
    ax.axis('off')

plt.suptitle("0: No Hidden Message, 1: JMiPOD, 2: JUNIWARD, 3:UERD")
plt.show()
del images, temp_df
gc.collect()

In [None]:
# 搭建 网络 efficient b0
class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = EfficientNet.from_pretrained('efficientnet-b0')
        self.dense_output = nn.Linear(1280, 4)

    def forward(self, x):
        feat = self.model.extract_features(x)
        feat = F.avg_pool2d(feat, feat.size()[2:]).reshape(-1, 1280)
        return self.dense_output(feat)

In [None]:
batch_size = 8
num_workers = 8

# 构建数据集
train_dataset = Alaska2Dataset(train_df, augmentations=AUGMENTATIONS_TRAIN)
valid_dataset = Alaska2Dataset(val_df, augmentations=AUGMENTATIONS_TEST)

train_loader = torch.utils.data.DataLoader(train_dataset,
                                           batch_size=batch_size,
                                           num_workers=num_workers,
                                           shuffle=True)

valid_loader = torch.utils.data.DataLoader(valid_dataset,
                                           batch_size=batch_size*2,
                                           num_workers=num_workers,
                                           shuffle=False)

device = 'cuda'
model = Net().to(device)
model.load_state_dict(torch.load('../input/alaska/epoch_13_val_loss_6.67_auc_0.819.pth'))
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)

In [None]:
def alaska_weighted_auc(y_true, y_valid):
    tpr_thresholds = [0.0, 0.4, 1.0]
    weights = [2, 1]

    fpr, tpr, thresholds = metrics.roc_curve(y_true, y_valid, pos_label=1)

    # size of subsets
    areas = np.array(tpr_thresholds[1:]) - np.array(tpr_thresholds[:-1])

    # The total area is normalized by the sum of weights such that the final weighted AUC is between 0 and 1.
    normalization = np.dot(areas, weights)

    competition_metric = 0
    for idx, weight in enumerate(weights):
        y_min = tpr_thresholds[idx]
        y_max = tpr_thresholds[idx + 1]
        mask = (y_min < tpr) & (tpr < y_max)
        # pdb.set_trace()

        x_padding = np.linspace(fpr[mask][-1], 1, 100)

        x = np.concatenate([fpr[mask], x_padding])
        y = np.concatenate([tpr[mask], [y_max] * len(x_padding)])
        y = y - y_min  # normalize such that curve starts at y=0
        score = metrics.auc(x, y)
        submetric = score * weight
        best_subscore = (y_max - y_min) * weight
        competition_metric += submetric

    return competition_metric / normalization

In [None]:
criterion = torch.nn.CrossEntropyLoss()
num_epochs = 0
train_loss, val_loss = [], []

for epoch in range(num_epochs):
    print('Epoch {}/{}'.format(epoch, num_epochs - 1))
    print('-' * 10)
    model.train()
    running_loss = 0
    tk0 = tqdm(train_loader, total=int(len(train_loader)))
    for im, labels in tk0:
        inputs = im["image"].to(device, dtype=torch.float)
        labels = labels.to(device, dtype=torch.long)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        tk0.set_postfix(loss=(loss.item()))

    epoch_loss = running_loss / (len(train_loader)/batch_size)
    train_loss.append(epoch_loss)
    print('Training Loss: {:.8f}'.format(epoch_loss))

    tk1 = tqdm(valid_loader, total=int(len(valid_loader)))
    model.eval()
    running_loss = 0
    y, preds = [], []
    with torch.no_grad():
        for (im, labels) in tk1:
            inputs = im["image"].to(device, dtype=torch.float)
            labels = labels.to(device, dtype=torch.long)
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            y.extend(labels.cpu().numpy().astype(int))
            preds.extend(F.softmax(outputs, 1).cpu().numpy())
            running_loss += loss.item()
            tk1.set_postfix(loss=(loss.item()))

        epoch_loss = running_loss / (len(valid_loader)/batch_size)
        val_loss.append(epoch_loss)
        preds = np.array(preds)
        # 多分类到二分类
        labels = preds.argmax(1)
        acc = (labels == y).mean()*100
        new_preds = np.zeros((len(preds),))
        temp = preds[labels != 0, 1:]
        new_preds[labels != 0] = temp.sum(1)
        new_preds[labels == 0] = preds[labels == 0, 0]
        y = np.array(y)
        y[y != 0] = 1
        auc_score = alaska_weighted_auc(y, new_preds)
        print(f'Val Loss: {epoch_loss:.3}, Weighted AUC:{auc_score:.3}, Acc: {acc:.3}')

    torch.save(model.state_dict(),f"epoch_{epoch+4}_val_loss_{epoch_loss:.3}_auc_{auc_score:.3}.pth")

In [None]:
plt.figure(figsize=(15,7))
plt.plot(train_loss, c='r')
plt.plot(val_loss, c='b')
plt.legend(['train_loss', 'val_loss'])
plt.title('Loss Plot')

# 进行测试

In [None]:
# 构建测试集数据
class Alaska2TestDataset(Dataset):

    def __init__(self, df, augmentations=None):

        self.data = df
        self.augment = augmentations

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        fn = self.data.loc[idx][0]
        im = cv2.imread(fn)[:, :, ::-1]

        if self.augment:
            im = self.augment(image=im)

        return im


test_filenames = sorted(glob(f"{data_dir}/Test/*.jpg"))
test_df = pd.DataFrame({'ImageFileName': list(test_filenames)}, columns=['ImageFileName'])

batch_size = 16
num_workers = 4
test_dataset = Alaska2TestDataset(test_df, augmentations=AUGMENTATIONS_TEST)
test_loader = torch.utils.data.DataLoader(test_dataset,
                                          batch_size=batch_size,
                                          num_workers=num_workers,
                                          shuffle=False,
                                          drop_last=False)

In [None]:
model.eval()

preds = []
tk0 = tqdm(test_loader)
with torch.no_grad():
    for i, im in enumerate(tk0):
        inputs = im["image"].to(device)
        # flip vertical
        im = inputs.flip(2)
        outputs = model(im)
        # fliplr
        im = inputs.flip(3)
        outputs = (0.25*outputs + 0.25*model(im))
        outputs = (outputs + 0.5*model(inputs))        
        preds.extend(F.softmax(outputs, 1).cpu().numpy())

preds = np.array(preds)
labels = preds.argmax(1)
new_preds = np.zeros((len(preds),))
temp = preds[labels != 0, 1:]
# new_preds[labels != 0] = [temp[i, val] for i, val in enumerate(temp.argmax(1))]
new_preds[labels != 0] = temp.sum(1)
new_preds[labels == 0] = preds[labels == 0, 0]

test_df['Id'] = test_df['ImageFileName'].apply(lambda x: x.split(os.sep)[-1])
test_df['Label'] = new_preds

test_df = test_df.drop('ImageFileName', axis=1)
test_df.to_csv('submission.csv', index=False)
print(test_df.head())