## Check the backbone models and the structured models.


In [1]:
import torch
from model.backbone.xception65 import Xception65
from model.backbone.mobilenet import MobileNetV2
from model.seg_model.deeplabv3plus import deeplabv3plus
from model.seg_model.deeplabv3plus_mobilev2 import deeplabv3plus_mobilev2
from model.seg_model.unet import unet
from model.seg_model.unet_scales import unet_scales
from model.seg_model.unet_scales_gate import unet_scales_gate
from model.seg_model.hrnet import hrnet
from torchsummary import summary
import matplotlib.pyplot as plt



In [2]:
input = torch.randn(4, 4, 256, 256)
input_scales = [torch.randn(4, 4, 256, 256),torch.randn(4, 4, 256, 256),torch.randn(4, 4, 256, 256)]
truth = torch.randn(4, 1, 256, 256)
# # input_scales, truth = torch.load(f='data/test_patches/patch_000.pt')
# # input_scales = [torch.unsqueeze(input, 0) for input in input_scales]


### Check backbone networks.

In [4]:
model = Xception65(num_bands=4, num_classes=2)
outp = model(input)
print('output shape:', outp.shape)


output shape: torch.Size([4, 2])


In [6]:
model = MobileNetV2(num_bands=4, num_classes=2)
outp = model(input)
print('output shape:', outp.shape)


output shape: torch.Size([4, 2])


### Check U-Net

In [8]:
model = unet(num_bands=4, num_classes=2)
outp = model(input)
print('output:', outp.shape)


output: torch.Size([4, 1, 256, 256])


### Check Deeplabv3plus

In [11]:
model = deeplabv3plus(num_bands=4, num_classes=2)
outp = model(input)
print('output:', outp.shape)



output: torch.Size([4, 1, 256, 256])


### Check Deeplabv3plus_imp

In [12]:
model = deeplabv3plus_mobilev2(num_bands=4, num_classes=2, channels_fea=[16,24,64])
outp = model(input)
print('output:', outp.shape)


output: torch.Size([4, 1, 256, 256])


### Check HRNet

In [13]:
model = hrnet(num_bands=4, num_classes=2)
outp = model(input)
print('output:', outp.shape)


output: torch.Size([4, 1, 256, 256])


### Check unet_scales

In [14]:
model = unet_scales(num_bands=4, num_classes=2, scale_high=2048, scale_mid=512, scale_low=256)
outp = model(input_scales)
print('output:', outp.shape)


output: torch.Size([4, 1, 256, 256])


### Check unet_scales_gate

In [15]:
model = unet_scales_gate(num_bands=4, num_classes=2, scale_high=2048, scale_mid=512, scale_low=256)
outp = model(input_scales)
print('output:', outp.shape)


output: torch.Size([4, 1, 256, 256])
