In [31]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import pathlib
import torch.utils.data
from sklearn.preprocessing import MultiLabelBinarizer

import torchvision.transforms as transforms
import numpy as np
import torch.optim as optim
import pandas as pd
from sklearn.metrics import f1_score
from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset, DataLoader
import PIL

In [32]:
RANDOM_SEED = 666
cuda = torch.device('cuda')

LABEL_MAP = {
0: "Nucleoplasm" ,
1: "Nuclear membrane"   ,
2: "Nucleoli"   ,
3: "Nucleoli fibrillar center",   
4: "Nuclear speckles"   ,
5: "Nuclear bodies"   ,
6: "Endoplasmic reticulum"   ,
7: "Golgi apparatus"  ,
8: "Peroxisomes"   ,
9:  "Endosomes"   ,
10: "Lysosomes"   ,
11: "Intermediate filaments"  , 
12: "Actin filaments"   ,
13: "Focal adhesion sites"  ,
14: "Microtubules"   ,
15: "Microtubule ends"   ,
16: "Cytokinetic bridge"   ,
17: "Mitotic spindle"  ,
18: "Microtubule organizing center",  
19: "Centrosome",
20: "Lipid droplets"   ,
21: "Plasma membrane"  ,
22: "Cell junctions"   ,
23: "Mitochondria"   ,
24: "Aggresome"   ,
25: "Cytosol" ,
26: "Cytoplasmic bodies",
27: "Rods & rings"}

In [33]:
# 데이터셋
class ProteinDataset(Dataset):
    BANDS_NAMES = ['_red.png','_green.png','_blue.png','_yellow.png']
    
    def __len__(self):
        return len(self.images_df)
    
    def __init__(self, images_df, 
                 base_path, 
                 image_transform, 
                 augmentator=None,
                 train_mode=True    
                ):
        if not isinstance(base_path, pathlib.Path):
            base_path = pathlib.Path(base_path)
            
        self.images_df = images_df.copy()
        self.image_transform = image_transform
        self.augmentator = augmentator
        self.images_df.Id = self.images_df.Id.apply(lambda x: base_path / x)
        self.mlb = MultiLabelBinarizer(classes=list(LABEL_MAP.keys()))
        self.train_mode = train_mode
        
    def __getitem__(self, index):
        y = None
        X = self._load_multiband_image(index)
        if self.train_mode:
            y = self._load_multilabel_target(index)
        
        if self.augmentator is not None:
            X = self.augmentator(X)
            
        X = self.image_transform(X)
            
        return X, y 
        
    def _load_multiband_image(self, index):
        row = self.images_df.iloc[index]
        image_bands = []
        for band_name in self.BANDS_NAMES:
            p = str(row.Id.absolute()) + band_name
            pil_channel = PIL.Image.open(p)
            image_bands.append(pil_channel)
            
        # lets pretend its a RBGA image to support 4 channels
        band4image = PIL.Image.merge('RGBA', bands=image_bands)
        return band4image
    
    def _load_multilabel_target(self, index):
        return list(map(int, self.images_df.iloc[index].Target.split(' ')))
    
        
    def collate_func(self, batch):
        labels = None
        images = [x[0] for x in batch]
        
        if self.train_mode:
            labels = [x[1] for x in batch]
            labels_one_hot  = self.mlb.fit_transform(labels)
            labels = torch.FloatTensor(labels_one_hot)
            
        
        return torch.stack(images)[:,:4,:,:], labels



