In [None]:
# download data from my personal drive (these links should work for everyone)
%%shell
wget --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id=1NPUFI13qBHFUMjs-XxFFMIJ5fuBtc2cg' -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id=1NPUFI13qBHFUMjs-XxFFMIJ5fuBtc2cg" -O UR5_images.zip && rm -rf /tmp/cookies.txt
wget --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id=1_xY4bdY_JRHJ_R0g7gtvGDvYE1W5YNS1' -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id=1_xY4bdY_JRHJ_R0g7gtvGDvYE1W5YNS1" -O UR5_positions.csv && rm -rf /tmp/cookies.txt
unzip UR5_images.zip
mv UR5_images/ images/

In [52]:
import numpy as np
import pandas as pd
from torch.utils.data import Dataset, DataLoader
import torch
import os
from skimage import io, transform
from torchvision.transforms import v2
from matplotlib import pyplot as plt

In [17]:
image_labels = pd.read_csv('UR5_positions.csv')
image_labels = image_labels[['ImageID','Joint1','Joint2','Joint3','Joint4','Joint5']]

In [19]:
image_labels

Unnamed: 0,ImageID,Joint1,Joint2,Joint3,Joint4,Joint5
0,0,2.093517,0.906312,2.523018,-1.073683,-2.068640
1,1,-2.349126,0.328547,2.794923,2.421446,0.232648
2,2,-0.796445,-1.380814,-0.211948,0.113720,1.000168
3,3,2.044282,-1.948778,0.550287,0.796973,0.394149
4,4,-1.588074,2.281538,3.139811,1.151248,2.943370
...,...,...,...,...,...,...
9995,9995,2.302785,0.077975,1.670275,2.402927,0.217893
9996,9996,-1.163020,2.925114,-1.184967,2.837455,-1.712592
9997,9997,-0.516064,-0.707306,-2.761259,0.284029,2.103629
9998,9998,1.783519,0.833506,-1.365241,0.614984,1.289216


In [181]:
class RobotImageDataset(Dataset):
    def __init__(self, csv_file: str, root_dir: str):
        """
        Args:
            csv_file (string): Path to the csv file with annotations.
            root_dir (string): Directory with all the images.
        """
        self.annotations = pd.read_csv(csv_file)
        self.annotations = self.annotations[['ImageID','Joint1','Joint2','Joint3','Joint4','Joint5']]
        self.root_dir = root_dir
        self.transform = v2.Compose(
            [
                v2.ToImage(),
                v2.Resize(size=(224, 224), antialias=True),
                v2.ToDtype(torch.float32, scale=True),  # Normalize expects float input
                v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
                v2.Grayscale(),
                v2.ToTensor()
            ])

    def process_sample(self, top, side):
        top = top[:,:,:3]
        side = side[:,:,:3]
        top_tens = self.transform(top)
        side_tens = self.transform(side)

        return(torch.cat([side_tens,top_tens],dim = 0))
    def __len__(self):
        return(len(self.annotations))

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        side_image_path = os.path.join(self.root_dir, 'side_view_'+ str(idx) + '.png')
        top_image_path = os.path.join(self.root_dir, 'top_view_'+str(idx) + '.png')

        side_image = io.imread(side_image_path)
        top_image = io.imread(top_image_path)

        im_data = self.process_sample(top_image, side_image)

        joint_values = self.annotations.iloc[idx, 1:].to_numpy(dtype=float)

        sample = {'images': im_data, 'joint_values': joint_values}


        return sample


In [182]:
robotdata = RobotImageDataset('/content/UR5_positions.csv', '/content/images')
dataloader = DataLoader(robotdata, batch_size=64,
                        shuffle=True)



In [None]:
for x in dataloader:
  print(x['images'].shape)
  input(x['joint_values'].shape)

In [188]:
from torchvision.models.resnet import ResNet18_Weights
from torchvision.models import resnet18
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim


weights = ResNet18_Weights.DEFAULT
resnet_pll = resnet18(weights=weights)
for idx,child in enumerate(resnet_pll.children()):
  if idx == 9:
    print(f'last layer: {child}')
  else:
    for param in child.parameters():
      param.requires_grad = False
resnet_pll.fc = nn.Sequential(nn.Linear(in_features=512, out_features=256,bias=True),
                              nn.ReLU(),
                              nn.Linear(in_features=256, out_features = 5, bias=True))

Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth
100%|██████████| 44.7M/44.7M [00:00<00:00, 63.2MB/s]

last layer: Linear(in_features=512, out_features=1000, bias=True)





In [191]:
# Check if GPU is available
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

# Params
LEARNING_RATE = 1e-4

# Initialize
resnet_pll.to(device)
criterion = nn.MSELoss()
optimizer = optim.Adam(resnet_pll.parameters(), lr=LEARNING_RATE)



# Define training loop
def train_one_epoch():
    running_loss = 0.0
    train_items = 0
    train_correct = 0

    for i, data in enumerate(dataloader):
        X = data['images'].float()
        y = data['joint_values'].float()
        X, y = X.to(device), y.to(device)
        preds = resnet_pll(X).float()
        loss = criterion(preds, y)
        running_loss +=loss.item()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    train_loss = running_loss/(i+1)

    return(train_loss)


cpu


In [192]:
# Train
import time

EPOCHS = 15
timestr = time.strftime("%Y-%m-%d_%H-%M-%S")
model_dir = os.path.join('models/resnet18', timestr)
for i in range(EPOCHS):
    training_loss = train_one_epoch()
    print(f"Epoch {i} | train loss: {training_loss}")

    model_folder_name = f'epoch_{i:04d}_loss_{training_loss:.8f}'
    if not os.path.exists(os.path.join(model_dir, model_folder_name)):
      os.makedirs(os.path.join(model_dir, model_folder_name))
    torch.save(resnet_pll.state_dict(), os.path.join(model_dir, model_folder_name, 'model_state_dict.pth'))

RuntimeError: ignored