### **Check the backbone models and the structured segmentation models.**

In [1]:
import torch
from model_backbone.xception65 import Xception65
from model_backbone.mobilenet import MobileNetV2
from model_seg.deeplabv3plus import deeplabv3plus
from model_seg.deeplabv3plus_mobilev2 import deeplabv3plus_mobilev2
from model_seg.unet import unet
from model_seg.surface_water.gmnet import gmnet
from model_seg.surface_water.watnet import watnet
from model_seg.hrnet import hrnet
from torchsummary import summary
import matplotlib.pyplot as plt


##### Input data simulation. we assume the input data with patch size of 256x256, band number of 4, and batch size of 4.   
##### Note: the gmnet requirs multiscale patches input, so multiscale patches also are simulated here.


In [2]:
input = torch.randn(4, 4, 256, 256)
input_scales = [torch.randn(4, 4, 256, 256), torch.randn(4, 4, 256, 256), torch.randn(4, 4, 256, 256)]
truth = torch.randn(4, 1, 256, 256)
# # input_scales, truth = torch.load(f='data/test_patches/patch_000.pt')
# # input_scales = [torch.unsqueeze(input, 0) for input in input_scales]


### Check backbone networks.

In [3]:
model = Xception65(num_bands=4, num_classes=2)
outp = model(input)
print('output shape:', outp.shape)


output shape: torch.Size([4, 2])


In [4]:
model = MobileNetV2(num_bands=4, num_classes=2)
outp = model(input)
print('output shape:', outp.shape)


torch.Size([4, 4, 256, 256])
torch.Size([4, 32, 128, 128])
output shape: torch.Size([4, 2])


### Check segmentation networks.

In [5]:
model = unet(num_bands=4, num_classes=2)
outp = model(input)
print('output:', outp.shape)


output: torch.Size([4, 1, 256, 256])


In [6]:
model = deeplabv3plus(num_bands=4, num_classes=2)
outp = model(input)
print('output:', outp.shape)


output: torch.Size([4, 1, 256, 256])


In [7]:
model = deeplabv3plus_mobilev2(num_bands=4, num_classes=2, channels_fea=[16, 24, 64])
outp = model(input)
print('output:', outp.shape)


output: torch.Size([4, 1, 256, 256])


In [8]:
model = watnet(num_bands=4, num_classes=2)
outp = model(input)
print('output:', outp.shape)


output: torch.Size([4, 1, 256, 256])


In [9]:
model = hrnet(num_bands=4, num_classes=2)
outp = model(input)
print('output:', outp.shape)


output: torch.Size([4, 1, 256, 256])


In [13]:
model = gmnet(num_bands=4, num_classes=2, scale_high=2048, scale_mid=512, scale_low=256)
outp = model(input_scales)
print('output:', outp.shape)



output: torch.Size([4, 1, 256, 256])