In [34]:
# cnn모델 셋팅
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
    
        # 각 컨볼루션층은 padding = 1 로 설정하여 사이즈가 줄어들지 않도록 설정하였다.
        self.C1 = nn.Conv2d(in_channels=4, out_channels=8, kernel_size=3,padding=1)
        self.C2 = nn.Conv2d(in_channels=8, out_channels=8, kernel_size=3,padding=1)
        
        self.C3 = nn.Conv2d(in_channels=8, out_channels=16, kernel_size=3,padding=1)
        self.C4 = nn.Conv2d(in_channels=16, out_channels=16, kernel_size=3,padding=1)
        
        self.C5 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3,padding=1)
        self.C6 = nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3,padding=1)
        self.C7 = nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3,padding=1)
        self.C8 = nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3,padding=1)
        
        self.C9 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3,padding=1)
        self.C10 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3,padding=1)
        
        # pool 연산 4번과 conv 연산들을 거쳐 512 * 512 사이즈의 RGBY 4 channel input이  32 * 32 사이즈의  64 channel이 된다.
        self.L1 = nn.Linear(32*32*64, 512)
        self.L2 = nn.Linear(512, 28)

    # conv - conv - relu - pool - conv - conv - relu - pool - conv - conv - conv -conv -relu - pool 
    # - conv - conv - relu - pool - fc - relu - fc - relu
    def forward(self, x):
        x=self.C1(x)
        x=self.C2(x)
        x=F.max_pool2d(F.relu(x),2)

        x=self.C3(x)
        x=self.C4(x)
        x=F.max_pool2d(F.relu(x),2)
        
        x=self.C5(x)
        x=self.C6(x)
        x=self.C7(x)
        x=self.C8(x)
        x=F.max_pool2d(F.relu(x),2)
        
        x=self.C9(x)
        x=self.C10(x)
        x=F.max_pool2d(F.relu(x),2)
        
        x = x.view(-1, self.num_flat_features(x))

        x = F.relu(self.L1(x))
        x = F.relu(self.L2(x))

        return x

    def num_flat_features(self, x):
        size = x.size()[1:]
        num_features = 1
        for s in size:
            num_features *= s
        return num_features

In [35]:
# test set 예측을 위한 함수
def predict_submission(model, submission_load):
    all_preds = []
    
    for i, b in enumerate(submission_load):
        if i % 100: print('processing batch {}/{}'.format(i, len(submission_load)))
        X, _ = b
        pred = model(X)
        all_preds.append(pred.sigmoid().cpu().data.numpy())
        
    return np.concatenate(all_preds)
        
# 결과 file을 생성해주는 함수
def make_submission_file(sample_submission_df, predictions):
    submissions = []
    for row in predictions:
        subrow = ' '.join(list([str(i) for i in np.nonzero(row)[0]]))
        submissions.append(subrow)
    
    sample_submission_df['Predicted'] = submissions
    sample_submission_df.to_csv('submission.csv', index=None)
    
    return sample_submission_df

In [36]:
PATH_TO_IMAGES = './input/train/'
PATH_TO_TEST_IMAGES = './input/test/'
PATH_TO_META = './input/train.csv'
SAMPLE_SUBMI = './input/sample_submission.csv'

In [37]:
SEED = 666
DEV_MODE = True
    
df = pd.read_csv(PATH_TO_META)
df_train, df_test  = train_test_split(df, test_size=0.2, random_state=SEED)
df_submission = pd.read_csv(SAMPLE_SUBMI)

# 원본 사이즈
image_transform = transforms.Compose([transforms.ToTensor()])

# 256로 리사이즈
# image_transform = transforms.Compose([
#            transforms.Resize(256),
#            transforms.ToTensor(),
#        ])

 
# Prepare datasets and loaders
   
gtrain = ProteinDataset(df_train, base_path=PATH_TO_IMAGES, image_transform=image_transform)
gtest = ProteinDataset(df_test, base_path=PATH_TO_IMAGES, image_transform=image_transform)
gsub = ProteinDataset(df_submission, base_path=PATH_TO_TEST_IMAGES, train_mode=False, image_transform=image_transform)

train_load = DataLoader(gtrain, collate_fn=gtrain.collate_func, batch_size=64)
test_load = DataLoader(gtest, collate_fn=gtest.collate_func, batch_size=64)
submission_load = DataLoader(gsub, collate_fn=gsub.collate_func, batch_size=64)


