In [2]:
import torch
import os
from os import listdir
import numpy as np
from skimage import io
from __future__ import print_function, division
import torch.utils.data
import random
import numpy as np
from skimage.io import imread
import torch.nn as nn

In [3]:
class Images_Dataset_folder(torch.utils.data.Dataset):
    """Class for getting individual transformations and data
    Args:
        images_dir = path of input images
        labels_dir = path of labeled images
        transformI = Input Images transformation (default: None)
        transformM = Input Labels transformation (default: None)
    Output:
        tx = Transformed images
        lx = Transformed labels"""

    def __init__(self, images_dir, labels_dir):
        self.images = sorted(os.listdir(images_dir))
        self.labels = sorted(os.listdir(labels_dir))
        self.images_dir = images_dir
        self.labels_dir = labels_dir
        self.images_len = len(self.images)
        self.labels_len = len(self.labels)
        self.length_dataset = max(len(self.images), len(self.labels)) # 1000, 1500
        
    def __len__(self):

        return self.length_dataset

    def __getitem__(self, i):
        img = np.array(io.imread(self.images_dir + self.images[i%self.images_len]),dtype=np.float32)
        label = np.array(io.imread(self.labels_dir + self.labels[i%self.labels_len]),dtype=np.float32)

        seed=np.random.randint(0,2**32) # make a seed with numpy generator 

        # apply this seed to img tranfsorms
        random.seed(seed) 
        torch.manual_seed(seed)
        img=(img-np.amin(img))*1.0/(np.amax(img)-np.amin(img))#img*1.0 transform array to double
        label=(label-np.amin(label))*1.0/(np.amax(label)-np.amin(label))#img*1.0 transform array to double
        img=img*1.0/np.median(img)
        label=label*1.0/np.median(label)
        img_h=img.shape[0]
        img_w=img.shape[1]
        img=np.reshape(img,(1,img_h,img_w))
        label_h=label.shape[0]
        label_w=label.shape[1]
        label=np.reshape(label,(1,label_h,label_w))
        
        return img, label
    

In [5]:
data_Inputs = ['NMuMg_phase_contrast', 'HK2_DIC', 'T47D_fluorescence']
dir_path='./dataset/'

batch_size=5
num_workers=0
epochs = 30

t_data = dir_path+data_Inputs[0]+'/train/'
l_data = dir_path+data_Inputs[1]+'/train/'

Training_Data = Images_Dataset_folder(t_data,l_data)
    
train_loader = torch.utils.data.DataLoader(Training_Data, batch_size=batch_size, num_workers=num_workers)

In [9]:
class Block(nn.Module):
    def __init__(self, in_channels, out_channels, stride):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 4, stride, 1, bias=True, padding_mode="reflect"),
            nn.InstanceNorm2d(out_channels),
            nn.LeakyReLU(0.2, inplace=True),
        )

    def forward(self, x):
        return self.conv(x)



class Discriminator(nn.Module):
    def __init__(self, in_channels=3, features=[64, 128, 256, 512]):
        super().__init__()
        self.initial = nn.Sequential(
            nn.Conv2d(
                in_channels,
                features[0],
                kernel_size=4,
                stride=2,
                padding=1,
                padding_mode="reflect",
            ),
            nn.LeakyReLU(0.2, inplace=True),
        )

        layers = []
        in_channels = features[0]
        for feature in features[1:]:
            layers.append(Block(in_channels, feature, stride=1 if feature==features[-1] else 2))
            in_channels = feature
        layers.append(nn.Conv2d(in_channels, 1, kernel_size=4, stride=1, padding=1, padding_mode="reflect"))
        self.model = nn.Sequential(*layers)

    def forward(self, x):
        x = self.initial(x)
        return self.model(x)

def test():
    x = torch.randn((5, 3, 256, 256))
    model = Discriminator(in_channels=3)
    preds = model(x)
    print(preds.shape)


if __name__ == "__main__":
    test()

torch.Size([5, 1, 30, 30])


