In [37]:
import torch
import torch.nn as nn
import torchvision
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms
from PIL import Image
import os
import cv2

In [38]:
transforms = torchvision.transforms.Compose({
    torchvision.transforms.Resize((256, 256)),
    torchvision.transforms.ToTensor(),
    # transforms.Lambda(lambda x: x.repeat(3, 1, 1)),
    torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
})

In [39]:
root_path = "C:/Users/ohhara/generative_model/generative_model/dataset/Tufts Dental Database/Radiographs"
image_list = os.listdir(root_path)
with open("C:/Users/ohhara/generative_model/generative_model/dataset/Tufts Dental Database/Radiographs/index.txt", "w") as f:
   for image_name in image_list:
      f.write(image_name+"\n")


In [61]:
# データセットの準備
class TuftsDataset(Dataset):
    def __init__(self, root_path, transform=None, input_size=256):
        self.root_path = root_path
        self.transform = transform
        self.len = len(os.listdir(root_path))
         
    
    def __getitem__(self, index):
        with open("C:/Users/ohhara/generative_model/generative_model/dataset/Tufts Dental Database/index.txt", "r") as f:
            image_name = f.readlines()[index].strip()
        image_path = os.path.join(self.root_path, image_name)
        image = Image.open(image_path)
        if self.transform is not None:
            image = self.transform(image)
        return image
    
    def __len__(self):
        return self.len
    




In [62]:
mydataset = TuftsDataset(root_path="C:/Users/ohhara/generative_model/generative_model/dataset/Tufts Dental Database/Radiographs", transform=transforms)


In [42]:
"""学習プロセスの表示の為の関数"""
def display_process(hist, G, image_frame_dim, sample_z, fix=True):
    plt.gcf().clear()

    fig = plt.figure(figsize=(24, 15))
    fig.subplots_adjust(left=0, right=1, bottom=0, hspace=0.05, wspace=0.05)

    x = range(len(hist["D_loss"]))

    y1 = hist["D_loss"]
    y2 = hist["G_loss"]

    ax1 = fig.add_subplot(1, 2, 1)

    ax1.plot(x, y1, label="D_loss")
    ax1.plot(x, y2, label="G_loss")

    ax1.set_xlabel("Epoch")
    ax1.set_ylabel("Loss")


    samples = G(sample_z)
    samples.cpu().data.numpy().transpose(0, 2, 3, 1)
    samples = (samples + 1) / 2

    for i in range(image_frame_dim*image_frame_dim):
        ax = fig.add_subplot(image_frame_dim, image_frame_dim*2, (int(i/image_frame_dim)+1)*image_frame_dim+i+1, xticks=[], yticks=[])
        ax.imshow(samples[i])
    else:
        ax.imshow(samples[i][:,:,0], cmap="gray")
        

    ax1.legend()

    plt.show()
    
                             


In [48]:
# 重みの初期化関数の定義
def initialize_weights(model):
    for m in model.modules():
        if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.BatchNorm2d)):
            m.weight.data.normal_(0.0, 0.02)
            m.bias.data.zero_()
        


