# Install essential python packages

In [None]:
!pip install -q timm albumentations wandb

In [None]:
!pip install catalyst

# Import essential python packages

In [None]:
from glob import glob
from sklearn.model_selection import GroupKFold
import torch
from torch import nn
import torch.nn.functional as F
import os
import time
from datetime import datetime
import random
import cv2
import pandas as pd
import numpy as np
import albumentations as A
import matplotlib.pyplot as plt
from albumentations.pytorch.transforms import ToTensorV2
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.sampler import SequentialSampler
import wandb
import timm
import shutil
from catalyst.data.sampler import BalanceClassSampler
import warnings

warnings.filterwarnings(action='ignore')

# Hyper-parameter setting

In [None]:
from google.colab import drive
drive.mount('/gdrive', force_remount=True)

In [None]:
#압축파일 그대로 올리는 경우 해당 경로에 압축해제
!unzip -qq '/gdrive/My Drive/22년 조교 업무/2학기_인공지능/2022_AI_dataset.zip'

In [None]:
args = {
    'model_name': 'tf_efficientnet_b0_ns',  # 신경망 구조
    'lr': 1e-3,  # 학습률
    'weight_decay': 1e-4,  # 가중치 감쇠
    'drop_rate': 0.2,  # 학습 시 dropout 비율
    'image_size': 256,  # 이미지 크기
    'num_epochs': 1,  # 학습 반복수
    'batch_size': 16,  # 미니배치 크기
    'num_classes': 2,  # 판별할 클래스 개수
    'num_folds': 5,  # 데이터셋 분할 fold 개수
    'val_fold': 0,  # 검증용 fold 선택
    'seed': 42,  # 랜덤 seed 설정
    'log_step': 50,  # log 남길 iteration 반복 수
    'model_save_step': 5,  # model 저장할 epoch 반복 수
    'workspace_path': '/content',  # 작업 위치
    'checkpoint_dir': './checkpoints',  # 모델 저장 디렉토리
    'pretrained_name': 'tf_efficientnet_b0_ns_model_best.pth',  # 학습한 모델 파일이름 (.pth까지 붙이기)
}

TRAIN_DATA_ROOT_PATH = os.path.join(args['workspace_path'], 'train')
TEST_DATA_ROOT_PATH = os.path.join(args['workspace_path'], 'test')

In [None]:
def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

seed_everything(43)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

In [None]:
wandb.init(project="2022_AI", entity="jeo514")
wandb.run.name = 'steganalysis(classification)'

# GroupKFold splitting

In [None]:
dataset = []

for label, kind in enumerate(['cover', 'stego']):
    for path in glob(f'{TRAIN_DATA_ROOT_PATH}/{kind}/*.png', recursive=True):
        dataset.append({
            'kind': kind,
            'image_name': path.replace('\\', '/').split('/')[-1],
            'label': label
        })

random.shuffle(dataset)
dataset = pd.DataFrame(dataset)

gkf = GroupKFold(n_splits=args['num_folds'])

dataset.loc[:, 'fold'] = 0
for fold_number, (train_index, val_index) in enumerate(gkf.split(X=dataset.index, y=dataset['label'], groups=dataset['image_name'])):
    dataset.loc[dataset.iloc[val_index].index, 'fold'] = fold_number

In [None]:
dataset.head()

# Dataset

In [None]:
class DatasetRetriever(Dataset):
    def __init__(self, kinds, image_names, labels, transforms=None):
        super().__init__()
        self.kinds = kinds
        self.image_names = image_names
        self.labels = labels
        self.transforms = transforms

    def __getitem__(self, index: int):
        kind, image_name, label = self.kinds[index], self.image_names[index], self.labels[index]
        image_path = f'{TRAIN_DATA_ROOT_PATH}/{kind}/{image_name}'.replace('\\', '/')
        image = cv2.imread(image_path, cv2.IMREAD_COLOR)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB).astype(np.float32)
        image /= 255.0
        if self.transforms:
            sample = {'image': image}
            sample = self.transforms(**sample)
            image = sample['image']
            
        return image, label

    def __len__(self) -> int:
        return self.image_names.shape[0]

    def get_labels(self):
        return list(self.labels)

# Simple Augmentations: Flips

In [None]:
def get_train_transforms():
    return A.Compose([
            A.HorizontalFlip(p=0.5),
            A.VerticalFlip(p=0.5),
            ToTensorV2(p=1.0),
        ], p=1.0)

def get_valid_transforms():
    return A.Compose([
            ToTensorV2(p=1.0),
        ], p=1.0)

In [None]:
train_dataset = DatasetRetriever(
    kinds=dataset[dataset['fold'] != args['val_fold']].kind.values,
    image_names=dataset[dataset['fold'] != args['val_fold']].image_name.values,
    labels=dataset[dataset['fold'] != args['val_fold']].label.values,
    transforms=get_train_transforms(),
)

