In [1]:
#importowanie bibliotek
import cv2
import torch
import torch.optim as optim
import torch.nn.functional as F
import torchvision
import torchvision.datasets as datasets
import torchvision.models as models
import torchvision.transforms as transforms
import glob
import PIL.Image
import os
import numpy as np
import ipywidgets
import traitlets
from IPython.display import display
from jetcam.utils import bgr8_to_jpeg
from jupyter_clickable_image_widget import ClickableImageWidget
from xy_dataset import XYDataset

In [2]:
#pobieranie wartości x z nazwy pliku
def get_x(path, width):
    return (float(int(path.split("_")[1])) - width/2) / (width/2)

#pobieranie wartości y z nazwy pliku
def get_y(path, height):
    return (float(int(path.split("_")[2])) - height/2) / (height/2)

class XYDataset(torch.utils.data.Dataset):
    
    def __init__(self, directory, random_hflips=False):
        self.directory = directory
        self.random_hflips = random_hflips
        self.image_paths = glob.glob(os.path.join(self.directory, '*.jpg'))
        self.color_jitter = transforms.ColorJitter(0.3, 0.3, 0.3, 0.3)
    
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        
        image = PIL.Image.open(image_path)
        width, height = image.size
        x = float(get_x(os.path.basename(image_path), width))
        y = float(get_y(os.path.basename(image_path), height))
      
        if float(np.random.rand(1)) > 0.5:
            image = transforms.functional.hflip(image)
            x = -x
        
        image = self.color_jitter(image)
        image = transforms.functional.resize(image, (224, 224))
        image = transforms.functional.to_tensor(image)
        image = image.numpy()[::-1].copy()
        image = torch.from_numpy(image)
        image = transforms.functional.normalize(image, [0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        
        return image, torch.tensor([x, y]).float()
    
dataset = XYDataset('ad_dataset_sala230', random_hflips=False)

In [3]:
test_percent = 0.1
num_test = int(test_percent * len(dataset))
train_dataset, test_dataset = torch.utils.data.random_split(dataset, [len(dataset) - num_test, num_test])

In [4]:
train_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=8,
    shuffle=True,
    num_workers=0
)

test_loader = torch.utils.data.DataLoader(
    test_dataset,
    batch_size=8,
    shuffle=True,
    num_workers=0
)

In [5]:
#resnet18
model = models.resnet18(pretrained=True)

In [6]:
model.fc = torch.nn.Linear(512, 2)
device = torch.device('cuda')
model = model.to(device)

In [7]:
NUM_EPOCHS = 100
BEST_MODEL_PATH = 'steering_model_sala230_100.pth'
best_loss = 1e9

optimizer = optim.Adam(model.parameters())

for epoch in range(NUM_EPOCHS):
    
    model.train()
    train_loss = 0.0
    for images, labels in iter(train_loader):
        images = images.to(device)

        labels = labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = F.mse_loss(outputs, labels)
        train_loss += float(loss)
        loss.backward()
        optimizer.step()
    train_loss /= len(train_loader)
    
    model.eval()
    test_loss = 0.0
    for images, labels in iter(test_loader):
        images = images.to(device)
        labels = labels.to(device)
        outputs = model(images)
        loss = F.mse_loss(outputs, labels)
        test_loss += float(loss)
    test_loss /= len(test_loader)
    
    print('%f, %f' % (train_loss, test_loss))
    if test_loss < best_loss:
        torch.save(model.state_dict(), BEST_MODEL_PATH)
        best_loss = test_loss

0.481239, 0.280518
0.052267, 0.029087
0.042583, 0.053563
0.038394, 0.016149
0.024998, 0.010291
0.028118, 0.053651
0.029428, 0.017180
0.026256, 0.015141
0.023554, 0.028529
0.018301, 0.018918
0.016252, 0.020275
0.019747, 0.010506
0.011272, 0.011312
0.015047, 0.018193
0.016508, 0.012344
0.014238, 0.022125
0.014575, 0.018530
0.015859, 0.015652
0.010976, 0.009908
0.011024, 0.017384
0.009074, 0.015173
0.009046, 0.013238
0.008858, 0.014485
0.006909, 0.011378
0.006991, 0.014336
0.007324, 0.009767
0.005660, 0.012550
0.007806, 0.011604
0.010171, 0.015134
0.007556, 0.008562
0.006391, 0.009306
0.005975, 0.009602
0.005898, 0.014061
0.005439, 0.009432
0.005179, 0.008774
0.007113, 0.009600
0.005027, 0.007475
0.005001, 0.011622
0.006635, 0.008443
0.004586, 0.009061
0.008431, 0.011294
0.005325, 0.008909
0.006443, 0.012192
0.006359, 0.008450
0.005162, 0.013921
0.007076, 0.008460
0.004640, 0.013832
0.003775, 0.011598
0.003367, 0.008132
0.004805, 0.009527
0.005561, 0.011734
0.004535, 0.014050
0.004102, 0.