In [44]:
"""Generatorの作成"""
class Generator(nn.Module):
    def __init__(self, input_dim=100, output_dim=1, input_size=256):
        super(Generator, self).__init__()
        self.input_size = input_size
        self.output_dim = output_dim
        self.input_dim = input_dim

        self.fc = nn.Sequential(
            nn.Linear(self.input_dim, 1024),
            nn.BatchNorm1d(1024),
            nn.ReLU(),
            nn.Linear(1024, 128*(self.input_size//16)**2),
            nn.BatchNorm1d(128*(self.input_size//16)**2),
            nn.ReLU(),
        )

        self.deconv = nn.Sequential(
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 32, self.output_dim, stride=2, padding=1),
            nn.Tanh(),
        )
        initialize_weights(self)

        def forward(self, input):
            x = self.fc(input)
            x = x.view(-1, 128, self.input_size//16, self.input_size//16)
            x = self.deconv(x)
            return x
        



In [50]:
"""Discriminatorの作成"""
class Discriminator(nn.Module):
    def __init__(self, input_dim=1, output_dim=1, input_size=256, sig_out=True):
        super(Discriminator, self).__init__()
        self.input_dim = input_dim
        self.input_size = input_size
        self.output_dim = output_dim
        self.sig_out = sig_out
        
        self.conv = nn.Sequential(
            nn.Conv2d(self.input_dim, 64, kernel_size=3, stride=1, padding=1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),

        )
        self.fc = nn.Sequential(
            nn.Linear(128*(self.input_size//4)**2, 1024),
            nn.BatchNorm1d(1024),
            nn.LeakyReLU(0.2),
            nn.Linear(1024, 1),
            
        )
        if self.sig_out:
            self.fc.add_module("sigmoid",nn.Sigmoid())
        initialize_weights(self)
        
        def forward(self, input):
            x = self.conv(input)
            x = x.view(-1, 128*(self.input_size//4)**2)
            x = self.fc(x)
            return x


In [None]:
"""GANの学習"""
class GAN:
    def __init__(self, epoch=500):
        self.epochs = epoch
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.sample_num = 16
        self.batch_size = 32
        self.input_size = 256   
        self.z_dim = 64
        self.lrG = 0.0002
        self.lrD = 0.0002
        self.beta1 = 0.5
        self.beta2 = 0.999

        # データローダ
        self.data_loader = DataLoader(mydataset, self.input_size, self.batch_size)
        data = self.data_loader.__iter__().__next__()[0]

        # モデルの初期化
        self.G = Generator(input_dim=self.z_dim, output_dim=data.shape[1], input_size=self.input_size).to(self.device)
        self.D = Discriminator(input_dim=data.shape[1], output_dim=1, input_size=self.input_size).to(self.device)
        self.G_optimizer = torch.optim.Adam(self.G.parameters(), lr=self.lrG, betas=(self.beta1, self.beta2))
        self.D_optimizer = torch.optim.Adam(self.D.parameters(), lr=self.lrD, betas=(self.beta1, self.beta2))

        # cudaに乗っける
        self.G.to(self.device)
        self.D.to(self.device)
        self.BCE_loss = nn.BCELoss().to(self.device)

        self.sample_z = torch.randn(self.sample_num, self.z_dim).to(self.device)
        
    def train(self):
        self.train_hist = {}
        self.train_hist['D_loss'] = []
        self.train_hist['G_loss'] = []

        self.y_real, self.y_fake = torch.ones(self.batch_size, 1).to(self.device), torch.zeros(self.batch_size, 1).to(self.device)
        
        # 訓練モードにする
        self.D.train()
        for epoch in range(self.epochs):
            for iter, x_ in enumerate(self.data_loader): # (batch_size, 3, 256, 256)
                if iter == self.data_loader.__len__() // self.batch_size:
                    break
                # zをサンプリング
                x_ = x_.to(self.device)
                z = torch.randn(self.batch_size, self.z_dim).to(self.device)

                # Discriminatorの更新
                self.D_optimizer.zero_grad()

                D_real = self.D(x_)
                D_real_loss = self.BCE_loss(D_real, self.y_real)

                G = self.G(z)
                D_fake = self.D(G)
                D_fake_loss = self.BCE_loss(D_fake, self.y_fake)

                # D_lossの計算
                D_loss = D_real_loss + D_fake_loss
                self.train_hist["D_loss"].append(D_loss.item())

                D_loss.backward()
                self.D_optimizer.step()

                # Generatorの更新
                self.G_optimizer.zero_grad()
                G = self.G(z)
                D_fake = self.D(G)
                G_loss = self.BCE_loss(D_fake, self.y_real)
                self.train_hist["G_loss"].append(G_loss.item())
                D_fake = self.D(G)
                G_loss = self.BCE_loss(D_fake, self.y_real)
                self.train_hist["G_loss"].append(G_loss.item())

                G_loss.backward()
                self.G_optimizer.step()


                # ディスプレイ
                if ((iter + 1) %10 == 0):
                    with torch.no_grad():
                        tot_num_samples = min(self.sample_num, self.batch_size)
                        image_frame_dim = int(np.floor(np.sqrt(tot_num_samples)))
                        display_process(self.train_hist, self.G, image_frame_dim, self.sample_z)
                        display.clear_output(wait=True)
                        display(plt.gcf())
                        plt.close()

        plt.close()
        print("Training complete")
                

                





In [None]:
gan = GAN()
gan.train()


In [63]:
for iter, x in enumerate(DataLoader(mydataset, batch_size=32, shuffle=True)):
    print(x.shape)



torch.Size([32, 3, 256, 256])
torch.Size([32, 3, 256, 256])
torch.Size([32, 3, 256, 256])
torch.Size([32, 3, 256, 256])
torch.Size([32, 3, 256, 256])
torch.Size([32, 3, 256, 256])
torch.Size([32, 3, 256, 256])
torch.Size([32, 3, 256, 256])
torch.Size([32, 3, 256, 256])
torch.Size([32, 3, 256, 256])
torch.Size([32, 3, 256, 256])
torch.Size([32, 3, 256, 256])
torch.Size([32, 3, 256, 256])
torch.Size([32, 3, 256, 256])
torch.Size([32, 3, 256, 256])
torch.Size([32, 3, 256, 256])
torch.Size([32, 3, 256, 256])
torch.Size([32, 3, 256, 256])
torch.Size([32, 3, 256, 256])
torch.Size([32, 3, 256, 256])
torch.Size([32, 3, 256, 256])
torch.Size([32, 3, 256, 256])
torch.Size([32, 3, 256, 256])
torch.Size([32, 3, 256, 256])
torch.Size([32, 3, 256, 256])
torch.Size([32, 3, 256, 256])
torch.Size([32, 3, 256, 256])
torch.Size([32, 3, 256, 256])
torch.Size([32, 3, 256, 256])
torch.Size([32, 3, 256, 256])
torch.Size([32, 3, 256, 256])
torch.Size([8, 3, 256, 256])
