In [1]:
import torch
import torch.nn as nn
import torch.optim as optim

from torch.utils.data import DataLoader
from torchvision import transforms


import argparse
import time
import tqdm
import yaml

from addict import Dict
from tensorboardX import SummaryWriter

from dataset import PartAffordanceDataset, ToTensor, CenterCrop, Normalize
from dataset import Resize, RandomFlip, RandomRotate, ColorChange, RandomCrop
from model.drn import drn_c_58

In [2]:
CONFIG = Dict(yaml.safe_load(open('./result/drn_c_58/config.yaml')))

""" DataLoader """

train_data = PartAffordanceDataset(CONFIG.train_data,
                                   config=CONFIG,
                                   transform=transforms.Compose([
                                       RandomRotate(45),
                                       RandomCrop(CONFIG),
                                       Resize(CONFIG),
                                       ColorChange(),
                                       RandomFlip(),
                                       ToTensor(CONFIG),
                                       Normalize()
                                   ]))

test_data = PartAffordanceDataset(CONFIG.test_data,
                                  config=CONFIG,
                                  transform=transforms.Compose([
                                      CenterCrop(CONFIG),
                                      Resize(CONFIG),
                                      ToTensor(CONFIG),
                                      Normalize()
                                  ]))

train_loader = DataLoader(train_data, batch_size=CONFIG.batch_size,
                          shuffle=True, num_workers=CONFIG.num_workers)
test_loader = DataLoader(test_data, batch_size=CONFIG.batch_size,
                         shuffle=False, num_workers=CONFIG.num_workers)

In [3]:
from torchvision.utils import save_image, make_grid

In [313]:
for i, sample in enumerate(train_loader):
    a = sample['aff_label']
#     img = sample['image']
#     save_image(make_grid(img), './{}.png'.format(i))
    if i == 10:
        break

In [316]:
a.sum()

tensor(21.)

In [228]:
n, c, h, w = feats.shape

In [229]:
desc_inds = torch.argsort(
    feats.reshape(n, c, -1), dim=2, descending=True)   # (n, c, h*w)

In [230]:
ds = torch.zeros_like(desc_inds, dtype=torch.float) + 0.996

In [231]:
ds.shape

torch.Size([8, 18, 784])

In [233]:
ds[a.nonzero()[:, 0], a.nonzero()[:, 1]]

tensor([[0.9960, 0.9960, 0.9960,  ..., 0.9960, 0.9960, 0.9960],
        [0.9960, 0.9960, 0.9960,  ..., 0.9960, 0.9960, 0.9960],
        [0.9960, 0.9960, 0.9960,  ..., 0.9960, 0.9960, 0.9960],
        ...,
        [0.9960, 0.9960, 0.9960,  ..., 0.9960, 0.9960, 0.9960],
        [0.9960, 0.9960, 0.9960,  ..., 0.9960, 0.9960, 0.9960],
        [0.9960, 0.9960, 0.9960,  ..., 0.9960, 0.9960, 0.9960]])

In [172]:
ds[:, 0] = 0.2

In [173]:
nn, cc, inds = torch.meshgrid(
    torch.arange(n), torch.arange(c), torch.arange(h * w)
)

In [188]:
weights = torch.pow(ds, inds.float())    # (n, c, h*w)

RuntimeError: Expected object of scalar type Float but got scalar type Long for argument #2 'exponent'

In [177]:
z_dc

