In [None]:
import sys
sys.path.insert(0, '/home/xp/stereo_toolbox/')
from stereo_toolbox.datasets import *
from stereo_toolbox.models import *
from stereo_toolbox.evaluation import *
from stereo_toolbox.loss_functions import *

import os
os.environ['HTTP_PROXY'] = 'http://10.13.73.98:7890'
os.environ['HTTPS_PROXY'] = 'http://10.13.73.98:7890'

import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torch.nn.functional as F
import torch.optim as optim
from tqdm import tqdm
torch.backends.cudnn.benchmark = True
import matplotlib.pyplot as plt
import argparse

import warnings
warnings.filterwarnings("ignore")

device = 'cuda:0'
model = load_checkpoint_flexible(IGEVStereo(),
                                 '/home/xp/BaCon/checkpoint/c1_p1e1_s1e-2_ms_mo_lr2e-4/iteration_00036500.pth',
                                 'model_student'
                                 )
model = nn.DataParallel(model).to(device)

In [None]:
sys.path.insert(0, '/home/xp/BaCon/')
from dataloader import BaCon_Dataset

dataset = BaCon_Dataset(split='split1_mini', training=True, root_dir='/data1/xp/Carla/data6/')
dataloader = DataLoader(dataset, batch_size=1, shuffle=True)

occlusion_threshold = 0.1

model.eval()

for i, data in enumerate(tqdm(dataloader)):

    student_flip_flag = data['student_flip_flag'].to(device)
    
    left = data['left_student'].to(device)
    right = data['right_student'].to(device)
    # gt_disp = data['gt_disp'].to(device).unsqueeze(1)
    raw_left = data['raw_left_student'].to(device)
    raw_right = data['raw_right_student'].to(device)
    with torch.no_grad():
        pred = model(left, right)


    # assert pred.shape == gt_disp.shape, f"Predicted shape {pred.shape} does not match ground truth shape {gt_disp.shape}"

    raw_left_np = raw_left.squeeze().permute(1,2,0).cpu().numpy()
    raw_right_np = raw_right.squeeze().permute(1,2,0).cpu().numpy()

    for occlusion in [True]:
        for stationary in [False]:
            for occlusion_threshold in [0.05, 0.1, 0.2, 0.4]:
                occ_mask = auto_mask(raw_left, raw_right, pred, stationary=stationary, occlusion=occlusion, occlusion_threshold=occlusion_threshold, reverse=student_flip_flag)
                occ_mask = occ_mask.squeeze().cpu().numpy()[..., np.newaxis]
                masked_raw_left = raw_left_np * (occ_mask) + raw_left_np * (1-occ_mask) * 0.65 + (1-occ_mask) * [1.0, 1.0, 0] * 0.35
                # 给 occ_mask 设置透明度 并叠加到 raw_left 上
                plt.imshow(masked_raw_left, alpha=1.0)
                plt.title(f'{occlusion_threshold}')
                plt.axis('off')
                plt.show()
    plt.imshow(raw_left_np)
    plt.axis('off')
    plt.show()        
    plt.imshow(raw_right_np)
    plt.axis('off')
    plt.show()

    print()

    if i > 10: 
        break

In [None]:
sys.path.insert(0, '/home/xp/BaCon/')
from dataloader import BaCon_Dataset

dataset = BaCon_Dataset(split='split1_mini', training=True, root_dir='/data1/xp/Carla/data6/', crop_size=[512,960])
dataloader = DataLoader(dataset, batch_size=1, shuffle=False)

model.eval()

def conditional_flip_vectorized(images, flip_flag):
    flipped_images = torch.flip(images, dims=[-1])
    result = flip_flag * flipped_images + (1 - flip_flag) * images
    return result
    

