In [None]:
import os
from ultralytics import YOLO

# change directory
ROOT = "/home/suwako/workspace/"
os.chdir(ROOT)
os.getcwd()

# dataset
ds = "data/kaist-sanit.yaml"

ch = 4  # number of channels
nc = 1  # number of classes


## 1. full scratch train

In [None]:
scratch_model = YOLO("yolov8s.yaml")

scratch_model.train(
    data=ds,
    epochs=30,
    project="runs/transfer",
    name=f"kaist-scratch-{ch}ch",
)  # train the model

del scratch_model


## 2. freezed transfer train

In [None]:
def freez_layer(trainer):
    model = trainer.model.model
    for k, v in model.named_parameters():
        v.requires_grad = False

    # select input and output layer
    layers = list(model.children())
    input_layer = layers[0]
    output_layer = layers[-1]

    freezed = []

    # setup input and output channels
    if input_layer.conv.in_channels != ch:
        print(f"change input channels {input_layer.conv.in_channels} → {ch}")
        input_layer.conv.in_channels = ch

        for k, v in input_layer.named_parameters():
            freezed.append(k)
            v.requires_grad = True

    # setup output channels
    for l in output_layer.cv3:
        l[2].out_channels = nc

    for k, v in output_layer.named_parameters():
        freezed.append(k)
        v.requires_grad = True

    # print(*freezed)


In [None]:
# transfer learning

# load pretrained model
freez_model = YOLO("yolov8s.pt")

# add callback for freezing
freez_model.add_callback("on_train_start", freez_layer)

# train
freez_model.train(
    data=ds,
    epochs=50,
    close_mosaic=0,
    project="runs/transfer",
    name=f"kaist-freez-{ch}ch",
)


In [None]:
# fine-tune
try:
    tgr = freez_model.metrics.save_dir / "weights/best.pt"
except:
    tgr = "runs/transfer/kaist-freez-4ch/weights/best.pt"

del freez_model

tune_model = YOLO(tgr)

tune_model.train(
    data=ds,
    epochs=15,
    close_mosaic=5,
    project="runs/transfer",
    name=f"kaist-finetune-{ch}ch",
)

del tune_model


# 3. full transfer train (w/o freez)

In [None]:
# transfer learning

# load pretrained model
trans_model = YOLO("yolov8s.pt")

trans_model.train(
    data=ds,
    epochs=30,
    project="runs/transfer",
    name=f"kaist-trans-{ch}ch",
)  # train the model


## 4. Results

In [None]:
# yolo cfg=cfg/yolov8-aug.yaml model=yolov8s.pt data=data/All-Season.yaml name=All-Season-pre100