validation_dataset = DatasetRetriever(
    kinds=dataset[dataset['fold'] == args['val_fold']].kind.values,
    image_names=dataset[dataset['fold'] == args['val_fold']].image_name.values,
    labels=dataset[dataset['fold'] == args['val_fold']].label.values,
    transforms=get_valid_transforms(),
)

In [None]:
image, target = train_dataset[0]
numpy_image = image.permute(1,2,0).cpu().numpy()

fig, ax = plt.subplots(1, 1, figsize=(16, 8))

ax.set_axis_off()
ax.imshow(numpy_image)

In [None]:
cover_image = cv2.imread(os.path.join(TRAIN_DATA_ROOT_PATH, 'cover', '00001.png'), cv2.IMREAD_COLOR)
cover_image = cv2.cvtColor(cover_image, cv2.COLOR_BGR2RGB)
stego_image = cv2.imread(os.path.join(TRAIN_DATA_ROOT_PATH, 'stego', '00001.png'), cv2.IMREAD_COLOR)
stego_image = cv2.cvtColor(stego_image, cv2.COLOR_BGR2RGB)
difference = np.abs(cover_image - stego_image)

fig, ax = plt.subplots(1, 1, figsize=(16, 8))

ax.set_axis_off()
im = ax.imshow(difference, vmin=0, vmax=255)
fig.colorbar(im)
plt.show()

In [None]:
train_loader = torch.utils.data.DataLoader(
        train_dataset,
        sampler=BalanceClassSampler(labels=train_dataset.get_labels(), mode="downsampling"),
        batch_size=args['batch_size'],
        pin_memory=False,
        drop_last=True,
    )
val_loader = torch.utils.data.DataLoader(
        validation_dataset,
        batch_size=args['batch_size'],
        shuffle=False,
        sampler=SequentialSampler(validation_dataset),
        pin_memory=False,
    )

# Metrics

In [None]:
class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


def accuracy(output, label, topk=(1,)):
    """Computes the precision@k for the specified values of k"""
    maxk = max(topk)
    batch_size = label.size(0)

    _, pred = output.topk(maxk, 1, True, True)  # prediction: select k maximum at each output
    pred = pred.t()
    correct = pred.eq(label.view(1, -1).expand_as(pred))
    acc = correct.view(-1).float().sum(0, keepdim=True).mul_(100.0 / batch_size)  # acc = num of equivalcencs / target_size
    result = to_np(correct.view(-1))

    return acc, result

def to_np(x):
    return x.detach().cpu().data.numpy()

# Loss

In [None]:
loss_func = nn.CrossEntropyLoss().to(device)

# Model

In [None]:
!mkdir './checkpoints'
def get_net():
    net = timm.create_model(args['model_name'], num_classes=args['num_classes'], pretrained=True)
    return net

model = get_net().to(device)

# Save & Load pretrained model

In [None]:
def save_checkpoint(model, optimizer, best_acc, checkpoint_path, model_name, is_best, epoch):
    state = {
        'model': model.state_dict(),
        'optimizer': optimizer.state_dict(),
        'best_acc': best_acc,
        'epoch': epoch
    }

    if not os.path.exists(checkpoint_path):
        os.makedirs(checkpoint_path)

    filename = os.path.join(checkpoint_path, f'{model_name}_model_epoch_{state["epoch"]}.pth')
    torch.save(state, filename)
    if is_best:
        shutil.copyfile(filename, os.path.join(checkpoint_path, f'{model_name}_model_best.pth'))


def load_checkpoint(model, optimizer, pretrained_path, device):
    state = torch.load(pretrained_path, map_location=device)
    model.load_state_dict(state['model'])
    best_acc = state['best_acc']
    epoch = state['epoch']
    print(f'\t## loaded trained models (epoch: {epoch})\n')
    return model, optimizer, best_acc, epoch

# Optimizer

In [None]:
optimizer = torch.optim.AdamW(model.parameters(),
                              lr=args['lr'], betas=(0.9, 0.999),
                              weight_decay=args['weight_decay'])

# Load pretrained model

In [None]:
# if args['pretrained_name']:
#     pretrained_path = os.path.join(args['checkpoint_dir'], args['pretrained_name'])
#     model, optimizer, best_acc, initial_epoch = load_checkpoint(model, optimizer, pretrained_path, device)
# else:
#     initial_epoch = 0
#     best_acc = 0
initial_epoch = 0
best_acc = 0

# Train

In [None]:
def train(train_loader, model, *args):
    # switch to train mode
    model.train()

    with torch.enable_grad():
        train_acc, train_loss = iteration('train', train_loader, model, *args)

    return train_acc, train_loss

