In [None]:
import glob
import os
import cv2
import numpy as np
from tqdm import tqdm
from PIL import Image, ImageStat
import matplotlib.pyplot as plt
import warnings
warnings.filterwarnings(action='ignore')

import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torchvision.models.resnet import resnet18
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter
import torch.optim as optim

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
train_bs = 64
test_bs = 64
valid_bs = 64
lr_init = 0.01
max_epoch = 100
save_path = '../models/'

In [None]:
train_path = '../200_bird/train/'
valid_path = '../200_bird/valid/'
test_path = '../200_bird/test/'

train_all = {}
test_all = {}
valid_all = {}

image_paths = {
    'Train': {},
    'Valid': {},
    'Test' : {},
}


categories = os.listdir(train_path)

for Type in ['Train', 'Valid', 'Test']:
    if Type == 'Train':
        root = train_path
    elif Type == 'Valid':
        root = valid_path
    else:
        root = test_path
    for category in categories:
        image_paths[Type][category] = glob.glob(os.path.join(root + category, '*.jpg'))

In [None]:
mapkey = {}

for i, category in enumerate(categories):
    mapkey[category] = str(i)

def gen_txt(txt_path, image_paths):
    f = open(txt_path, 'w')
    for key in image_paths.keys():
        label = mapkey[key]
        for path in image_paths[key]:
            line = path + ' ' + label + '\n'
            f.write(line)

gen_txt('../200_bird/train.txt', image_paths['Train'])
gen_txt('../200_bird/valid.txt', image_paths['Valid'])
gen_txt('../200_bird/test.txt', image_paths['Test'])

In [None]:
# 计算训练集的均值和标准差

all_paths = []

for Type in ['Train', 'Valid', 'Test']:
    for category in image_paths[Type]:
        for path in image_paths[Type][category]:
            all_paths.append(path)
            



m_list, s_list = [], []
for path in tqdm(all_paths):
    img = cv2.imread(path)
    img = img / 255.0
    m, s = cv2.meanStdDev(img)
    m_list.append(m.reshape((3,)))
    s_list.append(s.reshape((3,)))
m_array = np.array(m_list)
s_array = np.array(s_list)
m = m_array.mean(axis=0, keepdims=True)
s = s_array.mean(axis=0, keepdims=True)

In [None]:
BGR -> RGB
print(m[0][::-1])
print(s[0][::-1])

In [None]:
class MyDataset(Dataset):
    def __init__(self, txt_path, transform=None, target_transform=None):
        fh = open(txt_path, 'r')
        imgs = []
        for line in fh:
            line = line.rstrip()
            words = line.split(' ')
            label = words[-1]
            path = ' '.join(words[:-1])
            imgs.append((path, int(label)))

        self.imgs = imgs        # 最主要就是要生成这个list， 然后DataLoader中给index，通过getitem读取图片数据
        self.transform = transform
        self.target_transform = target_transform

    def __getitem__(self, index):
        fn, label = self.imgs[index]
        img = Image.open(fn).convert('RGB')     # 像素值 0~255，在transfrom.totensor会除以255，使像素值变成 0~1

        if self.transform is not None:
            img = self.transform(img)   # 在这里做transform，转为tensor等等

        return img, label

    def __len__(self):
        return len(self.imgs)

In [None]:
normMean = [0.47176269, 0.47058557, 0.40052285]
normStd = [0.20032301, 0.196909, 0.20223242]
normTransform = transforms.Normalize(normMean, normStd)
trainTransform = transforms.Compose([
    transforms.RandomCrop(size=(224,224), padding=28),
    transforms.RandomRotation(30),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    normTransform,
])

validTransfrom = transforms.Compose([
    transforms.ToTensor(),
    normTransform,
])

testTransform = transforms.Compose([
    transforms.ToTensor(),
    normTransform,
])

In [None]:
train_txt_path = '../200_bird/train.txt'
valid_txt_path = '../200_bird/valid.txt'
test_txt_path = '../200_bird/test.txt'

train_data = MyDataset(txt_path=train_txt_path, transform=trainTransform)
valid_data = MyDataset(txt_path=valid_txt_path, transform=trainTransform)
test_data = MyDataset(txt_path=test_txt_path, transform=testTransform)

In [None]:
train_loader = DataLoader(dataset = train_data, batch_size = train_bs, shuffle=True)
valid_loader = DataLoader(dataset = valid_data, batch_size = valid_bs, shuffle=True)
test_loader = DataLoader(dataset = test_data, batch_size = test_bs, shuffle=True)

In [None]:
def image_convert(img):
    img = img.clone().cpu().numpy()
    img = img.transpose(1,2,0)
    normStd = [0.22971935, 0.22475049, 0.22525084]
    normMean = [0.48827705, 0.45510637, 0.41741   ]
    img = img*normStd + normMean
    return img


