In [20]:
from torchvision.models import vgg #Last layer self.orientation=4 self.confidence=4
from torch_lib.Model import *
from torch_lib.ClassAverages import *
from torchvision import transforms
import os, glob, cv2
from library.ron_utils import *

def class2angle(bin_class,residual):
    # angle_per_class=2*torch.pi/float(12)
    angle_per_class=2*np.pi/float(4)
    angle=float(angle_per_class*bin_class)
    angle=angle+residual
    # print(angle)
    return angle

def get_calibration_cam_to_image(cab_f):
    for line in open(cab_f):
        if 'P2:' in line:
            cam_to_img = line.strip().split(' ')
            cam_to_img = np.asarray([float(number) for number in cam_to_img[1:]])
            cam_to_img = np.reshape(cam_to_img, (3, 4))
            return cam_to_img


#def Run_GT_pred_labels(weights_path, pred_label_root):
weights_path = 'weights/BL_4bin_epoch_20.pkl'
pred_label_root = './BL_class2angle'
os.makedirs(pred_label_root, exist_ok=True)
my_vgg = vgg.vgg19_bn(pretrained=True)
# dim
#my_vgg.features[0] = nn.Conv2d(4, 64, (3,3), (1,1), (1,1))
model = Model(features=my_vgg.features).cuda()

checkpoint = torch.load(weights_path)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
# for img processing
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
process = transforms.Compose([transforms.ToTensor(), normalize])

# Kitti image_2 dir / label_2 dir
img_root = "./Kitti/training/image_2"
label_root = "./Kitti/training/label_2"
calib_root = "./Kitti/training/calib"
extra_label_root = "./Kitti/training/extra_label"

images = glob.glob(os.path.join(img_root, '*.png'), recursive=True)
labels = glob.glob(os.path.join(label_root, '*.txt'), recursive=True)
calibs = glob.glob(os.path.join(calib_root, '*.txt'), recursive=True)
extra = glob.glob(os.path.join(extra_label_root, '*.txt'), recursive=True)

# dim averages
averages_all = ClassAverages()
start = time.time()
for i in range(len(images)):
    img = cv2.imread(images[i])
    cam_to_img = get_calibration_cam_to_image(calibs[i])

    CLASSes = list()
    BOX2Ds = list()
    CROPs_tensor = list()
    Alphas = list()
    THETAs = list()
    extra_labels = open(extra[i]).read().splitlines()


    with open(labels[i]) as f:
        lines = f.readlines()

        for idx, line in enumerate(lines):
            elements = line[:-1].split()
            if elements[0] == 'DontCare':
                continue
            for j in range(1, len(elements)):
                elements[j] = float(elements[j])

            CLASSes.append(elements[0])
            top_left = (int(round(elements[4])), int(round(elements[5])))
            btm_right = (int(round(elements[6])), int(round(elements[7])))
            box = [top_left, btm_right]
            BOX2Ds.append(box)
            #cv2 is(H,W,3)
            crop = img[top_left[1]:btm_right[1]+1, top_left[0]:btm_right[0]+1] 
            crop = cv2.resize(src = crop, dsize=(224, 224), interpolation=cv2.INTER_CUBIC)
            crop = process(crop)
            # Use calc function if cam_to_img changes to "camera_cal/calib_cam_to_cam.txt"
            theta_ray = float(extra_labels[idx].split()[5])
            THETAs.append(theta_ray)
            CROPs_tensor.append(crop)
            #4dim
            #cond = torch.tensor(theta_ray).expand(1, crop.shape[1], crop.shape[2])
            #img_cond = torch.concat((crop, cond), dim=0) # 3+1, 224, 224 + grouploss看看
            #CROPs_tensor.append(img_cond)
            
        # put together as a batch
        # model regress part
        input_ = torch.stack(CROPs_tensor).cuda()

        [ORIENTs, CONFs, delta_DIMs] = model(input_)
        
        cls_argmax = torch.max(CONFs, dim=1)[1]
        #print(cls_argmax)
        resdiual_orient = ORIENTs[torch.arange(len(ORIENTs)), cls_argmax]
        for argmax, resdiual in zip(cls_argmax, resdiual_orient):
            alpha=class2angle(argmax,resdiual)
            if alpha >np.pi:
                alpha-=(2*np.pi)
            Alphas.append(alpha)
        

    #write pred_label.txt 
    with open(labels[i].replace(label_root, pred_label_root),'w') as new_f:
        pred_labels = ''
        for class_, delta, alpha, theta, box_2d in zip(CLASSes, delta_DIMs, Alphas, THETAs, BOX2Ds):
            delta = delta.cpu().data.numpy() #torch->numpy
            alpha = alpha.cpu().data.numpy() #torch->numpy
            dim = delta + averages_all.get_item(class_)
            rotation_y = alpha + theta
            loc, _ = calc_location(dim, cam_to_img, box_2d, alpha, theta)

            pred_labels += '{CLASS} 0.0 0 {A:.2f} {left} {top} {right} {btm} {H:.2f} {W:.2f} {L:.2f} {X:.2f} {Y:.2f} {Z:.2f} {Ry:.2f}\n'.format(
                CLASS=class_, A=alpha, left=box_2d[0][0], top=box_2d[0][1], right=box_2d[1][0], btm=box_2d[1][1],
                H=dim[0], W=dim[1], L=dim[2], X=loc[0], Y=loc[1], Z=loc[2], Ry=rotation_y)
        #print(pred_labels)
        new_f.writelines(pred_labels)
    #print(pred_labels)
    if i%500==0:
        print(i)
print('Done, take {} min {} sec'.format((time.time()-start)//60, (time.time()-start)%60))# around 10min

0
500
1000
1500
2000
2500
3000
3500
4000
4500
5000
5500
6000
6500
7000
Done, take 10.0 min 30.51249098777771 sec
