# Attempts to create a pose estimation for spacecraft

## Naive approach using existing basic Linear model

Create a dataset per spacecraft.

In [2]:
# create dataset from train_labels.csv

import pandas as pd
import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import os
import io
from skimage import io, transform
import random
# create dataset that contains the image label and data label
class SpacecraftPoseDataset(Dataset):
    """ Spacecraft pose dataset """

    def __init__(self, csv_file=None, root_dir=None, df=None, transform = None):
        """
        Arguments:
            csv_file (string): Path to the csv file with annotations.
            root_dir (string): Directory with all the images.
            transform (callable, optional): Optional transform to be applied
                on a sample.
        """
        if df is not None:
            self.labels_frame = df
        else:
            self.labels_frame = pd.read_csv(csv_file)
        self.root_dir = root_dir
        self.transform = transform

    def __len__(self):
        return len(self.labels_frame)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()


        img_name = os.path.join(self.root_dir, self.labels_frame.iloc[idx, 0], f'{self.labels_frame.iloc[idx, 1]:03}.png')
        image = io.imread(img_name)
        label = self.labels_frame.iloc[idx, 1:]
        sample = {'image': image, 'label': label}

        if self.transform:
            sample = self.transform(sample)

        return sample

df = pd.read_csv('data/train_labels.csv')
# split df by chain_id
spacecraft = []
# print(df.head())
for chain_id in df['chain_id'].unique():
    spacecraft.append(df[df['chain_id'] == chain_id])

# print(spacecraft)
for i in range(len(spacecraft)):
    spacecraft[i] = SpacecraftPoseDataset(df=spacecraft[i], root_dir='data/images')
    # print(spacecraft[i].__len__())
    # print(spacecraft[i].__getitem__(0))

# SpacecraftPoseDataset = SpacecraftPoseDataset(csv_file = 'data/train_labels.csv', root_dir = 'data/images')

# train_dataloader = DataLoader(SpacecraftPoseDataset, batch_size=64)

# split dataset
# train_set, test_set, unused_set = torch.utils.data.random_split(SpacecraftPoseDataset, [0.2, 0.05, 0.75])
print("num of spacecraft:",len(spacecraft))


num of spacecraft: 660


In [15]:
# use torch to make a Linear model

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torch
import torchvision.transforms as transforms


class LinearModel(nn.Module):
    def __init__(self):
        super(LinearModel, self).__init__()
        self.flatten = nn.Flatten()
        self.model = nn.Sequential(
            nn.Linear(1024, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 7),
        )

    def forward(self, x):
        # resize x to 3x1280x1024
        # x = x.reshape(3, 1280, 1024)
       
        # convert x to tensor
        x = torch.tensor(x, dtype=torch.float32, device='mps')
        # take the average of the 3 channels
        x = torch.mean(x, dim=-1)
        x = torch.mean(x, dim=-1)

        # print(x.shape)
        x = self.model(x)
        return x

# Create an instance of the custom CNN model
cnn_model = LinearModel()


In [20]:
# use torch to make a CNN model

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torch
import torchvision.transforms as transforms


class BasicCNN(nn.Module):
    def __init__(self):
        super(BasicCNN, self).__init__()
        self.flatten = nn.Flatten()
        self.model = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
            nn.MaxPool2d(kernel_size=10, stride=5),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=10, stride=5),
            nn.ReLU(),
        )
        self.model2 = nn.Sequential(
            nn.Linear(29184, 128),
            nn.ReLU(),
            nn.Linear(128, 7),
            nn.ReLU(),
        )

    def forward(self, x):
        # resize x to 3x1280x1024
        x = x.reshape(3, 1280, 1024)
        # convert x to tensor
        x = torch.tensor(x, dtype=torch.float32, device='mps')
        # print(x.shape)
        x = self.model(x)
        # x = x.reshape(7)
        x = self.flatten(x)
        x = x.reshape(29184)
        x = self.model2(x)  
        return x

# Create an instance of the custom CNN model
cnn_model = BasicCNN()


In [34]:
# use torch to make a simpler CNN model

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torch
import torchvision.transforms as transforms


class BasicCNN(nn.Module):
    def __init__(self):
        super(BasicCNN, self).__init__()
        self.flatten = nn.Flatten()
        self.model = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=20, stride=10),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=20, stride=10),
        )
        self.model2 = nn.Sequential(
            nn.Linear(3168, 128),
            nn.ReLU(),
            nn.Linear(128, 7),
            nn.ReLU(),
        )

    def forward(self, x):
        # resize x to 3x1280x1024
        x = x.reshape(3, 1280, 1024)
        # convert x to tensor
        x = torch.tensor(x, dtype=torch.float32, device='mps')
        # print(x.shape)
        x = self.model(x)
        # x = x.reshape(7)
        x = self.flatten(x)
        x = x.reshape(3168)
        x = self.model2(x)  
        return x

# Create an instance of the custom CNN model
cnn_model = BasicCNN()


In [16]:
import time
# dataset = SpacecraftPoseDataset
dataset = spacecraft

# Define the loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(cnn_model.parameters(), lr=0.0001)

# Set the model to training mode
cnn_model.train()

# Train the model
# Set the number of epochs
num_epochs = 10

