In [2]:
%cd /home/ircv3/HYU-2024-Embedded/jetracer/AlexNet_Road_Center_Detection

/home/ircv3/HYU-2024-Embedded/jetracer/AlexNet_Road_Center_Detection


In [110]:
import torch
import torchvision

def get_model(num_classes=2):
    model = torchvision.models.alexnet(num_classes=num_classes, dropout=0.0)
    return model

device = torch.device('cuda')
model = get_model(1)
model = model.to(device)

In [111]:
from cnn.center_dataset import CenterDataset
from torch.utils.data import random_split
import random
import numpy as np

# TODO
######################################
seed = 1027
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)

determinstic = True # 안 됨
if determinstic:
    torch.backends.cudnn.determinstic = True
    torch.backends.cudnn.benchmark = False

batch_size = 16
######################################

dataset = CenterDataset('dataset/line', random_hflip=False, only_use_x=True)

train_size = int(0.8 * len(dataset))
val_size = int(0.1 * len(dataset))
test_size = len(dataset) - train_size - val_size

# 데이터셋 분할
train_dataset, val_dataset, test_dataset = random_split(dataset, [train_size, val_size, test_size]) # 3363, 420, 421
  # 메소드: .dataset, .indices

train_loader = torch.utils.data.DataLoader(
    train_dataset,
    num_workers=0,
    batch_size=batch_size,
    shuffle=True,
)

val_loader = torch.utils.data.DataLoader(
    val_dataset,
    num_workers=0,
    batch_size=batch_size,
    shuffle=True,
)

test_loader = torch.utils.data.DataLoader(
    test_dataset,
    num_workers=0,
    batch_size=batch_size,
    shuffle=True,
)

In [112]:
import torch.nn.functional as f

import ipywidgets
import datetime
import pytz
import os
import time

# 한국 시간 타임존 설정
korea_timezone = pytz.timezone('Asia/Seoul')
# 현재 한국 시간 구하기
korea_time = datetime.datetime.now(tz=korea_timezone)
# 포맷팅하여 timestamp 생성
timestamp = korea_time.strftime('%y%m%d_%H%M%S')

# TODO
#######################################
epoch = 32

# learning_rate = 1e-2
learning_rate = 2e-3
# learning_rate = 5e-4

reduction = 'sum'
# reduction = 'mean'
#######################################

checkpoint_dir = f'checkpoints/{timestamp}_b-{batch_size}_r-{reduction}'

epoch_slider = ipywidgets.IntSlider(description='Epochs', value=epoch, min=1, max=200, step=1)
lr_slider = ipywidgets.FloatSlider(description='lr', value=learning_rate, min=1e-4, max=1e-2, step=1e-4, readout_format='.4f')
train_button = ipywidgets.Button(description='Train', icon='tasks')
stop_button = ipywidgets.Button(description='Stop', icon='tasks')
loss_text = ipywidgets.Textarea(description='Progress', value='', rows=15, layout=ipywidgets.Layout(width="100%", height="auto"))
layout = ipywidgets.VBox([ipywidgets.HBox([epoch_slider, lr_slider, train_button, stop_button]), loss_text])

training_flag = True

min_train_loss = 99
min_val_loss = 99

def train_model(b):
    global epoch_slider, training_flag, checkpoint_dir
    for epoch in range(1, epoch_slider.value+1):
        if training_flag == False:
          break
        loss_text.value += "\n<<<<< Epoch {:d} >>>>>\n".format(epoch)
        time.sleep(1)
        train_step(epoch)
    with open(os.path.join(checkpoint_dir, 'log.txt'), 'w') as file:
      file.write(loss_text.value)
    print('Train is finished and log is saved!')

def train_stop(b):
    global training_flag
    training_flag = False
    print('<<< Train Stop! >>>')

def train_step(epoch):
    global model, lr_slider, loss_text, train_laoder, device, checkpoint_dir, training_flag, min_train_loss, min_val_loss, reduction

    try:
        # optimizer = torch.optim.Adam(model.parameters(), lr=lr_slider.value)
        # optimizer_name
        optimizer = torch.optim.SGD(model.parameters(), lr=lr_slider.value, momentum=0.9)
        
        loss_text.value += "<<<<< lr: {:f} >>>>>\n".format(optimizer.state_dict()['param_groups'][0]['lr'])

        train_button.disabled = True

        num_iters = len(train_loader)
        for ii, (images, labels) in enumerate(train_loader):

            if training_flag == False:
              break

            ii = ii+1

            model = model.train() # 적절한 dropout과 BN으로 변경

            # send data to device
            images = images.to(device)
            labels = labels.to(device)

            # zero gradients of parameters
            optimizer.zero_grad()

            # execute model to get outputs
            outputs = model(images)

            # compute MSE loss over x coordinates
            loss = f.mse_loss(outputs, labels, reduction=reduction)

            # run backpropogation to accumulate gradients
            loss.backward()

            # step optimizer to adjust parameters
            optimizer.step()
            
            if ii % 10 == 0 or ii == num_iters:
              if reduction == 'sum':
                train_loss = (loss.item() / labels.shape[0]) ** 0.5
              elif reduction == 'mean':
                train_loss = loss.item() ** 0.5
              msg = "[{:04d} / {:04d}] train_loss: {:.6f}".format(ii, num_iters, train_loss)
              loss_text.value += msg
                
              if train_loss < min_train_loss:
                msg = "-->min "
                loss_text.value += msg
                min_train_loss = train_loss

              # if ii % 100 == 0 or ii == num_iters:
              if ii == num_iters:
                  model.eval()

                  results_, labels_ = None, None
                  with torch.no_grad():
                    val_loss = 0
                    max_val_batch_loss = 0
                    for images, labels in val_loader:
                      images = images.to(device)
                      labels = labels.to(device)
                      results = model(images)
                      val_loss += f.mse_loss(results, labels, reduction='sum')
                      if max_val_batch_loss < val_loss:
                            results_, labels_ = results, labels
                            max_val_batch_loss = val_loss
                    val_loss = (val_loss / val_size) ** 0.5
                    
                  msg = " | val_loss: {:.6f}".format(val_loss)
                  loss_text.value += msg
                  
                  if val_loss < min_val_loss:
                    msg = "-->min "
                    loss_text.value += msg
                    min_val_loss = val_loss
      
                  r_list, l_list = [], []
                  for r, l in zip(results_, labels_):
                      r_list.append(int((r.item() / 2 + 0.5 ) * 960))
                      l_list.append(int((l.item() / 2 + 0.5 ) * 960))
                    
                  msg = "\n----- labels: {}\n---- results: {}".format(l_list, r_list)
                  loss_text.value += msg
                  
              loss_text.value += "\n"
              

    except Exception as e:
        print(e)
        pass

    model = model.eval()

    val_loss_zfill = f'{val_loss:.6f}'.replace('.', '').zfill(9)
    
    if epoch == 1:
        os.makedirs(checkpoint_dir, exist_ok=True)
    
    torch.save({'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimzizer_state_dict': optimizer.state_dict(),
                'loss': val_loss
                },
               f'{checkpoint_dir}/epoch_{epoch}_loss_{val_loss_zfill}.pt')

    # torch.save(model.state_dict(), f'model/epoch_{epoch}.pth')
    # torch.save(model, f'model/epoch_{epoch}.pt')

    train_button.disabled = False

