In [None]:
import os
import numpy as np
import pandas as pd
import random
import math
from kaggle_datasets import KaggleDatasets
from sklearn.model_selection import train_test_split
import tensorflow as tf
import tensorflow.keras.backend as K
from sklearn.model_selection import GroupKFold
import matplotlib.pyplot as plt
import gc
import sys
sys.path.append('../input/swin-transformer-tf-fixed')
from swintransformer import SwinTransformer


def seed_everything(seed):
    random.seed(seed)                           # Python 在一个明确的初始状态生成固定随机数字所必需的
    np.random.seed(seed)                        # numpy 在一个明确的初始状态生成固定随机数字所必需的
    tf.random.set_seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)    # 为了使某些基于散列的操作可复现
    os.environ['TF_CUDNN_DETERMINISTIC'] = str(seed)
    
seed_everything(41)

In [None]:
SATURATION  = (0.9, 1.1)
CONTRAST = (0.9, 1.1)
BRIGHTNESS  =  0.1
ROTATION    = 10.0
SHEAR    = 2.0
HZOOM  = 8.0
WZOOM  = 4.0
HSHIFT = 4.0
WSHIFT = 4.0

def get_mat(rotation, shear, height_zoom, width_zoom, height_shift, width_shift):
    # returns 3x3 transformmatrix which transforms indicies
        
    # CONVERT DEGREES TO RADIANS
    rotation = math.pi * rotation / 180.
    shear    = math.pi * shear    / 180.

    def get_3x3_mat(lst):
        return tf.reshape(tf.concat([lst],axis=0), [3,3])
    
    # ROTATION MATRIX
    c1   = tf.math.cos(rotation)
    s1   = tf.math.sin(rotation)
    one  = tf.constant([1],dtype='float32')
    zero = tf.constant([0],dtype='float32')
    
    rotation_matrix = get_3x3_mat([c1,   s1,   zero, 
                                   -s1,  c1,   zero, 
                                   zero, zero, one])    
    # SHEAR MATRIX
    c2 = tf.math.cos(shear)
    s2 = tf.math.sin(shear)    
    
    shear_matrix = get_3x3_mat([one,  s2,   zero, 
                                zero, c2,   zero, 
                                zero, zero, one])        
    # ZOOM MATRIX
    zoom_matrix = get_3x3_mat([one/height_zoom, zero,           zero, 
                               zero,            one/width_zoom, zero, 
                               zero,            zero,           one])    
    # SHIFT MATRIX
    shift_matrix = get_3x3_mat([one,  zero, height_shift, 
                                zero, one,  width_shift, 
                                zero, zero, one])
    
    return K.dot(K.dot(rotation_matrix, shear_matrix), 
                 K.dot(zoom_matrix,     shift_matrix))


