# Generative Weak Segmentation

#### Import libraries

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import torchvision
import torchvision.transforms.v2 as v2

import numpy as np
import matplotlib.pyplot as plt

from tqdm.notebook import tqdm
from torch.utils.tensorboard import SummaryWriter

np.random.seed(0)
torch.manual_seed(0)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

#### Tensorboard logging

In [None]:
import pathlib

logdir = pathlib.Path('./logs/segment')
i = 1
while (logdir/f'run{i}').exists():
    i += 1
logdir = logdir/f'run{i}'
logdir.mkdir(parents=True, exist_ok=True)
writer = SummaryWriter(logdir)

print(f'Logging to: {logdir}')

#### Hyperparameters

In [None]:
import yaml

# hparams_file = ''
hparams_file = './hparams_seg.yaml'

if hparams_file:
    with open(hparams_file) as f:
        hparams = yaml.safe_load(f)
else:
    hparams = {
        'image_size': [224, 224],
        'batch_size': 32,
        'num_epochs': 10,
        'lr': 1e-4,
        # model hparams
        'recon': 1,
        'mask_reg': 1e-10,
        'cls_guide': 1e-4,
    }

writer.add_text('hparams', yaml.dump(hparams, sort_keys=False))

#### Prepare dataset

In [None]:
import os
from PIL import Image

class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, root, transforms=None):
        neg_dir = os.path.join(root, 'neg')
        pos_dir = os.path.join(root, 'pos')

        with os.scandir(neg_dir) as it:
            neg_files = [entry.path for entry in it if entry.is_file()]
        with os.scandir(pos_dir) as it:
            pos_files = [entry.path for entry in it if entry.is_file()]

        self.transforms = transforms
        self.pos_files = pos_files
        self.neg_files = neg_files
    
    def __len__(self):
        return len(self.pos_files)

    def __getitem__(self, idx):
        with Image.open(self.pos_files[idx]) as img:
            pos_img = img.copy()
        with Image.open(self.neg_files[idx]) as img:
            neg_img = img.copy()

        if self.transforms:
            pos_img = self.transforms(pos_img)
            neg_img = self.transforms(neg_img)

        return (pos_img, neg_img)

In [None]:
image_size = hparams['image_size']
batch_size = hparams['batch_size']

transforms_list = [
    v2.ToImage(),
    # v2.RandomHorizontalFlip(),
    v2.Resize(image_size),
    v2.ToDtype(torch.float, scale=True),
]
transforms_composed = v2.Compose(transforms_list)

dataset = CustomDataset('./dataset/preprocessed/', transforms=transforms_composed)
dataset_train, dataset_val = torch.utils.data.random_split(dataset, [0.8, 0.2])

train_loader = torch.utils.data.DataLoader(
    dataset_train,
    batch_size=batch_size,
    shuffle=True
)
val_loader = torch.utils.data.DataLoader(
    dataset_val,
    batch_size=batch_size,
    shuffle=False
)

In [None]:
samples = next(iter(train_loader))
pos = samples[0][:4]
neg = samples[1][:4]
print(pos.shape)
print(neg.shape)
grid_img = torchvision.utils.make_grid(torch.cat((pos, neg), dim=0), nrow=4)
plt.imshow(grid_img.permute(1, 2, 0))
plt.show()

#### Build the model

In [None]:
import cls_model
cls_model = cls_model.Classifier(num_classes=1).to(device)
cls_model.load_state_dict(torch.load('./logs/cls/run1/best_model.pth'))

import seg_model
model = seg_model.GenWeakSegNet(cls_model, num_classes=2).to(device)

from torchinfo import summary
print(summary(model, input_size=(batch_size, 3, *image_size)))

#### Visualization utilities

