In [None]:
import numpy as np
import matplotlib.pyplot as plt

import tensorflow as tf
from tensorflow import keras

from tensorflow.keras.layers import Input
from tensorflow.keras.models import Model

from keras.layers import Activation
from keras.layers import Conv2D , BatchNormalization

In [None]:
idg = keras.preprocessing.image.ImageDataGenerator()

xtrain, ytrain = next(idg.flow_from_directory(train,
                                            target_size=(256, 128),
                                            color_mode='rgb',
                                            batch_size=30000,
                                            shuffle=True) )

xtest, ytest = next(idg.flow_from_directory(test,
                                            target_size=(256, 128),
                                            color_mode='rgb',
                                            batch_size=30000,
                                            shuffle=True) )

xquery, yquery = next(idg.flow_from_directory(query,
                                            target_size=(256, 128),
                                            color_mode='rgb',
                                            batch_size=30000,
                                            shuffle=True) )


print(xtrain.shape)
print(ytrain.shape)
print(xtest.shape)
print(ytest.shape)
print(xquery.shape)
print(yquery.shape)

In [None]:
#enter the number of classes in your training dataset
u = 25

In [None]:
def custom_loss(s1,s2,s3,m1,m2,m3, y_train):
    """Your custom loss function."""
    
    s1 = tf.math.reduce_sum(s1*y_train, axis = 1)
    s2 = tf.math.reduce_sum(s2*y_train, axis = 1)
    s3 = tf.math.reduce_sum(s3*y_train, axis = 1)
    
    Li = -tf.math.log(s1*s2*s3)
    
    y = np.reshape(np.stack([y_train]*32, axis=1), (y_train.shape[0], 8, 4, y_train.shape[1]))
    
    p1 = tf.math.reduce_sum(m1*y, axis=-1)
    p2 = tf.math.reduce_sum(m2*y, axis=-1)
    p3 = tf.math.reduce_sum(m3*y, axis=-1)
    
    sigma1 = tf.sort(tf.reshape(p1,(y_train.shape[0],32)))[:,20]
    sigma2 = tf.sort(tf.reshape(p2,(y_train.shape[0],32)))[:,20]
    sigma3 = tf.sort(tf.reshape(p3,(y_train.shape[0],32)))[:,20]
    
    sigma1 = tf.reshape(tf.stack([sigma1]*32, axis=1), (y_train.shape[0],8,4))
    sigma2 = tf.reshape(tf.stack([sigma2]*32, axis=1), (y_train.shape[0],8,4))
    sigma3 = tf.reshape(tf.stack([sigma3]*32, axis=1), (y_train.shape[0],8,4))

    a1 = 1 / (1 + tf.math.exp(-p1 + sigma1))
    a2 = 1 / (1 + tf.math.exp(-p2 + sigma2))
    a3 = 1 / (1 + tf.math.exp(-p3 + sigma3))
    
    a = a1*a2*a3
    
    Loap = tf.math.reduce_mean(a, axis=(1,2))
    
    loss = Li+Loap

    return Loap


In [None]:
resnet = keras.applications.ResNet50

class MyModel(tf.keras.Model):

    def __init__(self,unit, **kwargs):
        super().__init__(**kwargs)
        
        # Define some layers
        self.unit = unit
        self.resnet = resnet(include_top=False, weights='imagenet', input_shape=(256,128,3))
        self.F = BatchNormalization()

        #751 is the no. of classes in training set
        self.m1 = Conv2D(self.unit, (1, 1), padding="same")
        self.m2 = Conv2D(self.unit, (1, 1), padding="same")
        self.m3 = Conv2D(self.unit, (1, 1), padding="same")

        self.s1 = keras.layers.GlobalAveragePooling2D()
        self.s2 = keras.layers.GlobalAveragePooling2D()
        self.s3 = keras.layers.GlobalAveragePooling2D()

        self.S1 = keras.layers.Softmax()
        self.S2 = keras.layers.Softmax()
        self.S3 = keras.layers.Softmax()

    def call(self, inputs):
        
        x = self.resnet(inputs)
        F = self.F(x)
        
        m1 = self.m1(F)
        m2 = self.m2(F)
        m3 = self.m3(F)
        
        s1 = self.s1(m1)
        s2 = self.s2(m2)
        s3 = self.s3(m3)
        
        S1 = self.S1(s1)
        S2 = self.S2(s2)
        S3 = self.S3(s3)

        # Add the loss to the model
        loss = custom_loss(S1,S2,S3,m1,m2,m3, ytrain)
        self.add_loss(loss)
        return (S1+S2+S3)/3

In [None]:
model = MyModel(unit = uint)

In [None]:
model.layers[0].trainable = False
model.compile(optimizer='adam', metrics = ['accuracy'])

In [None]:
model.fit(xtrain, ytrain, ytrain.shape[0], epochs=10)

In [None]:
model.summary()

In [None]:
model.layers[0].trainable = True
model.compile(optimizer='adam', metrics = ['accuracy'])

In [None]:
model.fit(xtrain, ytrain, ytrain.shape[0], epochs=50)

In [None]:
model.summary()