In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import torchvision
import torchvision.datasets as dset
import torchvision.transforms as transforms

import torch.optim as optim
import torchvision.models as models

import PIL.Image as Image

In [2]:
train_data = dset.ImageFolder(root="Gestures_all")

In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cpu


In [4]:
resnet18 = models.resnet18(pretrained=True)

In [5]:
# 图像预处理
image_size = (128, 128)  # 图像大小设为128×128
data_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),  # 随机水平翻转图片
    transforms.Resize(image_size),  # 重设图像大小
    transforms.CenterCrop(image_size),  # 中间裁剪
    transforms.ToTensor(),  # 数据进转化为torch.FloadTensor，并归一化到[0, 1.0]
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ])

In [6]:
train_data = dset.ImageFolder(root="Gestures_all",transform=data_transform)
# 数据集长度
totallen = len(train_data)
print('train data length:',totallen)

train data length: 120625


In [8]:
trainlen = int(totallen * 0.95)  # 训练集长度（总数据的0.7，2075张图片）
vallen = totallen - trainlen  # 验证集长度（总数据的0.3，890张图片）
train_db,val_db=torch.utils.data.random_split(train_data, [trainlen, vallen])  # 分离数据
print('train:', len(train_db), 'validation:', len(val_db))

# batch size，每个banch中训练样本的数量
bs = 32
# 训练集
train_loader = torch.utils.data.DataLoader(train_db, batch_size=bs, shuffle=True, num_workers=12)
# 验证集
val_loader = torch.utils.data.DataLoader(val_db, batch_size=bs, shuffle=True, num_workers=12)

train: 114593 validation: 6032


In [9]:
# 相关函数（计算预测正确个数）
def get_num_correct(out, labels):
    return out.argmax(dim=1).eq(labels).sum().item()

In [10]:
model = resnet18
n_classes = len(train_data.classes)
model.fc = nn.Linear(512, n_classes)

In [11]:
n_classes

24

In [12]:
import torch.nn.init as init

for name, module in model._modules.items():
    if(name=='fc'):
        # print(module.weight.shape)
        init.kaiming_uniform_(module.weight, a=0, mode='fan_in')

In [13]:
import datetime

In [15]:
optimizer=torch.optim.SGD(model.parameters(), lr=0.01)
model = model.to(device)
epoch_num = 30 # 训练次数
for epoch in range(epoch_num):
    start_time = datetime.datetime.now()
    total_loss=0
    total_correct=0
    val_correct=0
    model.train()
    for batch in train_loader:#GetBatch
        images,labels=batch
        outs=model(images.to(device))#PassBatch
        loss=F.cross_entropy(outs,labels.to(device))#CalculateLoss
        optimizer.zero_grad()
        loss.backward()#CalculateGradients
        optimizer.step()#UpdateWeights
        total_loss+=loss.item()
        total_correct+=get_num_correct(outs,labels.to(device))
    for batch in val_loader:
        images,labels=batch
        outs=model(images.to(device))
        val_correct+=get_num_correct(outs,labels.to(device))
    end_time = datetime.datetime.now()
    if epoch % 10 == 0:
        torch.save(model, 'gestures-120625-' + str(epoch) + 'train.pkl')
    print("loss:",total_loss,"train_correct:",total_correct/trainlen, "val_correct:",val_correct/vallen, "time:", str(end_time - start_time))

loss: 1936.6197801735252 train_correct: 0.836141823671603 val_correct: 0.9207559681697612 time: 1:10:04.413396
loss: 270.1901269913651 train_correct: 0.9784978140026005 val_correct: 0.9618700265251989 time: 1:07:04.984635
loss: 121.95029278902803 train_correct: 0.9911774715733073 val_correct: 0.9781167108753316 time: 1:23:29.135024


KeyboardInterrupt: 

In [16]:
torch.save(model, 'gestures-120625-3train.pkl')