In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision

from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.utils import save_image

import click
import numpy as np
import pandas as pd
import scipy.io
import sys
import tqdm

from PIL import Image, ImageFilter
from tensorboardX import SummaryWriter


from models.deeplabv2 import DeepLabV2
from models.msc import MSC
from models.discriminator import Discriminator
from dataset import PartAffordanceDataset, PartAffordanceDatasetWithoutLabel
from dataset import CenterCrop, ToTensor, Normalize



''' one-hot representation '''

def one_hot(label, n_classes, device):
    one_hot_label = torch.eye(n_classes, requires_grad=True, device=device)[label].transpose(1, 3).transpose(2, 3)
    return one_hot_label
    

''' scheduler for learning rate '''

def poly_lr_scheduler(optimizer, init_lr, iter, lr_decay_iter, max_iter, power):
    if iter % lr_decay_iter or iter > max_iter:
        return None
    new_lr = init_lr * (1 - float(iter) / max_iter) ** power
    optimizer.param_groups[0]["lr"] = new_lr
    optimizer.param_groups[1]["lr"] = 10 * new_lr
    optimizer.param_groups[2]["lr"] = 20 * new_lr


def poly_lr_scheduler_d(optimizer, init_lr, iter, lr_decay_iter, max_iter, power):
    if iter % lr_decay_iter or iter > max_iter:
        return None
    new_lr = init_lr * (1 - float(iter) / max_iter) ** power
    optimizer.param_groups[0]["lr"] = new_lr
    if len(optimizer.param_groups) > 1:
        optimizer.param_groups[1]['lr'] = 10 * new_lr


''' model, weight initialization, get params '''

def init_weights(m):
    if isinstance(m, nn.Conv2d):
        nn.init.kaiming_normal_(m.weight)
        if m.bias is not None:
            nn.init.constant_(m.bias, 0)
    elif isinstance(m, nn.Linear):
        nn.init.kaiming_normal_(m.weight)
        if m.bias is not None:
            nn.init.constant_(m.bias, 0)
    elif isinstance(m, nn.BatchNorm2d):
        nn.init.constant_(m.weight, 1)
        if m.bias is not None:
            nn.init.constant_(m.weight, 1)


def DeepLabV2_ResNet101_MSC(n_classes):
    return MSC(
        scale=DeepLabV2(
            n_classes=n_classes, n_blocks=[3, 4, 23, 3], pyramids=[6, 12, 18, 24]
        ),
        pyramids=[0.5, 0.75],
    )


def get_params(model, key):
    # For Dilated FCN
    if key == "1x":
        for m in model.named_modules():
            if "layer" in m[0]:
                if isinstance(m[1], nn.Conv2d):
                    for p in m[1].parameters():
                        yield p
    # For conv weight in the ASPP module
    if key == "10x":
        for m in model.named_modules():
            if "aspp" in m[0]:
                if isinstance(m[1], nn.Conv2d):
                    yield m[1].weight
    # For conv bias in the ASPP module
    if key == "20x":
        for m in model.named_modules():
            if "aspp" in m[0]:
                if isinstance(m[1], nn.Conv2d):
                    yield m[1].bias



''' training '''

