In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils import data
import torchvision
from torchvision import transforms
from torch.utils.tensorboard import SummaryWriter
import matplotlib.pyplot as plt
%matplotlib inline
from d2l import torch as d2l
import random
import time
import pandas as pd
from PIL import Image
from modules import *
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

In [2]:
path = '../data/dog-breed-identification/'
train_csv = pd.read_csv(path + 'labels.csv')
label_list = sorted(train_csv['breed'].unique().tolist())
test_csv = pd.read_csv(path + 'sample_submission.csv')
print(train_csv.shape, test_csv.shape)

(10222, 2) (10357, 121)


In [3]:
class TrainDataset(data.Dataset):
    def __init__(self, dataset):
        super().__init__()
        self.dataset = dataset
        
        self.trans = transforms.Compose([transforms.RandomCrop(224),
                                         transforms.RandomHorizontalFlip(p=0.5),
                                         transforms.ColorJitter(brightness=0.2,
                                                                contrast=0.2,
                                                                saturation=0.2,
                                                                hue=0.2),
                                         transforms.ToTensor(),
                                         transforms.Normalize(mean=[0.4736, 0.4504, 0.3909],
                                                              std=[0.2655, 0.2607, 0.2650],
                                                              inplace=True)])
    
    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, index):
        image, label = self.dataset[index]
        resize = transforms.Resize(random.randint(256, 480))
        return self.trans(resize(image)), label

In [4]:
class ValidDataset(data.Dataset):
    def __init__(self, dataset):
        super().__init__()
        self.dataset = dataset
        self.trans = transforms.Compose([transforms.Resize(256),
                                         transforms.TenCrop(224),
                                         transforms.Lambda(lambda crops: torch.stack([transforms.ToTensor()(crop) for crop in crops])),
                                         transforms.Normalize(mean=[0.4736, 0.4504, 0.3909],
                                                              std=[0.2655, 0.2607, 0.2650],
                                                              inplace=True)])
        
    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, index):
        image, label = self.dataset[index]
        return self.trans(image), label

In [5]:
class TrainValidDataset(data.Dataset):
    def __init__(self):
        super().__init__()
                
    def __len__(self):
        return train_csv.shape[0]
    
    def __getitem__(self, index):
        image = Image.open(path + 'train/' + train_csv['id'][index] + '.jpg')
        label = label_list.index(train_csv['breed'][index])
        return image, label

In [6]:
train_dataset, valid_dataset = data.random_split(TrainValidDataset(),
                                                 [9200, 10222-9200])
train_dataset, valid_dataset = TrainDataset(train_dataset), ValidDataset(valid_dataset)

In [7]:
class TestDataset(data.Dataset):
    def __init__(self, size, horizontal_flip):
        super().__init__()
        self.trans = [transforms.Resize(size),
                      transforms.ToTensor(),
                      transforms.Normalize(mean=[0.4736, 0.4504, 0.3909],
                                           std=[0.2655, 0.2607, 0.2650],
                                           inplace=True)]
        if horizontal_flip:
            self.trans.insert(0, transforms.RandomHorizontalFlip(p=1))
        self.trans = transforms.Compose(self.trans)
    def __len__(self):
        return test_csv.shape[0]
    
    def __getitem__(self, index):
        image = Image.open(path + 'test/' + test_csv['id'][index] + '.jpg')
        return self.trans(image)

In [8]:
class SubmissionGenerater:
    def __init__(self, batch_size):
        self.sizes = [224, 256, 384, 480, 640]
        self.datasets = []
        for size in self.sizes:
            self.datasets += [TestDataset(size, False), TestDataset(size, True)]
        # 每个数据集创建一个dataloader
        self.dataloaders = [data.DataLoader(dataset,
                                            batch_size=batch_size,
                                            shuffle=False,
                                            num_workers=0) for dataset in self.datasets]
    def generate(self, net):
        net.eval()
        outputs = {}
        with torch.no_grad():
            # 对每个dataloader都过一遍
            for i, dataloader in enumerate(self.dataloaders):
                print(f'{i+1:2d} dataset inferencing')
                for i, input, in enumerate(dataloader):
                    input = input.to(device)
                    output = net(input)
                    # 把网络的输出存储起来
                    try:
                        outputs[i] += F.softmax(output, dim=1)
                    except KeyError:
                        outputs[i] = F.softmax(output, dim=1)
        output_tensor = torch.concat([outputs[i] for i in range(len(outputs))], dim=0)
        print(output_tensor.shape)
        
        rows = []
        column = ['id'] + label_list
        for i in range(output_tensor.shape[0]):
            row = [test_csv['id'][i]] + list(output_tensor[i].cpu().numpy())
            rows.append(pd.Series(row, index=column))
        submission = pd.DataFrame(rows)
        
        return submission, output_tensor

In [9]:
net = DenseNet_ImageNet(k=24,
                        theta=0.5,
                        block=Bottleneck,
                        archi='169',
                        num_classes=120,
                        batch_norm=True,
                        dropout=0.1).to(device)
net.load_state_dict(torch.load(f'DenseNet_ImageNet_archi={net.archi}_k={net.k}_theta={net.theta}_dropout={net.dropout}.pth'))

<All keys matched successfully>

In [10]:
generater = SubmissionGenerater(batch_size=1)
submission, output_tensor = generater.generate(net)

 1 dataset inferencing
 2 dataset inferencing
 3 dataset inferencing
 4 dataset inferencing
 5 dataset inferencing
 6 dataset inferencing
 7 dataset inferencing
 8 dataset inferencing
 9 dataset inferencing
10 dataset inferencing
torch.Size([10357, 120])


In [11]:
submission.to_csv(f'submission_archi={net.archi}_k={net.k}_theta={net.theta}_dropout={net.dropout}.csv', index=False)