In [None]:
import math, re, os
import tensorflow as tf
import numpy as np
from matplotlib import pyplot as plt
from kaggle_datasets import KaggleDatasets
from sklearn.metrics import f1_score, precision_score, recall_score, confusion_matrix
print("Tensorflow version " + tf.__version__)
AUTO = tf.data.experimental.AUTOTUNE

In [None]:
import sys
sys.path.append('../input/swintransformertf')
from swintransformer import SwinTransformer

# TPU or GPU detection

In [None]:
# NEW on TPU in TensorFlow 24: shorter cross-compatible TPU/GPU/multi-GPU/cluster-GPU detection code

try: # detect TPUs
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver.connect() # TPU detection
    strategy = tf.distribute.TPUStrategy(tpu)
except ValueError: # detect GPUs
    strategy = tf.distribute.MirroredStrategy() # for GPU or multi-GPU machines
    #strategy = tf.distribute.get_strategy() # default strategy that works on CPU and single GPU
    #strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy() # for clusters of multi-GPU machines

print("Number of accelerators: ", strategy.num_replicas_in_sync)

# Competition data access
TPUs read data directly from Google Cloud Storage (GCS). This Kaggle utility will copy the dataset to a GCS bucket co-located with the TPU. If you have multiple datasets attached to the notebook, you can pass the name of a specific dataset to the get_gcs_path function. The name of the dataset is the name of the directory it is mounted in. Use `!ls /kaggle/input/` to list attached datasets.

In [None]:
# GCS_DS_PATH = KaggleDatasets().get_gcs_path("flower-classification") # you can list the bucket with "!gsutil ls $GCS_DS_PATH"

# Configuration

In [None]:
IMAGE_SIZE = [224, 224] # At this size, a GPU will run out of memory. Use the TPU.
                        # For GPU training, please select 224 x 224 px image size.
epochs = 15
batch_size = 16 * strategy.num_replicas_in_sync

# GCS_PATH_SELECT = { # available image sizes
#     192: GCS_DS_PATH + '/tfrecords-jpeg-192x192',
#     224: GCS_DS_PATH + '/tfrecords-jpeg-224x224',
#     331: GCS_DS_PATH + '/tfrecords-jpeg-331x331',
#     512: GCS_DS_PATH + '/tfrecords-jpeg-512x512'
# }
# GCS_PATH = GCS_PATH_SELECT[IMAGE_SIZE[0]]

# TRAINING_FILENAMES = tf.io.gfile.glob(GCS_PATH + '/train/*.tfrec')
# VALIDATION_FILENAMES = tf.io.gfile.glob(GCS_PATH + '/val/*.tfrec')
# TEST_FILENAMES = tf.io.gfile.glob(GCS_PATH + '/test/*.tfrec') # predictions on this dataset should be submitted for the competition

