In [None]:
import os
test_set_path = os.path.join(os.path.abspath(''), os.pardir, 'datasets', 'ma_dataset', 'combined', 'test')

In [None]:
import torch
from torch.utils.data import DataLoader
BATCH_SIZE = 1
NUM_QUERIES = 6
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(DEVICE)

In [None]:
from test_loss import gather_eval_data, show_lines
import time
import numpy as np

# MS-ERFNet Model

In [None]:
from mse_dataset import CropRowDataset
mse_img_height = 512
mse_img_width = 512
ds_mse = CropRowDataset(os.path.join(test_set_path, 'masks'), os.path.join(test_set_path, 'imgs'), (mse_img_height, mse_img_width), False)
dl_mse = DataLoader(ds_mse, BATCH_SIZE, shuffle=False, num_workers=0)

from ms_erfnet import MSERFNet
p_mse = os.path.join(os.path.abspath(''), 'best_mse.pt')
m_mse = MSERFNet()
m_mse.load_state_dict(torch.load(p_mse, map_location=DEVICE))
m_mse.to(DEVICE)
m_mse.eval()

import cv2 as cv
from mse_croprows_in_masks import MSECropRowFinder
crf_mse = MSECropRowFinder()

THRESHOLD = 0.2
kernel = np.ones((5,5), np.uint8)

import matplotlib.pyplot as plt

from tqdm import tqdm
with torch.no_grad():
  preds = []
  gts = []
  for test_batch in tqdm(dl_mse, desc=f'MS-ERFNet'):
    imgs = test_batch['image'].to(DEVICE)
    gt = [p.tolist()[0] for p in test_batch['poly']]
    gts.append(gt)

    start_time = time.time()
    pred = m_mse(imgs)
    for i in range (BATCH_SIZE):
      current = pred[i, 0]
      norm = (current - current.min()) / (current.max() - current.min() + 1e-8)
      np_norm = norm.cpu().numpy()

      fixed_bin = (np_norm < THRESHOLD).astype(np.uint8)
      fixed_bin = cv.morphologyEx(fixed_bin, cv.MORPH_CLOSE, kernel, iterations=1)
      fixed_lines, fixed_pt_lines = crf_mse.process(fixed_bin)
      preds.append(fixed_lines)
    duration = time.time() - start_time
  fps = len(preds) / duration

  gather_eval_data(os.path.join(os.path.abspath(''), os.pardir, 'mse_eval_data.csv'), preds, gts, ds_mse.filenames, fps, mse_img_height, mse_img_width)

    # plt.imshow(pred[i, 0].cpu().numpy(), cmap='jet')
    # plt.colorbar()
    # plt.show()

    # np_8 = (np_norm * 255).astype(np.uint8)
    # _, otsu = cv.threshold(np_8, 0, 1, cv.THRESH_BINARY_INV + cv.THRESH_OTSU)
    # otsu = cv.morphologyEx(otsu,cv.MORPH_OPEN, kernel, iterations=1)
    
    # fig, axs = plt.subplots(1, 3, figsize=(15, 4))
    
    # axs[0].imshow(np_norm, cmap='jet')
    # axs[0].set_title('Heatmap normalisiert')
    # axs[0].axis('off')

    # axs[1].imshow(otsu, cmap='gray')
    # axs[1].set_title('Otsu')
    # axs[1].axis('off')

    # axs[2].imshow(fixed_bin, cmap='gray')
    # axs[2].set_title('Fixed Thresh')
    # axs[2].axis('off')

    # plt.tight_layout()
    # plt.show()

    # line_img = show_lines(imgs[i], fixed_lines, gt)
    # plt.imshow(line_img)
    # plt.axis('off')
    # plt.show()

      
    # line_pairings = match_lines(fixed_lines, gt_lines, np_norm.shape[0], np_norm.shape[1])

    # print(f'Section-Angle-Loss: {section_angle_loss(line_pairings, np_norm.shape[0], np_norm.shape[1])}')
    # print(f'Lateral-Pixel-Loss: {lateral_pixel_loss(line_pairings, np_norm.shape[0], np_norm.shape[1])}')
    # print(f'Found {len(fixed_lines) - len(gt_lines)} lines more than in Groundtruth')
    # pass


# InstaCropNet

In [None]:
from SegCropNet.dataloader.data_loaders import TusimpleSet
ds_insta = TusimpleSet(test_set_path, img_size=(256, 512), transform=False, shuffle=False)
dl_insta = DataLoader(ds_insta, BATCH_SIZE, shuffle=False, num_workers=0, collate_fn=lambda x: x)

from SegCropNet.model.SegCropNet.SegCropNet import SegCropNet
p_insta = os.path.join(os.path.abspath(''), 'best_insta.pt')
m_insta = SegCropNet(arch='UNet')
m_insta.load_state_dict(torch.load(p_insta, map_location=DEVICE))
m_insta.to(DEVICE)
m_insta.eval()