def transform(image, label):
    # input image - is one image of size [dim,dim,3] not a batch of [b,dim,dim,3]
    # output - image randomly rotated, sheared, zoomed, and shifted
    DIM = 384
    XDIM = DIM%2
    
    rot = ROTATION * tf.random.normal([1], dtype='float32')
    shr = SHEAR * tf.random.normal([1], dtype='float32')
    h_zoom = 1.0 + tf.random.normal([1], dtype='float32') / HZOOM
    w_zoom = 1.0 + tf.random.normal([1], dtype='float32') / WZOOM
    h_shift = HSHIFT * tf.random.normal([1], dtype='float32')
    w_shift = WSHIFT * tf.random.normal([1], dtype='float32')
  
    # GET TRANSFORMATION MATRIX
    m = get_mat(rot, shr, h_zoom, w_zoom, h_shift, w_shift) 

    # LIST DESTINATION PIXEL INDICES
    x = tf.repeat( tf.range(DIM//2,-DIM//2,-1), DIM )
    y = tf.tile( tf.range(-DIM//2,DIM//2),[DIM] )
    z = tf.ones([DIM*DIM],dtype='int32')
    idx = tf.stack( [x,y,z] )
    
    # ROTATE DESTINATION PIXELS ONTO ORIGIN PIXELS
    idx2 = K.dot(m,tf.cast(idx,dtype='float32'))
    idx2 = K.cast(idx2,dtype='int32')
    idx2 = K.clip(idx2,-DIM//2+XDIM+1,DIM//2)
    
    # FIND ORIGIN PIXEL VALUES           
    idx3 = tf.stack( [DIM//2-idx2[0,], DIM//2-1+idx2[1,]] )
    d = tf.gather_nd(image,tf.transpose(idx3))
        
    return tf.reshape(d,[DIM,DIM,3]), label


def transform_test(image):
    # input image - is one image of size [dim,dim,3] not a batch of [b,dim,dim,3]
    # output - image randomly rotated, sheared, zoomed, and shifted
    DIM = 384
    XDIM = DIM%2
    
    rot = ROTATION * tf.random.normal([1], dtype='float32')
    shr = SHEAR * tf.random.normal([1], dtype='float32')
    h_zoom = 1.0 + tf.random.normal([1], dtype='float32') / HZOOM
    w_zoom = 1.0 + tf.random.normal([1], dtype='float32') / WZOOM
    h_shift = HSHIFT * tf.random.normal([1], dtype='float32')
    w_shift = WSHIFT * tf.random.normal([1], dtype='float32')
  
    # GET TRANSFORMATION MATRIX
    m = get_mat(rot, shr, h_zoom, w_zoom, h_shift, w_shift) 

    # LIST DESTINATION PIXEL INDICES
    x = tf.repeat( tf.range(DIM//2,-DIM//2,-1), DIM )
    y = tf.tile( tf.range(-DIM//2,DIM//2),[DIM] )
    z = tf.ones([DIM*DIM],dtype='int32')
    idx = tf.stack( [x,y,z] )
    
    # ROTATE DESTINATION PIXELS ONTO ORIGIN PIXELS
    idx2 = K.dot(m,tf.cast(idx,dtype='float32'))
    idx2 = K.cast(idx2,dtype='int32')
    idx2 = K.clip(idx2,-DIM//2+XDIM+1,DIM//2)
    
    # FIND ORIGIN PIXEL VALUES           
    idx3 = tf.stack( [DIM//2-idx2[0,], DIM//2-1+idx2[1,]] )
    d = tf.gather_nd(image,tf.transpose(idx3))
        
    return tf.reshape(d,[DIM,DIM,3])

In [None]:
def auto_select_accelerator():
    try:
        tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
        tf.config.experimental_connect_to_cluster(tpu)
        tf.tpu.experimental.initialize_tpu_system(tpu)
        strategy = tf.distribute.experimental.TPUStrategy(tpu)
        print("Running on TPU:", tpu.master())
    except ValueError:
        strategy = tf.distribute.get_strategy()
    print(f"Running on {strategy.num_replicas_in_sync} replicas")
    
    return strategy


def build_decoder(with_labels=True, target_size=(256, 256), ext='jpg'):
    def decode(path):
        file_bytes = tf.io.read_file(path)

        if ext == 'png':
            img = tf.image.decode_png(file_bytes, channels=3)
        elif ext in ['jpg', 'jpeg']:
            img = tf.image.decode_jpeg(file_bytes, channels=3)
        else:
            raise ValueError("Image extension not supported")
        img = tf.cast(img, tf.float32) / 255.0
        img = tf.image.resize(img, target_size)

        return img
    
    def decode_with_labels(path, label):
        return decode(path), label
    
    return decode_with_labels if with_labels else decode


def build_augmenter(with_labels=True):
    def augment(img):
        img = tf.image.random_flip_left_right(img)
#         img = tf.image.random_flip_up_down(img)
        img = tf.image.random_saturation(img, SATURATION[0], SATURATION[1])
        img = tf.image.random_contrast(img, CONTRAST[0], CONTRAST[1])
        img = tf.image.random_brightness(img, BRIGHTNESS)
        return img
    
    def augment_with_labels(img, label):
        return augment(img), label
    
    return augment_with_labels if with_labels else augment


def build_dataset(paths, labels=None, bsize=128, cache=True,
                  decode_fn=None, augment_fn=None,
                  augment=True, repeat=True, shuffle=1024,
                  seed=None, cache_dir=""):
    if cache_dir != "" and cache is True:
        os.makedirs(cache_dir, exist_ok=True)
    
    if decode_fn is None:
        decode_fn = build_decoder(labels is not None)
    
    if augment_fn is None:
        augment_fn = build_augmenter(labels is not None)
    
    AUTO = tf.data.experimental.AUTOTUNE
    slices = paths if labels is None else (paths, labels)
    
    dset = tf.data.Dataset.from_tensor_slices(slices)
    dset = dset.map(decode_fn, num_parallel_calls=AUTO)
    dset = dset.cache(cache_dir) if cache else dset
    
    # Map the functions to perform Augmentations
    dset = dset.map(augment_fn, num_parallel_calls=AUTO) if augment else dset
    dset = dset.map(transform, num_parallel_calls=AUTO) if augment else dset
    dset = dset.repeat() if repeat else dset
    dset = dset.shuffle(shuffle, seed=seed) if shuffle else dset
    dset = dset.batch(bsize, drop_remainder = True).prefetch(AUTO)
    
    return dset


def build_dataset_test(paths, labels=None, bsize=128, cache=True,
                  decode_fn=None, augment_fn=None,
                  augment=True, repeat=True, shuffle=1024,
                  seed=None, cache_dir=""):
    if cache_dir != "" and cache is True:
        os.makedirs(cache_dir, exist_ok=True)
    
    if decode_fn is None:
        decode_fn = build_decoder(labels is not None)
    
    if augment_fn is None:
        augment_fn = build_augmenter(labels is not None)
    
    AUTO = tf.data.experimental.AUTOTUNE
    slices = paths if labels is None else (paths, labels)
    
    dset = tf.data.Dataset.from_tensor_slices(slices)
    dset = dset.map(decode_fn, num_parallel_calls=AUTO)
    dset = dset.cache(cache_dir) if cache else dset
    
    # Map the functions to perform Augmentations
    dset = dset.map(augment_fn, num_parallel_calls=AUTO) if augment else dset
    dset = dset.map(transform_test, num_parallel_calls=AUTO) if augment else dset
    dset = dset.repeat() if repeat else dset
    dset = dset.shuffle(shuffle, seed=seed) if shuffle else dset
    dset = dset.batch(bsize, drop_remainder = True).prefetch(AUTO)
    
    return dset

In [None]:
COMPETITION_NAME = "siim-384x384-study-png"
COMPETITION_NAME2 = "crop384"
BATCH_SIZE = 8 * 16
GCS_DS_PATH = KaggleDatasets().get_gcs_path(COMPETITION_NAME)
GCS_DS_PATH2 = KaggleDatasets().get_gcs_path(COMPETITION_NAME2)

In [None]:
# original data
df = pd.read_csv('../input/siim-covid19-detection/train_study_level.csv')
label_cols = df.columns[1:5]
df["class"] = np.argmax(df[label_cols].values, axis=1)

weight = 1 / np.log10(df["class"].value_counts().sort_index().values)
class_weight = {i: weight[i] for i in range(4)}
print(class_weight)

In [None]:
# crop data
imageid2studyid = dict()
for dirname, _, filenames in os.walk(f'/kaggle/input/siim-covid19-detection/train'):
    for file in filenames:
        imageid2studyid[file.replace('.dcm', '')] = dirname.split('/')[-2] + "_study"

crop_ids = os.listdir("/kaggle/input/crop384/images")
crop_study_ids = [imageid2studyid[crop_id.replace("_crop.png", "")] for crop_id in crop_ids]
crop_data = pd.DataFrame({
    "id": crop_study_ids,
    "image_name": crop_ids,
})
crop_data = pd.merge(crop_data, df, on="id", how="left")

In [None]:
# concat data
df["is_crop"] = 0
df["image_name"] = df["id"] + ".png"
crop_data["is_crop"] = 1
total_data = pd.concat([df, crop_data], axis=0).reset_index(drop=True)
total_data

In [None]:
folds = 5
gkf = GroupKFold(n_splits = folds)
total_data['fold'] = -1
for fold, (train_idx, val_idx) in enumerate(gkf.split(total_data, groups=total_data.id.tolist())):
    total_data.loc[val_idx, 'fold'] = fold

pred_cols_ori = [col + "_pred_ori" for col in label_cols]
pred_cols_aug = [col + "_pred_aug" for col in label_cols]
for col in pred_cols_ori+pred_cols_aug:
    total_data[col] = np.nan
    
total_data

In [None]:
historys = []

for i in range(folds):
    strategy = auto_select_accelerator()
    
    valid_paths_df = GCS_DS_PATH + '/train/' + total_data[(total_data['fold'] == i) & (total_data['is_crop'] == 0)]['image_name']
    valid_paths_crop = GCS_DS_PATH2 + '/images/' + total_data[(total_data['fold'] == i) & (total_data['is_crop'] == 1)]['image_name']
    valid_paths = pd.concat([valid_paths_df, valid_paths_crop])
    
    valid_labels_df = total_data[(total_data['fold'] == i) & (total_data['is_crop'] == 0)][label_cols].values
    valid_labels_crop = total_data[(total_data['fold'] == i) & (total_data['is_crop'] == 1)][label_cols].values
    valid_labels = np.concatenate([valid_labels_df, valid_labels_crop], axis=0)
    
    train_paths_df = GCS_DS_PATH + '/train/' + total_data[(total_data['fold'] != i) & (total_data['is_crop'] == 0)]['image_name']
    train_paths_crop = GCS_DS_PATH2+ '/images/' + total_data[(total_data['fold'] != i) & (total_data['is_crop'] == 1)]['image_name']
    train_paths = pd.concat([train_paths_df, train_paths_crop])
    
    train_labels_df = total_data[(total_data['fold'] != i) & (total_data['is_crop'] == 0)][label_cols].values
    train_labels_crop = total_data[(total_data['fold'] != i) & (total_data['is_crop'] == 1)][label_cols].values
    train_labels = np.concatenate([train_labels_df, train_labels_crop], axis=0)
    
    shuffle_idx = np.random.choice(len(train_paths), size=len(train_paths), replace=False)
    train_paths = train_paths.iloc[shuffle_idx]
    train_labels = train_labels[shuffle_idx]
    print(len(train_paths), len(valid_paths))

    decoder = build_decoder(with_labels=True, target_size=(384, 384), ext='png')
    test_decoder = build_decoder(with_labels=False, target_size=(384, 384), ext='png')

    train_dataset = build_dataset(
        train_paths, train_labels, bsize=BATCH_SIZE, decode_fn=decoder
    )

    valid_dataset = build_dataset(
        valid_paths, valid_labels, bsize=BATCH_SIZE, decode_fn=decoder,
        repeat=False, shuffle=False, augment=False
    )
    
    test_dataset = build_dataset_test(
        valid_paths, bsize=BATCH_SIZE, decode_fn=test_decoder,
        repeat=False, shuffle=False, augment=False
    )
    test_datasets = []
    for _ in range(3):
        test_datasets.append(build_dataset_test(
        valid_paths, bsize=BATCH_SIZE, decode_fn=test_decoder,
        repeat=False, shuffle=False, augment=True
    ))

    with strategy.scope():
        model = tf.keras.Sequential([
            tf.keras.layers.InputLayer(input_shape=[384, 384, 3]),
            SwinTransformer('swin_large_384', num_classes=4, include_top=False, pretrained=True, use_tpu=True),
            tf.keras.layers.Dropout(rate=0.1),
            tf.keras.layers.Dense(64, activation='relu'),
            tf.keras.layers.Dropout(rate=0.1),
            tf.keras.layers.Dense(4, activation='softmax')
        ])
        
    model.compile(
        optimizer=tf.keras.optimizers.Adam(1e-4),
        loss=tf.keras.losses.CategoricalCrossentropy(),
        metrics=[tf.keras.metrics.AUC(multi_label=True), "acc"])

    model.summary()

    steps_per_epoch = train_paths.shape[0] // BATCH_SIZE
    checkpoint = tf.keras.callbacks.ModelCheckpoint(
        f'model{i}.h5', save_best_only=True, monitor='val_loss', mode='min', save_weights_only=True)
    lr_reducer = tf.keras.callbacks.ReduceLROnPlateau(
        monitor="val_loss", patience=3, min_lr=1e-6, mode='min')

    history = model.fit(
        train_dataset, 
        epochs=15,
        verbose=1,
        callbacks=[checkpoint, lr_reducer],
        steps_per_epoch=steps_per_epoch,
        validation_data=valid_dataset,
        class_weight=class_weight,
    )
    
    historys.append(history.history)
    
    del decoder, test_decoder, train_dataset, valid_dataset, test_dataset, test_datasets, model
    gc.collect()

In [None]:
def get_first_occur(learning_rates_list):
    unique_lrs = np.sort(np.unique(learning_rates_list))
    first_occur = []
    for lr in unique_lrs[:-1]:
        first_occur.append(learning_rates_list.index(lr))
    return np.array(first_occur) - 1


record = np.zeros((folds, 4))
fig, ax = plt.subplots(folds, 3, figsize=(3*8, folds*6))
for idx, history in enumerate(historys):
    if idx == 0:
        name1 = "auc"
        name2 = "val_auc"
    else:
        name1 = f"auc_{idx}"
        name2 = f"val_auc_{idx}"
    
    first_occur = get_first_occur(history["lr"])
    
    best_train_loss = np.array(history["loss"])[np.argmin(history["val_loss"])]
    best_valid_loss = np.min(history["val_loss"])
    best_train_auc = np.array(history[name1])[np.argmin(history["val_loss"])]
    best_valid_auc = np.array(history[name2])[np.argmin(history["val_loss"])]
    min_idx = np.argmin(history["val_loss"]) + 1
    
    ax[idx][0].plot(range(1, len(history["loss"])+1), history["loss"], "bo-", label="train_loss")
    ax[idx][0].plot(range(1, len(history["loss"])+1), history["val_loss"], "go-", label="valid_loss")
    ax[idx][0].plot([min_idx, min_idx], [history["val_loss"][min_idx-1], history["loss"][min_idx-1]])
    ax[idx][0].scatter(first_occur+1, np.array(history["loss"])[first_occur], s=300, c="r", marker="*")
    ax[idx][0].scatter(first_occur+1, np.array(history["val_loss"])[first_occur], s=300, c="r", marker="*")
    ax[idx][0].legend()
    ax[idx][0].grid()
    ax[idx][0].set_title(f"{idx} Best Train Loss: {np.round(best_train_loss, 4)}, Best Valid Loss: {np.round(best_valid_loss, 4)}")
    record[idx][0] += best_train_loss
    record[idx][1] += best_valid_loss
    
    ax[idx][1].plot(range(1, len(history[name1])+1), history[name1], "bo-", label="train_auc")
    ax[idx][1].plot(range(1, len(history[name1])+1), history[name2], "go-", label="valid_auc")
    ax[idx][1].plot([min_idx, min_idx], [history[name2][min_idx-1], history[name1][min_idx-1]])
    ax[idx][1].scatter(first_occur+1, np.array(history[name1])[first_occur], s=300, c="r", marker="*")
    ax[idx][1].scatter(first_occur+1, np.array(history[name2])[first_occur], s=300, c="r", marker="*")
    ax[idx][1].legend()
    ax[idx][1].grid()
    ax[idx][1].set_title(f"{idx} Best Train AUC: {np.round(best_train_auc, 3)}, Best Valid AUC: {np.round(best_valid_auc, 3)}")
    record[idx][2] += best_train_auc
    record[idx][3] += best_valid_auc
    
    ax[idx][2].plot(range(1, len(history["lr"])+1), history["lr"], "ro-", label="learning_rate")
    ax[idx][2].legend()
    ax[idx][2].grid()
    min_lr = np.min(history["lr"])
    ax[idx][2].set_title(f"{idx} Min LR: {np.round(min_lr, 6)}")
    
avg_record = np.mean(record, axis=0)
print("Avg_Best_Train_Loss: ", np.round(avg_record[0], 4))
print("Avg_Best_Valid_Loss: ", np.round(avg_record[1], 4))
print("Avg_Best_Train_AUC: ", np.round(avg_record[2], 3))
print("Avg_Best_Valid_AUC: ", np.round(avg_record[3], 3))