In [1]:
import torch
import torchvision.models as models

In [2]:
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
import os 
import time
from datetime import timedelta

In [3]:
import shutil
tran_filenames = os.listdir('data/train')

In [4]:
batch_size = 32

In [5]:
tran_filenames

['cat.0.jpg',
 'cat.1.jpg',
 'cat.10.jpg',
 'cat.100.jpg',
 'cat.1000.jpg',
 'cat.10000.jpg',
 'cat.10001.jpg',
 'cat.10002.jpg',
 'cat.10003.jpg',
 'cat.10004.jpg',
 'cat.10005.jpg',
 'cat.10006.jpg',
 'cat.10007.jpg',
 'cat.10008.jpg',
 'cat.10009.jpg',
 'cat.1001.jpg',
 'cat.10010.jpg',
 'cat.10011.jpg',
 'cat.10012.jpg',
 'cat.10013.jpg',
 'cat.10014.jpg',
 'cat.10015.jpg',
 'cat.10016.jpg',
 'cat.10017.jpg',
 'cat.10018.jpg',
 'cat.10019.jpg',
 'cat.1002.jpg',
 'cat.10020.jpg',
 'cat.10021.jpg',
 'cat.10022.jpg',
 'cat.10023.jpg',
 'cat.10024.jpg',
 'cat.10025.jpg',
 'cat.10026.jpg',
 'cat.10027.jpg',
 'cat.10028.jpg',
 'cat.10029.jpg',
 'cat.1003.jpg',
 'cat.10030.jpg',
 'cat.10031.jpg',
 'cat.10032.jpg',
 'cat.10033.jpg',
 'cat.10034.jpg',
 'cat.10035.jpg',
 'cat.10036.jpg',
 'cat.10037.jpg',
 'cat.10038.jpg',
 'cat.10039.jpg',
 'cat.1004.jpg',
 'cat.10040.jpg',
 'cat.10041.jpg',
 'cat.10042.jpg',
 'cat.10043.jpg',
 'cat.10044.jpg',
 'cat.10045.jpg',
 'cat.10046.jpg',
 'cat.1004

In [6]:
train_cat = filter(lambda x:x[:3]=='cat', tran_filenames)
train_dog = filter(lambda x:x[:3]=='dog', tran_filenames)

In [7]:
# from shutil import copyfile

# def rmrf_mkdir(dirname):
#     if os.path.exists(dirname):
#         shutil.rmtree(dirname)
#     os.mkdir(dirname)

# rmrf_mkdir('data/train2')
# os.mkdir('data/train2/cat')
# os.mkdir('data/train2/dog')

# # rmrf_mkdir('test2')
# # os.symlink('../test/', 'test2/test')

# for filename in train_cat:
#     copyfile('data/train/'+filename, 'data/train2/cat/'+filename)

# for filename in train_dog:
#     copyfile('data/train/'+filename, 'data/train2/dog/'+filename)


## 准备数据集

In [8]:
datasetdir = os.path.join('./data')
traindir = os.path.join(datasetdir,'train2')

In [9]:
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])
train_dataset = datasets.ImageFolder(traindir,transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    normalize,
]))

n_train = int(len(train_dataset)*0.9)
n_validation = len(train_dataset)-n_train

In [10]:
# 制作训练集验证集
train_data,valid_data = torch.utils.data.random_split(train_dataset,[n_train,n_validation])

In [11]:
train_loader = torch.utils.data.DataLoader(
    train_data,
    batch_size=batch_size,
    shuffle=True,
    num_workers=10
)
val_loader = torch.utils.data.DataLoader(
    valid_data,
    batch_size=batch_size,
    shuffle=True,
    num_workers=10
)

In [12]:
classes=[d for d in os.listdir(traindir) if os.path.isdir(os.path.join(traindir,d))]

In [13]:
classes

['cat', 'dog']

##  加载模型
- 在做比赛的时候，建议使用已经训练好的ImageNet模型进行修改和Fine-Tune，这样可以比较快的提交结果，但是并不说明效果一定比重新训练的好

In [14]:
def resnet34(pretrained=True):
    model = models.resnet34(pretrained=pretrained)
    def set_untrainable(layer):
        for p in layer.parameters():
            p.requires_grad = False
    for layer in model.children():
        layer.apply(set_untrainable)
    model.fc = nn.Linear(512,2)
    return model

def densenet161(pretrained=True):
    model = models.densenet161(pretrained=pretrained)
    def set_untrainable(layer):
        for p in layer.parameters():
            p.requires_grad = False

    for layer in model.children():
        layer.apply(set_untrainable)
    model.classifier = nn.Linear(2208, 2)
    model.cuda()
    return model

In [15]:
model = resnet34()

## 训练

In [16]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# if torch.cuda.device_count() > 1:
#     model = nn.DataParallel(model)
model = model.to(device)

loss = nn.CrossEntropyLoss().cuda()
optimizer = optim.Adam([p for p in model.parameters() if p.requires_grad],
                       lr=1e-2)

In [17]:
def check_accu(model,loader):
    num_correct = 0
    num_samples = 0
    model.eval() 
    start_time = time.time()
    for x, y in loader:
        with torch.no_grad():
            x_var = x.cuda()
            scores = model(x_var)
            _, preds = scores.data.cpu().max(1)
            num_correct += (preds == y).sum()
            num_samples += preds.size(0)
    acc = float(num_correct) / num_samples
    print('Got %d / %d correct (%.2f)' % 
          (num_correct, num_samples, 100 * acc))
    print('duration = %s' % timedelta(seconds=time.time() - start_time))

def train(model,loader,epochs):
    #start_time=time.time()
    for epoch in range(epochs):
        model.train()
        for t,(x,y) in enumerate(loader):
            batch_start = time.time()
            x = x.to(device)
            y = y.to(device)
            score = model(x)
            l = loss(score,y)
            optimizer.zero_grad()
            l.backward()
            optimizer.step()
            if (t+1)%100 == 0:
                print('t = %d, loss = %.4f, duration = %s' % (t + 1, 
                l.item(), timedelta(seconds=time.time() - batch_start)))
                check_accu(model,val_loader)
    

In [18]:
train(model, train_loader, epochs=1)

t = 100, loss = 0.1673, duration = 0:00:00.285549
Got 2378 / 2500 correct (95.12)
duration = 0:00:13.255753
t = 200, loss = 0.0610, duration = 0:00:00.281679
Got 2368 / 2500 correct (94.72)
duration = 0:00:13.268645
t = 300, loss = 0.0722, duration = 0:00:00.278628
Got 2284 / 2500 correct (91.36)
duration = 0:00:13.362712
t = 400, loss = 0.3829, duration = 0:00:00.283391
Got 2352 / 2500 correct (94.08)
duration = 0:00:13.362105
t = 500, loss = 0.1578, duration = 0:00:00.282051
Got 2386 / 2500 correct (95.44)
duration = 0:00:13.310507
t = 600, loss = 0.2267, duration = 0:00:00.282381
Got 2369 / 2500 correct (94.76)
duration = 0:00:13.319198
t = 700, loss = 0.0080, duration = 0:00:00.285527
Got 2376 / 2500 correct (95.04)
duration = 0:00:13.346316
