In [60]:
import torchvision.datasets
from torch.utils.data import Dataset
import torch
import pandas as pd
import os
import numpy as np
import glob
import matplotlib.pyplot as plt
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.data.dataset import T_co
import random

In [61]:
annot_root = "Data/train"
images_only_root = "Data/train_semi_supervised"
test_root = "Data/test"
train_csv = "Data/train.csv"
epsilon = np.finfo(float).eps
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
batch_size = 4
device

device(type='cuda', index=0)

In [62]:
def process_df(df: pd.DataFrame) -> pd.DataFrame:
	df = df.set_index('id')
	annotations = df.groupby(['id'])['annotation'].transform(
		lambda x: ' '.join(x)).drop_duplicates()
	types = df.groupby(['id'])['cell_type'].unique()
	return annotations.to_frame().join(types)

def reconstruct_annotations_image(annotations, height, width) -> np.array:
	annotations = np.array(annotations.split(" ")).astype(int)
	image = np.zeros(shape=(height * width))
	for i in np.arange(0, len(annotations), 2):
		begin = annotations[i]
		length = annotations[i+1] - 1
		image[begin: begin+length] = 1
	return image.reshape(height, width)

def ToTensor(sample):
	image, ann_image = sample['image'], sample['annotation']
	# swap color axis because
	# numpy image: H x W x C
	# torch image: C x H x W
	image = image.transpose((2, 0, 1))
	ann_image = ann_image.transpose((2, 0, 1))
	return {'image': torch.from_numpy(image),
			'annotation': torch.from_numpy(ann_image)}

class RandomChoiceTransform(torch.nn.Module):
	def __init__(self, transforms):
	   super().__init__()
	   self.transforms = transforms

	def __call__(self, imgs):
		t = random.choice(self.transforms)
		return [t(img) for img in imgs]

class CellsAnnTrainDataSet(Dataset):
	def __init__(self, root_dir, csv_file, color_transform=None, shape_transform=None):
		self.df = None
		df = pd.read_csv(csv_file)
		self.df = process_df(df)
		self.root_dir = root_dir
		self.color_transform = color_transform
		self.shape_transform = shape_transform

	def __len__(self):
		return len(self.df)

	def __getitem__(self, index):
		if torch.is_tensor(index):
			index = index.tolist()
		image_id = self.df.iloc[index].name
		image_name = f"{image_id}.png"
		image_path = os.path.join(self.root_dir, image_name)
		image = plt.imread(image_path)[:,:,np.newaxis]
		width, height = image.shape[:2]
		annotations = self.df.iloc[index]['annotation']
		ann_image = reconstruct_annotations_image(annotations, width, height)[:,:,np.newaxis]
		if self.color_transform is not None:
			image = self.color_transform(image)
		if self.shape_transform is not None:
			image = self.shape_transform(image)
			ann_image = self.shape_transform(ann_image)
		return ToTensor({'image': image, "annotation": ann_image})

In [63]:
image_w_ann_data = CellsAnnTrainDataSet(annot_root, train_csv)
image_only_data = torchvision.datasets.ImageFolder(images_only_root, transform=[transforms.ToTensor()])
images_w_ann_loader = torch.utils.data.DataLoader(image_w_ann_data, batch_size = batch_size, shuffle=True)
images_only_loader = torch.utils.data.DataLoader(image_only_data, batch_size = batch_size, shuffle=True)

sample_image, sample_ann = image_w_ann_data[0]['image'], image_w_ann_data[0]['annotation']

In [64]:
def formatForPlot(*args):
	outputs = []
	for arg in args:
		if torch.is_tensor(arg):
			arg =  arg.detach().cpu().numpy()
		outputs.append(arg.squeeze())
	return outputs if len(outputs) > 1 else outputs[0] if len(outputs) > 0 else None

def plot_generator(orig_image, orig_ann, fake_ann):
	orig_image, orig_ann, fake_ann = formatForPlot(orig_image, orig_ann, fake_ann)
	fig = plt.figure(figsize=(15,5))
	ax = fig.subplots(ncols=3)
	ax[0].imshow(orig_image.squeeze(), cmap='gray')
	ax[1].imshow(orig_ann.squeeze(), cmap='gray')
	ax[2].imshow(fake_ann.squeeze(), cmap='gray')
	ax[0].set_title("image")
	ax[1].set_title("annotation")
	ax[1].set_title("fake annotation")
	plt.show()