# CLASSES = ['pink primrose',    'hard-leaved pocket orchid', 'canterbury bells', 'sweet pea',     'wild geranium',     'tiger lily',           'moon orchid',              'bird of paradise', 'monkshood',        'globe thistle',         # 00 - 09
#            'snapdragon',       "colt's foot",               'king protea',      'spear thistle', 'yellow iris',       'globe-flower',         'purple coneflower',        'peruvian lily',    'balloon flower',   'giant white arum lily', # 10 - 19
#            'fire lily',        'pincushion flower',         'fritillary',       'red ginger',    'grape hyacinth',    'corn poppy',           'prince of wales feathers', 'stemless gentian', 'artichoke',        'sweet william',         # 20 - 29
#            'carnation',        'garden phlox',              'love in the mist', 'cosmos',        'alpine sea holly',  'ruby-lipped cattleya', 'cape flower',              'great masterwort', 'siam tulip',       'lenten rose',           # 30 - 39
#            'barberton daisy',  'daffodil',                  'sword lily',       'poinsettia',    'bolero deep blue',  'wallflower',           'marigold',                 'buttercup',        'daisy',            'common dandelion',      # 40 - 49
#            'petunia',          'wild pansy',                'primula',          'sunflower',     'lilac hibiscus',    'bishop of llandaff',   'gaura',                    'geranium',         'orange dahlia',    'pink-yellow dahlia',    # 50 - 59
#            'cautleya spicata', 'japanese anemone',          'black-eyed susan', 'silverbush',    'californian poppy', 'osteospermum',         'spring crocus',            'iris',             'windflower',       'tree poppy',            # 60 - 69
#            'gazania',          'azalea',                    'water lily',       'rose',          'thorn apple',       'morning glory',        'passion flower',           'lotus',            'toad lily',        'anthurium',             # 70 - 79
#            'frangipani',       'clematis',                  'hibiscus',         'columbine',     'desert-rose',       'tree mallow',          'magnolia',                 'cyclamen ',        'watercress',       'canna lily',            # 80 - 89
#            'hippeastrum ',     'bee balm',                  'pink quill',       'foxglove',      'bougainvillea',     'camellia',             'mallow',                   'mexican petunia',  'bromelia',         'blanket flower',        # 90 - 99
#            'trumpet creeper',  'blackberry lily',           'common tulip',     'wild rose']                                                                                                                                               # 100 - 102

## Visualization utilities
data -> pixels, nothing of much interest for the machine learning practitioner in this section.

In [None]:
# # numpy and matplotlib defaults
# np.set_printoptions(threshold=15, linewidth=80)

# def batch_to_numpy_images_and_labels(data):
#     images, labels = data
#     numpy_images = images.numpy()
#     numpy_labels = labels.numpy()
#     if numpy_labels.dtype == object: # binary string in this case, these are image ID strings
#         numpy_labels = [None for _ in enumerate(numpy_images)]
#     # If no labels, only image IDs, return None for labels (this is the case for test data)
#     return numpy_images, numpy_labels

# def title_from_label_and_target(label, correct_label):
#     if correct_label is None:
#         return CLASSES[label], True
#     correct = (label == correct_label)
#     return "{} [{}{}{}]".format(CLASSES[label], 'OK' if correct else 'NO', u"\u2192" if not correct else '',
#                                 CLASSES[correct_label] if not correct else ''), correct

# def display_one_flower(image, title, subplot, red=False, titlesize=16):
#     plt.subplot(*subplot)
#     plt.axis('off')
#     plt.imshow(image)
#     if len(title) > 0:
#         plt.title(title, fontsize=int(titlesize) if not red else int(titlesize/1.2), color='red' if red else 'black', fontdict={'verticalalignment':'center'}, pad=int(titlesize/1.5))
#     return (subplot[0], subplot[1], subplot[2]+1)
    
# def display_batch_of_images(databatch, predictions=None):
#     """This will work with:
#     display_batch_of_images(images)
#     display_batch_of_images(images, predictions)
#     display_batch_of_images((images, labels))
#     display_batch_of_images((images, labels), predictions)
#     """
#     # data
#     images, labels = batch_to_numpy_images_and_labels(databatch)
#     if labels is None:
#         labels = [None for _ in enumerate(images)]
        
#     # auto-squaring: this will drop data that does not fit into square or square-ish rectangle
#     rows = int(math.sqrt(len(images)))
#     cols = len(images)//rows
        
#     # size and spacing
#     FIGSIZE = 13.0
#     SPACING = 0.1
#     subplot=(rows,cols,1)
#     if rows < cols:
#         plt.figure(figsize=(FIGSIZE,FIGSIZE/cols*rows))
#     else:
#         plt.figure(figsize=(FIGSIZE/rows*cols,FIGSIZE))
    
#     # display
#     for i, (image, label) in enumerate(zip(images[:rows*cols], labels[:rows*cols])):
#         title = '' if label is None else CLASSES[label]
#         correct = True
#         if predictions is not None:
#             title, correct = title_from_label_and_target(predictions[i], label)
#         dynamic_titlesize = FIGSIZE*SPACING/max(rows,cols)*40+3 # magic formula tested to work from 1x1 to 10x10 images
#         subplot = display_one_flower(image, title, subplot, not correct, titlesize=dynamic_titlesize)
    