train_button.on_click(train_model)
stop_button.on_click(train_stop)

display(layout)



VBox(children=(HBox(children=(IntSlider(value=32, description='Epochs', max=200, min=1), FloatSlider(value=0.0…

Train is finished and log is saved!


In [None]:
# 저장한 모델 불러오기


기존 모델 validation

In [75]:
model = get_model()
model.load_state_dict(torch.load('/home/ircv3/HYU-2024-Embedded/jetracer/model/road_following_model.pth'))
model = model.to(device)

In [76]:
loss_text = ''
model.eval()

with torch.no_grad():
    val_loss = 0
    max_val_batch_loss = 0
    for images, labels in val_loader:
        images = images.to(device)
        labels = labels.to(device)
        results = model(images)[:, 0].unsqueeze(1)
        val_loss += f.mse_loss(results, labels, reduction='sum')
        if max_val_batch_loss < val_loss:
            results_, labels_ = results, labels
            max_val_batch_loss = val_loss
    val_loss = (val_loss / val_size) ** 0.5

    msg = " | val_loss: {:.6f}".format(val_loss)
    loss_text += msg

    r_list, l_list = [], []
    for r, l in zip(results_, labels_):
        r_list.append(int((r.item() / 2 + 0.5 ) * 960))
        l_list.append(int((l.item() / 2 + 0.5 ) * 960))

    msg = "\n----- labels: {}\n---- results: {}".format(l_list, r_list)
    loss_text += msg
    
print(loss_text)

 | val_loss: 0.342896
----- labels: [457, 330, 603, 473, 348, 648, 643, 832, 657, 810, 956, 958, 368, 148, 763, 1]
---- results: [451, 337, 615, 586, 389, 677, 719, 637, 637, 531, 940, 1006, 454, 161, 738, 135]


In [77]:
from cnn.center_dataset import CenterDataset
from torch.utils.data import random_split
import random
import numpy as np

# TODO
######################################
seed = 1027
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)

determinstic = True # 안 됨
if determinstic:
    torch.backends.cudnn.determinstic = True
    torch.backends.cudnn.benchmark = False

batch_size = 16
######################################

dataset = CenterDataset('dataset/line', random_hflip=False, only_use_x=True)

night_loader = torch.utils.data.DataLoader(
    dataset,
    num_workers=0,
    batch_size=batch_size,
    shuffle=True,
)

In [101]:
model = get_model()
model.load_state_dict(torch.load('/home/ircv3/HYU-2024-Embedded/jetracer/model/road_following_model.pth'))
model = model.to(device)

In [102]:
loss_text = ''
model.eval()

with torch.no_grad():
    val_loss = 0
    max_val_batch_loss = 0
    for images, labels in night_loader:
        images = images.to(device)
        labels = labels.to(device)
        results = model(images)[:, 0].unsqueeze(1)
        val_loss += f.mse_loss(results, labels, reduction='sum')
        if max_val_batch_loss < val_loss:
            results_, labels_ = results, labels
            max_val_batch_loss = val_loss
    val_loss = (val_loss / val_size) ** 0.5

    msg = " | val_loss: {:.6f}".format(val_loss)
    loss_text += msg

    r_list, l_list = [], []
    for r, l in zip(results_, labels_):
        r_list.append(int((r.item() / 2 + 0.5 ) * 960))
        l_list.append(int((l.item() / 2 + 0.5 ) * 960))

    msg = "\n----- labels: {}\n---- results: {}".format(l_list, r_list)
    loss_text += msg
    
print(loss_text)

 | val_loss: 0.416119
----- labels: [679, 485, 16, 510, 644, 441, 241, 264, 559, 399, 504, 595, 952, 731, 330, 958]
---- results: [648, 522, 35, 442, 691, 361, 223, 107, 614, 356, 500, 600, 376, 766, 362, 855]
