In [163]:
%useLatestDescriptors
%use kotlin-dl (0.5.1)
%use dataframe (0.9.1)

In [164]:
val EPOCHS = 3
val TRAINING_BATCH_SIZE = 8
val TEST_BATCH_SIZE = 16
val NUM_CLASSES = 2
val NUM_CHANNELS = 3
val IMAGE_SIZE = 300
val TRAIN_TEST_SPLIT_RATIO = 0.7
val PATH_TO_MODEL = "/Users/Alexey.Zinoviev/IdeaProjects/KotlinDL-for-KotlinConf2023/savedmodels/customResNet50"

In [165]:
import org.jetbrains.kotlinx.dl.api.core.Functional
import org.jetbrains.kotlinx.dl.api.core.SavingFormat
import org.jetbrains.kotlinx.dl.api.core.Sequential
import org.jetbrains.kotlinx.dl.api.core.WritingMode
import org.jetbrains.kotlinx.dl.api.core.activation.Activations
import org.jetbrains.kotlinx.dl.api.core.initializer.GlorotUniform
import org.jetbrains.kotlinx.dl.api.core.layer.core.Dense
import org.jetbrains.kotlinx.dl.api.core.layer.pooling.GlobalAvgPool2D
import org.jetbrains.kotlinx.dl.api.core.loss.Losses
import org.jetbrains.kotlinx.dl.api.core.metric.Metrics
import org.jetbrains.kotlinx.dl.api.core.optimizer.Adam
import org.jetbrains.kotlinx.dl.api.inference.keras.loadWeightsForFrozenLayers
import org.jetbrains.kotlinx.dl.api.inference.loaders.TFModelHub
import org.jetbrains.kotlinx.dl.api.inference.loaders.TFModels
import org.jetbrains.kotlinx.dl.api.preprocessing.pipeline
import org.jetbrains.kotlinx.dl.dataset.OnFlyImageDataset
import org.jetbrains.kotlinx.dl.dataset.embedded.dogsCatsSmallDatasetPath
import org.jetbrains.kotlinx.dl.dataset.generator.FromFolders
import org.jetbrains.kotlinx.dl.impl.preprocessing.call
import org.jetbrains.kotlinx.dl.impl.preprocessing.image.*
import org.jetbrains.kotlinx.dl.dataset.preprocessing.fileLoader
import org.jetbrains.kotlinx.dl.impl.summary.logSummary
import java.awt.image.BufferedImage
import java.io.File

In [166]:
val modelHub = TFModelHub(cacheDirectory = File("../cache/pretrainedModels"))
val modelType = TFModels.CV.ResNet50(noTop = true, inputShape = intArrayOf(IMAGE_SIZE, IMAGE_SIZE, NUM_CHANNELS))
val noTopModel = modelHub.loadModel(modelType)

In [167]:
val preprocessing = pipeline<BufferedImage>()
        .resize {
            outputHeight = IMAGE_SIZE
            outputWidth = IMAGE_SIZE
            interpolation = InterpolationType.BILINEAR
        }
        .convert { colorMode = ColorMode.BGR }
        .toFloatArray { }
        .call(TFModels.CV.ResNet50().preprocessor)

val dataset = OnFlyImageDataset.create(
        File(dogsCatsSmallDatasetPath(File("../cache"))),
        FromFolders(mapping = mapOf("cat" to 0, "dog" to 1)),
        preprocessing
    ).shuffle()

val (train, test) = dataset.split(TRAIN_TEST_SPLIT_RATIO)

In [168]:
val hdfFile = modelHub.loadWeights(modelType)

In [169]:
val topModel = Sequential.of(
    GlobalAvgPool2D(
        name = "top_avg_pool",
    ),
    Dense(
        name = "top_dense",
        kernelInitializer = GlorotUniform(),
        biasInitializer = GlorotUniform(),
        outputSize = 200,
        activation = Activations.Relu
    ),
    Dense(
        name = "pred",
        kernelInitializer = GlorotUniform(),
        biasInitializer = GlorotUniform(),
        outputSize = NUM_CLASSES,
        activation = Activations.Linear
    ),
    noInput = true
)

In [170]:
val model = Functional.of(pretrainedModel = noTopModel, topModel = topModel)
        model.compile(
            optimizer = Adam(),
            loss = Losses.SOFT_MAX_CROSS_ENTROPY_WITH_LOGITS,
            metric = Metrics.ACCURACY
        )

In [171]:
model.loadWeightsForFrozenLayers(hdfFile)

