# Striving for Simplicity: Simple Yet Effective Prior-Aware Pseudo-Labeling for Semi-Supervised Ultrasound Image Segmentation

* Paper Code : 4867
* Paper Link : https://papers.miccai.org/miccai-2024/735-Paper2948.html
* Reproduction Level : 3 (25 points)
* Github Link : https://github.com/prachuryanath/SPSS-Reproduction

The paper presents a straightforward yet powerful pseudo-labeling technique for semi-supervised ultrasound image segmentation, effectively tackling the challenges of having limited labeled data and dealing with anatomical inaccuracies.

Instead of relying on complex methods, the proposed encoder-twin-decoder network uses an adversarially learned shape prior to ensure the segmentations are both anatomically accurate and aligned with the ground truth. This simple approach outperforms state-of-the-art techniques on two benchmarks, offering a solid foundation for future research in semi-supervised medical image segmentation. By striking a balance between labeled and unlabeled data, the method enhances the precision and usability of automated ultrasound analysis.

### By - Prachurya Nath

<div>
<img src="https://github.com/WUTCM-Lab/Shape-Prior-Semi-Seg/assets/155703366/029950f0-78a2-400a-91a4-837d8166a1cd" width="750" />
</div>

## Hardware Comments
* Graphics Used : A100 (40 GB)
* Training Time GAN : 19 hrs 52 mins
* Training Time CNN Model : 6 hrs 57 mins

## Figure to Reproduce : Fig 3

<div>
<img src="https://raw.githubusercontent.com/prachuryanath/WBC-NCA---Reproduction/refs/heads/main/images/spss_table.jpg" width="600" />
</div>

### Steps to download dataset

* Extract it into tn3k folder

In [10]:
!wget --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id=1reHyY5eTZ5uePXMVMzFOq5j3eFOSp50F' -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id=1reHyY5eTZ5uePXMVMzFOq5j3eFOSp50F" -O thyroid.zip && rm -rf /tmp/cookies.txt

--2025-02-18 19:22:57--  https://docs.google.com/uc?export=download&confirm=&id=1reHyY5eTZ5uePXMVMzFOq5j3eFOSp50F
Resolving docs.google.com (docs.google.com)... 2a00:1450:4001:831::200e, 142.250.184.238
Connecting to docs.google.com (docs.google.com)|2a00:1450:4001:831::200e|:443... connected.
HTTP request sent, awaiting response... 303 See Other
Location: https://drive.usercontent.google.com/download?id=1reHyY5eTZ5uePXMVMzFOq5j3eFOSp50F&export=download [following]
--2025-02-18 19:22:57--  https://drive.usercontent.google.com/download?id=1reHyY5eTZ5uePXMVMzFOq5j3eFOSp50F&export=download
Resolving drive.usercontent.google.com (drive.usercontent.google.com)... 2a00:1450:4001:81d::2001, 216.58.206.65
Connecting to drive.usercontent.google.com (drive.usercontent.google.com)|2a00:1450:4001:81d::2001|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 2431 (2.4K) [text/html]
Saving to: ‘thyroid.zip’


2025-02-18 19:22:57 (27.5 MB/s) - ‘thyroid.zip’ saved [2431/2431]



## Create environment and install libraries

* python -m venv .spss
* source spss/bin/activate
* pip install -r requirements.txt

## Import libraries

In [1]:
# Import necessary libraries
from __future__ import print_function
import argparse
import random
import torch
import yaml
import warnings
import os
import json

import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
from torch.nn import functional as F
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
from torch.autograd import Variable

from GAN.data.tn3k import tn3kDataSet
import GAN.models.dcgan as dcgan
import GAN.models.mlp as mlp
import warnings
from torch.nn import functional as F

# Ignore UserWarnings to avoid clutter in output
warnings.filterwarnings("ignore", category=UserWarning)

## GAN

