In [1]:
args = {
    'image_path': 'data/raw/image.jp2',
    'mask_path': 'data/raw/mask.jp2',
    'patch_size': 512,
    'path_to_save': 'data',
    'save_images': True,
    'mean_std_path': 'cashe/mean_std.pth',
    'test_ratio' : 0.2,
    'device': 'cuda',
    'batch_size': 16,
    'num_workers': 0,
    'lr': 1e-3,
    'momentum': 0.9,
    'weight_decay': 1e-4,
    'epochs': 50,
    'encoder_name': 'efficientnet-b2',
    'encoder_weights': 'imagenet',
    'crop_size': 256,
    'model_save_path': 'model/best_model.pth',
    'class_weight': 550
}

In [2]:
from train import main, get_datasets, get_mean_std
import torch.nn.functional as F
import segmentation_models_pytorch as smp
import torch
import matplotlib.pyplot as plt

In [3]:
main(args)



  0%|          | 0/22 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

EPOCH 1| TRAIN LOSS: 1.926|  TEST LOSS: 1.777|  TEST_PRECISION: 0.001| TEST_RECALL: 0.522| TEST_ACCURACY: 0.501| TEST_F1: 0.002| TEST_IOU: 0.001| 
Saving new best model at model/best_model.pth


  0%|          | 0/22 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

EPOCH 2| TRAIN LOSS: 1.660|  TEST LOSS: 1.408|  TEST_PRECISION: 0.002| TEST_RECALL: 0.558| TEST_ACCURACY: 0.463| TEST_F1: 0.004| TEST_IOU: 0.002| 
Saving new best model at model/best_model.pth


  0%|          | 0/22 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

EPOCH 3| TRAIN LOSS: 1.613|  TEST LOSS: 1.276|  TEST_PRECISION: 0.002| TEST_RECALL: 0.813| TEST_ACCURACY: 0.433| TEST_F1: 0.004| TEST_IOU: 0.002| 
Saving new best model at model/best_model.pth


  0%|          | 0/22 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

EPOCH 4| TRAIN LOSS: 1.517|  TEST LOSS: 1.357|  TEST_PRECISION: 0.002| TEST_RECALL: 0.856| TEST_ACCURACY: 0.307| TEST_F1: 0.004| TEST_IOU: 0.002| 


  0%|          | 0/22 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

EPOCH 5| TRAIN LOSS: 1.476|  TEST LOSS: 1.292|  TEST_PRECISION: 0.002| TEST_RECALL: 0.875| TEST_ACCURACY: 0.315| TEST_F1: 0.004| TEST_IOU: 0.002| 


  0%|          | 0/22 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

EPOCH 6| TRAIN LOSS: 1.611|  TEST LOSS: 1.275|  TEST_PRECISION: 0.002| TEST_RECALL: 0.876| TEST_ACCURACY: 0.317| TEST_F1: 0.004| TEST_IOU: 0.002| 
Saving new best model at model/best_model.pth


  0%|          | 0/22 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

EPOCH 7| TRAIN LOSS: 1.490|  TEST LOSS: 1.318|  TEST_PRECISION: 0.002| TEST_RECALL: 0.850| TEST_ACCURACY: 0.327| TEST_F1: 0.004| TEST_IOU: 0.002| 


  0%|          | 0/22 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

EPOCH 8| TRAIN LOSS: 1.459|  TEST LOSS: 1.266|  TEST_PRECISION: 0.002| TEST_RECALL: 0.769| TEST_ACCURACY: 0.309| TEST_F1: 0.004| TEST_IOU: 0.002| 
Saving new best model at model/best_model.pth


  0%|          | 0/22 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

EPOCH 9| TRAIN LOSS: 1.426|  TEST LOSS: 1.283|  TEST_PRECISION: 0.002| TEST_RECALL: 0.850| TEST_ACCURACY: 0.352| TEST_F1: 0.004| TEST_IOU: 0.002| 


  0%|          | 0/22 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [None]:
model = smp.Unet(
        encoder_name=args['encoder_name'],
        encoder_weights=args['encoder_weights']
    )
model.to('cpu')

In [None]:
model.load_state_dict(torch.load('model/best_model.pth'))
model.to('cpu')
model.eval()
print('Model loaded')

In [None]:
ts, vs = get_datasets(args)

In [None]:
iter_vs = iter(vs)

In [None]:
# create a new figure
fig = plt.figure(figsize=(10, 5))
org, mask = next(iter_vs)
with torch.inference_mode():
    prediction = model.predict(org.unsqueeze(dim = 0)).squeeze(0)
    prediction_img = torch.moveaxis(F.softmax(prediction, dim = 1), 0, -1)
    
# add the original image subplot to the figure
ax1 = fig.add_subplot(1, 3, 1)
ax1.imshow(org.permute(1,2,0))
ax1.set_title('Original Image')

# add the mask subplot to the figure
ax2 = fig.add_subplot(1, 3, 2)
ax2.imshow(mask.permute(1,2,0))
ax2.set_title('Mask')

# add the prediction subplot to the figure
ax3 = fig.add_subplot(1, 3, 3)
ax3.imshow(prediction_img)
ax3.set_title('Prediction')

# adjust the spacing between subplots
fig.tight_layout()

# show the figure
plt.show()