In [None]:
import numpy as np
from tensorflow import keras
from tensorflow.keras import layers
import tensorflow as tf
import re
import pandas as pd
from sklearn.model_selection import KFold
AUTO = tf.data.experimental.AUTOTUNE
folds = 5

image_size = 128  # We'll resize input images to this size
input_shape = (image_size, image_size,3)

TEST_FILENAMES = tf.io.gfile.glob('../input/ranzcr-clip-catheter-line-classification/test_tfrecords/*.tfrec') # predictions on this dataset should be submitted for the competition
print(TEST_FILENAMES)

batch_size = 4

In [None]:
import sys
!pip install ../input/validators
package_path = '../input/vitkeras/'
sys.path.append(package_path)

#test
#!pip install vit-keras
from vit_keras import vit, utils


In [None]:

UNLABELED_TFREC_FORMAT = {
    'image': tf.io.FixedLenFeature([], tf.string),
    'StudyInstanceUID': tf.io.FixedLenFeature([], tf.string)
}

def decode_image(image_data):
    image = tf.image.decode_jpeg(image_data, channels=3)
    image = tf.cast(image, tf.float32) / 255.0  # convert image to floats in [0, 1] range
    image = tf.image.resize(image, [image_size, image_size])
    image = tf.reshape(image, [image_size, image_size, 3])
    return image


def read_unlabeled_tfrecord(example):
    example = tf.io.parse_single_example(example, UNLABELED_TFREC_FORMAT)
    image = decode_image(example['image'])
    idnum = example['StudyInstanceUID']
    return image, idnum  # returns a dataset of image(s)


def load_dataset(filenames, labeled=True, ordered=False):
    ignore_order = tf.data.Options()
    if not ordered:
        ignore_order.experimental_deterministic = False  # disable order, increase speed

    dataset = tf.data.TFRecordDataset(filenames,
                                      num_parallel_reads=AUTO)  # automatically interleaves reads from multiple files
    dataset = dataset.with_options(ignore_order)  # use data as soon as it streams in, rather than in its original order
    dataset = dataset.map(read_unlabeled_tfrecord,
                          num_parallel_calls=AUTO)  # returns a dataset of (image, label) pairs if labeled = True or (image, id) pair if labeld = False
    return dataset


def get_test_dataset(ordered=False):
    dataset = load_dataset(TEST_FILENAMES, labeled=False, ordered=ordered)
    dataset = dataset.batch(batch_size)
    dataset = dataset.prefetch(AUTO)  # prefetch next batch while training (autotune prefetch buffer size)
    return dataset


def count_data_items(filenames):
    # the number of data items is written in the name of the .tfrec files, i.e. flowers00-230.tfrec = 230 data items
    n = [int(re.compile(r"-([0-9]*)\.").search(filename).group(1)) for filename in filenames]
    return np.sum(n)

NUM_TEST_IMAGES = count_data_items(TEST_FILENAMES)


In [None]:
def test():
    transformer = vit.vit_l32(
        image_size=image_size,
        pretrained=False,
        include_top=False,
        pretrained_top = False,
        weights=f'../input/vitweight/fold0vit.h5',
    )

    model = tf.keras.Sequential([
        transformer,
        tf.keras.layers.Dense(11, activation='sigmoid')
    ])

    print(model)
    models1=[]
    models1.append(model)
    test_ds = get_test_dataset(ordered=True)
    test_images_ds = test_ds.map(lambda image, idnum: image)

    labels = ['ETT - Abnormal', 'ETT - Borderline',
              'ETT - Normal', 'NGT - Abnormal', 'NGT - Borderline',
              'NGT - Incompletely Imaged', 'NGT - Normal', 'CVC - Abnormal',
              'CVC - Borderline', 'CVC - Normal', 'Swan Ganz Catheter Present']
    mean = np.average([models1[i].predict(test_images_ds) for i in range(len(models1))], axis=0)

    test_ids_ds = test_ds.map(lambda image, idnum: idnum).unbatch()
    test_ids = next(iter(test_ids_ds.batch(NUM_TEST_IMAGES))).numpy().astype('U')

    submission = pd.DataFrame(mean, columns=labels)

    submission.insert(0, "StudyInstanceUID", test_ids, False)
    submission['StudyInstanceUID'] = submission['StudyInstanceUID'].apply(lambda x: x.rstrip(".jpg"))
    submission.to_csv('submission.csv', index=False)

test()