In [40]:
model = Net()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# 이 주제는 다중라벨링이 필요하기 때문에 교차엔트로피를 사용하였다.
criterion = nn.CrossEntropyLoss()

# criterion = nn.MultiLabelMarginLoss()
# optimizer = optim.SGD(model.parameters(),lr = 0.01, momentum = 0.5)

optimizer = optim.Adam(model.parameters(), lr=0.001, betas=(0.9, 0.99))

params = list(model.parameters())
print(len(params))
print(params[0].size())

def train(epoch):
    model.train()
    for batch_idx,(data,target) in enumerate(train_load):

        data = data.to(device)
        target = target.to(device)
        target = target.long()

        print(data.size()) # 디버깅을 위해 추가 .. 학습 진행 중임을 확인하기 위해 ..

        output = model(data)
        optimizer.zero_grad()
        
        loss = criterion(output, torch.max(target, 1)[1])
#         print(loss)
#         loss = FocalLoss(output, target)
        loss.backward()
        optimizer.step()

        if batch_idx%100==0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                                                                           epoch, batch_idx * len(data), len(train_load.dataset),100. * batch_idx / len(train_load), loss))


for epoch in range(1, 9):
    train(epoch)

24
torch.Size([8, 4, 3, 3])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([

torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size

torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size

torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size

torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size

torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size

torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size

torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size

torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size

torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size

torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([25, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size

torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size([64, 4, 512, 512])
torch.Size

In [41]:
# test data를 예측한다(성능평가)
submission_predictions =predict_submission(model, submission_load)

processing batch 1/183
processing batch 2/183
processing batch 3/183
processing batch 4/183
processing batch 5/183
processing batch 6/183
processing batch 7/183
processing batch 8/183
processing batch 9/183
processing batch 10/183
processing batch 11/183
processing batch 12/183
processing batch 13/183
processing batch 14/183
processing batch 15/183
processing batch 16/183
processing batch 17/183
processing batch 18/183
processing batch 19/183
processing batch 20/183
processing batch 21/183
processing batch 22/183
processing batch 23/183
processing batch 24/183
processing batch 25/183
processing batch 26/183
processing batch 27/183
processing batch 28/183
processing batch 29/183
processing batch 30/183
processing batch 31/183
processing batch 32/183
processing batch 33/183
processing batch 34/183
processing batch 35/183
processing batch 36/183
processing batch 37/183
processing batch 38/183
processing batch 39/183
processing batch 40/183
processing batch 41/183
processing batch 42/183
p

In [49]:
# 적절한 threshold 값을 설정하고 submission file을 생성한다.
THRESHOLD = 0.6

# threshold 값 설정을 위한 디버깅
print(submission_predictions)

p = submission_predictions>THRESHOLD

submission_file = make_submission_file(sample_submission_df=df_submission,predictions=p)

[[0.92533195 0.5        0.5        ... 0.5        0.5        0.5       ]
 [0.5        0.5        0.5        ... 0.53751856 0.5        0.5       ]
 [0.5        0.5        0.5        ... 0.9999678  0.5        0.5       ]
 ...
 [0.5        0.5        0.992447   ... 0.5        0.5        0.5       ]
 [0.69692147 0.5        0.8664491  ... 0.5        0.5        0.5       ]
 [0.5        0.5        0.5        ... 0.5        0.5        0.5       ]]


In [50]:
# submission file 확인
submission_file.head()

Unnamed: 0,Id,Predicted
0,00008af0-bad0-11e8-b2b8-ac1f6b6435d0,0
1,0000a892-bacf-11e8-b2b8-ac1f6b6435d0,19 23
2,0006faa6-bac7-11e8-b2b7-ac1f6b6435d0,21 25
3,0008baca-bad7-11e8-b2b9-ac1f6b6435d0,21 23 25
4,000cce7e-bad4-11e8-b2b8-ac1f6b6435d0,
