In [1]:
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.sampler import SubsetRandomSampler
import os
import random
import itertools
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt

In [2]:
class ResidualBlock(nn.Module):
    def __init__(self, in_channels):
        super(ResidualBlock, self).__init__()
        self.block = nn.Sequential(
            nn.ReflectionPad2d(1), # padding, keep the image size constant after next conv2d
            nn.Conv2d(in_channels, in_channels, 3),
            nn.InstanceNorm2d(in_channels),
            nn.ReLU(inplace=True),
            nn.ReflectionPad2d(1),
            nn.Conv2d(in_channels, in_channels, 3),
            nn.InstanceNorm2d(in_channels)
        )

    def forward(self, x):
        return x + self.block(x)

In [3]:
class GeneratorResNet(nn.Module):
    def __init__(self, in_channels, num_residual_blocks=9):
        super(GeneratorResNet, self).__init__()
        
        # Inital Convolution  3*224*224 -> 64*224*224
        out_channels=64
        self.conv = nn.Sequential(
            nn.ReflectionPad2d(in_channels), # padding, keep the image size constant after next conv2d
            nn.Conv2d(in_channels, out_channels, 2*in_channels+1),
            nn.InstanceNorm2d(out_channels),
            nn.ReLU(inplace=True),
        )
        
        channels = out_channels
        
        # Downsampling   64*224*224 -> 128*112*112 -> 256*56*56
        self.down = []
        for _ in range(2):
            out_channels = channels * 2
            self.down += [
                nn.Conv2d(channels, out_channels, 3, stride=2, padding=1),
                nn.InstanceNorm2d(out_channels),
                nn.ReLU(inplace=True),
            ]
            channels = out_channels
        self.down = nn.Sequential(*self.down)
        
        # Transformation (ResNet)  256*56*56
        self.trans = [ResidualBlock(channels) for _ in range(num_residual_blocks)]
        self.trans = nn.Sequential(*self.trans)
        
        # Upsampling  256*56*56 -> 128*112*112 -> 64*224*224
        self.up = []
        for _ in range(2):
            out_channels = channels // 2
            self.up += [
                nn.Upsample(scale_factor=2), # bilinear interpolation
                nn.Conv2d(channels, out_channels, 3, stride=1, padding=1),
                nn.InstanceNorm2d(out_channels),
                nn.ReLU(inplace=True),
            ]
            channels = out_channels
        self.up = nn.Sequential(*self.up)
        
        # Out layer  64*224*224 -> 3*224*224
        self.out = nn.Sequential(
            nn.ReflectionPad2d(in_channels),
            nn.Conv2d(channels, in_channels, 2*in_channels+1),
            nn.Tanh()
        )
    
    def forward(self, x):
        x = self.conv(x)
        x = self.down(x)
        x = self.trans(x)
        x = self.up(x)
        x = self.out(x)
        return x

In [4]:
class Discriminator(nn.Module):
    def __init__(self, in_channels):
        super(Discriminator, self).__init__()
        
        self.model = nn.Sequential(
            # why normalize=False?
            *self.block(in_channels, 64, normalize=False), # 3*224*224 -> 64*112*112 
            *self.block(64, 128),  # 64*112*112 -> 128*56*56
            *self.block(128, 256), # 128*56*56 -> 256*28*28
            *self.block(256, 512), # 256*28*28 -> 512*14*14
            
            # Why padding first then convolution?
            nn.ZeroPad2d((1,0,1,0)), # padding left and top   512*14*14 -> 512*15*15
            nn.Conv2d(512, 1, 4, padding=1) # 512*15*15 -> 1*14*14
        )
        
        self.scale_factor = 16
    
    @staticmethod
    def block(in_channels, out_channels, normalize=True):
        layers = [nn.Conv2d(in_channels, out_channels, 4, stride=2, padding=1)]
        if normalize:
            layers.append(nn.InstanceNorm2d(out_channels))
        layers.append(nn.LeakyReLU(0.2, inplace=True))
        
        return layers
        
    def forward(self, x):
        return self.model(x)

In [5]:
criterion_GAN = nn.MSELoss()
criterion_cycle = nn.L1Loss()
criterion_identity = nn.L1Loss()

In [6]:
G_AB = GeneratorResNet(3, num_residual_blocks=9)
D_B = Discriminator(3)

