### 初始化任务名字（影响数据读取和保存的路径）等

In [1]:
import torchvision.transforms as transforms
from dataset import ImageClassificationDataset

TASK = 'test'
width=224
height=224
rootDir = TASK + '/cross'
imagesDir= rootDir + '/images'

CATEGORIES = ['道路','十字路口','左三叉路口','右三叉路口','左右三叉路口']

DATASETS = ['A', 'B']

TRANSFORMS = transforms.Compose([
    transforms.ColorJitter(0.1, 0.1, 0.1, 0.1),
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

datasets = {}
for name in DATASETS:
    datasets[name] = ImageClassificationDataset(imagesDir + '_' + name, CATEGORIES, TRANSFORMS)

###  显示收集的数据

In [2]:
import cv2
import ipywidgets
import traitlets
from IPython.display import display
from jetcam.utils import bgr8_to_jpeg
from jupyter_clickable_image_widget import ClickableImageWidget
from imagelist import ImageList
import threading
import time


# 导入记录的图片
imagelist = ImageList(TASK+'/record')

# initialize active dataset
dataset = datasets[DATASETS[0]]


# create image preview
camera_widget = ipywidgets.Image(width=width, height=height)
snapshot_widget = ipywidgets.Image(width=width, height=height)

# create widgets
dataset_widget = ipywidgets.Dropdown(options=DATASETS, description='dataset')
category_widget = ipywidgets.Dropdown(options=CATEGORIES, description='分类')
count_widget = ipywidgets.IntText(description='数量')

# manually update counts at initialization
count_widget.value = dataset.get_count(category_widget.value)

# sets the active dataset
def set_dataset(change):
    global dataset
    dataset = datasets[change['new']]
    count_widget.value,a = dataset.get_count(category_widget.value)
dataset_widget.observe(set_dataset, names='value')

# update counts when we select a new category
def update_counts(change):
    count_widget.value = dataset.get_count(change['new'])
category_widget.observe(update_counts, names='value')

# save image for category and update counts
def save(c):
    dataset.save_entry(imagelist.get_value(), category_widget.value)
    count_widget.value = dataset.get_count(category_widget.value)

#定义添加数据到分类按钮
add_button = ipywidgets.Button(description='添加图片到分类')
add_button.on_click(save)

# 定义播放按钮
playstat = 0
play_button = ipywidgets.Button(description='停止')
def playclick(c):
    global playstat
    if play_button.description == '播放':
        play_button.description = '停止'
        playstat = 0
    else:
        play_button.description = '播放'
        playstat = 1
play_button.on_click(playclick)

forward_button = ipywidgets.Button(description='前进')
def forward(c):
    global imagelist
    imagelist.get_next(step_widget.value)
forward_button.on_click(forward)

back_button = ipywidgets.Button(description='后退')
def back(c):
    global imagelist
    imagelist.get_next(-step_widget.value)
back_button.on_click(back)

step_widget = ipywidgets.IntText(value=1)


play_button.layout.width='15%'
forward_button.layout.width='15%'
step_widget.layout.width='15%'
back_button.layout.width='15%'


def play(camera_widget,play_button):
    global imagelist
    global playstat
    while True :
        time.sleep(0.2)
        if (playstat == 0):
            camera_widget.value = bgr8_to_jpeg(imagelist.get_next());
        else :
            camera_widget.value = bgr8_to_jpeg(imagelist.get_value());
            


execute_thread = threading.Thread(target=play, args=( camera_widget,play_button))
execute_thread.start()


        


camera_collection_widget = ipywidgets.VBox([camera_widget,ipywidgets.HBox([play_button,back_button,step_widget,forward_button])])
data_collection_widget = ipywidgets.VBox([
    ipywidgets.HBox([camera_collection_widget]),
    dataset_widget,
    category_widget,
    count_widget,
    add_button
])
display(data_collection_widget)


VBox(children=(HBox(children=(VBox(children=(Image(value=b'', height='224', width='224'), HBox(children=(Butto…

### 训练

In [3]:
import torch
import torchvision

device = torch.device('cuda')
output_dim = len(dataset.categories)  # x, y coordinate for each category

# ALEXNET
# model = torchvision.models.alexnet(pretrained=True)
# model.classifier[-1] = torch.nn.Linear(4096, output_dim)

# SQUEEZENET 
model = torchvision.models.squeezenet1_1(pretrained=True)
model.classifier[1] = torch.nn.Conv2d(512, output_dim, kernel_size=1)
model.num_classes = len(dataset.categories)

# RESNET 18
# model = torchvision.models.resnet18(pretrained=True)
# model.fc = torch.nn.Linear(512, output_dim)

# RESNET 34
# model = torchvision.models.resnet34(pretrained=True)
# model.fc = torch.nn.Linear(512, output_dim)

# DENSENET 121
# model = torchvision.models.densenet121(pretrained=True)
# model.classifier = torch.nn.Linear(model.num_features, output_dim)

model = model.to(device)

model_save_button = ipywidgets.Button(description='保存模型')
model_load_button = ipywidgets.Button(description='加载模型')
model_path_widget = ipywidgets.Text(description='模型文件', value=rootDir+'/model.pth')

def load_model(c):
    model.load_state_dict(torch.load(model_path_widget.value))
model_load_button.on_click(load_model)
    
def save_model(c):
    torch.save(model.state_dict(), model_path_widget.value)
model_save_button.on_click(save_model)

model_widget = ipywidgets.VBox([
    model_path_widget,
    ipywidgets.HBox([model_load_button, model_save_button])
])



### 测试和保存

In [4]:
from utils import preprocess
import torch.nn.functional as F

state_widget = ipywidgets.ToggleButtons(options=['stop', 'live'], description='state', value='stop')
prediction_widget = ipywidgets.Text(description='分析结果')
score_widgets = []
for category in dataset.categories:
    score_widget = ipywidgets.FloatSlider(min=0.0, max=1.0, description=category, orientation='vertical')
    score_widgets.append(score_widget)
    
def live(state_widget, model, imagelist, prediction_widget):
    global dataset
    while state_widget.value == 'live':
        time.sleep(0.1)
        image = imagelist.get_value()
        preprocessed = preprocess(image)
        output = model(preprocessed)
        output = F.softmax(output, dim=1).detach().cpu().numpy().flatten()
        category_index = output.argmax()
        prediction_widget.value = dataset.categories[category_index]
        for i, score in enumerate(list(output)):
            score_widgets[i].value = score
            
def start_live(change):
    if change['new'] == 'live':
        execute_thread = threading.Thread(target=live, args=(state_widget, model, imagelist, prediction_widget))
        execute_thread.start()

state_widget.observe(start_live, names='value')

live_execution_widget = ipywidgets.VBox([
    ipywidgets.HBox(score_widgets),
    prediction_widget,
    state_widget
])





BATCH_SIZE = 8

optimizer = torch.optim.Adam(model.parameters())
# optimizer = torch.optim.SGD(model.parameters(), lr=1e-3, momentum=0.9)

epochs_widget = ipywidgets.IntText(description='epochs（次数）', value=1)
eval_button = ipywidgets.Button(description='验证')
train_button = ipywidgets.Button(description='训练')
loss_widget = ipywidgets.FloatText(description='loss')
accuracy_widget = ipywidgets.FloatText(description='准确度')
progress_widget = ipywidgets.FloatProgress(min=0.0, max=1.0, description='进度')

def train_eval(is_training):
    global BATCH_SIZE, LEARNING_RATE, MOMENTUM, model, dataset, optimizer, eval_button, train_button, accuracy_widget, loss_widget, progress_widget, state_widget
    
    try:
        train_loader = torch.utils.data.DataLoader(
            dataset,
            batch_size=BATCH_SIZE,
            shuffle=True
        )

        state_widget.value = 'stop'
        train_button.disabled = True
        eval_button.disabled = True
        time.sleep(1)

        if is_training:
            model = model.train()
        else:
            model = model.eval()

        while epochs_widget.value > 0:
            i = 0
            sum_loss = 0.0
            error_count = 0.0
            for images, labels in iter(train_loader):
                # send data to device
                images = images.to(device)
                labels = labels.to(device)

                if is_training:
                    # zero gradients of parameters
                    optimizer.zero_grad()

                # execute model to get outputs
                outputs = model(images)

                                # compute loss
                loss = F.cross_entropy(outputs, labels)

                if is_training:
                    # run backpropogation to accumulate gradients
                    loss.backward()

                    # step optimizer to adjust parameters
                    optimizer.step()
                # increment progress
                error_count += len(torch.nonzero(outputs.argmax(1) - labels).flatten())
                count = len(labels.flatten())
                i += count
                sum_loss += float(loss)
                progress_widget.value = i / len(dataset)
                loss_widget.value = sum_loss / i
                accuracy_widget.value = 1.0 - error_count / i
                
                
            if is_training:
                epochs_widget.value = epochs_widget.value - 1
            else:
                break
    except e:
        pass
    model = model.eval()

    train_button.disabled = False
    eval_button.disabled = False
    state_widget.value = 'live'
    
train_button.on_click(lambda c: train_eval(is_training=True))
eval_button.on_click(lambda c: train_eval(is_training=False))
    
train_eval_widget = ipywidgets.VBox([
    epochs_widget,
    progress_widget,
    loss_widget,
    accuracy_widget,
    ipywidgets.HBox([train_button, eval_button])
])

all_widget = ipywidgets.VBox([
    ipywidgets.HBox([data_collection_widget, live_execution_widget]), 
    train_eval_widget,
    model_widget
])

display(all_widget)


VBox(children=(HBox(children=(VBox(children=(HBox(children=(VBox(children=(Image(value=b'\xff\xd8\xff\xe0\x00\…