In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import models
from torchvision.transforms import transforms
from PIL import Image
import numpy as np
import random

In [3]:
seed = 2023
np.random.seed(seed)
random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True

In [4]:
with open('./flags.txt', 'r', encoding='utf8') as f:
    flags = f.read()

In [5]:
flags = flags.split('\n')
label_to_flag = dict(enumerate(flags))
flag_to_label = dict([(j, i) for i, j in label_to_flag.items()])

In [6]:
import os
import json

idx = [i[:-5] for i in os.listdir('./dataset/')]
labels = []
unknow = []
for p in os.listdir('./dataset/'):
    with open(f'./dataset/{p}', 'r', encoding='utf8') as f:
        label = json.load(f)
        label = [item[0] for item in label['flags'].items() if item[1] is True]
        if label[0] == 'unknow':
            unknow.append(p[:-5])
        else:
            labels.append(label[0])
idx = [i for i in idx if i not in unknow]
labels = [flag_to_label[i] - 1 for i in labels]

In [7]:
len(labels)

3553

In [8]:
from sklearn.model_selection import train_test_split

file_paths = ['./data4/'+i+'.png' for i in idx]
train_imgs, val_imgs, train_labels, val_labels = train_test_split(file_paths, labels, test_size=0.3, stratify=labels, random_state=2023)
val_imgs, test_imgs, val_labels, test_labels = train_test_split(val_imgs, val_labels, test_size=0.5, stratify=val_labels, random_state=2023)
# val_imgs, test_imgs, val_labels, test_labels = train_test_split(val_imgs, val_labels, test_size=0.5, random_state=2023)

In [9]:
# from collections import Counter

img_size = (112, 112)
norm_mean = [0.485, 0.456, 0.406]
norm_std = [0.229, 0.224, 0.225]
# trf = transforms.Compose([
#     transforms.Resize(img_size), 
#     transforms.ToTensor(),
#     transforms.Normalize(mean=norm_mean, std=norm_std)
# ])
trf_dict = {
    'train': transforms.Compose([
        transforms.Resize(img_size),
        transforms.RandomChoice([
            transforms.RandomRotation(degrees=(-10, 10)),
            transforms.RandomHorizontalFlip(p=0.5)
        ]),
        transforms.ToTensor(),
        transforms.Normalize(mean=norm_mean, std=norm_std)
    ]),
    'val': transforms.Compose([
        transforms.Resize(img_size),
        transforms.ToTensor(),
        transforms.Normalize(mean=norm_mean, std=norm_std)
    ]),
    'test': transforms.Compose([
        transforms.Resize(img_size),
        transforms.ToTensor(),
        transforms.Normalize(mean=norm_mean, std=norm_std)
    ])
}
class ImgDataset(Dataset):
    
    def __init__(self, img_paths, labels, trf):
        self.img_paths = img_paths
        self.labels = labels
        self.trf = trf
    
    def __len__(self):
        return len(self.labels)
    
    def __getitem__(self, idx):
        path = self.img_paths[idx]
        img = Image.open(path).convert('RGB')
        img = self.trf(img)
        target = self.labels[idx]
        target = torch.tensor(target, dtype=torch.long)
        
        return img, target

In [10]:
batch_size = 32

train_dataset = ImgDataset(train_imgs, train_labels, trf_dict['train'])
val_dataset = ImgDataset(val_imgs, val_labels, trf_dict['val'])
test_dataset = ImgDataset(test_imgs, test_labels, trf_dict['test'])