def full_train(
        model, model_d, sample, criterion_ce_full, criterion_bce, 
        optimizer, optimizer_d, ones, zeros, device):

    ''' full supervised learning '''
    
    model.train()
    model.scale.freeze_bn()
    model_d.train()

    # train segmentation network
    x, y = sample['image'], sample['class']

    batch_len = len(x)

    x = x.to(device)
    y = y.to(device)

    h = model(x)     # shape => (N, 8, H/8, W/8)
    h = F.interpolate(h, size=(256, 256), mode='bilinear', align_corners=True)

    h_ = h.detach()    # h_ is for calculating loss for discriminator
    y_ = y.detach()    # y_is for the same purpose.  shape => (N, H, W)

    d_out = model_d(h)    # shape => (N, 1, H/32, W/32)
    d_out = F.interpolate(d_out, size=(256, 256), mode='bilinear', align_corners=True)    # shape => (N, 1, H, W)
    d_out = d_out.squeeze()
    
    loss_ce = criterion_ce_full(h, y)
    loss_adv = criterion_bce(d_out, ones[:batch_len])
    loss_full = loss_ce + 0.01 * loss_adv

    optimizer.zero_grad()
    optimizer_d.zero_grad()
    loss_full.backward()
    optimizer.step()


    # train discriminator
    seg_out = model_d(h_)    # shape => (N, 1, H/32, W/32)
    seg_out = F.interpolate(seg_out, size=(256, 256), mode='bilinear', align_corners=True)    # shape => (N, 1, H, W)
    seg_out = seg_out.squeeze()
    
    y_ = one_hot(y_, 8, device)    # shape => (N, 8, H, W)
    true_out = model_d(y_)    # shape => (N, 1, H/32, W/32)
    true_out = F.interpolate(true_out, size=(256, 256), mode='bilinear', align_corners=True)    # shape => (N, 1, H, W)
    true_out = true_out.squeeze()

    loss_d_fake = criterion_bce(seg_out, zeros[:batch_len])
    loss_d_real = criterion_bce(true_out, ones[:batch_len])
    loss_d = loss_d_fake + loss_d_real

    optimizer.zero_grad()
    optimizer_d.zero_grad()
    loss_d.backward()
    optimizer_d.step()
    
    return loss_full.item(), loss_d.item()




def semi_train(
        model, model_d, sample, criterion_ce_semi, criterion_bce, 
        optimizer, optimizer_d, ones, zeros, device):

    ''' semi supervised learning '''
    
    model.train()
    model.scale.freeze_bn()
    model_d.eval()

    # train segmentation network
    x = sample['image']
    batch_len = len(x)
    
    x = x.to(device)
    
    h = model(x)     # shape => (N, 8, H/8, W/8)
    h = F.interpolate(h, size=(256, 256), mode='bilinear', align_corners=True)

    _, h_ = torch.max(h, dim=1)    # to calculate the crossentropy loss. shape => (N, H, W)

    with torch.no_grad():
        d_out = model_d(h)    # shape => (N, 1, H/32, W/32)
        d_out = F.interpolate(d_out, size=(256, 256), mode='bilinear', align_corners=True)    # shape => (N, 1, H, W)
        d_out = d_out.squeeze()

    loss_adv = criterion_bce(d_out, ones[:batch_len])


    # if the pixel value of the output from discriminator is more than a threshold,
    # its value is viewd as one from true label. Else, its value is ignored(value=255).
    h_[d_out < 0.2] = 255

    loss_ce = criterion_ce_semi(h, h_)
    loss_semi = 0.001 * loss_adv + 0.1 * loss_ce

    optimizer.zero_grad()
    optimizer_d.zero_grad()
    loss_semi.backward()
    optimizer.step()

    return loss_semi.item()

In [2]:
pretrained_model = './models/deeplabv2_resnet101_COCO_init.pth'
class_weight_flag = True
batch_size = 6
num_workers = 4
max_epoch = 1000
learning_rate = 0.00025
learning_rate_d = 0.0001
n_classes = 8
device = 'cuda'

In [3]:
''' DataLoader '''
train_data_with_label = PartAffordanceDataset('train_with_label.csv',
                                        transform=transforms.Compose([
                                            CenterCrop(),
                                            ToTensor(),
                                            Normalize()
                                        ]))

train_data_without_label = PartAffordanceDatasetWithoutLabel('train_without_label.csv',
                                        transform=transforms.Compose([
                                            CenterCrop(),
                                            ToTensor(),
                                            Normalize()
                                         ]))

test_data = PartAffordanceDataset('test.csv',
                            transform=transforms.Compose([
                                CenterCrop(),
                                ToTensor(),
                                Normalize()
                            ]))

train_loader_with_label = DataLoader(train_data_with_label, batch_size=batch_size, shuffle=True, num_workers=num_workers)
train_loader_without_label = DataLoader(train_data_without_label, batch_size=batch_size, shuffle=True, num_workers=num_workers)
test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=False, num_workers=num_workers)


''' define model, optimizer, loss '''
model = DeepLabV2_ResNet101_MSC(n_classes)
model_d = Discriminator(n_classes)