#     #layout
#     plt.tight_layout()
#     if label is None and predictions is None:
#         plt.subplots_adjust(wspace=0, hspace=0)
#     else:
#         plt.subplots_adjust(wspace=SPACING, hspace=SPACING)
#     plt.show()

# def display_confusion_matrix(cmat, score, precision, recall):
#     plt.figure(figsize=(15,15))
#     ax = plt.gca()
#     ax.matshow(cmat, cmap='Reds')
#     ax.set_xticks(range(len(CLASSES)))
#     ax.set_xticklabels(CLASSES, fontdict={'fontsize': 7})
#     plt.setp(ax.get_xticklabels(), rotation=45, ha="left", rotation_mode="anchor")
#     ax.set_yticks(range(len(CLASSES)))
#     ax.set_yticklabels(CLASSES, fontdict={'fontsize': 7})
#     plt.setp(ax.get_yticklabels(), rotation=45, ha="right", rotation_mode="anchor")
#     titlestring = ""
#     if score is not None:
#         titlestring += 'f1 = {:.3f} '.format(score)
#     if precision is not None:
#         titlestring += '\nprecision = {:.3f} '.format(precision)
#     if recall is not None:
#         titlestring += '\nrecall = {:.3f} '.format(recall)
#     if len(titlestring) > 0:
#         ax.text(101, 1, titlestring, fontdict={'fontsize': 18, 'horizontalalignment':'right', 'verticalalignment':'top', 'color':'#804040'})
#     plt.show()
    
def display_training_curves(training, validation, title, subplot):
    if subplot%10==1: # set up the subplots on the first call
        plt.subplots(figsize=(10,10), facecolor='#F0F0F0')
        plt.tight_layout()
    ax = plt.subplot(subplot)
    ax.set_facecolor('#F8F8F8')
    ax.plot(training)
    ax.plot(validation)
    ax.set_title('model '+ title)
    ax.set_ylabel(title)
    #ax.set_ylim(0.28,1.05)
    ax.set_xlabel('epoch')
    ax.legend(['train', 'valid.'])

# Datasets

In [None]:
# def decode_image(image_data):
#     image = tf.image.decode_jpeg(image_data, channels=3)  # image format uint8 [0,255]
#     image = tf.reshape(image, [*IMAGE_SIZE, 3]) # explicit size needed for TPU
#     return image

# def read_labeled_tfrecord(example):
#     LABELED_TFREC_FORMAT = {
#         "image": tf.io.FixedLenFeature([], tf.string), # tf.string means bytestring
#         "class": tf.io.FixedLenFeature([], tf.int64),  # shape [] means single element
#     }
#     example = tf.io.parse_single_example(example, LABELED_TFREC_FORMAT)
#     image = decode_image(example['image'])
#     label = tf.cast(example['class'], tf.int32)
#     return image, label # returns a dataset of (image, label) pairs

# def read_unlabeled_tfrecord(example):
#     UNLABELED_TFREC_FORMAT = {
#         "image": tf.io.FixedLenFeature([], tf.string), # tf.string means bytestring
#         "id": tf.io.FixedLenFeature([], tf.string),  # shape [] means single element
#         # class is missing, this competitions's challenge is to predict flower classes for the test dataset
#     }
#     example = tf.io.parse_single_example(example, UNLABELED_TFREC_FORMAT)
#     image = decode_image(example['image'])
#     idnum = example['id']
#     return image, idnum # returns a dataset of image(s)

# def load_dataset(filenames, labeled=True, ordered=False):
#     # Read from TFRecords. For optimal performance, reading from multiple files at once and
#     # disregarding data order. Order does not matter since we will be shuffling the data anyway.

#     ignore_order = tf.data.Options()
#     if not ordered:
#         ignore_order.experimental_deterministic = False # disable order, increase speed

