## Vgg
.VGGNet全部使用3\*3的卷积核和2\*2的池化核，通过不断加深网络结构来提升性能。Vgg表明了卷积神经网络的深度增加和小卷积核的使用对网络的最终分类识别效果有很大的作用.
![vgg](pics/vgg_architectures.png)

从上图可知，VGG16由13个卷积层和3个全连接层组成，期间还有5个max pooling以及最后的softmax

In [2]:
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import torch
import torch.nn.functional as F
import torch.nn as nn
import torch.utils.data as Data
import torchvision
import torchvision.transforms as transforms
from torch.autograd import Variable

## VGG模型

In [21]:
class VGG(nn.Module):
    def __init__(self, cfg, num_classes=12):
        super(VGG, self).__init__()
        # make layers
        self.layers = self.make_layers(cfg)
        # add the linear layer
        self.classifier = nn.Sequential(
            nn.Linear(512 * 7 * 7, 4096),
            nn.ReLU(True),
            nn.Dropout(),
            nn.Linear(4096, 1024),
            nn.ReLU(True),
            nn.Dropout(),
            nn.Linear(1024, num_classes)
        )
    
    def forward(self, x):
        out = self.layers(x)
        # 第一项为batch size，其余应该叠加在一起，所以用-1表示
        out = out.view(out.size(0), -1)
        out = self.classifier(out)
        return out
    
    # make the 16 layers of VGG16
    def make_layers(self, cfg):
        '''
        cfg: a list of channels and 'M' for VGG16
        'M': MaxPooling
        '''
        layers = []  # store layers
        in_channel = 3
        for item in cfg:
            if item == 'M':
                '''
                make MaxPooling layers
                '''
                layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
            else:
                '''
                make conv layers, consising of conv, bn and relu
                '''
                layers += [nn.Conv2d(in_channels=in_channel, out_channels=item, kernel_size=3, padding=1),
                          nn.BatchNorm2d(item),
                          nn.ReLU(inplace=True)]
                in_channel = item
        return nn.Sequential(*layers)

In [4]:
cfg = {
    'VGG11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'VGG13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'VGG16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
    'VGG19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
}

In [5]:
vggnet = VGG(cfg['VGG16'])

In [6]:
vggnet

