In [45]:
import torch
import torch.nn as nn
from utils.model_tools import ModelValidator
from utils.my_logging import Logger
from data.dataloader import build_dataloader
from model import build_model
from utils import parse_args
from tqdm import tqdm
import configs
import numpy as np
import matplotlib.pyplot as plt
import time

In [46]:
config_file = 'segformer_b4_gaofen'
weight = 'work/models/segformer_b4_gaofen/2021-09-14T16:40:39.pkl'

config = getattr(configs, config_file)
model = build_model(config["model"])
cfg_test_pipeline = config["test_pipeline"]
cfg_test_pipeline['dataloader']['batch_size'] = 1
cfg_test_pipeline['dataloader']['num_workers'] = 1
cfg_test_pipeline['dataloader']['shuffle'] = True

test_loader = build_dataloader(cfg_test_pipeline)
train_config = config["train_config"]
device = train_config["device"]
model.load_state_dict(torch.load(weight, map_location="cpu"))
model.to(device)

[ToTensor(), Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])]


Segformer(
  (encoder): mit_b4(
    (patch_embed1): OverlapPatchEmbed(
      (proj): Conv2d(3, 64, kernel_size=(7, 7), stride=(4, 4), padding=(3, 3))
      (norm): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
    )
    (patch_embed2): OverlapPatchEmbed(
      (proj): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
    )
    (patch_embed3): OverlapPatchEmbed(
      (proj): Conv2d(128, 320, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (norm): LayerNorm((320,), eps=1e-05, elementwise_affine=True)
    )
    (patch_embed4): OverlapPatchEmbed(
      (proj): Conv2d(320, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
    )
    (block1): ModuleList(
      (0): Block(
        (norm1): LayerNorm((64,), eps=1e-06, elementwise_affine=True)
        (attn): Attention(
          (q): Linear(in_features=64, out_features=6

In [48]:
with torch.no_grad():
    for val_img, val_mask in tqdm(
        test_loader, total=len(test_loader), desc="Valid", ncols=100
    ):
        val_img = val_img.to(device)
        # val_mask = val_mask.to(device)

        pred_img_1 = model(val_img)

        # pred_img_2 = model(torch.flip(val_img, [-1]))
        # pred_img_2 = torch.flip(pred_img_2, [-1])

        # pred_img_3 = model(torch.flip(val_img, [-2]))
        # pred_img_3 = torch.flip(pred_img_3, [-2])

        # pred_img_4 = model(torch.flip(val_img, [-1, -2]))
        # pred_img_4 = torch.flip(pred_img_4, [-1, -2])

        # pred_list = pred_img_1 + pred_img_2 + pred_img_3 + pred_img_4
        pred_list = pred_img_1
        pred_list = torch.argmax(pred_list.cpu(), 1).byte().numpy()
        
        gt = val_mask.data.numpy()
        
        vis = np.concatenate([gt[0], pred_list[0]], axis=0)
        vis = np.clip(vis, a_min=0.0, a_max=1.0)
        plt.imsave("vis/valid.jpg", vis)
        time.sleep(2)

Valid:   4%|█▉                                                     | 14/400 [00:30<14:11,  2.21s/it]


KeyboardInterrupt: 