In [1]:
import os

from lightning.pytorch import Trainer, seed_everything
from lightning.pytorch.loggers.tensorboard import TensorBoardLogger
from lightning.pytorch.callbacks import ModelCheckpoint, EarlyStopping

from src.dataset import TrashNet
from src.models import get_model
from src.trainer import WasteClassifier

# Configuration

Setting this for your need

In [2]:
class config:
    SEED = 42
    ACCELERATOR = "gpu"

    # Trainer
    EPOCHS = 200
    BATCH_SIZE = 2
    VAL_EACH_EPOCH = 2

    # Data
    DATA_DIR = "./split_data"

    # Tensorboard
    TENSORBOARD = {
        "DIR": "",
        "NAME": "LOG",
        "VERSION": "0",
    }

    # Checkpoint
    CHECKPOINT_DIR = os.path.join(TENSORBOARD["DIR"], TENSORBOARD["NAME"], TENSORBOARD["VERSION"], "CKPT")

    # ckpt path to test model
    TEST_CKPT_PATH = None

    # ckpt path to continue training
    CONTINUE_TRAINING = None

# Train

In [None]:
seed_everything(config.SEED)

model = get_model()
system = WasteClassifier(model=model)

dm = TrashNet(data_dir=config.DATA_DIR, batch_size=config.BATCH_SIZE)

checkpoint_callback = ModelCheckpoint(dirpath= config.CHECKPOINT_DIR, monitor="val_loss",
                                        save_top_k=3, mode="min")
early_stopping = EarlyStopping(monitor="val_loss", mode="min")

logger = TensorBoardLogger(save_dir=config.TENSORBOARD["DIR"], name=config.TENSORBOARD["NAME"], version=config.TENSORBOARD["VERSION"])

trainer = Trainer(accelerator=config.ACCELERATOR, check_val_every_n_epoch=config.VAL_EACH_EPOCH,
                gradient_clip_val=1.0,max_epochs=config.EPOCHS,
                enable_checkpointing=True, deterministic=True, default_root_dir=config.CHECKPOINT_DIR,
                callbacks=[checkpoint_callback, early_stopping], logger=logger, accumulate_grad_batches=5, log_every_n_steps=10)

trainer.fit(model=system, datamodule=dm, ckpt_path=config.CONTINUE_TRAINING)

# Test

Set `config.TEST_CKPT_PATH` = path/to/your/chechpoint.ckpt

In [None]:
model = get_model()
system = WasteClassifier(model=model)

trainer.test(model=system, datamodule=dm, ckpt_path=config.TEST_CKPT_PATH)

# Export to ONNX

In [None]:
import torch
from src.models import ConvNext
from src.dataset import WasteDataset

In [None]:
# Path to ckpt
path_to_ckpt = None

In [None]:
model = ConvNext(from_pretrained=False)
system = WasteClassifier(model=model)

ckpt = torch.load(path_to_ckpt, map_location='cpu')
system.load_state_dict(ckpt['state_dict'])
system.eval()
print("Done")

In [None]:
dataset = WasteDataset(root=config.DATA_DIR)
for img, label in dataset:
    print(img.size())
    print(label)
    break

In [None]:
system.to_onnx("model.onnx", input_sample=img.unsqueeze(0))