In [1]:
from torchvision.models.detection import fasterrcnn_resnet50_fpn, retinanet_resnet50_fpn # 親モデル
from torchvision.models.detection import ssdlite320_mobilenet_v3_large # 子モデル
from torchvision import transforms as T
import tqdm
import torch.optim as optim
import os
from torch.utils.tensorboard import SummaryWriter
import torch
import time


In [3]:
# 異常歯の検出
num_classes = 5

"""Dataset、Dataloaderの定義"""
# データセットの定義
from torchvision.datasets import CocoDetection
from data import CustomCocoDetection
from torch.utils.data import DataLoader

# データセットのパス
root = "C:/Users/ohhara/mobilenetv2-ssd/dataset/quadrant-enumeration-disease/xrays"
annFile = "C:/Users/ohhara/mobilenetv2-ssd/dataset/quadrant-enumeration-disease/train_corrected.json"

target_size = (320, 320)

transform = T.Compose([
    #T.ToTensor(),
    T.ToTensor(),
    T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    

])

def collate_fn(batch):
    return tuple(zip(*batch))
# データセットの作成
dataset = CustomCocoDetection(root=root, annFile=annFile, transform=transform, target_size=target_size)

dataloader = DataLoader(dataset, batch_size=16, shuffle=True, collate_fn=collate_fn, drop_last=True)
# データローダーの作成
first_batch_image, first_batch_target = next(iter(dataloader))


                                                                 

loading annotations into memory...
Done (t=0.07s)
creating index...
index created!


In [4]:
# 二つのモデルを定義
teacher_model = retinanet_resnet50_fpn(pretrained=False, num_classes=num_classes)
student_model = ssdlite320_mobilenet_v3_large(pretrained=False, num_classes=num_classes)

# モデルのパラメータ数を表示
print(f"Teacher model parameters: {sum(p.numel() for p in teacher_model.parameters())}")
print(f"Student model parameters: {sum(p.numel() for p in student_model.parameters())}")






Teacher model parameters: 32230929
Student model parameters: 3758900


In [None]:
# 教師モデルを学習させる
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
teacher_model.to(device)

teacher_model.train()
save_dir = "C:/Users/ohhara/mobilenetv2-ssd/distillation/teacher_model"
os.makedirs(save_dir, exist_ok=True)
epochs = 200

optimizer = optim.SGD(teacher_model.parameters(), lr=0.0001, momentum=0.9)
start_time = time.time()

for epoch in range(epochs):
    writer = SummaryWriter

    epoch_loss = 0
    for images, targets in tqdm.tqdm(dataloader):
        images = [image.to(device) for image in images]
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

        # 勾配のリセット
        optimizer.zero_grad()
        loss_dict = teacher_model(images, targets)
        losses = sum(loss for loss in loss_dict.values())
        losses.backward()
        optimizer.step()

        epoch_loss += losses.item()

    # 平均損失を計算
    avg_epoch_loss = epoch_loss / len(dataloader)

    
    
    print(f"epoch:{epoch + 1}")
    print(f"loss:{avg_epoch_loss}")

    if (epoch + 1) % 10 == 0:
        save_path = os.path.join(save_dir, f"model_epoch_{epoch + 1}.pth")
        torch.save(teacher_model.state_dict(), save_path)
        print(f"Saved model weights at {save_path}")

end_time = time.time()
print(f"Training time: {end_time - start_time} seconds")


cuda


100%|██████████| 44/44 [27:01<00:00, 36.84s/it]


epoch:1
loss:1.8654985129833221


100%|██████████| 44/44 [27:00<00:00, 36.84s/it]


epoch:2
loss:1.8011962961066852


100%|██████████| 44/44 [27:00<00:00, 36.83s/it]


epoch:3
loss:1.7165450047362933


100%|██████████| 44/44 [27:05<00:00, 36.94s/it]


epoch:4
loss:1.6794976266947659


  7%|▋         | 3/44 [01:38<22:28, 32.90s/it]