In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.utils.data
import torchvision

from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.utils import save_image

import pandas as pd
import numpy as np
import scipy.io
import skimage.io

from PIL import Image, ImageFilter

# Define Model

In [2]:
class Encoder(nn.Module):
    def __init__(self, in_channel, out_channel):
        super().__init__()
        
        self.conv = nn.Conv2d(in_channel, out_channel, kernel_size=3, padding=1)
        self.bn = nn.BatchNorm2d(out_channel)
        
    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        x = F.relu(x)
        x, idx = F.max_pool2d(x, kernel_size=2, stride=2, return_indices=True)
        return x, idx

In [3]:
class Decoder(nn.Module):
    def __init__(self, in_channel, out_channel):
        super().__init__()
        
        self.conv = nn.Conv2d(in_channel, out_channel, kernel_size=3, padding=1)
        self.bn = nn.BatchNorm2d(out_channel)
        
    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        x = F.relu(x)
        return x

In [4]:
class SegNetBasic(nn.Module):
    """ 
        SegNet Basic is a smaller version of SegNet
        Please refer to this repository:
        https://github.com/0bserver07/Keras-SegNet-Basic
    """
    
    def __init__(self, in_channel, out_channel):
        super().__init__()
        
        self.encoder1 = Encoder(in_channel, 64)
        self.encoder2 = Encoder(64, 80)
        self.encoder3 = Encoder(80, 96)
        self.encoder4 = Encoder(96, 128)
        
        self.decoder1 = Decoder(128, 96)
        self.decoder2 = Decoder(96, 80)
        self.decoder3 = Decoder(80, 64)
        self.decoder4 = Decoder(64, out_channel)
        
    def forward(self, x):
        size1 = x.size()
        x, idx1 = self.encoder1(x)

        size2 = x.size()
        x, idx2 = self.encoder2(x)

        size3 = x.size()
        x, idx3 = self.encoder3(x)
        
        size4 = x.size()
        x, idx4 = self.encoder4(x)

        x = F.max_unpool2d(x, idx4, kernel_size=2, stride=2, output_size=size4)
        x = self.decoder1(x)
        
        x = F.max_unpool2d(x, idx3, kernel_size=2, stride=2, output_size=size3)
        x = self.decoder2(x)

        x = F.max_unpool2d(x, idx2, kernel_size=2, stride=2, output_size=size2)
        x = self.decoder3(x)

        x = F.max_unpool2d(x, idx1, kernel_size=2, stride=2, output_size=size1)
        x = self.decoder4(x)

        return x

# Dataset

In [5]:
# count the number of pixels in each affordance
import glob

path_list = glob.glob('./part-affordance-dataset/tools/*', recursive=True)
print(path_list)
print(len(path_list))

