In [5]:
import torch
import torch.optim as optim

import torch.nn as nn
import torch.nn.functional as F



class CNNPointDetector(nn.Module):
    def __init__(self):
        super(CNNPointDetector, self).__init__()
        self.conv1 = nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
        self.conv3 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.fc1 = nn.Linear(64 * 8 * 8, 128)
        self.fc2 = nn.Linear(128, 2)  # Output (x, y) coordinates
        
    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2(x), 2))
        x = F.relu(F.max_pool2d(self.conv3(x), 2))
        x = x.view(-1, 64 * 8 * 8)  # Flatten
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# Initialize model
model = CNNPointDetector()
# 初始化模型
model = CNNPointDetector()
model.load_state_dict(torch.load('fan_corner_detector_1127.pth'))  # 加载模型参数
model.train()

  model.load_state_dict(torch.load('fan_corner_detector_1127.pth'))  # 加载模型参数


CNNPointDetector(
  (conv1): Conv2d(1, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv2): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv3): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (fc1): Linear(in_features=4096, out_features=128, bias=True)
  (fc2): Linear(in_features=128, out_features=2, bias=True)
)

In [7]:
import pandas as pd
import cv2
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np

class FanDataset(Dataset):
    def __init__(self, annotations_file, img_size=64):
        self.annotations = pd.read_csv(annotations_file)
        self.img_size = img_size
    
    def __len__(self):
        return len(self.annotations)
    
    def __getitem__(self, idx):
        img_path = self.annotations.iloc[idx, 0]
        image = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)
        image = cv2.resize(image, (self.img_size, self.img_size)) / 2047.0  # Resize and normalize
        image = image.reshape(1, self.img_size, self.img_size)  # Reshape for CNN input

        # Label: x and y coordinates of the corner
        label = self.annotations.iloc[idx, 1:3].values.astype(np.float32)
        return torch.tensor(image, dtype=torch.float32), torch.tensor(label, dtype=torch.float32)

# 使用自己的标注文件
dataset = FanDataset('selected_points_1127.csv')
train_loader = DataLoader(dataset, batch_size=32, shuffle=True)

In [8]:
import torch.optim as optim

# Loss and optimizer
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Training loop
num_epochs = 700
for epoch in range(num_epochs):
    running_loss = 0.0
    for images, labels in train_loader:
        optimizer.zero_grad()
        
        # Forward pass
        outputs = model(images)
        loss = criterion(outputs, labels)
        
        # Backward pass and optimization
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
    
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss / len(train_loader):.4f}", 'about',(num_epochs-epoch)/1.5,'min left')

Epoch [1/700], Loss: 20193.6994 about 466.6666666666667 min left
Epoch [2/700], Loss: 24080.6989 about 466.0 min left
Epoch [3/700], Loss: 16553.3662 about 465.3333333333333 min left
Epoch [4/700], Loss: 4518.9365 about 464.6666666666667 min left
Epoch [5/700], Loss: 14729.3285 about 464.0 min left
Epoch [6/700], Loss: 3229.2860 about 463.3333333333333 min left
Epoch [7/700], Loss: 7319.6159 about 462.6666666666667 min left
Epoch [8/700], Loss: 4325.8805 about 462.0 min left
Epoch [9/700], Loss: 3686.9128 about 461.3333333333333 min left
Epoch [10/700], Loss: 2961.5212 about 460.6666666666667 min left
Epoch [11/700], Loss: 2069.2524 about 460.0 min left
Epoch [12/700], Loss: 1841.2236 about 459.3333333333333 min left
Epoch [13/700], Loss: 1373.3524 about 458.6666666666667 min left
Epoch [14/700], Loss: 1458.4600 about 458.0 min left
Epoch [15/700], Loss: 1570.8805 about 457.3333333333333 min left
Epoch [16/700], Loss: 1576.8704 about 456.6666666666667 min left
Epoch [17/700], Loss: 129

In [None]:
torch.save(model.state_dict(), './fan_corner_detector_112-f.pth')