In [10]:
class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, down=True, use_act=True, **kwargs):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, padding_mode="reflect", **kwargs)
            if down
            else nn.ConvTranspose2d(in_channels, out_channels, **kwargs),
            nn.InstanceNorm2d(out_channels),
            nn.ReLU(inplace=True) if use_act else nn.Identity()
        )

    def forward(self, x):
        return self.conv(x)

class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.block = nn.Sequential(
            ConvBlock(channels, channels, kernel_size=3, padding=1),
            ConvBlock(channels, channels, use_act=False, kernel_size=3, padding=1),
        )

    def forward(self, x):
        return x + self.block(x)

class Generator(nn.Module):
    def __init__(self, img_channels, num_features = 64, num_residuals=9):
        super().__init__()
        self.initial = nn.Sequential(
            nn.Conv2d(img_channels, num_features, kernel_size=7, stride=1, padding=3, padding_mode="reflect"),
            nn.InstanceNorm2d(num_features),
            nn.ReLU(inplace=True),
        )
        self.down_blocks = nn.ModuleList(
            [
                ConvBlock(num_features, num_features*2, kernel_size=3, stride=2, padding=1),
                ConvBlock(num_features*2, num_features*4, kernel_size=3, stride=2, padding=1),
            ]
        )
        self.res_blocks = nn.Sequential(
            *[ResidualBlock(num_features*4) for _ in range(num_residuals)]
        )
        self.up_blocks = nn.ModuleList(
            [
                ConvBlock(num_features*4, num_features*2, down=False, kernel_size=3, stride=2, padding=1, output_padding=1),
                ConvBlock(num_features*2, num_features*1, down=False, kernel_size=3, stride=2, padding=1, output_padding=1),
            ]
        )

        self.last = nn.Conv2d(num_features*1, img_channels, kernel_size=7, stride=1, padding=3, padding_mode="reflect")

    def forward(self, x):
        x = self.initial(x)
        for layer in self.down_blocks:
            x = layer(x)
        x = self.res_blocks(x)
        for layer in self.up_blocks:
            x = layer(x)
        return self.last(x)

def test():
    img_channels = 3
    img_size = 256
    x = torch.randn((2, img_channels, img_size, img_size))
    gen = Generator(img_channels, 9)
    print(gen(x).shape)

if __name__ == "__main__":
    test()

torch.Size([2, 3, 256, 256])


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
from torchvision.utils import save_image