num_spacecrafts = 30

# Set the initial running loss
running_loss = 0.0

# Train the model for the specified number of epochs
# dont use dataloader, use dataset

# use m1 pro gpu to train
mps_device = torch.device('mps')
cnn_model.to(mps_device)
startTime = time.time()

for epoch in range(num_epochs):
    print(f'Epoch {epoch+1}/{num_epochs}')
    print("Time taken so far:", time.time() - startTime)
    for i in range(num_spacecrafts):
        print(f'Batch {i+1}/{num_spacecrafts}')
        i = random.randint(0, len(dataset) - 1)
        data = dataset[i]

        for j in range(len(data)):
            sample = data[j]
            image = sample['image']
            label = sample['label']

            # convert label to tensor
            label = torch.tensor(label[1:], dtype=torch.float32, device='mps')
            # print(label)
            # Zero the parameter gradients
            optimizer.zero_grad()

            # Forward pass
            outputs = cnn_model(image)
            # print(outputs.shape)

            # Calculate the loss
            loss = criterion(outputs, label)

            # Backward pass
            loss.backward()

            # Optimize the weights
            optimizer.step()

            # Print the statistics
            running_loss += loss.item()




Epoch 1/10
Time taken so far: 0.0001780986785888672
Batch 1/30
Batch 2/30
Batch 3/30
Batch 4/30
Batch 5/30
Batch 6/30
Batch 7/30
Batch 8/30
Batch 9/30
Batch 10/30
Batch 11/30
Batch 12/30
Batch 13/30
Batch 14/30
Batch 15/30
Batch 16/30
Batch 17/30
Batch 18/30
Batch 19/30
Batch 20/30
Batch 21/30
Batch 22/30
Batch 23/30
Batch 24/30
Batch 25/30
Batch 26/30
Batch 27/30
Batch 28/30
Batch 29/30
Batch 30/30
Epoch 2/10
Time taken so far: 152.5457682609558
Batch 1/30
Batch 2/30
Batch 3/30
Batch 4/30
Batch 5/30
Batch 6/30
Batch 7/30
Batch 8/30
Batch 9/30
Batch 10/30
Batch 11/30
Batch 12/30
Batch 13/30
Batch 14/30
Batch 15/30
Batch 16/30
Batch 17/30
Batch 18/30
Batch 19/30
Batch 20/30
Batch 21/30
Batch 22/30
Batch 23/30
Batch 24/30
Batch 25/30
Batch 26/30
Batch 27/30
Batch 28/30
Batch 29/30
Batch 30/30
Epoch 3/10
Time taken so far: 300.20884704589844
Batch 1/30
Batch 2/30
Batch 3/30
Batch 4/30
Batch 5/30
Batch 6/30
Batch 7/30
Batch 8/30
Batch 9/30
Batch 10/30
Batch 11/30
Batch 12/30
Batch 13/30
Ba

In [17]:
# save the model
torch.save(cnn_model.state_dict(), 'linear_model.pth')

# convert to cpu
cnn_model.to('cpu')
# save the model
torch.save(cnn_model.state_dict(), 'linear_model_cpu.pth')


In [19]:
# test the model
# Set the model to evaluation mode
cnn_model.eval()

cnn_model.to(mps_device)

# Test the model
# Set the initial running loss
running_loss = 0.0

# run the model on the 5 random spacecrafts
for i in range(5):
    i = random.randint(0, len(dataset) - 1)
    sample = dataset[i]
    result_df = pd.DataFrame(columns=['chain_id','i',"x", "y", "z", "qw", "qx", "qy", "qz"])
    for j in range(len(sample)):
        sample = data[j]
        image = sample['image']
        label = sample['label']

        # convert label to tensor
        label = torch.tensor(label[1:], dtype=torch.float32, device=mps_device)
        # print(label)
        # Forward pass
        outputs = cnn_model(image)
        # print(outputs.detach().numpy())
        # add the result to the dataframe
        temp = outputs.cpu().detach().numpy()
        result_df = pd.concat([result_df, pd.DataFrame({'chain_id': sample['label'][0], 'i': j, 
                    'x': temp[0], 'y': temp[1], 'z': temp[2], 'qw': temp[3], 
                    'qx': temp[4], 'qy': temp[5], 'qz': temp[6]}, index=[0])], ignore_index=True)


        # Calculate the loss
        loss = criterion(outputs, label)

        # Print the statistics
        running_loss += loss.item()

print(result_df.head())


  chain_id  i            x          y            z           qw            qx  \
0        0  0  3392.112793 -53.273254 -4633.512695  2497.546143 -11770.953125   
1        1  1  3708.927979 -57.926975 -5068.404297  2731.128174 -12872.385742   
2        2  2  3506.833252 -55.030956 -4790.500000  2582.021729 -12169.142578   
3        3  3  4187.043457 -65.468628 -5722.483398  3083.154297 -14532.419922   
4        4  4  4553.538086 -71.295242 -6223.434570  3352.858154 -15804.304688   

            qy            qz  
0 -4489.433594  -9681.136719  
1 -4909.356934 -10587.912109  
2 -4641.238770 -10008.919922  
3 -5542.458008 -11953.802734  
4 -6027.540527 -13000.271484  
