In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.utils.data
import torchvision

from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.utils import save_image

import pandas as pd
import numpy as np
import scipy.io
import skimage.io

from PIL import Image, ImageFilter

# Define Model

In [2]:
class Encoder(nn.Module):
    def __init__(self, in_channel, out_channel):
        super().__init__()
        
        self.conv = nn.Conv2d(in_channel, out_channel, kernel_size=3, padding=1)
        self.bn = nn.BatchNorm2d(out_channel)
        
    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        x = F.relu(x)
        x, idx = F.max_pool2d(x, kernel_size=2, stride=2, return_indices=True)
        return x, idx

In [3]:
class Decoder(nn.Module):
    def __init__(self, in_channel, out_channel):
        super().__init__()
        
        self.conv = nn.Conv2d(in_channel, out_channel, kernel_size=3, padding=1)
        self.bn = nn.BatchNorm2d(out_channel)
        
    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        x = F.relu(x)
        return x

In [4]:
class SegNetBasic(nn.Module):
    """ 
        SegNet Basic is a smaller version of SegNet
        Please refer to this repository:
        https://github.com/0bserver07/Keras-SegNet-Basic
    """
    
    def __init__(self, in_channel, out_channel):
        super().__init__()
        
        self.encoder1 = Encoder(in_channel, 64)
        self.encoder2 = Encoder(64, 80)
        self.encoder3 = Encoder(80, 96)
        self.encoder4 = Encoder(96, 128)
        
        self.decoder1 = Decoder(128, 96)
        self.decoder2 = Decoder(96, 80)
        self.decoder3 = Decoder(80, 64)
        self.decoder4 = Decoder(64, out_channel)
        
    def forward(self, x):
        size1 = x.size()
        x, idx1 = self.encoder1(x)

        size2 = x.size()
        x, idx2 = self.encoder2(x)

        size3 = x.size()
        x, idx3 = self.encoder3(x)
        
        size4 = x.size()
        x, idx4 = self.encoder4(x)

        x = F.max_unpool2d(x, idx4, kernel_size=2, stride=2, output_size=size4)
        x = self.decoder1(x)
        
        x = F.max_unpool2d(x, idx3, kernel_size=2, stride=2, output_size=size3)
        x = self.decoder2(x)

        x = F.max_unpool2d(x, idx2, kernel_size=2, stride=2, output_size=size2)
        x = self.decoder3(x)

        x = F.max_unpool2d(x, idx1, kernel_size=2, stride=2, output_size=size1)
        x = self.decoder4(x)

        return x

# Dataset

In [7]:
# count the number of pixels in each affordance
import glob

path_list = glob.glob('./part-affordance-dataset/tools/*', recursive=True)
print(path_list)
print(len(path_list))