#     dataset = tf.data.TFRecordDataset(filenames, num_parallel_reads=AUTO) # automatically interleaves reads from multiple files
#     dataset = dataset.with_options(ignore_order) # uses data as soon as it streams in, rather than in its original order
#     dataset = dataset.map(read_labeled_tfrecord if labeled else read_unlabeled_tfrecord, num_parallel_calls=AUTO)
#     # returns a dataset of (image, label) pairs if labeled=True or (image, id) pairs if labeled=False
#     return dataset

# def data_augment(image, label):
#     # data augmentation. Thanks to the dataset.prefetch(AUTO) statement in the next function (below),
#     # this happens essentially for free on TPU. Data pipeline code is executed on the "CPU" part
#     # of the TPU while the TPU itself is computing gradients.
#     image = tf.image.random_flip_left_right(image)
#     #image = tf.image.random_saturation(image, 0, 2)
#     return image, label   

# def get_training_dataset():
#     dataset = load_dataset(TRAINING_FILENAMES, labeled=True)
#     dataset = dataset.map(data_augment, num_parallel_calls=AUTO)
#     dataset = dataset.repeat() # the training dataset must repeat for several epochs
#     dataset = dataset.shuffle(2048)
#     dataset = dataset.batch(BATCH_SIZE, drop_remainder=True)
#     dataset = dataset.prefetch(AUTO) # prefetch next batch while training (autotune prefetch buffer size)
#     return dataset

# def get_validation_dataset(ordered=False):
#     dataset = load_dataset(VALIDATION_FILENAMES, labeled=True, ordered=ordered)
#     dataset = dataset.batch(BATCH_SIZE, drop_remainder=True)
#     dataset = dataset.cache()
#     dataset = dataset.prefetch(AUTO) # prefetch next batch while training (autotune prefetch buffer size)
#     return dataset

# def get_test_dataset(ordered=False):
#     dataset = load_dataset(TEST_FILENAMES, labeled=False, ordered=ordered)
#     dataset = dataset.batch(BATCH_SIZE)
#     dataset = dataset.prefetch(AUTO) # prefetch next batch while training (autotune prefetch buffer size)
#     return dataset

# def count_data_items(filenames):
#     # the number of data items is written in the name of the .tfrec files, i.e. flowers00-230.tfrec = 230 data items
#     n = [int(re.compile(r"-([0-9]*)\.").search(filename).group(1)) for filename in filenames]
#     return np.sum(n)

# NUM_TRAINING_IMAGES = count_data_items(TRAINING_FILENAMES)
# NUM_VALIDATION_IMAGES = count_data_items(VALIDATION_FILENAMES)
# NUM_TEST_IMAGES = count_data_items(TEST_FILENAMES)
# STEPS_PER_EPOCH = NUM_TRAINING_IMAGES // BATCH_SIZE
# VALIDATION_STEPS = -(-NUM_VALIDATION_IMAGES // BATCH_SIZE) # The "-(-//)" trick rounds up instead of down :-)
# TEST_STEPS = -(-NUM_TEST_IMAGES // BATCH_SIZE)             # The "-(-//)" trick rounds up instead of down :-)
# print('Dataset: {} training images, {} validation images, {} unlabeled test images'.format(NUM_TRAINING_IMAGES, NUM_VALIDATION_IMAGES, NUM_TEST_IMAGES))

# Dataset visualizations

In [None]:
# # data dump
# print("Training data shapes:")
# for image, label in get_training_dataset().take(3):
#     print(image.numpy().shape, label.numpy().shape)
# print("Training data label examples:", label.numpy())
# print("Validation data shapes:")
# for image, label in get_validation_dataset().take(3):
#     print(image.numpy().shape, label.numpy().shape)
# print("Validation data label examples:", label.numpy())
# print("Test data shapes:")
# for image, idnum in get_test_dataset().take(3):
#     print(image.numpy().shape, idnum.numpy().shape)
# print("Test data IDs:", idnum.numpy().astype('U')) # U=unicode string