train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(dataset=val_dataset, batch_size=batch_size//2, shuffle=False)
test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size//2, shuffle=False)

In [11]:
for img, label in train_loader:
    print(img.shape)
    print(label.shape)
    break

torch.Size([32, 3, 112, 112])
torch.Size([32])


In [12]:
class ImgClassifyModel(nn.Module):
    
    def __init__(self, class_num, pretrained=None):
        super().__init__()
        self.model = models.efficientnet_b5(pretrained=False)
        # self.model = models.efficientnet_b7(pretrained=False)
        if pretrained:
            self.model.load_state_dict(torch.load(pretrained))
        # self.model.classifier.add_module('3', nn.Linear(1000, class_num))
        self.model.classifier[1] = nn.Linear(2048, class_num)
        # self.model.classifier[1] = nn.Linear(2560, class_num)
    
    def forward(self, x):
        x = self.model(x)
        
        return x

class_num = 90
model_path = './model/efficientnet_b5_lukemelas-b6417697.pth'
# model_path = None
model = ImgClassifyModel(class_num=class_num, pretrained=model_path)

inputs = torch.zeros((32, 3, 112, 112))
outputs = model(inputs)
print(outputs.shape)

torch.Size([32, 90])


In [13]:
def eval_acc(y, pred_y):
    pred_y = pred_y.detach().argmax(dim=-1)
    acc = (y == pred_y).cpu().numpy()
    return acc.mean()

def train(epoch, model, iterator, optimizer, loss_fct, scheduler=None, device='cpu'):
    model.train()
    step = 0
    all_loss = 0
    all_acc = 0
    for img, label in iterator:
        step += 1
        img = img.to(device)
        label = label.to(device)
        
        pred = model(img)
        loss = loss_fct(pred, label)
        all_loss += loss.item()
        
        loss.backward()
        optimizer.step()
        if scheduler is not None:
            scheduler.step()
        optimizer.zero_grad()
        
        acc = eval_acc(label, pred)
        all_acc += acc
    
    print("Epoch: {}, Train Loss: {:.4f}, Train Acc: {:.4f}".format(epoch, all_loss / step, all_acc / step))

def validate(epoch, model, iterator, loss_fct, device):
    model.eval()
    step = 0
    all_loss = 0
    all_acc = 0
    with torch.no_grad():
        for img, label in iterator:
            step += 1
            img = img.to(device)
            label = label.to(device)

            pred = model(img)
            loss = loss_fct(pred, label)
            all_loss += loss.item()

            acc = eval_acc(label, pred)
            all_acc += acc
    
    print("Epoch: {}, Val Loss: {:.4f}, Val Acc: {:.4f}".format(epoch, all_loss / step, all_acc / step))
    return model, all_loss / step, all_acc / step

In [14]:
%%time
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
model = model.to(device)
epochs = 30
lr = 0.0005
# opt = torch.optim.Adam(model.parameters(), lr=lr)
opt = torch.optim.RMSprop(model.parameters(), alpha=0.9, momentum=0.9, lr=lr, weight_decay=1e-5)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer=opt, step_size=6, gamma=0.1)
loss_func = nn.CrossEntropyLoss()

best_model = None
best_val_loss = 1e10
best_val_acc = 1e-10

print("start train=============================")
for epoch in range(1, epochs+1):
    train(epoch, model=model, iterator=train_loader, loss_fct=loss_func, optimizer=opt, device=device)
    candidate_model, loss, acc = validate(epoch, model=model, iterator=val_loader, loss_fct=loss_func, device=device)
    scheduler.step()
    
    if loss < best_val_loss and acc > best_val_acc:
        best_model = candidate_model
        best_val_loss, best_val_acc = loss, acc
    
    print("===========================================")
print("train finish=============================")

Epoch: 1, Train Loss: 4.1492, Train Acc: 0.1057
Epoch: 1, Val Loss: 5122851.4412, Val Acc: 0.0129
Epoch: 2, Train Loss: 3.3710, Train Acc: 0.1943
Epoch: 2, Val Loss: 113809.4559, Val Acc: 0.0129
Epoch: 3, Train Loss: 2.6929, Train Acc: 0.3267
Epoch: 3, Val Loss: 272.7396, Val Acc: 0.0129
Epoch: 4, Train Loss: 2.1453, Train Acc: 0.4472
Epoch: 4, Val Loss: 324.2922, Val Acc: 0.0500
Epoch: 5, Train Loss: 1.9089, Train Acc: 0.5061
Epoch: 5, Val Loss: 20.5827, Val Acc: 0.1971
Epoch: 6, Train Loss: 1.5227, Train Acc: 0.5879
Epoch: 6, Val Loss: 6.7400, Val Acc: 0.3460
Epoch: 7, Train Loss: 0.9506, Train Acc: 0.7318
Epoch: 7, Val Loss: 1.2755, Val Acc: 0.7048
Epoch: 8, Train Loss: 0.6234, Train Acc: 0.8305
Epoch: 8, Val Loss: 0.9924, Val Acc: 0.7471
Epoch: 9, Train Loss: 0.4625, Train Acc: 0.8657
Epoch: 9, Val Loss: 0.9051, Val Acc: 0.7548
Epoch: 10, Train Loss: 0.4208, Train Acc: 0.8744
Epoch: 10, Val Loss: 0.8533, Val Acc: 0.7621
Epoch: 11, Train Loss: 0.3195, Train Acc: 0.9050
Epoch: 11, Va

