In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision

from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.utils import save_image

import argparse
import numpy as np
import pandas as pd
import yaml
import scipy.io
import skimage.io
import time
import tqdm

from addict import Dict
from PIL import Image, ImageFilter
from tensorboardX import SummaryWriter

from models.SegNet import SegNetBasic
from models.discriminator import Discriminator


In [6]:
class PartAffordanceDataset(Dataset):
    """Part Affordance Dataset"""
    
    def __init__(self, csv_file, transform=None):
        super().__init__()
        
        self.image_class_path = pd.read_csv(csv_file)
        self.transform = transform
        
    def __len__(self):
        return len(self.image_class_path)
    
    def __getitem__(self, idx):
        image_path = self.image_class_path.iloc[idx, 0]
        class_path = self.image_class_path.iloc[idx, 1]
        image = skimage.io.imread(image_path)
        cls = scipy.io.loadmat(class_path)["gt_label"]
        
        sample = {'image': image, 'class': cls}
        
        if self.transform:
            sample = self.transform(sample)
            
        return sample



class PartAffordanceDatasetWithoutLabel(Dataset):
    """ Part Affordance Dataset without label """
    
    def __init__(self, csv_file, transform=None):
        super().__init__()
        
        self.image_path = pd.read_csv(csv_file)
        self.transform = transform
        
    def __len__(self):
        return len(self.image_path)
    
    def __getitem__(self, idx):
        image_path = self.image_path.iloc[idx, 0]
        image = Image.open(image_path) 
        
        sample = {'image': image}

        if self.transform:
            sample = self.transform(sample)
            
        return sample

In [7]:
def crop_center_numpy(array, crop_height, crop_weight):
    h, w = array.shape
    return array[h//2 - crop_height//2: h//2 + crop_height//2,
                w//2 - crop_weight//2: w//2 + crop_weight//2
                ]


