In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.utils.data
import torchvision

from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.utils import save_image

import pandas as pd
import numpy as np
import scipy.io
import skimage.io

from PIL import Image, ImageFilter

# Dataset Class

In [3]:
class PartAffordanceDataset(Dataset):
    """Part Affordance Dataset"""
    
    def __init__(self, csv_file, transform=None):
        super().__init__()
        
        self.image_class_path = pd.read_csv(csv_file)
        self.transform = transform
        
    def __len__(self):
        return len(self.image_class_path)
    
    def __getitem__(self, idx):
        image_path = self.image_class_path.iloc[idx, 0]
        class_path = self.image_class_path.iloc[idx, 1]
        image = skimage.io.imread(image_path) # read as numpy array
        cls = scipy.io.loadmat(class_path)["gt_label"]
        
        sample = {'image': image, 'class': cls}
        
        if self.transform:
            sample = self.transform(sample)
            
        return sample

In [4]:
def crop_center_numpy(array, crop_height, crop_weight):
    h, w = array.shape
    return array[h//2 - crop_height//2: h//2 + crop_height//2,
                 w//2 - crop_weight//2: w//2 + crop_weight//2
                ]

In [5]:
def crop_center_pil_image(pil_img, crop_width, crop_height):
    img_width, img_height = pil_img.size
    return pil_img.crop(((img_width - crop_width) // 2,
                         (img_height - crop_height) // 2,
                         (img_width + crop_width) // 2,
                         (img_height + crop_height) // 2))

In [6]:
class CenterCrop(object):
    def __call__(self, sample):
        image, cls = sample['image'], sample['class']
        
        image = Image.fromarray(np.uint8(image))
        
        image = crop_center_pil_image(image, 320, 240)
        cls = crop_center_numpy(cls, 240, 320)
        
        image = np.asarray(image)
        
        return {'image': image, 'class': cls}

In [7]:
def one_hot(labels):
    h, w = labels.shape
    x = np.zeros((h, w, 8), dtype=np.int64)
    
    for i in range(h):
        for j in range(w):
            x[i, j, labels[i, j]] = 1
            
    return x

In [8]:
class OneHot(object):
    def __call__(self, sample):
        image, cls = sample['image'], sample['class']
        
        cls_one_hot = one_hot(cls)
        return {'image': image, 
                'class': cls_one_hot}

In [9]:
class ToTensor(object):
    def __call__(self, sample):
        image, cls = sample['image'], sample['class']
        
        image = image.transpose((2, 0, 1))
        # cls = cls.transpose((2, 0, 1))
        return {'image': torch.from_numpy(image).float(), 
                'class': torch.from_numpy(cls).long()}

In [10]:
mean=[55.8630, 59.9099, 91.7419]
std=[31.6852, 29.8496, 19.0835]

In [11]:
class Normalize(object):
    def __call__(self, sample):
        image, cls = sample['image'], sample['class']
        
        image = transforms.functional.normalize(image, mean, std)
        
        return {'image': image, 'class': cls}

In [12]:
train_data = PartAffordanceDataset('train.csv',
                                transform=transforms.Compose([
                                    CenterCrop(),
                                    # OneHot(), # don't need to convert class into one-hot when calculating CrossEntropy Loss
                                    ToTensor(),
                                    Normalize()
                                ]))

In [12]:
test_data = PartAffordanceDataset('test.csv',
                                transform=transforms.Compose([
                                    CenterCrop(),
                                    ToTensor(),
                                    Normalize()
                                ]))

In [13]:
train_loader = DataLoader(train_data, batch_size=15, shuffle=True)
test_loader = DataLoader(test_data, batch_size=15, shuffle=False)
pred_loader = DataLoader(test_data, batch_size=5, shuffle=False)

### the number of pixels in each class

In [17]:
# dataset = PartAffordanceDataset('image_class_path.csv',
#                                 transform=transforms.Compose([
#                                     CenterCrop(),
#                                     ToTensor()
#                                 ]))
# data_laoder = DataLoader(dataset, batch_size=100, shuffle=False)

In [18]:
# cnt_dict = {0:0, 1:0, 2:0, 3:0, 4:0, 5:0, 6:0, 7:0}

# for sample in data_laoder:
#     img = sample['class'].numpy()
    
#     num, cnt = np.unique(img, return_counts=True)
    
#     for n, c in zip(num, cnt):
#         cnt_dict[n] += c

cnt_dict

0: 2078085712,  
 1: 34078992,  
 2: 15921090,  
 3: 12433420,  
 4: 38473752,  
 5: 6773528,  
 6: 9273826,  
 7: 20102080  

In [14]:
class_num = torch.tensor([2078085712, 34078992, 15921090, 12433420, 
                          38473752, 6773528, 9273826, 20102080])

total = class_num.sum().item()
print(total)

2215142400


In [15]:
frequency = class_num.float() / total
median = torch.median(frequency)
class_weight = median / frequency

In [16]:
class_weight

tensor([0.0077, 0.4672, 1.0000, 1.2805, 0.4138, 2.3505, 1.7168, 0.7920])

# calculate mean and std for normalization

In [22]:
# mean = 0
# std = 0
# n = 0

# for sample in data_laoder:
#     img = sample['image']   
#     img = img.view(len(img), 3, -1)
#     mean += img.mean(2).sum(0)
#     std += img.std(2).sum(0)
#     n += len(img)
    
# mean /= n
# std /= n

# Define Model

### parts of U-Net

In [13]:
class DoubleConv(nn.Module):
    """ (Conv > BatchNorm > ReLU) * 2"""
    
    def __init__(self, in_channel, out_channel):
        super().__init__()
        
        self.net = nn.Sequential(
            nn.Conv2d(in_channel, out_channel, 3, 1, 1),
            nn.BatchNorm2d(out_channel),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channel, out_channel, 3, 1, 1),
            nn.BatchNorm2d(out_channel),
            nn.ReLU(inplace=True)
        )
        
    def forward(self, x):
        x = self.net(x)
        return x

In [14]:
class Down(nn.Module):
    """ MaxPooling > DoubleConv """
    
    def __init__(self, in_channel, out_channel):
        super().__init__()
        
        self.net = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_channel, out_channel)
        )
        
    def forward(self, x):
        x = self.net(x)
        return x

In [15]:
class Up(nn.Module):
    """ UpSampling > concat > DoubleConv """
    
    def __init__(self, in_channel, out_channel):
        super().__init__()

        self.double_conv = DoubleConv(in_channel+out_channel, out_channel) # after concat
        
    def forward(self, x, skipped_layer):
        """ the size of x is the same as the input from skipped layer """
        
        x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=True)
        x = torch.cat([x, skipped_layer], dim=1)
        x = self.double_conv(x)
        return x