def train_fn(disc_H, disc_Z, gen_Z, gen_H, loader, opt_disc, opt_gen, l1, mse, d_scaler, g_scaler):
    H_reals = 0
    H_fakes = 0
    loop = tqdm(loader, leave=True)

    for idx, (zebra, horse) in enumerate(loop):
        zebra = zebra.cuda()
        horse = horse.cuda()

        # Train Discriminators H and Z
        with torch.cuda.amp.autocast():
            fake_horse = gen_H(zebra)
            D_H_real = disc_H(horse)
            D_H_fake = disc_H(fake_horse.detach())
            H_reals += D_H_real.mean().item()
            H_fakes += D_H_fake.mean().item()
            D_H_real_loss = mse(D_H_real, torch.ones_like(D_H_real))
            D_H_fake_loss = mse(D_H_fake, torch.zeros_like(D_H_fake))
            D_H_loss = D_H_real_loss + D_H_fake_loss

            fake_zebra = gen_Z(horse)
            D_Z_real = disc_Z(zebra)
            D_Z_fake = disc_Z(fake_zebra.detach())
            D_Z_real_loss = mse(D_Z_real, torch.ones_like(D_Z_real))
            D_Z_fake_loss = mse(D_Z_fake, torch.zeros_like(D_Z_fake))
            D_Z_loss = D_Z_real_loss + D_Z_fake_loss

            # put it togethor
            D_loss = (D_H_loss + D_Z_loss)/2

        opt_disc.zero_grad()
        d_scaler.scale(D_loss).backward()
        d_scaler.step(opt_disc)
        d_scaler.update()

        # Train Generators H and Z
        with torch.cuda.amp.autocast():
            # adversarial loss for both generators
            D_H_fake = disc_H(fake_horse)
            D_Z_fake = disc_Z(fake_zebra)
            loss_G_H = mse(D_H_fake, torch.ones_like(D_H_fake))
            loss_G_Z = mse(D_Z_fake, torch.ones_like(D_Z_fake))

            # cycle loss
            cycle_zebra = gen_Z(fake_horse)
            cycle_horse = gen_H(fake_zebra)
            cycle_zebra_loss = l1(zebra, cycle_zebra)
            cycle_horse_loss = l1(horse, cycle_horse)

            # identity loss (remove these for efficiency if you set lambda_identity=0)
            identity_zebra = gen_Z(zebra)
            identity_horse = gen_H(horse)
            identity_zebra_loss = l1(zebra, identity_zebra)
            identity_horse_loss = l1(horse, identity_horse)

            # add all togethor
            G_loss = (
                loss_G_Z
                + loss_G_H
                + cycle_zebra_loss * 10
                + cycle_horse_loss * 10
                + identity_horse_loss * 0.0
                + identity_zebra_loss * 0.0
            )

        opt_gen.zero_grad()
        g_scaler.scale(G_loss).backward()
        g_scaler.step(opt_gen)
        g_scaler.update()

        if idx % 10 == 0:
            save_image(fake_horse*0.5+0.5, f"./horse_{idx}.png")
            save_image(fake_zebra*0.5+0.5, f"./zebra_{idx}.png")

        loop.set_postfix(H_real=H_reals/(idx+1), H_fake=H_fakes/(idx+1))




with torch.cuda.device(5):
    disc_H = Discriminator(in_channels=1).cuda()
    disc_Z = Discriminator(in_channels=1).cuda()
    gen_Z = Generator(img_channels=1, num_residuals=9).cuda()
    gen_H = Generator(img_channels=1, num_residuals=9).cuda()
    opt_disc = optim.Adam(
        list(disc_H.parameters()) + list(disc_Z.parameters()),
        lr=1e-5,
        betas=(0.5, 0.999),
    )

    opt_gen = optim.Adam(
        list(gen_Z.parameters()) + list(gen_H.parameters()),
        lr=1e-5,
        betas=(0.5, 0.999),
    )

    L1 = nn.L1Loss()
    mse = nn.MSELoss()

    g_scaler = torch.cuda.amp.GradScaler()
    d_scaler = torch.cuda.amp.GradScaler()

    for epoch in range(epochs):
        print('epoch:',epoch)
        train_fn(disc_H, disc_Z, gen_Z, gen_H, train_loader, opt_disc, opt_gen, L1, mse, d_scaler, g_scaler)
        

In [14]:
pred_array = []
out_idx = 1

path_test_gt=dir_path+data_Inputs[out_idx]+'/train/Img'
gt_list=sorted(listdir(path_test_gt))
i=0
while (i<len(gt_list)):
    label = np.array(imread(path_test_gt+'/'+ gt_list[i]),dtype=np.float32)
    label=(label-np.amin(label))*1.0/(np.amax(label)-np.amin(label))#img*1.0 transform array to double
    label=label*1.0/np.median(label)
    label_h=label.shape[0]
    label_w=label.shape[1]
    label=np.reshape(label,(1, 1,label_h,label_w))
    y1 = torch.tensor(label)
    gen_Z.eval()
    with torch.no_grad():      
        y1 = y1.to('cuda')
        y_pred1 = gen_Z(y1)
        y_eval = np.squeeze(y_pred1)
        y_eval = y_eval.cpu().detach().numpy()
        label = np.squeeze(label)
    
    i += 1
    pred_array.append(y_eval)

In [213]:
p = np.array(pred_array)
with open('./dataset/HK2_NMuMg.npy', 'wb') as f:
    np.save(f, p)