model.apply(init_weights)
model_d.apply(init_weights)

state_dict = torch.load(pretrained_model)
model.load_state_dict(state_dict, strict=False)

model.to(device)
model_d.to(device)

optimizer = optim.SGD(
                    params=[{
                        "params": get_params(model, key="1x"),
                        "lr": learning_rate,
                        "weight_decay": 5.0e-4,
                            },
                            {
                        "params": get_params(model, key="10x"),
                        "lr": 10 * learning_rate,
                        "weight_decay": 5.0e-4,
                            },
                            {
                            "params": get_params(model, key="20x"),
                            "lr": 20 * learning_rate,
                            "weight_decay": 0.0,
                            }],
                    momentum=0.9)



optimizer_d = optim.Adam(model_d.parameters(), lr=learning_rate_d, betas=(0.9,0.99))

if class_weight_flag:
    class_weight = torch.tensor([0.0057, 0.4689, 1.0000, 1.2993, 
                                0.4240, 2.3702, 1.7317, 0.8149])    # refer to dataset.py
    criterion_ce_full = nn.CrossEntropyLoss(weight=class_weight.to(device))
else:
    criterion_ce_full = nn.CrossEntropyLoss()

criterion_ce_semi = nn.CrossEntropyLoss(ignore_index=255)
criterion_bce = nn.BCEWithLogitsLoss()

# supplementary constant for discriminator
ones = torch.ones(batch_size, 256, 256).to(device)
zeros = torch.zeros(batch_size, 256, 256).to(device)


''' training '''

losses_full = []
losses_semi = []
losses_d = []
val_iou = []
mean_iou = []
best_iou = 0.0

In [4]:
for epoch in range(max_epoch):

    epoch_loss_full = 0.0
    epoch_loss_d = 0.0
    epoch_loss_semi = 0.0

    poly_lr_scheduler(
        optimizer=optimizer,
        init_lr=learning_rate,
        iter=epoch - 1,
        lr_decay_iter=10,
        max_iter=max_epoch,
        power=0.9,
    )

    poly_lr_scheduler_d(
        optimizer=optimizer_d,
        init_lr=learning_rate_d,
        iter=epoch - 1,
        lr_decay_iter=10,
        max_iter=max_epoch,
        power=0.9,
    )


    # only supervised learning
    if epoch < 0:
        for i, sample in tqdm.tqdm(enumerate(train_loader_with_label), 
                                   total=len(train_loader_with_label)):

            loss_full, loss_d = full_train(
                                    model, model_d, sample, criterion_ce_full, criterion_bce,
                                    optimizer, optimizer_d, ones, zeros, device)
            print(loss_full, loss_d)
            epoch_loss_full += loss_full
            epoch_loss_d += loss_d

        losses_full.append(epoch_loss_full / i)   # mean loss over all samples
        losses_d.append(epoch_loss_d / i)
        losses_semi.append(0.0)


    # semi-supervised learning
    if epoch >= 0:
        for i, (sample1, sample2) in tqdm.tqdm(enumerate(zip(train_loader_with_label, train_loader_without_label)), 
                                                total=len(train_loader_with_label)):

            loss_full, loss_d = full_train(
                                    model, model_d, sample1, criterion_ce_full, criterion_bce,
                                    optimizer, optimizer_d, ones, zeros, device)
            
            epoch_loss_full += loss_full
            epoch_loss_d += loss_d

            loss_semi = semi_train(
                                    model, model_d, sample2, criterion_ce_semi, criterion_bce,
                                    optimizer, optimizer_d, ones, zeros, device)
            
            epoch_loss_semi += loss_semi

        losses_full.append(epoch_loss_full / i)   # mean loss over all samples
        losses_d.append(epoch_loss_d / i)
        losses_semi.append(epoch_loss_semi / i)

  "See the documentation of nn.Upsample for details.".format(mode))


2.8698792457580566 1.253151297569275


  0%|          | 1/1967 [00:01<41:05,  1.25s/it]

0.011708456091582775
11.97008991241455 4.213683128356934


  0%|          | 2/1967 [00:02<40:32,  1.24s/it]