In [None]:
class Visualizer:
    def __init__(self, model, writer, device, batch_size=64):
        self.model = model
        self.writer = writer
        self.device = device
        self.batch_size = batch_size

        self.mask_colors = [
            (0, 0, 1),
            (0, 0, 0),
        ]
    
    def vis_samples(self, samples, step, tag):
        x_all = []
        x_hat_all = []
        y_mask_all = []

        training = self.model.training
        self.model.eval()

        with torch.no_grad():
            for i in range(0, len(samples), self.batch_size):
                x = torch.stack(samples[i:i+self.batch_size]).to(self.device)
                y_img, y_mask = self.model(x)

                # reparameterization of categorical distribution
                y_mask1 = y_mask.unsqueeze(2)
                sm = F.gumbel_softmax(y_mask1.view(-1, model.num_classes), hard=True)
                sm = sm.view(y_mask1.shape)
                # compose the image from pixelets and masks
                x_hat = torch.sum(sm * y_img, dim=1)

                x_all += [x]
                x_hat_all += [x_hat]
                y_mask_all += [torch.argmax(y_mask, dim=1)]
        
        self.model.train(training)

        x_all = torch.cat(x_all, dim=0)
        x_hat_all = torch.cat(x_hat_all, dim=0)
        y_mask_all = torch.cat(y_mask_all, dim=0)

        y_mask_all1 = y_mask_all.unsqueeze(1)
        mc_all = torch.zeros_like(x_hat_all)

        for i, color in enumerate(self.mask_colors):
            mask = (y_mask_all1 == i).float()
            # mask = (torch.randn(y_mask_all1.shape, device=self.device) > 0).float()
            color = torch.tensor(color, dtype=torch.float, device=self.device)
            color = color.view(1, 3, 1, 1)
            mc_all += mask * color
        
        writer.add_images(f'{tag}/x', x_all, step)
        writer.add_images(f'{tag}/x_hat', x_hat_all, step)
        writer.add_images(f'{tag}/y_mask', mc_all, step)

In [None]:
visualizer = Visualizer(model, writer, device, batch_size=batch_size)
n_vis = 50

vx_train = (
    [dataset_train[i][0] for i in range(n_vis)] +
    [dataset_train[i][1] for i in range(n_vis)]
)
vx_val = (
    [dataset_val[i][0] for i in range(n_vis)] +
    [dataset_val[i][1] for i in range(n_vis)]
)

visualizer.vis_samples(vx_train, 0, 'train')
visualizer.vis_samples(vx_val, 0, 'val')

#### Training and evaluation

In [None]:
def evaluate(model, dataloader):
    loss_dict = {}

    training = model.training
    model.eval()

    with torch.no_grad():
        for data in tqdm(dataloader, leave=False):
            pos, neg = data
            pos = pos.to(device)
            neg = neg.to(device)

            x = torch.cat((pos, neg), dim=0)

            y_pos = torch.tensor([1, 1], dtype=torch.float).repeat(pos.shape[0], 1)
            y_neg = torch.tensor([0, 1], dtype=torch.float).repeat(neg.shape[0], 1)
            label = torch.cat((y_pos, y_neg), dim=0).to(device)

            N = x.shape[0]

            y_img, y_mask = model(x)
            loss, loss_dict1 = model.loss_fn(x, label, y_img, y_mask)

            for k, v in loss_dict1.items():
                if k not in loss_dict:
                    loss_dict[k] = 0
                loss_dict[k] += v * N
    
    model.train(training)
    
    for k in loss_dict:
        loss_dict[k] /= len(dataloader.dataset)
    return loss_dict

##### Training

In [None]:
num_epochs = hparams['num_epochs']
optimizer = torch.optim.Adam(model.parameters(), lr=hparams['lr'])

model.hparams = hparams
model.train()
step = 0
best_val_loss = np.inf

for epoch in tqdm(range(num_epochs)):
    for data in tqdm(train_loader, leave=False):
        pos, neg = data
        pos = pos.to(device)
        neg = neg.to(device)

        x = torch.cat((pos, neg), dim=0)
        
        y_pos = torch.tensor([1, 1], dtype=torch.float).repeat(pos.shape[0], 1)
        y_neg = torch.tensor([0, 1], dtype=torch.float).repeat(neg.shape[0], 1)
        label = torch.cat((y_pos, y_neg), dim=0).to(device)
        
        y_img, y_mask = model(x)
        loss, loss_dict = model.loss_fn(x, label, y_img, y_mask)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        step += 1
        for k, v in loss_dict.items():
            writer.add_scalar(f'train/{k}', v.item(), step)
    
    val_loss_dict = evaluate(model, val_loader)
    for k, v in val_loss_dict.items():
        writer.add_scalar(f'val/{k}', v, step)
    
    # visualize
    visualizer.vis_samples(vx_train, step, 'train')
    visualizer.vis_samples(vx_val, step, 'val')

    if val_loss_dict['loss'] < best_val_loss:
        best_val_loss = val_loss_dict['loss']
        torch.save(model.state_dict(), logdir/'best_model.pth')

torch.save(model.state_dict(), logdir/'last_model.pth')

#### Evaluate on test set

In [None]:
model.load_state_dict(torch.load(logdir/'best_model.pth'))
loss_dict = evaluate(model, val_loader)
print(loss_dict)