G_BA = GeneratorResNet(3, num_residual_blocks=9)
D_A = Discriminator(3)

In [7]:
lr = 0.0002
b1 = 0.5
b2 = 0.999

optimizer_G = torch.optim.Adam(
    itertools.chain(G_AB.parameters(), G_BA.parameters()), lr=lr, betas=(b1, b2)
)

optimizer_D_A = torch.optim.Adam(
    D_A.parameters(), lr=lr, betas=(b1, b2)
)

optimizer_D_B = torch.optim.Adam(
    D_B.parameters(), lr=lr, betas=(b1, b2)
)

In [8]:
# checkpoint = torch.load("../input/trainedmodel1/melanomagan_config_3.pth")
# G_AB.load_state_dict(checkpoint['G_AB_state_dict'])
# G_BA.load_state_dict(checkpoint['G_BA_state_dict'])
# D_A.load_state_dict(checkpoint['D_A_state_dict'])
# D_B.load_state_dict(checkpoint['D_B_state_dict'])
# optimizer_G.load_state_dict(checkpoint['optimizer_G_state_dict'])
# optimizer_D_A.load_state_dict(checkpoint['optimizer_D_A_state_dict'])
# optimizer_D_B.load_state_dict(checkpoint['optimizer_D_B_state_dict'])

In [9]:
cuda = torch.cuda.is_available()
print(f'cuda: {cuda}')
if cuda:
    G_AB = G_AB.cuda()
    D_B = D_B.cuda()
    G_BA = G_BA.cuda()
    D_A = D_A.cuda()
    for state in optimizer_G.state.values():
        for k, v in state.items():
            if isinstance(v, torch.Tensor):
                state[k] = v.cuda()
    for state in optimizer_D_A.state.values():
        for k, v in state.items():
            if isinstance(v, torch.Tensor):
                state[k] = v.cuda()
    for state in optimizer_D_B.state.values():
        for k, v in state.items():
            if isinstance(v, torch.Tensor):
                state[k] = v.cuda()
    
    criterion_GAN = criterion_GAN.cuda()
    criterion_cycle = criterion_cycle.cuda()
    criterion_identity = criterion_identity.cuda()

In [10]:
# G_AB.train()
# G_BA.train()
# D_A.train()
# D_B.train()

In [11]:
n_epoches = 500
decay_epoch = 10

lambda_func = lambda epoch: 1 - max(0, epoch-decay_epoch)/(n_epoches-decay_epoch)

lr_scheduler_G = torch.optim.lr_scheduler.LambdaLR(optimizer_G, lr_lambda=lambda_func)
lr_scheduler_D_A = torch.optim.lr_scheduler.LambdaLR(optimizer_D_A, lr_lambda=lambda_func)
lr_scheduler_D_B = torch.optim.lr_scheduler.LambdaLR(optimizer_D_B, lr_lambda=lambda_func)

In [12]:
class ImageDataset(Dataset):
    def __init__(self, malign_dir, benign_dir, size=(256, 256), normalize=True):
        super().__init__()
        self.malign_dir = malign_dir
        self.benign_dir = benign_dir
        self.malign_idx = {}
        self.benign_idx = {}
        if normalize:
            self.transform_b = transforms.Compose([
#                 transforms.Resize(size),
#                 transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))                                
            ])
        else:
            self.transform_b = transforms.Compose([
#                 transforms.Resize(size),
#                 transforms.RandomHorizontalFlip(),
                transforms.ToTensor()                               
            ])
        if normalize:
            self.transform_m = transforms.Compose([
#                 transforms.Resize(size),
                transforms.RandomHorizontalFlip(0.5),
                transforms.RandomVerticalFlip(0.5),
                transforms.RandomRotation((90, 90)),
                transforms.ToTensor(),
                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))                                
            ])
        else:
            self.transform_m = transforms.Compose([
#                 transforms.Resize(size),
                transforms.RandomHorizontalFlip(0.5),
                transforms.RandomVerticalFlip(0.5),
                transforms.RandomRotation((90, 90)),
                transforms.ToTensor()                               
            ])
        for i, fl in enumerate(os.listdir(self.malign_dir)):
            self.malign_idx[i] = fl
        for i, fl in enumerate(os.listdir(self.benign_dir)):
            self.benign_idx[i] = fl

    def __getitem__(self, idx):
        rand_idx = int(np.random.uniform(0, len(self.malign_idx.keys())))
        benign_path = os.path.join(self.benign_dir, self.benign_idx[rand_idx])
        malign_path = os.path.join(self.malign_dir, self.malign_idx[idx])
        benign_img = Image.open(benign_path)
        benign_img = self.transform_b(benign_img)
        malign_img = Image.open(malign_path)
        malign_img = self.transform_m(malign_img)
        return benign_img, malign_img

    def __len__(self):
        return min(len(self.malign_idx.keys()), len(self.benign_idx.keys()))