**From the paper :** We first train a GAN leveraging existing labels. Specifically, we resize all ground truth segmentation masks into a fixed size of 64 × 64 and set batch size to 16. We optimize the generator and discriminator using two RMSprop optimizers with a learning rate of 0.00005, and the total number of epochs is set to 5,000.

## Model configuration

In [5]:
# Load config file
with open("config_gan.yaml", "r") as f:
    config = yaml.safe_load(f)
    
from types import SimpleNamespace

opt = SimpleNamespace(**config)
print(opt) 

namespace(dataset='tn3k', dataroot='tn3k/', workers=2, batchSize=16, imageSize=64, nc=1, nz=100, ngf=64, ndf=64, niter=5001, lrD=5e-05, lrG=5e-05, beta1=0.5, ngpu=2, netG='', netD='', clamp_lower=-0.01, clamp_upper=0.01, Diters=5, noBN=False, mlp_G=False, mlp_D=False, n_extra_layers=0, experiment='result', adam=False, root='', expID=1)


In [6]:
# Create result directory if not existing
if opt.experiment is None:
    opt.experiment = 'samples'
os.system('mkdir {0}'.format(opt.experiment))

# Set random seed for reproducibility
opt.manualSeed = random.randint(1, 10000) # fix seed
print("Random Seed: ", opt.manualSeed)
random.seed(opt.manualSeed)
torch.manual_seed(opt.manualSeed)

cudnn.benchmark = True

Random Seed:  5739


## Preprocessing

In [7]:
# Load the appropriate dataset
if opt.dataset == 'tn3k':
    dataset = tn3kDataSet(opt.root, opt.expID, mode='train')
assert dataset

# DataLoader for batching and shuffling the dataset
dataloader = torch.utils.data.DataLoader(dataset, batch_size=opt.batchSize,
                                        shuffle=True, num_workers=int(opt.workers))

# Define model hyperparameters
ngpu = int(opt.ngpu) # number of gpu #1
nz = int(opt.nz) # size of the latent z vector #100 
ngf = int(opt.ngf) #64
ndf = int(opt.ndf) #64
nc = int(opt.nc) #input images channels #3
n_extra_layers = int(opt.n_extra_layers) #Number of extra layers on gen and disc #0

In [8]:
# custom weights initialization called on netG and netD
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        m.weight.data.normal_(0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)

# Choose Generator architecture based on options
if opt.noBN:
    netG = dcgan.DCGAN_G_nobn(opt.imageSize, nz, nc, ngf, ngpu, n_extra_layers)
elif opt.mlp_G:
    netG = mlp.MLP_G(opt.imageSize, nz, nc, ngf, ngpu)
else:
    netG = dcgan.DCGAN_G(opt.imageSize, nz, nc, ngf, ngpu, n_extra_layers)

In [9]:
# Save Generator configuration to JSON
generator_config = {"imageSize": opt.imageSize, "nz": nz, "nc": nc, "ngf": ngf, "ngpu": ngpu, "n_extra_layers": n_extra_layers, "noBN": opt.noBN, "mlp_G": opt.mlp_G}

with open(os.path.join(opt.experiment, "generator_config.json"), 'w') as gcfg:
    gcfg.write(json.dumps(generator_config)+"\n")

## Initializing the model

In [10]:
# Initialize Generator weights
netG.apply(weights_init)
if opt.netG != '': # load checkpoint if needed
    netG.load_state_dict(torch.load(opt.netG))

# Choose Discriminator architecture based on options    
if opt.mlp_D:
    netD = mlp.MLP_D(opt.imageSize, nz, nc, ndf, ngpu)
else:
    netD = dcgan.DCGAN_D(opt.imageSize, nz, nc, ndf, ngpu, n_extra_layers)
    netD.apply(weights_init)

if opt.netD != '':
    netD.load_state_dict(torch.load(opt.netD))

# Setup input tensors for real and fake data
input = torch.FloatTensor(opt.batchSize, 3, opt.imageSize, opt.imageSize)
noise = torch.FloatTensor(opt.batchSize, nz, 1, 1)
fixed_noise = torch.FloatTensor(opt.batchSize, nz, 1, 1).normal_(0, 1)
one = torch.FloatTensor([1])
mone = one * -1

