In [1]:
import torch
import torch.nn as nn
import torchvision
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms
from PIL import Image
import os
import cv2

In [2]:
transforms = torchvision.transforms.Compose({
    torchvision.transforms.Resize((256, 256)),
    torchvision.transforms.ToTensor(),
    # transforms.Lambda(lambda x: x.repeat(3, 1, 1)),
    torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
})

In [26]:
root_path = "C:/Users/ohhara/generative_model/generative_model/dataset/Tufts Dental Database/Radiographs"
image_list = os.listdir(root_path)
with open("C:/Users/ohhara/generative_model/generative_model/dataset/Tufts Dental Database/Radiographs/index.txt", "w") as f:
   for image_name in image_list:
      f.write(image_name+"\n")


In [15]:
# データセットの準備
class TuftsDataset(Dataset):
    def __init__(self, root_path, transform=None, input_size=256):
        self.root_path = root_path
        self.transform = transform
        self.len = len(os.listdir(root_path))
         
    
    def __getitem__(self, index):
        image_path = os.path.join(self.root_path, JPG")
        image = Image.open(image_path)
        if self.transform is not None:
            image = self.transform(image)
        return image
    
    def __len__(self):
        return self.len
    




In [16]:
mydataset = TuftsDataset(root_path="C:/Users/ohhara/generative_model/generative_model/dataset/Tufts Dental Database/Radiographs", transform=transforms)


In [19]:
dataloader = DataLoader(mydataset, batch_size=32, shuffle=True)
data = dataloader.__iter__().__next__()[0]


FileNotFoundError: [Errno 2] No such file or directory: 'C:/Users/ohhara/generative_model/generative_model/dataset/Tufts Dental Database/Radiographs\\863.JPG'

In [None]:
"""学習プロセスの表示の為の関数"""
def display_process(hist, G, image_frame_dim, sample_z, fix=True):
    plt.gcf().clear()



In [None]:
# 重みの初期化関数の定義
def initialize_weights(model):
    for m in model.modules():
        if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.BatchNorm2d)):
            m.weight.data.normal_(0.0, 0.02)
            m.bias.data.zero_(0)
        


In [None]:
"""Generatorの作成"""
class Generator(nn.Module):
    def __init__(self, input_dim=100, output_dim=3, input_size=256):
        super(Generator, self).__init__()
        self.input_size = input_size
        self.output_dim = output_dim
        self.input_dim = input_dim

        self.fc = nn.Sequential(
            nn.Linear(self.input_dim, 1024),
            nn.BatchNorm1d(1024),
            nn.ReLU(),
            nn.Linear(1024, 128*(self.input_size//16)**2),
            nn.BatchNorm1d(128*(self.input_size//16)**2),
            nn.ReLU(),
        )

        self.deconv = nn.Sequential(
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 32, self.output_dim, stride=2, padding=1),
            nn.Tanh(),
        )
        initialize_weights(self)

        def forward(self, input):
            x = self.fc(input)
            x = x.view(-1, 128, self.input_size//16, self.input_size//16)
            x = self.deconv(x)
            return x
        



In [None]:
"""Discriminatorの作成"""
class Discriminator(nn.Module):
    def __init__(self, input_dim=1, output_dim=1, input_size=256, sig_out=True):
        super(Discriminator, self).__init__()
        self.input_size = input_size
        self.output_dim = output_dim
        self.sig_out = sig_out
        
        self.conv = nn.Sequential(
            nn.Conv2d(self.input_dim, 64, kernel_size=3, stride=1, padding=1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),

        )
        self.fc = nn.Sequential(
            nn.Linear(128*(self.input_size//4)**2, 1024),
            nn.BatchNorm1d(1024),
            nn.LeakyReLU(0.2),
            nn.Linear(1024, 1),
            
        )
        if self.sig_out:
            self.fc.add_module("sigmoid",nn.Sigmoid())
        initialize_weights(self)
        
        def forward(self, input):
            x = self.conv(input)
            x = x.view(-1, 128*(self.input_size//4)**2)
            x = self.fc(x)
            return x


In [None]:
dataloader = DataLoader(mydataset, batch_size=32, shuffle=True)


In [None]:
"""GANの学習"""
class GAN:
    def __init__(self, epoch=500):
        self.epochs = epoch
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.sample_num = 16
        self.batch_size = 32
        self.z_dim = 64
        self.lrG = 0.0002
        self.lrD = 0.0002
        self.beta1 = 0.5
        self.beta2 = 0.999

        # データローダ
        self.data_loader = DataLoader(mydataset, batch_size=self.batch_size, shuffle=True)
        

        self.G = Generator(input_dim=self.z_dim, output_dim=3, input_size=256).to(self.device)
        self.D = Discriminator(input_dim=3, output_dim=1, input_size=256).to(self.device)
        self.criterion = nn.BCELoss()
        self.optimizerG = torch.optim.Adam(self.G.parameters(), lr=self.lrG, betas=(self.beta1, self.beta2))
        