In [None]:
# # Peek at training data
# training_dataset = get_training_dataset()
# training_dataset = training_dataset.unbatch().batch(20)
# train_batch = iter(training_dataset)

In [None]:
# # run this cell again for next set of images
# display_batch_of_images(next(train_batch))

In [None]:
# # peer at test data
# test_dataset = get_test_dataset()
# test_dataset = test_dataset.unbatch().batch(20)
# test_batch = iter(test_dataset)

In [None]:
# # run this cell again for next set of images
# display_batch_of_images(next(test_batch))

In [None]:
import json

def load_split(file: str, is_jsonl):
    """
    function that loads a .json file in a dictionary, used to load Monte Carlo splits.
    :param case_based: whether to use cases.
    :param manual_tags: whether to use manual tags.
    :param file: the file to be loaded
    :param path: the path of the file
    :return: the dictionary of the data and the unique concepts of the file
    """

    with open(file) as json_file:

        if is_jsonl:
            data_imgs = [json.loads(line)["id"] for line in open(file)] 
            data_labels = [json.loads(line)["label"] for line in open(file)]
            data_captions = [json.loads(line)["text"] for line in open(file)]
            data = {}
            text = {}
            print(len(data_imgs))
            
            for i in range(len(data_imgs)):
                text[data_imgs[i]+".jpg"]  = data_captions[i]
                if data_labels[i] == '':
                    data[data_imgs[i]+".jpg"] = ['others']
                else:
                    labels = data_labels[i].replace("'","").split(", ")
                    if len(labels)>1 and "No Finding" in labels and "Support Devices" in labels:
                        labels = ["support devices"]
                    if len(labels)==1 and "No Finding" in labels:
                        labels = ["normal"]
                        
                    data[data_imgs[i]+".jpg"] = [x.lower() for x in labels]
        else:
            data = json.load(json_file)
        
        print('Loaded from: ', file)
        keys = list(data.keys())
        partition = {}  # train/val/test partition
        concepts = []  # concepts for train/val/test
        for key in keys:
            partition[key] = data[key]
            concepts.extend(data[key])
        concepts = list(set(concepts))
    return partition, concepts, text  # maybe concepts can be omitted in test set.


#from keras.utils import CustomObjectScope
import math
#from Attention import Attention
# import utilities as u
import os
from tqdm import tqdm
import numpy as np
import random


def load_data( data, concepts_list):
    """
    function that loads the images and tags in NumPy arrays.
    Tags become one-hot encoded.
    :param data: the data dictionary to be loaded into the arrays.
    :param concepts_list: the list of tags.
    :return: data as (X,y) that can be used in training.
    """
    x_data, y_data = [], []
    # read the data file
    for img_id in tqdm(data.keys()):
        image_path = os.path.join(images_dir, img_id)
        img = image.load_img(image_path, target_size=(224, 224))  # load PIL image.
        x = image.img_to_array(img)  # turn the PIL image to NumPy array.
        x = preprocess_input(x)  # mean and std of ImageNet, also [0,1] values.
        # encode the tags
        concepts = np.zeros(len(concepts_list), dtype=int)
        if len(data[img_id]) != 0:
            image_concepts = data[img_id]
        else:
            image_concepts = []
        for i in range(0, len(concepts_list)):
            # if the tag is assigned to the image put 1 in its position in the true binary vector
            if concepts_list[i] in image_concepts:
                concepts[i] = 1  # 1-hot encoding.
        x_data.append(x)
        y_data.append(concepts)
    # creates images and labels
    return np.array(x_data), np.array(y_data)