In [11]:
# Enable GPU usage if available
if torch.cuda.is_available():
    print("using cuda ===================================== ")
    print("No of cuda devices :", torch.cuda.device_count())
    netD = nn.DataParallel(netD, device_ids=[0])
    netD = netD.cuda(0)
    netG = nn.DataParallel(netG, device_ids=[0])
    netG.cuda(0)
    input = input.cuda(0)
    one, mone = one.cuda(0), mone.cuda(0)
    noise, fixed_noise = noise.cuda(0), fixed_noise.cuda(0)

No of cuda devices : 1


In [12]:
# setup optimizer
if opt.adam:
    optimizerD = optim.Adam(netD.parameters(), lr=opt.lrD, betas=(opt.beta1, 0.999))
    optimizerG = optim.Adam(netG.parameters(), lr=opt.lrG, betas=(opt.beta1, 0.999))
else:
    optimizerD = optim.RMSprop(netD.parameters(), lr = opt.lrD)
    optimizerG = optim.RMSprop(netG.parameters(), lr = opt.lrG)

## Training the GAN model

In [None]:
# Initialize generator iterations counter
gen_iterations = 0
for epoch in range(opt.niter):
    data_iter = iter(dataloader)
    i = 0
    while i < len(dataloader):
        # Discriminator (D) network
        # Enable gradient computation for the discriminator parameters
        for p in netD.parameters(): 
            p.requires_grad = True

        # Set the number of discriminator iterations based on the current training step
        if gen_iterations < 25 or gen_iterations % 500 == 0:
            Diters = 100
        else:
            Diters = opt.Diters
            
        # Train the discriminator Diters times per generator update
        j = 0
        while j < Diters and i < len(dataloader):
            j += 1

            # clamp parameters to a cube
            for p in netD.parameters():
                p.data.clamp_(opt.clamp_lower, opt.clamp_upper)

            data = next(data_iter)
            i += 1

            # train with real images
            real = data['label']
            real_cpu = F.interpolate(real, size=(64, 64), mode='bilinear', align_corners=False)
            netD.zero_grad()
            batch_size = real_cpu.size(0)

            if torch.cuda.is_available():
                real_cpu = real_cpu.cuda(0)
            input.resize_as_(real_cpu).copy_(real_cpu)
            inputv = Variable(input)
            
            # Compute loss for real images
            errD_real = netD(inputv)
            errD_real.backward(one)

            # train with fake images generated by the generator
            noise.resize_(opt.batchSize, nz, 1, 1).normal_(0, 1)
            noisev = Variable(noise, volatile = True) # totally freeze netG
            fake = Variable(netG(noisev).data)
            inputv = fake
            
            # Compute discriminator's error for fake images
            errD_fake = netD(inputv)
            errD_fake.backward(mone)
            errD = errD_real - errD_fake
            optimizerD.step()
            
        # Generator (G) Network
        for p in netD.parameters():
            p.requires_grad = False # to avoid computation
        netG.zero_grad()
        
        # Generate new noise to feed into the generator
        noise.resize_(opt.batchSize, nz, 1, 1).normal_(0, 1)
        noisev = Variable(noise)
        fake = netG(noisev)

        # Compute generator's error based on discriminator's response
        errG = netD(fake)
        errG.backward(one)
        optimizerG.step()
        gen_iterations += 1

        # Print loss values for the current step
        print('[%d/%d][%d/%d][%d] Loss_D: %f Loss_G: %f Loss_D_real: %f Loss_D_fake %f'
            % (epoch, opt.niter, i, len(dataloader), gen_iterations,
            errD.data[0], errG.data[0], errD_real.data[0], errD_fake.data[0]))

        # Save samples of real and fake images every 500 iterations
        if gen_iterations % 500 == 0:
            real = real.mul(0.5).add(0.5)
            vutils.save_image(real, '{0}/real_samples.png'.format(opt.experiment))
            fake = netG(Variable(fixed_noise, volatile=True))
            fake = F.interpolate(fake, size=(256, 256), mode='bilinear', align_corners=False)
            fake.data = fake.data.mul(0.5).add(0.5)
            vutils.save_image(fake.data, '{0}/fake_samples_{1}.png'.format(opt.experiment, gen_iterations))

    # Checkpointing: Save model weights every 1000 epochs
    if epoch % 1000 == 0:   
        torch.save(netG.state_dict(), '{0}/netG_epoch_{1}.pth'.format(opt.experiment, epoch))
        torch.save(netD.state_dict(), '{0}/netD_epoch_{1}.pth'.format(opt.experiment, epoch))

