In [84]:
import os
import shutil
from sklearn.model_selection import train_test_split

# 폴더 경로 설정
base_dir = './Fewshot'
categories = ['Full', 'Part', 'Box', 'ETC']
train_dir = os.path.join(base_dir, 'train')
val_dir = os.path.join(base_dir, 'val')

# train 및 val 디렉토리 생성
os.makedirs(train_dir, exist_ok=True)
os.makedirs(val_dir, exist_ok=True)

for category in categories:
    os.makedirs(os.path.join(train_dir, category), exist_ok=True)
    os.makedirs(os.path.join(val_dir, category), exist_ok=True)

# 데이터를 학습 및 검증 세트로 분할
for category in categories:
    category_dir = os.path.join(base_dir, category)
    images = os.listdir(category_dir)
    train_images, val_images = train_test_split(images, test_size=0.2, random_state=42)
    
    for img in train_images:
        shutil.copy(os.path.join(category_dir, img), os.path.join(train_dir, category, img))
    
    for img in val_images:
        shutil.copy(os.path.join(category_dir, img), os.path.join(val_dir, category, img))


In [85]:
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import datasets, transforms
import numpy as np

class FewShotDataset(Dataset):
    def __init__(self, root_dir, transform=None, num_support=5, num_query=15):
        self.dataset = datasets.ImageFolder(root_dir, transform=transform)
        self.num_support = num_support
        self.num_query = num_query
        self.transform = transform
        self.class_indices = self._group_images_by_class()

    def _group_images_by_class(self):
        class_indices = {}
        for idx, (img_path, label) in enumerate(self.dataset.imgs):
            if label not in class_indices:
                class_indices[label] = []
            class_indices[label].append(img_path)
        return class_indices

    def __len__(self):
        return len(self.dataset.classes)

    def __getitem__(self, idx):
        class_label = list(self.class_indices.keys())[idx]
        image_paths = self.class_indices[class_label]

        # 이미지 수가 부족한 경우를 처리
        if len(image_paths) < self.num_support + self.num_query:
            # 부족한 경우 반복해서 채웁니다.
            selected_images = np.random.choice(image_paths, self.num_support + self.num_query, replace=True)
        else:
            selected_images = np.random.choice(image_paths, self.num_support + self.num_query, replace=False)

        support_images = selected_images[:self.num_support]
        query_images = selected_images[self.num_support:]

        support_set = [self.dataset.loader(img_path) for img_path in support_images]
        query_set = [self.dataset.loader(img_path) for img_path in query_images]

        if self.transform:
            support_set = [self.transform(img) for img in support_set]
            query_set = [self.transform(img) for img in query_set]

        support_labels = [class_label] * self.num_support
        query_labels = [class_label] * self.num_query

        return torch.stack(support_set), torch.tensor(support_labels), torch.stack(query_set), torch.tensor(query_labels)


In [86]:
# 변환 정의
# 변환 정의
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])


# 데이터셋 로드
train_dir = 'Fewshot/train'
val_dir = 'Fewshot/val'

num_support = 5
num_query = 15

train_dataset = FewShotDataset(train_dir, transform=transform, num_support=num_support, num_query=num_query)
val_dataset = FewShotDataset(val_dir, transform=transform, num_support=num_support, num_query=num_query)

# 클래스 개수 가져오기
num_classes = len(train_dataset.dataset.classes)

# 데이터 로더 생성
train_loader = DataLoader(train_dataset, batch_size=num_support * num_classes, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=num_support * num_classes, shuffle=False)


In [87]:
import torch.nn as nn
import torch.optim as optim
from torchvision import models

class FeatureExtractor(nn.Module):
    def __init__(self):
        super(FeatureExtractor, self).__init__()
        self.resnet = models.resnet50(pretrained=True)  # ResNet50 사용
        self.resnet.fc = nn.Identity()  # 마지막 FC 레이어 제거

    def forward(self, x):
        return self.resnet(x)

feature_extractor = FeatureExtractor()

