In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.utils.data
import torchvision

from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.utils import save_image

import pandas as pd
import numpy as np
import scipy.io
import skimage.io

from PIL import Image, ImageFilter

# Dataset Class

In [2]:
# input size => (N, 3, 256, 320)

class PartAffordanceDataset(Dataset):
    """Part Affordance Dataset"""
    
    def __init__(self, csv_file, transform=None):
        super().__init__()
        
        self.image_class_path = pd.read_csv(csv_file)
        self.transform = transform
        
    def __len__(self):
        return len(self.image_class_path)
    
    def __getitem__(self, idx):
        image_path = self.image_class_path.iloc[idx, 0]
        class_path = self.image_class_path.iloc[idx, 1]
        image = skimage.io.imread(image_path) # read as numpy array
        cls = scipy.io.loadmat(class_path)["gt_label"]
        
        sample = {'image': image, 'class': cls}
        
        if self.transform:
            sample = self.transform(sample)
            
        return sample

In [3]:
def crop_center_numpy(array, crop_height, crop_weight):
    h, w = array.shape
    return array[h//2 - crop_height//2: h//2 + crop_height//2,
                 w//2 - crop_weight//2: w//2 + crop_weight//2
                ]

In [4]:
def crop_center_pil_image(pil_img, crop_width, crop_height):
    img_width, img_height = pil_img.size
    return pil_img.crop(((img_width - crop_width) // 2,
                         (img_height - crop_height) // 2,
                         (img_width + crop_width) // 2,
                         (img_height + crop_height) // 2))

In [5]:
class CenterCrop(object):
    def __call__(self, sample):
        image, cls = sample['image'], sample['class']
        
        image = Image.fromarray(np.uint8(image))
        
        image = crop_center_pil_image(image, 320, 256)
        cls = crop_center_numpy(cls, 256, 320)
        
        image = np.asarray(image)
        
        return {'image': image, 'class': cls}

In [6]:
class ToTensor(object):
    def __call__(self, sample):
        image, cls = sample['image'], sample['class']
        
        image = image.transpose((2, 0, 1))
        # cls = cls.transpose((2, 0, 1))
        return {'image': torch.from_numpy(image).float(), 
                'class': torch.from_numpy(cls).long()}

In [7]:
# mean = 0
# std = 0
# n = 0

# for sample in data_laoder:
#     img = sample['image']   
#     img = img.view(len(img), 3, -1)
#     mean += img.mean(2).sum(0)
#     std += img.std(2).sum(0)
#     n += len(img)
    
# mean /= n
# std /= n

In [8]:
mean=[55.8630, 59.9099, 91.7419]
std=[31.6852, 29.8496, 19.0835]

In [9]:
class Normalize(object):
    def __call__(self, sample):
        image, cls = sample['image'], sample['class']
        
        image = transforms.functional.normalize(image, mean, std)
        
        return {'image': image, 'class': cls}

In [10]:
train_data = PartAffordanceDataset('train.csv',
                                transform=transforms.Compose([
                                    CenterCrop(),
                                    # OneHot(), # don't need to convert class into one-hot when calculating CrossEntropy Loss
                                    ToTensor(),
                                    Normalize()
                                ]))

In [11]:
test_data = PartAffordanceDataset('test.csv',
                                transform=transforms.Compose([
                                    CenterCrop(),
                                    ToTensor(),
                                    Normalize()
                                ]))

In [12]:
train_loader = DataLoader(train_data, batch_size=20, shuffle=True, num_workers=4)
test_loader = DataLoader(test_data, batch_size=20, shuffle=False, num_workers=4)

### class weight

In [13]:
# cnt_dict = {0:0, 1:0, 2:0, 3:0, 4:0, 5:0, 6:0, 7:0}

# for sample in data_laoder:
#     img = sample['class'].numpy()
    
#     num, cnt = np.unique(img, return_counts=True)
    
#     for n, c in zip(num, cnt):
#         cnt_dict[n] += c

cnt_dict  
 0: 2078085712,  
 1: 34078992,  
 2: 15921090,  
 3: 12433420,  
 4: 38473752,  
 5: 6773528,  
 6: 9273826,  
 7: 20102080  

In [14]:
class_num = torch.tensor([2078085712, 34078992, 15921090, 12433420, 
                          38473752, 6773528, 9273826, 20102080])

total = class_num.sum().item()
print(total)

2215142400


In [15]:
frequency = class_num.float() / total
median = torch.median(frequency)

In [16]:
class_weight = median / frequency

In [17]:
class_weight

tensor([0.0077, 0.4672, 1.0000, 1.2805, 0.4138, 2.3505, 1.7168, 0.7920])

# Define Model

In [18]:
class DeconvBn_2(nn.Module):
    """ Deconvolution(stride=2) => Batch Normilization """
    
    def __init__(self, in_channel, out_channel):
        super().__init__()
        
        self.deconv = nn.ConvTranspose2d(in_channel, out_channel, kernel_size=2, stride=2, bias=False)
        self.bn = nn.BatchNorm2d(out_channel)
        
    def forward(self, x):
        return self.bn(self.deconv(x))

In [19]:
class DeconvBn_8(nn.Module):
    """ Deconvolution(stride=8) => Batch Normilization """
    
    def __init__(self, in_channel, out_channel):
        super().__init__()
        
        self.deconv = nn.ConvTranspose2d(in_channel, out_channel, kernel_size=8, stride=8, bias=False)
        self.bn = nn.BatchNorm2d(out_channel)
        
    def forward(self, x):
        return self.bn(self.deconv(x))

In [20]:
class FCN8s(nn.Module):
    """ Fully Convolutional Network"""
    
    def __init__(self, in_channel, n_classes):
        super().__init__()
        vgg =torchvision.models.vgg16_bn(pretrained=True).features
        
        # confirm the architecture of vgg16 by "print(vgg)"
        
        self.pool3 = vgg[:24]
        self.pool4 = vgg[24:34]
        self.pool5 = vgg[34:]
        
        self.deconv_bn1 = DeconvBn_2(512, 512)
        self.deconv_bn2 = DeconvBn_2(512, 256)
        self.deconv_bn3 = DeconvBn_8(256, n_classes)
        
    def forward(self, x):
        # vgg16
        x3 = self.pool3(x)     # output size => (N, 256, H/8, W/8)
        x4 = self.pool4(x3)    # output size => (N, 512, H/16, W/16)
        x5 = self.pool5(x4)    # output size => (N, 512, H/32, W/32)
        
        score = self.deconv_bn1(x5)
        score = self.deconv_bn2(x4 + score)
        score = self.deconv_bn3(x3 + score)
        
        return score

# Training

In [21]:
from tensorboardX import SummaryWriter
import tqdm

In [22]:
# reference

# SMOOTH = 1e-6

# def iou_pytorch(outputs: torch.Tensor, labels: torch.Tensor):
#     # You can comment out this line if you are passing tensors of equal shape
#     # But if you are passing output from UNet or something it will most probably
#     # be with the BATCH x 1 x H x W shape
#     outputs = outputs.squeeze(1)  # BATCH x 1 x H x W => BATCH x H x W
    
#     intersection = (outputs & labels).float().sum((1, 2))  # Will be zero if Truth=0 or Prediction=0
#     union = (outputs | labels).float().sum((1, 2))         # Will be zzero if both are 0
    
#     iou = (intersection + SMOOTH) / (union + SMOOTH)  # We smooth our devision to avoid 0/0
    
#     thresholded = torch.clamp(20 * (iou - 0.5), 0, 10).ceil() / 10  # This is equal to comparing with thresolds
    
#     return thresholded  # Or thresholded.mean() if you are interested in average across the batch

In [23]:
def eval_model(model, test_loader, device='cpu'):
    model.eval()
    
    intersection = torch.zeros(8)   # the dataset has 8 classes including background
    union = torch.zeros(8)
    
    for sample in test_loader:
        x, y = sample['image'], sample['class']
        
        x = x.to(device)
        y = y.to(device)
        
        with torch.no_grad():
            _, ypred = model(x).max(1)    # y_pred.shape => (N, 240, 320)
        
        for i in range(8):
            y_i = (y == i)           
            ypred_i = (ypred == i)   

            inter = (y_i.byte() & ypred_i.byte()).float().sum().to('cpu')
            intersection[i] += inter
            union[i] += (y_i.float().sum() + ypred_i.float().sum()).to('cpu') - inter
    
    """ iou[i] is the IoU of class i """
    iou = intersection / union
    
    return iou

In [24]:
def train_model(model, train_loader, test_loader, optimizer_cls=optim.Adadelta, 
                criterion=nn.CrossEntropyLoss(), max_epoch=200, device='cpu', writer=None):
    
    model.to(device)
    
    train_losses = []
    val_iou = []
    mean_iou = []
    best_mean_iou = 0.0
    
    optimizer = optimizer_cls(model.parameters())
    
    for epoch in range(max_epoch):
        model.train()
        running_loss = 0.0
        
        for i, sample in tqdm.tqdm(enumerate(train_loader), total=len(train_loader)):
            optimizer.zero_grad()
            
            x, y = sample['image'], sample['class']
            
            x = x.to(device)
            y = y.to(device)

            h = model(x)
            loss = criterion(h, y)
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()

        train_losses.append(running_loss / i)
        
        val_iou.append(eval_model(model, test_loader, device).to('cpu').float())
        mean_iou.append(val_iou[-1].mean().item())
        
        if best_mean_iou < mean_iou[-1]:
            best_mean_iou = mean_iou[-1]
            torch.save(model.state_dict(), "./FCN_results_with_class_weight/best_mean_iou_model.prm")
        
        if writer is not None:
            writer.add_scalar("train_loss", train_losses[-1], epoch)
            writer.add_scalar("mean_IoU", mean_iou[-1], epoch)
            writer.add_scalars("class_IoU", {'iou of class 0': val_iou[-1][0],
                                           'iou of class 1': val_iou[-1][1],
                                           'iou of class 2': val_iou[-1][2],
                                           'iou of class 3': val_iou[-1][3],
                                           'iou of class 4': val_iou[-1][4],
                                           'iou of class 5': val_iou[-1][5],
                                           'iou of class 6': val_iou[-1][6],
                                           'iou of class 7': val_iou[-1][7]}, epoch)
            
        print(epoch, train_losses[-1], mean_iou[-1])
        
    torch.save(model.state_dict(), "./FCN_results_with_class_weight/final_model.prm")

In [27]:
def init_weight(m):
    if isinstance(m, nn.Conv2d) or isinstance(m, torch.nn.Linear):
        torch.nn.init.xavier_uniform_(m.weight)
        if m.bias is not None:
            torch.nn.init.zeros_(m.bias)

In [28]:
model = FCN8s(3, 8)
model.apply(init_weight)

FCN8s(
  (pool3): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace)
    (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace)
    (6): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (7): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (9): ReLU(inplace)
    (10): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (12): ReLU(inplace)
    (13): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (14): Conv2d(128, 256, kernel_size=(3, 3),

In [25]:
writer = SummaryWriter("./FCN_results_with_class_weight/")
train_model(model, train_loader, test_loader, criterion=nn.CrossEntropyLoss(weight=class_weight.to('cuda')), device="cuda", writer=writer)

100%|██████████| 1180/1180 [07:20<00:00,  3.03it/s]


0 0.3986462745610135 0.6330952048301697


100%|██████████| 1180/1180 [07:16<00:00,  3.03it/s]


1 0.07249707246927561 0.686819314956665


100%|██████████| 1180/1180 [07:12<00:00,  3.04it/s]


2 0.053750061802092096 0.7170349955558777


100%|██████████| 1180/1180 [07:12<00:00,  3.05it/s]


3 0.047837933561779664 0.738013505935669


100%|██████████| 1180/1180 [07:12<00:00,  3.05it/s]


4 0.06748823552021949 0.7232653498649597


100%|██████████| 1180/1180 [07:12<00:00,  2.99it/s]


5 0.04451714942991683 0.7186663746833801


100%|██████████| 1180/1180 [07:12<00:00,  3.04it/s]


6 0.04949091883586361 0.7446253299713135


100%|██████████| 1180/1180 [07:12<00:00,  3.04it/s]


7 0.038958875809980205 0.7573626041412354


100%|██████████| 1180/1180 [07:12<00:00,  3.05it/s]


8 0.03696042795976003 0.7745365500450134


100%|██████████| 1180/1180 [07:16<00:00,  2.99it/s]


9 0.035267938623572935 0.7924236059188843


100%|██████████| 1180/1180 [07:19<00:00,  2.98it/s]


10 0.03392795960247466 0.7569003105163574


100%|██████████| 1180/1180 [07:18<00:00,  3.06it/s]


11 0.03261894445180691 0.7828565835952759


100%|██████████| 1180/1180 [07:19<00:00,  2.99it/s]


12 0.031458342713292016 0.7871162295341492


100%|██████████| 1180/1180 [07:13<00:00,  3.05it/s]


13 0.03059727944231468 0.7941336035728455


100%|██████████| 1180/1180 [07:17<00:00,  3.04it/s]


14 0.029629326429442392 0.8041478991508484


100%|██████████| 1180/1180 [07:17<00:00,  3.01it/s]


15 0.028664745744301973 0.7933037281036377


100%|██████████| 1180/1180 [07:19<00:00,  3.02it/s]


16 0.028031474382216086 0.7994261980056763


100%|██████████| 1180/1180 [07:17<00:00,  2.98it/s]


17 0.02691004977682392 0.8082736134529114


100%|██████████| 1180/1180 [07:18<00:00,  3.05it/s]


18 0.02623794734733933 0.810982346534729


100%|██████████| 1180/1180 [07:19<00:00,  3.05it/s]


19 0.025467447873340814 0.8091073632240295


100%|██████████| 1180/1180 [07:17<00:00,  2.94it/s]


20 0.024652934267459733 0.8169394135475159


100%|██████████| 1180/1180 [07:24<00:00,  3.03it/s]


21 0.023951891026885403 0.828045129776001


100%|██████████| 1180/1180 [07:20<00:00,  2.97it/s]


22 0.02331720533969546 0.8212076425552368


100%|██████████| 1180/1180 [07:21<00:00,  2.90it/s]


23 0.022707848581370912 0.8224647045135498


100%|██████████| 1180/1180 [07:19<00:00,  3.03it/s]


24 0.022007828083069423 0.8194459080696106


100%|██████████| 1180/1180 [07:21<00:00,  3.01it/s]


25 0.02115694596649379 0.8196450471878052


100%|██████████| 1180/1180 [07:21<00:00,  2.89it/s]


26 0.02052305472036784 0.8237771391868591


100%|██████████| 1180/1180 [07:24<00:00,  2.97it/s]


27 0.02004884694567846 0.8179020285606384


100%|██████████| 1180/1180 [07:21<00:00,  2.86it/s]


28 0.019612822549747805 0.8339904546737671


100%|██████████| 1180/1180 [07:19<00:00,  3.04it/s]


29 0.0190106291763099 0.8368421792984009


100%|██████████| 1180/1180 [07:14<00:00,  3.05it/s]


30 0.018228108834307814 0.8317654132843018


100%|██████████| 1180/1180 [07:23<00:00,  2.97it/s]


31 0.017727200355004767 0.835856020450592


100%|██████████| 1180/1180 [07:17<00:00,  3.02it/s]


32 0.017070746170933515 0.8372489213943481


100%|██████████| 1180/1180 [07:15<00:00,  3.01it/s]


33 0.01661197917237309 0.8361132740974426


100%|██████████| 1180/1180 [07:15<00:00,  3.03it/s]


34 0.01620193738395552 0.8445684909820557


100%|██████████| 1180/1180 [07:15<00:00,  3.03it/s]


35 0.015752410050086596 0.8353536128997803


100%|██████████| 1180/1180 [07:14<00:00,  3.03it/s]


36 0.015176889972415286 0.8483723998069763


100%|██████████| 1180/1180 [07:13<00:00,  3.04it/s]


37 0.014751077134737254 0.849240243434906


100%|██████████| 1180/1180 [07:13<00:00,  3.04it/s]


38 0.014439639049259104 0.8511528968811035


100%|██████████| 1180/1180 [07:13<00:00,  3.03it/s]


39 0.013978596950507549 0.8560994863510132


100%|██████████| 1180/1180 [07:13<00:00,  3.03it/s]


40 0.01355205631622945 0.851409375667572


100%|██████████| 1180/1180 [07:13<00:00,  3.04it/s]


41 0.013321645762471196 0.8525335192680359


100%|██████████| 1180/1180 [07:13<00:00,  3.04it/s]


42 0.013336511093023548 0.8571277856826782


100%|██████████| 1180/1180 [07:13<00:00,  3.03it/s]


43 0.012631055792719378 0.8593115210533142


100%|██████████| 1180/1180 [07:13<00:00,  3.04it/s]


44 0.012532936889390707 0.8614533543586731


100%|██████████| 1180/1180 [07:13<00:00,  3.04it/s]


45 0.011980587559293219 0.8611907958984375


100%|██████████| 1180/1180 [07:13<00:00,  3.04it/s]


46 0.01176652671774001 0.8575897216796875


100%|██████████| 1180/1180 [07:13<00:00,  3.03it/s]


47 0.01159801906364969 0.8579355478286743


100%|██████████| 1180/1180 [07:13<00:00,  3.03it/s]


48 0.011095181879673573 0.8594158291816711


100%|██████████| 1180/1180 [07:13<00:00,  3.03it/s]


49 0.010797758060219154 0.8653665781021118


100%|██████████| 1180/1180 [07:13<00:00,  3.04it/s]


50 0.010542121445689038 0.8691108226776123


100%|██████████| 1180/1180 [07:13<00:00,  3.03it/s]


51 0.010410367526221417 0.867010235786438


100%|██████████| 1180/1180 [07:13<00:00,  3.03it/s]


52 0.010134548922420888 0.8657082319259644


100%|██████████| 1180/1180 [07:13<00:00,  3.03it/s]


53 0.009794563458295597 0.8644595146179199


100%|██████████| 1180/1180 [07:13<00:00,  3.04it/s]


54 0.009632921892994788 0.8701090216636658


100%|██████████| 1180/1180 [07:13<00:00,  3.03it/s]


55 0.009441877599565456 0.8707809448242188


100%|██████████| 1180/1180 [07:13<00:00,  3.04it/s]


56 0.009206766622461714 0.8689863085746765


100%|██████████| 1180/1180 [07:13<00:00,  3.03it/s]


57 0.008985596727911601 0.8744690418243408


100%|██████████| 1180/1180 [07:13<00:00,  3.03it/s]


58 0.00877732467269675 0.7909148931503296


100%|██████████| 1180/1180 [07:13<00:00,  3.03it/s]


59 0.008620366722699238 0.8732945322990417


100%|██████████| 1180/1180 [07:13<00:00,  3.03it/s]


60 0.008441051227325752 0.877457320690155


100%|██████████| 1180/1180 [07:13<00:00,  3.03it/s]


61 0.008312033018155736 0.8686455488204956


100%|██████████| 1180/1180 [07:13<00:00,  3.03it/s]


62 0.008045009236650082 0.878824770450592


100%|██████████| 1180/1180 [07:13<00:00,  3.04it/s]


63 0.007880891220244414 0.8756841421127319


100%|██████████| 1180/1180 [07:13<00:00,  3.03it/s]


64 0.007740106393738734 0.8747234344482422


100%|██████████| 1180/1180 [07:13<00:00,  3.04it/s]


65 0.007706760200952276 0.8797920346260071


100%|██████████| 1180/1180 [07:13<00:00,  3.04it/s]


66 0.007444733037732611 0.8783200979232788


100%|██████████| 1180/1180 [07:13<00:00,  3.04it/s]


67 0.007290676086461701 0.881561279296875


100%|██████████| 1180/1180 [07:13<00:00,  3.04it/s]


68 0.008284352340619376 0.8826940059661865


100%|██████████| 1180/1180 [07:13<00:00,  3.04it/s]


69 0.07626954017229144 0.7809682488441467


100%|██████████| 1180/1180 [07:13<00:00,  3.03it/s]


70 0.021889474422238812 0.8339248895645142


100%|██████████| 1180/1180 [07:13<00:00,  3.02it/s]


71 0.014948306873004703 0.8480001091957092


100%|██████████| 1180/1180 [07:13<00:00,  3.04it/s]


72 0.01143076577823391 0.8599839806556702


100%|██████████| 1180/1180 [07:13<00:00,  3.04it/s]


73 0.00978282620131881 0.8654844164848328


100%|██████████| 1180/1180 [07:13<00:00,  3.04it/s]


74 0.008783037239879976 0.869948148727417


100%|██████████| 1180/1180 [07:13<00:00,  3.04it/s]


75 0.008117154635176177 0.8764138221740723


100%|██████████| 1180/1180 [07:13<00:00,  3.03it/s]


76 0.007635331839841235 0.8776335120201111


100%|██████████| 1180/1180 [07:13<00:00,  3.03it/s]


77 0.007267582798621557 0.8826636075973511


100%|██████████| 1180/1180 [07:13<00:00,  3.03it/s]


78 0.007022728466533876 0.8813828229904175


100%|██████████| 1180/1180 [07:13<00:00,  3.03it/s]


79 0.006748659950953555 0.8827632665634155


100%|██████████| 1180/1180 [07:13<00:00,  3.02it/s]


80 0.006537075632783419 0.8866127729415894


100%|██████████| 1180/1180 [07:13<00:00,  3.04it/s]


81 0.006333086530632948 0.8819543123245239


100%|██████████| 1180/1180 [07:13<00:00,  3.04it/s]


82 0.0061988700860582035 0.8852930068969727


100%|██████████| 1180/1180 [07:13<00:00,  3.04it/s]


83 0.006044060561373084 0.8893274068832397


100%|██████████| 1180/1180 [07:13<00:00,  3.03it/s]


84 0.005852956018522003 0.8835666179656982


100%|██████████| 1180/1180 [07:13<00:00,  3.03it/s]


85 0.005691171641934716 0.8895459175109863


100%|██████████| 1180/1180 [07:13<00:00,  3.04it/s]


86 0.005573749586669657 0.8911742568016052


100%|██████████| 1180/1180 [07:13<00:00,  3.03it/s]


87 0.00549637973332261 0.890146017074585


100%|██████████| 1180/1180 [07:12<00:00,  3.04it/s]


88 0.0054299736703944295 0.8921504020690918


100%|██████████| 1180/1180 [07:12<00:00,  3.04it/s]


89 0.005295926568543095 0.8916536569595337


100%|██████████| 1180/1180 [07:12<00:00,  3.03it/s]


90 0.005180558231023343 0.8947907090187073


100%|██████████| 1180/1180 [07:12<00:00,  3.03it/s]


91 0.00508151025206413 0.8952339887619019


100%|██████████| 1180/1180 [07:12<00:00,  3.04it/s]


92 0.004964054749166376 0.8943475484848022


100%|██████████| 1180/1180 [07:12<00:00,  3.03it/s]


93 0.004872509338391394 0.8947336673736572


100%|██████████| 1180/1180 [07:12<00:00,  3.04it/s]


94 0.004816595037989918 0.8922699093818665


100%|██████████| 1180/1180 [07:12<00:00,  3.04it/s]


95 0.0047202554386839 0.8947696685791016


100%|██████████| 1180/1180 [07:12<00:00,  3.03it/s]


96 0.004703275690311795 0.8928834795951843


100%|██████████| 1180/1180 [07:12<00:00,  3.04it/s]


97 0.004575650288270782 0.8965234160423279


100%|██████████| 1180/1180 [07:12<00:00,  3.04it/s]


98 0.004492557880313051 0.8947445154190063


100%|██████████| 1180/1180 [07:12<00:00,  3.03it/s]


99 0.004430075681020298 0.8973491191864014


100%|██████████| 1180/1180 [07:12<00:00,  3.05it/s]


100 0.004370696449682332 0.8973464965820312


100%|██████████| 1180/1180 [07:12<00:00,  3.05it/s]


101 0.004311166880841891 0.8996188044548035


100%|██████████| 1180/1180 [07:12<00:00,  3.04it/s]


102 0.004284430714679583 0.899145781993866


100%|██████████| 1180/1180 [07:12<00:00,  3.04it/s]


103 0.004288420974347264 0.9003196954727173


100%|██████████| 1180/1180 [07:12<00:00,  3.04it/s]


104 0.004180079227292861 0.8979096412658691


100%|██████████| 1180/1180 [07:13<00:00,  3.05it/s]


105 0.004078142534731315 0.8975201845169067


100%|██████████| 1180/1180 [07:12<00:00,  3.03it/s]


106 0.003989116593040494 0.90157550573349


100%|██████████| 1180/1180 [07:12<00:00,  3.05it/s]


107 0.003926745279983596 0.8988988995552063


100%|██████████| 1180/1180 [07:12<00:00,  3.04it/s]


108 0.003906324373888353 0.9019366502761841


100%|██████████| 1180/1180 [07:12<00:00,  3.04it/s]


109 0.0038544193239593476 0.9024327397346497


100%|██████████| 1180/1180 [07:12<00:00,  3.04it/s]


110 0.003771733448376991 0.9027378559112549


100%|██████████| 1180/1180 [07:12<00:00,  3.04it/s]


111 0.0037740126769180843 0.9021874666213989


100%|██████████| 1180/1180 [07:13<00:00,  3.05it/s]


112 0.003785417046701256 0.9023447036743164


100%|██████████| 1180/1180 [07:12<00:00,  3.04it/s]


113 0.0036865819513756683 0.9023312926292419


100%|██████████| 1180/1180 [07:12<00:00,  3.04it/s]


114 0.0035589572314356248 0.9018086194992065


100%|██████████| 1180/1180 [07:12<00:00,  3.04it/s]


115 0.0035758447071870723 0.9054793119430542


100%|██████████| 1180/1180 [07:12<00:00,  3.04it/s]


116 0.0035249892374590685 0.9005846381187439


100%|██████████| 1180/1180 [07:12<00:00,  3.04it/s]


117 0.0035188107469306344 0.9054570198059082


100%|██████████| 1180/1180 [07:12<00:00,  3.03it/s]


118 0.003488637114368676 0.9022617936134338


100%|██████████| 1180/1180 [07:12<00:00,  3.05it/s]


119 0.004540065248742497 0.8963029980659485


100%|██████████| 1180/1180 [07:12<00:00,  3.05it/s]


120 0.004608701351726707 0.9000248908996582


100%|██████████| 1180/1180 [07:12<00:00,  3.04it/s]


121 0.004213470570068488 0.8998806476593018


100%|██████████| 1180/1180 [07:12<00:00,  3.04it/s]


122 0.0038149128462629514 0.9019002914428711


100%|██████████| 1180/1180 [07:12<00:00,  3.04it/s]


123 0.003510762106957415 0.9053968787193298


100%|██████████| 1180/1180 [07:13<00:00,  3.04it/s]


124 0.003295306201062092 0.9048678874969482


100%|██████████| 1180/1180 [07:13<00:00,  3.05it/s]


125 0.003193738340874905 0.9035369157791138


100%|██████████| 1180/1180 [07:12<00:00,  3.05it/s]


126 0.0031149029339552305 0.9062021374702454


100%|██████████| 1180/1180 [07:12<00:00,  3.04it/s]


127 0.00303539240999784 0.9071841239929199


100%|██████████| 1180/1180 [07:13<00:00,  3.04it/s]


128 0.0029898735203253354 0.9062369465827942


100%|██████████| 1180/1180 [07:13<00:00,  3.04it/s]


129 0.0029488774391714296 0.9074038863182068


100%|██████████| 1180/1180 [07:19<00:00,  3.03it/s]


130 0.0029061728971121648 0.9078258872032166


100%|██████████| 1180/1180 [07:19<00:00,  3.03it/s]


131 0.0028945882585487367 0.9081862568855286


100%|██████████| 1180/1180 [07:20<00:00,  2.96it/s]


132 0.0028635875221801584 0.907732367515564


100%|██████████| 1180/1180 [07:21<00:00,  2.99it/s]


133 0.0029212312183980343 0.905977725982666


100%|██████████| 1180/1180 [07:19<00:00,  2.94it/s]


134 0.002903730804819465 0.9079429507255554


100%|██████████| 1180/1180 [07:14<00:00,  3.03it/s]


135 0.002907061582649217 0.9046892523765564


100%|██████████| 1180/1180 [07:18<00:00,  3.04it/s]


136 0.0029847137033591794 0.8989441990852356


100%|██████████| 1180/1180 [07:17<00:00,  3.05it/s]


137 0.0031075185888381303 0.9058961868286133


100%|██████████| 1180/1180 [07:19<00:00,  3.03it/s]


138 0.07660190594296482 0.7510302066802979


100%|██████████| 1180/1180 [07:15<00:00,  3.03it/s]


139 0.03775885982314114 0.8002917170524597


100%|██████████| 1180/1180 [07:15<00:00,  3.04it/s]


140 0.022016083150217348 0.8218770027160645


100%|██████████| 1180/1180 [07:15<00:00,  3.05it/s]


141 0.015699746575367664 0.8515130281448364


100%|██████████| 1180/1180 [07:16<00:00,  3.04it/s]


142 0.011131371027830296 0.8615949153900146


100%|██████████| 1180/1180 [07:12<00:00,  3.04it/s]


143 0.008715011241851317 0.8721811771392822


100%|██████████| 1180/1180 [07:12<00:00,  3.04it/s]


144 0.007229933275448583 0.8832398056983948


100%|██████████| 1180/1180 [07:12<00:00,  3.05it/s]


145 0.006105912295964004 0.8892663717269897


100%|██████████| 1180/1180 [07:13<00:00,  3.04it/s]


146 0.005385184624612306 0.8891268372535706


100%|██████████| 1180/1180 [07:12<00:00,  3.03it/s]


147 0.004864085304129796 0.893421471118927


100%|██████████| 1180/1180 [07:13<00:00,  3.04it/s]


148 0.0044663694023255 0.8967865705490112


100%|██████████| 1180/1180 [07:13<00:00,  3.04it/s]


149 0.004183578036964079 0.8983004093170166


100%|██████████| 1180/1180 [07:13<00:00,  3.04it/s]


150 0.0038943608521315584 0.9005599617958069


100%|██████████| 1180/1180 [07:13<00:00,  3.04it/s]


151 0.0036638571939607285 0.9035572409629822


100%|██████████| 1180/1180 [07:13<00:00,  3.03it/s]


152 0.0035905362054724577 0.9029739499092102


100%|██████████| 1180/1180 [07:13<00:00,  3.05it/s]


153 0.0034254126620429162 0.9038093686103821


100%|██████████| 1180/1180 [07:13<00:00,  3.04it/s]


154 0.003308720068146646 0.9052532911300659


100%|██████████| 1180/1180 [07:13<00:00,  3.04it/s]


155 0.003186459743920982 0.9051162600517273


100%|██████████| 1180/1180 [07:13<00:00,  3.03it/s]


156 0.003085411259188831 0.906252920627594


100%|██████████| 1180/1180 [07:13<00:00,  3.04it/s]


157 0.003002659184858203 0.9077150225639343


100%|██████████| 1180/1180 [07:13<00:00,  3.04it/s]


158 0.0029730799562906795 0.9076017737388611


100%|██████████| 1180/1180 [07:13<00:00,  3.03it/s]


159 0.00290228486510564 0.9060102105140686


100%|██████████| 1180/1180 [07:13<00:00,  3.04it/s]


160 0.002869751206832204 0.9086045622825623


100%|██████████| 1180/1180 [07:13<00:00,  3.04it/s]


161 0.002811311543216112 0.9083613753318787


100%|██████████| 1180/1180 [07:17<00:00,  2.84it/s]


162 0.002761395361477492 0.9097648859024048


100%|██████████| 1180/1180 [07:14<00:00,  3.03it/s]


163 0.002708251359560499 0.9082962870597839


100%|██████████| 1180/1180 [07:13<00:00,  3.04it/s]


164 0.0026521298384050616 0.910370945930481


100%|██████████| 1180/1180 [07:16<00:00,  3.06it/s]


165 0.0026116853038187157 0.9093931317329407


100%|██████████| 1180/1180 [07:15<00:00,  3.04it/s]


166 0.002583667682325807 0.9103217720985413


100%|██████████| 1180/1180 [07:12<00:00,  3.04it/s]


167 0.0025729845668780883 0.9098289608955383


100%|██████████| 1180/1180 [07:12<00:00,  3.05it/s]


168 0.0025376651143422945 0.910446286201477


100%|██████████| 1180/1180 [07:12<00:00,  3.05it/s]


169 0.0025281675696685797 0.906991720199585


100%|██████████| 1180/1180 [07:11<00:00,  3.05it/s]


170 0.0025025964725448216 0.9112609028816223


100%|██████████| 1180/1180 [07:11<00:00,  3.05it/s]


171 0.0024941243426605704 0.9106967449188232


100%|██████████| 1180/1180 [07:12<00:00,  3.05it/s]


172 0.0024651550973012752 0.9109419584274292


100%|██████████| 1180/1180 [07:12<00:00,  3.04it/s]


173 0.002435652822717405 0.9117019772529602


100%|██████████| 1180/1180 [07:12<00:00,  3.05it/s]


174 0.002417028483684294 0.9088654518127441


100%|██████████| 1180/1180 [07:12<00:00,  3.05it/s]


175 0.002412765735540732 0.9107418060302734


100%|██████████| 1180/1180 [07:19<00:00,  3.00it/s]


176 0.0024209612051532166 0.9128139019012451


100%|██████████| 1180/1180 [07:18<00:00,  3.03it/s]


177 0.0024018521427372004 0.9112863540649414


100%|██████████| 1180/1180 [07:23<00:00,  2.98it/s]


178 0.0024186471944984502 0.9109468460083008


100%|██████████| 1180/1180 [07:18<00:00,  3.03it/s]


179 0.002482129943393569 0.9112021923065186


100%|██████████| 1180/1180 [07:17<00:00,  3.00it/s]


180 0.002465320218562214 0.9116851687431335


100%|██████████| 1180/1180 [07:20<00:00,  3.01it/s]


181 0.0023806917312672714 0.9107927083969116


100%|██████████| 1180/1180 [07:13<00:00,  3.05it/s]


182 0.0024248945819643707 0.9115434885025024


100%|██████████| 1180/1180 [07:11<00:00,  3.04it/s]


183 0.0025040887378549566 0.9111201167106628


100%|██████████| 1180/1180 [07:11<00:00,  3.06it/s]


184 0.0039142046047317735 0.9021943211555481


100%|██████████| 1180/1180 [07:11<00:00,  3.05it/s]


185 0.006519191664784983 0.9001189470291138


100%|██████████| 1180/1180 [07:11<00:00,  3.05it/s]


186 0.004194154792944754 0.903379499912262


100%|██████████| 1180/1180 [07:11<00:00,  3.06it/s]


187 0.0035811740681102487 0.9068666696548462


100%|██████████| 1180/1180 [07:11<00:00,  3.06it/s]


188 0.0028897717992433106 0.909687876701355


100%|██████████| 1180/1180 [07:11<00:00,  3.05it/s]


189 0.002588462092727265 0.9110180735588074


100%|██████████| 1180/1180 [07:11<00:00,  3.05it/s]


190 0.0024154683820108635 0.9113780856132507


100%|██████████| 1180/1180 [07:11<00:00,  3.05it/s]


191 0.002300453979851389 0.9127278327941895


100%|██████████| 1180/1180 [07:11<00:00,  3.05it/s]


192 0.0021896653978882375 0.9119945168495178


100%|██████████| 1180/1180 [07:11<00:00,  3.05it/s]


193 0.002133085800905928 0.9131699204444885


100%|██████████| 1180/1180 [07:11<00:00,  3.05it/s]


194 0.0020817068554152933 0.9144887328147888


100%|██████████| 1180/1180 [07:11<00:00,  3.04it/s]


195 0.00204331816939923 0.9137371182441711


100%|██████████| 1180/1180 [07:11<00:00,  3.06it/s]


196 0.0020097706863258518 0.9134413003921509


100%|██████████| 1180/1180 [07:11<00:00,  3.05it/s]


197 0.0019848999342003534 0.9141438603401184


100%|██████████| 1180/1180 [07:11<00:00,  3.02it/s]


198 0.0019540422168687504 0.9143350124359131


100%|██████████| 1180/1180 [07:11<00:00,  3.06it/s]


199 0.0019489789140300921 0.9152449369430542


In [26]:
colors = torch.tensor([[0, 0, 0],         # class 0 'background'  black
                       [255, 0, 0],       # class 1 'grasp'       red
                       [255, 255, 0],     # class 2 'cut'         yellow
                       [0, 255, 0],       # class 3 'scoop'       green
                       [0, 255, 255],     # class 4 'contain'     sky blue
                       [0, 0, 255],       # class 5 'pound'       blue
                       [255, 0, 255],     # class 6 'support'     purple
                       [255, 255, 255]    # class 7 'wrap grasp'  white
                      ])

In [27]:
def class_to_mask(cls):
    
    mask = colors[cls].transpose(1, 2).transpose(1, 3)
    
    return mask

In [28]:
def predict(model, sample, device='cpu'):
    model.eval()
    model.to(device)
    
    x, y = sample['image'], sample['class']
    
    x = x.to(device)
    y = y.to(device)

    with torch.no_grad():
        _, y_pred = model(x).max(1)    # y_pred.shape => (N, 240, 320)
    
    true_mask = class_to_mask(y).to('cpu')
    pred_mask = class_to_mask(y_pred).to('cpu')
    
    save_image(true_mask, "./FCN_results_with_class_weight/true_mask_with_FCN.jpg")
    save_image(pred_mask, "./FCN_results_with_class_weight/pred_mask_with_FCN.jpg")

In [31]:
trained_model = FCN8s(3, 8)
trained_model.load_state_dict(torch.load("./FCN_results_with_class_weight/best_mean_iou_model.prm"))

eval_data = PartAffordanceDataset('eval.csv',
                                transform=transforms.Compose([
                                    CenterCrop(),
                                    ToTensor(),
                                    Normalize()
                                ]))

In [32]:
def reverse_normalize(x, mean, std):
    x[:, 0, :, :] = x[:, 0, :, :] * std[0] + mean[0]
    x[:, 1, :, :] = x[:, 1, :, :] * std[1] + mean[1]
    x[:, 2, :, :] = x[:, 2, :, :] * std[2] + mean[2]
    return x

In [33]:
eval_loader = DataLoader(eval_data, batch_size=8, shuffle=False)

mean=[55.8630, 59.9099, 91.7419]
std=[31.6852, 29.8496, 19.0835]

for sample in eval_loader:
    trained_model.eval()
    
    predict(trained_model, sample)
    
    x = sample["image"]
    x = reverse_normalize(x, mean, std)
    save_image(x/255, "./FCN_results_with_class_weight/original_img_with_FCN.jpg")
    
    break