# Validate

In [None]:
def validate(val_loader, model, *args):
    # switch to eval mode
    model.eval()

    with torch.no_grad():
        val_acc, val_loss = iteration('val', val_loader, model, *args)

    return val_acc, val_loss

# Iteration

In [None]:
def iteration(mode, data_loader, model, optimizer, loss_func, epoch):
    am_batch_time = AverageMeter()
    am_data_time = AverageMeter()
    am_loss = AverageMeter()
    am_acc = AverageMeter()

    end = time.time()
    num_batch = np.ceil(len(data_loader)).astype(np.int32)

    for i, (input_img, target) in enumerate(data_loader):
        # measure data loading time
        am_data_time.update(time.time() - end)

        input_img = input_img.to(device)
        target = target.to(device)

        # feed-forward
        output = model(input_img)   # two output

        # calculate loss
        output = torch.nan_to_num(output)
        loss = loss_func(output, target)
        am_loss.update(loss.item(), input_img.size(0))

        # calculate accuracy
        class_prob = F.softmax(output, dim=1)
        class_acc, _ = accuracy(class_prob, target)
        am_acc.update(class_acc.item(), input_img.size(0))

        # compute gradient and do SGD step
        if mode == 'train':
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        # measure elapsed time
        am_batch_time.update(time.time() - end)
        end = time.time()
        if (i + 1) % args['log_step'] == 0:
            print('Epoch: [{0}/{1}][{2}/{3}]\t'
                  'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f}) \t'
                  'Accuracy {acc.val:.4f} ({acc.avg:.4f})'
                  .format(epoch + 1, args['num_epochs'], i + 1, num_batch, batch_time=am_batch_time,
                          data_time=am_data_time, loss=am_loss, acc=am_acc))

    return am_acc.avg, am_loss.avg

# Pipeline: train, validate, log, and save

In [None]:
min_loss = 1000
for epoch in range(initial_epoch, args['num_epochs']):
    # train for one epoch
    print('# Training')
    train_acc, train_loss = train(train_loader, model, optimizer, loss_func, epoch)
    wandb.log({'train_acc': train_acc, 'train_loss': train_loss})

    # evaluate on validation set
    print('# Validation')
    val_acc, val_loss = validate(val_loader, model, optimizer, loss_func, epoch)
    wandb.log({'val_acc': val_acc, 'val_loss':val_loss})

    is_best = val_acc > best_acc
    best_acc = max(val_acc, best_acc)
    min_loss = min(val_loss, min_loss)

    if is_best or (epoch + 1) % args['model_save_step'] == 0:
        save_checkpoint(model, optimizer, best_acc, args['checkpoint_dir'], args['model_name'], is_best, epoch + 1)

# Inference

In [None]:
pretrained_path = os.path.join(args['checkpoint_dir'], args['pretrained_name'])
model, _, _, _ = load_checkpoint(model, optimizer, pretrained_path, device)
model.eval()

In [None]:
class DatasetSubmissionRetriever(Dataset):
    def __init__(self, image_names, transforms=None):
        super().__init__()
        self.image_names = image_names
        self.transforms = transforms

    def __getitem__(self, index: int):
        image_name = self.image_names[index]
        image_path = f'{TEST_DATA_ROOT_PATH}/{image_name}'
        image = cv2.imread(image_path, cv2.IMREAD_COLOR)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB).astype(np.float32)
        image /= 255.0
        if self.transforms:
            sample = {'image': image}
            sample = self.transforms(**sample)
            image = sample['image']

        return image_name, image

    def __len__(self) -> int:
        return self.image_names.shape[0]

In [None]:
dataset = DatasetSubmissionRetriever(
    image_names=np.array([path.replace('\\', '/').split('/')[-1] for path in glob('./test/*.png')]),
    transforms=get_valid_transforms(),
)


test_loader = DataLoader(
    dataset,
    batch_size=args['batch_size'],
    shuffle=False,
    drop_last=False,
)

In [None]:
df_result = pd.DataFrame(columns=['id', 'label'])
for step, (image_names, images) in enumerate(test_loader):
    print(step, end='\r')
    images = images.to(device)
    output = model(images)
    class_prob = F.softmax(output, dim=1)
    _, class_pred = output.topk(1, 1, True, True)  # prediction: select k maximum at each output
    label = class_pred.view(-1).detach().cpu().numpy()
    df_curr = pd.DataFrame({
        'id': image_names,
        'label': label
    })
    df_result = pd.concat([df_result, df_curr], axis=0, ignore_index=True)

In [None]:
current_time = datetime.now().strftime(r'%y-%m-%d_%H-%M-%S')
df_result.to_csv(f'submission_{current_time}.csv', index=False)
df_result.head()