<a href="https://colab.research.google.com/github/uwaa-ndcl/ACC_2018_Avant/blob/master/resnet%26pose.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:

# connect google drive
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


## data cleaning

In [3]:
# dataset source: https://cvgl.stanford.edu/projects/objectnet3d/
import os, glob, scipy.io as sio
from tqdm import tqdm

root = '/content/drive/MyDrive/standford_pose_class_1'
ann_dir = os.path.join(root, 'annotation')
img_dir = os.path.join(root, 'image')

clean_rel_paths = []  # only keep the readable samples

for mat_path in tqdm(glob.glob(os.path.join(ann_dir, '*.mat'))):
    base = os.path.splitext(os.path.basename(mat_path))[0]
    try:
        sio.loadmat(mat_path, struct_as_record=False, squeeze_me=True)
    except NotImplementedError:
        # v7.3 cannot be read, delete
        print(f"Deleting (v7.3): {mat_path}")
        os.remove(mat_path)
        # delete the corresponding pics（jpg/png）
        for ext in ['.jpg', '.png']:
            img_path = os.path.join(img_dir, base + ext)
            if os.path.exists(img_path):
                print(f"Deleting image: {img_path}")
                os.remove(img_path)
        continue
    except Exception:
        # other damages, delete
        print(f"Deleting (corrupt): {mat_path}")
        os.remove(mat_path)
        for ext in ['.jpg', '.png']:
            img_path = os.path.join(img_dir, base + ext)
            if os.path.exists(img_path):
                print(f"Deleting image: {img_path}")
                os.remove(img_path)
        continue

    # usable samples
    imgs = glob.glob(os.path.join(img_dir, base + '.*'))
    if imgs:
        clean_rel_paths.append(os.path.relpath(imgs[0], root))

print(f"✅ Usable samples: {len(clean_rel_paths)}")


100%|██████████| 530/530 [00:17<00:00, 30.39it/s] 

✅ Usable samples: 530





## dataloader

### dataset check

In [9]:
# see the structure of the dataset
import scipy.io as sio, glob, os, pprint, numpy as np

# root = 'class1'
mat_path = glob.glob(os.path.join(root, 'annotation', '*.mat'))[0]
rec  = sio.loadmat(mat_path)['record'][0][0]
obj  = rec['objects'][0][0]

vp = obj['viewpoint'][0][0]          # (1,1) → struct
print('viewpoint dtype names →')
pprint.pprint(vp.dtype.names)

# 也可打印具体数值看看
for k in vp.dtype.names:
    v = vp[k]
    # v 可能还是 ndarray，把标量值取出来
    try: v = float(v[0][0])
    except Exception: pass
    print(f'{k:15s} → {v}')




viewpoint dtype names →
('azimuth_coarse',
 'elevation_coarse',
 'azimuth',
 'elevation',
 'distance',
 'focal',
 'px',
 'py',
 'theta',
 'error',
 'interval_azimuth',
 'interval_elevation',
 'num_anchor',
 'viewport')
azimuth_coarse  → 0.0
elevation_coarse → 10.0
azimuth         → []
elevation       → []
distance        → 5.207271426214868
focal           → 1.0
px              → 185.5
py              → 180.0
theta           → 0.0
error           → []
interval_azimuth → []
interval_elevation → []
num_anchor      → 12.0
viewport        → 2000.0


### dataloader define&data augmentation

In [3]:
import os, glob, torch
from torch.utils.data import Dataset
from PIL import Image
import torchvision.transforms as T
import scipy.io as sio   # 读取 .mat
import numpy as np


def safe_angle(v_fine, v_coarse):
    """优先取 fine；若为空则用 coarse"""
    return float(v_fine[0][0]) if v_fine.size else float(v_coarse[0][0])


class PoseDataset(Dataset):
    def __init__(self, root, file_list, transform=None):
        """
        root       : 根目录 'class1'
        file_list  : 图像相对路径列表，例如 ['images/img0001.jpg', ...]
        transform  : torchvision transforms
        """
        self.root = root
        self.file_list = file_list
        self.transform = transform

    def __len__(self):
        return len(self.file_list)

    def __getitem__(self, idx):
        img_rel_path = self.file_list[idx]
        img_path = os.path.join(self.root, img_rel_path)
        img = Image.open(img_path).convert('RGB')

        base = os.path.splitext(os.path.basename(img_rel_path))[0]
        mat_path = os.path.join(self.root, 'annotation', f'{base}.mat')

        rec = sio.loadmat(mat_path)['record'][0][0]
        obj = rec['objects'][0][0]
        vp  = obj['viewpoint'][0][0]

        # ---------- 角度读取 ----------
        yaw   = safe_angle(vp['azimuth'],   vp['azimuth_coarse'])
        pitch = safe_angle(vp['elevation'], vp['elevation_coarse'])
        roll  = float(vp['theta'][0][0])                 # theta 始终有值

        pose = torch.tensor([yaw, pitch, roll], dtype=torch.float32)

        if self.transform:
            img = self.transform(img)

        return img, pose


In [4]:
from sklearn.model_selection import train_test_split
import random, numpy as np


all_imgs = sorted(glob.glob(os.path.join(root, 'image', '*.*')))          # 绝对路径
all_imgs = [os.path.relpath(p, root) for p in all_imgs]                    # 变为相对路径

# 固定随机种子便于复现
random.seed(42); np.random.seed(42)

train_imgs, tmp = train_test_split(all_imgs, test_size=0.2, random_state=42)
val_imgs, test_imgs = train_test_split(tmp, test_size=0.5, random_state=42)

print(f'Train: {len(train_imgs)}, Val: {len(val_imgs)}, Test: {len(test_imgs)}')


Train: 424, Val: 53, Test: 53