### GAN outputs 
* The fake images getting slightly better each iterations as you can see below 

### Real Sample :
<div>
<img src="https://raw.githubusercontent.com/prachuryanath/WBC-NCA---Reproduction/refs/heads/main/images/real_samples.png" width="600" />
</div>

### Fake sample after 250 iterations
<div>
<img src="https://raw.githubusercontent.com/prachuryanath/WBC-NCA---Reproduction/refs/heads/main/images/fake_samples_250.png" width="600" />
</div>

### Fake sample after 24500 iterations
<div>
<img src="https://raw.githubusercontent.com/prachuryanath/WBC-NCA---Reproduction/refs/heads/main/images/fake_samples_24500.png" width="600" />
</div>

# Segmentation Network

In [18]:
import torch
import os
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
from itertools import cycle
from semi.code.data.build_dataset import build_dataset
from semi.code.models.build_model import build_model
from semi.code.models.dc_gan import DCGAN_D
from semi.code.utils.evaluate import evaluate
from semi.code.utils.loss import BceDiceLoss
import math
import warnings

warnings.filterwarnings("ignore", category=UserWarning)

In [19]:
# Load config file
with open("config_model.yaml", "r") as f:
    config = yaml.safe_load(f)
    
from types import SimpleNamespace

args = SimpleNamespace(**config)
print(args) 

namespace(GPUs='0,1', root='', dataset='tn3k', ratio=2, manner='semi', mode='train', nEpoch=200, batch_size=32, num_workers=2, load_ckpt='best', model='MyModel', expID=3, ckpt_name='tn3k_1', lr=0.001, power=0.9, betas=[0.9, 0.999], weight_decay='1e-5', eps='1e-8', mt=0.9, nclasses=1, band=3)


In [20]:
# Function to calculate Deep Supervised Segmentation Loss (BCE + Dice)
def DeepSupSeg(pred, gt):
    criterion = BceDiceLoss()
    loss = criterion(pred, gt)
    return loss

# Function to calculate learning rate based on polynomial decay
def lr_poly(base_lr, iter, max_iter, power):
    return base_lr * ((1-float(iter)/max_iter)**power)

# Function to adjust the learning rate during training using lr_poly function
def adjust_lr_rate(argsimizer, iter, total_batch):
    lr = lr_poly(args.lr, iter, args.nEpoch*total_batch, args.power)
    argsimizer.param_groups[0]['lr'] = lr
    return lr

## Training the Segmentation Network without GAN weights