def tune_threshold(model, x_val, y_val, generator, multilabel: bool):
    """
    tune the threshold in validation data.
    :param x_val: validation images as NumPy array.
    :param split: the number of Monte Carlo current split.
    :return: the best threshold and its validation score.
    """

    # 2D NumPy array(rows=images,columns=prediction for each tag)
    if generator != None:
        predictions = model.predict(generator)
    else:
        predictions = model.predict(x_val, batch_size=16, verbose=1)
    steps = 100
    f1_scores = {}
    for i in tqdm(range(steps)):
        threshold = float(i + 1) / steps  # ImageCLEF 2020 tests.
        y_pred_val =  ( predictions >= threshold).astype('int')
        if multilabel:
            y_pred_val = add_column(y_pred_val)
            val_data = add_column(y_val)
            f1_scores[threshold] = f1_score(val_data, y_pred_val, average="micro")
        else:
            f1_scores[threshold] = f1_score(y_val, y_pred_val)

    best_threshold = max(f1_scores, key=f1_scores.get)  # get key with max value.
    print('The best F1 score on validation data for split #' + 
          ' is ' + str(f1_scores[best_threshold]) +
          ' achieved with threshold = ' + str(best_threshold) + '\n')

    return best_threshold, f1_scores[best_threshold]


def removeNormal(data, concepts):
    for vals in data.values():
        if "normal" in vals:
            vals.remove("normal")
    
    concepts.remove("normal")
    
    return data, concepts
    

splitDir = "../input/mimicmedvillsplit/"

train_data, train_concepts, train_captions = load_split(splitDir+'train.json',True)
# train_data, train_concepts = removeNormal(train_data, train_concepts)

val_data, val_concepts, val_captions = load_split(splitDir+'valid.json',True)
# val_data, val_concepts = removeNormal(val_data, val_concepts)


test_data, test_concepts, test_captions = load_split(splitDir+'test.json',True)
# test_data, test_concepts = removeNormal(test_data, test_concepts)


print('Total tags: ', len(train_concepts))

print("normal" in train_concepts)


train_data, train_concepts = removeNormal(train_data, train_concepts)

val_data, val_concepts = removeNormal(val_data, val_concepts)

test_data, test_concepts = removeNormal(test_data, test_concepts)


print('Total tags: ', len(train_concepts))

import pandas as pd 

def dict_to_df(data_dict, text_dict):
    df = pd.DataFrame.from_dict([data_dict]).T
    df.reset_index(inplace=True)
    df.columns = ["filename","labels"]

    df2 = pd.DataFrame.from_dict([text_dict]).T
    df2.reset_index(inplace=True)
    df2.columns = ["filename","captions"]
    
    return df.join(df2.set_index('filename'), on='filename')



train_df = dict_to_df(train_data, train_captions)
val_df = dict_to_df(val_data, val_captions)
test_df = dict_to_df(test_data, test_captions) 

train_df.drop('captions', inplace=True, axis=1)
val_df.drop('captions', inplace=True, axis=1)
test_df.drop('captions', inplace=True, axis=1)  

df = pd.concat([train_df, val_df, test_df]).copy()
df = df.sample(frac=1).reset_index(drop=True)
del train_df,val_df,test_df

from sklearn.preprocessing import MultiLabelBinarizer
mlb = MultiLabelBinarizer()
df.labels = mlb.fit_transform(df.labels.tolist()).tolist()
# train_df.labels = mlb.fit_transform(train_df.labels.tolist()).tolist()
# val_df.labels = mlb.fit_transform(val_df.labels.tolist()).tolist()
# test_df.labels = mlb.fit_transform(test_df.labels.tolist()).tolist() 

idx1 = int(0.65 * len(df))
idx2 = int(0.8 * len(df))
train_df = df[0:idx1].reset_index(drop=True).copy()
val_df = df[idx1:idx2].reset_index(drop=True).copy()
test_df = df[idx2:].reset_index(drop=True).copy()

del df
print(mlb.classes_)

In [None]:
from os.path import exists

