In [4]:
import os
import json
from functions import image_segmentation_generator

def train(model,
          train_images,
          train_annotations,
          val_images,
          val_annotations,
          input_height=None,
          input_width=None,
          n_classes=None,
          checkpoints_path = "checkpoints",
          epochs=5,
          batch_size=32,
          steps_per_epoch=512,
          val_steps_per_epoch=512,
          load_weights=None,
          read_image_type=1):  # cv2.IMREAD_COLOR = 1 (rgb)
    
    os.environ['PYTHONIOENCODING'] = 'utf-8'
    n_classes = model.n_classes
    input_height = model.input_height
    input_width = model.input_width
    output_height = model.output_height
    output_width = model.output_width
    
    model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])

    # Load weights if provided
    # if load_weights:
    #     print(f"Loading weights from {load_weights}")
    #     model.load_weights(load_weights)
    
    # Save model configuration
    # if checkpoints_path:
    #     config_file = os.path.join(checkpoints_path, "config.json")
    #     os.makedirs(os.path.dirname(config_file), exist_ok=True)
    #     with open(config_file, "w") as f:
    #         json.dump({
    #             "model_class": model.model_name,
    #             "n_classes": n_classes,
    #             "input_height": input_height,
    #             "input_width": input_width,
    #             "output_height": model.output_height,
    #             "output_width": model.output_width
    #         }, f)
    
    train_gen = image_segmentation_generator(train_images, train_annotations, batch_size, n_classes,
                                             input_height, input_width, model.output_height, model.output_width,
                                             read_image_type=read_image_type)
    
    val_gen = image_segmentation_generator(val_images, val_annotations, batch_size, n_classes,
                                           input_height, input_width, model.output_height, model.output_width,
                                           read_image_type=read_image_type)
    data, labels = next(train_gen)
    print("Data batch shape:", data.shape)
    print("Labels batch shape:", labels.shape)
    
    model.fit(train_gen, 
              steps_per_epoch=steps_per_epoch, 
              epochs=epochs,
              validation_data=val_gen, 
              validation_steps=val_steps_per_epoch, verbose =1)
    
    # if checkpoints_path:
    #     weights_path = os.path.join(checkpoints_path, "model_weights.h5")
    #     print(f"Saving model weights to {weights_path}")
    #     model.save_weights(weights_path)


In [5]:
from model import fcn_8_vgg
batch_size = 32
train_images = "training_data/train_images"
train_annotations = "training_data/train_annotations"
val_images = "training_data/val_images"
val_annotations = "training_data/val_annotations"
checkpoints_path = "checkpoints"
steps_per_epoch= len(os.listdir(train_annotations)) // batch_size
val_steps_per_epoch=len(os.listdir(val_images)) // batch_size
print(steps_per_epoch)
print(val_steps_per_epoch)
n_classes = 27
input_height = 224
input_width = 320
epochs = 5
load_weights = None 

model = fcn_8_vgg(n_classes=n_classes, input_height=input_height, input_width=input_width)

175
43


In [6]:
train(
    model=model,
    train_images=train_images,
    train_annotations=train_annotations,
    val_images=val_images,
    val_annotations=val_annotations,
    epochs=epochs,
    batch_size=batch_size,
    steps_per_epoch=steps_per_epoch,
    val_steps_per_epoch=val_steps_per_epoch,
    load_weights=load_weights
)

Data batch shape: (32, 224, 320, 3)
Labels batch shape: (32, 76096, 27)
Epoch 1/5
[1m175/175[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m7638s[0m 44s/step - accuracy: 0.3811 - loss: 12.4670 - val_accuracy: 0.6546 - val_loss: 1.1846
Epoch 2/5
[1m175/175[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m8301s[0m 47s/step - accuracy: 0.6654 - loss: 1.1590 - val_accuracy: 0.7043 - val_loss: 1.0260
Epoch 3/5
[1m175/175[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m8414s[0m 48s/step - accuracy: 0.7068 - loss: 1.0150 - val_accuracy: 0.7266 - val_loss: 0.9389
Epoch 4/5
[1m175/175[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m8410s[0m 48s/step - accuracy: 0.7253 - loss: 0.9395 - val_accuracy: 0.7385 - val_loss: 0.8814
Epoch 5/5
[1m175/175[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m8653s[0m 49s/step - accuracy: 0.7378 - loss: 0.8816 - val_accuracy: 0.7427 - val_loss: 0.8499


In [7]:
import os
weights_path = os.path.join("checkpoints", "model.weights.h5")
print(f"Saving model weights to {weights_path}")
model.save_weights(weights_path)

Saving model weights to checkpoints\model.weights.h5