0.008660722523927689
3.994086265563965 0.36702650785446167


  0%|          | 3/1967 [00:03<39:42,  1.21s/it]

0.007229161448776722
4.498116970062256 0.3504098653793335


  0%|          | 4/1967 [00:04<39:15,  1.20s/it]

0.002384445397183299
4.05718994140625 0.2752782702445984


  0%|          | 5/1967 [00:05<39:00,  1.19s/it]

0.13536643981933594
3.9291133880615234 0.5638864040374756


  0%|          | 6/1967 [00:07<38:37,  1.18s/it]

0.004842708352953196
4.964388370513916 0.20941978693008423


  0%|          | 7/1967 [00:08<38:35,  1.18s/it]

0.002662213984876871
2.342280149459839 0.2223885953426361


  0%|          | 8/1967 [00:09<38:24,  1.18s/it]

0.10166565328836441
3.393409013748169 3.2462804317474365


  0%|          | 9/1967 [00:10<38:08,  1.17s/it]

0.13631108403205872
2.080451488494873 0.33272701501846313


  1%|          | 10/1967 [00:11<37:56,  1.16s/it]

0.0060912128537893295
1.844444990158081 0.25567737221717834


  1%|          | 11/1967 [00:12<37:50,  1.16s/it]

0.006948020309209824
2.46440052986145 0.32061752676963806


  1%|          | 12/1967 [00:14<37:45,  1.16s/it]

0.004863671492785215
1.927861213684082 0.40020695328712463


  1%|          | 13/1967 [00:15<37:43,  1.16s/it]

0.0035459098871797323
2.2502543926239014 0.44072863459587097


  1%|          | 14/1967 [00:16<37:38,  1.16s/it]

0.0036840436514467
1.95070219039917 0.38630300760269165


  1%|          | 15/1967 [00:17<37:36,  1.16s/it]

0.005218432750552893
2.022118091583252 0.3907071053981781


  1%|          | 16/1967 [00:18<37:34,  1.16s/it]

0.005122071132063866
1.9362027645111084 0.3504157066345215


  1%|          | 17/1967 [00:19<37:31,  1.15s/it]

0.005259358324110508
2.290332078933716 0.35551223158836365


  1%|          | 18/1967 [00:20<37:29,  1.15s/it]

0.008057921193540096
2.0660722255706787 0.299469918012619


  1%|          | 19/1967 [00:22<37:27,  1.15s/it]

0.010536549612879753
2.083909034729004 0.2644241154193878


  1%|          | 20/1967 [00:23<37:30,  1.16s/it]

0.05754852294921875
2.344194173812866 0.3417905271053314


  1%|          | 21/1967 [00:24<37:27,  1.15s/it]

0.049362048506736755
2.3307483196258545 0.26051437854766846


  1%|          | 22/1967 [00:25<37:25,  1.15s/it]

0.07296470552682877
2.13665509223938 0.6764751672744751


  1%|          | 23/1967 [00:26<37:24,  1.15s/it]

0.07109033316373825
1.785438060760498 0.5833394527435303


  1%|          | 24/1967 [00:27<37:24,  1.16s/it]

0.07179375737905502
1.8190076351165771 1.0444958209991455


  1%|▏         | 25/1967 [00:29<38:02,  1.18s/it]

0.06407677382230759
2.15037202835083 1.5430760383605957


  1%|▏         | 26/1967 [00:30<38:13,  1.18s/it]

0.051441941410303116
2.0724167823791504 1.6285820007324219


  1%|▏         | 27/1967 [00:31<38:30,  1.19s/it]

0.04607565701007843
2.01420521736145 1.1314961910247803


  1%|▏         | 28/1967 [00:32<38:43,  1.20s/it]

0.04255591705441475
1.751344084739685 0.39074862003326416


  1%|▏         | 29/1967 [00:33<38:48,  1.20s/it]

0.00858160387724638
2.3360936641693115 0.4577636122703552


  2%|▏         | 30/1967 [00:35<38:44,  1.20s/it]

0.009269634261727333
1.8112484216690063 0.8698969483375549


  2%|▏         | 31/1967 [00:36<38:50,  1.20s/it]