In [16]:
# Training function that loads data, initializes model, and performs training loop
def train():
    """load data"""
    train_l_data, _ , valid_data = build_dataset(args)
    train_l_dataloader = DataLoader(train_l_data, args.batch_size, shuffle=True, num_workers=args.num_workers)
    valid_sign = False
    if valid_data is not None:
        valid_sign = True
        valid_dataloader = DataLoader(valid_data, batch_size=1, shuffle=False, num_workers=args.num_workers)
        val_total_batch = int(len(valid_data) / 1)
    
    """Initialize model and optimizer"""
    model = build_model(args)
    model = nn.DataParallel(model)
    model = model.cuda()
    
    # Using Stochastic Gradient Descent (SGD) optimizer
    optim = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=args.mt, weight_decay=args.weight_decay)

    # train
    print('\n---------------------------------')
    print('Start training')
    print("No of cuda devices :", torch.cuda.device_count())
    print('---------------------------------\n')

    F1_best, F1_second_best, F1_third_best = 0, 0, 0
    best = 0
    for epoch in range(args.nEpoch):

        model.train() # Set model to training mode

      
        print("Epoch: {}".format(epoch))
        total_batch = math.ceil(len(train_l_data) / args.batch_size)
        bar = tqdm(enumerate(train_l_dataloader), total=total_batch)
        for batch_id, data_l in bar:
            itr = total_batch * epoch + batch_id
            img, gt = data_l['image'], data_l['label']
            if torch.cuda.is_available():
                img = img.cuda()
                gt = gt.cuda()
            optim.zero_grad()
            mask = model(img)
            loss = DeepSupSeg(mask, gt) 
            loss.backward()
            optim.step()
            adjust_lr_rate(optim, itr, total_batch)
            
        # Validation step if validation data is provided
        if valid_sign:
            recall, specificity, precision, F1, F2, ACC_overall, IoU_poly, IoU_bg, IoU_mean, dice, list_name, list_point = evaluate(model, valid_dataloader, val_total_batch)

            print("Valid Result:")
            print('recall: %.4f, specificity: %.4f, precision: %.4f, F1: %.4f, F2: %.4f, ACC_overall: %.4f, IoU_poly: %.4f, IoU_bg: %.4f, IoU_mean: %.4f, dice: %.4f' \
                % (recall, specificity, precision, F1, F2, ACC_overall, IoU_poly, IoU_bg, IoU_mean, dice))

            # Track and save best model based on dice score
            if dice > best:
                best = dice
            print("Best Dice:: ", best)

            # Save model checkpoints based on F1 score performance
            if (F1 > F1_best):
                F1_best = F1
                torch.save(model.state_dict(), args.root + "/semi/checkpoint/" + args.ckpt_name + "/best.pth")
            elif(F1 > F1_second_best):
                F1_second_best = F1
                torch.save(model.state_dict(), args.root + "/semi/checkpoint/" + args.ckpt_name + "/second_best.pth")
            elif(F1 > F1_third_best):
                F1_third_best = F1
                torch.save(model.state_dict(), args.root + "/semi/checkpoint/" + args.ckpt_name + "/third_best.pth")

## Training Segmentation Network with DSR (GAN) weights