from insta_cluster import dbscan
import numpy as np
from mse_croprows_in_masks import MSECropRowFinder
crf_insta = MSECropRowFinder()

THRESHOLD = 0.2
kernel = np.ones((5,5), np.uint8)

import matplotlib.pyplot as plt

import torch.nn.functional as F
from tqdm import tqdm
with torch.no_grad():
  preds = []
  gts = []
  duration = 0.0
  for test_batch in tqdm(dl_insta, desc=f'InstaCropNet'):
    inputs = test_batch[0]['input'].type(torch.FloatTensor).to(DEVICE)
    binaries = test_batch[0]['binary'].type(torch.LongTensor).to(DEVICE)
    instances = test_batch[0]['instance'].type(torch.FloatTensor).to(DEVICE)
    gt = test_batch[0]['poly']
    gts.append(test_batch[0]['poly'])

    start_time = time.time()
    pred = m_insta(inputs.unsqueeze(0))
    for i in range (BATCH_SIZE):
      pred_lines = dbscan((pred['binary_seg_pred'][i, 0].cpu().numpy() * 255).astype(np.uint8), pred['instance_seg_logits'][i].permute(1,2,0).cpu().numpy())
      preds.append(pred_lines)
    duration += time.time() - start_time
  fps = len(preds) / duration

  gather_eval_data(os.path.join(os.path.abspath(''), os.pardir, 'insta_eval_data.csv'), preds, gts, ds_insta._gt_img_list, fps, ds_insta.img_size[0], ds_insta.img_size[1])

      # fig, axs = plt.subplots(2, 3, figsize=(15, 4))
      
      # inp = inputs.cpu() * 0.5 + 0.5
      # inp = show_lines(inp, pred_lines, gt)
      # axs[0,0].imshow(inp)
      # axs[0,0].set_title('Input')
      # axs[0,0].axis('off')

      # axs[0,1].imshow(binaries.cpu(), cmap='gray')
      # axs[0,1].set_title('GT Binary')
      # axs[0,1].axis('off')

      # axs[0,2].imshow(instances.cpu(), cmap='gray')
      # axs[0,2].set_title('GT Instanced Binary')
      # axs[0,2].axis('off')
      
      # axs[1,0].imshow(F.softmax(pred['binary_seg_logits'][i, 1], dim=0).cpu().numpy(), cmap='jet')
      # axs[1,0].set_title('Probability of Crop Row')
      # axs[1,0].axis('off')

      # axs[1,1].imshow(pred['binary_seg_pred'][i, 0].cpu().numpy(), cmap='gray')
      # axs[1,1].set_title('Binary')
      # axs[1,1].axis('off')

      # axs[1,2].imshow(pred['instance_seg_logits'][i].permute(1,2,0).cpu().numpy(), cmap='gray')
      # axs[1,2].set_title('Instanced Binary')
      # axs[1,2].axis('off')

      # plt.tight_layout()
      # plt.show()
      # pass




# Transformer Based Model

In [None]:
from transformer_dataset import MaskLessDataset
ds_transformer = MaskLessDataset(os.path.join(test_set_path, 'labels'), os.path.join(test_set_path, 'imgs'), NUM_QUERIES, (360, 640), full_transform=False)
dl_transformer = DataLoader(ds_transformer, BATCH_SIZE, shuffle=False, num_workers=0)

from transformer import TransformerBasedModel
p_transformer = os.path.join(os.path.abspath(''), 'best_transformer.pt')
m_transformer = TransformerBasedModel(max_crop_rows=NUM_QUERIES)
m_transformer.load_state_dict(torch.load(p_transformer, map_location=DEVICE))
m_transformer.to(DEVICE)
m_transformer.eval()

import matplotlib.pyplot as plt

from tqdm import tqdm
with torch.no_grad():
  preds = []
  gts = []
  for test_batch in tqdm(dl_transformer, desc=f'Transformer'):
    imgs = test_batch['image'].to(DEVICE)
    gt = test_batch['gt']
    gt_classes = test_batch['class']
    gts.append(gt.squeeze().cpu().tolist())

    start_time = time.time()
    probs, pred = m_transformer(imgs)
    np_probs = probs.squeeze().cpu().numpy()
    confident = np.where(np_probs[:,1] > 0.8)
    preds.append(pred.squeeze().cpu().numpy()[confident].tolist())
    duration = time.time() - start_time
  fps = len(preds) / duration

  gather_eval_data(os.path.join(os.path.abspath(''), os.pardir, 'trans_eval_data.csv'), preds, gts, ds_transformer.filenames, fps, 360, 640)

    # line_img = show_lines(imgs[0], pred.squeeze().cpu().numpy()[confident].tolist(), gt.squeeze().cpu().tolist())
    # plt.imshow(line_img)
    # plt.axis('off')
    # plt.show()