0.007887675426900387
2.150132179260254 1.0758354663848877


  2%|▏         | 32/1967 [00:37<38:51,  1.20s/it]

0.006908936891704798
1.8819122314453125 0.8515539169311523


  2%|▏         | 33/1967 [00:38<38:58,  1.21s/it]

0.006734306458383799
2.0483977794647217 0.7720101475715637


  2%|▏         | 34/1967 [00:39<38:43,  1.20s/it]

0.006717617157846689
1.8270946741104126 0.5364201068878174


  2%|▏         | 35/1967 [00:41<38:16,  1.19s/it]

0.006812836974859238
1.832937479019165 0.321870356798172


  2%|▏         | 36/1967 [00:42<38:18,  1.19s/it]

0.007502488326281309
2.067319631576538 0.2914794981479645


  2%|▏         | 37/1967 [00:43<38:15,  1.19s/it]

0.007698163390159607
2.0559587478637695 0.23794183135032654


  2%|▏         | 38/1967 [00:44<38:14,  1.19s/it]

0.006840172223746777
1.8440251350402832 0.18941205739974976


  2%|▏         | 39/1967 [00:45<38:19,  1.19s/it]

0.005976554471999407
2.3094379901885986 0.1665370911359787


  2%|▏         | 40/1967 [00:47<38:09,  1.19s/it]

0.00518804183229804
1.9418361186981201 0.2282865345478058


  2%|▏         | 41/1967 [00:48<37:45,  1.18s/it]

0.034713950008153915
1.9351756572723389 0.20811393857002258


  2%|▏         | 42/1967 [00:49<37:36,  1.17s/it]

0.03173390030860901
1.7852531671524048 0.25071749091148376


  2%|▏         | 43/1967 [00:50<37:27,  1.17s/it]

0.046767182648181915
2.0552620887756348 0.3405505120754242


  2%|▏         | 44/1967 [00:51<37:37,  1.17s/it]

0.05214646831154823
2.281172275543213 0.4594035744667053


  2%|▏         | 45/1967 [00:52<37:23,  1.17s/it]

0.0578266978263855
1.9729678630828857 0.5519734621047974


  2%|▏         | 46/1967 [00:54<37:14,  1.16s/it]

0.0654635950922966
2.091203212738037 0.4291480779647827


  2%|▏         | 47/1967 [00:55<37:09,  1.16s/it]

0.0841076597571373
1.9007800817489624 0.21572357416152954


  2%|▏         | 48/1967 [00:56<37:04,  1.16s/it]

0.10133564472198486
2.332698106765747 0.13438966870307922


  2%|▏         | 49/1967 [00:57<37:02,  1.16s/it]

0.10387008637189865
1.9274110794067383 0.25297993421554565


  3%|▎         | 50/1967 [00:58<36:59,  1.16s/it]

0.09944798052310944
1.7868869304656982 0.1923980712890625


  3%|▎         | 51/1967 [00:59<36:57,  1.16s/it]

0.10541000962257385
1.8932772874832153 0.2853808104991913


  3%|▎         | 52/1967 [01:01<36:53,  1.16s/it]

0.09207338094711304
1.9082915782928467 0.20880092680454254


  3%|▎         | 53/1967 [01:02<36:50,  1.16s/it]

0.006519518326967955
1.6963984966278076 0.23023900389671326


  3%|▎         | 54/1967 [01:03<36:50,  1.16s/it]

0.006707970984280109
1.9847044944763184 0.28108543157577515


  3%|▎         | 55/1967 [01:04<36:48,  1.16s/it]

0.024528127163648605
2.0366339683532715 0.3741839528083801


  3%|▎         | 56/1967 [01:05<36:47,  1.15s/it]

0.029199114069342613
1.8934401273727417 0.31763288378715515


  3%|▎         | 57/1967 [01:06<36:46,  1.16s/it]

0.036558378487825394
2.183948040008545 0.45450419187545776


  3%|▎         | 58/1967 [01:07<36:43,  1.15s/it]