def crop_center_pil_image(pil_img, crop_height, crop_width):
    w, h = pil_img.size
    return pil_img.crop(((w - crop_width) // 2,
                        (h - crop_height) // 2,
                        (w + crop_width) // 2,
                        (h + crop_height) // 2))


class CenterCrop(object):
    def __call__(self, sample):
        
        if 'class' in sample:
            image, cls = sample['image'], sample['class']
            image = crop_center_pil_image(image, 256, 320)
            cls = crop_center_numpy(cls, 256, 320)
            return {'image': image, 'class': cls}
            
        else:
            image = sample['image']
            image = crop_center_pil_image(image, 256, 320)
            return {'image': image}



class ToTensor(object):
    def __call__(self, sample):
        
        if 'class' in sample:
            image, cls = sample['image'], sample['class']
            return {'image': torch.from_numpy(image).float(), 
                    'class': torch.from_numpy(cls).long()}
        else:
            image = sample['image']
            return {'image': transforms.functional.to_tensor(image).float()}



class Normalize(object):
    def __init__(self, mean=[55.8630, 59.9099, 91.7419], std=[31.6852, 29.8496, 19.0835]):
        self.mean = mean
        self.std = std


    def __call__(self, sample):

        if 'class' in sample:
            image, cls = sample['image'], sample['class']
            image = transforms.functional.normalize(image, self.mean, self.std)
            return {'image': image, 'class': cls}
        else:
            image = sample['image']
            image = transforms.functional.normalize(image, self.mean, self.std)
            return {'image': image}




In [9]:
model = SegNetBasic(3, 8)

In [3]:
def full_train(model, sample, criterion_ce_full, optimizer, device):

    ''' full supervised learning for segmentation network'''

    model.train()

    x, y = sample['image'], sample['class']

    x = x.to(device)
    y = y.to(device)

    h = model(x)     # shape => (N, 8, H, W)

    loss_ce = criterion_ce_full(h, y)

    optimizer.zero_grad()
    loss_ce.backward()
    optimizer.step()

    return loss_ce.item()

In [4]:
def eval_model(model, test_loader, device='cpu'):
    model.eval()
    
    start = time.time()
    
    intersection = torch.zeros(8)   # the dataset has 8 classes including background
    union = torch.zeros(8)
    
    for k, sample in enumerate(test_loader):
        x, y = sample['image'], sample['class']
        
        x = x.to(device)
        y = y.to(device)
        
        with torch.no_grad():
            ypred = model(x)    # ypred.shape => (N, 8, H, W)
            _, ypred = ypred.max(1)    # y_pred.shape => (N, 256, 320)

        for i in range(8):
            y_i = (y == i)           
            ypred_i = (ypred == i)   
            
            inter = (y_i.byte() & ypred_i.byte()).float().sum().to('cpu')
            intersection[i] += inter
            union[i] += (y_i.float().sum() + ypred_i.float().sum()).to('cpu') - inter
            
        if k == 10:
            break
    
        
    """ iou[i] is the IoU of class i """
    iou = intersection / union
    
    taken_time = start - time.time()
    
    print(taken_time)
    
    return iou

In [10]:
CONFIG = Dict(yaml.safe_load(open('./result_segnet/config_segnet.yaml')))

In [14]:
train_data_with_label = PartAffordanceDataset('train.csv',
                                        transform=transforms.Compose([
                                            CenterCrop(),
                                            ToTensor(),
                                            Normalize()
                                        ]))

train_data_without_label = PartAffordanceDatasetWithoutLabel('train_without_label_4to1.csv',
                                        transform=transforms.Compose([
                                            CenterCrop(),
                                            ToTensor(),
                                            Normalize()
                                        ]))

test_data = PartAffordanceDataset('test.csv',
                            transform=transforms.Compose([
                                ToTensor(),
                            ]))

train_loader_with_label = DataLoader(train_data_with_label, batch_size=CONFIG.batch_size, shuffle=True, num_workers=CONFIG.num_workers)
train_loader_without_label = DataLoader(train_data_without_label, batch_size=CONFIG.batch_size, shuffle=True, num_workers=CONFIG.num_workers)
test_loader = DataLoader(test_data, batch_size=CONFIG.batch_size, shuffle=False)

In [7]:
def one_hot(label, n_classes, device):
    one_hot_label = torch.eye(n_classes, requires_grad=True, device=device)[label].transpose(1, 3).transpose(2, 3)
    return one_hot_label

In [21]:
def eval_model(model, test_loader, device='cpu'):
    
    start = time.time()
    
    intersections = torch.zeros(8).to(device)
    unions = torch.zeros(8).to(device)
    
    for i, sample in enumerate(test_loader):
        x = sample['image']
        y = sample['class']
        
        x = x.to(device)
        y = y.to(device)
        
        with torch.no_grad():
            ypred = model(x)    # ypred.shape => (N, 8, H, W)
            _, ypred = ypred.max(1)    # y_pred.shape => (N, 256, 320)

        p = one_hot(ypred, 8, device).long()
        t = one_hot(y, 8, device).long()
        
        intersection = torch.sum(p & t, (0,2,3))
        union = torch.sum(p | t, (0, 2, 3))
        
        intersections += intersection.float()
        unions += union.float()
        
        if i == 10:
            break
        
    iou = intersections / unions
    
    taken_time = time.time() - start
    print(taken_time)
    
    return iou

In [2]:
a = torch.zeros(10, 8, 256, 320)

In [3]:
for i in range(8):
    a[:,i] = i

In [4]:
a

tensor([[[[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]],

         [[1., 1., 1.,  ..., 1., 1., 1.],
          [1., 1., 1.,  ..., 1., 1., 1.],
          [1., 1., 1.,  ..., 1., 1., 1.],
          ...,
          [1., 1., 1.,  ..., 1., 1., 1.],
          [1., 1., 1.,  ..., 1., 1., 1.],
          [1., 1., 1.,  ..., 1., 1., 1.]],

         [[2., 2., 2.,  ..., 2., 2., 2.],
          [2., 2., 2.,  ..., 2., 2., 2.],
          [2., 2., 2.,  ..., 2., 2., 2.],
          ...,
          [2., 2., 2.,  ..., 2., 2., 2.],
          [2., 2., 2.,  ..., 2., 2., 2.],
          [2., 2., 2.,  ..., 2., 2., 2.]],

         ...,

         [[5., 5., 5.,  ..., 5., 5., 5.],
          [5., 5., 5.,  ..., 5., 5., 5.],
          [5., 5., 5.,  ..., 5., 5., 5.],
          ...,
          [5., 5., 5.,  ..., 5., 5., 

In [7]:
b = F.softmax(a, dim=1)

In [12]:
c = torch.sum(b,dim=1)

In [14]:
c.shape

torch.Size([10, 256, 320])

In [1]:
import torch


In [3]:
torch.rand((3, 8, 256, 320))

tensor([[[[5.2787e-02, 4.5112e-01, 2.3507e-01,  ..., 5.9310e-01,
           6.2576e-01, 5.0098e-01],
          [5.1723e-01, 8.6876e-01, 3.5743e-01,  ..., 4.6504e-01,
           8.6561e-01, 9.0760e-01],
          [7.2756e-02, 6.3288e-01, 7.6766e-01,  ..., 6.8076e-01,
           6.3129e-01, 9.5399e-01],
          ...,
          [3.4836e-01, 7.6110e-01, 4.8951e-01,  ..., 3.9664e-01,
           8.9063e-01, 7.8384e-01],
          [9.1087e-01, 7.2367e-01, 8.6599e-02,  ..., 9.0496e-01,
           2.2214e-01, 2.2129e-01],
          [1.4003e-01, 9.2004e-01, 5.3931e-01,  ..., 3.9316e-01,
           8.9206e-01, 8.9963e-01]],

         [[2.2617e-01, 3.6583e-01, 2.9902e-01,  ..., 4.2709e-01,
           8.0931e-03, 5.5957e-01],
          [4.6404e-01, 7.3141e-01, 9.9498e-01,  ..., 6.6768e-01,
           4.7515e-01, 1.9825e-01],
          [8.9449e-01, 1.4594e-01, 5.7890e-01,  ..., 2.9833e-01,
           9.0014e-01, 8.3672e-01],
          ...,
          [6.7938e-01, 5.0287e-01, 6.8052e-01,  ..., 2.0442