class PrototypicalNetwork(nn.Module):
    def __init__(self, feature_extractor, num_classes, num_support):
        super(PrototypicalNetwork, self).__init__()
        self.feature_extractor = feature_extractor
        self.num_classes = num_classes
        self.num_support = num_support

    def forward(self, support_images, query_images):
        # 지원 세트에서 특징 추출
        support_images = support_images.view(-1, *support_images.size()[2:])  # [num_classes * num_support, channels, height, width]
        support_embeddings = self.feature_extractor(support_images)
        
        # 디버그: support_embeddings의 크기 출력
        print(f"support_embeddings shape: {support_embeddings.shape}")
        
        # support_embeddings의 모양: [num_support * num_classes, embedding_size]
        num_support_samples = self.num_support * self.num_classes
        embedding_size = support_embeddings.size(-1)
        
        if support_embeddings.size(0) != num_support_samples:
            raise ValueError(f"Expected {num_support_samples} support samples, but got {support_embeddings.size(0)}")
        
        support_embeddings = support_embeddings.view(self.num_classes, self.num_support, embedding_size)

        # 프로토타입 계산
        prototypes = support_embeddings.mean(dim=1)

        # 질의 이미지에서 특징 추출
        query_images = query_images.view(-1, *query_images.size()[2:])  # [num_query * num_classes, channels, height, width]
        query_embeddings = self.feature_extractor(query_images)
        
        # 유클리드 거리 계산
        dists = torch.cdist(query_embeddings, prototypes)
        
        # 최솟값 인덱스를 사용하여 클래스 예측
        return dists





num_classes = len(train_dataset.dataset.classes)
num_support = 5  # num_support 값을 설정
model = PrototypicalNetwork(feature_extractor, num_classes, num_support)


Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to C:\Users\sunwoong/.cache\torch\hub\checkpoints\resnet50-0676ba61.pth
100%|██████████| 97.8M/97.8M [00:02<00:00, 45.0MB/s]


In [88]:
import torch.optim as optim

optimizer = optim.Adam(model.parameters(), lr=0.001)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)

num_epochs = 20

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for support_images, support_labels, query_images, query_labels in train_loader:
        support_images = support_images.squeeze(0)
        query_images = query_images.squeeze(0)
        support_labels = support_labels.squeeze(0)
        query_labels = query_labels.squeeze(0)

        optimizer.zero_grad()
        outputs = model(support_images, query_images)

        query_labels = query_labels.long()
        query_labels = query_labels.view(-1)
        outputs = outputs.view(-1, outputs.size(-1))

        loss = criterion(outputs, query_labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    
    scheduler.step()  # 학습률 조정

    print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(train_loader)}')

    model.eval()
    val_loss = 0.0
    correct = 0
    total = 0
    with torch.no_grad():
        for support_images, support_labels, query_images, query_labels in val_loader:
            support_images = support_images.squeeze(0)
            query_images = query_images.squeeze(0)
            support_labels = support_labels.squeeze(0)
            query_labels = query_labels.squeeze(0)

            outputs = model(support_images, query_images)

            query_labels = query_labels.long()
            query_labels = query_labels.view(-1)
            outputs = outputs.view(-1, outputs.size(-1))

            loss = criterion(outputs, query_labels)
            val_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            total += query_labels.size(0)
            correct += (predicted == query_labels).sum().item()
    
    print(f'Validation Loss: {val_loss/len(val_loader)}, Accuracy: {100 * correct/total}%')


support_embeddings shape: torch.Size([20, 2048])
Epoch [1/20], Loss: 1.6667711734771729
support_embeddings shape: torch.Size([20, 2048])
Validation Loss: 10.938091278076172, Accuracy: 15.0%
support_embeddings shape: torch.Size([20, 2048])
Epoch [2/20], Loss: 4.181957244873047
support_embeddings shape: torch.Size([20, 2048])
Validation Loss: 24.08255958557129, Accuracy: 10.0%
support_embeddings shape: torch.Size([20, 2048])
Epoch [3/20], Loss: 8.0247220993042
support_embeddings shape: torch.Size([20, 2048])
Validation Loss: 12.465987205505371, Accuracy: 31.666666666666668%
support_embeddings shape: torch.Size([20, 2048])
Epoch [4/20], Loss: 2.0576529502868652
support_embeddings shape: torch.Size([20, 2048])
Validation Loss: 14.25047492980957, Accuracy: 23.333333333333332%
support_embeddings shape: torch.Size([20, 2048])
Epoch [5/20], Loss: 2.2134997844696045
support_embeddings shape: torch.Size([20, 2048])
Validation Loss: 20.55222511291504, Accuracy: 20.0%
support_embeddings shape: tor

In [None]:
from flask import Flask, request, jsonify
from PIL import Image
import io

app = Flask(__name__)

model.load_state_dict(torch.load('shoe_classification_model.pth'))
model.eval()

def predict(image_bytes):
    image = Image.open(io.BytesIO(image_bytes))
    image = transform(image).unsqueeze(0)
    with torch.no_grad():
        outputs = model(image)
        _, predicted = torch.max(outputs, 1)
        return train_dataset.dataset.classes[predicted.item()]

@app.route('/predict', methods=['POST'])
def predict_route():
    if 'file' not in request.files:
        return jsonify({'error': 'No file provided'}), 400
    
    file = request.files['file']
    img_bytes = file.read()
    prediction = predict(img_bytes)
    return jsonify({'category': prediction})

if __name__ == '__main__':
    app.run(debug=True)