In [None]:
# This function is used to train the model with both labeled and unlabeled data using GAN model
def train_semi():
    # Load the dataset (labeled, unlabeled, and validation data)
    train_l_data, train_u_data, valid_data = build_dataset(args)
    train_l_dataloader = DataLoader(train_l_data, args.batch_size, shuffle=True, num_workers=args.num_workers)
    train_u_dataloader = DataLoader(train_u_data, args.batch_size, shuffle=True, num_workers=args.num_workers)
    valid_sign = False
    if valid_data is not None:
        valid_sign = True
        valid_dataloader = DataLoader(valid_data, batch_size=1, shuffle=False, num_workers=args.num_workers)
        val_total_batch = int(len(valid_data) / 1)
        
    # Load the model
    model = build_model(args)
    model = nn.DataParallel(model)
    model = model.cuda()
    
    # Load the discriminator model for the GAN
    netD = DCGAN_D(64, 100, 1, 64, 1, 0)
    netD = nn.DataParallel(netD)
    netD.cuda()
    
    # Load pre-trained weights for the discriminator
    netD_weight = torch.load("GAN/result/netD_epoch_5000.pth")
    netD.load_state_dict(netD_weight)
    netD.eval()
    
    # Initialize the optimizer for the model
    optim = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=args.mt, weight_decay=args.weight_decay)

    # train
    print('\n---------------------------------')
    print('Start training_semi')
    print('---------------------------------\n')
    F1_best, F1_second_best, F1_third_best = 0, 0, 0
    best = 0
    for epoch in range(args.nEpoch):
        model.train()
        print("Epoch: {}".format(epoch))
        loader = iter(zip(cycle(train_l_dataloader), train_u_dataloader))
        bar = tqdm(range(len(train_u_dataloader)))
        
        # Iterate through the training batches
        for batch_id in bar:
            data_l, data_u = next(loader)
            total_batch = len(train_u_dataloader)
            itr = total_batch * epoch + batch_id
            img_l, gt = data_l['image'], data_l['label']
            img_u = data_u
            if torch.cuda.is_available():
                img_l = img_l.cuda()
                gt = gt.cuda()
                img_u = img_u.cuda()
            optim.zero_grad()

            # Forward pass for labeled data
            pred_l = model(img_l)
            mask = pred_l[0]
            loss_l_seg = DeepSupSeg(mask, gt)
            loss_l = loss_l_seg
            
            # Forward pass for unlabeled data
            pred_u = model(img_u)
            _, predboud, inpimg2, inpimg3, inpimg4, inpimg5, mask_boud = pred_u
            loss_u_seg = DeepSupSeg(predboud, mask_boud)
            
            # Apply GAN loss to unlabeled data
            shape_u_1 = F.interpolate(predboud, size = (64, 64), mode = 'bilinear', align_corners = False)
            shape_u_2 = F.interpolate(inpimg2, size = (64, 64), mode = 'bilinear', align_corners = False)
            shape_u_3 = F.interpolate(inpimg3, size = (64, 64), mode = 'bilinear', align_corners = False)
            shape_u_4 = F.interpolate(inpimg4, size = (64, 64), mode = 'bilinear', align_corners = False)
            shape_u_5 = F.interpolate(inpimg5, size = (64, 64), mode = 'bilinear', align_corners = False)
            loss_u_shape = (netD(shape_u_1) + netD(shape_u_2) + netD(shape_u_3) + netD(shape_u_4) + netD(shape_u_5)) / 5
            loss_u = loss_u_seg + 0.1 * loss_u_shape
            
            # Total loss combines the losses from labeled and unlabeled data
            loss = 2 * loss_l + loss_u
            loss.sum().backward()
            optim.step()
            
            # Adjust the learning rate
            adjust_lr_rate(optim, itr, total_batch)
        model.eval()
        
        # If validation data is available, evaluate the model after each epoch
        if valid_sign:
            recall, specificity, precision, F1, F2, ACC_overall, IoU_poly, IoU_bg, IoU_mean, dice, list_name, list_point = evaluate(model, valid_dataloader, val_total_batch)

            print("Valid Result:")
            print('recall: %.4f, specificity: %.4f, precision: %.4f, F1: %.4f, F2: %.4f, ACC_overall: %.4f, IoU_poly: %.4f, IoU_bg: %.4f, IoU_mean: %.4f, dice: %.4f' \
                % (recall, specificity, precision, F1, F2, ACC_overall, IoU_poly, IoU_bg, IoU_mean,dice))
            
            # Save best models based on F1 score and dice metric            
            if dice > best:
                best = dice
            print("Best Dice:: ", best)

            if (F1 > F1_best):
                F1_best = F1
                torch.save(model.state_dict(), args.root + "/semi/checkpoint/" + args.ckpt_name + "/best.pth")
            elif(F1 > F1_second_best):
                F1_second_best = F1
                torch.save(model.state_dict(), args.root + "/semi/checkpoint/" + args.ckpt_name + "/second_best.pth")
            elif(F1 > F1_third_best):
                F1_third_best = F1
                torch.save(model.state_dict(), args.root + "/semi/checkpoint/" + args.ckpt_name + "/third_best.pth")

## Testing the model