['./part-affordance-dataset/tools/mug_15', './part-affordance-dataset/tools/knife_12', './part-affordance-dataset/tools/spoon_01', './part-affordance-dataset/tools/ladle_04', './part-affordance-dataset/tools/trowel_02', './part-affordance-dataset/tools/scissors_01', './part-affordance-dataset/tools/bowl_08', './part-affordance-dataset/tools/tenderizer_02', './part-affordance-dataset/tools/shovel_02', './part-affordance-dataset/tools/bowl_07', './part-affordance-dataset/tools/spoon_02', './part-affordance-dataset/tools/saw_03', './part-affordance-dataset/tools/cup_04', './part-affordance-dataset/tools/spoon_06', './part-affordance-dataset/tools/spoon_04', './part-affordance-dataset/tools/spoon_08', './part-affordance-dataset/tools/mug_19', './part-affordance-dataset/tools/pot_01', './part-affordance-dataset/tools/knife_11', './part-affordance-dataset/tools/scoop_02', './part-affordance-dataset/tools/bowl_02', './part-affordance-dataset/tools/cup_05', './part-affordance-dataset/tools/mug

In [6]:
image_path_list = []

for path in path_list:
    i = glob.glob(path + '/*rgb.jpg')
    
    image_path_list += i

In [7]:
class_path_list = []

for path in image_path_list:
    c = path[:-7] + 'label.mat'
    class_path_list.append(c)

In [11]:
df = pd.DataFrame({
    'image_path': image_path_list,
    'class_path': class_path_list},
    columns=['image_path', 'class_path']
)

In [5]:
# # write data as csv_file
# df.to_csv('image_class_path.csv', index=None)

In [14]:
df = pd.read_csv("image_class_path.csv")

In [15]:
df_s = df.sample(frac=1)

In [16]:
df_s

Unnamed: 0,image_path,class_path
24546,./part-affordance-dataset/tools/hammer_04/hamm...,./part-affordance-dataset/tools/hammer_04/hamm...
25558,./part-affordance-dataset/tools/mug_04/mug_04_...,./part-affordance-dataset/tools/mug_04/mug_04_...
17133,./part-affordance-dataset/tools/bowl_05/bowl_0...,./part-affordance-dataset/tools/bowl_05/bowl_0...
26672,./part-affordance-dataset/tools/spoon_03/spoon...,./part-affordance-dataset/tools/spoon_03/spoon...
10798,./part-affordance-dataset/tools/mallet_03/mall...,./part-affordance-dataset/tools/mallet_03/mall...
11685,./part-affordance-dataset/tools/mug_14/mug_14_...,./part-affordance-dataset/tools/mug_14/mug_14_...
5518,./part-affordance-dataset/tools/scoop_02/scoop...,./part-affordance-dataset/tools/scoop_02/scoop...
7302,./part-affordance-dataset/tools/bowl_03/bowl_0...,./part-affordance-dataset/tools/bowl_03/bowl_0...
28594,./part-affordance-dataset/tools/scissors_04/sc...,./part-affordance-dataset/tools/scissors_04/sc...
26333,./part-affordance-dataset/tools/mug_02/mug_02_...,./part-affordance-dataset/tools/mug_02/mug_02_...


In [17]:
28843*0.8

23074.4

In [43]:
df_train = df_s[:23100]

In [19]:
df_train.to_csv('train.csv', index=None)

In [20]:
df_test = df_s[23100:]

In [21]:
df_test.to_csv('test.csv', index=None)

### the number of the pixels before center crop

In [172]:
# cnt_dict = {0:0, 1:0, 2:0, 3:0, 4:0, 5:0, 6:0, 7:0}

# for path in class_path_list:
#     mat = scipy.io.loadmat(path)
#     num, cnt = np.unique(mat['gt_label'], return_counts=True)
    
#     for n, c in zip(num, cnt):
#         cnt_dict[n] += c

cnt_dict

0: 8723340880,<br>
 1: 34159284,<br>
 2: 16009552,<br>
 3: 12433420,<br>
 4: 38476964,<br>
 5: 6773540,<br>
 6: 9273880,<br>
 7: 20102080<br>

# Define Dataset Class

In [5]:
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

import pandas as pd
import numpy as np
import scipy.io
import skimage.io
from PIL import Image, ImageFilter
from sklearn.model_selection import train_test_split

In [6]:
class PartAffordanceDataset(Dataset):
    """Part Affordance Dataset"""
    
    def __init__(self, csv_file, transform=None):
        super().__init__()
        
        self.image_class_path = pd.read_csv(csv_file)
        self.transform = transform
        
    def __len__(self):
        return len(self.image_class_path)
    
    def __getitem__(self, idx):
        image_path = self.image_class_path.iloc[idx, 0]
        class_path = self.image_class_path.iloc[idx, 1]
        image = skimage.io.imread(image_path) # read as numpy array
        cls = scipy.io.loadmat(class_path)["gt_label"]
        
        sample = {'image': image, 'class': cls}
        
        if self.transform:
            sample = self.transform(sample)
            
        return sample

In [7]:
def crop_center_numpy(array, crop_height, crop_weight):
    h, w = array.shape
    return array[h//2 - crop_height//2: h//2 + crop_height//2,
                 w//2 - crop_weight//2: w//2 + crop_weight//2
                ]

In [8]:
def crop_center_pil_image(pil_img, crop_width, crop_height):
    img_width, img_height = pil_img.size
    return pil_img.crop(((img_width - crop_width) // 2,
                         (img_height - crop_height) // 2,
                         (img_width + crop_width) // 2,
                         (img_height + crop_height) // 2))

In [9]:
class CenterCrop(object):
    def __call__(self, sample):
        image, cls = sample['image'], sample['class']
        
        image = Image.fromarray(np.uint8(image))
        
        image = crop_center_pil_image(image, 320, 240)
        cls = crop_center_numpy(cls, 240, 320)
        
        image = np.asarray(image)
        
        return {'image': image, 'class': cls}

In [10]:
class ToTensor(object):
    def __call__(self, sample):
        image, cls = sample['image'], sample['class']
        
        image = image.transpose((2, 0, 1))
        return {'image': torch.from_numpy(image).float(), 
                'class': torch.from_numpy(cls).long()}

In [11]:
# data = PartAffordanceDataset('image_class_path.csv',
#                                 transform=transforms.Compose([
#                                     CenterCrop(),
#                                     ToTensor()
#                                 ]))

# data_loader = DataLoader(data, batch_size=10, shuffle=False)

In [12]:
# mean = 0
# std = 0
# n = 0

# for sample in data_loader:
#     img = sample['image']   
#     img = img.view(len(img), 3, -1)
#     mean += img.mean(2).sum(0)
#     std += img.std(2).sum(0)
#     n += len(img)
    
# mean /= n
# std /= n

In [13]:
mean=[55.8630, 59.9099, 91.7419]
std=[31.6852, 29.8496, 19.0835]

In [14]:
class Normalize(object):
    def __call__(self, sample):
        image, cls = sample['image'], sample['class']
        
        image = transforms.functional.normalize(image, mean, std)
        
        return {'image': image, 'class': cls}

In [15]:
train_data = PartAffordanceDataset('train.csv',
                                transform=transforms.Compose([
                                    CenterCrop(),
                                    ToTensor(),
                                    Normalize()
                                ]))

In [16]:
test_data = PartAffordanceDataset('test.csv',
                                transform=transforms.Compose([
                                    CenterCrop(),
                                    ToTensor(),
                                    Normalize()
                                ]))

In [34]:
train_loader = DataLoader(train_data, batch_size=10, shuffle=True)
test_loader = DataLoader(test_data, batch_size=10, shuffle=False)
pred_loader = DataLoader(test_data, batch_size=10, shuffle=True)

### count the number of pixels in each class after center crop

In [35]:
# dataset = PartAffordanceDataset('image_class_path.csv',
#                                 transform=transforms.Compose([
#                                     CenterCrop(),
#                                     ToTensor()
#                                 ]))
# data_laoder = DataLoader(dataset, batch_size=100, shuffle=False)

In [36]:
cnt_dict = {0:0, 1:0, 2:0, 3:0, 4:0, 5:0, 6:0, 7:0}

for sample in data_laoder:
    img = sample['class'].numpy()
    
    num, cnt = np.unique(img, return_counts=True)
    
    for n, c in zip(num, cnt):
        cnt_dict[n] += c

In [37]:
cnt_dict

{0: 2078085712,
 1: 34078992,
 2: 15921090,
 3: 12433420,
 4: 38473752,
 5: 6773528,
 6: 9273826,
 7: 20102080}

cnt_dict

0: 2078085712,  
 1: 34078992,  
 2: 15921090,  
 3: 12433420,  
 4: 38473752,  
 5: 6773528,  
 6: 9273826,  
 7: 20102080  

# Training

### class weight

In [17]:
class_num = torch.tensor([2078085712, 34078992, 15921090, 12433420, 
                          38473752, 6773528, 9273826, 20102080])

total = class_num.sum().item()
print(total)

2215142400


In [18]:
class_num[0].float() / (100.0 * class_num.float())

tensor([0.0100, 0.6098, 1.3052, 1.6714, 0.5401, 3.0680, 2.2408, 1.0338])

In [19]:
frequency = class_num.float() / total
median = torch.median(frequency)

In [20]:
class_weight = median / frequency

In [21]:
class_weight

tensor([0.0077, 0.4672, 1.0000, 1.2805, 0.4138, 2.3505, 1.7168, 0.7920])

In [22]:
from tensorboardX import SummaryWriter
import tqdm

In [40]:
def eval_model(model, test_loader, device='cpu'):
    model.eval()
    
    intersection = torch.zeros(8)   # the dataset has 8 classes including background
    union = torch.zeros(8)
    
    for sample in test_loader:
        x, y = sample['image'], sample['class']
        
        x = x.to(device)
        y = y.to(device)
        
        with torch.no_grad():
            _, ypred = model(x).max(1)    # y_pred.shape => (N, 240, 320)
        
        for i in range(8):
            y_i = (y == i)           
            ypred_i = (ypred == i)   
            
            inter = (y_i.byte() & ypred_i.byte()).float().sum().to('cpu')
            intersection[i] += inter
            union[i] += (y_i.float().sum() + ypred_i.float().sum()).to('cpu') - inter
    
    """ iou[i] is the IoU of class i """
    iou = intersection / union
    
    return iou

In [41]:
def train_model(model, train_loader, test_loader, optimizer_cls=optim.Adam, 
                criterion=nn.CrossEntropyLoss(), max_epoch=200, device='cpu', writer=None):
    
    model.to(device)
    
    train_losses = []
    val_iou = []
    mean_iou = []
    best_iou = 0.0
    
    optimizer = optimizer_cls(model.parameters(), lr=0.01)
    
    for epoch in range(max_epoch):
        model.train()
        running_loss = 0.0
        
        for i, sample in tqdm.tqdm(enumerate(train_loader), total=len(train_loader)):
            optimizer.zero_grad()
            
            x, y = sample['image'], sample['class']
            
            x = x.to(device)
            y = y.to(device)

            h = model(x)
            loss = criterion(h, y)
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()

        train_losses.append(running_loss / i)
        
        val_iou.append(eval_model(model, test_loader, device))
        mean_iou.append(val_iou[-1].mean().item())
        
        if best_iou < mean_iou[-1]:
            best_iou = mean_iou[-1]
            torch.save(model.state_dict(), "./SegNet_with_class_weight(median)_results/best_iou_model.prm")
        
        if writer is not None:
            writer.add_scalar("train_loss", train_losses[-1], epoch)
            writer.add_scalar("mean_IoU", mean_iou[-1], epoch)
            writer.add_scalars("class_IoU", {'iou of class 0': val_iou[-1][0],
                                           'iou of class 1': val_iou[-1][1],
                                           'iou of class 2': val_iou[-1][2],
                                           'iou of class 3': val_iou[-1][3],
                                           'iou of class 4': val_iou[-1][4],
                                           'iou of class 5': val_iou[-1][5],
                                           'iou of class 6': val_iou[-1][6],
                                           'iou of class 7': val_iou[-1][7]}, epoch)
            
        print(epoch, train_losses[-1], mean_iou[-1])
        
    torch.save(model.state_dict(), "./SegNet_with_class_weight(median)_results/final_model.prm")

In [42]:
model = SegNetBasic(3, 8)
writer = SummaryWriter("./SegNet_with_class_weight(median)_results/")
train_model(model, train_loader, test_loader, criterion=nn.CrossEntropyLoss(weight=class_weight.to('cuda')), device="cuda", writer=writer)

100%|██████████| 2310/2310 [04:59<00:00,  7.52it/s]
  0%|          | 1/2310 [00:00<05:22,  7.16it/s]

0 0.9479384233407925 0.4458967447280884


100%|██████████| 2310/2310 [07:20<00:00,  4.70it/s]
  0%|          | 0/2310 [00:00<?, ?it/s]

1 0.31126673399898697 0.5815553069114685


100%|██████████| 2310/2310 [07:40<00:00,  6.37it/s]
  0%|          | 0/2310 [00:00<?, ?it/s]

2 0.1859119769239023 0.5518877506256104


100%|██████████| 2310/2310 [07:42<00:00,  4.99it/s]
  0%|          | 1/2310 [00:00<05:18,  7.26it/s]

3 0.13938977102107328 0.6321080923080444


100%|██████████| 2310/2310 [07:20<00:00,  6.36it/s]
  0%|          | 1/2310 [00:00<05:23,  7.14it/s]

4 0.12219092433956755 0.6218795776367188


100%|██████████| 2310/2310 [07:42<00:00,  4.89it/s]
  0%|          | 1/2310 [00:00<06:19,  6.09it/s]

5 0.10089501879234859 0.6404547095298767


100%|██████████| 2310/2310 [07:15<00:00,  4.93it/s]
  0%|          | 1/2310 [00:00<05:25,  7.09it/s]

6 0.09608773276735792 0.6605692505836487


100%|██████████| 2310/2310 [07:44<00:00,  4.72it/s]
  0%|          | 1/2310 [00:00<06:10,  6.23it/s]

7 0.08342502160411652 0.6491931676864624


100%|██████████| 2310/2310 [07:14<00:00,  4.88it/s]
  0%|          | 0/2310 [00:00<?, ?it/s]

8 0.07898528751868379 0.6751347780227661


100%|██████████| 2310/2310 [07:42<00:00,  4.84it/s]
  0%|          | 0/2310 [00:00<?, ?it/s]

9 0.07661883559621686 0.6486854553222656


100%|██████████| 2310/2310 [07:17<00:00,  4.72it/s]
  0%|          | 1/2310 [00:00<06:03,  6.35it/s]

10 0.07431364761187097 0.6700063943862915


100%|██████████| 2310/2310 [07:43<00:00,  4.97it/s]
  0%|          | 1/2310 [00:00<07:28,  5.15it/s]

11 0.06832222835733137 0.6523666381835938


100%|██████████| 2310/2310 [07:34<00:00,  5.04it/s]
  0%|          | 0/2310 [00:00<?, ?it/s]

12 0.0678605081544985 0.6935240626335144


100%|██████████| 2310/2310 [07:38<00:00,  5.97it/s]
  0%|          | 0/2310 [00:00<?, ?it/s]

13 0.06647871727355377 0.6817703247070312


100%|██████████| 2310/2310 [07:42<00:00,  5.13it/s]
  0%|          | 1/2310 [00:00<07:23,  5.21it/s]

14 0.061074292369512784 0.6786307096481323


100%|██████████| 2310/2310 [07:14<00:00,  5.93it/s]
  0%|          | 0/2310 [00:00<?, ?it/s]

15 0.056928208197844964 0.6548572182655334


100%|██████████| 2310/2310 [07:41<00:00,  4.79it/s]
  0%|          | 1/2310 [00:00<06:37,  5.80it/s]

16 0.05800027023934735 0.6962262988090515


100%|██████████| 2310/2310 [07:14<00:00,  5.13it/s]
  0%|          | 0/2310 [00:00<?, ?it/s]

17 0.054097605805445564 0.6845463514328003


100%|██████████| 2310/2310 [07:44<00:00,  5.04it/s]
  0%|          | 1/2310 [00:00<06:10,  6.24it/s]

18 0.055184435907665814 0.6935293674468994


100%|██████████| 2310/2310 [07:14<00:00,  4.92it/s]
  0%|          | 1/2310 [00:00<06:12,  6.19it/s]

19 0.053553013103595025 0.6814028024673462


100%|██████████| 2310/2310 [07:44<00:00,  5.16it/s]
  0%|          | 1/2310 [00:00<07:28,  5.15it/s]

20 0.052105821639254216 0.677937388420105


100%|██████████| 2310/2310 [07:22<00:00,  5.00it/s]
  0%|          | 0/2310 [00:00<?, ?it/s]

21 0.0533672124308753 0.7109189033508301


100%|██████████| 2310/2310 [07:40<00:00,  5.50it/s]
  0%|          | 1/2310 [00:00<07:36,  5.06it/s]

22 0.04990191492923621 0.6725143194198608


100%|██████████| 2310/2310 [07:35<00:00,  4.72it/s]
  0%|          | 1/2310 [00:00<06:29,  5.92it/s]

23 0.049677440373895744 0.6799400448799133


100%|██████████| 2310/2310 [07:32<00:00,  6.13it/s]
  0%|          | 1/2310 [00:00<05:20,  7.21it/s]

24 0.04876134100447955 0.6339251399040222


100%|██████████| 2310/2310 [07:44<00:00,  5.04it/s]
  0%|          | 0/2310 [00:00<?, ?it/s]

25 0.046119414063815097 0.6976977586746216


100%|██████████| 2310/2310 [07:15<00:00,  4.75it/s]
  0%|          | 1/2310 [00:00<05:49,  6.62it/s]

26 0.04984900370988819 0.6871145367622375


100%|██████████| 2310/2310 [07:45<00:00,  4.87it/s]
  0%|          | 0/2310 [00:00<?, ?it/s]

27 0.045758156151453906 0.7020561099052429


100%|██████████| 2310/2310 [07:14<00:00,  4.58it/s]
  0%|          | 0/2310 [00:00<?, ?it/s]

28 0.04664090252612488 0.6794081330299377


100%|██████████| 2310/2310 [07:42<00:00,  4.85it/s]
  0%|          | 1/2310 [00:00<06:24,  6.00it/s]

29 0.04519900467832427 0.7040132880210876


100%|██████████| 2310/2310 [07:16<00:00,  5.24it/s]
  0%|          | 1/2310 [00:00<06:49,  5.64it/s]

30 0.04501014156206545 0.6974311470985413


100%|██████████| 2310/2310 [07:43<00:00,  4.77it/s]
  0%|          | 1/2310 [00:00<05:51,  6.56it/s]

31 0.04526663488103823 0.7046093940734863


100%|██████████| 2310/2310 [07:32<00:00,  4.61it/s]
  0%|          | 0/2310 [00:00<?, ?it/s]

32 0.045994942183390536 0.7031357884407043


100%|██████████| 2310/2310 [07:39<00:00,  5.83it/s]
  0%|          | 0/2310 [00:00<?, ?it/s]

33 0.04100392541322565 0.7112332582473755


100%|██████████| 2310/2310 [07:41<00:00,  4.88it/s]
  0%|          | 1/2310 [00:00<06:28,  5.94it/s]

34 0.043919590768884686 0.6917855739593506


100%|██████████| 2310/2310 [07:17<00:00,  6.15it/s]
  0%|          | 0/2310 [00:00<?, ?it/s]

35 0.04166660578933538 0.706326425075531


100%|██████████| 2310/2310 [07:44<00:00,  4.87it/s]
  0%|          | 0/2310 [00:00<?, ?it/s]

36 0.04129197061684623 0.7290657758712769


100%|██████████| 2310/2310 [07:14<00:00,  5.29it/s]
  0%|          | 1/2310 [00:00<06:14,  6.17it/s]

37 0.04186017051356897 0.7046337127685547


100%|██████████| 2310/2310 [07:42<00:00,  4.70it/s]
  0%|          | 1/2310 [00:00<06:22,  6.03it/s]

38 0.04153247029250734 0.721276581287384


100%|██████████| 2310/2310 [07:13<00:00,  4.94it/s]
  0%|          | 0/2310 [00:00<?, ?it/s]

39 0.04215184145842204 0.6967706084251404


100%|██████████| 2310/2310 [07:44<00:00,  4.98it/s]
  0%|          | 0/2310 [00:00<?, ?it/s]

40 0.03851284468671547 0.6907954216003418


100%|██████████| 2310/2310 [07:25<00:00,  4.69it/s]
  0%|          | 0/2310 [00:00<?, ?it/s]

41 0.039737517019792674 0.7173754572868347


100%|██████████| 2310/2310 [07:41<00:00,  4.76it/s]
  0%|          | 0/2310 [00:00<?, ?it/s]

42 0.04092348965216427 0.704993486404419


100%|██████████| 2310/2310 [07:39<00:00,  4.83it/s]
  0%|          | 0/2310 [00:00<?, ?it/s]

43 0.03746678321325854 0.7154452204704285


100%|██████████| 2310/2310 [07:24<00:00,  5.88it/s]
  0%|          | 1/2310 [00:00<05:09,  7.47it/s]

44 0.0394193580303133 0.6968002319335938


100%|██████████| 2310/2310 [07:41<00:00,  4.78it/s]
  0%|          | 0/2310 [00:00<?, ?it/s]

45 0.039217261364381 0.7231190204620361


100%|██████████| 2310/2310 [07:13<00:00,  4.93it/s]
  0%|          | 1/2310 [00:00<06:50,  5.62it/s]

46 0.03713117645034585 0.7030229568481445


100%|██████████| 2310/2310 [07:43<00:00,  4.92it/s]
  0%|          | 0/2310 [00:00<?, ?it/s]

47 0.03903379186466756 0.711388111114502


100%|██████████| 2310/2310 [07:13<00:00,  4.83it/s]
  0%|          | 1/2310 [00:00<06:03,  6.34it/s]

48 0.03596191996835162 0.7213131189346313


100%|██████████| 2310/2310 [07:44<00:00,  5.43it/s]
  0%|          | 1/2310 [00:00<05:42,  6.75it/s]

49 0.038100232014044534 0.7187646627426147


100%|██████████| 2310/2310 [07:16<00:00,  4.76it/s]
  0%|          | 1/2310 [00:00<06:11,  6.21it/s]

50 0.03594690179085257 0.7120652794837952


100%|██████████| 2310/2310 [07:42<00:00,  4.61it/s]
  0%|          | 0/2310 [00:00<?, ?it/s]

51 0.038133805521694084 0.7239380478858948


100%|██████████| 2310/2310 [07:31<00:00,  4.81it/s]
  0%|          | 1/2310 [00:00<05:41,  6.75it/s]

52 0.03502317256386557 0.6882578730583191


100%|██████████| 2310/2310 [07:38<00:00,  6.16it/s]
  0%|          | 1/2310 [00:00<05:53,  6.54it/s]

53 0.03618180622407604 0.7408170104026794


100%|██████████| 2310/2310 [07:41<00:00,  4.69it/s]
  0%|          | 1/2310 [00:00<05:44,  6.70it/s]

54 0.03904643326944515 0.7428799271583557


100%|██████████| 2310/2310 [07:15<00:00,  6.14it/s]
  0%|          | 0/2310 [00:00<?, ?it/s]

55 0.037762125523259964 0.6992014646530151


100%|██████████| 2310/2310 [07:43<00:00,  4.70it/s]
  0%|          | 0/2310 [00:00<?, ?it/s]

56 0.035952637906619074 0.7103115320205688


100%|██████████| 2310/2310 [07:15<00:00,  5.05it/s]
  0%|          | 0/2310 [00:00<?, ?it/s]

57 0.03431552154971581 0.7185770273208618


100%|██████████| 2310/2310 [07:44<00:00,  4.64it/s]
  0%|          | 1/2310 [00:00<06:22,  6.03it/s]

58 0.03762117844041084 0.7136328220367432


100%|██████████| 2310/2310 [07:14<00:00,  5.14it/s]
  0%|          | 1/2310 [00:00<05:28,  7.03it/s]

59 0.0346130906624801 0.7297396659851074


100%|██████████| 2310/2310 [07:43<00:00,  4.70it/s]
  0%|          | 1/2310 [00:00<05:36,  6.86it/s]

60 0.034976318864181684 0.7186569571495056


100%|██████████| 2310/2310 [07:25<00:00,  5.37it/s]
  0%|          | 0/2310 [00:00<?, ?it/s]

61 0.033668282237702525 0.7406005859375


100%|██████████| 2310/2310 [07:44<00:00,  4.93it/s]
  0%|          | 0/2310 [00:00<?, ?it/s]

62 0.035178541133624744 0.7240380644798279


100%|██████████| 2310/2310 [07:44<00:00,  4.80it/s]
  0%|          | 1/2310 [00:00<05:21,  7.17it/s]

63 0.03297044756354812 0.7249711155891418


100%|██████████| 2310/2310 [07:20<00:00,  5.93it/s]
  0%|          | 1/2310 [00:00<07:30,  5.13it/s]

64 0.03383999320707857 0.7084926962852478


100%|██████████| 2310/2310 [07:43<00:00,  4.66it/s]
  0%|          | 0/2310 [00:00<?, ?it/s]

65 0.03427696725915861 0.7231460213661194


100%|██████████| 2310/2310 [07:12<00:00,  5.17it/s]
  0%|          | 1/2310 [00:00<07:25,  5.18it/s]

66 0.03527513094167332 0.7373120784759521


100%|██████████| 2310/2310 [07:42<00:00,  5.04it/s]
  0%|          | 1/2310 [00:00<05:49,  6.61it/s]

67 0.03227746438544813 0.7181694507598877


100%|██████████| 2310/2310 [07:14<00:00,  4.90it/s]
  0%|          | 0/2310 [00:00<?, ?it/s]

68 0.03284424947491646 0.7446251511573792


100%|██████████| 2310/2310 [07:42<00:00,  4.84it/s]
  0%|          | 0/2310 [00:00<?, ?it/s]

69 0.03453888057583065 0.7289090752601624


100%|██████████| 2310/2310 [07:19<00:00,  4.85it/s]
  0%|          | 0/2310 [00:00<?, ?it/s]

70 0.03393802315927639 0.7320225834846497


100%|██████████| 2310/2310 [07:43<00:00,  5.08it/s]
  0%|          | 1/2310 [00:00<05:38,  6.82it/s]

71 0.032130280500314176 0.7229412198066711


100%|██████████| 2310/2310 [07:31<00:00,  4.92it/s]
  0%|          | 1/2310 [00:00<06:25,  5.99it/s]

72 0.03206861813239279 0.7368724346160889


100%|██████████| 2310/2310 [07:32<00:00,  5.98it/s]
  0%|          | 0/2310 [00:00<?, ?it/s]

73 0.03383351478567561 0.7244682908058167


100%|██████████| 2310/2310 [07:44<00:00,  4.96it/s]
  0%|          | 0/2310 [00:00<?, ?it/s]

74 0.03184839886935462 0.7307894825935364


100%|██████████| 2310/2310 [07:12<00:00,  6.14it/s]
  0%|          | 0/2310 [00:00<?, ?it/s]

75 0.03193812599206394 0.7334988713264465


100%|██████████| 2310/2310 [07:46<00:00,  4.63it/s]
  0%|          | 0/2310 [00:00<?, ?it/s]

76 0.03206331752952157 0.7275597453117371


100%|██████████| 2310/2310 [07:03<00:00,  5.50it/s]
  0%|          | 1/2310 [00:00<04:59,  7.70it/s]

77 0.03249730727316279 0.7248004674911499


100%|██████████| 2310/2310 [06:42<00:00,  5.38it/s]
  0%|          | 1/2310 [00:00<05:32,  6.93it/s]

78 0.030587992003294828 0.7218435406684875


100%|██████████| 2310/2310 [06:20<00:00,  5.95it/s]
  0%|          | 1/2310 [00:00<05:57,  6.46it/s]

79 0.033033733938216854 0.7243591547012329


100%|██████████| 2310/2310 [06:42<00:00,  7.22it/s]
  0%|          | 1/2310 [00:00<06:50,  5.63it/s]

80 0.030518262579330456 0.727026104927063


100%|██████████| 2310/2310 [06:41<00:00,  5.69it/s]
  0%|          | 0/2310 [00:00<?, ?it/s]

81 0.032245386335006544 0.7461102604866028


100%|██████████| 2310/2310 [06:26<00:00,  5.66it/s]
  0%|          | 1/2310 [00:00<05:19,  7.23it/s]

82 0.03053059380770943 0.7346011996269226


100%|██████████| 2310/2310 [06:52<00:00,  5.07it/s]
  0%|          | 1/2310 [00:00<04:53,  7.88it/s]

83 0.032464177425161826 0.7272863388061523


100%|██████████| 2310/2310 [07:20<00:00,  4.98it/s]
  0%|          | 1/2310 [00:00<06:19,  6.09it/s]

84 0.030251535835473013 0.7294768691062927


100%|██████████| 2310/2310 [07:23<00:00,  5.15it/s]
  0%|          | 0/2310 [00:00<?, ?it/s]

85 0.03210581162998484 0.7196683883666992


100%|██████████| 2310/2310 [07:21<00:00,  4.86it/s]
  0%|          | 0/2310 [00:00<?, ?it/s]

86 0.030315036995956758 0.748259961605072


100%|██████████| 2310/2310 [07:41<00:00,  5.73it/s]
  0%|          | 0/2310 [00:00<?, ?it/s]

87 0.030861124468522832 0.7646054029464722


100%|██████████| 2310/2310 [07:34<00:00,  5.27it/s]
  0%|          | 0/2310 [00:00<?, ?it/s]

88 0.030556252731624208 0.7313066720962524


100%|██████████| 2310/2310 [07:24<00:00,  6.36it/s]
  0%|          | 0/2310 [00:00<?, ?it/s]

89 0.030463209733505716 0.7356234192848206


100%|██████████| 2310/2310 [07:27<00:00,  4.84it/s]
  0%|          | 0/2310 [00:00<?, ?it/s]

90 0.030361222244005354 0.7574634552001953


100%|██████████| 2310/2310 [07:10<00:00,  5.63it/s]
  0%|          | 1/2310 [00:00<06:14,  6.16it/s]

91 0.03203776574602061 0.7273143529891968


100%|██████████| 2310/2310 [06:49<00:00,  5.97it/s]
  0%|          | 1/2310 [00:00<05:00,  7.69it/s]

92 0.02941849574304701 0.7254880666732788


100%|██████████| 2310/2310 [06:29<00:00,  6.03it/s]
  0%|          | 1/2310 [00:00<04:58,  7.73it/s]

93 0.03137677091211339 0.7294725775718689


100%|██████████| 2310/2310 [06:45<00:00,  5.37it/s]
  0%|          | 1/2310 [00:00<05:45,  6.68it/s]

94 0.02975113943441831 0.7303140163421631


100%|██████████| 2310/2310 [06:37<00:00,  5.56it/s]
  0%|          | 1/2310 [00:00<06:00,  6.41it/s]

95 0.029431225160026303 0.7414294481277466


100%|██████████| 2310/2310 [06:30<00:00,  5.54it/s]
  0%|          | 0/2310 [00:00<?, ?it/s]

96 0.030052001126217035 0.7403866648674011


100%|██████████| 2310/2310 [06:43<00:00,  6.02it/s]
  0%|          | 1/2310 [00:00<05:03,  7.60it/s]

97 0.031279906855582316 0.7459648847579956


100%|██████████| 2310/2310 [06:27<00:00,  5.85it/s]
  0%|          | 1/2310 [00:00<04:58,  7.73it/s]

98 0.030007797064265294 0.7523114681243896


100%|██████████| 2310/2310 [06:43<00:00,  5.65it/s]
  0%|          | 1/2310 [00:00<05:51,  6.58it/s]

99 0.02872577156861877 0.7741700410842896


100%|██████████| 2310/2310 [06:29<00:00,  5.60it/s]
  0%|          | 1/2310 [00:00<06:02,  6.37it/s]

100 0.029172442131867368 0.7559229135513306


100%|██████████| 2310/2310 [06:31<00:00,  6.77it/s]
  0%|          | 1/2310 [00:00<04:57,  7.77it/s]

101 0.030364951390681096 0.7693836092948914


100%|██████████| 2310/2310 [06:43<00:00,  5.95it/s]
  0%|          | 1/2310 [00:00<04:58,  7.75it/s]

102 0.02935160442439766 0.7352402210235596


100%|██████████| 2310/2310 [06:27<00:00,  5.67it/s]
  0%|          | 1/2310 [00:00<05:48,  6.63it/s]

103 0.029085391218826053 0.7280812859535217


100%|██████████| 2310/2310 [06:43<00:00,  6.00it/s]
  0%|          | 1/2310 [00:00<05:01,  7.66it/s]

104 0.0291607588321461 0.6648954749107361


100%|██████████| 2310/2310 [06:26<00:00,  5.68it/s]
  0%|          | 1/2310 [00:00<06:00,  6.41it/s]

105 0.02979427307840935 0.7388861775398254


100%|██████████| 2310/2310 [06:36<00:00,  6.75it/s]
  0%|          | 1/2310 [00:00<05:12,  7.40it/s]

106 0.028653001312316496 0.7420180439949036


100%|██████████| 2310/2310 [06:43<00:00,  5.45it/s]
  0%|          | 0/2310 [00:00<?, ?it/s]

107 0.028549411857235273 0.7339214086532593


100%|██████████| 2310/2310 [06:27<00:00,  5.66it/s]
  0%|          | 1/2310 [00:00<05:53,  6.53it/s]

108 0.03038755149631763 0.7373933792114258


100%|██████████| 2310/2310 [06:43<00:00,  5.58it/s]
  0%|          | 1/2310 [00:00<06:00,  6.41it/s]

109 0.02826561794352898 0.7197709679603577


100%|██████████| 2310/2310 [06:52<00:00,  5.12it/s]
  0%|          | 0/2310 [00:00<?, ?it/s]

110 0.02784992360311083 0.753833532333374


100%|██████████| 2310/2310 [07:43<00:00,  4.18it/s]
  0%|          | 1/2310 [00:00<07:00,  5.49it/s]

111 0.028374506056957325 0.7310289740562439


100%|██████████| 2310/2310 [07:32<00:00,  4.89it/s]
  0%|          | 1/2310 [00:00<06:08,  6.27it/s]

112 0.02950834440960415 0.7710450887680054


100%|██████████| 2310/2310 [07:45<00:00,  6.31it/s]
  0%|          | 1/2310 [00:00<06:21,  6.05it/s]

113 0.028480794901111363 0.7506764531135559


100%|██████████| 2310/2310 [07:46<00:00,  5.00it/s]
  0%|          | 0/2310 [00:00<?, ?it/s]

114 0.027800569280936844 0.7382298707962036


100%|██████████| 2310/2310 [07:27<00:00,  6.05it/s]
  0%|          | 1/2310 [00:00<05:10,  7.44it/s]

115 0.02793845068179345 0.7310258746147156


100%|██████████| 2310/2310 [07:44<00:00,  4.92it/s]
  0%|          | 0/2310 [00:00<?, ?it/s]

116 0.030168495074417902 0.7711393237113953


100%|██████████| 2310/2310 [07:17<00:00,  5.07it/s]
  0%|          | 0/2310 [00:00<?, ?it/s]

117 0.027430484135505723 0.7662345170974731


100%|██████████| 2310/2310 [07:42<00:00,  5.01it/s]
  0%|          | 1/2310 [00:00<07:13,  5.33it/s]

118 0.02773906188081642 0.7271172404289246


100%|██████████| 2310/2310 [07:14<00:00,  5.18it/s]
  0%|          | 1/2310 [00:00<06:04,  6.33it/s]

119 0.03154453129179682 0.7273355722427368


100%|██████████| 2310/2310 [07:43<00:00,  4.77it/s]
  0%|          | 1/2310 [00:00<05:20,  7.19it/s]

120 0.026800041529440734 0.7365738749504089


100%|██████████| 2310/2310 [07:16<00:00,  5.09it/s]
  0%|          | 0/2310 [00:00<?, ?it/s]

121 0.027883037724304297 0.745805561542511


100%|██████████| 2310/2310 [07:45<00:00,  4.79it/s]
  0%|          | 1/2310 [00:00<05:36,  6.85it/s]

122 0.027966668636926508 0.7469882965087891


100%|██████████| 2310/2310 [07:18<00:00,  4.60it/s]
  0%|          | 0/2310 [00:00<?, ?it/s]

123 0.02699457453309702 0.730517566204071


100%|██████████| 2310/2310 [07:43<00:00,  4.96it/s]
  0%|          | 1/2310 [00:00<07:13,  5.33it/s]

124 0.027548249714938064 0.7666043639183044


100%|██████████| 2310/2310 [07:35<00:00,  5.14it/s]
  0%|          | 0/2310 [00:00<?, ?it/s]

125 0.0278002876113179 0.7330780625343323


100%|██████████| 2310/2310 [07:41<00:00,  6.23it/s]
  0%|          | 0/2310 [00:00<?, ?it/s]

126 0.02673958278038882 0.7417932748794556


100%|██████████| 2310/2310 [07:44<00:00,  5.06it/s]
  0%|          | 0/2310 [00:00<?, ?it/s]

127 0.028862149342926822 0.7374913692474365


100%|██████████| 2310/2310 [07:21<00:00,  6.07it/s]
  0%|          | 0/2310 [00:00<?, ?it/s]

128 0.026488964055451467 0.7592865228652954


100%|██████████| 2310/2310 [07:45<00:00,  4.94it/s]
  0%|          | 1/2310 [00:00<06:32,  5.88it/s]

129 0.02823660710657918 0.7763395309448242


100%|██████████| 2310/2310 [07:18<00:00,  5.12it/s]
  0%|          | 1/2310 [00:00<06:59,  5.51it/s]

130 0.026513482751740217 0.7588056325912476


100%|██████████| 2310/2310 [07:46<00:00,  5.10it/s]
  0%|          | 0/2310 [00:00<?, ?it/s]

131 0.027629662682243813 0.7911058068275452


100%|██████████| 2310/2310 [07:17<00:00,  5.03it/s]
  0%|          | 0/2310 [00:00<?, ?it/s]

132 0.027171659607225652 0.7691518664360046


100%|██████████| 2310/2310 [07:46<00:00,  5.17it/s]
  0%|          | 0/2310 [00:00<?, ?it/s]

133 0.02696572022126151 0.7628541588783264


100%|██████████| 2310/2310 [07:15<00:00,  4.84it/s]
  0%|          | 1/2310 [00:00<05:13,  7.36it/s]

134 0.02801531563150374 0.7228161692619324


100%|██████████| 2310/2310 [07:43<00:00,  5.11it/s]
  0%|          | 1/2310 [00:00<07:27,  5.16it/s]

135 0.025945602356742193 0.7241033315658569


100%|██████████| 2310/2310 [07:23<00:00,  4.79it/s]
  0%|          | 0/2310 [00:00<?, ?it/s]

136 0.028743105915513403 0.7549447417259216


100%|██████████| 2310/2310 [07:45<00:00,  4.63it/s]
  0%|          | 1/2310 [00:00<05:57,  6.46it/s]

137 0.0263949355946256 0.7494614124298096


100%|██████████| 2310/2310 [07:42<00:00,  5.18it/s]
  0%|          | 1/2310 [00:00<07:28,  5.15it/s]

138 0.026181241228366604 0.7476997375488281


100%|██████████| 2310/2310 [07:32<00:00,  5.86it/s]
  0%|          | 0/2310 [00:00<?, ?it/s]

139 0.027417177866842644 0.7503105401992798


100%|██████████| 2310/2310 [07:43<00:00,  4.87it/s]
  0%|          | 1/2310 [00:00<04:52,  7.89it/s]

140 0.026815203703768897 0.7411831021308899


100%|██████████| 2310/2310 [07:18<00:00,  5.02it/s]
  0%|          | 0/2310 [00:00<?, ?it/s]

141 0.026674937469753263 0.7722437381744385


100%|██████████| 2310/2310 [07:42<00:00,  4.88it/s]
  0%|          | 1/2310 [00:00<06:01,  6.39it/s]

142 0.02642611980610864 0.7546193599700928


100%|██████████| 2310/2310 [07:13<00:00,  5.19it/s]
  0%|          | 1/2310 [00:00<06:26,  5.97it/s]

143 0.028085818079019095 0.7666022777557373


100%|██████████| 2310/2310 [07:45<00:00,  5.14it/s]
  0%|          | 1/2310 [00:00<06:14,  6.17it/s]

144 0.025482097942738653 0.7334880232810974


100%|██████████| 2310/2310 [07:16<00:00,  5.24it/s]
  0%|          | 0/2310 [00:00<?, ?it/s]

145 0.026114381125762302 0.7652950286865234


100%|██████████| 2310/2310 [07:46<00:00,  4.97it/s]
  0%|          | 1/2310 [00:00<05:51,  6.57it/s]

146 0.0275622837736738 0.7363050580024719


100%|██████████| 2310/2310 [07:16<00:00,  4.81it/s]
  0%|          | 0/2310 [00:00<?, ?it/s]

147 0.02597076574283188 0.7609251737594604


100%|██████████| 2310/2310 [07:46<00:00,  5.09it/s]
  0%|          | 1/2310 [00:00<07:08,  5.39it/s]

148 0.02834781066055949 0.7414460778236389


100%|██████████| 2310/2310 [07:30<00:00,  5.21it/s]
  0%|          | 1/2310 [00:00<07:04,  5.44it/s]

149 0.02555752897069613 0.7869117856025696


100%|██████████| 2310/2310 [07:45<00:00,  5.69it/s]
  0%|          | 1/2310 [00:00<06:11,  6.22it/s]

150 0.02578579471678497 0.7437182664871216


100%|██████████| 2310/2310 [07:45<00:00,  5.08it/s]
  0%|          | 1/2310 [00:00<07:37,  5.05it/s]

151 0.026149000085151367 0.7740800380706787


100%|██████████| 2310/2310 [07:28<00:00,  6.11it/s]
  0%|          | 0/2310 [00:00<?, ?it/s]

152 0.027649003222194286 0.7686397433280945


100%|██████████| 2310/2310 [07:46<00:00,  5.11it/s]
  0%|          | 0/2310 [00:00<?, ?it/s]

153 0.024821489964256482 0.779561460018158


100%|██████████| 2310/2310 [07:18<00:00,  4.80it/s]
  0%|          | 0/2310 [00:00<?, ?it/s]

154 0.026811352222690002 0.7747917175292969


100%|██████████| 2310/2310 [07:46<00:00,  4.86it/s]
  0%|          | 0/2310 [00:00<?, ?it/s]

155 0.025404104629782383 0.741504967212677


100%|██████████| 2310/2310 [07:16<00:00,  4.86it/s]
  0%|          | 0/2310 [00:00<?, ?it/s]

156 0.025952772879485008 0.7589824795722961


100%|██████████| 2310/2310 [08:26<00:00,  4.57it/s]
  0%|          | 0/2310 [00:00<?, ?it/s]

157 0.02535039606117335 0.7204545736312866


100%|██████████| 2310/2310 [07:58<00:00,  4.26it/s]
  0%|          | 1/2310 [00:00<05:39,  6.80it/s]

158 0.028620505177901132 0.7599985599517822


100%|██████████| 2310/2310 [08:34<00:00,  4.65it/s]
  0%|          | 0/2310 [00:00<?, ?it/s]

159 0.02536254843286347 0.7882809638977051


100%|██████████| 2310/2310 [08:17<00:00,  4.68it/s]
  0%|          | 0/2310 [00:00<?, ?it/s]

160 0.02576647201555407 0.76485276222229


100%|██████████| 2310/2310 [08:24<00:00,  5.55it/s]
  0%|          | 0/2310 [00:00<?, ?it/s]

161 0.025456762083291028 0.7687813639640808


100%|██████████| 2310/2310 [08:35<00:00,  4.56it/s]
  0%|          | 0/2310 [00:00<?, ?it/s]

162 0.025401042776739716 0.7591294050216675


100%|██████████| 2310/2310 [07:59<00:00,  4.65it/s]
  0%|          | 1/2310 [00:00<06:57,  5.53it/s]

163 0.026175998935807496 0.7623401284217834


100%|██████████| 2310/2310 [08:35<00:00,  4.55it/s]
  0%|          | 0/2310 [00:00<?, ?it/s]

164 0.02630393099380216 0.7628633975982666


100%|██████████| 2310/2310 [08:00<00:00,  4.75it/s]
  0%|          | 0/2310 [00:00<?, ?it/s]

165 0.02522295518165584 0.7774348258972168


100%|██████████| 2310/2310 [08:32<00:00,  4.36it/s]
  0%|          | 0/2310 [00:00<?, ?it/s]

166 0.025294492408010467 0.7623269557952881


100%|██████████| 2310/2310 [08:00<00:00,  4.29it/s]
  0%|          | 0/2310 [00:00<?, ?it/s]

167 0.02516186709450005 0.7683058381080627


100%|██████████| 2310/2310 [08:34<00:00,  4.31it/s]
  0%|          | 0/2310 [00:00<?, ?it/s]

168 0.02509807794519824 0.7679992914199829


100%|██████████| 2310/2310 [08:20<00:00,  4.41it/s]
  0%|          | 0/2310 [00:00<?, ?it/s]

169 0.02657000006137407 0.7734635472297668


100%|██████████| 2310/2310 [08:19<00:00,  5.66it/s]
  0%|          | 0/2310 [00:00<?, ?it/s]

170 0.02570997963422702 0.763621985912323


100%|██████████| 2310/2310 [08:34<00:00,  4.39it/s]
  0%|          | 1/2310 [00:00<06:47,  5.67it/s]

171 0.024416964560837497 0.7515972256660461


100%|██████████| 2310/2310 [07:58<00:00,  4.43it/s]
  0%|          | 1/2310 [00:00<06:55,  5.56it/s]

172 0.02528871755583971 0.7372146844863892


100%|██████████| 2310/2310 [08:33<00:00,  4.37it/s]
  0%|          | 0/2310 [00:00<?, ?it/s]

173 0.026562318573280324 0.6158245205879211


100%|██████████| 2310/2310 [07:58<00:00,  4.22it/s]
  0%|          | 0/2310 [00:00<?, ?it/s]

174 0.026651302982975182 0.7663036584854126


100%|██████████| 2310/2310 [08:34<00:00,  4.51it/s]
  0%|          | 0/2310 [00:00<?, ?it/s]

175 0.02417632769196322 0.7537404894828796


100%|██████████| 2310/2310 [08:00<00:00,  4.42it/s]
  0%|          | 1/2310 [00:00<06:55,  5.56it/s]

176 0.026074068974130803 0.7496300935745239


100%|██████████| 2310/2310 [08:33<00:00,  4.37it/s]
  0%|          | 1/2310 [00:00<07:10,  5.36it/s]

177 0.02477791428165861 0.7801429033279419


100%|██████████| 2310/2310 [08:21<00:00,  4.48it/s]
  0%|          | 0/2310 [00:00<?, ?it/s]

178 0.024591794391168743 0.7748480439186096


100%|██████████| 2310/2310 [08:19<00:00,  5.68it/s]
  0%|          | 1/2310 [00:00<07:17,  5.27it/s]

179 0.02586299811613707 0.781362771987915


100%|██████████| 2310/2310 [08:33<00:00,  4.70it/s]
  0%|          | 1/2310 [00:00<07:10,  5.37it/s]

180 0.02457407482925136 0.7467730045318604


100%|██████████| 2310/2310 [07:58<00:00,  4.46it/s]
  0%|          | 1/2310 [00:00<06:54,  5.57it/s]

181 0.025150394147837145 0.7859334349632263


100%|██████████| 2310/2310 [08:34<00:00,  4.48it/s]
  0%|          | 1/2310 [00:00<07:17,  5.28it/s]

182 0.024891965868876785 0.7409635782241821


100%|██████████| 2310/2310 [07:58<00:00,  4.32it/s]
  0%|          | 0/2310 [00:00<?, ?it/s]

183 0.024264250424577204 0.721518337726593


100%|██████████| 2310/2310 [08:35<00:00,  4.30it/s]
  0%|          | 0/2310 [00:00<?, ?it/s]

184 0.025264636834864754 0.790497899055481


100%|██████████| 2310/2310 [07:58<00:00,  4.37it/s]
  0%|          | 0/2310 [00:00<?, ?it/s]

185 0.024134873448447162 0.7132746577262878


100%|██████████| 2310/2310 [08:32<00:00,  4.46it/s]
  0%|          | 0/2310 [00:00<?, ?it/s]

186 0.02595774906103799 0.7457727193832397


100%|██████████| 2310/2310 [08:18<00:00,  4.26it/s]
  0%|          | 1/2310 [00:00<06:58,  5.52it/s]

187 0.02396119598883308 0.7730768322944641


100%|██████████| 2310/2310 [08:22<00:00,  5.62it/s]
  0%|          | 0/2310 [00:00<?, ?it/s]

188 0.024268883693077545 0.665178120136261


100%|██████████| 2310/2310 [08:32<00:00,  4.68it/s]
  0%|          | 1/2310 [00:00<07:04,  5.44it/s]

189 0.025793601195709098 0.7866300940513611


100%|██████████| 2310/2310 [07:58<00:00,  4.80it/s]
  0%|          | 1/2310 [00:00<07:18,  5.26it/s]

190 0.023995439267054608 0.7652312517166138


100%|██████████| 2310/2310 [08:34<00:00,  4.39it/s]
  0%|          | 0/2310 [00:00<?, ?it/s]

191 0.0243918428817847 0.7445462942123413


100%|██████████| 2310/2310 [07:59<00:00,  4.38it/s]
  0%|          | 0/2310 [00:00<?, ?it/s]

192 0.024235503301960195 0.6899592876434326


100%|██████████| 2310/2310 [08:31<00:00,  4.55it/s]
  0%|          | 0/2310 [00:00<?, ?it/s]

193 0.025057132918131126 0.7429038286209106


100%|██████████| 2310/2310 [07:58<00:00,  4.93it/s]
  0%|          | 1/2310 [00:00<05:44,  6.71it/s]

194 0.024778159557901816 0.7846518754959106


100%|██████████| 2310/2310 [08:34<00:00,  4.23it/s]
  0%|          | 0/2310 [00:00<?, ?it/s]

195 0.023914748428074013 0.7382233738899231


100%|██████████| 2310/2310 [08:18<00:00,  4.38it/s]
  0%|          | 0/2310 [00:00<?, ?it/s]

196 0.024232091793960672 0.7607142925262451


100%|██████████| 2310/2310 [08:20<00:00,  5.61it/s]
  0%|          | 1/2310 [00:00<07:05,  5.43it/s]

197 0.025764800707784218 0.7745466232299805


100%|██████████| 2310/2310 [08:34<00:00,  4.66it/s]
  0%|          | 0/2310 [00:00<?, ?it/s]

198 0.023398797412960755 0.7772853374481201


100%|██████████| 2310/2310 [07:59<00:00,  4.31it/s]


199 0.025647739554287913 0.7147150039672852


In [16]:
colors = torch.tensor([[0, 0, 0],         # class 0 'background'  black
                       [255, 0, 0],       # class 1 'grasp'       red
                       [255, 255, 0],     # class 2 'cut'         yellow
                       [0, 255, 0],       # class 3 'scoop'       green
                       [0, 255, 255],     # class 4 'contain'     sky blue
                       [0, 0, 255],       # class 5 'pound'       blue
                       [255, 0, 255],     # class 6 'support'     purple
                       [255, 255, 255]    # class 7 'wrap grasp'  white
                      ])

In [17]:
def class_to_mask(cls):
    
    mask = colors[cls].transpose(1, 2).transpose(1, 3)
    
    return mask

In [18]:
def predict(model, sample, device='cpu'):
    model.eval()
    model.to(device)
    
    x, y = sample['image'], sample['class']
    
    x = x.to(device)
    y = y.to(device)

    with torch.no_grad():
        _, y_pred = model(x).max(1)    # y_pred.shape => (N, 240, 320)
    
    true_mask = class_to_mask(y).to('cpu')
    pred_mask = class_to_mask(y_pred).to('cpu')
    
    save_image(true_mask, "./SegNet_with_class_weight(median)_results/true_mask_with_SegNet_with_class_weight(median).jpg")
    save_image(pred_mask, "./SegNet_with_class_weight(median)_results/pred_mask_with_SegNet_with_class_weight(median).jpg")

In [40]:
trained_model = SegNetBasic(3, 8)
trained_model.load_state_dict(torch.load("./SegNet_with_class_weight(median)_results/best_iou_model.prm"))

In [41]:
eval_data = PartAffordanceDataset('eval.csv',
                                transform=transforms.Compose([
                                    CenterCrop(),
                                    ToTensor(),
                                    Normalize()
                                ]))

In [42]:
def reverse_normalize(x, mean, std):
    x[:, 0, :, :] = x[:, 0, :, :] * std[0] + mean[0]
    x[:, 1, :, :] = x[:, 1, :, :] * std[1] + mean[1]
    x[:, 2, :, :] = x[:, 2, :, :] * std[2] + mean[2]
    return x

In [43]:
eval_loader = DataLoader(eval_data, batch_size=8, shuffle=False)

In [46]:
mean=[55.8630, 59.9099, 91.7419]
std=[31.6852, 29.8496, 19.0835]

for sample in eval_loader:
    trained_model.eval()
    
    predict(trained_model, sample)
    
    x = sample["image"]
    x = reverse_normalize(x, mean, std)
    save_image(x/255, "./SegNet_with_class_weight(median)_results/original_img_with_SegNet_with_class_weight(median).jpg")
    
    break