In [1]:
import math
import time

import h5py
import matplotlib.pyplot as plt
import mlflow
import numpy as np
from numpy.random import choice
import torch
import torch.nn as nn
from torch.cuda.amp import autocast
from torch.optim import SGD, Optimizer, Adam, RMSprop
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms

from tqdm.auto import tqdm as tq
from craterdetection.detection.loss_functions.dice import f_score

from craterdetection.detection.enet import ENet
from craterdetection.detection.training import CraterDataset, RAdam, dice_coefficient

from craterdetection.detection.loss_functions.lovasz import LovaszHingeLoss

from torchsummary import summary

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [2]:
mlflow.set_tracking_uri("http://localhost:5000/")
mlflow.set_experiment("crater-detection")

In [3]:
dataset_path = "../data/dataset4.h5"
train_dataset = CraterDataset(dataset_path, "training", device)
train_loader = DataLoader(train_dataset, batch_size=20, shuffle=True, num_workers=0)

validation_dataset = CraterDataset(dataset_path, "validation", device)
validation_loader = DataLoader(train_dataset, batch_size=20, num_workers=0)

test_dataset = CraterDataset(dataset_path, "test", device)
test_loader = DataLoader(train_dataset, batch_size=10, shuffle=True, num_workers=0)

In [4]:
model = ENet(num_classes=1)
model.to(device)
summary(model, input_size=(1, 256, 256))

RuntimeError: Given groups=1, weight of size [4, 16, 2, 2], expected input[2, 14, 128, 128] to have 16 channels, but got 14 channels instead

In [None]:
images, masks = next(iter(test_loader))
images, masks = images.to(device), masks.to(device)
model.eval()
with torch.no_grad():
    out = nn.Sigmoid()(model(images))
    # out = out > 0.5

fig, axes = plt.subplots(len(images), 3, figsize=(15, 60))

for i in range(len(images)):
    axes[i, 0].imshow(images[i, 0].cpu().numpy()*255, cmap='Greys_r')
    axes[i, 1].imshow(masks[i, 0].cpu().numpy(), cmap='Greys_r')
    axes[i, 2].imshow(out[i, 0].cpu().numpy(), cmap='Greys_r')
fig.tight_layout()