In [172]:
val accuracyBeforeTraining = model.evaluate(dataset = test, batchSize = TEST_BATCH_SIZE).metrics[Metrics.ACCURACY]
"Accuracy before training $accuracyBeforeTraining"



Accuracy before training 0.4241071343421936

TODO: add cats, dogs wrongly classified

In [173]:
val trainHistory = model.fit(
    dataset = train,
    batchSize = TRAINING_BATCH_SIZE,
    epochs = EPOCHS
)

val accuracyAfterTraining = model.evaluate(dataset = test, batchSize = TEST_BATCH_SIZE).metrics[Metrics.ACCURACY]

"Accuracy after training $accuracyAfterTraining"

Accuracy after training 0.96875

In [174]:
val trainHistoryDF = trainHistory.batchHistory.toDataFrame()
trainHistoryDF.print()

    epochIndex batchIndex lossValue        metricValues
  0          1          0  2,948558              [0.75]
  1          1          1  0,626892               [1.5]
  2          1          2  1,450236               [2.0]
  3          1          3  0,758204             [2.625]
  4          1          4  0,278900               [3.5]
  5          1          5  0,664509             [4.125]
  6          1          6  0,620354               [5.0]
  7          1          7  0,371981              [5.75]
  8          1          8  0,241257 [6.583333492279053]
  9          2          0  0,013084               [1.0]
 10          2          1  0,000213               [2.0]
 11          2          2  0,103701               [3.0]
 12          2          3  0,001088               [4.0]
 13          2          4  0,003956               [5.0]
 14          2          5  0,000152               [6.0]
 15          2          6  0,001304               [7.0]
 16          2          7  0,327463             

In [175]:
model.save(
            File(PATH_TO_MODEL),
            SavingFormat.JSON_CONFIG_CUSTOM_VARIABLES,
            writingMode = WritingMode.OVERRIDE
        )

In [176]:
model.close()

In [177]:
val model2 = Functional.loadModelConfiguration(File("$PATH_TO_MODEL/modelConfig.json"))

    val fileDataLoader = pipeline<BufferedImage>()
        .resize {
            outputHeight = IMAGE_SIZE
            outputWidth = IMAGE_SIZE
            interpolation = InterpolationType.BILINEAR
        }
        .convert { colorMode = ColorMode.BGR }
        .toFloatArray { }
        .call(TFModels.CV.ResNet50().preprocessor)
        .fileLoader()

In [178]:
fun setUpModel(it: Functional) {
    it.compile(
        optimizer = RMSProp(),
        loss = Losses.MAE,
        metric = Metrics.ACCURACY
    )
    it.logSummary()

    it.loadWeights(File(PATH_TO_MODEL))
}

In [179]:
 model2.use {
        setUpModel(it)
        println("CATS")
        for (i in 0..49) {
            val inputData = fileDataLoader.load(File("../cache/datasets/small-dogs-vs-cats/cat/cat.$i.jpg")).first
            val res = it.predict(inputData)
            println("Predicted object for cat.$i.jpg is $res")
        }

        println("DOGS")

        for (i in 0..49) {
            val inputData = fileDataLoader.load(File("../cache/datasets/small-dogs-vs-cats/dog/dog.$i.jpg")).first
            val res = it.predict(inputData)
            println("Predicted object for dog.$i.jpg is $res")
        }
    }

CATS
Predicted object for cat.0.jpg is 0
Predicted object for cat.1.jpg is 0
Predicted object for cat.2.jpg is 0
Predicted object for cat.3.jpg is 0
Predicted object for cat.4.jpg is 0
Predicted object for cat.5.jpg is 0
Predicted object for cat.6.jpg is 0
Predicted object for cat.7.jpg is 0
Predicted object for cat.8.jpg is 0
Predicted object for cat.9.jpg is 0
Predicted object for cat.10.jpg is 0
Predicted object for cat.11.jpg is 0
Predicted object for cat.12.jpg is 0
Predicted object for cat.13.jpg is 0
Predicted object for cat.14.jpg is 0
Predicted object for cat.15.jpg is 0
Predicted object for cat.16.jpg is 0
Predicted object for cat.17.jpg is 0
Predicted object for cat.18.jpg is 0
Predicted object for cat.19.jpg is 0
Predicted object for cat.20.jpg is 0
Predicted object for cat.21.jpg is 0
Predicted object for cat.22.jpg is 0
Predicted object for cat.23.jpg is 0
Predicted object for cat.24.jpg is 0
Predicted object for cat.25.jpg is 0
Predicted object for cat.26.jpg is 0
Predic