In [16]:
# camera-ready

import sys

sys.path.append("/root/deeplabv3")
from datasets import DatasetThnSeq # (this needs to be imported before torch, because cv2 needs to be imported before torch for some reason)

sys.path.append("/root/deeplabv3/model")
from deeplabv3 import DeepLabV3

sys.path.append("/root/deeplabv3/utils")
from utils import label_img_to_color

import torch
import torch.utils.data
import torch.nn as nn
from torch.autograd import Variable
import torch.optim as optim
import torch.nn.functional as F

import numpy as np
import pickle
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import cv2

batch_size = 2

network = DeepLabV3("eval_seq_thn", project_dir="/root/deeplabv3").cuda()
network.load_state_dict(torch.load("/root/deeplabv3/pretrained_models/model_13_2_2_2_epoch_580.pth"))

val_dataset = DatasetThnSeq(thn_data_path="/root/deeplabv3/test_data")

num_val_batches = int(len(val_dataset)/batch_size)
print ("num_val_batches:", num_val_batches)

val_loader = torch.utils.data.DataLoader(dataset=val_dataset,
                                         batch_size=batch_size, shuffle=False,
                                         num_workers=1)

network.eval() # (set in evaluation mode, this affects BatchNorm and dropout)
unsorted_img_ids = []
kkk=0
for step, (imgs, img_ids) in enumerate(val_loader):
    print(kkk)
    kkk=kkk+1
    #if kkk==4:break
    with torch.no_grad(): # (corresponds to setting volatile=True in all variables, this is done during inference to reduce memory consumption)
        imgs = Variable(imgs).cuda() # (shape: (batch_size, 3, img_h, img_w))

        outputs = network(imgs) # (shape: (batch_size, num_classes, img_h, img_w))

        ########################################################################
        # save data for visualization:
        ########################################################################
        outputs = outputs.data.cpu().numpy() # (shape: (batch_size, num_classes, img_h, img_w))
        pred_label_imgs = np.argmax(outputs, axis=1) # (shape: (batch_size, img_h, img_w))
        pred_label_imgs = pred_label_imgs.astype(np.uint8)

        for i in range(pred_label_imgs.shape[0]):
            #print(i)
            pred_label_img = pred_label_imgs[i] # (shape: (img_h, img_w))
            img_id = img_ids[i]
            img = imgs[i] # (shape: (3, img_h, img_w))

            img = img.data.cpu().numpy()
            img = np.transpose(img, (1, 2, 0)) # (shape: (img_h, img_w, 3))
            img = img*np.array([0.229, 0.224, 0.225])
            img = img + np.array([0.485, 0.456, 0.406])
            img = img*255.0
            img = img.astype(np.uint8)

            pred_label_img_color = label_img_to_color(pred_label_img)
            overlayed_img = 0.35*img + 0.65*pred_label_img_color
            overlayed_img = overlayed_img.astype(np.uint8)

            img_h = overlayed_img.shape[0]
            img_w = overlayed_img.shape[1]

            # TODO! do this using network.model_dir instead
            cv2.imwrite(network.model_dir + "/" + img_id + ".png", img)
            cv2.imwrite(network.model_dir + "/" + img_id + "_pred.png", pred_label_img_color)
            cv2.imwrite(network.model_dir + "/" + img_id + "_overlayed.png", overlayed_img)

            unsorted_img_ids.append([int(img_id.split('.')[0]),img_id])

################################################################################
# create visualization video:
################################################################################
print(unsorted_img_ids)
out = cv2.VideoWriter("%s/thn_combined.avi" % network.model_dir, cv2.VideoWriter_fourcc(*"MJPG"), 12, (2*img_w, 2*img_h))
sorted_img_ids = sorted(unsorted_img_ids)
print(sorted_img_ids)
for item in sorted_img_ids:
    img_id=item[1]
    img = cv2.imread(network.model_dir + "/" + img_id + ".png", -1)
    pred_img = cv2.imread(network.model_dir + "/" + img_id + "_pred.png", -1)
    overlayed_img = cv2.imread(network.model_dir + "/" + img_id + "_overlayed.png", -1)

    combined_img = np.zeros((2*img_h, 2*img_w, 3), dtype=np.uint8)

    combined_img[0:img_h, 0:img_w] = img
    combined_img[0:img_h, img_w:(2*img_w)] = pred_img
    combined_img[img_h:(2*img_h), (int(img_w/2)):(img_w + int(img_w/2))] = overlayed_img

    out.write(combined_img)

out.release()


pretrained resnet, 18
num_val_batches: 86
0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
[[6, '6.jpg'], [150, '150.jpg'], [59, '59.jpg'], [69, '69.jpg'], [94, '94.jpg'], [70, '70.jpg'], [29, '29.jpg'], [157, '157.jpg'], [130, '130.jpg'], [23, '23.jpg'], [57, '57.jpg'], [91, '91.jpg'], [48, '48.jpg'], [25, '25.jpg'], [11, '11.jpg'], [22, '22.jpg'], [71, '71.jpg'], [66, '66.jpg'], [10, '10.jpg'], [93, '93.jpg'], [58, '58.jpg'], [13, '13.jpg'], [77, '77.jpg'], [134, '134.jpg'], [88, '88.jpg'], [41, '41.jpg'], [100, '100.jpg'], [123, '123.jpg'], [9, '9.jpg'], [168, '168.jpg'], [129, '129.jpg'], [125, '125.jpg'], [151, '151.jpg'], [167, '167.jpg'], [74, '74.jpg'], [136, '136.jpg'], [24, '24.jpg'], [38, '38.jpg'], [144, '144.jpg'], [165, '165.jpg'], [153, '153.jpg'], [121, '121.jpg'], [78, '78.

In [3]:
sys.path

['/root/WJY/final',
 '/root/anaconda3/lib/python38.zip',
 '/root/anaconda3/lib/python3.8',
 '/root/anaconda3/lib/python3.8/lib-dynload',
 '',
 '/root/anaconda3/lib/python3.8/site-packages',
 '/root/rl/OneAgent',
 '/root/anaconda3/lib/python3.8/site-packages/IPython/extensions',
 '/root/.ipython']