In [15]:
# This function is used to evaluate the model on the test dataset.
def test():
  
    print('loading data......')
    test_data = build_dataset(args)
    test_dataloader = DataLoader(test_data, batch_size=1, shuffle=False, num_workers=args.num_workers)
    total_batch = int(len(test_data) / 1)\
    
    model = build_model(args)
    model.eval()
    
    # Evaluate the model on the test dataset
    recall, specificity, precision, F1, F2, ACC_overall, IoU_poly, IoU_bg, IoU_mean, dice, list_name, list_point = evaluate(model, test_dataloader, total_batch)
    
    print("Test Result:")
    print('recall: %.4f, specificity: %.4f, precision: %.4f, F1: %.4f, F2: %.4f, ACC_overall: %.4f, IoU_poly: %.4f, IoU_bg: %.4f, IoU_mean: %.4f, dice: %.4f' \
                % (recall, specificity, precision, F1, F2, ACC_overall, IoU_poly, IoU_bg, IoU_mean,dice))

In [16]:
os.environ['CUDA_VISIBLE_DEVICES'] = args.GPUs

In [17]:
# Create checkpoint directory if it doesn't exist
checkpoint_name = os.path.join(args.root, 'semi/checkpoint/' + args.ckpt_name)
if not os.path.exists(checkpoint_name):
    os.makedirs(checkpoint_name)
else:
    pass

In [None]:
# Decide on the mode: full training, semi-supervised training, or testing
if args.manner == 'full':
    print('---{}-Seg Train---'.format(args.dataset))
    train()
elif args.manner =='semi':
    print('---{}-seg Semi-Train--'.format(args.dataset))
    train_semi()
elif args.manner == 'test':
    print('---{}-Seg Test---'.format(args.dataset))
    test()
print('Done')

## Results

* As shown in the following images, the results obtained from our complete model are presented. These results are found to be quite similar to the results reported in the original paper, specifically when applied to a subset of 1/8 of the images used in our experiment.


<div>
<img src="https://raw.githubusercontent.com/prachuryanath/WBC-NCA---Reproduction/refs/heads/main/images/Screenshot%202025-02-21%20001220.jpg" width="600" />
</div>


<div>
<img src="https://raw.githubusercontent.com/prachuryanath/WBC-NCA---Reproduction/refs/heads/main/images/Screenshot%202025-02-21%20002126.jpg" width="400" />
</div>

### Some labels and model outputs

#### First segmentation

* Real mask :
<div>
<img src="https://github.com/prachuryanath/WBC-NCA---Reproduction/blob/main/images/1651.jpg?raw=true" width="300" />
</div>

* Model Output Mask :
<div>
<img src="https://github.com/prachuryanath/WBC-NCA---Reproduction/blob/main/images/1651.png?raw=true" width="300" />
</div>

#### Second segmentation

* Real mask :
<div>
<img src="https://github.com/prachuryanath/WBC-NCA---Reproduction/blob/main/images/2150.jpg?raw=true" width="300" />
</div>

* Model Output Mask :

<div>
<img src="https://github.com/prachuryanath/WBC-NCA---Reproduction/blob/main/images/2150.png?raw=true" width="300" />
</div>

#### Third segmentation

* Real Mask :
<div>
<img src="https://github.com/prachuryanath/WBC-NCA---Reproduction/blob/main/images/1671.jpg?raw=true" width="300" />
</div>

* Model Output Mask :
<div>
<img src="https://github.com/prachuryanath/WBC-NCA---Reproduction/blob/main/images/1671.png?raw=true" width="300" />
</div>

## Challenges

1. Insufficient Code Comments: The repository lacks comments, requiring additions for clarity.
2. Disorganized Repository: The folder structure and setup are unorganized, needing restructuring and configuration files.
3. Required Code Modifications: Some moderate changes are necessary to align the code with the paper’s results.


## Conclusion

Reproducing the paper's findings revealed some significant difficulties, especially with structuring the code and including the required explanations in the comments. This reproduction is classified as Level 3 because while the paper and repository provide most of the necessary setup, the code lacks sufficient comments, and the repository does not have an ideal structure. This requires moderate changes to properly reproduce even one figure from the paper, including restructuring the code, adding comments, and ensuring a proper setup for smooth execution.