In [5]:
BATCH_SIZE = 32
IMG_SIZE   = 224  # ResNet-50 标准输入

train_tf = T.Compose([
    T.Resize((IMG_SIZE, IMG_SIZE)),
    T.RandomHorizontalFlip(),
    T.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
    T.ToTensor(),
    T.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
])

test_tf = T.Compose([
    T.Resize((IMG_SIZE, IMG_SIZE)),
    T.ToTensor(),
    T.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
])

train_ds = PoseDataset(root, train_imgs, transform=train_tf)
val_ds   = PoseDataset(root, val_imgs,   transform=test_tf)
test_ds  = PoseDataset(root, test_imgs,  transform=test_tf)

from torch.utils.data import DataLoader
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True,  num_workers=4)
val_loader   = DataLoader(val_ds,   batch_size=BATCH_SIZE, shuffle=False, num_workers=4)
test_loader  = DataLoader(test_ds,  batch_size=BATCH_SIZE, shuffle=False, num_workers=4)




viewpoint dtype names →
('azimuth_coarse',
 'elevation_coarse',
 'azimuth',
 'elevation',
 'distance',
 'focal',
 'px',
 'py',
 'theta',
 'error',
 'interval_azimuth',
 'interval_elevation',
 'num_anchor',
 'viewport')
azimuth_coarse  → 0.0
elevation_coarse → 10.0
azimuth         → []
elevation       → []
distance        → 5.207271426214868
focal           → 1.0
px              → 185.5
py              → 180.0
theta           → 0.0
error           → []
interval_azimuth → []
interval_elevation → []
num_anchor      → 12.0
viewport        → 2000.0


## training

### define network

In [7]:
import torch, torch.nn as nn
import torchvision.models as models

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V2)
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 3)   # yaw(left-right rotation), pitch(up down rotation), roll(forward-backward rotation)
model = model.to(device)


In [8]:
from tqdm import tqdm
import torch.optim as optim

EPOCHS = 15
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-5)

def evaluate(loader):
    model.eval()
    total_loss, n = 0.0, 0
    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            pred = model(x)
            loss = torch.sqrt(criterion(pred, y))  # RMSE
            total_loss += loss.item() * x.size(0)
            n += x.size(0)
    return total_loss / n

best_val = float('inf')
for epoch in range(1, EPOCHS+1):
    model.train()
    pbar = tqdm(train_loader, desc=f'Epoch {epoch}/{EPOCHS}')
    for x, y in pbar:
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()
        pred = model(x)
        loss = criterion(pred, y)
        loss.backward()
        optimizer.step()
        pbar.set_postfix(train_loss=loss.item())

    val_rmse = evaluate(val_loader)
    print(f'⚡️  Val RMSE: {val_rmse:.3f}')
    if val_rmse < best_val:
        best_val = val_rmse
        torch.save(model.state_dict(), 'best_pose_resnet50.pth')
        print('🔖  Model saved.')


Epoch 1/15: 100%|██████████| 14/14 [00:08<00:00,  1.61it/s, train_loss=814]


⚡️  Val RMSE: 26.367
🔖  Model saved.


Epoch 2/15: 100%|██████████| 14/14 [00:07<00:00,  1.87it/s, train_loss=1.83e+3]


⚡️  Val RMSE: 25.198
🔖  Model saved.


Epoch 3/15: 100%|██████████| 14/14 [00:06<00:00,  2.24it/s, train_loss=686]


⚡️  Val RMSE: 24.291
🔖  Model saved.


Epoch 4/15: 100%|██████████| 14/14 [00:08<00:00,  1.72it/s, train_loss=854]


⚡️  Val RMSE: 24.098
🔖  Model saved.


Epoch 5/15: 100%|██████████| 14/14 [00:07<00:00,  1.99it/s, train_loss=1.16e+3]


⚡️  Val RMSE: 23.361
🔖  Model saved.


Epoch 6/15: 100%|██████████| 14/14 [00:06<00:00,  2.17it/s, train_loss=1.64e+3]


⚡️  Val RMSE: 23.958


Epoch 7/15: 100%|██████████| 14/14 [00:08<00:00,  1.70it/s, train_loss=26.9]


⚡️  Val RMSE: 23.963


Epoch 8/15: 100%|██████████| 14/14 [00:08<00:00,  1.66it/s, train_loss=234]


⚡️  Val RMSE: 23.761


Epoch 9/15: 100%|██████████| 14/14 [00:08<00:00,  1.70it/s, train_loss=591]


⚡️  Val RMSE: 23.773


Epoch 10/15: 100%|██████████| 14/14 [00:07<00:00,  1.93it/s, train_loss=1.03e+3]


⚡️  Val RMSE: 23.482


Epoch 11/15: 100%|██████████| 14/14 [00:06<00:00,  2.00it/s, train_loss=138]


⚡️  Val RMSE: 23.942


Epoch 12/15: 100%|██████████| 14/14 [00:07<00:00,  1.77it/s, train_loss=176]


⚡️  Val RMSE: 23.476


Epoch 13/15: 100%|██████████| 14/14 [00:06<00:00,  2.25it/s, train_loss=235]


⚡️  Val RMSE: 23.794


Epoch 14/15: 100%|██████████| 14/14 [00:07<00:00,  1.94it/s, train_loss=427]


⚡️  Val RMSE: 24.073


Epoch 15/15: 100%|██████████| 14/14 [00:06<00:00,  2.22it/s, train_loss=359]


⚡️  Val RMSE: 24.168


In [10]:
model.load_state_dict(torch.load('best_pose_resnet50.pth'))
test_rmse = evaluate(test_loader)
print(f'🎯  Test RMSE (deg): {test_rmse:.3f}')




🎯  Test RMSE (deg): 23.049