def plot_batch():
        iter_ = iter(train_loader)
        images,labels = next(iter_)
        
        plt.figure(figsize=(20,10))
        for idx in range(train_bs):
            plt.subplot(4,train_bs//4,idx+1)
            img = image_convert(images[idx])
            label = labels[idx]
            plt.imshow(img)
            plt.title(categories[label])
        plt.show()

plot_batch()

In [None]:
class fc_part(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(512,200)
        
    def forward(self,x):
        x = self.fc1(x)
        return x

In [None]:
model = resnet18(pretrained=False).to(device)
model.fc = fc_part().to(device)

In [None]:
images, labels = next(iter(train_loader))
if torch.cuda.is_available():
    images = images.to(device)
grid = torchvision.utils.make_grid(images)
comment = f'batch_size{train_bs} lr{lr_init}'
tb = SummaryWriter(comment=comment)
tb.add_image('images', grid)
tb.add_graph(model, images)

In [None]:
ignored_params = list(map(id, model.fc.parameters()))
base_params = filter(lambda p: id(p) not in ignored_params, model.parameters())

In [None]:
optimizer = optim.SGD([
    {'params': base_params},
    {'params': model.fc.parameters(), 'lr': lr_init*10}],  lr_init, momentum=0.9, weight_decay=1e-4)

criterion = nn.CrossEntropyLoss()                                                   # 选择损失函数
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.1)     # 设置学习率下降策略

In [None]:
cur_max = 0
for epoch in range(1, max_epoch+1):

    train_loss = 0.0    # 记录一个epoch的loss之和
    train_correct = 0.0
    train_total = 0.0
    
    scheduler.step()  # 更新学习率
    
    with tqdm(train_loader, desc = f'Train epoch: {epoch}') as t:
        model.train()
        for data in t:
            # 获取图片和标签
            inputs, labels = data
            inputs = inputs.to(device)
            labels = labels.to(device)

            # forward, backward, update weights
            optimizer.zero_grad()
            outputs = model.forward(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            # 统计预测信息
            _, predicted = torch.max(outputs, axis = 1)
            train_total += labels.size(0)
            train_correct += torch.sum(predicted == labels).item()
            train_loss += loss.item()

            #设置进度条右边显示的信息
            t.set_postfix(train_loss = loss.item(), train_accuracy = train_correct / train_total)
    
    valid_loss = 0.0
    valid_correct = 0
    valid_total = 0

    with torch.no_grad():
        model.eval()
        with tqdm(valid_loader, desc = f'Valid epoch: {epoch}') as t:
            for data in t:                
                inputs, labels = data
                inputs = inputs.to(device)
                labels = labels.to(device)
                    
                outputs = model.forward(inputs)
                loss = F.cross_entropy(outputs, labels)
                valid_loss += loss.item()
                _, predicted = torch.max(outputs, axis = 1)
                valid_total += labels.size(0)
                valid_correct += torch.sum(predicted == labels).item()
                t.set_postfix(valid_loss = loss.item(), valid_accuracy = valid_correct / valid_total)
            
    tb.add_scalar('train_loss', train_loss/train_total, epoch)
    tb.add_scalar('train_accuracy', train_correct/train_total, epoch)
    tb.add_scalar('valid_loss', valid_loss/valid_total, epoch)
    tb.add_scalar('valid_accuracy', valid_correct / valid_total, epoch)
    
    cur_max = max(cur_max, valid_correct / valid_total)
    if valid_correct / valid_total >= max(0.95, cur_max):
        model_name = 'model_' + str((valid_correct / valid_total)*1000) + '.pkl'
        torch.save(model.state_dict(), '../models/' + model_name)

In [None]:
model.load_state_dict(torch.load('../models/model_955.0.pkl'))

In [None]:
test_loss = 0.0
test_correct = 0
test_total = 0

with torch.no_grad():
    model.eval()
    with tqdm(test_loader, desc = f'Test epoch: {1}') as t:
        for data in t:                
            inputs, labels = data
            inputs = inputs.to(device)
            labels = labels.to(device)

            outputs = model.forward(inputs)
            loss = F.cross_entropy(outputs, labels)
            test_loss += loss.item()
            _, predicted = torch.max(outputs, axis = 1)
            test_total += labels.size(0)
            test_correct += torch.sum(predicted == labels).item()
            t.set_postfix(test_loss = loss.item(), test_accuracy = test_correct / test_total)

In [None]:
def plot_val_images():

    iter_ = iter(test_loader)
    images,labels = next(iter_)
    images = images.to(device)
    pred_labels = labels.to(device)

    img_out = model.forward(images)
    value, index_val = torch.max(img_out, 1)

    # label = label_dict[str(label)]
    fig = plt.figure(figsize=(35,9))
    for idx in np.arange(10):
        ax = fig.add_subplot(2,5,idx+1)
        plt.imshow(image_convert(images[idx]))
        label = labels[idx]  
        pred_label = pred_labels[idx]
        ax.set_title('Act {},pred {}'.format(categories[label],categories[pred_label]))
        
plot_val_images()