VGG(
  (layers): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace)
    (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace)
    (6): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (7): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (9): ReLU(inplace)
    (10): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (12): ReLU(inplace)
    (13): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (14): Conv2d(128, 256, kernel_size=(3, 3), 

### 作曲线图

In [4]:
import matplotlib.pyplot as plt
def show_curve(ys, title):
    """
    plot curlve for Loss and Accuacy
    Args:
        ys: loss or acc list
        title: loss or accuracy
    """
    x = np.array(range(len(ys)))
    y = np.array(ys)
    plt.plot(x, y, c='b')
    plt.axis()
    plt.title('{} curve'.format(title))
    plt.xlabel('epoch')
    plt.ylabel('{}'.format(title))
    plt.show()

## 数据预处理

### 划分训练街和验证集，把数据分成训练集和验证集的不同目录（！！！以下代码暂时用不上）

In [19]:
import os
import numpy
import random, shutil

def moveSomeFileToNewDir(fileDir, tarDir):
    #sonDirPath = []
    allDir = os.listdir(fileDir)  # 列出指定路径下的全部文件夹，以列表的方式保存
    print(allDir)
    for dir in allDir:  # 遍历指定路径下的全部文件和文件夹
        sonDirName = os.path.join(fileDir, dir)  # 子文件夹的路径名称
        if os.path.isdir(sonDirName):
            #sonDirPath.append(sonDirName)
            pathDir = os.listdir(sonDirName)  # 取图片的原始路径
            filenumber = len(pathDir)
            rate = 0.1  # 自定义抽取图片的比例，比方说100张抽15张，那就是0.15
            picknumber = int(filenumber * rate)  # 按照rate比例从文件夹中取一定数量图片
            sample = random.sample(pathDir, picknumber)  # 随机选取picknumber数量的样本图片
            print(sample)
            for name in sample:
                oldDir = sonDirName +  '/' + name
                newDir = tarDir + dir
                isExists = os.path.exists(newDir)
                if not isExists:
                    os.makedirs(newDir)
                newTarDir = tarDir +dir + '/'+ name

                print(oldDir, newTarDir)
                shutil.move(oldDir, newDir)


fileDir = "./data/train"      #源图片文件夹路径
valDir = './data/val/'       #移动到验证集目录路径
new_train = './data/new_train/'# 新的测试集
moveSomeFileToNewDir(fileDir, valDir)
testDir = r'E:\test_data'     #移动到测试集目录路径
moveSomeFileToNewDir(fileDir, valDir)
# --------------------- 
# 作者：hhy9820 
# 来源：CSDN 
# 原文：https://blog.csdn.net/hyl999/article/details/84887464 
# 版权声明：本文为博主原创文章，转载请附上博文链接！

['Small-flowered Cranesbill', 'Common wheat', 'Scentless Mayweed', 'Maize', 'Charlock', 'Shepherds Purse', 'Cleavers', 'Loose Silky-bent', 'Sugar beet', 'Common Chickweed', 'Fat Hen', 'Black-grass']
['f8d9d8885.png', 'ba08ca84c.png', '3008533bc.png', '8e9cdb545.png', '2db36f1a0.png', '18a460203.png', '9eed2c8f4.png', '87429079a.png', 'e909c5348.png', '523dae399.png', '273fdc7f5.png', '75107b8b4.png', 'db360e6d2.png', '3b0100994.png', '6d7fa83ff.png', '0e7f05ec0.png', '5dbc8a7d9.png', '840ca8a36.png', '85f1d46e7.png', '4daedff7d.png', '4507d5e15.png', '1c41c5fd4.png', 'c38a4c31e.png', 'abf3ee5df.png', '05598e057.png', '758271672.png', 'f2f975384.png', '21f0b515d.png', '22429fb19.png', 'cb6988242.png', 'f2d762192.png', 'efbf3750d.png', '90e1ab9bd.png', '4ef7552f4.png', '9ab1673ae.png', '6226031d8.png', '5861480ff.png', '524b4014b.png', '19f14f508.png', 'c423a70c3.png', '7be34ad47.png', '789711110.png', '91fa6a4e8.png', '92888379a.png', '5d1f49b75.png', '47aa8024a.png', '09be59300.png', '

./data/train/Charlock/7f251fb9d.png ./data/val/Charlock/7f251fb9d.png
./data/train/Charlock/8b35222d0.png ./data/val/Charlock/8b35222d0.png
./data/train/Charlock/d3228543a.png ./data/val/Charlock/d3228543a.png
./data/train/Charlock/ae66022e7.png ./data/val/Charlock/ae66022e7.png
['2926b17c0.png', '21cfeb62a.png', 'f00311fd8.png', 'b3e5c949e.png', '3d32f86f4.png', '60ee96ab9.png', 'ba4b5df66.png', '143203030.png', '9123349c5.png', '500bd1f17.png', '95e89ddd3.png', '19fb8b2cc.png', 'c16206dca.png', '1c6a48d4f.png', '179cedc9e.png', '50ef0e765.png', '33ea3207a.png', '0bef4ae08.png', '5512ca7ba.png', '42de1a9d5.png', '65241684b.png', '77686f343.png', '995a8bd85.png']
./data/train/Shepherds Purse/2926b17c0.png ./data/val/Shepherds Purse/2926b17c0.png
./data/train/Shepherds Purse/21cfeb62a.png ./data/val/Shepherds Purse/21cfeb62a.png
./data/train/Shepherds Purse/f00311fd8.png ./data/val/Shepherds Purse/f00311fd8.png
./data/train/Shepherds Purse/b3e5c949e.png ./data/val/Shepherds Purse/b3e5c9

['09d0908b0.png', 'f18f2ca04.png', '06c42cf3f.png', 'cb199a0d6.png', '6cc932059.png', '3ee3ef6a3.png', '0b91b1f50.png', 'f1530c3e3.png', '53585f37d.png', '9253be20b.png', 'd920f1441.png', 'b53c5ac08.png', '9b3f2f7a1.png', 'fe03224a0.png', '310656b36.png', 'b48e67073.png', '98f407d78.png', '5b4b5f5ca.png', '2e025ece6.png', 'cba7f2307.png', 'bbfa8d1c9.png', '7a7c2d6f8.png', '2aa88416e.png', '1efb03a94.png', '993fcfaa4.png', '9435b2c58.png', '0c25871d9.png', '0c7fc717a.png', '672d71ed0.png', '1007fd84f.png', '1c1cce1e6.png', '709ff44b4.png', '1265c4a42.png', 'aa83de6bb.png', '0366e36eb.png', '303835197.png', 'ae7415e25.png', '7175e9d7d.png', 'c0fd4e4aa.png', '07e651912.png', '7c933aa92.png', '974108721.png', '61cb94bb2.png', '0a26afdf8.png', '06e9bbeba.png', '3a2a3ddb9.png', 'bd4304980.png', 'a0c39c1dd.png', '63ac8cb8b.png', '422cf9f7d.png', 'af98e2c11.png', 'dd76f845f.png', '6e64646e7.png', '4426efc94.png', 'bc92f8149.png', '63e5b3269.png', 'c2ab91ad2.png', 'ab6338bd1.png', '2f963cc5b.pn

./data/train/Common wheat/0382d0faf.png ./data/val/Common wheat/0382d0faf.png
./data/train/Common wheat/e244e2544.png ./data/val/Common wheat/e244e2544.png
./data/train/Common wheat/42098546c.png ./data/val/Common wheat/42098546c.png
./data/train/Common wheat/7d9f34d96.png ./data/val/Common wheat/7d9f34d96.png
./data/train/Common wheat/143774101.png ./data/val/Common wheat/143774101.png
./data/train/Common wheat/d42042a90.png ./data/val/Common wheat/d42042a90.png
['6c2cef408.png', '97ea07cab.png', '0372b48e1.png', '2c28f05a9.png', 'eef749129.png', '69da83055.png', '0dc27b35d.png', 'a82d370cd.png', '1ed148332.png', '2a6edb04e.png', '275152b11.png', 'b3ab2acea.png', 'f616c6831.png', '083a8c7d2.png', '0438cc647.png', '6398c3458.png', '67cd81694.png', '88b1699ea.png', 'a029e78bb.png', '2bdd11146.png', '03ee6340f.png', '4ae939d7d.png', '41c73ed49.png', '406f54018.png', 'fb1ba1eb6.png', '0a2bcaf43.png', 'aa4b8d708.png', '0d58d5433.png', '553e69068.png', '8e6c73968.png', '368fb9035.png', 'eb4

./data/train/Loose Silky-bent/accab3c58.png ./data/val/Loose Silky-bent/accab3c58.png
./data/train/Loose Silky-bent/b3b03d5b6.png ./data/val/Loose Silky-bent/b3b03d5b6.png
./data/train/Loose Silky-bent/330630db5.png ./data/val/Loose Silky-bent/330630db5.png
./data/train/Loose Silky-bent/4d1241f40.png ./data/val/Loose Silky-bent/4d1241f40.png
./data/train/Loose Silky-bent/217b3e661.png ./data/val/Loose Silky-bent/217b3e661.png
./data/train/Loose Silky-bent/830455097.png ./data/val/Loose Silky-bent/830455097.png
./data/train/Loose Silky-bent/7b4598d18.png ./data/val/Loose Silky-bent/7b4598d18.png
./data/train/Loose Silky-bent/34634ddc3.png ./data/val/Loose Silky-bent/34634ddc3.png
./data/train/Loose Silky-bent/0bcf22873.png ./data/val/Loose Silky-bent/0bcf22873.png
./data/train/Loose Silky-bent/4387157f1.png ./data/val/Loose Silky-bent/4387157f1.png
./data/train/Loose Silky-bent/db5bc3276.png ./data/val/Loose Silky-bent/db5bc3276.png
./data/train/Loose Silky-bent/333f36618.png ./data/val

### 自定义Dataset

In [8]:
from torch.utils.data import Dataset, DataLoader
import PIL

class myDataset(Dataset):
    def __init__(self, labels, root, subset=False, transform=None):
        self.labels = labels
        self.root = root
        self.transform = transform
        
    def __getitem__(self, idx):
        # iloc[i, c]: get the i_th data in c column
        img_name = self.labels.iloc[idx, 0]
        # get full path
        full_name = os.path.join(self.root, img_name)
        # open img as RGB
        image = PIL.Image.open(full_name).convert('RGB')
        
        labels = self.labels.iloc[idx, 2]
        
        if self.transform != None:
            image = self.transform(image)
            
        return image, int(labels)
    
    def __len__(self):
        return len(self.labels)

### 初始化自定义dataset和loader
https://www.kaggle.com/tylercosner/pytorch-starter-pre-trained-resnet50-torchvision/notebook

#### 定义目录、类别

In [7]:
import os
data_dir = './data/'

classes = os.listdir(data_dir + 'train/')
classes = sorted(classes, key=lambda item: (int(item.partition(' ')[0])
                               if item[0].isdigit() else float('inf'), item))
num_to_class = dict(zip(range(len(classes)), classes))
num_to_class

{0: 'Black-grass',
 1: 'Charlock',
 2: 'Cleavers',
 3: 'Common Chickweed',
 4: 'Common wheat',
 5: 'Fat Hen',
 6: 'Loose Silky-bent',
 7: 'Maize',
 8: 'Scentless Mayweed',
 9: 'Shepherds Purse',
 10: 'Small-flowered Cranesbill',
 11: 'Sugar beet'}

#### 定义训练数据的结构

In [9]:
import pandas as pd
train = []
for index, label in enumerate(classes):
    path = data_dir + 'train/' + label + '/'
    for file in os.listdir(path):
        train.append(['{}/{}'.format(label, file), label, index])
    
df = pd.DataFrame(train, columns=['file', 'category', 'category_id',]) 

#### 分离训练集和验证集

In [10]:
train_data = df.sample(frac=0.8)
valid_data = df[~df['file'].isin(train_data['file'])]

#### 定义dataset、dataloader

In [11]:
image_size = 224
batch_size = 64

In [12]:
train_trans = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

valid_trans = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

train_set = myDataset(train_data, data_dir + 'train/', transform = train_trans)
valid_set = myDataset(valid_data, data_dir + 'train/', transform = valid_trans)
# test_set = myDataset(sample_submission, data_dir + 'test/', transform = valid_trans)

train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
valid_loader = DataLoader(valid_set, batch_size=batch_size, shuffle=True)
# test_loader  = DataLoader(test_set, batch_size=1, shuffle=False)

In [13]:
dataset_sizes = {
    'train': len(train_loader.dataset), 
    'valid': len(valid_loader.dataset)
}

## 训练函数

In [14]:
def train(model, train_loader, loss_func, optimizer, device):
    """
    train model using loss_fn and optimizer in an epoch.
    model: CNN networks
    train_loader: a Dataloader object with training data
    loss_func: loss function
    device: train on cpu or gpu device
    """
    total_loss = 0
    # train the model using minibatch
    for i, (images, targets) in enumerate(train_loader):
        images = images.to(device)
        targets = targets.to(device)
#         images, targets = Variable(images), Variable(targets)
        # forward
        outputs = model(images)
#         loss = loss_func(outputs, targets.squeeze())
    
        loss = loss_func(outputs, targets)

        # backward and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        
    avg_loss = total_loss * 1.0 / (len(train_loader))

    return avg_loss

## 评估函数

In [15]:
def evaluate(model, val_loader, loss_fn, device):
    """
    model: CNN networks
    val_loader: a Dataloader object with validation data
    device: evaluate on cpu or gpu device
    return classification accuracy of the model on val dataset
    """
    # evaluate the model
    model.eval()
    # context-manager that disabled gradient computation
    with torch.no_grad():
        correct = 0
        total = 0
        total_loss = 0
        
        for i, (images, targets) in enumerate(val_loader):
            # device: cpu or gpu
            images = images.to(device)
            targets = targets.to(device)
            
            
            outputs = model(images)
            
            # return the maximum value of each row of the input tensor in the 
            # given dimension dim, the second return vale is the index location
            # of each maxium value found(argmax)
            _, predicted = torch.max(outputs.data, dim=1)
            
            
            correct += (predicted == targets).sum().item()
            
            total += targets.size(0)
            
            loss = loss_fn(outputs, targets)  # compute loss 
            total_loss += loss.item() # accumulate every batch loss in a epoch
            
        accuracy = correct*100.0 / total
        total_loss = = total_loss * 1.0 / (len(val_loader))
        return accuracy, total_loss

## 拟合函数

In [16]:
def fit(model, num_epochs, optimizer, train_loader, valid_loader, device):
    """
    train and evaluate an classifier num_epochs times.
    We use optimizer and cross entropy loss to train the model. 
    Args: 
        model: CNN network
        num_epochs: the number of training epochs
        optimizer: optimize the loss function
    """
        
    # loss and optimizer
    loss_func = nn.CrossEntropyLoss()
    
    model.to(device)
    loss_func.to(device)
    
    # log train loss and test accuracy
    train_losses = []
    train_accs = []
    val_accs = []
    val_losses = []
    
    for epoch in range(num_epochs):
        
        print('Epoch {}/{}:'.format(epoch + 1, num_epochs))
        # train step
        loss = train(model, train_loader, loss_func, optimizer, device)
        
        train_accuracy, _ = evaluate(model, train_loader, loss_func, device)
        message = 'Epoch: {}/{}. Train set: Average loss: {:.4f}, Accuracy: {:.4f}%'.format(epoch+1, \
                                                                num_epochs, loss, train_accuracy)
        print(message)
        
        train_accs.append(train_accuracy)
        train_losses.append(loss)
        
        # evaluate step
        val_accuracy, val_loss = evaluate(model, valid_loader, loss_func, device)
        message = 'Epoch: {}/{}. Validation set: Average loss: {:.4f}, Accuracy: {:.4f}%'.format(epoch+1, \
                                                                num_epochs, val_loss, val_accuracy)
        print(message)
        
        val_accs.append(val_accuracy)
        val_losses.append(val_loss)
        
    
    # show curve
    show_curve(train_losses, "train loss")
    show_curve(train_accs, "train accuracy")
    show_curve(val_losses, "validation loss")
    show_curve(val_accs, "validation accuracy")
    return model

## 初始化参数

In [22]:
vggnet = VGG(cfg['VGG16'])

In [23]:
num_epoch = 10
lr = 1e-3
device = torch.device('cuda:1')
# vggnet = VGG(cfg['VGG16'])

optimizer = torch.optim.SGD(vggnet.parameters(), lr=lr, momentum=0.9)



## 训练模型

In [24]:
fit(vggnet, num_epoch, optimizer, train_loader, valid_loader, device)

Epoch 1/10:
64


KeyboardInterrupt: 

## 保存模型（待）

In [23]:
# show parameters in model

# Print model's state_dict
print("Model's state_dict:")
for param_tensor in vggnet.state_dict():
    print(param_tensor, "\t", vggnet.state_dict()[param_tensor].size())

# Print optimizer's state_dict
print("\nOptimizer's state_dict:")
for var_name in optimzer.state_dict():
    print(var_name, "\t", optimzer.state_dict()[var_name])

Model's state_dict:
layers.0.weight 	 torch.Size([64, 3, 3, 3])
layers.0.bias 	 torch.Size([64])
layers.1.weight 	 torch.Size([64])
layers.1.bias 	 torch.Size([64])
layers.1.running_mean 	 torch.Size([64])
layers.1.running_var 	 torch.Size([64])
layers.1.num_batches_tracked 	 torch.Size([])
layers.3.weight 	 torch.Size([64, 64, 3, 3])
layers.3.bias 	 torch.Size([64])
layers.4.weight 	 torch.Size([64])
layers.4.bias 	 torch.Size([64])
layers.4.running_mean 	 torch.Size([64])
layers.4.running_var 	 torch.Size([64])
layers.4.num_batches_tracked 	 torch.Size([])
layers.7.weight 	 torch.Size([128, 64, 3, 3])
layers.7.bias 	 torch.Size([128])
layers.8.weight 	 torch.Size([128])
layers.8.bias 	 torch.Size([128])
layers.8.running_mean 	 torch.Size([128])
layers.8.running_var 	 torch.Size([128])
layers.8.num_batches_tracked 	 torch.Size([])
layers.10.weight 	 torch.Size([128, 128, 3, 3])
layers.10.bias 	 torch.Size([128])
layers.11.weight 	 torch.Size([128])
layers.11.bias 	 torch.Size([128])
l

In [27]:
# save model

save_path = './model/model.pt'
torch.save(vggnet.state_dict(), save_path)

## 使用已有的模型参数来初始化模型

In [15]:
cfg = {
    'VGG11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'VGG13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'VGG16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
    'VGG19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
}

In [14]:
save_path = './model/model.pt'
saved_parametes = torch.load(save_path)
print(saved_parametes)

OrderedDict([('layers.0.weight', tensor([[[[-0.1438, -0.0157, -0.1004],
          [-0.1716, -0.1502, -0.0154],
          [ 0.1081,  0.1225,  0.0179]],

         [[ 0.0763,  0.0847,  0.0539],
          [-0.0197, -0.0904, -0.0503],
          [ 0.0305,  0.0030,  0.0242]],

         [[ 0.1018,  0.1460, -0.1783],
          [-0.1371, -0.0325,  0.1242],
          [ 0.1093, -0.1373, -0.1183]]],


        [[[-0.0621, -0.0403,  0.0187],
          [ 0.1941, -0.0290, -0.1887],
          [ 0.1184,  0.0075, -0.0286]],

         [[ 0.0008, -0.1495, -0.0305],
          [ 0.0682, -0.0363, -0.1111],
          [ 0.1830,  0.1046, -0.1198]],

         [[-0.0904,  0.0915,  0.0168],
          [ 0.1767, -0.0320, -0.0148],
          [ 0.1797, -0.0539, -0.0282]]],


        [[[-0.0345,  0.0082,  0.1282],
          [ 0.0759,  0.1590, -0.0355],
          [-0.0018,  0.0665,  0.0520]],

         [[ 0.0678,  0.1821, -0.0708],
          [ 0.1623,  0.1467,  0.0316],
          [ 0.0974, -0.0706, -0.0248]],

         [[

In [15]:
# initailze model by saved parameters
new_vgg = VGG(cfg['VGG16'])
new_vgg.load_state_dict(saved_parametes)

In [16]:
device = torch.device('cuda:1')

## 获取每一张图片
https://www.kaggle.com/solomonk/pytorch-simplenet-augmentation-cnn-lb-0-945

In [17]:
def testImageLoader(image_name, test_trans):
    """load image, returns cuda tensor"""
    image = Image.open(image_name).convert('RGB')
    image = test_trans(image)
    image = image.unsqueeze(0)     
    return image  


## 预测测试集结果

In [21]:
sample_submission = pd.read_csv(data_dir + 'sample_submission.csv')
sample_submission.columns = ['file', 'species']
# sample_submission['category_id'] = 0
sample_submission.head(3)

Unnamed: 0,file,species
0,0021e90e4.png,Sugar beet
1,003d61042.png,Sugar beet
2,007b3da8b.png,Sugar beet


In [19]:
def predict(model, transform, sample_submission, device):
    model = model.to(device)
    model.eval()
    
    columns = ['file', 'species']
    df_pred = pd.DataFrame(data=np.zeros((0, len(columns))), columns=columns)
    
    test_dir = './data/test'
    
    for index, row in (sample_submission.iterrows()):
        currImage = os.path.join(test_dir, row['file'])
        
        if os.path.isfile(currImage):
            text_img = testImageLoader (currImage, transform)
            
            text_img = text_img.to(device)
            
            # get the index of the max probability
            pro = (model(text_img)).data.max(1)[1]
            
            # get the index of classes for the predicting img
            img_index = (pro.cpu().numpy().item())
            
            # add into data frame
            df_pred = df_pred.append({'file': row['file'], 'species': num_to_class[int(img_index)]}, ignore_index=True) 
        
    return df_pred

In [23]:
df_pred = predict(new_vgg, valid_trans, sample_submission, device)

### 写入文件

In [24]:
df_pred.to_csv('./result/result0.csv', columns=('file', 'species'), index=None)