In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision

from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.utils import save_image

import argparse
import numpy as np
import pandas as pd
import yaml
import scipy.io
import time
import tqdm

from addict import Dict
from PIL import Image, ImageFilter
from tensorboardX import SummaryWriter

from models.SegNet import SegNetBasic
from models.discriminator import Discriminator
from dataset import PartAffordanceDataset, PartAffordanceDatasetWithoutLabel
from dataset import CenterCrop, ToTensor, Normalize

In [2]:
model = SegNetBasic(3, 8)

In [3]:
def init_weights(m):
    if isinstance(m, nn.Conv2d):
        nn.init.kaiming_normal_(m.weight, nonlinearity='relu')
        if m.bias is not None:
            nn.init.constant_(m.bias, 0)
    elif isinstance(m, nn.Linear):
        nn.init.kaiming_normal_(m.weight, nonlinearity='relu')
        if m.bias is not None:
            nn.init.constant_(m.bias, 0)
    elif isinstance(m, nn.BatchNorm2d):
        nn.init.constant_(m.weight, 1)
        if m.bias is not None:
            nn.init.constant_(m.weight, 0)

In [4]:
model.apply(init_weights)

SegNetBasic(
  (encoder1): Encoder(
    (conv): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (encoder2): Encoder(
    (conv): Conv2d(64, 80, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (bn): BatchNorm2d(80, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (encoder3): Encoder(
    (conv): Conv2d(80, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (bn): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (encoder4): Encoder(
    (conv): Conv2d(96, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (decoder1): Decoder(
    (conv): Conv2d(128, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (bn): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (decoder

In [5]:
def full_train(model, sample, criterion_ce_full, optimizer, device):

    ''' full supervised learning for segmentation network'''

    model.train()

    x, y = sample['image'], sample['class']

    x = x.to(device)
    y = y.to(device)

    h = model(x)     # shape => (N, 8, H, W)

    loss_ce = criterion_ce_full(h, y)

    optimizer.zero_grad()
    loss_ce.backward()
    optimizer.step()

    return loss_ce.item()

In [6]:
def eval_model(model, test_loader, device='cpu'):
    model.eval()
    
    start = time.time()
    
    intersection = torch.zeros(8)   # the dataset has 8 classes including background
    union = torch.zeros(8)
    
    for k, sample in enumerate(test_loader):
        x, y = sample['image'], sample['class']
        
        x = x.to(device)
        y = y.to(device)
        
        with torch.no_grad():
            ypred = model(x)    # ypred.shape => (N, 8, H, W)
            _, ypred = ypred.max(1)    # y_pred.shape => (N, 256, 320)

        for i in range(8):
            y_i = (y == i)           
            ypred_i = (ypred == i)   
            
            inter = (y_i.byte() & ypred_i.byte()).float().sum().to('cpu')
            intersection[i] += inter
            union[i] += (y_i.float().sum() + ypred_i.float().sum()).to('cpu') - inter
            
        if k == 10:
            break
    
        
    """ iou[i] is the IoU of class i """
    iou = intersection / union
    
    taken_time = start - time.time()
    
    print(taken_time)
    
    return iou

In [7]:
CONFIG = Dict(yaml.safe_load(open('./result_segnet/config_segnet.yaml')))

In [8]:
train_data_with_label = PartAffordanceDataset('train.csv',
                                        transform=transforms.Compose([
                                            CenterCrop(),
                                            ToTensor(),
                                            Normalize()
                                        ]))

train_data_without_label = PartAffordanceDatasetWithoutLabel('train_without_label_4to1.csv',
                                        transform=transforms.Compose([
                                            CenterCrop(),
                                            ToTensor(),
                                            Normalize()
                                        ]))

test_data = PartAffordanceDataset('test.csv',
                            transform=transforms.Compose([
                                CenterCrop(),
                                ToTensor(),
                                Normalize()
                            ]))

train_loader_with_label = DataLoader(train_data_with_label, batch_size=CONFIG.batch_size, shuffle=True, num_workers=CONFIG.num_workers)
train_loader_without_label = DataLoader(train_data_without_label, batch_size=CONFIG.batch_size, shuffle=True, num_workers=CONFIG.num_workers)
test_loader = DataLoader(test_data, batch_size=CONFIG.batch_size, shuffle=False)

In [9]:
eval_model(model, test_loader)

-57.366888999938965


tensor([0.8753,    nan,    nan,    nan, 0.0000,    nan,    nan,    nan])

In [10]:
def one_hot(label, n_classes, device):
    one_hot_label = torch.eye(n_classes, requires_grad=True, device=device)[label].transpose(1, 3).transpose(2, 3)
    return one_hot_label

In [11]:
def eval_model(model, test_loader, device='cpu'):
    
    start = time.time()
    
    intersections = torch.zeros(8).to(device)
    unions = torch.zeros(8).to(device)
    
    for i, sample in enumerate(test_loader):
        x = sample['image']
        y = sample['class']
        
        x = x.to(device)
        y = y.to(device)
        
        with torch.no_grad():
            ypred = model(x)    # ypred.shape => (N, 8, H, W)
            _, ypred = ypred.max(1)    # y_pred.shape => (N, 256, 320)

        p = one_hot(ypred, 8, device).long()
        t = one_hot(y, 8, device).long()
        
        intersection = torch.sum(p & t, (0,2,3))
        union = torch.sum(p | t, (0, 2, 3))
        
        intersections += intersection.float()
        unions += union.float()
        
        if i == 10:
            break
        
    iou = intersections / unions
    
    taken_time = time.time() - start
    print(taken_time)
    
    return iou

In [12]:
eval_model(model, test_loader)

57.96979212760925


tensor([0.8753,    nan,    nan,    nan, 0.0000,    nan,    nan,    nan])

In [3]:
a = torch.ones(10, 8, 256, 320).long()
b = torch.ones(10, 8, 256, 320).long()
c = torch.zeros(10, 8, 256, 320).long()

In [4]:
z =a | c

In [28]:
torch.sum(z, (2,3)).shape

torch.Size([10, 8])