In [3]:
import os
data_path = './data'

In [None]:
if not os.path.exists('open.zip') and not os.path.exists('./data'): # Download data from google drive
    !pip install gdown
    !gdown https://drive.google.com/uc?id=13oGkm3Ao7fL2p51H62J68Gw630ABBR0g 
    !unzip open.zip -d data # unzip in the 'data' folder


In [7]:
from prepare_origin import save_origin
if not os.path.exists('./data/origin'):  # make origin picture from train data
    save_origin(data_path)

In [68]:
import torch
import os
import pandas as pd
import numpy as np
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

class PuzzleDataset(Dataset):
    def __init__(self, csv_file, root_dir, transform=None):
        self.annotations = pd.read_csv(csv_file)
        self.root_dir = root_dir
        self.transform = transform

    def __len__(self):
        return len(self.annotations)

    def __getitem__(self, index):
        # 이미지 로드
        img_name = os.path.join(self.root_dir, self.annotations.iloc[index, 1].split("train/")[1])
        image = Image.open(img_name).convert('RGB')

        # 이미지를 16개의 조각으로 분할
        pieces = []
        piece_size = image.size[0] // 4  # 가정: 이미지는 정사각형이고, 4x4로 분할
        for i in range(4):
            for j in range(4):
                piece = image.crop((j * piece_size, i * piece_size, (j + 1) * piece_size, (i + 1) * piece_size))
                if self.transform:
                    piece = self.transform(piece)
                pieces.append(piece)

        # 레이블 처리
        label = np.array(self.annotations.iloc[index, 2:].values.astype(np.float32)) - 1  # 인덱스를 0부터 시작하도록 조정

        return torch.stack(pieces), torch.tensor(label)

# 데이터셋 인스턴스화 및 트랜스폼 정의
transform = transforms.Compose([transforms.ToTensor()])
train_dataset = PuzzleDataset(csv_file=data_path+'/train.csv', root_dir=data_path+'/train', transform=transform)
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)


In [69]:
import torch
import torch.nn as nn
from torchvision.models import resnet50

class PuzzleSolver(nn.Module):
    def __init__(self, num_pieces):
        super(PuzzleSolver, self).__init__()
        self.resnet = resnet50(pretrained=True)
        for param in self.resnet.parameters():
            param.requires_grad = False

        self.fc = nn.Sequential(
            nn.Linear(self.resnet.fc.out_features * num_pieces, 1024),
            nn.ReLU(),
            nn.Linear(1024, num_pieces)
        )

    def forward(self, x):
        # x: [batch_size, num_pieces, C, H, W]
        batch_size, num_pieces, C, H, W = x.size()
        x = x.view(-1, C, H, W)  # [batch_size * num_pieces, C, H, W]
        features = self.resnet(x)  # 각 조각에 대한 특징 추출
        features = features.view(batch_size, -1)  # [batch_size, num_pieces * feature_size]
        output = self.fc(features)  # 최종 위치 예측
        return output

# 모델 인스턴스 생성
model = PuzzleSolver(num_pieces=16)




In [72]:
from tqdm import tqdm

# 손실 함수 및 최적화
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# 학습 루프
num_epochs = 10
for epoch in range(num_epochs):
    loop = tqdm(enumerate(train_loader), total=len(train_loader), leave=False)
    for i, (images, labels) in loop:
        # Forward pass
        outputs = model(images)
        loss = criterion(outputs, labels)
        
        # Backward and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Update progress bar
        loop.set_description(f'Epoch [{epoch+1}/{num_epochs}]')
        loop.set_postfix(loss=loss.item())


  0%|          | 0/17500 [00:00<?, ?it/s]

                                                                             

KeyboardInterrupt: 

In [None]:
torch.save(model.state_dict(), 'puzzle_solver.pth')