0.04945949465036392
2.111417293548584 0.42479872703552246


  3%|▎         | 59/1967 [01:09<36:43,  1.16s/it]

0.056809909641742706
1.8871766328811646 0.3573371171951294


  3%|▎         | 60/1967 [01:10<36:44,  1.16s/it]

0.0515313595533371
1.9374065399169922 0.37023860216140747


  3%|▎         | 61/1967 [01:11<36:41,  1.16s/it]

0.03955599665641785
1.9340322017669678 0.4208599925041199


  3%|▎         | 62/1967 [01:12<36:41,  1.16s/it]

0.03100588172674179
1.977989912033081 0.5030295848846436


  3%|▎         | 63/1967 [01:13<36:38,  1.15s/it]

0.02473408728837967
1.8147164583206177 0.4861946702003479


  3%|▎         | 64/1967 [01:14<36:34,  1.15s/it]

0.01704687997698784
1.8672401905059814 0.3234764337539673


  3%|▎         | 65/1967 [01:16<36:34,  1.15s/it]

0.014907840639352798
2.051997661590576 0.3503275513648987


  3%|▎         | 66/1967 [01:17<36:33,  1.15s/it]

0.007723917253315449
2.049039602279663 0.40981659293174744


  3%|▎         | 67/1967 [01:18<36:32,  1.15s/it]

0.007902977056801319
2.0364768505096436 0.3540683090686798


  3%|▎         | 68/1967 [01:19<36:30,  1.15s/it]

0.007337816525250673
1.9757226705551147 0.2874699831008911


  4%|▎         | 69/1967 [01:20<36:31,  1.15s/it]

0.08282433450222015
1.797174096107483 0.24017083644866943


  4%|▎         | 70/1967 [01:21<36:31,  1.16s/it]

0.08863652497529984
1.9535452127456665 0.32859477400779724


  4%|▎         | 71/1967 [01:22<36:27,  1.15s/it]

0.07653693109750748
1.5670963525772095 0.2532805800437927


  4%|▎         | 72/1967 [01:24<36:48,  1.17s/it]

0.05998126417398453
1.6313265562057495 0.31231892108917236


  4%|▎         | 73/1967 [01:25<37:11,  1.18s/it]

0.028836306184530258
1.3697766065597534 0.612615704536438


  4%|▍         | 74/1967 [01:26<37:28,  1.19s/it]

0.01534265000373125
2.0430400371551514 0.7894971966743469


  4%|▍         | 75/1967 [01:27<37:42,  1.20s/it]

0.011391492560505867
2.7277204990386963 0.5070832371711731


  4%|▍         | 76/1967 [01:28<37:51,  1.20s/it]

0.011684981174767017
1.4180712699890137 0.5150801539421082


  4%|▍         | 77/1967 [01:30<37:44,  1.20s/it]

0.011525508016347885
2.0453577041625977 0.6577736735343933


  4%|▍         | 78/1967 [01:31<37:25,  1.19s/it]

0.009751927107572556
1.5689260959625244 0.5549251437187195


  4%|▍         | 79/1967 [01:32<37:04,  1.18s/it]

0.007900122553110123
2.2243895530700684 0.6844226717948914


  4%|▍         | 80/1967 [01:33<37:01,  1.18s/it]

0.0062259407714009285
1.8258659839630127 0.42357996106147766


  4%|▍         | 81/1967 [01:34<36:50,  1.17s/it]

0.004781669471412897
1.7842903137207031 0.3134675920009613


  4%|▍         | 82/1967 [01:36<36:55,  1.18s/it]

0.12046321481466293
1.9656651020050049 0.2698366045951843


  4%|▍         | 83/1967 [01:37<36:40,  1.17s/it]

0.09806180000305176
2.0694022178649902 0.3041979670524597


  4%|▍         | 84/1967 [01:38<36:31,  1.16s/it]

0.07933719456195831
1.9231666326522827 0.19583441317081451


  4%|▍         | 85/1967 [01:39<36:26,  1.16s/it]

0.058030735701322556
2.330043315887451 0.34594419598579407


  4%|▍         | 86/1967 [01:40<36:20,  1.16s/it]