tensor([[[  1.2500],
         [  2.0000],
         [239.2036],
         [239.2036],
         [239.2036],
         [239.2036],
         [239.2036],
         [239.2036],
         [239.2036],
         [239.2036],
         [239.2036],
         [239.2036],
         [239.2036],
         [239.2036],
         [239.2036],
         [239.2036],
         [239.2036],
         [239.2036]],

        [[  1.2500],
         [239.2036],
         [239.2036],
         [239.2036],
         [239.2036],
         [239.2036],
         [239.2036],
         [239.2036],
         [239.2036],
         [239.2036],
         [239.2036],
         [239.2036],
         [239.2036],
         [239.2036],
         [  2.0000],
         [239.2036],
         [239.2036],
         [239.2036]],

        [[  1.2500],
         [239.2036],
         [239.2036],
         [239.2036],
         [239.2036],
         [239.2036],
         [239.2036],
         [239.2036],
         [239.2036],
         [239.2036],
         [239.2036],
         

In [176]:
z_dc = torch.sum(weights, dim=2, keepdim=True)    # (n, 1, h*w)

In [178]:
z_dc.shape

torch.Size([8, 18, 1])

In [179]:
desc_feats = feats.reshape(n, c, -1)[nn, cc, desc_inds]

In [180]:
desc_feats.shape

torch.Size([8, 18, 784])

In [181]:
y_gwrp = torch.sum(weights * desc_feats / z_dc, dim=2)

In [182]:
y_gwrp.shape

torch.Size([8, 18])

In [187]:
(a == 0.).nonzero()

tensor([[ 0,  2],
        [ 0,  3],
        [ 0,  4],
        [ 0,  5],
        [ 0,  6],
        [ 0,  7],
        [ 0,  8],
        [ 0,  9],
        [ 0, 10],
        [ 0, 11],
        [ 0, 12],
        [ 0, 13],
        [ 0, 14],
        [ 0, 15],
        [ 0, 16],
        [ 0, 17],
        [ 1,  1],
        [ 1,  2],
        [ 1,  3],
        [ 1,  4],
        [ 1,  5],
        [ 1,  6],
        [ 1,  7],
        [ 1,  8],
        [ 1,  9],
        [ 1, 10],
        [ 1, 11],
        [ 1, 12],
        [ 1, 13],
        [ 1, 15],
        [ 1, 16],
        [ 1, 17],
        [ 2,  1],
        [ 2,  2],
        [ 2,  3],
        [ 2,  4],
        [ 2,  5],
        [ 2,  6],
        [ 2,  7],
        [ 2,  8],
        [ 2,  9],
        [ 2, 10],
        [ 2, 11],
        [ 2, 12],
        [ 2, 13],
        [ 2, 14],
        [ 2, 15],
        [ 2, 16],
        [ 3,  1],
        [ 3,  2],
        [ 3,  3],
        [ 3,  4],
        [ 3,  5],
        [ 3,  6],
        [ 3,  8],
        [ 

In [183]:
y_gwrp[a.nonzero()[:, 0], a.nonzero()[:, 1]]

tensor([0.8684, 0.9154, 0.9519, 0.9072, 0.8579, 0.9271, 0.9170, 0.8787, 0.9811,
        0.9010, 0.8960, 0.8874, 0.9080, 0.8244, 0.8966, 0.9016])

In [185]:
-torch.log(y_gwrp[a.nonzero()[:, 0], a.nonzero()[:, 1]])

tensor([0.1411, 0.0884, 0.0493, 0.0974, 0.1533, 0.0757, 0.0866, 0.1293, 0.0191,
        0.1043, 0.1098, 0.1194, 0.0965, 0.1931, 0.1091, 0.1036])

In [214]:
a = torch.randn((3, 10,11))

In [211]:
b, _ = torch.min(a.view(a.shape[0], -1), dim=1)

In [220]:
a - b.view(a.shape[0], 1, 1)

tensor([[[ 0.7765,  3.1403,  2.8301,  0.1948,  3.7324,  0.7689,  4.4386,
           2.5990,  2.8630,  1.8811,  2.1810],
         [ 2.6806,  0.4221,  1.8691,  1.2421,  3.5555,  2.7880,  2.2874,
           1.2096,  4.2552,  1.8552,  1.8166],
         [ 0.4182,  3.3875,  4.2118,  2.6281, -0.0603,  4.1654,  3.1464,
           2.6079,  5.4117,  1.6185,  0.9433],
         [ 1.7041,  4.2058,  2.3291,  2.8075,  2.3852,  4.4864,  4.2541,
           1.9252,  1.6006,  3.2935,  1.8956],
         [ 2.5204,  2.2754,  1.3410,  1.0843,  1.8647,  1.0619,  1.4325,
           1.3814,  0.9903,  2.6481,  1.3210],
         [ 1.3530,  2.3135,  1.7354,  2.3050,  2.6622,  3.2153,  2.6942,
           3.2961,  2.3779,  4.5564,  1.5830],
         [ 4.2169,  0.1637,  2.1036,  0.7843,  1.8701,  2.1472,  1.6780,
           1.7922,  2.5391,  2.0627,  2.7914],
         [ 1.6887,  2.3329,  0.4537,  2.5321,  3.0966,  2.2392,  0.9720,
           3.2287,  2.9211,  2.7971,  1.3497],
         [ 2.5053,  2.0959,  1.9453,  0.

In [206]:
a.max

torch.Size([3])

In [295]:
probmap = torch.softmax(torch.randn(10, 8, 40, 40), dim=1)
img = torch.randn((10, 3, 40, 40))

In [296]:
img

tensor([[[[-1.8112e+00, -8.2045e-01,  1.5209e+00,  ..., -4.7582e-01,
           -9.5373e-01, -9.0694e-01],
          [ 1.1852e+00,  5.2978e-01, -1.3915e+00,  ..., -8.6812e-01,
            1.1956e-01, -1.8131e+00],
          [-2.2529e+00, -1.0193e+00, -4.5320e-01,  ..., -6.3770e-01,
            2.0964e+00,  1.1363e-01],
          ...,
          [-2.8639e-01,  6.5036e-01,  2.2802e-01,  ..., -7.3508e-01,
            3.3833e-01,  6.2759e-03],
          [-5.1036e-01, -9.4658e-01, -1.9916e+00,  ..., -1.0299e+00,
           -3.2999e-01,  6.9528e-01],
          [ 6.3311e-01,  1.3467e+00, -9.0909e-01,  ...,  4.9998e-01,
            1.4604e-01, -7.2562e-01]],

         [[-2.4962e+00,  1.1567e+00, -8.3286e-01,  ..., -1.4584e+00,
           -2.0079e+00,  4.6978e-01],
          [-5.8979e-01, -1.0458e-01, -1.3512e-01,  ...,  9.6694e-01,
            2.9118e-01, -1.3437e+00],
          [-9.1424e-01,  6.3626e-01,  2.4245e-01,  ...,  4.9379e-01,
            1.2537e+00, -9.5697e-01],
          ...,
     

In [297]:
img = (img*255).numpy().astype(np.uint8).transpose(0, 2, 3, 1)

array([[[[ 51, 132, 249],
         [ 47,  38, 248],
         [131,  44, 242],
         ...,
         [135, 141, 230],
         [ 13,   0,  57],
         [ 25, 119, 196]],

        [[ 46, 106, 162],
         [135, 230, 249],
         [158, 222, 182],
         ...,
         [ 35, 246,  44],
         [ 30,  74,  42],
         [ 50, 170, 185]],

        [[194,  23, 239],
         [253, 162,  69],
         [141,  61,  99],
         ...,
         [ 94, 125,  18],
         [ 22,  63,   2],
         [ 28,  12, 169]],

        ...,

        [[183, 164, 123],
         [165, 126,  37],
         [ 58,  48, 169],
         ...,
         [ 69,  74, 129],
         [ 86, 101,  36],
         [  1,  51, 255]],

        [[126, 169,  35],
         [ 15, 233, 216],
         [  5,  53,  19],
         ...,
         [250, 231,  25],
         [172,  34, 199],
         [177, 206,  25]],

        [[161, 206, 102],
         [ 87, 125, 128],
         [ 25, 165, 223],
         ...,
         [127,  60, 135],
        

In [260]:
from utils.crf import DenseCRF
import numpy as np

In [261]:
crf = DenseCRF()

In [268]:
probmap = probmap.numpy()

In [269]:
# images: (B,C,H,W) -> (B,H,W,C)
Q = Parallel(n_jobs=-1)(
    [
        delayed(crf)(*pair) for pair in zip(img, probmap)
    ]
)

In [274]:
len(Q)

10

In [275]:
Q[0].shape

(8, 40, 40)

In [301]:
len(Q)

10

In [305]:
q =torch.tensor(Q)

In [307]:
q.shape

torch.Size([10, 8, 40, 40])

In [309]:
q.dtype

torch.float32