In [1]:
"""必要なモジュールの読み込み"""
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import numpy as np
import os
import cv2
import torchvision
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
import clip



device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)












cuda


In [2]:
# 画像にマスクをかける
from demo.simmim import MaskGenerator

# 画像の読み込み
image = cv2.imread("xrays/val_0.png")
image = cv2.resize(image, (512, 512))
## 前処理
import torchvision.transforms as transforms
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
image = transform(image)
mask = MaskGenerator(image)
# マスクの生成
masked_image = mask

# 画像の表示
## numpyに変換

# image Encoderの実装
from demo.image_encoder import ResnetEncoder

encoder = ResnetEncoder()
# imagetensorを(B, H, W, C)に変換
image = image.unsqueeze(0)

image = image.to(device)


encoder.to(device)
features = encoder.forward(image)
print(features.shape)










error: OpenCV(4.10.0) D:\a\opencv-python\opencv-python\opencv\modules\imgproc\src\resize.cpp:4152: error: (-215:Assertion failed) !ssize.empty() in function 'cv::resize'


In [3]:
# 特徴量抽出された画像を復元
from demo.reconstruction import Recostruction

recostruction = Recostruction(encoder_outchannels=2048)

recostruction.to(device)

features = features.to(device)

reconstructed_image = recostruction.forward(features)

# 画像の表示
reconstructed_image = reconstructed_image.squeeze(0)
reconstructed_image = reconstructed_image.detach().cpu().numpy()
print(reconstructed_image.shape)

reconstructed_image = np.transpose(reconstructed_image, (1, 2, 0))
# 正規化を戻す

cv2.imwrite("reconstructed_image.png", reconstructed_image)






(3, 511, 511)


True

In [None]:
"""学習の実装"""
# データセットの読み込み
from demo.data import UnlabelledDataset

from torchvision.datasets import CocoDetection
from torch.utils.data import DataLoader

# 再構築ブランチのデータセットの読み込み
reconstruction_dataset = UnlabelledDataset(root="xrays/val", transforms=transform)
reconstruction_dataloader = DataLoader(reconstruction_dataset, batch_size=16, shuffle=True, num_workers=4)
# 検出ブランチのデータセットの読み込み
detection_dataset = CocoDetection(root="xrays/val", annFile="xrays/val.json",transform=transform)
detection_dataloader = DataLoader(detection_dataset, batch_size=16, shuffle=True, num_workers=4, collate_fn = lambda x: tuple(zip(*x)))


# モデルの読み込み
from demo.image_encoder import ResnetEncoder
from demo.detection import FCOSDetector
from demo.reconstruction import Recostruction

encoder = ResnetEncoder()
detector = FCOSDetector()
recostruction = Recostruction(encoder_outchannels=2048)


# 損失関数
from demo.loss import ReconstructionLoss, TextureConsistencyLoss

# オプティマイザー
import torch.optim as optim
optimizer = optim.Adam(
    [
        {"params": encoder.parameters(), "lr": 1e-4},
        {"params": detector.parameters(), "lr": 1e-4},
        {"params": recostruction.parameters(), "lr": 1e-4},
    ]
)

# 学習ループ
NUM_EPOCHS = 1000
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# モデルをデバイスに移動
encoder.to(device)
detector.to(device)
recostruction.to(device)

for epoch in range(NUM_EPOCHS):
    encoder.train()
    detector.train()
    recostruction.train()

    total_loss = 0
    
    for images, targets in detection_dataloader:
        images = images.to(device)
        targets = targets.to(device)

        # imageにマスクをかける
        mask = MaskGenerator(image)
        masked_image = mask

        # マスクをかけた画像をencoderに通す
        features = encoder.forward(masked_image)

        # 特徴量抽出された画像を復元
        reconstructed_image = recostruction.forward(features)

        # 再構築ブランチの損失関数を計算
        reconstruction_loss = ReconstructionLoss(image, reconstructed_image)
        texture_consistency_loss = TextureConsistencyLoss(image, reconstructed_image)

        # 検出ブランチの損失関数を計算
        detections = detector.forward(image)
        detection_loss = detector(images, targets)
        detection_loss = sum(loss for loss in detection_loss.values())

        # 総損失を計算  
        total_loss = reconstruction_loss + texture_consistency_loss + detection_loss

        # 勾配を計算
        optimizer.zero_grad()
        total_loss.backward()
        optimizer.step()

        # ロスを表示
        print(f"Epoch {epoch+1}, Loss: {total_loss.item()}")








    



