In [4]:
import os
import random
from PIL import Image
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import torchvision.models as models


In [5]:
# 1. Dataset
class SiameseDataset(Dataset):
    anchor_dir = r"C:\Users\yun\Desktop\write\an"
    def __init__(self, anchor_dir, other_dir, transform=None, positive_pairs=10, negative_pairs=10):
        #
        self.anchor_images = [os.path.join(anchor_dir, f) for f in os.listdir(anchor_dir)]
        self.other_images = [os.path.join(other_dir, f) for f in os.listdir(other_dir)]
        self.transform = transform
        self.pairs = []
        # positive pairs (same class)
        for i in range(positive_pairs):
            img1, img2 = random.sample(self.anchor_images, 2)
            self.pairs.append((img1, img2, 1))
        # negative pairs
        for i in range(negative_pairs):
            img1 = random.choice(self.anchor_images)
            img2 = random.choice(self.other_images)
            self.pairs.append((img1, img2, 0))

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx):
        path1, path2, label = self.pairs[idx]
        img1 = Image.open(path1).convert('RGB')
        img2 = Image.open(path2).convert('RGB')
        if self.transform:
            img1 = self.transform(img1)
            img2 = self.transform(img2)
        return img1, img2, torch.tensor([label], dtype=torch.float32)


In [6]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

In [49]:

# 3. Model
class ResNetEmbedder(nn.Module):
    def __init__(self):
        super(ResNetEmbedder, self).__init__()
        resnet = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
        modules = list(resnet.children())[:-1]  # 移除最後分類層
        self.resnet = nn.Sequential(*modules)
        self.fc = nn.Linear(resnet.fc.in_features, 128)  # 特徵向量降維
    def forward(self, x):
        x = self.resnet(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x
    
# 4. Loss\ n
class ContrastiveLoss(nn.Module):
    def __init__(self, margin=1.0):
        super().__init__()
        self.margin = margin

    def forward(self, out1, out2, label):
        # out shape: (B, D), label shape: (B,1)
        dist = nn.functional.pairwise_distance(out1, out2)
        loss_pos = label * dist.pow(2)
        loss_neg = (1 - label) * torch.clamp(self.margin - dist, min=0.0).pow(2)
        return (loss_pos + loss_neg).mean()


In [47]:

# 5. Training loop
def train(anchor_dir, other_dir, epochs=20, batch_size=8, lr=1e-4):
    dataset = SiameseDataset(anchor_dir, other_dir, transform)
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = ResNetEmbedder().to(device)
    criterion = ContrastiveLoss(margin=1.0)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    losses = []

    for epoch in range(epochs):
        model.train()
        total_loss = 0.0
        for img1, img2, label in loader:
            img1, img2, label = img1.to(device), img2.to(device), label.to(device)
            out1 = model(img1)
            out2 = model(img2)
            loss = criterion(out1, out2, label)
            optimizer.zero_grad()
            label = label * 2 - 1
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        
        print(f"Epoch {epoch+1}/{epochs}, Loss: {total_loss/len(loader):.4f}")

    torch.save(model.state_dict(), 'resnet_siamese.pth')

    # 儲存 loss 到檔案
    with open('train_loss.txt', 'w') as f:
        for l in losses:
            f.write(f"{l}\n")

    return model
    


In [61]:

# 6. Inference
def similarity(model, img_path1, img_path2):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model.eval()
    
    img1 = Image.open(img_path1).convert('RGB')
    img2 = Image.open(img_path2).convert('RGB')
    
    img1 = transform(img1).unsqueeze(0).to(device)
    img2 = transform(img2).unsqueeze(0).to(device)

    with torch.no_grad():
        feat1 = model(img1)
        feat2 = model(img2)
        cos_sim = nn.functional.cosine_similarity(feat1, feat2)
        return (cos_sim.item() * 100 )  # 轉成百分比
    

def load_model(model_path='resnet_siamese.pth'):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = ResNetEmbedder().to(device)
    model.load_state_dict(torch.load(model_path, map_location=device))
    model.eval()
    return model


In [50]:
anchor_dir = r'C:\Users\yun\Desktop\write\an'
other_dir = r'C:\Users\yun\Desktop\write\xxx'
model = train(anchor_dir, other_dir, epochs=500)

Downloading: "https://download.pytorch.org/models/resnet50-11ad3fa6.pth" to C:\Users\yun/.cache\torch\hub\checkpoints\resnet50-11ad3fa6.pth


100%|██████████| 97.8M/97.8M [00:01<00:00, 96.6MB/s]


Epoch 1/500, Loss: 0.6192
Epoch 2/500, Loss: 0.2673
Epoch 3/500, Loss: 0.2255
Epoch 4/500, Loss: 0.2563
Epoch 5/500, Loss: 0.2528
Epoch 6/500, Loss: 0.2444
Epoch 7/500, Loss: 0.2501
Epoch 8/500, Loss: 0.2542
Epoch 9/500, Loss: 0.2510
Epoch 10/500, Loss: 0.2193
Epoch 11/500, Loss: 0.2197
Epoch 12/500, Loss: 0.2163
Epoch 13/500, Loss: 0.2133
Epoch 14/500, Loss: 0.1747
Epoch 15/500, Loss: 0.2309
Epoch 16/500, Loss: 0.2073
Epoch 17/500, Loss: 0.2118
Epoch 18/500, Loss: 0.2527
Epoch 19/500, Loss: 0.2520
Epoch 20/500, Loss: 0.2310
Epoch 21/500, Loss: 0.2545
Epoch 22/500, Loss: 0.2443
Epoch 23/500, Loss: 0.2524
Epoch 24/500, Loss: 0.2438
Epoch 25/500, Loss: 0.2415
Epoch 26/500, Loss: 0.2032
Epoch 27/500, Loss: 0.2085
Epoch 28/500, Loss: 0.2255
Epoch 29/500, Loss: 0.2520
Epoch 30/500, Loss: 0.2329
Epoch 31/500, Loss: 0.2079
Epoch 32/500, Loss: 0.2471
Epoch 33/500, Loss: 0.2534
Epoch 34/500, Loss: 0.2298
Epoch 35/500, Loss: 0.2459
Epoch 36/500, Loss: 0.2278
Epoch 37/500, Loss: 0.2437
Epoch 38/5

In [91]:
if __name__ == '__main__':
    
    score = similarity(model, 
                       r'C:\Users\yun\Desktop\write\an\螢幕擷取畫面 2025-05-24 234703.png',
                       r'C:\Users\yun\Desktop\write\dba770c3-97b3-49e4-bae9-767df7eeee7a.jpg')
    print(f"相似度: {score:.2f}%")

相似度: 81.02%