0.05012979358434677
2.07830810546875 0.43461453914642334


  4%|▍         | 87/1967 [01:41<36:22,  1.16s/it]

0.05211004242300987
1.980574131011963 0.48374998569488525


  4%|▍         | 88/1967 [01:42<36:15,  1.16s/it]

0.05370006710290909
1.7598956823349 0.39124464988708496


  5%|▍         | 89/1967 [01:44<36:12,  1.16s/it]

0.04902023449540138
1.710841178894043 0.3675198554992676


  5%|▍         | 90/1967 [01:45<36:09,  1.16s/it]

0.029899023473262787
1.697337031364441 0.2514742910861969


  5%|▍         | 91/1967 [01:46<36:07,  1.16s/it]

0.02063891291618347
1.633617877960205 0.20001038908958435


  5%|▍         | 92/1967 [01:47<36:02,  1.15s/it]

0.02492283098399639
2.125988006591797 0.28692153096199036


  5%|▍         | 93/1967 [01:48<36:14,  1.16s/it]

0.021364454180002213
1.913976788520813 0.30659234523773193


  5%|▍         | 94/1967 [01:49<36:44,  1.18s/it]

0.0061536431312561035
1.5616105794906616 0.32357311248779297


  5%|▍         | 95/1967 [01:51<37:00,  1.19s/it]

0.006108004134148359
1.4487308263778687 0.32076650857925415


  5%|▍         | 96/1967 [01:52<37:15,  1.19s/it]

0.005899030715227127
1.851867914199829 0.42920202016830444


  5%|▍         | 97/1967 [01:53<37:02,  1.19s/it]

0.005165720358490944
2.192650556564331 0.5112618803977966


  5%|▍         | 98/1967 [01:54<36:51,  1.18s/it]

0.004282210487872362
1.8800337314605713 0.4106219410896301


  5%|▌         | 99/1967 [01:55<36:33,  1.17s/it]

0.027784625068306923
2.2203335762023926 0.39545154571533203


  5%|▌         | 100/1967 [01:57<36:24,  1.17s/it]Process Process-1:
Process Process-2:
Process Process-5:
Process Process-4:
Process Process-8:
Process Process-7:
Process Process-3:
Process Process-6:
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
  File "/home/yuchi/anaconda3/envs/torch/lib/python3.5/multiprocessing/process.py", line 252, in _bootstrap
    self.run()
  File "/home/yuchi/anaconda3/envs/torch/lib/python3.5/multiprocessing/process.py", line 252, in _bootstrap
    self.run()
  File "/home/yuchi/anaconda3/envs/torch/lib/python3.5/multiprocessing/process.py", line 252, in _bootstrap
    self.run()
  File "/home/yuchi/anaconda3/envs/torch/lib/python3.5/multiprocessing/process.py", line 252, in _bootstrap
    self.run()
  File "/home/yuchi/anaconda3/envs/torch/lib/python3.5/multip

  File "/home/yuchi/anaconda3/envs/torch/lib/python3.5/multiprocessing/connection.py", line 911, in wait
    ready = selector.select(timeout)
KeyboardInterrupt
  File "/home/yuchi/anaconda3/envs/torch/lib/python3.5/multiprocessing/connection.py", line 911, in wait
    ready = selector.select(timeout)
  File "/home/yuchi/anaconda3/envs/torch/lib/python3.5/selectors.py", line 376, in select
    fd_event_list = self._poll.poll(timeout)
  File "/home/yuchi/anaconda3/envs/torch/lib/python3.5/selectors.py", line 376, in select
    fd_event_list = self._poll.poll(timeout)
  File "/home/yuchi/anaconda3/envs/torch/lib/python3.5/multiprocessing/connection.py", line 911, in wait
    ready = selector.select(timeout)
  File "/home/yuchi/anaconda3/envs/torch/lib/python3.5/selectors.py", line 376, in select
    fd_event_list = self._poll.poll(timeout)
KeyboardInterrupt
KeyboardInterrupt
  File "/home/yuchi/anaconda3/envs/torch/lib/python3.5/multiprocessing/queues.py", line 104, in get
    if timeout 

0.050244495272636414


KeyboardInterrupt: 