In [None]:
from functools import reduce
from operator import __add__
class Conv2dSamePadding(nn.Conv2d):
    def __init__(self,*args,**kwargs):
        super(Conv2dSamePadding, self).__init__(*args, **kwargs)
        self.zero_pad_2d = nn.ZeroPad2d(reduce(__add__,
            [(k // 2 + (k - 2 * (k // 2)) - 1, k // 2) for k in self.kernel_size[::-1]]))

    def forward(self, input):
        return  self._conv_forward(self.zero_pad_2d(input), self.weight, self.bias)

In [65]:
def init_normal_weights(m, mean=0, std=0.02):
	if isinstance(m, nn.Conv2d):
		torch.nn.init.normal_(m.weight, mean=mean, std=std)

class Discriminator(nn.Module):
	def __init__(self):
		super(Discriminator, self).__init__()
		self.conv1 = nn.Conv2d(2, 64, kernel_size=(4,4), stride=(2,2), padding="same")
		self.conv2 = nn.Conv2d(64, 128, kernel_size=(4,4), stride=(2,2), padding="same")
		self.batchNorm2 = nn.BatchNorm2d(128)
		self.conv3 = nn.Conv2d(128, 256, kernel_size=(4,4), stride=(2,2), padding="same")
		self.batchNorm3 = nn.BatchNorm2d(256)
		self.conv4 = nn.Conv2d(256, 512, kernel_size=(4,4), stride=(2,2), padding="same")
		self.batchNorm4 = nn.BatchNorm2d(512)
		self.conv5 = nn.Conv2d(512, 512, kernel_size=(4,4), stride=(2,2), padding="same")
		self.batchNorm5 = nn.BatchNorm2d(512)
		self.conv6 = nn.Conv2d(512, 1, kernel_size=(4,4), stride=(2,2), padding="same")
		self.fc = nn.Linear(512, 1)
		for conv in [self.conv1, self.conv2, self.conv3, self.conv4, self.conv5, self.conv6]:
			init_normal_weights(conv)


	def forward(self, x, y):
		z = torch.cat((x, y), dim=1)
		z = F.leaky_relu(self.conv1(z), negative_slope=0.2, inplace=True)
		z = F.leaky_relu(self.batchNorm2(self.conv2(z)), negative_slope=0.2, inplace=True)
		z = F.leaky_relu(self.batchNorm3(self.conv3(z)), negative_slope=0.2, inplace=True)
		z = F.leaky_relu(self.batchNorm4(self.conv4(z)), negative_slope=0.2, inplace=True)
		z = F.leaky_relu(self.batchNorm5(self.conv5(z)), negative_slope=0.2, inplace=True)
		z = self.conv6(z)
		z = self.fc(z)
		return torch.sigmoid(z)

class EncoderBlock(nn.Module):
	def __init__(self, in_ch, out_ch, batchNorm=True):
		super(EncoderBlock, self).__init__()
		layers = [nn.Conv2d(in_ch, out_ch, kernel_size=(4,4), stride=(2,2), padding="same")]
		if batchNorm:
			layers.append(nn.BatchNorm2d(out_ch))
		layers.append(nn.LeakyReLU(negative_slope=0.2, inplace=True))
		init_normal_weights(layers[0])
		self.layers = nn.Sequential(*layers)

	def forward(self, x):
		return self.layers(x)

class DecoderBlock(nn.Module):
	def __init__(self, in_ch, out_ch, dropout = True):
		super(DecoderBlock, self).__init__()
		layers = [
			nn.ConvTranspose2d(in_ch, out_ch, kernel_size=(4,4), stride=(2,2), padding='same'),
			nn.BatchNorm2d(out_ch)
		]
		if dropout:
			layers.append(nn.Dropout(0.5))
		init_normal_weights(layers[0])
		self.layers = nn.Sequential(*layers)

	def forward(self, x, skip_in):
		x = self.layers(x)
		return F.relu(torch.cat((x, skip_in), dim=1))

class Generator(nn.Module):
	def __init__(self):
		super(Generator, self).__init__()
		self.e1 = EncoderBlock(1, 64, batchNorm=False)
		self.e2 = EncoderBlock(64, 128)
		self.e3 = EncoderBlock(128, 256)
		self.e4 = EncoderBlock(256, 512)
		self.e5 = EncoderBlock(512, 512)
		self.e6 = EncoderBlock(512, 512)
		self.e7 = EncoderBlock(512, 512)
		self.bottleneck = nn.Sequential(
			nn.Conv2d(512, 512, kernel_size=(4,4), stride=(2,2), padding='same'),
			nn.ReLU(inplace=True)
		)
		init_normal_weights(self.bottleneck[0])
		self.d1 = DecoderBlock(512, 512)
		self.d2 = DecoderBlock(512, 512)
		self.d3 = DecoderBlock(512, 512)
		self.d4 = DecoderBlock(512, 512, dropout=False)
		self.d5 = DecoderBlock(512, 256, dropout=False)
		self.d6 = DecoderBlock(256, 128, dropout=False)
		self.d7 = DecoderBlock(128, 64, dropout=False)
		self.output = nn.Sequential(
			nn.ConvTranspose2d(64, 1, kernel_size=(4,4), stride=(2,2), padding='same'),
			nn.Tanh()
		)
		init_normal_weights(self.output[0])

	def forward(self, x):
		# encoding:
		e1 = self.e1(x)
		e2 = self.e2(e1)
		e3 = self.e3(e2)
		e4 = self.e4(e3)
		e5 = self.e5(e4)
		e6 = self.e6(e5)
		e7 = self.e7(e6)

		# bottleneck:
		b = self.bottleneck(e7)

		# decoding:
		d1 = self.d1(b, e7)
		d2 = self.d2(d1, e6)
		d3 = self.d3(d2, e5)
		d4 = self.d4(d3, e4)
		d5 = self.d5(d4, e3)
		d6 = self.d6(d5, e2)
		d7 = self.d7(d6, e1)
		return self.output(d7)

In [66]:
class CellGAN(nn.Module):
	def __init__(self):
		super(CellGAN, self).__init__()
		self.D = Discriminator()
		self.G = Generator()
		self.to(device)

	def forward(self, x):
		self.G.eval()
		return self.G(x)


	def G_loss(self, image):
		self.G.train()
		self.D.eval()
		fake_ann = self.G(image)
		D_fake = self.D(image, fake_ann)
		return -torch.mean(torch.log(D_fake + epsilon))

	def D_loss(self, image, real_ann = None):
		self.D.train()
		self.G.eval()
		fake_ann = self.G(image)
		D_fake = self.D(image, fake_ann)
		if real_ann is not None:
			D_real = self.D(image, real_ann)
			return -torch.mean(torch.log(D_real + epsilon) + torch.log(1-D_fake + epsilon))
		else:
			return -torch.mean(torch.log(1-D_fake + epsilon))

	def run_training(self,epochs: int = 10, G_lr: int = 2e-4, D_lr: int = 2e-4, plot:bool = True):
		D_Opt = torch.optim.Adam(params=self.D.parameters(), lr=D_lr, betas=(0.5,0.5))
		G_Opt = torch.optim.Adam(params=self.G.parameters(), lr=G_lr, betas=(0.5,0.5))
		d_losses, g_losses = [], []
		for epoch in range(1, epochs+1):
			images_only_iter = iter(images_only_loader)
			d_epoch_losses, g_epoch_losses = [], []
			for i, batch in enumerate(images_w_ann_loader, 0):
				image, ann = batch['image'].to(device), batch['annotation'].to(device)
				d_step_loss, g_step_loss = 0, 0
				# train discriminator:
				d_loss = self.D_loss(image, ann)
				D_Opt.zero_grad()
				d_loss.backward()
				D_Opt.step()
				d_step_loss += d_loss.item() / 2

				# train generator:
				g_loss = self.G_loss(image)
				G_Opt.zero_grad()
				g_loss.backward()
				G_Opt.step()
				g_step_loss += g_loss.item() / 2

				# semi supervised:
				for _ in range(3):
					step_d_losses, step_g_losses = [],[]
					image = next(images_only_iter)

					# train discriminator:
					d_loss = self.D_loss(image)
					D_Opt.zero_grad()
					d_loss.backward()
					D_Opt.step()
					g_step_loss += d_loss.item() / 6

					# train generator:
					g_loss = self.G_loss(image)
					G_Opt.zero_grad()
					g_loss.backward()
					G_Opt.step()
					g_step_loss += g_loss.item() / 6
				d_epoch_losses.append(d_step_loss)
				g_epoch_losses.append(g_step_loss)

			d_losses.append(np.mean(d_epoch_losses))
			g_losses.append(np.mean(g_epoch_losses))
			print(f'epoch [{epoch}/{epochs}],'
				  f' generator loss:{g_losses[epoch-1].round(3)},'
				  f' discriminator loss:{d_losses[epoch-1].round(3)}')
			if plot:
				self.G.eval()
				with torch.no_grad:
					fake_ann = self.G(sample_image)
					plot_generator(sample_image, sample_ann, fake_ann)

In [67]:
gan = CellGAN()
gan.run_training()

torch.Size([4, 1, 520, 704])
torch.Size([4, 64, 259, 351])
torch.Size([4, 128, 128, 174])
torch.Size([4, 256, 63, 86])
torch.Size([4, 512, 30, 42])


RuntimeError: Sizes of tensors must match except in dimension 2. Got 63 and 62 (The offending index is 0)