In [None]:
from PIL import Image
from pathlib import Path
import torch
from torch.utils.data import Dataset
from torchvision import transforms
import pandas as pd
import numpy as np
from matplotlib import pyplot as plt

class GridMapDataset(Dataset):
    def __init__(self, root, split='train', attr_name=None, transform=None):
        root_dir = Path(root)
        self.image_files = [f for f in root_dir.iterdir() if f.is_file()]

        self.transform = transform or transforms.Compose([
            transforms.Resize((128,128)),
            transforms.ToTensor(),
            transforms.Normalize([0.5,0.5,0.5],[0.5,0.5,0.5])
        ])

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        img_file = self.image_files[idx]
        img = Image.open(img_file).convert('RGB')
        img = self.transform(img)

        return {'data': img}

# 使用例: DataLoader を作ってバッチサイズと形状を確認
try:
    dataset = GridMapDataset("/path/to/dir")
    from torch.utils.data import DataLoader
    loader = DataLoader(dataset, batch_size=8, shuffle=True)
    batch = next(iter(loader))
    print('data.shape =', batch['data'].shape)
except Exception as e:
    print('Error creating/using CelebADaIllustDatasettaset:', e)

In [3]:
# マスクを作るユーティリティと可視化関数
import torch
import random
import matplotlib.pyplot as plt
from pathlib import Path
from PIL import Image
import tqdm

def random_mask(shape, mask_size=(32,32)):
    """与えられた (C,H,W) tensor のランダムな矩形を 0 にするマスクを返す"""
    C, H, W = shape
    mh, mw = mask_size
    top = random.randint(0, H - mh)
    left = random.randint(0, W - mw)
    mask = torch.ones((C, H, W), dtype=torch.float32)
    mask[:, top:top+mh, left:left+mw] = 0.0
    return mask


def show_masked(image_tensor, mask, title=None):
    # image_tensor: C,H,W
    im = image_tensor.clone()
    im = im * mask
    im = im.permute(1,2,0).numpy()
    im = (im * 0.5) + 0.5  # unnormalize
    plt.figure(figsize=(3,3))
    plt.imshow(im)
    if title:
        plt.title(title)
    plt.axis('off')


def save_im(image_tensor, mask, name="masked.jpg"):
    """image_tensor (C,H,W) と mask (C,H,W) を乗算した結果（im）を画像ファイルとして保存する"""
    try:
        im = image_tensor.clone()
        im = im * mask
        # C,H,W -> H,W,C, convert to CPU numpy
        im_np = im.permute(1,2,0).cpu().numpy()
        # unnormalize range -1..1 -> 0..1  あるいは 0..1 の場合もあるので安全に扱う
        im_np = (im_np * 0.5) + 0.5
        im_uint8 = (im_np * 255).clip(0,255).astype('uint8')
        # PIL expects HxW or HxWx3
        if im_uint8.shape[2] == 1:
            img_pil = Image.fromarray(im_uint8[:,:,0], mode='L')
        else:
            img_pil = Image.fromarray(im_uint8)
        path = Path.cwd() / name
        img_pil.save(str(path))
        print(f'Saved masked image to: {path}')
    except Exception as e:
        print('Failed to save masked image:', e)

# for i in tqdm.tqdm(range(len(dataset))):
# for i in range(len(dataset)):
#     sample = dataset[i]['data']
#     m = random_mask(sample.shape, mask_size=(32,32))
#     save_im(sample, m, f'dataset/celeba/mask_32_32/{i:06}.jpg')

# 簡単なテスト（dataset が使えるなら）。生成した im を画像として保存する
try:
    num = random.randint(0, len(dataset)-1)
    sample = dataset[num]['data']
    m = random_mask(sample.shape, mask_size=(32,32))
    show_masked(sample, m, f'Masked sample {num}')
except Exception as e:
    print('Mask util test skipped:', e)


In [10]:
# シンプルな畳み込みオートエンコーダと学習ループ
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import tqdm

class SimpleAutoEncoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 16, 3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(16, 32, 3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 64, 3, stride=2, padding=1),
            nn.ReLU()
        )
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(64, 32, 4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(32, 16, 4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(16, 3, 4, stride=2, padding=1),
            nn.Tanh()  # 出力は正規化範囲に合わせて -1..1
        )

    def forward(self, x):
        z = self.encoder(x)
        out = self.decoder(z)
        return out

try:
    checkpoint_path = "grid_map_complement.pt"
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = SimpleAutoEncoder().to(device)
    optim = optim.Adam(model.parameters(), lr=1e-3)
    criterion = nn.MSELoss()

    loader = DataLoader(dataset, batch_size=16, shuffle=True)

    early_stopping_best_val_loss = float("inf")
    epochs = 10000
    for epoch in range(epochs):
        total_loss = 0.0
        # for batch in tqdm.tqdm(loader):
        for batch in loader:
            imgs = batch['data'].to(device)
            masks = torch.stack([random_mask(img.shape, mask_size=(32,32)) for img in imgs]).to(device)
            masked = imgs * masks

            recon = model(masked)
            loss = criterion(recon * (1-masks), imgs * (1-masks))  # マスク領域の損失を最小化

            optim.zero_grad()
            loss.backward()
            optim.step()

            total_loss += loss.item()
        print(f"Epoch {epoch+1}/{epochs} loss={total_loss/len(loader):.6f}")

        # モデルのセーブ
        if total_loss < early_stopping_best_val_loss:
            early_stopping_best_val_loss = total_loss
            early_stopping_patience_counter = 0
            # ベストなモデルとして Checkpoint を更新する
            checkpoint_params = {
                "epoch": epoch,
                "model_state_dict": model.state_dict(),
                "optimizer_state_dict": optim.state_dict(),
                "loss": total_loss,
            }
            torch.save(
                checkpoint_params,
                checkpoint_path,
            )

except Exception as e:
    print('Training/test loop skipped or failed:', e)


In [15]:
# 1 バッチを可視化して補間結果を確認（マスク領域を pred で置き換える）
batch = next(iter(loader))
imgs = batch['data']
masks = torch.stack([random_mask(img.shape, mask_size=(32,32)) for img in imgs])
masked = imgs * masks

# モデルのロード
checkpoint = torch.load(checkpoint_path)
model.load_state_dict(checkpoint["model_state_dict"])

model.eval()
with torch.no_grad():
    pred = model(masked.to(device)).cpu()

# マスク領域を pred で置き換え: masks=1 が保持領域、0 が補間領域と想定
# ここでは masks が 1:保持, 0:穴 の形式なので、補完画像は imgs * masks + pred * (1-masks)
filled = imgs * masks + pred * (1 - masks)

# 最初の画像を表示: 原画 / マスク / 補間 / マスクで埋めた結果
show_masked(imgs[0], torch.ones_like(imgs[0]), 'Original')
show_masked(imgs[0], masks[0], 'Masked')
show_masked(pred[0], torch.ones_like(pred[0]), 'Reconstructed')
show_masked(filled[0], torch.ones_like(filled[0]), 'Filled (masked regions replaced by Reconstructed)')

In [16]:
# 単独で実行する場合に filled を表示
try:
    show_masked(filled[0], torch.ones_like(filled[0]), 'Filled (masked regions replaced by pred)')
except NameError:
    print('Variable filled not found. Run the previous cell first.')