class CustomDataGen(tf.keras.utils.Sequence):

    def __init__(self, df, batch_size, input_size, preprocess_input, path, shuffle, modality):
        self.df = df.copy()
        self.batch_size = batch_size
        self.input_size = input_size
        self.shuffle = shuffle
        self.n = len(self.df)
        self.preprocess_input = preprocess_input
        self.path = path
        self.modality = modality

    def on_epoch_end(self):
        if self.shuffle:
            self.df = self.df.sample(frac=1).reset_index(drop=True)

    def __get_input(self, path, target_size):
       
        if exists(self.path+"Train/"+path):
            path = self.path+"Train/"+path
        elif exists(self.path+"Test/"+path):
            path = self.path+"Test/"+path
        else:
            path = self.path+"Valid/"+path
        
        image = tf.keras.preprocessing.image.load_img(path, target_size=(target_size[0], target_size[1]))
        image = tf.keras.preprocessing.image.img_to_array(image)
#         image = preprocess_input(image)
        return image 

    def __get_data(self, batches):
        # Generates data containing batch_size samples
        y_batch = np.asarray([np.asarray(y) for y in batches.labels])

        if self.modality=="image":
            image_batch = np.asarray([self.__get_input(x, self.input_size) for x in batches.filename])
            return image_batch, y_batch
        elif self.modality=="text":
            text_batch = np.asarray([np.asarray(x) for x in batches.captions])
            return text_batch, y_batch
        elif self.modality=="multimodal":
            image_batch = np.asarray([self.__get_input(x, self.input_size) for x in batches.filename])
            text_batch = np.asarray([np.asarray(x) for x in batches.captions])
            return tuple([image_batch, text_batch]), y_batch

    def __getitem__(self, index):
        batches = self.df[index * self.batch_size:(index + 1) * self.batch_size]
        X, y = self.__get_data(batches)        
        return X, y

    def __len__(self):
        return self.n // self.batch_size
        
images_dir = "../input/medvilmimic/mimic/re_512_3ch/"

train_generator = CustomDataGen(train_df, batch_size=batch_size, input_size=(224,224,3), preprocess_input=None, path=images_dir, shuffle=True, modality='image')
val_generator =  CustomDataGen(val_df, batch_size=batch_size, input_size=(224,224,3), preprocess_input=None, path=images_dir, shuffle=True, modality='image')
# test_generator =  CustomDataGen(test_df, batch_size=batch_size, input_size=(224,224,3), preprocess_input=None, path=images_dir+'Test/', shuffle=False, modality='image')


In [None]:
NUM_TRAINING_IMAGES = train_generator.n
NUM_VALIDATION_IMAGES = val_generator.n
# NUM_TEST_IMAGES = count_data_items(TEST_FILENAMES)
STEPS_PER_EPOCH = NUM_TRAINING_IMAGES // batch_size
VALIDATION_STEPS = -(-NUM_VALIDATION_IMAGES // batch_size) # The "-(-//)" trick rounds up instead of down :-)

# Model
You can select these models:  
`swin_tiny_224`    
`swin_small_224`  
`swin_base_224`  
`swin_base_384`  
`swin_large_224`  
`swin_large_384`  

In [None]:
!pip install tensorflow-addons
import tensorflow_addons as tfa

In [None]:
f1_tfa =  tfa.metrics.F1Score(num_classes=len(mlb.classes_), average="micro",threshold=0.5)

In [None]:
batch_size

In [None]:
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau
early_stopping = EarlyStopping(monitor='val_loss', patience=3, mode='auto', restore_best_weights=True, verbose=1)
reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.1, patience=1, verbose=1, mode='min')


In [None]:
# with strategy.scope():
img_adjust_layer = tf.keras.layers.Lambda(lambda data: tf.keras.applications.imagenet_utils.preprocess_input(tf.cast(data, tf.float32), mode="torch"), input_shape=[*IMAGE_SIZE, 3])
#     pretrained_model = SwinTransformer('swin_large_224', len(mlb.classes_), include_top=False, pretrained=True, use_tpu=True)
pretrained_model = SwinTransformer('swin_large_224', len(mlb.classes_), include_top=False, pretrained=True, use_tpu=False)

model = tf.keras.Sequential([
    img_adjust_layer,
    pretrained_model,
    tf.keras.layers.Dense(len(mlb.classes_), activation='sigmoid')
])