In [13]:
benign_dir = '../input/melanoma/Melanoma/train/benign'
malign_dir = '../input/melanoma/Melanoma/train/malignant'

batch_size = 8
dataset = ImageDataset(malign_dir, benign_dir, normalize=True)
validation_split = .1
shuffle_dataset = True
random_seed= 42

dataset_size = len(dataset)
indices = list(range(dataset_size))
split = int(np.floor(validation_split * dataset_size))
if shuffle_dataset :
    np.random.seed(random_seed)
    np.random.shuffle(indices)
train_indices, val_indices = indices[split:], indices[:split]

train_sampler = SubsetRandomSampler(train_indices)
valid_sampler = SubsetRandomSampler(val_indices)

trainloader = DataLoader(
    dataset, 
    batch_size=batch_size,
    sampler=train_sampler)

testloader = DataLoader(
    dataset, 
    batch_size=batch_size,
    sampler=valid_sampler)

In [14]:
len(trainloader), len(testloader)

In [15]:
from torchvision.utils import make_grid

Tensor = torch.cuda.FloatTensor if cuda else torch.Tensor

def sample_images(real_A, real_B, figside=3):
    assert real_A.size() == real_B.size() #'The image size for two domains must be the same'
    
    G_AB.eval()
    G_BA.eval()
    
    real_A = real_A.type(Tensor)
    fake_B = G_AB(real_A).detach()
    real_B = real_B.type(Tensor)
    fake_A = G_BA(real_B).detach()
    
    nrows = real_A.size(0)
    real_A = make_grid(real_A, nrow=nrows, normalize=True)
    fake_B = make_grid(fake_B, nrow=nrows, normalize=True)
    real_B = make_grid(real_B, nrow=nrows, normalize=True)
    fake_A = make_grid(fake_A, nrow=nrows, normalize=True)
    
    image_grid = torch.cat((real_A, fake_B, real_B, fake_A), 1).cpu().permute(1, 2, 0)
    
    plt.figure(figsize=(figside*nrows, figside*4))
    plt.imshow(image_grid)
    plt.axis('off')
    plt.show()

In [16]:
real_A, real_B = next(iter(trainloader))
sample_images(real_A, real_B)

In [None]:
G_loss = []
D_loss = []
D_A_loss = []
D_B_loss = []
GAN_loss = []
# identity_loss = []
cycle_loss = []

for epoch in range(n_epoches):
    
    g_loss_epoch = 0
    da_loss_epoch = 0
    db_loss_epoch = 0
    gan_loss_epoch = 0
