In [None]:
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.sampler import SubsetRandomSampler
import os
import random
import itertools
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt

In [None]:
class ResidualBlock(nn.Module):
    def __init__(self, in_channels):
        super(ResidualBlock, self).__init__()
        self.block = nn.Sequential(
            nn.ReflectionPad2d(1), # padding, keep the image size constant after next conv2d
            nn.Conv2d(in_channels, in_channels, 3),
            nn.InstanceNorm2d(in_channels),
            nn.ReLU(inplace=True),
            nn.ReflectionPad2d(1),
            nn.Conv2d(in_channels, in_channels, 3),
            nn.InstanceNorm2d(in_channels)
        )

    def forward(self, x):
        return x + self.block(x)

In [None]:
class GeneratorResNet(nn.Module):
    def __init__(self, in_channels, num_residual_blocks=9):
        super(GeneratorResNet, self).__init__()
        
        # Inital Convolution  3*224*224 -> 64*224*224
        out_channels=64
        self.conv = nn.Sequential(
            nn.ReflectionPad2d(in_channels), # padding, keep the image size constant after next conv2d
            nn.Conv2d(in_channels, out_channels, 2*in_channels+1),
            nn.InstanceNorm2d(out_channels),
            nn.ReLU(inplace=True),
        )
        
        channels = out_channels
        
        # Downsampling   64*224*224 -> 128*112*112 -> 256*56*56
        self.down = []
        for _ in range(2):
            out_channels = channels * 2
            self.down += [
                nn.Conv2d(channels, out_channels, 3, stride=2, padding=1),
                nn.InstanceNorm2d(out_channels),
                nn.ReLU(inplace=True),
            ]
            channels = out_channels
        self.down = nn.Sequential(*self.down)
        
        # Transformation (ResNet)  256*56*56
        self.trans = [ResidualBlock(channels) for _ in range(num_residual_blocks)]
        self.trans = nn.Sequential(*self.trans)
        
        # Upsampling  256*56*56 -> 128*112*112 -> 64*224*224
        self.up = []
        for _ in range(2):
            out_channels = channels // 2
            self.up += [
                nn.Upsample(scale_factor=2), # bilinear interpolation
                nn.Conv2d(channels, out_channels, 3, stride=1, padding=1),
                nn.InstanceNorm2d(out_channels),
                nn.ReLU(inplace=True),
            ]
            channels = out_channels
        self.up = nn.Sequential(*self.up)
        
        # Out layer  64*224*224 -> 3*224*224
        self.out = nn.Sequential(
            nn.ReflectionPad2d(in_channels),
            nn.Conv2d(channels, in_channels, 2*in_channels+1),
            nn.Tanh()
        )
    
    def forward(self, x):
        x = self.conv(x)
        x = self.down(x)
        x = self.trans(x)
        x = self.up(x)
        x = self.out(x)
        return x

In [None]:
class Discriminator(nn.Module):
    def __init__(self, in_channels):
        super(Discriminator, self).__init__()
        
        self.model = nn.Sequential(
            # why normalize=False?
            *self.block(in_channels, 64, normalize=False), # 3*224*224 -> 64*112*112 
            *self.block(64, 128),  # 64*112*112 -> 128*56*56
            *self.block(128, 256), # 128*56*56 -> 256*28*28
            *self.block(256, 512), # 256*28*28 -> 512*14*14
            
            # Why padding first then convolution?
            nn.ZeroPad2d((1,0,1,0)), # padding left and top   512*14*14 -> 512*15*15
            nn.Conv2d(512, 1, 4, padding=1) # 512*15*15 -> 1*14*14
        )
        
        self.scale_factor = 16
    
    @staticmethod
    def block(in_channels, out_channels, normalize=True):
        layers = [nn.Conv2d(in_channels, out_channels, 4, stride=2, padding=1)]
        if normalize:
            layers.append(nn.InstanceNorm2d(out_channels))
        layers.append(nn.LeakyReLU(0.2, inplace=True))
        
        return layers
        
    def forward(self, x):
        return self.model(x)

In [None]:
G_AB = GeneratorResNet(3, num_residual_blocks=9)
G_BA = GeneratorResNet(3, num_residual_blocks=9)

In [None]:
checkpoint = torch.load("../input/trainedmodel1/melanomagan_config_3_1.pth", map_location=torch.device('cpu'))
G_AB.load_state_dict(checkpoint['G_AB_state_dict'])
G_BA.load_state_dict(checkpoint['G_BA_state_dict'])

In [None]:
cuda = torch.cuda.is_available()
print(f'cuda: {cuda}')
if cuda:
    G_AB = G_AB.cuda()
    G_BA = G_BA.cuda()

In [None]:
G_AB.eval()
G_BA.eval()

In [None]:
benign_dir = '../input/melanoma/Melanoma/train/benign'
malign_dir = '../input/melanoma/Melanoma/train/malignant'

In [None]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

In [None]:
imgs = os.listdir(benign_dir)

In [None]:
for i in imgs:
    plt.figure(figsize=(15, 12))
    path = benign_dir + '/' + i
    img = Image.open(path)
    plt.subplot(1, 3, 1)
    plt.imshow(img)
    plt.axis('off')
    
    transformed_img = transform(img)
    transformed_img = torch.unsqueeze(transformed_img, 0)
    
    generated_img = G_AB(transformed_img)
    reconstructed_img = G_BA(generated_img)
    
    generated_img = torch.squeeze(generated_img.detach(), 0)
    generated_img = ((generated_img * 0.5)+ 0.5)
    plt.subplot(1, 3, 2)
    plt.imshow(generated_img.permute(1, 2, 0))
    plt.axis('off')
    
    reconstructed_img = torch.squeeze(reconstructed_img.detach(), 0)
    reconstructed_img = ((reconstructed_img * 0.5)+ 0.5)
    plt.subplot(1, 3, 3)
    plt.imshow(reconstructed_img.permute(1, 2, 0))
    plt.axis('off')
    
    

In [2]:
import os
path = './generated_malign'
if not os.path.exists(path):
    os.mkdir(path)

In [None]:
for i in imgs:
    path_ip = benign_dir + '/' + i
    img = Image.open(path_ip)
    transformed_img = transform(img)
    transformed_img = torch.unsqueeze(transformed_img, 0)
    generated_img = G_AB(transformed_img)
    generated_img = torch.squeeze(generated_img.detach(), 0)
    generated_img = ((generated_img * 0.5)+ 0.5)
    generated_img = generated_img.permute(1, 2, 0).numpy()
    generated_img = generated_img * 255.0
    generated_img = generated_img.astype(np.uint8)
    save_img = Image.fromarray(generated_img)
    save_img.save(path + '/generated_' + i)

In [3]:
import shutil
shutil.make_archive('generated_malign', 'zip', path)
# shutil.rmtree(folderlocation)