## test training with real data

In [None]:
import sys
sys.path.append('..')

import os

In [None]:
from torch.optim import Adam
from torch.optim.lr_scheduler import CosineAnnealingLR

from torch.utils.data import DataLoader
import segmentation_models_pytorch as smp
import pytorch_lightning as pl


from pytorch_lightning.callbacks import ModelCheckpoint


In [None]:
from dataset import BinarySegmentationDataset, MultiLabelSegmentationDataset
from utils import visualize_gray, visualize_binaire
from train_eval import get_training_augmentation, get_validation_augmentation, get_portrait_validation_augmentation, get_portrait_augmentation, get_landscape_augmentation, get_landscape_validation_augmentation
from train_eval import BinarySegmentationModel, MulticlassSegmentationModel



In [None]:
DATA_DIR = "your/path/to/dataset"

yaml_path = f"{DATA_DIR}/class_gray_levels.yaml"

binary_dataset = BinarySegmentationDataset.from_yaml(
    images_dir=f"{DATA_DIR}/val",
    masks_dir=f"{DATA_DIR}/val_mask_gray",
    yaml_path=yaml_path,
    augmentation=get_training_augmentation()
)


train_loader = DataLoader(binary_dataset, batch_size=8, shuffle=True, num_workers=4)


In [None]:
model = BinarySegmentationModel(
    arch="Unet",
    encoder_name="resnet34",
    in_channels=3,
    loss_fn=smp.losses.FocalLoss(smp.losses.BINARY_MODE),
    optimizer=Adam,
    optimizer_kwargs={"lr": 2e-4},
    lr_scheduler=CosineAnnealingLR,
    lr_scheduler_kwargs={"T_max": 50, "eta_min": 1e-5},
    save_interval=1, 
)

trainer = pl.Trainer(max_epochs=2)
trainer.fit(model, train_loader, train_loader)

In [None]:
DATA_DIR = "your/path/to/dataset"

yaml_path = f"{DATA_DIR}/class_gray_levels.yaml"

binary_dataset = MultiLabelSegmentationDataset.from_yaml(
    images_dir=f"{DATA_DIR}/val",
    masks_dir=f"{DATA_DIR}/val_mask_gray",
    yaml_path=yaml_path,
    augmentation=get_portrait_augmentation()
)
num_classe = binary_dataset.get_num_classes()
print(num_classe)

train_loader = DataLoader(binary_dataset, batch_size=4, shuffle=True, num_workers=4)

In [None]:
run_name = "test0"


model = MulticlassSegmentationModel(
    arch="Unet",
    encoder_name="resnet34",
    in_channels=3,
    out_classes=num_classe,  
    loss_fn=smp.losses.FocalLoss(smp.losses.MULTICLASS_MODE),  
    optimizer=Adam,
    optimizer_kwargs={"lr": 2e-4},
    lr_scheduler=CosineAnnealingLR,
    lr_scheduler_kwargs={"T_max": 50, "eta_min": 1e-5},
    save_interval=1,  
)


checkpoint_callback = ModelCheckpoint(
    dirpath=f"logs/{run_name}/best_models",  
    filename="best-model-{epoch:02d}-{val_loss:.2f}", 
    monitor="val_loss",  
    mode="min", 
    save_top_k=1,  
    save_last=True, 
)


trainer = pl.Trainer( max_epochs=2, callbacks=[checkpoint_callback], default_root_dir=f"logs/{run_name}") 
trainer.fit(model, train_loader, train_loader)