#     identity_loss_epoch = 0
    cycle_loss_epoch = 0
    
    for i, (real_A, real_B) in enumerate(trainloader):
        real_A, real_B = real_A.type(Tensor), real_B.type(Tensor)
        # groud truth
        out_shape = [real_A.size(0), 1, real_A.size(2)//D_A.scale_factor, real_A.size(3)//D_A.scale_factor]
        valid = torch.ones(out_shape).type(Tensor)
        fake = torch.zeros(out_shape).type(Tensor)
        
        """Train Generators"""
        # set to training mode in the begining, beacause sample_images will set it to eval mode
        G_AB.train()
        G_BA.train()
        
        optimizer_G.zero_grad()
        
        fake_B = G_AB(real_A)
        fake_A = G_BA(real_B)
        
        # identity loss
#         loss_id_A = criterion_identity(fake_B, real_A)
#         loss_id_B = criterion_identity(fake_A, real_B)
#         loss_identity = (loss_id_A + loss_id_B) / 2
        
        # GAN loss, train G to make D think it's true
        loss_GAN_AB = criterion_GAN(D_B(fake_B), valid)
        loss_GAN_BA = criterion_GAN(D_A(fake_A), valid)
        loss_GAN = (loss_GAN_AB + loss_GAN_BA) / 2
        
        # cycle loss
        recov_A = G_BA(fake_B)
        recov_B = G_AB(fake_A)
        loss_cycle_A = criterion_cycle(recov_A, real_A)
        loss_cycle_B = criterion_cycle(recov_B, real_B)
        loss_cycle = (loss_cycle_A + loss_cycle_B) / 2
        
        # G totol loss
        loss_G = loss_GAN + 10.0*loss_cycle
        
        loss_G.backward()
        optimizer_G.step()
        
        """Train Discriminator A"""
        optimizer_D_A.zero_grad()
        
        loss_real = criterion_GAN(D_A(real_A), valid)
        loss_fake = criterion_GAN(D_A(fake_A.detach()), fake)
        loss_D_A = (loss_real + loss_fake) / 2
        
        loss_D_A.backward()
        optimizer_D_A.step()
        
        """Train Discriminator B"""
        optimizer_D_B.zero_grad()
        
        loss_real = criterion_GAN(D_B(real_B), valid)
        loss_fake = criterion_GAN(D_B(fake_B.detach()), fake)
        loss_D_B = (loss_real + loss_fake) / 2
        
        loss_D_B.backward()
        optimizer_D_B.step()
        
        g_loss_epoch += loss_G.item()
        da_loss_epoch += loss_D_A.item()
        db_loss_epoch += loss_D_B.item()
        gan_loss_epoch += loss_GAN.item()
#         identity_loss_epoch += loss_identity.item()
        cycle_loss_epoch += loss_cycle.item()
        
    
    lr_scheduler_G.step()
    lr_scheduler_D_A.step()
    lr_scheduler_D_B.step()
    
    len_tr = len(trainloader)
    g_loss_epoch /= len_tr
    da_loss_epoch /= len_tr
    db_loss_epoch /= len_tr
    d_loss_epoch = (da_loss_epoch + db_loss_epoch)/2
    gan_loss_epoch /= len_tr
#     identity_loss_epoch /= len_tr
    cycle_loss_epoch /= len_tr
    
    G_loss.append(g_loss_epoch)
    D_loss.append(d_loss_epoch)
    D_A_loss.append(da_loss_epoch)
    D_B_loss.append(db_loss_epoch)
    GAN_loss.append(gan_loss_epoch)
#     identity_loss.append(identity_loss_epoch)
    cycle_loss.append(cycle_loss_epoch)
 
    # test
    if (epoch+1) % 2 == 0:

        print(f'[Epoch {epoch+1}/{n_epoches}]')
        print(f'[G loss: {g_loss_epoch} | GAN: {gan_loss_epoch} cycle: {cycle_loss_epoch}]')
        print(f'[D loss: {d_loss_epoch} | D_A: {da_loss_epoch} D_B: {db_loss_epoch}]')

        torch.save({
            'G_AB_state_dict': G_AB.state_dict(),
            'G_BA_state_dict': G_BA.state_dict(),
            'D_A_state_dict': D_A.state_dict(),
            'D_B_state_dict': D_B.state_dict(),
            'optimizer_G_state_dict': optimizer_G.state_dict(),
            'optimizer_D_A_state_dict': optimizer_D_A.state_dict(),
            'optimizer_D_B_state_dict': optimizer_D_B.state_dict(),
            }, "./melanomagan.pth")
    if (epoch+1) % 10 == 0:
        test_real_A, test_real_B = next(iter(testloader))
        sample_images(test_real_A, test_real_B)

In [None]:
plt.figure(figsize=(10, 8))
plt.plot(G_loss)
plt.plot(D_loss)
plt.plot(D_A_loss)
plt.plot(D_B_loss)
plt.plot(GAN_loss)
# plt.plot(identity_loss)
plt.plot(cycle_loss)

In [None]:
itr = iter(testloader)

In [None]:
test_real_A, test_real_B = next(itr)
sample_images(test_real_A, test_real_B)

In [None]:
test_real_A, test_real_B = next(itr)
sample_images(test_real_A, test_real_B)

In [None]:
test_real_A, test_real_B = next(itr)
sample_images(test_real_A, test_real_B)