for i, data in enumerate(tqdm(dataloader)):
    if i <= 300: # 77 122
        continue

    mean = torch.tensor([0.485, 0.456, 0.406], device=device).view(1, 3, 1, 1)
    std = torch.tensor([0.229, 0.224, 0.225], device=device).view(1, 3, 1, 1)

    ref_s = data['raw_left_student'].to(device)
    tgt_s = data['raw_right_student'].to(device)
    ref_s = (ref_s - mean) / std
    tgt_s = (tgt_s - mean) / std
    ref_t = data['left_teacher'].to(device)
    tgt_t = data['right_teacher'].to(device)

    student_flip_flag = data['student_flip_flag'].to(device).reshape(-1, 1, 1, 1)
    teacher_flip_flag = data['teacher_flip_flag'].to(device).reshape(-1, 1, 1, 1)

    if student_flip_flag.item() == teacher_flip_flag.item():
        continue

    print('index', i)

    with torch.no_grad():
        pred_s = model(conditional_flip_vectorized(ref_s, student_flip_flag), conditional_flip_vectorized(tgt_s, student_flip_flag))
        pred_t = model(conditional_flip_vectorized(ref_t, teacher_flip_flag), conditional_flip_vectorized(tgt_t, teacher_flip_flag))

    pred_s = conditional_flip_vectorized(pred_s, student_flip_flag)
    pred_t = conditional_flip_vectorized(pred_t, teacher_flip_flag)

    ref_s = (ref_s * std + mean).clamp(min=0., max=1.)
    tgt_s = (tgt_s * std + mean).clamp(min=0., max=1.)
    ref_t = (ref_t * std + mean).clamp(min=0., max=1.)
    tgt_t = (tgt_t * std + mean).clamp(min=0., max=1.)

    mask_t = auto_mask(ref_t, tgt_t, pred_t, 
                        denorm=False, stationary=False, occlusion=True, 
                        occlusion_threshold=0.2, reverse=teacher_flip_flag)
    
    mask_s = auto_mask(ref_s, tgt_s, pred_s, 
                        denorm=False, stationary=False, occlusion=True, 
                        occlusion_threshold=0.2, reverse=student_flip_flag)
    
    ref_t = ref_t.squeeze().permute(1,2,0).cpu().numpy()
    tgt_t = tgt_t.squeeze().permute(1,2,0).cpu().numpy()
    tgt_s = tgt_s.squeeze().permute(1,2,0).cpu().numpy()

    mask_t = mask_t.squeeze().cpu().numpy()[..., np.newaxis]
    mask_s = mask_s.squeeze().cpu().numpy()[..., np.newaxis]

    mask_0 = 1 - mask_t
    mask_1 = mask_t * mask_s
    mask_2 = mask_t * (1 - mask_s)

    # Generate a pseudo-color image by assigning different colors to mask_0, mask_1, and mask_2
    pseudo_color_image = (
        mask_0 * [1.0, 0.0, 0.0] +  # Red for mask_0
        mask_1 * [0.0, 0.0, 0.0] +  # NULL for mask_1
        mask_2 * [0.0, 1.0, 0.0]    # Green for mask_2
    )

    # Overlay the pseudo-color image on ref_t
    overlay_image = ref_t * 0.75 + pseudo_color_image * 0.25

    fig, axes = plt.subplots(1, 3, figsize=(15, 5))

    if teacher_flip_flag.item():
        axes[0].imshow(tgt_t)
        axes[0].axis('off')
        axes[0].set_title("tgt teacher")

        axes[1].imshow(overlay_image)
        axes[1].axis('off')
        axes[1].set_title("ref, green a=2, red a=0, others a=1")

        axes[2].imshow(tgt_s)
        axes[2].axis('off')
        axes[2].set_title("tgt student")
    else:
        axes[0].imshow(tgt_s)
        axes[0].axis('off')
        axes[0].set_title("tgt student")

        axes[1].imshow(overlay_image)
        axes[1].axis('off')
        axes[1].set_title("ref, green a=2, red a=0, others a=1")

        axes[2].imshow(tgt_t)
        axes[2].axis('off')
        axes[2].set_title("tgt teacher")

    plt.tight_layout()
    plt.show()

    if i > 450:
        break
    