In [15]:
best_model.eval()
test_acc = 0
step = 0
with torch.no_grad():
    for img, label in test_loader:
        step += 1
        img = img.to(device)
        label = label.to(device)
        
        pred = best_model(img)
        
        acc = eval_acc(label, pred)
        test_acc += acc

test_acc / step

0.7838235294117647

In [16]:
best_val_acc

0.8231617647058824

In [21]:
# torch.save(best_model.state_dict(), "./model/model_0724.pt")

In [16]:
trf = transforms.Compose([
    transforms.Resize(img_size), 
    transforms.ToTensor(),
    transforms.Normalize(mean=norm_mean, std=norm_std)
])
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'

In [75]:
test_img = Image.open('./10013.png').convert('RGB')
test_img = trf(test_img).unsqueeze(0)

In [76]:
test_img = test_img.to(device)
# model = ImgClassifyModel(class_num=class_num, pretrained=None)
# model.load_state_dict(torch.load('./model/model_0724.pt'))
# model = model.to(device)

In [77]:
model.eval()
pred = model(test_img)

In [78]:
pred.detach().argmax(dim=-1).cpu().numpy()[0]

26

In [79]:
label_to_flag

{0: 'unknow',
 1: '鱼',
 2: '鸟',
 3: '熊猫',
 4: '兔子',
 5: '猫',
 6: '鳄鱼',
 7: '狗',
 8: '猪',
 9: '狮子',
 10: '老虎',
 11: '鸡',
 12: '袋鼠',
 13: '猴子',
 14: '乌龟',
 15: '羊',
 16: '牛',
 17: '蝴蝶',
 18: '企鹅',
 19: '瓢虫',
 20: '鹿、梅花鹿',
 21: '长颈鹿',
 22: '骆驼',
 23: '照相机',
 24: '剪刀',
 25: '钥匙',
 26: '荡秋千',
 27: '螺丝刀',
 28: '铁轨',
 29: '锅',
 30: '鼠标',
 31: '书',
 32: '键盘',
 33: '订书机',
 34: '拱桥',
 35: '锅铲',
 36: '听诊器',
 37: '笔、钢笔',
 38: '冰箱',
 39: '足球',
 40: '轮胎',
 41: '手表',
 42: '斧头',
 43: '方向盘',
 44: '打印机',
 45: '台灯',
 46: '桌球',
 47: '排插',
 48: '梳子',
 49: '救护车',
 50: '电钻',
 51: '望远镜',
 52: '雨伞',
 53: '积木、城堡积木',
 54: '叉子',
 55: '井盖',
 56: '齿轮',
 57: '轮船、船',
 58: '袜子',
 59: '头盔',
 60: '火箭',
 61: '钱包',
 62: '红绿灯',
 63: '口红',
 64: '拉链',
 65: '计算器',
 66: '烟斗',
 67: '大巴车、公交车',
 68: '地球仪',
 69: '喷泉池',
 70: '牙刷、刷子',
 71: '气球',
 72: '手套',
 73: '水壶',
 74: '碗',
 75: '针筒',
 76: '苍蝇拍、电蚊拍',
 77: '羽毛球',
 78: '过山车',
 79: '桌子',
 80: '贝壳',
 81: '眼镜、带着眼镜的人',
 82: '轮椅',
 83: '帽子、带着帽子的人',
 84: '领带',
 85: '音响',
 86: '椅子',
 87: 