['./part-affordance-dataset/tools/mug_15', './part-affordance-dataset/tools/knife_12', './part-affordance-dataset/tools/spoon_01', './part-affordance-dataset/tools/ladle_04', './part-affordance-dataset/tools/trowel_02', './part-affordance-dataset/tools/scissors_01', './part-affordance-dataset/tools/bowl_08', './part-affordance-dataset/tools/tenderizer_02', './part-affordance-dataset/tools/shovel_02', './part-affordance-dataset/tools/bowl_07', './part-affordance-dataset/tools/spoon_02', './part-affordance-dataset/tools/saw_03', './part-affordance-dataset/tools/cup_04', './part-affordance-dataset/tools/spoon_06', './part-affordance-dataset/tools/spoon_04', './part-affordance-dataset/tools/spoon_08', './part-affordance-dataset/tools/mug_19', './part-affordance-dataset/tools/pot_01', './part-affordance-dataset/tools/knife_11', './part-affordance-dataset/tools/scoop_02', './part-affordance-dataset/tools/bowl_02', './part-affordance-dataset/tools/cup_05', './part-affordance-dataset/tools/mug

In [8]:
105*50

5250

In [9]:
image_path_list = []

for path in path_list:
    i = glob.glob(path + '/*rgb.jpg')
    image_path_list += i

SyntaxError: invalid syntax (<ipython-input-9-dc0bd215e457>, line 5)

In [10]:
class_path_list = []

for path in image_path_list:
    c = path[:-7] + 'label.mat'
    class_path_list.append(c)

NameError: name 'image_path_list' is not defined

In [None]:
df = pd.DataFrame({
    'image_path': image_path_list,
    'class_path': class_path_list
})

In [None]:
# # write data as csv_file
# df.to_csv('image_class_path.csv', index=None)

In [None]:
df = pd.read_csv('image_class_path.csv')

In [None]:
df_s = df.sample(frac=1, random_state=2)

In [9]:
df_s

Unnamed: 0,image_path,class_path
6650,./part-affordance-dataset/tools/bowl_04/bowl_0...,./part-affordance-dataset/tools/bowl_04/bowl_0...
16785,./part-affordance-dataset/tools/scissors_06/sc...,./part-affordance-dataset/tools/scissors_06/sc...
21334,./part-affordance-dataset/tools/knife_06/knife...,./part-affordance-dataset/tools/knife_06/knife...
10676,./part-affordance-dataset/tools/turner_08/turn...,./part-affordance-dataset/tools/turner_08/turn...
3308,./part-affordance-dataset/tools/saw_03/saw_03_...,./part-affordance-dataset/tools/saw_03/saw_03_...
1304,./part-affordance-dataset/tools/trowel_02/trow...,./part-affordance-dataset/tools/trowel_02/trow...
23177,./part-affordance-dataset/tools/turner_02/turn...,./part-affordance-dataset/tools/turner_02/turn...
28795,./part-affordance-dataset/tools/scissors_04/sc...,./part-affordance-dataset/tools/scissors_04/sc...
4522,./part-affordance-dataset/tools/spoon_08/spoon...,./part-affordance-dataset/tools/spoon_08/spoon...
5254,./part-affordance-dataset/tools/knife_11/knife...,./part-affordance-dataset/tools/knife_11/knife...


In [20]:
len(image_path_list)

28843

In [21]:
28843*0.8

23074.4

In [10]:
df_train = df_s[:23100]

In [11]:
df_train.to_csv('train.csv', index=None)

In [12]:
df_test = df_s[23100:]

In [13]:
df_test.to_csv('test.csv', index=None)

### the number of the pixels before center crop

In [172]:
cnt_dict = {0:0, 1:0, 2:0, 3:0, 4:0, 5:0, 6:0, 7:0}

for path in class_path_list:
    mat = scipy.io.loadmat(path)
    num, cnt = np.unique(mat['gt_label'], return_counts=True)
    
    for n, c in zip(num, cnt):
        cnt_dict[n] += c

cnt_dict

0: 8723340880,<br>
 1: 34159284,<br>
 2: 16009552,<br>
 3: 12433420,<br>
 4: 38476964,<br>
 5: 6773540,<br>
 6: 9273880,<br>
 7: 20102080<br>

# Define Dataset Class

In [5]:
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

import pandas as pd
import numpy as np
import scipy.io
import skimage.io
from PIL import Image, ImageFilter
from sklearn.model_selection import train_test_split

In [11]:
class PartAffordanceDataset(Dataset):
    """Part Affordance Dataset"""
    
    def __init__(self, csv_file, transform=None):
        super().__init__()
        
        self.image_class_path = pd.read_csv(csv_file)
        self.transform = transform
        
    def __len__(self):
        return len(self.image_class_path)
    
    def __getitem__(self, idx):
        image_path = self.image_class_path.iloc[idx, 0]
        class_path = self.image_class_path.iloc[idx, 1]
        image = skimage.io.imread(image_path) # read as numpy array
        cls = scipy.io.loadmat(class_path)["gt_label"]
        
        sample = {'image': image, 'class': cls}
        
        if self.transform:
            sample = self.transform(sample)
            
        return sample

In [12]:
def crop_center_numpy(array, crop_height, crop_weight):
    h, w = array.shape
    return array[h//2 - crop_height//2: h//2 + crop_height//2,
                 w//2 - crop_weight//2: w//2 + crop_weight//2
                ]

In [13]:
def crop_center_pil_image(pil_img, crop_width, crop_height):
    img_width, img_height = pil_img.size
    return pil_img.crop(((img_width - crop_width) // 2,
                         (img_height - crop_height) // 2,
                         (img_width + crop_width) // 2,
                         (img_height + crop_height) // 2))

In [14]:
class CenterCrop(object):
    def __call__(self, sample):
        image, cls = sample['image'], sample['class']
        
        image = Image.fromarray(np.uint8(image))
        
        image = crop_center_pil_image(image, 320, 240)
        cls = crop_center_numpy(cls, 240, 320)
        
        image = np.asarray(image)
        
        return {'image': image, 'class': cls}

In [15]:
class ToTensor(object):
    def __call__(self, sample):
        image, cls = sample['image'], sample['class']
        
        image = image.transpose((2, 0, 1))
        return {'image': torch.from_numpy(image).float(), 
                'class': torch.from_numpy(cls).long()}

In [16]:
# data = PartAffordanceDataset('image_class_path.csv',
#                                 transform=transforms.Compose([
#                                     CenterCrop(),
#                                     ToTensor()
#                                 ]))

# data_loader = DataLoader(data, batch_size=10, shuffle=False)

In [17]:
# mean = 0
# std = 0
# n = 0

# for sample in data_loader:
#     img = sample['image']   
#     img = img.view(len(img), 3, -1)
#     mean += img.mean(2).sum(0)
#     std += img.std(2).sum(0)
#     n += len(img)
    
# mean /= n
# std /= n

In [18]:
mean=[55.8630, 59.9099, 91.7419]
std=[31.6852, 29.8496, 19.0835]

In [19]:
class Normalize(object):
    def __call__(self, sample):
        image, cls = sample['image'], sample['class']
        
        image = transforms.functional.normalize(image, mean, std)
        
        return {'image': image, 'class': cls}

In [15]:
train_data = PartAffordanceDataset('train.csv',
                                transform=transforms.Compose([
                                    CenterCrop(),
                                    ToTensor(),
                                    Normalize()
                                ]))

In [16]:
test_data = PartAffordanceDataset('test.csv',
                                transform=transforms.Compose([
                                    CenterCrop(),
                                    ToTensor(),
                                    Normalize()
                                ]))

In [17]:
train_loader = DataLoader(train_data, batch_size=10, shuffle=True)
test_loader = DataLoader(test_data, batch_size=10, shuffle=False)

### count the number of pixels in each class after center crop

In [25]:
dataset = PartAffordanceDataset('image_class_path.csv',
                                transform=transforms.Compose([
                                    CenterCrop(),
                                    ToTensor()
                                ]))
data_laoder = DataLoader(dataset, batch_size=100, shuffle=False)

In [175]:
cnt_dict = {0:0, 1:0, 2:0, 3:0, 4:0, 5:0, 6:0, 7:0}

for sample in data_laoder:
    img = sample['class'].numpy()
    
    num, cnt = np.unique(img, return_counts=True)
    
    for n, c in zip(num, cnt):
        cnt_dict[n] += c

cnt_dict

0: 2078085712,  
 1: 34078992,  
 2: 15921090,  
 3: 12433420,  
 4: 38473752,  
 5: 6773528,  
 6: 9273826,  
 7: 20102080  

# Training

In [18]:
from tensorboardX import SummaryWriter
import tqdm

In [19]:
def eval_model(model, test_loader, device='cpu'):
    model.eval()
    
    intersection = torch.zeros(8)   # the dataset has 8 classes including background
    union = torch.zeros(8)
    
    for sample in test_loader:
        x, y = sample['image'], sample['class']
        
        x = x.to(device)
        y = y.to(device)
        
        with torch.no_grad():
            _, ypred = model(x).max(1)    # y_pred.shape => (N, 240, 320)
        
        for i in range(8):
            y_i = (y == i)           
            ypred_i = (ypred == i)   
            
            inter = (y_i.byte() & ypred_i.byte()).float().sum().to('cpu')
            intersection[i] += inter
            union[i] += (y_i.float().sum() + ypred_i.float().sum()).to('cpu') - inter
    
    """ iou[i] is the IoU of class i """
    iou = intersection / union
    
    return iou

In [20]:
def train_model(model, train_loader, test_loader, optimizer_cls=optim.Adam, 
                criterion=nn.CrossEntropyLoss(), max_epoch=200, device='cpu', writer=None):
    
    model.to(device)
    
    train_losses = []
    val_iou = []
    mean_iou = []
    best_iou = 0.0
    
    optimizer = optimizer_cls(model.parameters(), lr=0.01)
    
    for epoch in range(max_epoch):
        model.train()
        running_loss = 0.0
        
        for i, sample in tqdm.tqdm(enumerate(train_loader), total=len(train_loader)):
            optimizer.zero_grad()
            
            x, y = sample['image'], sample['class']
            
            x = x.to(device)
            y = y.to(device)

            h = model(x)
            loss = criterion(h, y)
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()

        train_losses.append(running_loss / i)
        
        val_iou.append(eval_model(model, test_loader, device))
        mean_iou.append(val_iou[-1].mean().item())
        
        if best_iou < mean_iou[-1]:
            best_iou = mean_iou[-1]
            torch.save(model.state_dict(), "./SegNet_without_class_weights_results/best_iou_model.prm")
        
        if writer is not None:
            writer.add_scalar("train_loss", train_losses[-1], epoch)
            writer.add_scalar("mean_IoU", mean_iou[-1], epoch)
            writer.add_scalars("class_IoU", {'iou of class 0': val_iou[-1][0],
                                           'iou of class 1': val_iou[-1][1],
                                           'iou of class 2': val_iou[-1][2],
                                           'iou of class 3': val_iou[-1][3],
                                           'iou of class 4': val_iou[-1][4],
                                           'iou of class 5': val_iou[-1][5],
                                           'iou of class 6': val_iou[-1][6],
                                           'iou of class 7': val_iou[-1][7]}, epoch)
            
        print(epoch, train_losses[-1], mean_iou[-1])
        
    torch.save(model.state_dict(), "./SegNet_without_class_weights_results/final_model.prm")

In [21]:
# num0 = 2078085712
# ratio = [1, num0 / 34078992, num0 / 15921090, num0/12433420, 
#          num0 / 38473752, num0 / 6773528, num0 / 9273826, num0 / 20102080]
# class_weight = torch.tensor(ratio, dtype=torch.float32) / 100

In [22]:
model = SegNetBasic(3, 8)
writer = SummaryWriter("./SegNet_without_class_weights_results/")
train_model(model, train_loader, test_loader, device="cuda", writer=writer)

100%|██████████| 2360/2360 [07:48<00:00,  5.04it/s]
  0%|          | 0/2360 [00:00<?, ?it/s]

0 0.25749609359761527 0.34269246459007263


100%|██████████| 2360/2360 [07:48<00:00,  5.03it/s]
  0%|          | 0/2360 [00:00<?, ?it/s]

1 0.07439514800460852 0.602653443813324


100%|██████████| 2360/2360 [07:55<00:00,  4.96it/s]
  0%|          | 1/2360 [00:00<06:05,  6.45it/s]

2 0.03932775350642083 0.741662323474884


100%|██████████| 2360/2360 [07:31<00:00,  5.23it/s]
  0%|          | 1/2360 [00:00<07:43,  5.09it/s]

3 0.032143937092197046 0.7682681679725647


100%|██████████| 2360/2360 [07:54<00:00,  4.97it/s]
  0%|          | 1/2360 [00:00<06:10,  6.36it/s]

4 0.02816715541099494 0.7248371839523315


100%|██████████| 2360/2360 [07:27<00:00,  5.28it/s]
  0%|          | 0/2360 [00:00<?, ?it/s]

5 0.027425403638034997 0.8078756332397461


100%|██████████| 2360/2360 [07:52<00:00,  5.00it/s]
  0%|          | 0/2360 [00:00<?, ?it/s]

6 0.023980272443504747 0.8095967173576355


100%|██████████| 2360/2360 [07:24<00:00,  5.31it/s]
  0%|          | 0/2360 [00:00<?, ?it/s]

7 0.0224203361275686 0.8410762548446655


100%|██████████| 2360/2360 [07:53<00:00,  4.99it/s]
  0%|          | 0/2360 [00:00<?, ?it/s]

8 0.021891408017589764 0.83408522605896


100%|██████████| 2360/2360 [07:26<00:00,  5.29it/s]
  0%|          | 0/2360 [00:00<?, ?it/s]

9 0.021183367417481756 0.839994490146637


100%|██████████| 2360/2360 [07:54<00:00,  4.97it/s]
  0%|          | 0/2360 [00:00<?, ?it/s]

10 0.020901540436810764 0.8322322368621826


100%|██████████| 2360/2360 [07:28<00:00,  5.26it/s]
  0%|          | 0/2360 [00:00<?, ?it/s]

11 0.020406621709251213 0.8374950289726257


100%|██████████| 2360/2360 [07:53<00:00,  4.98it/s]
  0%|          | 0/2360 [00:00<?, ?it/s]

12 0.02003540012917563 0.83253413438797


100%|██████████| 2360/2360 [07:48<00:00,  5.04it/s]
  0%|          | 1/2360 [00:00<07:41,  5.12it/s]

13 0.01962782215364925 0.8292673826217651


100%|██████████| 2360/2360 [07:45<00:00,  5.07it/s]
  0%|          | 1/2360 [00:00<06:07,  6.42it/s]

14 0.0196143306879427 0.8371412754058838


100%|██████████| 2360/2360 [07:53<00:00,  4.98it/s]
  0%|          | 1/2360 [00:00<05:58,  6.58it/s]

15 0.018905301399500554 0.8413196206092834


100%|██████████| 2360/2360 [07:28<00:00,  5.26it/s]
  0%|          | 1/2360 [00:00<06:54,  5.69it/s]

16 0.018938368602074294 0.8393340706825256


100%|██████████| 2360/2360 [07:54<00:00,  4.97it/s]
  0%|          | 1/2360 [00:00<05:46,  6.80it/s]

17 0.018627684465883875 0.8509811162948608


100%|██████████| 2360/2360 [07:27<00:00,  5.28it/s]
  0%|          | 0/2360 [00:00<?, ?it/s]

18 0.018589611038101358 0.8444779515266418


100%|██████████| 2360/2360 [07:55<00:00,  4.96it/s]
  0%|          | 1/2360 [00:00<05:46,  6.81it/s]

19 0.018128387896077975 0.8627512454986572


100%|██████████| 2360/2360 [07:27<00:00,  5.27it/s]
  0%|          | 0/2360 [00:00<?, ?it/s]

20 0.0181199998427822 0.851839542388916


100%|██████████| 2360/2360 [07:55<00:00,  4.96it/s]
  0%|          | 0/2360 [00:00<?, ?it/s]

21 0.01791338395167832 0.8565146923065186


100%|██████████| 2360/2360 [07:24<00:00,  5.31it/s]
  0%|          | 1/2360 [00:00<06:50,  5.75it/s]

22 0.01795774766399149 0.84378981590271


100%|██████████| 2360/2360 [07:52<00:00,  4.99it/s]
  0%|          | 0/2360 [00:00<?, ?it/s]

23 0.017511765469037766 0.839311957359314


100%|██████████| 2360/2360 [07:37<00:00,  5.16it/s]
  0%|          | 0/2360 [00:00<?, ?it/s]

24 0.017764932583368472 0.8473342657089233


100%|██████████| 2360/2360 [07:53<00:00,  4.99it/s]
  0%|          | 0/2360 [00:00<?, ?it/s]

25 0.01718672868768452 0.8529242873191833


100%|██████████| 2360/2360 [07:54<00:00,  4.97it/s]
  0%|          | 0/2360 [00:00<?, ?it/s]

26 0.017252955336996993 0.8483675718307495


100%|██████████| 2360/2360 [07:36<00:00,  5.17it/s]
  0%|          | 0/2360 [00:00<?, ?it/s]

27 0.017077775566470336 0.8497467637062073


100%|██████████| 2360/2360 [07:53<00:00,  4.98it/s]
  0%|          | 0/2360 [00:00<?, ?it/s]

28 0.016998787072612327 0.8490786552429199


100%|██████████| 2360/2360 [07:27<00:00,  5.28it/s]
  0%|          | 0/2360 [00:00<?, ?it/s]

29 0.016855085489307357 0.8416799902915955


100%|██████████| 2360/2360 [07:51<00:00,  5.01it/s]
  0%|          | 1/2360 [00:00<06:13,  6.32it/s]

30 0.01673808080452562 0.854082465171814


100%|██████████| 2360/2360 [07:22<00:00,  5.34it/s]
  0%|          | 0/2360 [00:00<?, ?it/s]

31 0.016691506926748853 0.8490341305732727


100%|██████████| 2360/2360 [07:54<00:00,  4.97it/s]
  0%|          | 1/2360 [00:00<07:38,  5.15it/s]

32 0.016627072878955734 0.8447034955024719


100%|██████████| 2360/2360 [07:25<00:00,  5.30it/s]
  0%|          | 1/2360 [00:00<06:12,  6.34it/s]

33 0.016538741457504634 0.8381818532943726


100%|██████████| 2360/2360 [07:55<00:00,  4.96it/s]
  0%|          | 0/2360 [00:00<?, ?it/s]

34 0.01640959445350804 0.8442271947860718


100%|██████████| 2360/2360 [07:25<00:00,  5.30it/s]
  0%|          | 1/2360 [00:00<06:51,  5.74it/s]

35 0.016313451858084995 0.8504714965820312


100%|██████████| 2360/2360 [07:55<00:00,  4.96it/s]
  0%|          | 1/2360 [00:00<07:32,  5.21it/s]

36 0.016357069935739647 0.8545243144035339


100%|██████████| 2360/2360 [07:43<00:00,  5.09it/s]
  0%|          | 1/2360 [00:00<06:51,  5.73it/s]

37 0.016081522025235846 0.8569158911705017


100%|██████████| 2360/2360 [07:49<00:00,  5.03it/s]
  0%|          | 1/2360 [00:00<06:38,  5.93it/s]

38 0.01619640770461055 0.8559517860412598


100%|██████████| 2360/2360 [07:55<00:00,  4.96it/s]
  0%|          | 1/2360 [00:00<06:23,  6.16it/s]

39 0.01599217887854124 0.8495839834213257


100%|██████████| 2360/2360 [07:32<00:00,  5.22it/s]
  0%|          | 0/2360 [00:00<?, ?it/s]

40 0.01594794035272143 0.8573684096336365


100%|██████████| 2360/2360 [07:55<00:00,  4.96it/s]
  0%|          | 0/2360 [00:00<?, ?it/s]

41 0.015991021409013314 0.8601050972938538


100%|██████████| 2360/2360 [07:27<00:00,  5.28it/s]
  0%|          | 0/2360 [00:00<?, ?it/s]

42 0.015734781691944286 0.8607248067855835


100%|██████████| 2360/2360 [07:55<00:00,  4.96it/s]
  0%|          | 1/2360 [00:00<07:34,  5.18it/s]

43 0.016112521521142287 0.8744149804115295


100%|██████████| 2360/2360 [07:27<00:00,  5.28it/s]
  0%|          | 0/2360 [00:00<?, ?it/s]

44 0.015601149669778503 0.8647000193595886


100%|██████████| 2360/2360 [08:38<00:00,  4.55it/s]
  0%|          | 1/2360 [00:00<07:38,  5.15it/s]

45 0.015703619683999134 0.8602010607719421


100%|██████████| 2360/2360 [08:08<00:00,  4.83it/s]
  0%|          | 0/2360 [00:00<?, ?it/s]

46 0.015584229841045072 0.8700461387634277


100%|██████████| 2360/2360 [08:44<00:00,  4.50it/s]
  0%|          | 0/2360 [00:00<?, ?it/s]

47 0.015692343465366406 0.8467915654182434


100%|██████████| 2360/2360 [08:31<00:00,  4.61it/s]
  0%|          | 0/2360 [00:00<?, ?it/s]

48 0.015495990823058947 0.8591222763061523


100%|██████████| 2360/2360 [08:29<00:00,  4.64it/s]
  0%|          | 0/2360 [00:00<?, ?it/s]

49 0.015600039479182579 0.8579229116439819


100%|██████████| 2360/2360 [08:45<00:00,  4.49it/s]
  0%|          | 0/2360 [00:00<?, ?it/s]

50 0.015314728839637972 0.8505256175994873


100%|██████████| 2360/2360 [08:09<00:00,  4.82it/s]
  0%|          | 1/2360 [00:00<07:15,  5.41it/s]

51 0.015347000857362712 0.854249894618988


100%|██████████| 2360/2360 [08:45<00:00,  4.49it/s]
  0%|          | 0/2360 [00:00<?, ?it/s]

52 0.015442573217475324 0.8833284378051758


100%|██████████| 2360/2360 [08:10<00:00,  4.81it/s]
  0%|          | 0/2360 [00:00<?, ?it/s]

53 0.015210982896711309 0.865283191204071


100%|██████████| 2360/2360 [08:42<00:00,  4.52it/s]
  0%|          | 0/2360 [00:00<?, ?it/s]

54 0.015265444077072872 0.8584712147712708


100%|██████████| 2360/2360 [08:10<00:00,  4.82it/s]
  0%|          | 0/2360 [00:00<?, ?it/s]

55 0.015203256193578445 0.8525657653808594


100%|██████████| 2360/2360 [08:44<00:00,  4.50it/s]
  0%|          | 1/2360 [00:00<06:23,  6.16it/s]

56 0.015111118397545567 0.8559494018554688


100%|██████████| 2360/2360 [08:34<00:00,  4.59it/s]
  0%|          | 0/2360 [00:00<?, ?it/s]

57 0.015087264697547117 0.8589873313903809


100%|██████████| 2360/2360 [08:24<00:00,  4.68it/s]
  0%|          | 0/2360 [00:00<?, ?it/s]

58 0.015024140852630997 0.8623402714729309


100%|██████████| 2360/2360 [08:44<00:00,  4.50it/s]
  0%|          | 1/2360 [00:00<06:44,  5.83it/s]

59 0.015206929633213068 0.8582804799079895


100%|██████████| 2360/2360 [08:09<00:00,  4.82it/s]
  0%|          | 0/2360 [00:00<?, ?it/s]

60 0.014994985993277248 0.8545204401016235


100%|██████████| 2360/2360 [08:44<00:00,  4.50it/s]
  0%|          | 0/2360 [00:00<?, ?it/s]

61 0.014905040128880599 0.8790791630744934


100%|██████████| 2360/2360 [08:08<00:00,  4.83it/s]
  0%|          | 0/2360 [00:00<?, ?it/s]

62 0.014936003870132959 0.8911903500556946


100%|██████████| 2360/2360 [08:44<00:00,  4.50it/s]
  0%|          | 0/2360 [00:00<?, ?it/s]

63 0.01486788006975533 0.8619596362113953


100%|██████████| 2360/2360 [08:10<00:00,  4.81it/s]
  0%|          | 0/2360 [00:00<?, ?it/s]

64 0.014864735866905888 0.8563994765281677


100%|██████████| 2360/2360 [08:43<00:00,  4.50it/s]
  0%|          | 0/2360 [00:00<?, ?it/s]

65 0.014781253188122071 0.8637582063674927


100%|██████████| 2360/2360 [08:35<00:00,  4.57it/s]
  0%|          | 0/2360 [00:00<?, ?it/s]

66 0.014785968406215128 0.8547287583351135


100%|██████████| 2360/2360 [08:24<00:00,  4.68it/s]
  0%|          | 0/2360 [00:00<?, ?it/s]

67 0.014709967927989084 0.8791029453277588


100%|██████████| 2360/2360 [08:43<00:00,  4.51it/s]
  0%|          | 0/2360 [00:00<?, ?it/s]

68 0.014713169541999527 0.8658230304718018


100%|██████████| 2360/2360 [08:09<00:00,  4.83it/s]
  0%|          | 0/2360 [00:00<?, ?it/s]

69 0.014688880971604078 0.848665714263916


100%|██████████| 2360/2360 [08:45<00:00,  4.49it/s]
  0%|          | 0/2360 [00:00<?, ?it/s]

70 0.014675219746203103 0.8584675192832947


100%|██████████| 2360/2360 [08:08<00:00,  4.83it/s]
  0%|          | 0/2360 [00:00<?, ?it/s]

71 0.014555164665140088 0.8596745729446411


100%|██████████| 2360/2360 [08:45<00:00,  4.49it/s]
  0%|          | 0/2360 [00:00<?, ?it/s]

72 0.014669002179054827 0.865885853767395


100%|██████████| 2360/2360 [08:08<00:00,  4.83it/s]
  0%|          | 0/2360 [00:00<?, ?it/s]

73 0.01452944882296187 0.8600696325302124


100%|██████████| 2360/2360 [08:43<00:00,  4.51it/s]
  0%|          | 0/2360 [00:00<?, ?it/s]

74 0.01456420093947716 0.8678543567657471


100%|██████████| 2360/2360 [08:33<00:00,  4.60it/s]
  0%|          | 0/2360 [00:00<?, ?it/s]

75 0.014463069088142736 0.8573943376541138


100%|██████████| 2360/2360 [08:27<00:00,  4.65it/s]
  0%|          | 0/2360 [00:00<?, ?it/s]

76 0.014522715724551449 0.8644660711288452


100%|██████████| 2360/2360 [08:42<00:00,  4.51it/s]
  0%|          | 0/2360 [00:00<?, ?it/s]

77 0.014368890313097623 0.8611115217208862


100%|██████████| 2360/2360 [08:09<00:00,  4.83it/s]
  0%|          | 0/2360 [00:00<?, ?it/s]

78 0.014455381199329043 0.8639352917671204


100%|██████████| 2360/2360 [08:44<00:00,  4.50it/s]
  0%|          | 1/2360 [00:00<06:39,  5.90it/s]

79 0.014399187199158682 0.867906928062439


100%|██████████| 2360/2360 [08:09<00:00,  4.82it/s]
  0%|          | 0/2360 [00:00<?, ?it/s]

80 0.014771594888806633 0.8689473867416382


100%|██████████| 2360/2360 [08:42<00:00,  4.52it/s]
  0%|          | 1/2360 [00:00<07:05,  5.54it/s]

81 0.014278772336222791 0.8676146864891052


100%|██████████| 2360/2360 [08:09<00:00,  4.83it/s]
  0%|          | 0/2360 [00:00<?, ?it/s]

82 0.014336011589249346 0.8583075404167175


100%|██████████| 2360/2360 [08:44<00:00,  4.50it/s]
  0%|          | 0/2360 [00:00<?, ?it/s]

83 0.01426651282670255 0.8629851341247559


100%|██████████| 2360/2360 [08:33<00:00,  4.60it/s]
  0%|          | 0/2360 [00:00<?, ?it/s]

84 0.014418235480557968 0.8762858510017395


100%|██████████| 2360/2360 [08:25<00:00,  4.67it/s]
  0%|          | 1/2360 [00:00<06:06,  6.43it/s]

85 0.01418127643632189 0.8725388050079346


100%|██████████| 2360/2360 [08:44<00:00,  4.50it/s]
  0%|          | 0/2360 [00:00<?, ?it/s]

86 0.014185761447384758 0.8512136936187744


100%|██████████| 2360/2360 [08:09<00:00,  4.83it/s]
  0%|          | 1/2360 [00:00<05:26,  7.22it/s]

87 0.014164131900059527 0.8928039073944092


100%|██████████| 2360/2360 [07:19<00:00,  5.37it/s]
  0%|          | 1/2360 [00:00<05:22,  7.30it/s]

88 0.014173353563255112 0.8923461437225342


100%|██████████| 2360/2360 [07:01<00:00,  5.59it/s]
  0%|          | 1/2360 [00:00<05:27,  7.19it/s]

89 0.01420355076680924 0.8659934997558594


100%|██████████| 2360/2360 [07:19<00:00,  5.37it/s]
  0%|          | 1/2360 [00:00<06:48,  5.77it/s]

90 0.014121529537846459 0.8987671732902527


100%|██████████| 2360/2360 [07:11<00:00,  5.47it/s]
  0%|          | 1/2360 [00:00<05:25,  7.25it/s]

91 0.014059006263734476 0.8974979519844055


100%|██████████| 2360/2360 [07:01<00:00,  5.60it/s]
  0%|          | 1/2360 [00:00<06:35,  5.96it/s]

92 0.014098578189736949 0.8738524317741394


100%|██████████| 2360/2360 [07:20<00:00,  5.36it/s]
  0%|          | 1/2360 [00:00<05:28,  7.18it/s]

93 0.014035457146684463 0.8576726913452148


100%|██████████| 2360/2360 [07:02<00:00,  5.59it/s]
  0%|          | 1/2360 [00:00<05:23,  7.30it/s]

94 0.014027266569108706 0.8737738728523254


100%|██████████| 2360/2360 [07:20<00:00,  5.36it/s]
  0%|          | 1/2360 [00:00<06:04,  6.48it/s]

95 0.014034084566513921 0.8597263097763062


100%|██████████| 2360/2360 [07:04<00:00,  5.56it/s]
  0%|          | 1/2360 [00:00<05:22,  7.30it/s]

96 0.014003629602757158 0.8640711903572083


100%|██████████| 2360/2360 [07:06<00:00,  5.53it/s]
  0%|          | 1/2360 [00:00<06:34,  5.98it/s]

97 0.014004591477568206 0.848856508731842


100%|██████████| 2360/2360 [07:19<00:00,  5.36it/s]
  0%|          | 1/2360 [00:00<06:35,  5.96it/s]

98 0.014018157584692978 0.8690809011459351


100%|██████████| 2360/2360 [07:02<00:00,  5.59it/s]
  0%|          | 1/2360 [00:00<06:35,  5.97it/s]

99 0.013934629099169567 0.8734709620475769


100%|██████████| 2360/2360 [07:20<00:00,  5.36it/s]
  0%|          | 1/2360 [00:00<05:31,  7.12it/s]

100 0.013987466751156835 0.8629025220870972


100%|██████████| 2360/2360 [07:02<00:00,  5.59it/s]
  0%|          | 0/2360 [00:00<?, ?it/s]

101 0.013865645261676433 0.8537488579750061


100%|██████████| 2360/2360 [07:14<00:00,  5.44it/s]
  0%|          | 1/2360 [00:00<05:26,  7.22it/s]

102 0.013888385944540607 0.8625053763389587


100%|██████████| 2360/2360 [07:19<00:00,  5.37it/s]
  0%|          | 0/2360 [00:00<?, ?it/s]

103 0.013921418574280585 0.8565607070922852


100%|██████████| 2360/2360 [07:02<00:00,  5.59it/s]
  0%|          | 0/2360 [00:00<?, ?it/s]

104 0.013852702794046209 0.883345365524292


100%|██████████| 2360/2360 [07:20<00:00,  5.35it/s]
  0%|          | 1/2360 [00:00<06:36,  5.95it/s]

105 0.013825850776819517 0.8661392331123352


100%|██████████| 2360/2360 [07:02<00:00,  5.59it/s]
  0%|          | 1/2360 [00:00<06:36,  5.95it/s]

106 0.013797090679403962 0.8617722392082214


100%|██████████| 2360/2360 [07:20<00:00,  5.36it/s]
  0%|          | 1/2360 [00:00<06:03,  6.50it/s]

107 0.013804545061289317 0.8737964034080505


100%|██████████| 2360/2360 [07:11<00:00,  5.48it/s]
  0%|          | 1/2360 [00:00<06:36,  5.94it/s]

108 0.01377751700613283 0.8648096323013306


100%|██████████| 2360/2360 [07:02<00:00,  5.59it/s]
  0%|          | 0/2360 [00:00<?, ?it/s]

109 0.013744346377996064 0.8834155797958374


100%|██████████| 2360/2360 [07:19<00:00,  5.37it/s]
  0%|          | 1/2360 [00:00<07:38,  5.14it/s]

110 0.01379752393173162 0.8658259510993958


100%|██████████| 2360/2360 [07:02<00:00,  5.59it/s]
  0%|          | 0/2360 [00:00<?, ?it/s]

111 0.013755831310399992 0.8700324296951294


100%|██████████| 2360/2360 [07:20<00:00,  5.35it/s]
  0%|          | 1/2360 [00:00<06:25,  6.12it/s]

112 0.013673652102045585 0.8609415888786316


100%|██████████| 2360/2360 [07:04<00:00,  5.56it/s]
  0%|          | 0/2360 [00:00<?, ?it/s]

113 0.013706092821807616 0.8963622450828552


100%|██████████| 2360/2360 [07:06<00:00,  5.53it/s]
  0%|          | 0/2360 [00:00<?, ?it/s]

114 0.013752276020818605 0.8547776341438293


100%|██████████| 2360/2360 [07:20<00:00,  5.36it/s]
  0%|          | 1/2360 [00:00<05:25,  7.25it/s]

115 0.013676829129179538 0.8648966550827026


100%|██████████| 2360/2360 [07:02<00:00,  5.59it/s]
  0%|          | 1/2360 [00:00<05:23,  7.29it/s]

116 0.01366382540705416 0.8670176267623901


100%|██████████| 2360/2360 [07:20<00:00,  5.36it/s]
  0%|          | 1/2360 [00:00<06:29,  6.06it/s]

117 0.01361322387478118 0.872445285320282


100%|██████████| 2360/2360 [07:02<00:00,  5.59it/s]
  0%|          | 0/2360 [00:00<?, ?it/s]

118 0.013653798874873462 0.8624249696731567


100%|██████████| 2360/2360 [07:13<00:00,  5.44it/s]
  0%|          | 1/2360 [00:00<06:24,  6.13it/s]

119 0.0136764098464931 0.9023815989494324


100%|██████████| 2360/2360 [07:20<00:00,  5.36it/s]
  0%|          | 1/2360 [00:00<06:36,  5.96it/s]

120 0.013533351057567944 0.8697843551635742


100%|██████████| 2360/2360 [07:02<00:00,  5.59it/s]
  0%|          | 1/2360 [00:00<05:27,  7.20it/s]

121 0.013624415971574048 0.869778037071228


100%|██████████| 2360/2360 [07:20<00:00,  5.36it/s]
  0%|          | 0/2360 [00:00<?, ?it/s]

122 0.013569164580194538 0.8709573149681091


100%|██████████| 2360/2360 [07:02<00:00,  5.59it/s]
  0%|          | 1/2360 [00:00<06:30,  6.03it/s]

123 0.0135685044246828 0.8672856688499451


100%|██████████| 2360/2360 [07:20<00:00,  5.36it/s]
  0%|          | 1/2360 [00:00<06:03,  6.49it/s]

124 0.013546040803375651 0.8552590012550354


100%|██████████| 2360/2360 [07:11<00:00,  5.47it/s]
  0%|          | 1/2360 [00:00<06:34,  5.99it/s]

125 0.013567717536461143 0.8893975019454956


100%|██████████| 2360/2360 [07:01<00:00,  5.60it/s]
  0%|          | 1/2360 [00:00<06:37,  5.93it/s]

126 0.0135365221941557 0.8832588195800781


100%|██████████| 2360/2360 [07:19<00:00,  5.36it/s]
  0%|          | 1/2360 [00:00<05:24,  7.26it/s]

127 0.013640894075076286 0.887403666973114


100%|██████████| 2360/2360 [07:02<00:00,  5.59it/s]
  0%|          | 1/2360 [00:00<06:27,  6.09it/s]

128 0.013431386410007226 0.8681719899177551


100%|██████████| 2360/2360 [07:20<00:00,  5.36it/s]
  0%|          | 1/2360 [00:00<06:26,  6.11it/s]

129 0.013407725665320364 0.8661791086196899


100%|██████████| 2360/2360 [07:04<00:00,  5.56it/s]
  0%|          | 0/2360 [00:00<?, ?it/s]

130 0.013476474136595772 0.8672317266464233


100%|██████████| 2360/2360 [07:05<00:00,  5.54it/s]
  0%|          | 0/2360 [00:00<?, ?it/s]

131 0.013443082185906657 0.8757042288780212


100%|██████████| 2360/2360 [07:19<00:00,  5.37it/s]
  0%|          | 1/2360 [00:00<06:29,  6.05it/s]

132 0.013413754125426145 0.8709167838096619


100%|██████████| 2360/2360 [07:02<00:00,  5.59it/s]
  0%|          | 1/2360 [00:00<06:33,  5.99it/s]

133 0.01346205702409205 0.900867223739624


100%|██████████| 2360/2360 [07:20<00:00,  5.36it/s]
  0%|          | 1/2360 [00:00<05:26,  7.23it/s]

134 0.013444814260933751 0.8598535656929016


100%|██████████| 2360/2360 [07:02<00:00,  5.58it/s]
  0%|          | 1/2360 [00:00<06:29,  6.06it/s]

135 0.01344803474834027 0.8735413551330566


100%|██████████| 2360/2360 [07:14<00:00,  5.43it/s]
  0%|          | 1/2360 [00:00<05:25,  7.24it/s]

136 0.013356405829883648 0.8651074171066284


100%|██████████| 2360/2360 [07:20<00:00,  5.36it/s]
  0%|          | 0/2360 [00:00<?, ?it/s]

137 0.013395050216241945 0.8747798204421997


100%|██████████| 2360/2360 [07:01<00:00,  5.60it/s]
  0%|          | 0/2360 [00:00<?, ?it/s]

138 0.013390734225816332 0.8684653639793396


100%|██████████| 2360/2360 [07:19<00:00,  5.37it/s]
  0%|          | 1/2360 [00:00<05:28,  7.18it/s]

139 0.013368987351981271 0.8752294182777405


100%|██████████| 2360/2360 [07:01<00:00,  5.59it/s]
  0%|          | 1/2360 [00:00<07:18,  5.37it/s]

140 0.013321739911240332 0.8610833883285522


100%|██████████| 2360/2360 [07:19<00:00,  5.37it/s]
  0%|          | 1/2360 [00:00<06:30,  6.04it/s]

141 0.01333948027365456 0.863467276096344


100%|██████████| 2360/2360 [07:11<00:00,  5.47it/s]
  0%|          | 0/2360 [00:00<?, ?it/s]

142 0.013364451107912116 0.8698132038116455


100%|██████████| 2360/2360 [07:02<00:00,  5.59it/s]
  0%|          | 0/2360 [00:00<?, ?it/s]

143 0.013255656435698801 0.8734846115112305


100%|██████████| 2360/2360 [07:19<00:00,  5.37it/s]
  0%|          | 1/2360 [00:00<06:21,  6.19it/s]

144 0.01329978417459877 0.904060959815979


100%|██████████| 2360/2360 [07:02<00:00,  5.59it/s]
  0%|          | 0/2360 [00:00<?, ?it/s]

145 0.013296147824816593 0.8669014573097229


100%|██████████| 2360/2360 [07:19<00:00,  5.36it/s]
  0%|          | 1/2360 [00:00<06:27,  6.09it/s]

146 0.013301448329600958 0.8872202634811401


100%|██████████| 2360/2360 [07:03<00:00,  5.57it/s]
  0%|          | 1/2360 [00:00<06:35,  5.96it/s]

147 0.013341476983853054 0.8558257222175598


100%|██████████| 2360/2360 [07:06<00:00,  5.53it/s]
  0%|          | 0/2360 [00:00<?, ?it/s]

148 0.013224813878503062 0.8678972721099854


100%|██████████| 2360/2360 [07:19<00:00,  5.37it/s]
  0%|          | 0/2360 [00:00<?, ?it/s]

149 0.013294480231145635 0.8618655204772949


100%|██████████| 2360/2360 [07:02<00:00,  5.59it/s]
  0%|          | 1/2360 [00:00<06:30,  6.04it/s]

150 0.013243761078042782 0.8869732022285461


100%|██████████| 2360/2360 [07:20<00:00,  5.36it/s]
  0%|          | 1/2360 [00:00<06:31,  6.02it/s]

151 0.013271891862357721 0.8852930665016174


100%|██████████| 2360/2360 [07:02<00:00,  5.59it/s]
  0%|          | 0/2360 [00:00<?, ?it/s]

152 0.013194621446249898 0.8694198727607727


100%|██████████| 2360/2360 [07:13<00:00,  5.44it/s]
  0%|          | 0/2360 [00:00<?, ?it/s]

153 0.013167378117400251 0.8694524168968201


100%|██████████| 2360/2360 [07:20<00:00,  5.36it/s]
  0%|          | 1/2360 [00:00<06:35,  5.96it/s]

154 0.01321091355725131 0.8713523149490356


100%|██████████| 2360/2360 [07:02<00:00,  5.59it/s]
  0%|          | 1/2360 [00:00<06:36,  5.95it/s]

155 0.013190649681965039 0.8749017715454102


100%|██████████| 2360/2360 [07:19<00:00,  5.37it/s]
  0%|          | 1/2360 [00:00<06:35,  5.97it/s]

156 0.013164959768770596 0.8595268130302429


100%|██████████| 2360/2360 [07:02<00:00,  5.59it/s]
  0%|          | 1/2360 [00:00<05:21,  7.34it/s]

157 0.013145354143910397 0.8588370084762573


100%|██████████| 2360/2360 [07:19<00:00,  5.37it/s]
  0%|          | 1/2360 [00:00<06:40,  5.89it/s]

158 0.013127166693165564 0.8796624541282654


100%|██████████| 2360/2360 [07:12<00:00,  5.46it/s]
  0%|          | 1/2360 [00:00<06:38,  5.92it/s]

159 0.01315585099024072 0.8743121027946472


100%|██████████| 2360/2360 [07:02<00:00,  5.59it/s]
  0%|          | 0/2360 [00:00<?, ?it/s]

160 0.013148074871740032 0.8770698308944702


100%|██████████| 2360/2360 [07:19<00:00,  5.37it/s]
  0%|          | 0/2360 [00:00<?, ?it/s]

161 0.01315865586320315 0.8635058403015137


100%|██████████| 2360/2360 [07:02<00:00,  5.58it/s]
  0%|          | 1/2360 [00:00<06:25,  6.12it/s]

162 0.01311766139664713 0.8680830001831055


100%|██████████| 2360/2360 [07:20<00:00,  5.36it/s]
  0%|          | 1/2360 [00:00<06:31,  6.03it/s]

163 0.01312240290950199 0.8753772974014282


100%|██████████| 2360/2360 [07:04<00:00,  5.56it/s]
  0%|          | 0/2360 [00:00<?, ?it/s]

164 0.013209525668688021 0.8672376275062561


100%|██████████| 2360/2360 [07:05<00:00,  5.54it/s]
  0%|          | 0/2360 [00:00<?, ?it/s]

165 0.013148051713036247 0.8663107752799988


100%|██████████| 2360/2360 [07:20<00:00,  5.36it/s]
  0%|          | 1/2360 [00:00<06:37,  5.94it/s]

166 0.013088445015324279 0.8723428845405579


100%|██████████| 2360/2360 [07:02<00:00,  5.59it/s]
  0%|          | 1/2360 [00:00<06:35,  5.96it/s]

167 0.013048970009444653 0.8664073944091797


100%|██████████| 2360/2360 [07:20<00:00,  5.36it/s]
  0%|          | 1/2360 [00:00<05:28,  7.17it/s]

168 0.013076405439465835 0.8706517219543457


100%|██████████| 2360/2360 [07:02<00:00,  5.59it/s]
  0%|          | 1/2360 [00:00<06:31,  6.02it/s]

169 0.013052148140821607 0.8582512736320496


100%|██████████| 2360/2360 [07:12<00:00,  5.45it/s]
  0%|          | 1/2360 [00:00<05:33,  7.08it/s]

170 0.013057554951794335 0.8665174245834351


100%|██████████| 2360/2360 [07:20<00:00,  5.36it/s]
  0%|          | 1/2360 [00:00<05:28,  7.18it/s]

171 0.013092061481033911 0.8635299801826477


100%|██████████| 2360/2360 [07:01<00:00,  5.60it/s]
  0%|          | 1/2360 [00:00<06:29,  6.06it/s]

172 0.012991038673794614 0.8642004132270813


100%|██████████| 2360/2360 [07:20<00:00,  5.36it/s]
  0%|          | 1/2360 [00:00<06:35,  5.96it/s]

173 0.01306364733442323 0.8660329580307007


100%|██████████| 2360/2360 [07:02<00:00,  5.59it/s]
  0%|          | 1/2360 [00:00<06:28,  6.07it/s]

174 0.013018399852455521 0.8677811026573181


100%|██████████| 2360/2360 [07:19<00:00,  5.37it/s]
  0%|          | 1/2360 [00:00<06:24,  6.13it/s]

175 0.012969579068559327 0.8652368187904358


100%|██████████| 2360/2360 [07:12<00:00,  5.46it/s]
  0%|          | 1/2360 [00:00<05:21,  7.33it/s]

176 0.013051391718968363 0.8762533664703369


100%|██████████| 2360/2360 [07:03<00:00,  5.58it/s]
  0%|          | 1/2360 [00:00<05:25,  7.24it/s]

177 0.012934905026209094 0.8676360845565796


100%|██████████| 2360/2360 [07:19<00:00,  5.37it/s]
  0%|          | 1/2360 [00:00<05:28,  7.18it/s]

178 0.012974097432340478 0.8778597116470337


100%|██████████| 2360/2360 [07:02<00:00,  5.59it/s]
  0%|          | 1/2360 [00:00<06:32,  6.02it/s]

179 0.012975947674989208 0.8694836497306824


100%|██████████| 2360/2360 [07:19<00:00,  5.36it/s]
  0%|          | 1/2360 [00:00<06:04,  6.47it/s]

180 0.01305264845528356 0.8646611571311951


100%|██████████| 2360/2360 [07:04<00:00,  5.56it/s]
  0%|          | 1/2360 [00:00<06:29,  6.05it/s]

181 0.012912220800877401 0.8646316528320312


100%|██████████| 2360/2360 [07:05<00:00,  5.54it/s]
  0%|          | 1/2360 [00:00<06:30,  6.05it/s]

182 0.012959904682314423 0.8706281185150146


100%|██████████| 2360/2360 [07:19<00:00,  5.37it/s]
  0%|          | 0/2360 [00:00<?, ?it/s]

183 0.01305668536516798 0.8715200424194336


100%|██████████| 2360/2360 [07:01<00:00,  5.59it/s]
  0%|          | 0/2360 [00:00<?, ?it/s]

184 0.012927964132913989 0.8693456053733826


100%|██████████| 2360/2360 [07:20<00:00,  5.36it/s]
  0%|          | 1/2360 [00:00<06:35,  5.97it/s]

185 0.012920002900924646 0.8669033050537109


100%|██████████| 2360/2360 [07:02<00:00,  5.59it/s]
  0%|          | 0/2360 [00:00<?, ?it/s]

186 0.01294841428811221 0.8564243912696838


100%|██████████| 2360/2360 [07:13<00:00,  5.44it/s]
  0%|          | 1/2360 [00:00<06:27,  6.09it/s]

187 0.01293149699574244 0.8799216151237488


100%|██████████| 2360/2360 [07:20<00:00,  5.36it/s]
  0%|          | 1/2360 [00:00<06:29,  6.05it/s]

188 0.012916284972969988 0.8810272216796875


100%|██████████| 2360/2360 [07:02<00:00,  5.59it/s]
  0%|          | 1/2360 [00:00<05:25,  7.24it/s]

189 0.012896470290076914 0.8691161274909973


100%|██████████| 2360/2360 [07:20<00:00,  5.36it/s]
  0%|          | 1/2360 [00:00<06:27,  6.08it/s]

190 0.012871482146010846 0.8675312399864197


100%|██████████| 2360/2360 [07:02<00:00,  5.58it/s]
  0%|          | 0/2360 [00:00<?, ?it/s]

191 0.012927050260497687 0.8751747012138367


100%|██████████| 2360/2360 [07:19<00:00,  5.37it/s]
  0%|          | 1/2360 [00:00<06:06,  6.44it/s]

192 0.012893478105009714 0.8678687810897827


100%|██████████| 2360/2360 [07:12<00:00,  5.46it/s]
  0%|          | 0/2360 [00:00<?, ?it/s]

193 0.012847475633901616 0.8813547492027283


100%|██████████| 2360/2360 [07:02<00:00,  5.59it/s]
  0%|          | 1/2360 [00:00<06:47,  5.79it/s]

194 0.012880410427987992 0.8655335903167725


100%|██████████| 2360/2360 [07:19<00:00,  5.37it/s]
  0%|          | 1/2360 [00:00<06:35,  5.97it/s]

195 0.012910398220562074 0.8821432590484619


100%|██████████| 2360/2360 [07:02<00:00,  5.59it/s]
  0%|          | 1/2360 [00:00<06:29,  6.06it/s]

196 0.01286649992456561 0.8779321312904358


100%|██████████| 2360/2360 [07:20<00:00,  5.36it/s]
  0%|          | 1/2360 [00:00<06:26,  6.10it/s]

197 0.012849311990013575 0.873960018157959


100%|██████████| 2360/2360 [07:05<00:00,  5.55it/s]
  0%|          | 1/2360 [00:00<05:26,  7.24it/s]

198 0.012817848483065003 0.8674715757369995


100%|██████████| 2360/2360 [07:05<00:00,  5.55it/s]


199 0.013007371443204957 0.8684117197990417


In [7]:
colors = torch.tensor([[0, 0, 0],         # class 0 'background'  black
                       [255, 0, 0],       # class 1 'grasp'       red
                       [255, 255, 0],     # class 2 'cut'         yellow
                       [0, 255, 0],       # class 3 'scoop'       green
                       [0, 255, 255],     # class 4 'contain'     sky blue
                       [0, 0, 255],       # class 5 'pound'       blue
                       [255, 0, 255],     # class 6 'support'     purple
                       [255, 255, 255]    # class 7 'wrap grasp'  white
                      ])

In [8]:
def class_to_mask(cls):
    
    mask = colors[cls].transpose(1, 2).transpose(1, 3)
    
    return mask

In [30]:
def predict(model, sample, device='cpu'):
    model.eval()
    model.to(device)
    
    x, y = sample['image'], sample['class']
    
    
    x = x.to(device)
    y = y.to(device)

    with torch.no_grad():
        _, y_pred = model(x).max(1)    # y_pred.shape => (N, 240, 320)
    
    true_mask = class_to_mask(y).to('cpu')
    pred_mask = class_to_mask(y_pred).to('cpu')
    
    save_image(true_mask, "./SegNet_without_class_weights_results/true_mask_with_SegNet_without_class_weights.jpg")
    save_image(pred_mask, "./SegNet_without_class_weights_results/pred_mask_with_SegNet_without_class_weights.jpg")

In [59]:
trained_model = SegNetBasic(3, 8)
trained_model.load_state_dict(torch.load("./SegNet_without_class_weights_results/final_model.prm"))

In [32]:
eval_data = PartAffordanceDataset('eval.csv',
                                transform=transforms.Compose([
                                    CenterCrop(),
                                    ToTensor(),
                                    Normalize()
                                ]))

In [33]:
def reverse_normalize(x, mean, std):
    x[:, 0, :, :] = x[:, 0, :, :] * std[0] + mean[0]
    x[:, 1, :, :] = x[:, 1, :, :] * std[1] + mean[1]
    x[:, 2, :, :] = x[:, 2, :, :] * std[2] + mean[2]
    return x

In [36]:
eval_loader = DataLoader(eval_data, batch_size=8, shuffle=False)

mean=[55.8630, 59.9099, 91.7419]
std=[31.6852, 29.8496, 19.0835]

for sample in eval_loader:
    trained_model.eval()
    
    predict(trained_model, sample)
    
    x = sample["image"]
    x = reverse_normalize(x, mean, std)
    save_image(x/255, "./SegNet_without_class_weights_results/original_img_with_SegNet_without_class_weights.jpg")
    
    break

In [6]:
def crop_center_pil_image(pil_img, crop_width, crop_height):
    img_width, img_height = pil_img.size
    return pil_img.crop(((img_width - crop_width) // 2,
                         (img_height - crop_height) // 2,
                         (img_width + crop_width) // 2,
                         (img_height + crop_height) // 2))

In [112]:
mean=[55.8630, 59.9099, 91.7419]
std=[31.6852, 29.8496, 19.0835]

def predict_from_image(model, image, device='cpu'):
    model.eval()
    
    image = transforms.functional.resize(image, 420)
    image = crop_center_pil_image(image, 320, 240)
    image = transforms.functional.to_tensor(image)
    image = transforms.functional.normalize(image, mean, std)
    image = image.view((1, 3, 240, 320))
    print(image.shape)
    
    with torch.no_grad():
        _, y_pred = model(image).max(1)
        
    y_pred = class_to_mask(y_pred)
    
    save_image(y_pred, './test_image_with_segnet.jpg')

In [113]:
image = Image.open('Image from iOS.jpg')                                                                                

In [114]:
predict_from_image(trained_model, image)

torch.Size([1, 3, 240, 320])
