In [1]:
!pip3 install tqdm torch torchvision enet-seifeddine-dridi --index-url https://download.pytorch.org/whl/cu118 --extra-index-url https://test.pypi.org/simple/ --extra-index-url https://pypi.org/simple/

Looking in indexes: https://download.pytorch.org/whl/cu118, https://test.pypi.org/simple/, https://pypi.org/simple/


In [2]:
from time import time

import torch
from tqdm import tqdm
from lanenet.enet.config import EnetConfig
from lanenet.enet.model_utils import load_model, eval_model, compute_loss, segment_image

In [3]:
pretrained_model_path = 'pretrained_model/enet_model_2116.pt'
# pretrained_model_path = None
config = EnetConfig(pretrained_model_path=pretrained_model_path, train_full_model=True, max_epoch=1, dataset_root_folder='datasets/cityscapes/data_unzipped')
device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
torch.multiprocessing.set_start_method('spawn')
torch.set_flush_denormal(True)
model, optimizer, train_dataset, test_dataset = load_model(config, device)
train_dataset_iter = iter(train_dataset)
test_dataset_iter = iter(test_dataset)

last_checkpoint_saving_time = time()
saving_period = 2 * 60  # 2 minutes
best_eval_loss = float('inf')

Checkpoint file pretrained_model/enet_model_2116.pt successfully loaded


In [5]:
bar_format = '{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, ''{rate_noinv_fmt}{postfix}]'
progress_bar = tqdm(range(0, 1), bar_format=bar_format)
for epoch in progress_bar:
    try:
        in_tensor, target = next(train_dataset_iter)
    except StopIteration:
        # Iterator is exhausted
        train_dataset_iter = iter(train_dataset)
        in_tensor, target = next(train_dataset_iter)
    in_tensor = in_tensor.to(device)
    target = target.to(device)
    logits = model(in_tensor)
    loss = compute_loss(logits, target, config.custom_weight_scaling_const, config.scaling_props_range)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    progress_bar.set_description(f"Epoch [{epoch}/{config.max_epoch}]")
    progress_bar.set_postfix(loss=loss.item())
    if time() - last_checkpoint_saving_time >= saving_period:
        last_checkpoint_saving_time = time()
        try:
            eval_loss = eval_model(model, test_dataset_iter, config.custom_weight_scaling_const,
                                   config.scaling_props_range, device)
        except StopIteration:
            # Iterator is exhausted
            test_dataset_iter = iter(test_dataset)
            eval_loss = eval_model(model, test_dataset_iter, config.custom_weight_scaling_const,
                                   config.scaling_props_range, device)
        if eval_loss < best_eval_loss:
            config.save_checkpoint(model, optimizer, loss, epoch)
            best_eval_loss = eval_loss

Epoch [0/100]: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:02<00:00,  0.39it/s, loss=3.44]


In [4]:
    try:
        in_tensor, target = next(test_dataset_iter)
    except StopIteration:
        # Iterator is exhausted
        test_dataset_iter = iter(test_dataset)
        in_tensor, target = next(test_dataset_iter)
    in_tensor = in_tensor.to(device)
    target = target.to(device)

In [None]:
from lanenet.enet.model_utils import segment_image

segment_image(model, in_tensor, device)