In [None]:
import torch

In [None]:
# Logging
import logging
logger = logging.getLogger(__name__)
logging.basicConfig(filename='mse.log', level=logging.INFO)

In [None]:
# Hyperparameters - according to paper
NUM_EPOCHS = 1000
BATCH_SIZE = 2#16
LEARNING_RATE = 0.007
IMG_SIZE = (512, 512)
NUM_QUERIES = 11
NUM_WORKERS =  0#16
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(DEVICE)
logger.info(f'Device: {DEVICE}')

In [None]:
# Init: Datasets and Dataloader
import os
from mse_dataset import CropRowDataset
from torch.utils.data import DataLoader

#train_path = os.path.join(os.path.abspath(''), 'dataset', 'train')
train_path = os.path.join(os.path.abspath(''), os.pardir, 'datasets', 'ma_dataset', 'combined', 'train')
train_dataset = CropRowDataset(label_path = os.path.join(train_path, 'masks'), img_path = os.path.join(train_path, 'imgs'), img_size=IMG_SIZE)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers = NUM_WORKERS)

#val_path = os.path.join(os.path.abspath(''), 'dataset', 'val')
val_path = os.path.join(os.path.abspath(''), os.pardir, 'datasets', 'ma_dataset', 'combined', 'val')
val_dataset = CropRowDataset(label_path = os.path.join(val_path, 'masks'), img_path = os.path.join(val_path, 'imgs'), img_size=IMG_SIZE)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS)
logger.info('created datasets and dataloader')

In [None]:
# Init: Model
import torch.optim as optim
from ms_erfnet import MSERFNet
from hungarian_loss import HungarianLoss
model = MSERFNet().to(DEVICE)
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE, weight_decay=1e-5)
patience_counter = 0
best_loss = float('inf')
patience_limit = 20
epoch_start = 0
logger.info('created model and optimizer')

In [None]:
# Load Model if already existing
checkpoint_path = os.path.join(os.path.abspath(''), 'mse_checkpoint.pt.tar')
if os.path.isfile(checkpoint_path):
    logger.info('found existing model')
    checkpoint = torch.load(checkpoint_path, map_location=DEVICE)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    epoch_start = checkpoint['epoch']
    best_loss = checkpoint['loss']
    patience_counter = checkpoint['patience_counter']
    logger.info('loaded existing model for continuation of training')

In [None]:
# Training Cycle
from tqdm import tqdm
import matplotlib.pyplot as plt
for epoch in range(epoch_start, NUM_EPOCHS, 1):
    model.train()
    epoch_loss = 0.0
    for train_batch in tqdm(train_loader, desc=f'Epoch {epoch+1}/{NUM_EPOCHS}'):
        train_images = train_batch['image'].to(DEVICE)
        train_gt_params = train_batch['gt'].to(DEVICE)
        
        #===DEBUG===
        # show_mask = train_gt_params[0].squeeze(0).cpu().numpy()
        # show_img = train_images[0].cpu().permute(1,2,0).numpy()
        # fig, axs = plt.subplots(1, 2)
        # axs[0].imshow(show_img)
        # axs[0].axis('off')
        # axs[1].imshow(show_mask, cmap='gray', vmin=0, vmax=1)
        # axs[1].axis('off')
        # plt.tight_layout()
        # plt.show()
        #===DEBUG END===
        
        optimizer.zero_grad()

        train_pred_params = model(train_images)
        train_gt_params = train_gt_params.squeeze(1).long()
        batch_loss = torch.nn.functional.cross_entropy(train_pred_params, train_gt_params, reduction='mean')

        batch_loss.backward()
        optimizer.step()
        epoch_loss += batch_loss.item()
    avg_train_loss = epoch_loss / len(train_loader)

    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for val_batch in tqdm(val_loader, desc='Validation'):
            val_imgs = val_batch['image'].to(DEVICE)
            val_gt_params = val_batch['gt'].to(DEVICE)
            val_pred_params = model(val_imgs)
            val_gt_params = val_gt_params.squeeze(1).long()
            val_loss += torch.nn.functional.cross_entropy(val_pred_params, val_gt_params, reduction='mean')
        avg_val_loss = val_loss / len(val_loader)
        print(f'train_loss: {avg_train_loss} | val_loss: {avg_val_loss}')
        if avg_val_loss < best_loss:
            best_loss = avg_val_loss
            torch.save(model.state_dict(), 'best_mse.pt')
            print('New Model saved')
            patience_counter = 0
            logger.info(f'updated best_model in epoch {epoch+1} with training_loss: {avg_train_loss} and validation_loss: {avg_val_loss}')
        #else:
            #patience_counter += 1
            #if patience_counter >= patience_limit:
                #print('Should: Early stop due to no improvement in validation')
                #logger.info(f'Should: Early stopping in epoch {epoch}')
                #break
    
    if (epoch+1) % 10 == 0:
        torch.save({'epoch': epoch,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'loss': best_loss,
                    'patience_counter': patience_counter},
                   'mse_checkpoint.pt.tar')
        logger.info(f'created checkpoint for epoch {epoch+1} with training_loss: {avg_train_loss} and validation_loss: {avg_val_loss}')