model.compile(
optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4, epsilon=1e-8),
loss = 'binary_crossentropy',
metrics=[f1_tfa]
)
model.summary()

# Training

In [None]:
history = model.fit(train_generator,  epochs=epochs,
                    validation_data=val_generator, callbacks=[early_stopping, reduce_lr])

In [None]:
import matplotlib.pyplot as plt
plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])

In [None]:
plt.plot(history.history['f1_score'])
plt.plot(history.history['val_f1_score'])

In [None]:
from sklearn.metrics import precision_score, recall_score, f1_score, confusion_matrix


def add_column(y):
    return np.hstack((y,(np.sum(y, axis=1) == 0).astype('int').reshape(-1, 1) ))

In [None]:
y_val = np.array(val_df.labels.tolist())

In [None]:
val_generator2 =  CustomDataGen(val_df, batch_size=1, input_size=(224,224,3), preprocess_input=None, path=images_dir, shuffle=False, modality='image')


In [None]:
y_pred = model.predict(val_generator2)

best_threshold = 0
best_f1 = 0
for thres in np.arange(0,0.7,0.05):
    y_p = ( y_pred >= thres).astype('int')
    f1 =  f1_score(y_val, y_p, average="micro")
    print(thres,f1 )
    if f1>best_f1:
        best_f1 = f1
        best_threshold = thres

In [None]:
y_pred =  ( model.predict(val_generator2) >= best_threshold).astype('int')
y_pred2 = add_column(y_pred)
y_val2 = add_column(y_val)

In [None]:
for avg_type in ['micro','macro']:
    print(avg_type, f1_score(y_pred2, y_val2, average=avg_type)  )
#     print("Samples",avg_type, samples_f1(y_pred2, y_val2,avg_type)  )

In [None]:
from sklearn.metrics import classification_report

target_names = train_concepts.copy()
target_names.append("normal")

print( classification_report(y_val2, y_pred2, target_names=target_names) )

In [None]:
from sklearn.dummy import DummyClassifier
dummy_clf = DummyClassifier(strategy="most_frequent")
y_train = np.array(train_df.labels.tolist())
dummy_clf.fit(y_train,y_train)


In [None]:
for avg_type in ['micro','macro']:
    print(avg_type, f1_score(y_pred2, y_val2, average=avg_type)  )

In [None]:
print( classification_report(y_val2,y_pred2, target_names=target_names) )

In [None]:
import os
os.mkdir("saved")

In [None]:
model.save_weights("saved/ckpt")

In [None]:
import shutil
shutil.make_archive("saved", 'zip', 'saved')

# Confusion matrix

In [None]:
# cmdataset = get_validation_dataset(ordered=True) # since we are splitting the dataset and iterating separately on images and labels, order matters.
# images_ds = cmdataset.map(lambda image, label: image)
# labels_ds = cmdataset.map(lambda image, label: label).unbatch()
# cm_correct_labels = next(iter(labels_ds.batch(NUM_VALIDATION_IMAGES))).numpy() # get everything as one batch
# cm_probabilities = model.predict(images_ds, steps=VALIDATION_STEPS)
# cm_predictions = np.argmax(cm_probabilities, axis=-1)
# print("Correct   labels: ", cm_correct_labels.shape, cm_correct_labels)
# print("Predicted labels: ", cm_predictions.shape, cm_predictions)

In [None]:
# cmat = confusion_matrix(cm_correct_labels, cm_predictions, labels=range(len(CLASSES)))
# score = f1_score(cm_correct_labels, cm_predictions, labels=range(len(CLASSES)), average='macro')
# precision = precision_score(cm_correct_labels, cm_predictions, labels=range(len(CLASSES)), average='macro')
# recall = recall_score(cm_correct_labels, cm_predictions, labels=range(len(CLASSES)), average='macro')
# cmat = (cmat.T / cmat.sum(axis=1)).T # normalized
# display_confusion_matrix(cmat, score, precision, recall)
# print('f1 score: {:.3f}, precision: {:.3f}, recall: {:.3f}'.format(score, precision, recall))