In [1]:
import torch
import torch.nn as nn
import numpy
from torchvision import datasets , transforms
import matplotlib.pyplot as plt
import time #dfdsfdsf

## 加载数据，并制作成数据集

In [2]:
dataTrans = transforms.Compose([transforms.Resize((128,128)),transforms.ToTensor()])
foodSets ={x:datasets.ImageFolder("../data/food/"+x,transform = dataTrans) for x in ["training","validation"]}  
dataloads ={x:torch.utils.data.DataLoader(foodSets[x],batch_size=16,shuffle=True,num_workers=10) for x in ["training","validation"]}

## 建立CNN模型

In [3]:
class foodCNN(torch.nn.Module):
    def __init__(self):
        super(foodCNN, self).__init__()
        self.cnn = nn.Sequential(
            nn.Conv2d(3, 64,3,1,1),  # [64, 128, 128]
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(2, 2, 0),      # [64, 64, 64]
        

            nn.Conv2d(64,128,3,1,1), 
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.MaxPool2d(2,2,0),

            nn.Conv2d(128,256,3,1,1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.MaxPool2d(2,2,0),

            nn.Conv2d(256,512,3,1,1),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.MaxPool2d(2,2,0),

            nn.Conv2d(512,512,3,1,1),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.MaxPool2d(2,2,0),
        )
        self.fc = nn.Sequential(
            nn.Linear(512*4*4,1024),
            nn.ReLU(),
            nn.Linear(1024,256),
            nn.ReLU(),
            nn.Linear(256,11)
        )

    def forward(self, x ):
        y = self.cnn(x)
        #y = y.view(y.size()[0], -1)        
        y = y.reshape(y.size(0),-1)
        return self.fc(y)

## Training

In [5]:
model = foodCNN().to("cuda")
lossFunc = nn.CrossEntropyLoss()
opt = torch.optim.Adam(model.parameters(),lr=0.001)
num_epoch = 10
since = time.time()
n_t= len(dataloads["training"].dataset)
n_v = len(dataloads["validation"].dataset)
for epoch in range(num_epoch):    
    best_acc = 0    
    t_loss = 0.0
    t_cor = 0
    v_loss=0.0  
    v_cor=0

    model.train()  # 训练 
    for x,y in dataloads["training"]:
        x = x.to("cuda");y = y.to("cuda")            

        out = model(x);loss = lossFunc(out,y)
        _, preds = torch.max(out, 1)        
        opt.zero_grad();loss.backward();opt.step()

        t_loss += loss.item()*x.size(0)
        t_cor += torch.sum(preds==y.data) 
    now = time.time() - since   
    print("train:loss{:.4f}; correct{:.4f};time:{:.0f}m {:.0f}s".format(t_loss/n_t,t_cor.double()/n_t,now // 60, now % 60))
    model.eval()   # 验证
    with torch.no_grad():
        for x,y in dataloads["validation"]:
            x = x.to("cuda");y = y.to("cuda") 
            out = model(x);loss = lossFunc(out,y)
            _, preds = torch.max(out, 1)
            v_loss += loss.item()*x.size(0)
            v_cor += torch.sum(preds==y.data) 
    now = time.time() - since
    print("validation:loss{:.4f}; correct{:.4f};time{:.0f}m {:.0f}s".format(v_loss/n_v,v_cor/n_v,now // 60, now % 60))



train:loss2.1674; correct0.2325;time:1m 2s
validation:loss2.0312; correct0.2805;time1m 11s
train:loss1.9694; correct0.3032;time:2m 13s
validation:loss1.9295; correct0.3155;time2m 22s
train:loss1.8333; correct0.3570;time:3m 25s
validation:loss1.9289; correct0.3259;time3m 33s
train:loss1.7002; correct0.4048;time:4m 39s
validation:loss1.6782; correct0.4163;time4m 49s
train:loss1.5701; correct0.4513;time:5m 55s
validation:loss2.2316; correct0.3096;time6m 3s
train:loss1.4758; correct0.4827;time:7m 6s
validation:loss1.5381; correct0.4831;time7m 14s
train:loss1.3550; correct0.5274;time:8m 16s
validation:loss1.5898; correct0.4761;time8m 25s
train:loss1.2630; correct0.5631;time:9m 26s
validation:loss1.3635; correct0.5362;time9m 34s
train:loss1.1665; correct0.5983;time:10m 35s
validation:loss1.3155; correct0.5566;time10m 44s
train:loss1.0677; correct0.6332;time:11m 46s
validation:loss1.3315; correct0.5577;time11m 54s


In [7]:

state = {'state_dict': model.state_dict(), 'optimizer' : opt.state_dict()}
torch.save(state, "checkpoint")

## 测试数据集，并显示图片

In [None]:
myIter =iter(dataloads["training"]) 
images ,classes = myIter.next()
im = images[0].numpy()
im = im.transpose(1,2,0)
plt.imshow(im)
plt.show(),classes[0],foodSets["training"].class_to_idx