In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.utils.data
import torchvision

from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.utils import save_image

import pandas as pd
import numpy as np
import scipy.io
import skimage.io

from PIL import Image, ImageFilter

# Dataset Class

In [2]:
# input size => (N, 3, 256, 320)

class PartAffordanceDataset(Dataset):
    """Part Affordance Dataset"""
    
    def __init__(self, csv_file, transform=None):
        super().__init__()
        
        self.image_class_path = pd.read_csv(csv_file)
        self.transform = transform
        
    def __len__(self):
        return len(self.image_class_path)
    
    def __getitem__(self, idx):
        image_path = self.image_class_path.iloc[idx, 0]
        class_path = self.image_class_path.iloc[idx, 1]
        image = skimage.io.imread(image_path) # read as numpy array
        cls = scipy.io.loadmat(class_path)["gt_label"]
        
        sample = {'image': image, 'class': cls}
        
        if self.transform:
            sample = self.transform(sample)
            
        return sample

In [3]:
def crop_center_numpy(array, crop_height, crop_weight):
    h, w = array.shape
    return array[h//2 - crop_height//2: h//2 + crop_height//2,
                 w//2 - crop_weight//2: w//2 + crop_weight//2
                ]

In [4]:
def crop_center_pil_image(pil_img, crop_width, crop_height):
    img_width, img_height = pil_img.size
    return pil_img.crop(((img_width - crop_width) // 2,
                         (img_height - crop_height) // 2,
                         (img_width + crop_width) // 2,
                         (img_height + crop_height) // 2))

In [5]:
class CenterCrop(object):
    def __call__(self, sample):
        image, cls = sample['image'], sample['class']
        
        image = Image.fromarray(np.uint8(image))
        
        image = crop_center_pil_image(image, 320, 256)
        cls = crop_center_numpy(cls, 256, 320)
        
        image = np.asarray(image)
        
        return {'image': image, 'class': cls}

In [6]:
class ToTensor(object):
    def __call__(self, sample):
        image, cls = sample['image'], sample['class']
        
        image = image.transpose((2, 0, 1))
        # cls = cls.transpose((2, 0, 1))
        return {'image': torch.from_numpy(image).float(), 
                'class': torch.from_numpy(cls).long()}

In [7]:
# mean = 0
# std = 0
# n = 0

# for sample in data_laoder:
#     img = sample['image']   
#     img = img.view(len(img), 3, -1)
#     mean += img.mean(2).sum(0)
#     std += img.std(2).sum(0)
#     n += len(img)
    
# mean /= n
# std /= n

In [8]:
mean=[55.8630, 59.9099, 91.7419]
std=[31.6852, 29.8496, 19.0835]

In [9]:
class Normalize(object):
    def __call__(self, sample):
        image, cls = sample['image'], sample['class']
        
        image = transforms.functional.normalize(image, mean, std)
        
        return {'image': image, 'class': cls}

In [10]:
train_data = PartAffordanceDataset('train.csv',
                                transform=transforms.Compose([
                                    CenterCrop(),
                                    # OneHot(), # don't need to convert class into one-hot when calculating CrossEntropy Loss
                                    ToTensor(),
                                    Normalize()
                                ]))

In [11]:
test_data = PartAffordanceDataset('test.csv',
                                transform=transforms.Compose([
                                    CenterCrop(),
                                    ToTensor(),
                                    Normalize()
                                ]))

In [12]:
train_loader = DataLoader(train_data, batch_size=10, shuffle=True, num_workers=4)
test_loader = DataLoader(test_data, batch_size=10, shuffle=False, num_workers=4)

### class weight

In [13]:
# cnt_dict = {0:0, 1:0, 2:0, 3:0, 4:0, 5:0, 6:0, 7:0}

# for sample in data_laoder:
#     img = sample['class'].numpy()
    
#     num, cnt = np.unique(img, return_counts=True)
    
#     for n, c in zip(num, cnt):
#         cnt_dict[n] += c

cnt_dict  
 0: 2078085712,  
 1: 34078992,  
 2: 15921090,  
 3: 12433420,  
 4: 38473752,  
 5: 6773528,  
 6: 9273826,  
 7: 20102080  

In [14]:
class_num = torch.tensor([2078085712, 34078992, 15921090, 12433420, 
                          38473752, 6773528, 9273826, 20102080])

total = class_num.sum().item()
print(total)

2215142400


In [15]:
frequency = class_num.float() / total
median = torch.median(frequency)

In [16]:
class_weight = median / frequency

In [17]:
class_weight

tensor([0.0077, 0.4672, 1.0000, 1.2805, 0.4138, 2.3505, 1.7168, 0.7920])

# Define Model

In [13]:
class DeconvBn_2(nn.Module):
    """ Deconvolution(stride=2) => Batch Normilization """
    
    def __init__(self, in_channel, out_channel):
        super().__init__()
        
        self.deconv = nn.ConvTranspose2d(in_channel, out_channel, kernel_size=2, stride=2, bias=False)
        self.bn = nn.BatchNorm2d(out_channel)
        
    def forward(self, x):
        return self.bn(self.deconv(x))

In [14]:
class DeconvBn_8(nn.Module):
    """ Deconvolution(stride=8) => Batch Normilization """
    
    def __init__(self, in_channel, out_channel):
        super().__init__()
        
        self.deconv = nn.ConvTranspose2d(in_channel, out_channel, kernel_size=8, stride=8, bias=False)
        self.bn = nn.BatchNorm2d(out_channel)
        
    def forward(self, x):
        return self.bn(self.deconv(x))

In [15]:
class FCN8s(nn.Module):
    """ Fully Convolutional Network"""
    
    def __init__(self, in_channel, n_classes):
        super().__init__()
        vgg =torchvision.models.vgg16_bn(pretrained=True).features
        
        # confirm the architecture of vgg16 by "print(vgg)"
        
        self.pool3 = vgg[:24]
        self.pool4 = vgg[24:34]
        self.pool5 = vgg[34:]
        
        self.deconv_bn1 = DeconvBn_2(512, 512)
        self.deconv_bn2 = DeconvBn_2(512, 256)
        self.deconv_bn3 = DeconvBn_8(256, n_classes)
        
    def forward(self, x):
        # vgg16
        x3 = self.pool3(x)     # output size => (N, 256, H/8, W/8)
        x4 = self.pool4(x3)    # output size => (N, 512, H/16, W/16)
        x5 = self.pool5(x4)    # output size => (N, 512, H/32, W/32)
        
        score = self.deconv_bn1(x5)
        score = self.deconv_bn2(x4 + score)
        score = self.deconv_bn3(x3 + score)
        
        return score

# Training

In [16]:
from tensorboardX import SummaryWriter
import tqdm

In [17]:
# reference

# SMOOTH = 1e-6

# def iou_pytorch(outputs: torch.Tensor, labels: torch.Tensor):
#     # You can comment out this line if you are passing tensors of equal shape
#     # But if you are passing output from UNet or something it will most probably
#     # be with the BATCH x 1 x H x W shape
#     outputs = outputs.squeeze(1)  # BATCH x 1 x H x W => BATCH x H x W
    
#     intersection = (outputs & labels).float().sum((1, 2))  # Will be zero if Truth=0 or Prediction=0
#     union = (outputs | labels).float().sum((1, 2))         # Will be zzero if both are 0
    
#     iou = (intersection + SMOOTH) / (union + SMOOTH)  # We smooth our devision to avoid 0/0
    
#     thresholded = torch.clamp(20 * (iou - 0.5), 0, 10).ceil() / 10  # This is equal to comparing with thresolds
    
#     return thresholded  # Or thresholded.mean() if you are interested in average across the batch

In [18]:
def eval_model(model, test_loader, device='cpu'):
    model.eval()
    
    intersection = torch.zeros(8)   # the dataset has 8 classes including background
    union = torch.zeros(8)
    
    for sample in test_loader:
        x, y = sample['image'], sample['class']
        
        x = x.to(device)
        y = y.to(device)
        
        with torch.no_grad():
            _, ypred = model(x).max(1)    # y_pred.shape => (N, 240, 320)
        
        for i in range(8):
            y_i = (y == i)           
            ypred_i = (ypred == i)   

            inter = (y_i.byte() & ypred_i.byte()).float().sum().to('cpu')
            intersection[i] += inter
            union[i] += (y_i.float().sum() + ypred_i.float().sum()).to('cpu') - inter
    
    """ iou[i] is the IoU of class i """
    iou = intersection / union
    
    return iou

In [19]:
def train_model(model, train_loader, test_loader, optimizer_cls=optim.Adadelta, 
                criterion=nn.CrossEntropyLoss(), max_epoch=200, device='cpu', writer=None):
    
    model.to(device)
    
    train_losses = []
    val_iou = []
    mean_iou = []
    best_mean_iou = 0.0
    
    optimizer = optimizer_cls(model.parameters())
    
    for epoch in range(max_epoch):
        model.train()
        running_loss = 0.0
        
        for i, sample in tqdm.tqdm(enumerate(train_loader), total=len(train_loader)):
            optimizer.zero_grad()
            
            x, y = sample['image'], sample['class']
            
            x = x.to(device)
            y = y.to(device)

            h = model(x)
            loss = criterion(h, y)
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()

        train_losses.append(running_loss / i)
        
        val_iou.append(eval_model(model, test_loader, device).to('cpu').float())
        mean_iou.append(val_iou[-1].mean().item())
        
        if best_mean_iou < mean_iou[-1]:
            best_mean_iou = mean_iou[-1]
            torch.save(model.state_dict(), "./FCN_results_without_class_weight/best_mean_iou_model.prm")
        
        if writer is not None:
            writer.add_scalar("train_loss", train_losses[-1], epoch)
            writer.add_scalar("mean_IoU", mean_iou[-1], epoch)
            writer.add_scalars("class_IoU", {'iou of class 0': val_iou[-1][0],
                                           'iou of class 1': val_iou[-1][1],
                                           'iou of class 2': val_iou[-1][2],
                                           'iou of class 3': val_iou[-1][3],
                                           'iou of class 4': val_iou[-1][4],
                                           'iou of class 5': val_iou[-1][5],
                                           'iou of class 6': val_iou[-1][6],
                                           'iou of class 7': val_iou[-1][7]}, epoch)
            
        print(epoch, train_losses[-1], mean_iou[-1])
        
    torch.save(model.state_dict(), "./FCN_results_without_class_weight/final_model.prm")

In [20]:
model = FCN8s(3, 8)
writer = SummaryWriter("./FCN_results_without_class_weight/")
train_model(model, train_loader, test_loader, criterion=nn.CrossEntropyLoss(), device="cuda", writer=writer)

100%|██████████| 2360/2360 [07:29<00:00,  5.25it/s]


0 0.17716328388150274 0.8809050917625427


100%|██████████| 2360/2360 [07:29<00:00,  5.25it/s]


1 0.01992074077954455 0.8966249823570251


100%|██████████| 2360/2360 [07:28<00:00,  5.26it/s]


2 0.017313478183318694 0.9051161408424377


100%|██████████| 2360/2360 [07:29<00:00,  5.25it/s]


3 0.016106035933451234 0.9070571064949036


100%|██████████| 2360/2360 [07:29<00:00,  5.25it/s]


4 0.014871698739453234 0.9032094478607178


100%|██████████| 2360/2360 [07:29<00:00,  5.25it/s]


5 0.014106169980589415 0.9121999144554138


100%|██████████| 2360/2360 [07:28<00:00,  5.26it/s]


6 0.013458768860841642 0.9120826125144958


100%|██████████| 2360/2360 [07:28<00:00,  5.26it/s]


7 0.012917540015020313 0.9131811261177063


100%|██████████| 2360/2360 [07:28<00:00,  5.26it/s]


8 0.012414135734482355 0.9145854711532593


100%|██████████| 2360/2360 [07:27<00:00,  5.27it/s]


9 0.011936701643712987 0.9167629480361938


100%|██████████| 2360/2360 [07:27<00:00,  5.27it/s]


10 0.011462875067849768 0.9183233976364136


100%|██████████| 2360/2360 [07:27<00:00,  5.27it/s]


11 0.011024978414143153 0.9188547134399414


100%|██████████| 2360/2360 [07:28<00:00,  5.27it/s]


12 0.010619661861169182 0.9201186299324036


100%|██████████| 2360/2360 [07:28<00:00,  5.27it/s]


13 0.010270735701556337 0.9191166758537292


100%|██████████| 2360/2360 [07:28<00:00,  5.27it/s]


14 0.009916556849065046 0.9201535582542419


100%|██████████| 2360/2360 [07:28<00:00,  5.27it/s]


15 0.009568421280460811 0.921238362789154


100%|██████████| 2360/2360 [07:28<00:00,  5.27it/s]


16 0.00924647315600733 0.9215643405914307


100%|██████████| 2360/2360 [07:28<00:00,  5.27it/s]


17 0.00892818267739466 0.9194531440734863


100%|██████████| 2360/2360 [07:28<00:00,  5.27it/s]


18 0.008635509478488101 0.9219616651535034


100%|██████████| 2360/2360 [07:28<00:00,  5.26it/s]


19 0.00838962672716508 0.921676754951477


100%|██████████| 2360/2360 [07:28<00:00,  5.27it/s]


20 0.008150922059902274 0.9225894212722778


100%|██████████| 2360/2360 [07:28<00:00,  5.27it/s]


21 0.007919796546877706 0.9228162169456482


100%|██████████| 2360/2360 [07:28<00:00,  5.27it/s]


22 0.00766647252186678 0.9230561256408691


100%|██████████| 2360/2360 [07:28<00:00,  5.27it/s]


23 0.007434801636567721 0.9229086637496948


100%|██████████| 2360/2360 [07:28<00:00,  5.27it/s]


24 0.0072222221607201025 0.9234530329704285


100%|██████████| 2360/2360 [07:28<00:00,  5.27it/s]


25 0.0070253657866995774 0.9237632155418396


100%|██████████| 2360/2360 [07:28<00:00,  5.27it/s]


26 0.006847506381027265 0.9228942394256592


100%|██████████| 2360/2360 [07:28<00:00,  5.27it/s]


27 0.006699312019946339 0.9235650897026062


100%|██████████| 2360/2360 [07:28<00:00,  5.27it/s]


28 0.006546673052752767 0.9237663745880127


100%|██████████| 2360/2360 [07:28<00:00,  5.27it/s]


29 0.0063602715387589825 0.9241076707839966


100%|██████████| 2360/2360 [07:28<00:00,  5.27it/s]


30 0.006215264744371264 0.9231640696525574


100%|██████████| 2360/2360 [07:28<00:00,  5.26it/s]


31 0.00609285253304996 0.9239416718482971


100%|██████████| 2360/2360 [07:28<00:00,  5.27it/s]


32 0.005949215130859156 0.9249095916748047


100%|██████████| 2360/2360 [07:28<00:00,  5.27it/s]


33 0.005847618301036235 0.9215842485427856


100%|██████████| 2360/2360 [07:28<00:00,  5.27it/s]


34 0.005727203412932131 0.9227313995361328


100%|██████████| 2360/2360 [07:28<00:00,  5.27it/s]


35 0.005591917212641134 0.9226158857345581


100%|██████████| 2360/2360 [07:28<00:00,  5.27it/s]


36 0.005481948521566206 0.9251782298088074


100%|██████████| 2360/2360 [07:28<00:00,  5.27it/s]


37 0.005368296279946612 0.9243266582489014


100%|██████████| 2360/2360 [07:28<00:00,  5.26it/s]


38 0.005275826443728153 0.9253337383270264


100%|██████████| 2360/2360 [07:28<00:00,  5.26it/s]


39 0.005179795015116916 0.9246577620506287


100%|██████████| 2360/2360 [07:28<00:00,  5.27it/s]


40 0.005076756028635344 0.9246212244033813


100%|██████████| 2360/2360 [07:28<00:00,  5.27it/s]


41 0.004985839459628985 0.9249484539031982


100%|██████████| 2360/2360 [07:28<00:00,  5.26it/s]


42 0.004880721817616627 0.9251054525375366


100%|██████████| 2360/2360 [07:29<00:00,  5.25it/s]


43 0.004813344254651537 0.9231788516044617


100%|██████████| 2360/2360 [07:35<00:00,  5.19it/s]


44 0.0047322024684833645 0.9253475666046143


100%|██████████| 2360/2360 [07:36<00:00,  5.17it/s]


45 0.004634441175995906 0.9244801998138428


100%|██████████| 2360/2360 [07:36<00:00,  5.17it/s]


46 0.004555688295300202 0.9252924919128418


100%|██████████| 2360/2360 [07:37<00:00,  5.16it/s]


47 0.004496253405777788 0.9254516363143921


100%|██████████| 2360/2360 [07:35<00:00,  5.18it/s]


48 0.004408030870108818 0.9254871606826782


100%|██████████| 2360/2360 [07:36<00:00,  5.17it/s]


49 0.004341532616596156 0.9252212047576904


100%|██████████| 2360/2360 [07:36<00:00,  5.17it/s]


50 0.004285546704248737 0.9253305196762085


100%|██████████| 2360/2360 [07:35<00:00,  5.18it/s]


51 0.004208522424083669 0.9258168935775757


100%|██████████| 2360/2360 [07:35<00:00,  5.18it/s]


52 0.00415060250394539 0.9250734448432922


100%|██████████| 2360/2360 [07:32<00:00,  5.21it/s]


53 0.004077197142156972 0.9244197607040405


100%|██████████| 2360/2360 [07:36<00:00,  5.17it/s]


54 0.004015526834175768 0.9252328276634216


100%|██████████| 2360/2360 [07:37<00:00,  5.16it/s]


55 0.003941738239859419 0.9249245524406433


100%|██████████| 2360/2360 [07:37<00:00,  5.16it/s]


56 0.0038792505780777383 0.9250216484069824


100%|██████████| 2360/2360 [07:39<00:00,  5.14it/s]


57 0.0038229353229849647 0.9254277348518372


100%|██████████| 2360/2360 [07:36<00:00,  5.17it/s]


58 0.003755961064444306 0.9245156645774841


100%|██████████| 2360/2360 [07:37<00:00,  5.15it/s]


59 0.0036997353542010227 0.9248698353767395


100%|██████████| 2360/2360 [07:34<00:00,  5.19it/s]


60 0.0036389849292988454 0.925497829914093


100%|██████████| 2360/2360 [07:35<00:00,  5.18it/s]


61 0.00359151053331017 0.9249183535575867


100%|██████████| 2360/2360 [07:36<00:00,  5.17it/s]


62 0.003544886150135216 0.9254908561706543


100%|██████████| 2360/2360 [07:37<00:00,  5.16it/s]


63 0.0035159005690721347 0.925294816493988


100%|██████████| 2360/2360 [07:26<00:00,  5.29it/s]


64 0.003436918711923805 0.9251009821891785


100%|██████████| 2360/2360 [07:25<00:00,  5.30it/s]


65 0.003388207610605751 0.9242032170295715


100%|██████████| 2360/2360 [07:25<00:00,  5.30it/s]


66 0.003343405164673612 0.9249057769775391


100%|██████████| 2360/2360 [07:25<00:00,  5.30it/s]


67 0.003283591790680646 0.9255163073539734


100%|██████████| 2360/2360 [07:25<00:00,  5.30it/s]


68 0.0032238321356225052 0.9254275560379028


100%|██████████| 2360/2360 [07:25<00:00,  5.30it/s]


69 0.00318412631467828 0.9244811534881592


100%|██████████| 2360/2360 [07:25<00:00,  5.30it/s]


70 0.003148349618427487 0.9254236221313477


100%|██████████| 2360/2360 [07:25<00:00,  5.30it/s]


71 0.0030998854980625974 0.9240642786026001


100%|██████████| 2360/2360 [07:25<00:00,  5.30it/s]


72 0.0030573136555646823 0.9249374270439148


100%|██████████| 2360/2360 [07:25<00:00,  5.30it/s]


73 0.0030066641392549496 0.9251340627670288


100%|██████████| 2360/2360 [07:25<00:00,  5.30it/s]


74 0.0029672668825249902 0.925343930721283


100%|██████████| 2360/2360 [07:26<00:00,  5.29it/s]


75 0.002910589463215269 0.9256623983383179


100%|██████████| 2360/2360 [07:25<00:00,  5.29it/s]


76 0.0028690967280275556 0.9252813458442688


100%|██████████| 2360/2360 [07:26<00:00,  5.29it/s]


77 0.0028330333981068426 0.925227701663971


100%|██████████| 2360/2360 [07:25<00:00,  5.29it/s]


78 0.0027843523392948864 0.9253849983215332


100%|██████████| 2360/2360 [07:25<00:00,  5.29it/s]


79 0.0027594448469643175 0.9252501130104065


100%|██████████| 2360/2360 [07:25<00:00,  5.29it/s]


80 0.002713056147620066 0.9251991510391235


100%|██████████| 2360/2360 [07:25<00:00,  5.29it/s]


81 0.002687185823247421 0.9248448610305786


100%|██████████| 2360/2360 [07:26<00:00,  5.29it/s]


82 0.002636829760354352 0.9254103899002075


100%|██████████| 2360/2360 [07:26<00:00,  5.29it/s]


83 0.002594862698611692 0.9224841594696045


100%|██████████| 2360/2360 [07:25<00:00,  5.29it/s]


84 0.002567922994147018 0.9253701567649841


100%|██████████| 2360/2360 [07:25<00:00,  5.29it/s]


85 0.002529542242745731 0.9251444339752197


100%|██████████| 2360/2360 [07:25<00:00,  5.29it/s]


86 0.002488381382916478 0.9254922270774841


100%|██████████| 2360/2360 [07:25<00:00,  5.29it/s]


87 0.002449424028794955 0.9252225160598755


100%|██████████| 2360/2360 [07:25<00:00,  5.29it/s]


88 0.0024206586005545176 0.9254390001296997


100%|██████████| 2360/2360 [07:25<00:00,  5.29it/s]


89 0.0023851316947203355 0.9251100420951843


100%|██████████| 2360/2360 [07:25<00:00,  5.29it/s]


90 0.0023331952376738066 0.9254093170166016


100%|██████████| 2360/2360 [07:25<00:00,  5.29it/s]


91 0.0023085824694092046 0.9249361157417297


100%|██████████| 2360/2360 [07:25<00:00,  5.29it/s]


92 0.0022934832822984687 0.9252597093582153


100%|██████████| 2360/2360 [07:25<00:00,  5.29it/s]


93 0.002251363690822508 0.9247732162475586


100%|██████████| 2360/2360 [07:25<00:00,  5.29it/s]


94 0.0022063424956150616 0.9252371191978455


100%|██████████| 2360/2360 [07:25<00:00,  5.29it/s]


95 0.002172461546191195 0.9250582456588745


100%|██████████| 2360/2360 [07:25<00:00,  5.29it/s]


96 0.002134731850857661 0.9249227046966553


100%|██████████| 2360/2360 [07:25<00:00,  5.29it/s]


97 0.0021050396526762177 0.9245458841323853


100%|██████████| 2360/2360 [07:25<00:00,  5.29it/s]


98 0.00208799976490528 0.9252803325653076


100%|██████████| 2360/2360 [07:25<00:00,  5.29it/s]


99 0.002046598655879561 0.9250534772872925


100%|██████████| 2360/2360 [07:25<00:00,  5.29it/s]


100 0.0020152976524815966 0.9248652458190918


100%|██████████| 2360/2360 [07:25<00:00,  5.29it/s]


101 0.0019829192777879256 0.9252105951309204


100%|██████████| 2360/2360 [07:25<00:00,  5.29it/s]


102 0.001958050834605823 0.925297737121582


100%|██████████| 2360/2360 [07:25<00:00,  5.29it/s]


103 0.0019232067973370735 0.9248149991035461


100%|██████████| 2360/2360 [07:25<00:00,  5.29it/s]


104 0.0018931699129054026 0.9251181483268738


100%|██████████| 2360/2360 [07:25<00:00,  5.29it/s]


105 0.0018645027972561997 0.9244446754455566


100%|██████████| 2360/2360 [07:25<00:00,  5.29it/s]


106 0.0018357920087370572 0.9249011874198914


100%|██████████| 2360/2360 [07:25<00:00,  5.29it/s]


107 0.0018051661392685693 0.9236010313034058


100%|██████████| 2360/2360 [07:25<00:00,  5.29it/s]


108 0.0017783048170328143 0.9250398278236389


100%|██████████| 2360/2360 [07:25<00:00,  5.29it/s]


109 0.001741635468253497 0.9249058365821838


100%|██████████| 2360/2360 [07:25<00:00,  5.29it/s]


110 0.0017247949832240644 0.9248538613319397


100%|██████████| 2360/2360 [07:25<00:00,  5.29it/s]


111 0.0016996319213863928 0.9251600503921509


100%|██████████| 2360/2360 [07:25<00:00,  5.29it/s]


112 0.0016738585418853364 0.9243032932281494


100%|██████████| 2360/2360 [07:25<00:00,  5.29it/s]


113 0.0016423815112170535 0.9248900413513184


100%|██████████| 2360/2360 [07:25<00:00,  5.29it/s]


114 0.0016177026364700051 0.9244115948677063


100%|██████████| 2360/2360 [07:25<00:00,  5.29it/s]


115 0.00160082506269849 0.9243512749671936


100%|██████████| 2360/2360 [07:25<00:00,  5.29it/s]


116 0.0015688983702004879 0.9249793291091919


100%|██████████| 2360/2360 [07:25<00:00,  5.29it/s]


117 0.001543884501284586 0.9251309633255005


100%|██████████| 2360/2360 [07:25<00:00,  5.29it/s]


118 0.001514553313480904 0.9243748784065247


100%|██████████| 2360/2360 [07:25<00:00,  5.29it/s]


119 0.0014962410896429464 0.9222036004066467


100%|██████████| 2360/2360 [07:25<00:00,  5.30it/s]


120 0.0014748610435466266 0.9247933030128479


100%|██████████| 2360/2360 [07:25<00:00,  5.29it/s]


121 0.0014434248604189926 0.9250097870826721


100%|██████████| 2360/2360 [07:25<00:00,  5.29it/s]


122 0.0014201142830762739 0.9250909090042114


100%|██████████| 2360/2360 [07:25<00:00,  5.29it/s]


123 0.0013945985920977948 0.9250463843345642


100%|██████████| 2360/2360 [07:25<00:00,  5.29it/s]


124 0.0013763276874824044 0.9251623749732971


100%|██████████| 2360/2360 [07:25<00:00,  5.29it/s]


125 0.0013517336047899203 0.9244792461395264


100%|██████████| 2360/2360 [07:25<00:00,  5.29it/s]


126 0.0013219531059323172 0.9247092604637146


100%|██████████| 2360/2360 [07:25<00:00,  5.29it/s]


127 0.0012968351287286036 0.9249926805496216


100%|██████████| 2360/2360 [07:25<00:00,  5.30it/s]


128 0.001267698301807174 0.9252172112464905


100%|██████████| 2360/2360 [07:25<00:00,  5.29it/s]


129 0.001257910723104124 0.9247483611106873


100%|██████████| 2360/2360 [07:25<00:00,  5.30it/s]


130 0.0012347749933581081 0.9247053861618042


100%|██████████| 2360/2360 [07:25<00:00,  5.29it/s]


131 0.0012140923277200557 0.9251244068145752


100%|██████████| 2360/2360 [07:25<00:00,  5.30it/s]


132 0.001186696731207293 0.9247991442680359


100%|██████████| 2360/2360 [07:25<00:00,  5.30it/s]


133 0.0011689923346627324 0.9245750308036804


100%|██████████| 2360/2360 [07:25<00:00,  5.29it/s]


134 0.00115917200811745 0.9249733686447144


100%|██████████| 2360/2360 [07:25<00:00,  5.30it/s]


135 0.001138081331599161 0.9247567057609558


100%|██████████| 2360/2360 [07:25<00:00,  5.30it/s]


136 0.0011095821579072218 0.9244847893714905


100%|██████████| 2360/2360 [07:25<00:00,  5.29it/s]


137 0.0010904699073742796 0.9227232933044434


100%|██████████| 2360/2360 [07:25<00:00,  5.30it/s]


138 0.001091499870146367 0.9246847033500671


100%|██████████| 2360/2360 [07:25<00:00,  5.29it/s]


139 0.0010607332418709852 0.9243859052658081


100%|██████████| 2360/2360 [07:25<00:00,  5.29it/s]


140 0.001049133964847606 0.9249884486198425


100%|██████████| 2360/2360 [07:25<00:00,  5.30it/s]


141 0.0010269042714811812 0.9250049591064453


100%|██████████| 2360/2360 [07:25<00:00,  5.29it/s]


142 0.001004768620548559 0.9248015284538269


100%|██████████| 2360/2360 [07:25<00:00,  5.29it/s]


143 0.0009819741894730066 0.924699068069458


100%|██████████| 2360/2360 [07:25<00:00,  5.30it/s]


144 0.0009608266696823629 0.9246730804443359


100%|██████████| 2360/2360 [07:25<00:00,  5.30it/s]


145 0.0009403870285914254 0.924523115158081


100%|██████████| 2360/2360 [07:25<00:00,  5.30it/s]


146 0.0009288879584127444 0.924268364906311


100%|██████████| 2360/2360 [07:25<00:00,  5.30it/s]


147 0.0009196992641078579 0.9233945608139038


100%|██████████| 2360/2360 [07:25<00:00,  5.29it/s]


148 0.0009042744304875844 0.9246556758880615


100%|██████████| 2360/2360 [07:25<00:00,  5.30it/s]


149 0.0008778048232324096 0.9247235655784607


100%|██████████| 2360/2360 [07:25<00:00,  5.30it/s]


150 0.000859775184683309 0.9247373342514038


100%|██████████| 2360/2360 [07:25<00:00,  5.30it/s]


151 0.0008423813854511187 0.9248026013374329


100%|██████████| 2360/2360 [07:25<00:00,  5.29it/s]


152 0.0008263855009592813 0.9242469668388367


100%|██████████| 2360/2360 [07:25<00:00,  5.29it/s]


153 0.0008212237245238995 0.9246984720230103


100%|██████████| 2360/2360 [07:25<00:00,  5.29it/s]


154 0.0007974347834104823 0.9244180917739868


100%|██████████| 2360/2360 [07:25<00:00,  5.29it/s]


155 0.0007922143365398417 0.9239101409912109


100%|██████████| 2360/2360 [07:25<00:00,  5.29it/s]


156 0.0007791306237393461 0.9242106676101685


100%|██████████| 2360/2360 [07:25<00:00,  5.29it/s]


157 0.0007534884946177397 0.9242110252380371


100%|██████████| 2360/2360 [07:25<00:00,  5.29it/s]


158 0.0007441492123441868 0.924743115901947


100%|██████████| 2360/2360 [07:25<00:00,  5.29it/s]


159 0.0007311352461269705 0.9240330457687378


100%|██████████| 2360/2360 [07:25<00:00,  5.30it/s]


160 0.0007311264446672232 0.9245830774307251


100%|██████████| 2360/2360 [07:25<00:00,  5.30it/s]


161 0.0007159828571648104 0.9247581958770752


100%|██████████| 2360/2360 [07:25<00:00,  5.30it/s]


162 0.0007008798857171747 0.9246759414672852


100%|██████████| 2360/2360 [07:25<00:00,  5.29it/s]


163 0.0006782646439524408 0.9244095087051392


100%|██████████| 2360/2360 [07:25<00:00,  5.30it/s]


164 0.0006598913020458493 0.9247317314147949


100%|██████████| 2360/2360 [07:25<00:00,  5.29it/s]


165 0.0006419447121344035 0.9240416884422302


100%|██████████| 2360/2360 [07:25<00:00,  5.30it/s]


166 0.0006311610366980537 0.9247227907180786


100%|██████████| 2360/2360 [07:25<00:00,  5.30it/s]


167 0.0006223871326152963 0.9245551228523254


100%|██████████| 2360/2360 [07:25<00:00,  5.30it/s]


168 0.0006078777947363804 0.9240438938140869


100%|██████████| 2360/2360 [07:25<00:00,  5.29it/s]


169 0.000591686923241309 0.9241266846656799


100%|██████████| 2360/2360 [07:25<00:00,  5.30it/s]


170 0.0005968759661535849 0.9248113036155701


100%|██████████| 2360/2360 [07:25<00:00,  5.29it/s]


171 0.0005818727449696565 0.9245377779006958


100%|██████████| 2360/2360 [07:25<00:00,  5.29it/s]


172 0.0005619811019380532 0.9241829514503479


100%|██████████| 2360/2360 [07:25<00:00,  5.29it/s]


173 0.000554698490156824 0.9247201085090637


100%|██████████| 2360/2360 [07:25<00:00,  5.29it/s]


174 0.0005384932554711087 0.9246951937675476


100%|██████████| 2360/2360 [07:25<00:00,  5.30it/s]


175 0.0005295019836628698 0.9244673252105713


100%|██████████| 2360/2360 [07:25<00:00,  5.29it/s]


176 0.0005161926869244559 0.9243877530097961


100%|██████████| 2360/2360 [07:25<00:00,  5.29it/s]


177 0.0005144198643263341 0.923701286315918


100%|██████████| 2360/2360 [07:25<00:00,  5.30it/s]


178 0.0005008708887352344 0.9244451522827148


100%|██████████| 2360/2360 [07:25<00:00,  5.30it/s]


179 0.0004901153572951295 0.9244444370269775


100%|██████████| 2360/2360 [07:25<00:00,  5.30it/s]


180 0.0004850924218426533 0.9244093894958496


100%|██████████| 2360/2360 [07:25<00:00,  5.29it/s]


181 0.0005165142900495616 0.9245355725288391


100%|██████████| 2360/2360 [07:25<00:00,  5.29it/s]


182 0.0005406092086278065 0.9241147041320801


100%|██████████| 2360/2360 [07:25<00:00,  5.29it/s]


183 0.0005300150342460231 0.9244786500930786


100%|██████████| 2360/2360 [07:25<00:00,  5.29it/s]


184 0.0004915603010310887 0.9245222806930542


100%|██████████| 2360/2360 [07:25<00:00,  5.29it/s]


185 0.0004614509625641464 0.9245696067810059


100%|██████████| 2360/2360 [07:25<00:00,  5.30it/s]


186 0.00048035274510145116 0.9243521690368652


100%|██████████| 2360/2360 [07:25<00:00,  5.29it/s]


187 0.000458012973059799 0.924573540687561


100%|██████████| 2360/2360 [07:25<00:00,  5.30it/s]


188 0.0004218358847292593 0.9246298670768738


100%|██████████| 2360/2360 [07:25<00:00,  5.29it/s]


189 0.0003993200471467998 0.9245548248291016


100%|██████████| 2360/2360 [07:25<00:00,  5.29it/s]


190 0.00038716174345191105 0.9242516160011292


100%|██████████| 2360/2360 [07:25<00:00,  5.30it/s]


191 0.00037843709865087726 0.9244483709335327


100%|██████████| 2360/2360 [07:25<00:00,  5.29it/s]


192 0.0003709270530865907 0.9246881008148193


100%|██████████| 2360/2360 [07:25<00:00,  5.29it/s]


193 0.00036323137794957457 0.9244778156280518


100%|██████████| 2360/2360 [07:25<00:00,  5.29it/s]


194 0.00034843546483664126 0.9245861768722534


100%|██████████| 2360/2360 [07:38<00:00,  5.15it/s]


195 0.0003422764534910512 0.9242525100708008


100%|██████████| 2360/2360 [07:35<00:00,  5.19it/s]


196 0.0003371686996499174 0.9210181832313538


100%|██████████| 2360/2360 [07:34<00:00,  5.19it/s]


197 0.000336306526104755 0.924422025680542


100%|██████████| 2360/2360 [07:34<00:00,  5.19it/s]


198 0.00032150784773930146 0.9243425130844116


100%|██████████| 2360/2360 [07:35<00:00,  5.18it/s]


199 0.00031502902253114214 0.9245156049728394


In [26]:
colors = torch.tensor([[0, 0, 0],         # class 0 'background'  black
                       [255, 0, 0],       # class 1 'grasp'       red
                       [255, 255, 0],     # class 2 'cut'         yellow
                       [0, 255, 0],       # class 3 'scoop'       green
                       [0, 255, 255],     # class 4 'contain'     sky blue
                       [0, 0, 255],       # class 5 'pound'       blue
                       [255, 0, 255],     # class 6 'support'     purple
                       [255, 255, 255]    # class 7 'wrap grasp'  white
                      ])

In [27]:
def class_to_mask(cls):
    
    mask = colors[cls].transpose(1, 2).transpose(1, 3)
    
    return mask

In [28]:
def predict(model, sample, device='cpu'):
    model.eval()
    model.to(device)
    
    x, y = sample['image'], sample['class']
    
    x = x.to(device)
    y = y.to(device)

    with torch.no_grad():
        _, y_pred = model(x).max(1)    # y_pred.shape => (N, 240, 320)
    
    true_mask = class_to_mask(y).to('cpu')
    pred_mask = class_to_mask(y_pred).to('cpu')
    
    save_image(true_mask, "./FCN_results_without_class_weight/true_mask_with_FCN.jpg")
    save_image(pred_mask, "./FCN_results_without_class_weight/pred_mask_with_FCN.jpg")

In [31]:
trained_model = FCN8s(3, 8)
trained_model.load_state_dict(torch.load("./FCN_results_without_class_weight/best_mean_iou_model.prm"))

eval_data = PartAffordanceDataset('eval.csv',
                                transform=transforms.Compose([
                                    CenterCrop(),
                                    ToTensor(),
                                    Normalize()
                                ]))

In [32]:
def reverse_normalize(x, mean, std):
    x[:, 0, :, :] = x[:, 0, :, :] * std[0] + mean[0]
    x[:, 1, :, :] = x[:, 1, :, :] * std[1] + mean[1]
    x[:, 2, :, :] = x[:, 2, :, :] * std[2] + mean[2]
    return x

In [33]:
eval_loader = DataLoader(eval_data, batch_size=8, shuffle=False)

mean=[55.8630, 59.9099, 91.7419]
std=[31.6852, 29.8496, 19.0835]

for sample in eval_loader:
    trained_model.eval()
    
    predict(trained_model, sample)
    
    x = sample["image"]
    x = reverse_normalize(x, mean, std)
    save_image(x/255, "./FCN_results_without_class_weight/original_img_with_FCN.jpg")
    
    break