In [16]:
class UNet(nn.Module):
    """ the size of input is (3, 240, 320) and that of output is the same """
    
    def __init__(self, in_channel, n_classes):
        super().__init__()
        
        self.double_conv = DoubleConv(in_channel, 32)
        self.down1 = Down(32, 64)
        self.down2 = Down(64, 128)
        self.down3 = Down(128, 256)
        self.down4 = Down(256, 512)
        self.up4 = Up(512, 256)
        self.up3 = Up(256, 128)
        self.up2 = Up(128, 64)
        self.up1 = Up(64, 32)
        self.conv = nn.Conv2d(32, n_classes, 1)
        
    def forward(self, x):
        # the left side of U-Net
        x1 = self.double_conv(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x = self.down4(x4)
        
        # the right side of U-Net
        x = self.up4(x, x4)
        x = self.up3(x, x3)
        x = self.up2(x, x2)
        x = self.up1(x, x1)
        x = self.conv(x)
        
        return x

# Training

In [21]:
from tensorboardX import SummaryWriter
import tqdm

In [22]:
def eval_model(model, test_loader, device='cpu'):
    model.eval()
    
    intersection = torch.zeros(8)   # the dataset has 8 classes including background
    union = torch.zeros(8)
    
    for sample in test_loader:
        x, y = sample['image'], sample['class']
        
        x = x.to(device)
        y = y.to(device)
        
        with torch.no_grad():
            _, ypred = model(x).max(1)    # y_pred.shape => (N, 240, 320)
        
        for i in range(8):
            y_i = (y == i)           
            ypred_i = (ypred == i)   
            
            inter = (y_i.byte() & ypred_i.byte()).float().sum().to('cpu')
            intersection[i] += inter
            union[i] += (y_i.float().sum() + ypred_i.float().sum()).to('cpu') - inter
    
    """ iou[i] is the IoU of class i """
    iou = intersection / union
    
    return iou

In [23]:
def train_model(model, train_loader, test_loader, optimizer_cls=optim.Adam, 
                criterion=nn.CrossEntropyLoss(), max_epoch=200, device='cpu', writer=None):
    
    model.to(device)
    
    train_losses = []
    val_iou = []
    mean_iou = []
    best_iou = 0.0
    
    optimizer = optimizer_cls(model.parameters(), lr=0.01)
    
    for epoch in range(max_epoch):
        model.train()
        running_loss = 0.0
        
        for i, sample in tqdm.tqdm(enumerate(train_loader), total=len(train_loader)):
            optimizer.zero_grad()
            
            x, y = sample['image'], sample['class']
            
            x = x.to(device)
            y = y.to(device)

            h = model(x)
            loss = criterion(h, y)
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()

        train_losses.append(running_loss / i)
        
        val_iou.append(eval_model(model, test_loader, device))
        mean_iou.append(val_iou[-1].mean().item())
        
        if best_iou < mean_iou[-1]:
            best_iou = mean_iou[-1]
            torch.save(model.state_dict(), "./U-Net_with_class_weight(median)_results/best_iou_model.prm")
        
        if writer is not None:
            writer.add_scalar("train_loss", train_losses[-1], epoch)
            writer.add_scalar("mean_IoU", mean_iou[-1], epoch)
            writer.add_scalars("class_IoU", {'iou of class 0': val_iou[-1][0],
                                           'iou of class 1': val_iou[-1][1],
                                           'iou of class 2': val_iou[-1][2],
                                           'iou of class 3': val_iou[-1][3],
                                           'iou of class 4': val_iou[-1][4],
                                           'iou of class 5': val_iou[-1][5],
                                           'iou of class 6': val_iou[-1][6],
                                           'iou of class 7': val_iou[-1][7]}, epoch)
            
        print(epoch, train_losses[-1], mean_iou[-1])
        
    torch.save(model.state_dict(), "./U-Net_with_class_weight(median)_results/final_model.prm")

In [24]:
model = UNet(3, 8)
writer = SummaryWriter("./U-Net_with_class_weight(median)_results/")
train_model(model, train_loader, test_loader, criterion=nn.CrossEntropyLoss(weight=class_weight.to('cuda')), device="cuda", writer=writer)

100%|██████████| 1540/1540 [15:41<00:00,  1.61it/s]
  0%|          | 0/1540 [00:00<?, ?it/s]

0 0.8172400682932219 0.44117721915245056


100%|██████████| 1540/1540 [15:19<00:00,  1.58it/s]
  0%|          | 0/1540 [00:00<?, ?it/s]

1 0.30335768800691165 0.5068031549453735


100%|██████████| 1540/1540 [14:49<00:00,  1.62it/s]
  0%|          | 0/1540 [00:00<?, ?it/s]

2 0.18974766149981367 0.6763685941696167


100%|██████████| 1540/1540 [14:46<00:00,  1.50it/s]
  0%|          | 0/1540 [00:00<?, ?it/s]

3 0.14918468836658647 0.6821063756942749


100%|██████████| 1540/1540 [14:50<00:00,  1.93it/s]
  0%|          | 0/1540 [00:00<?, ?it/s]

4 0.10751829283037226 0.6225299835205078


100%|██████████| 1540/1540 [15:18<00:00,  2.31it/s]
  0%|          | 0/1540 [00:00<?, ?it/s]

5 0.08078467668487392 0.7084543704986572


100%|██████████| 1540/1540 [15:30<00:00,  1.47it/s]
  0%|          | 0/1540 [00:00<?, ?it/s]

6 0.08238293311744445 0.6178677678108215


100%|██████████| 1540/1540 [15:03<00:00,  1.46it/s]
  0%|          | 0/1540 [00:00<?, ?it/s]

7 0.08434234788887024 0.7267493605613708


100%|██████████| 1540/1540 [14:50<00:00,  1.58it/s]
  0%|          | 0/1540 [00:00<?, ?it/s]

8 0.07135959890870168 0.7038478851318359


100%|██████████| 1540/1540 [14:38<00:00,  1.48it/s]
  0%|          | 0/1540 [00:00<?, ?it/s]

9 0.05340945447266799 0.7201460599899292


100%|██████████| 1540/1540 [14:52<00:00,  1.96it/s]
  0%|          | 0/1540 [00:00<?, ?it/s]

10 0.07538178109378106 0.6921923160552979


100%|██████████| 1540/1540 [15:28<00:00,  1.94it/s]
  0%|          | 0/1540 [00:00<?, ?it/s]

11 0.04981967734319386 0.745348334312439


100%|██████████| 1540/1540 [15:26<00:00,  1.64it/s]
  0%|          | 0/1540 [00:00<?, ?it/s]

12 0.07089169398240644 0.7397600412368774


100%|██████████| 1540/1540 [14:52<00:00,  1.47it/s]
  0%|          | 0/1540 [00:00<?, ?it/s]

13 0.0540734937013919 0.7207469344139099


100%|██████████| 1540/1540 [14:35<00:00,  1.52it/s]
  0%|          | 0/1540 [00:00<?, ?it/s]

14 0.04573465398282219 0.7131005525588989


100%|██████████| 1540/1540 [14:44<00:00,  1.94it/s]
  0%|          | 0/1540 [00:00<?, ?it/s]

15 0.05959058540151288 0.7283705472946167


100%|██████████| 1540/1540 [15:10<00:00,  2.33it/s]
  0%|          | 0/1540 [00:00<?, ?it/s]

16 0.04337764287620415 0.7080122828483582


100%|██████████| 1540/1540 [15:32<00:00,  1.64it/s]
  0%|          | 0/1540 [00:00<?, ?it/s]

17 0.05712094003258393 0.7399424910545349


100%|██████████| 1540/1540 [15:06<00:00,  1.62it/s]
  0%|          | 0/1540 [00:00<?, ?it/s]

18 0.04195407555814375 0.7625642418861389


100%|██████████| 1540/1540 [14:42<00:00,  1.63it/s]
  0%|          | 0/1540 [00:00<?, ?it/s]

19 0.04135997753646019 0.7443401217460632


100%|██████████| 1540/1540 [14:40<00:00,  1.61it/s]
  0%|          | 0/1540 [00:00<?, ?it/s]

20 0.05687924229998028 0.7567595839500427


100%|██████████| 1540/1540 [14:53<00:00,  1.92it/s]
  0%|          | 0/1540 [00:00<?, ?it/s]

21 0.03894690465172384 0.7445806860923767


100%|██████████| 1540/1540 [15:26<00:00,  1.95it/s]
  0%|          | 0/1540 [00:00<?, ?it/s]

22 0.052034376396678385 0.7189707159996033


100%|██████████| 1540/1540 [15:22<00:00,  1.53it/s]
  0%|          | 0/1540 [00:00<?, ?it/s]

23 0.040594795192310455 0.7727477550506592


100%|██████████| 1540/1540 [14:57<00:00,  1.60it/s]
  0%|          | 0/1540 [00:00<?, ?it/s]

24 0.039553868770720396 0.6405915021896362


100%|██████████| 1540/1540 [14:38<00:00,  1.57it/s]
  0%|          | 0/1540 [00:00<?, ?it/s]

25 0.04296064920858991 0.7605265974998474


100%|██████████| 1540/1540 [14:45<00:00,  1.94it/s]
  0%|          | 0/1540 [00:00<?, ?it/s]

26 0.03681470103782152 0.7314755320549011


100%|██████████| 1540/1540 [15:17<00:00,  1.96it/s]
  0%|          | 0/1540 [00:00<?, ?it/s]

27 0.03683444639259999 0.7779368758201599


100%|██████████| 1540/1540 [15:27<00:00,  1.63it/s]
  0%|          | 0/1540 [00:00<?, ?it/s]

28 0.052718757212408915 0.7139151692390442


100%|██████████| 1540/1540 [15:02<00:00,  1.50it/s]
  0%|          | 0/1540 [00:00<?, ?it/s]

29 0.03675361195926526 0.7592933177947998


100%|██████████| 1540/1540 [14:39<00:00,  1.60it/s]
  0%|          | 0/1540 [00:00<?, ?it/s]

30 0.04165122465335215 0.7156141400337219


100%|██████████| 1540/1540 [14:39<00:00,  1.56it/s]
  0%|          | 0/1540 [00:00<?, ?it/s]

31 0.03865187885475728 0.7597687244415283


100%|██████████| 1540/1540 [14:58<00:00,  1.95it/s]
  0%|          | 0/1540 [00:00<?, ?it/s]

32 0.033563404829350864 0.7778260111808777


100%|██████████| 1540/1540 [15:23<00:00,  1.95it/s]
  0%|          | 0/1540 [00:00<?, ?it/s]

33 0.03382654529720874 0.754504919052124


100%|██████████| 1540/1540 [15:06<00:00,  1.63it/s]
  0%|          | 0/1540 [00:00<?, ?it/s]

34 0.047160892556721005 0.7811678647994995


100%|██████████| 1540/1540 [14:47<00:00,  1.55it/s]
  0%|          | 0/1540 [00:00<?, ?it/s]

35 0.03268969249975147 0.7776017189025879


100%|██████████| 1540/1540 [14:46<00:00,  1.62it/s]
  0%|          | 0/1540 [00:00<?, ?it/s]

36 0.032581716545332824 0.7775541543960571


100%|██████████| 1540/1540 [14:54<00:00,  1.93it/s]
  0%|          | 0/1540 [00:00<?, ?it/s]

37 0.046126910382452105 0.7699235081672668


100%|██████████| 1540/1540 [15:25<00:00,  1.97it/s]
  0%|          | 0/1540 [00:00<?, ?it/s]

38 0.03229072785199831 0.7807507514953613


100%|██████████| 1540/1540 [15:32<00:00,  1.57it/s]
  0%|          | 0/1540 [00:00<?, ?it/s]

39 0.031323893787853704 0.7796289920806885


100%|██████████| 1540/1540 [15:06<00:00,  1.62it/s]
  0%|          | 0/1540 [00:00<?, ?it/s]

40 0.03165527094092615 0.772951066493988


100%|██████████| 1540/1540 [14:54<00:00,  1.64it/s]
  0%|          | 0/1540 [00:00<?, ?it/s]

41 0.03780681881968893 0.7705194354057312


100%|██████████| 1540/1540 [11:43<00:00,  2.13it/s]
  0%|          | 0/1540 [00:00<?, ?it/s]

42 0.03024969447423511 0.8037403225898743


100%|██████████| 1540/1540 [11:53<00:00,  2.08it/s]
  0%|          | 0/1540 [00:00<?, ?it/s]

43 0.030894058318403226 0.7747011184692383


100%|██████████| 1540/1540 [11:57<00:00,  2.11it/s]
  0%|          | 0/1540 [00:00<?, ?it/s]

44 0.030582499642547188 0.7880792021751404


100%|██████████| 1540/1540 [12:42<00:00,  1.60it/s]
  0%|          | 0/1540 [00:00<?, ?it/s]

45 0.04180806101789387 0.7905295491218567


100%|██████████| 1540/1540 [14:26<00:00,  1.62it/s]
  0%|          | 0/1540 [00:00<?, ?it/s]

46 0.029107281459523262 0.7927146553993225


100%|██████████| 1540/1540 [15:12<00:00,  2.53it/s]
  0%|          | 0/1540 [00:00<?, ?it/s]

47 0.030714330116990055 0.7914932370185852


100%|██████████| 1540/1540 [15:12<00:00,  1.55it/s]
  0%|          | 0/1540 [00:00<?, ?it/s]

48 0.02894094309819798 0.8016871213912964


100%|██████████| 1540/1540 [14:37<00:00,  1.95it/s]
  0%|          | 0/1540 [00:00<?, ?it/s]

49 0.028885585111895692 0.7816388607025146


100%|██████████| 1540/1540 [11:51<00:00,  2.08it/s]
  0%|          | 0/1540 [00:00<?, ?it/s]

50 0.02947795385575918 0.7855564951896667


100%|██████████| 1540/1540 [12:00<00:00,  2.47it/s]
  0%|          | 0/1540 [00:00<?, ?it/s]

51 0.03466230735383671 0.7993636131286621


100%|██████████| 1540/1540 [12:06<00:00,  2.06it/s]
  0%|          | 0/1540 [00:00<?, ?it/s]

52 0.033105476415045 0.7977133393287659


100%|██████████| 1540/1540 [11:45<00:00,  2.09it/s]
  0%|          | 0/1540 [00:00<?, ?it/s]

53 0.026805636889946933 0.7934030294418335


100%|██████████| 1540/1540 [11:47<00:00,  2.49it/s]
  0%|          | 0/1540 [00:00<?, ?it/s]

54 0.02704353247786726 0.7987520098686218


100%|██████████| 1540/1540 [11:59<00:00,  2.09it/s]
  0%|          | 0/1540 [00:00<?, ?it/s]

55 0.027271054250014495 0.8026408553123474


100%|██████████| 1540/1540 [11:50<00:00,  2.09it/s]
  0%|          | 0/1540 [00:00<?, ?it/s]

56 0.027108625752719443 0.7995657324790955


100%|██████████| 1540/1540 [11:45<00:00,  2.11it/s]
  0%|          | 0/1540 [00:00<?, ?it/s]

57 0.027613438439541546 0.8083735108375549


100%|██████████| 1540/1540 [11:59<00:00,  2.11it/s]
  0%|          | 0/1540 [00:00<?, ?it/s]

58 0.02687819551402985 0.756256639957428


100%|██████████| 1540/1540 [11:57<00:00,  2.08it/s]
  0%|          | 0/1540 [00:00<?, ?it/s]

59 0.025732364542443424 0.7940279245376587


100%|██████████| 1540/1540 [11:45<00:00,  2.09it/s]
  0%|          | 0/1540 [00:00<?, ?it/s]

60 0.025890190269292138 0.8185473084449768


100%|██████████| 1540/1540 [14:56<00:00,  1.93it/s]
  0%|          | 0/1540 [00:00<?, ?it/s]

61 0.03463992804453464 0.7970828413963318


100%|██████████| 1540/1540 [15:51<00:00,  1.50it/s]
  0%|          | 0/1540 [00:00<?, ?it/s]

62 0.02512192175087602 0.8079693913459778


100%|██████████| 1540/1540 [15:41<00:00,  1.61it/s]
  0%|          | 0/1540 [00:00<?, ?it/s]

63 0.023910054071643587 0.8053655624389648


100%|██████████| 1540/1540 [15:10<00:00,  1.60it/s]
  0%|          | 0/1540 [00:00<?, ?it/s]

64 0.024319676614576458 0.8170074820518494


100%|██████████| 1540/1540 [14:45<00:00,  1.66it/s]
  0%|          | 0/1540 [00:00<?, ?it/s]

65 0.02443204686902536 0.8213095664978027


100%|██████████| 1540/1540 [14:43<00:00,  1.63it/s]
  0%|          | 0/1540 [00:00<?, ?it/s]

66 0.024726101585807275 0.8193380236625671


100%|██████████| 1540/1540 [14:51<00:00,  1.92it/s]
  0%|          | 0/1540 [00:00<?, ?it/s]

67 0.03243195254457334 0.8158618807792664


100%|██████████| 1540/1540 [15:16<00:00,  2.29it/s]
  0%|          | 0/1540 [00:00<?, ?it/s]

68 0.02342130670528508 0.8395110964775085


100%|██████████| 1540/1540 [15:33<00:00,  1.61it/s]
  0%|          | 0/1540 [00:00<?, ?it/s]

69 0.022689312604474432 0.8236275315284729


100%|██████████| 1540/1540 [15:18<00:00,  1.49it/s]
  0%|          | 0/1540 [00:00<?, ?it/s]

70 0.023590811405420457 0.8158432245254517


100%|██████████| 1540/1540 [14:53<00:00,  1.62it/s]
  0%|          | 0/1540 [00:00<?, ?it/s]

71 0.023346146703603944 0.7632960677146912


100%|██████████| 1540/1540 [14:46<00:00,  1.57it/s]
  0%|          | 0/1540 [00:00<?, ?it/s]

72 0.023330991088622688 0.7946252226829529


100%|██████████| 1540/1540 [14:49<00:00,  1.62it/s]
  0%|          | 0/1540 [00:00<?, ?it/s]

73 0.02254571096186149 0.8096922636032104


100%|██████████| 1540/1540 [15:05<00:00,  2.35it/s]
  0%|          | 0/1540 [00:00<?, ?it/s]

74 0.030046376405747352 0.7905733585357666


100%|██████████| 1540/1540 [15:30<00:00,  2.01it/s]
  0%|          | 0/1540 [00:00<?, ?it/s]

75 0.022704066366170273 0.8259660005569458


100%|██████████| 1540/1540 [15:36<00:00,  1.51it/s]
  0%|          | 0/1540 [00:00<?, ?it/s]

76 0.02111876570051409 0.843940258026123


100%|██████████| 1540/1540 [15:01<00:00,  1.62it/s]
  0%|          | 0/1540 [00:00<?, ?it/s]

77 0.02127763873984029 0.818200409412384


100%|██████████| 1540/1540 [15:03<00:00,  1.54it/s]
  0%|          | 0/1540 [00:00<?, ?it/s]

78 0.022270724251188576 0.8222898244857788


100%|██████████| 1540/1540 [14:51<00:00,  1.62it/s]
  0%|          | 0/1540 [00:00<?, ?it/s]

79 0.02185515438093331 0.8228725790977478


100%|██████████| 1540/1540 [14:47<00:00,  1.62it/s]
  0%|          | 0/1540 [00:00<?, ?it/s]

80 0.02119327501997187 0.8422010540962219


100%|██████████| 1540/1540 [15:09<00:00,  2.35it/s]
  0%|          | 0/1540 [00:00<?, ?it/s]

81 0.02122604448827561 0.7679103016853333


100%|██████████| 1540/1540 [15:31<00:00,  1.59it/s]
  0%|          | 0/1540 [00:00<?, ?it/s]

82 0.024190094869624996 0.737048327922821


100%|██████████| 1540/1540 [15:29<00:00,  1.64it/s]
  0%|          | 0/1540 [00:00<?, ?it/s]

83 0.0220720179958718 0.8239198327064514


100%|██████████| 1540/1540 [14:51<00:00,  1.63it/s]
  0%|          | 0/1540 [00:00<?, ?it/s]

84 0.019358304922866543 0.8314065933227539


100%|██████████| 1540/1540 [14:45<00:00,  1.58it/s]
  0%|          | 0/1540 [00:00<?, ?it/s]

85 0.020374151790917616 0.7934911847114563


100%|██████████| 1540/1540 [15:25<00:00,  1.40it/s]
  0%|          | 0/1540 [00:00<?, ?it/s]

86 0.019768163419006085 0.8506183624267578


100%|██████████| 1540/1540 [16:14<00:00,  2.19it/s]
  0%|          | 0/1540 [00:00<?, ?it/s]

87 0.02959283726632382 0.8138813376426697


100%|██████████| 1540/1540 [16:41<00:00,  1.49it/s]
  0%|          | 0/1540 [00:00<?, ?it/s]

88 0.01939289337732116 0.8106219172477722


100%|██████████| 1540/1540 [16:13<00:00,  1.48it/s]
  0%|          | 0/1540 [00:00<?, ?it/s]

89 0.01887850985572449 0.857711672782898


100%|██████████| 1540/1540 [15:45<00:00,  1.48it/s]
  0%|          | 0/1540 [00:00<?, ?it/s]

90 0.019204097778776623 0.8603644371032715


100%|██████████| 1540/1540 [15:52<00:00,  1.49it/s]
  0%|          | 0/1540 [00:00<?, ?it/s]

91 0.01966540397237442 0.8279309272766113


100%|██████████| 1540/1540 [16:17<00:00,  2.18it/s]
  0%|          | 0/1540 [00:00<?, ?it/s]

92 0.0197492224667185 0.8116166591644287


100%|██████████| 1540/1540 [16:41<00:00,  1.48it/s]
  0%|          | 0/1540 [00:00<?, ?it/s]

93 0.019515829499087176 0.7662962675094604


100%|██████████| 1540/1540 [16:13<00:00,  1.49it/s]
  0%|          | 0/1540 [00:00<?, ?it/s]

94 0.018860531015330084 0.79639732837677


100%|██████████| 1540/1540 [15:50<00:00,  1.49it/s]
  0%|          | 0/1540 [00:00<?, ?it/s]

95 0.0190105111059225 0.7818004488945007


100%|██████████| 1540/1540 [15:51<00:00,  1.49it/s]
  0%|          | 0/1540 [00:00<?, ?it/s]

96 0.0186650525049934 0.8350691199302673


100%|██████████| 1540/1540 [16:19<00:00,  2.19it/s]
  0%|          | 0/1540 [00:00<?, ?it/s]

97 0.018663383487798394 0.5903311371803284


100%|██████████| 1540/1540 [16:39<00:00,  1.35it/s]
  0%|          | 0/1540 [00:00<?, ?it/s]

98 0.018539096959191842 0.7385974526405334


100%|██████████| 1540/1540 [16:14<00:00,  1.49it/s]
  0%|          | 0/1540 [00:00<?, ?it/s]

99 0.01802969558522725 0.8498731255531311


100%|██████████| 1540/1540 [15:49<00:00,  1.50it/s]
  0%|          | 0/1540 [00:00<?, ?it/s]

100 0.018199511576756165 0.8364809155464172


100%|██████████| 1540/1540 [15:50<00:00,  1.49it/s]
  0%|          | 0/1540 [00:00<?, ?it/s]

101 0.01781879563206508 0.8045675754547119


100%|██████████| 1540/1540 [16:23<00:00,  2.16it/s]
  0%|          | 0/1540 [00:00<?, ?it/s]

102 0.03092954020893965 0.7973804473876953


100%|██████████| 1540/1540 [16:41<00:00,  1.48it/s]
  0%|          | 0/1540 [00:00<?, ?it/s]

103 0.018217795183537184 0.7697142362594604


100%|██████████| 1540/1540 [16:17<00:00,  1.49it/s]
  0%|          | 0/1540 [00:00<?, ?it/s]

104 0.016816943100531955 0.7644516229629517


100%|██████████| 1540/1540 [15:46<00:00,  1.49it/s]
  0%|          | 0/1540 [00:00<?, ?it/s]

105 0.017012255443011964 0.8133794665336609


100%|██████████| 1540/1540 [15:52<00:00,  1.42it/s]
  0%|          | 0/1540 [00:00<?, ?it/s]

106 0.017297628614627903 0.5795503258705139


100%|██████████| 1540/1540 [16:16<00:00,  2.19it/s]
  0%|          | 0/1540 [00:00<?, ?it/s]

107 0.017743254267405466 0.7969100475311279


100%|██████████| 1540/1540 [16:50<00:00,  1.36it/s]
  0%|          | 0/1540 [00:00<?, ?it/s]

108 0.01744849282626094 0.6862926483154297


100%|██████████| 1540/1540 [16:11<00:00,  1.39it/s]
  0%|          | 0/1540 [00:00<?, ?it/s]

109 0.016998602099755393 0.7507978081703186


100%|██████████| 1540/1540 [12:55<00:00,  1.94it/s]
  0%|          | 0/1540 [00:00<?, ?it/s]

110 0.01719690915475502 0.5716046690940857


100%|██████████| 1540/1540 [12:50<00:00,  2.33it/s]
  0%|          | 0/1540 [00:00<?, ?it/s]

111 0.016886750542358425 0.8283900022506714


100%|██████████| 1540/1540 [12:54<00:00,  1.95it/s]
  0%|          | 0/1540 [00:00<?, ?it/s]

112 0.01680899626569964 0.6627713441848755


100%|██████████| 1540/1540 [12:40<00:00,  1.95it/s]
  0%|          | 0/1540 [00:00<?, ?it/s]

113 0.016788039200947104 0.7820351123809814


100%|██████████| 1540/1540 [12:42<00:00,  2.33it/s]
  0%|          | 0/1540 [00:00<?, ?it/s]

114 0.016485006609584234 0.6829151511192322


100%|██████████| 1540/1540 [12:56<00:00,  1.94it/s]
  0%|          | 0/1540 [00:00<?, ?it/s]

115 0.016662552868650864 0.6849145889282227


100%|██████████| 1540/1540 [12:45<00:00,  1.95it/s]
  0%|          | 0/1540 [00:00<?, ?it/s]

116 0.022323823817464013 0.7906542420387268


100%|██████████| 1540/1540 [12:40<00:00,  1.94it/s]
  0%|          | 0/1540 [00:00<?, ?it/s]

117 0.017326299518666304 0.8209937810897827


100%|██████████| 1540/1540 [12:55<00:00,  1.95it/s]
  0%|          | 0/1540 [00:00<?, ?it/s]

118 0.01517863654185999 0.8396770358085632


100%|██████████| 1540/1540 [12:54<00:00,  1.95it/s]
  0%|          | 0/1540 [00:00<?, ?it/s]

119 0.015727803555557467 0.672956109046936


100%|██████████| 1540/1540 [12:40<00:00,  1.93it/s]
  0%|          | 0/1540 [00:00<?, ?it/s]

120 0.016411224952907998 0.8306739926338196


100%|██████████| 1540/1540 [12:52<00:00,  2.33it/s]
  0%|          | 0/1540 [00:00<?, ?it/s]

121 0.015800081601014204 0.8394181132316589


100%|██████████| 1540/1540 [12:55<00:00,  1.94it/s]
  0%|          | 0/1540 [00:00<?, ?it/s]

122 0.01604592883953361 0.8340046405792236


100%|██████████| 1540/1540 [12:41<00:00,  1.95it/s]
  0%|          | 0/1540 [00:00<?, ?it/s]

123 0.015555864086228194 0.8531497120857239


100%|██████████| 1540/1540 [12:43<00:00,  2.33it/s]
  0%|          | 0/1540 [00:00<?, ?it/s]

124 0.015819552594758302 0.8436644673347473


100%|██████████| 1540/1540 [12:55<00:00,  1.95it/s]
  0%|          | 0/1540 [00:00<?, ?it/s]

125 0.015586926314079331 0.8568095564842224


100%|██████████| 1540/1540 [12:45<00:00,  1.94it/s]
  0%|          | 0/1540 [00:00<?, ?it/s]

126 0.015654176079363716 0.8680230379104614


100%|██████████| 1540/1540 [12:41<00:00,  1.94it/s]
  0%|          | 0/1540 [00:00<?, ?it/s]

127 0.015394978574704066 0.8479724526405334


100%|██████████| 1540/1540 [12:55<00:00,  1.95it/s]
  0%|          | 0/1540 [00:00<?, ?it/s]

128 0.015368822730576609 0.7955278754234314


100%|██████████| 1540/1540 [12:54<00:00,  1.95it/s]
  0%|          | 0/1540 [00:00<?, ?it/s]

129 0.015006226879290999 0.8739804625511169


100%|██████████| 1540/1540 [12:40<00:00,  1.95it/s]
  0%|          | 0/1540 [00:00<?, ?it/s]

130 0.015416116120573314 0.8151503801345825


100%|██████████| 1540/1540 [12:52<00:00,  2.33it/s]
  0%|          | 0/1540 [00:00<?, ?it/s]

131 0.014789470792472692 0.7468255758285522


100%|██████████| 1540/1540 [12:54<00:00,  1.94it/s]
  0%|          | 0/1540 [00:00<?, ?it/s]

132 0.014777957309277086 0.7901513576507568


100%|██████████| 1540/1540 [12:40<00:00,  1.95it/s]
  0%|          | 0/1540 [00:00<?, ?it/s]

133 0.01503782576225974 0.8577682971954346


100%|██████████| 1540/1540 [12:43<00:00,  2.33it/s]
  0%|          | 0/1540 [00:00<?, ?it/s]

134 0.01519863654588863 0.8297964334487915


100%|██████████| 1540/1540 [12:55<00:00,  1.94it/s]
  0%|          | 0/1540 [00:00<?, ?it/s]

135 0.014474742916062272 0.8609117865562439


100%|██████████| 1540/1540 [12:44<00:00,  1.95it/s]
  0%|          | 0/1540 [00:00<?, ?it/s]

136 0.014843734099558125 0.8239460587501526


100%|██████████| 1540/1540 [12:40<00:00,  1.95it/s]
  0%|          | 0/1540 [00:00<?, ?it/s]

137 0.014533209951523065 0.8429516553878784


100%|██████████| 1540/1540 [12:54<00:00,  1.95it/s]
  0%|          | 0/1540 [00:00<?, ?it/s]

138 0.02513864334028812 0.8219152092933655


100%|██████████| 1540/1540 [12:54<00:00,  1.94it/s]
  0%|          | 0/1540 [00:00<?, ?it/s]

139 0.015387755888862181 0.826694667339325


100%|██████████| 1540/1540 [12:41<00:00,  1.92it/s]
  0%|          | 0/1540 [00:00<?, ?it/s]

140 0.013670339702027641 0.8369917869567871


100%|██████████| 1540/1540 [12:51<00:00,  2.33it/s]
  0%|          | 0/1540 [00:00<?, ?it/s]

141 0.013552287229184907 0.7691352367401123


100%|██████████| 1540/1540 [12:55<00:00,  1.96it/s]
  0%|          | 0/1540 [00:00<?, ?it/s]

142 0.014159999770258176 0.29994794726371765


100%|██████████| 1540/1540 [12:40<00:00,  1.95it/s]
  0%|          | 0/1540 [00:00<?, ?it/s]

143 0.014338922036643479 0.8203902840614319


100%|██████████| 1540/1540 [12:43<00:00,  2.33it/s]
  0%|          | 0/1540 [00:00<?, ?it/s]

144 0.014270648329994987 0.8325129747390747


100%|██████████| 1540/1540 [12:55<00:00,  1.95it/s]
  0%|          | 0/1540 [00:00<?, ?it/s]

145 0.013948471682077204 0.8284309506416321


100%|██████████| 1540/1540 [12:44<00:00,  1.94it/s]
  0%|          | 0/1540 [00:00<?, ?it/s]

146 0.01973206674452578 0.8309177756309509


100%|██████████| 1540/1540 [12:41<00:00,  1.94it/s]
  0%|          | 0/1540 [00:00<?, ?it/s]

147 0.01359157723993731 0.8000673055648804


100%|██████████| 1540/1540 [12:54<00:00,  1.96it/s]
  0%|          | 0/1540 [00:00<?, ?it/s]

148 0.013084178707544837 0.8591777682304382


100%|██████████| 1540/1540 [12:54<00:00,  1.95it/s]
  0%|          | 0/1540 [00:00<?, ?it/s]

149 0.013515455977503475 0.7728613018989563


100%|██████████| 1540/1540 [12:40<00:00,  1.95it/s]
  0%|          | 0/1540 [00:00<?, ?it/s]

150 0.015112358095793657 0.8562983274459839


100%|██████████| 1540/1540 [12:51<00:00,  2.34it/s]
  0%|          | 0/1540 [00:00<?, ?it/s]

151 0.013300169932536776 0.8490756750106812


100%|██████████| 1540/1540 [12:54<00:00,  1.95it/s]
  0%|          | 0/1540 [00:00<?, ?it/s]

152 0.013500453731198835 0.8710470199584961


100%|██████████| 1540/1540 [12:40<00:00,  1.95it/s]
  0%|          | 0/1540 [00:00<?, ?it/s]

153 0.013657164298206122 0.8197787404060364


100%|██████████| 1540/1540 [12:44<00:00,  2.33it/s]
  0%|          | 0/1540 [00:00<?, ?it/s]

154 0.013762404147627545 0.8353955149650574


100%|██████████| 1540/1540 [12:55<00:00,  1.94it/s]
  0%|          | 0/1540 [00:00<?, ?it/s]

155 0.013828040811985421 0.8510718941688538


100%|██████████| 1540/1540 [12:44<00:00,  1.94it/s]
  0%|          | 0/1540 [00:00<?, ?it/s]

156 0.013239420444811703 0.7965975999832153


100%|██████████| 1540/1540 [12:40<00:00,  1.95it/s]
  0%|          | 0/1540 [00:00<?, ?it/s]

157 0.013509635341216346 0.8487356901168823


100%|██████████| 1540/1540 [12:55<00:00,  1.95it/s]
  0%|          | 0/1540 [00:00<?, ?it/s]

158 0.013152370697747894 0.8276564478874207


100%|██████████| 1540/1540 [12:53<00:00,  1.95it/s]
  0%|          | 0/1540 [00:00<?, ?it/s]

159 0.013275333135276839 0.817908763885498


100%|██████████| 1540/1540 [12:40<00:00,  1.96it/s]
  0%|          | 0/1540 [00:00<?, ?it/s]

160 0.02072730620650065 0.8362441658973694


100%|██████████| 1540/1540 [12:52<00:00,  2.32it/s]
  0%|          | 0/1540 [00:00<?, ?it/s]

161 0.013823830150128322 0.8594170212745667


100%|██████████| 1540/1540 [12:55<00:00,  1.96it/s]
  0%|          | 0/1540 [00:00<?, ?it/s]

162 0.012263602344470511 0.8592092990875244


100%|██████████| 1540/1540 [12:40<00:00,  1.95it/s]
  0%|          | 0/1540 [00:00<?, ?it/s]

163 0.012240601280828801 0.8516265153884888


100%|██████████| 1540/1540 [12:44<00:00,  2.33it/s]
  0%|          | 0/1540 [00:00<?, ?it/s]

164 0.012945408241785797 0.8342110514640808


100%|██████████| 1540/1540 [12:55<00:00,  1.95it/s]
  0%|          | 0/1540 [00:00<?, ?it/s]

165 0.013080571235603543 0.8604485392570496


100%|██████████| 1540/1540 [12:44<00:00,  1.94it/s]
  0%|          | 0/1540 [00:00<?, ?it/s]

166 0.012924659681035044 0.7699509859085083


100%|██████████| 1540/1540 [12:40<00:00,  1.96it/s]
  0%|          | 0/1540 [00:00<?, ?it/s]

167 0.013153235588399445 0.790121853351593


100%|██████████| 1540/1540 [12:55<00:00,  1.94it/s]
  0%|          | 0/1540 [00:00<?, ?it/s]

168 0.012879956686412982 0.8185667395591736


100%|██████████| 1540/1540 [12:54<00:00,  1.94it/s]
  0%|          | 0/1540 [00:00<?, ?it/s]

169 0.01267318335338178 0.6929101347923279


100%|██████████| 1540/1540 [12:41<00:00,  1.95it/s]
  0%|          | 0/1540 [00:00<?, ?it/s]

170 0.012936287060321888 0.8794510364532471


100%|██████████| 1540/1540 [12:53<00:00,  2.33it/s]
  0%|          | 0/1540 [00:00<?, ?it/s]

171 0.012674026198798937 0.8682711124420166


100%|██████████| 1540/1540 [12:56<00:00,  1.96it/s]
  0%|          | 0/1540 [00:00<?, ?it/s]

172 0.012830343774608817 0.8329666256904602


100%|██████████| 1540/1540 [12:40<00:00,  1.95it/s]
  0%|          | 0/1540 [00:00<?, ?it/s]

173 0.012966643824636007 0.8387928009033203


100%|██████████| 1540/1540 [12:43<00:00,  2.33it/s]
  0%|          | 0/1540 [00:00<?, ?it/s]

174 0.012457800706239723 0.8812844157218933


100%|██████████| 1540/1540 [12:56<00:00,  1.94it/s]
  0%|          | 0/1540 [00:00<?, ?it/s]

175 0.013102007358406971 0.8532557487487793


100%|██████████| 1540/1540 [09:44<00:00,  2.66it/s]
  0%|          | 0/1540 [00:00<?, ?it/s]

176 0.012178304332566869 0.8034756183624268


100%|██████████| 1540/1540 [09:38<00:00,  2.66it/s]
  0%|          | 0/1540 [00:00<?, ?it/s]

177 0.01228741612594001 0.8618844151496887


100%|██████████| 1540/1540 [09:37<00:00,  2.67it/s]
  0%|          | 0/1540 [00:00<?, ?it/s]

178 0.012282791765633053 0.8622781038284302


100%|██████████| 1540/1540 [09:38<00:00,  2.66it/s]
  0%|          | 0/1540 [00:00<?, ?it/s]

179 0.012679452827542808 0.8818132877349854


100%|██████████| 1540/1540 [09:41<00:00,  2.67it/s]
  0%|          | 0/1540 [00:00<?, ?it/s]

180 0.011940307115810385 0.8080909252166748


100%|██████████| 1540/1540 [09:37<00:00,  2.67it/s]
  0%|          | 0/1540 [00:00<?, ?it/s]

181 0.012112072478650626 0.7992501854896545


100%|██████████| 1540/1540 [09:37<00:00,  2.67it/s]
  0%|          | 0/1540 [00:00<?, ?it/s]

182 0.01255923366813138 0.7195051312446594


100%|██████████| 1540/1540 [09:38<00:00,  2.67it/s]
  0%|          | 0/1540 [00:00<?, ?it/s]

183 0.011978793957060696 0.7889151573181152


100%|██████████| 1540/1540 [09:37<00:00,  2.66it/s]
  0%|          | 0/1540 [00:00<?, ?it/s]

184 0.011930943810446957 0.6313046216964722


100%|██████████| 1540/1540 [09:37<00:00,  2.67it/s]
  0%|          | 0/1540 [00:00<?, ?it/s]

185 0.012179043917366636 0.8446710705757141


100%|██████████| 1540/1540 [09:41<00:00,  2.66it/s]
  0%|          | 0/1540 [00:00<?, ?it/s]

186 0.018666433581938365 0.843788743019104


100%|██████████| 1540/1540 [09:38<00:00,  2.67it/s]
  0%|          | 0/1540 [00:00<?, ?it/s]

187 0.011506777320267746 0.8399491906166077


100%|██████████| 1540/1540 [09:37<00:00,  2.66it/s]
  0%|          | 0/1540 [00:00<?, ?it/s]

188 0.015645374559396134 0.8234814405441284


100%|██████████| 1540/1540 [09:37<00:00,  2.67it/s]
  0%|          | 0/1540 [00:00<?, ?it/s]

189 0.012974671465468782 0.8704023361206055


100%|██████████| 1540/1540 [09:37<00:00,  2.66it/s]
  0%|          | 0/1540 [00:00<?, ?it/s]

190 0.011070384898134026 0.8797211647033691


100%|██████████| 1540/1540 [09:37<00:00,  2.66it/s]
  0%|          | 0/1540 [00:00<?, ?it/s]

191 0.011267258713652438 0.8637678623199463


100%|██████████| 1540/1540 [09:37<00:00,  2.67it/s]
  0%|          | 0/1540 [00:00<?, ?it/s]

192 0.011872837267019692 0.8673992156982422


100%|██████████| 1540/1540 [09:37<00:00,  2.66it/s]
  0%|          | 0/1540 [00:00<?, ?it/s]

193 0.011810564879591857 0.7445842623710632


100%|██████████| 1540/1540 [09:37<00:00,  2.67it/s]
  0%|          | 0/1540 [00:00<?, ?it/s]

194 0.012599714736976431 0.8839831948280334


100%|██████████| 1540/1540 [09:37<00:00,  2.67it/s]
  0%|          | 0/1540 [00:00<?, ?it/s]

195 0.01174776998086505 0.7608171105384827


100%|██████████| 1540/1540 [09:37<00:00,  2.67it/s]
  0%|          | 0/1540 [00:00<?, ?it/s]

196 0.011757567337919281 0.8811920881271362


100%|██████████| 1540/1540 [09:37<00:00,  2.67it/s]
  0%|          | 0/1540 [00:00<?, ?it/s]

197 0.011769790786346922 0.86788409948349


100%|██████████| 1540/1540 [09:37<00:00,  2.66it/s]
  0%|          | 0/1540 [00:00<?, ?it/s]

198 0.01154389113364493 0.887050211429596


100%|██████████| 1540/1540 [09:39<00:00,  2.67it/s]


199 0.018344483912581745 0.857844352722168


In [17]:
colors = torch.tensor([[0, 0, 0],         # class 0
                       [255, 0, 0],       # class 1
                       [255, 255, 0], 
                       [0, 255, 0],
                       [0, 255, 255],
                       [0, 0, 255],
                       [255, 0, 255],
                       [255, 255, 255]    # class 7
                      ])

In [18]:
def class_to_mask(cls):
    
    mask = colors[cls].transpose(1, 2).transpose(1, 3)
    
    return mask

In [19]:
def predict(model, sample, device='cpu'):
    model.eval()
    model.to(device)
    
    x, y = sample['image'], sample['class']
    
    x = x.to(device)
    y = y.to(device)

    with torch.no_grad():
        _, y_pred = model(x).max(1)    # y_pred.shape => (N, 240, 320)
    
    true_mask = class_to_mask(y).to('cpu')
    pred_mask = class_to_mask(y_pred).to('cpu')
    
    save_image(true_mask, "./U-Net_with_class_weight(median)_results/true_mask_with_UNet.jpg")
    save_image(pred_mask, "./U-Net_with_class_weight(median)_results/pred_mask_with_UNet.jpg")

In [20]:
trained_model = UNet(3, 8)
trained_model.load_state_dict(torch.load("./U-Net_with_class_weight(median)_results/best_iou_model.prm"))

In [21]:
eval_data = PartAffordanceDataset('eval.csv',
                                transform=transforms.Compose([
                                    CenterCrop(),
                                    ToTensor(),
                                    Normalize()
                                ]))

In [22]:
def reverse_normalize(x, mean, std):
    x[:, 0, :, :] = x[:, 0, :, :] * std[0] + mean[0]
    x[:, 1, :, :] = x[:, 1, :, :] * std[1] + mean[1]
    x[:, 2, :, :] = x[:, 2, :, :] * std[2] + mean[2]
    return x

In [23]:
eval_loader = DataLoader(eval_data, batch_size=8, shuffle=False)

In [24]:
mean=[55.8630, 59.9099, 91.7419]
std=[31.6852, 29.8496, 19.0835]

for sample in eval_loader:
    trained_model.eval()
    
    predict(trained_model, sample)
    
    x = sample["image"]
    x = reverse_normalize(x, mean, std)
    save_image(x/255, "./U-Net_with_class_weight(median)_results/original_img_with_UNet.jpg")
    
    break