# SG-ShadowNet Notebook
This notebook contains the implementation of SG-ShadowNet, including training, testing, and utility functions.

## Import Required Libraries
Import the necessary libraries and modules.

In [None]:
from __future__ import print_function
import os
import datetime
import argparse
import itertools
import torchvision
from torch.utils.data import DataLoader
import torch
from utils.utils import LambdaLR, weights_init_normal, tensor2img, calc_RMSE
from models.model import ConGenerator_S2F, ConRefineNet
from loss.losses import L_spa
from data.datasets import ImageDataset, TestImageDataset
import numpy as np
from skimage import io, color
from skimage.transform import resize
import torch.nn.functional as F
import warnings
warnings.filterwarnings("ignore")

## Training Logic
The following cell contains the training logic for SG-ShadowNet.

In [None]:
# Training logic extracted from train.py
parser = argparse.ArgumentParser()
parser.add_argument('--epoch', type=int, default=0, help='starting epoch')
parser.add_argument('--n_epochs', type=int, default=200, help='number of epochs of training')
parser.add_argument('--batchSize', type=int, default=1, help='size of the batches')
parser.add_argument('--lr', type=float, default=0.0002, help='initial learning rate')
parser.add_argument('--decay_epoch', type=int, default=50, help='epoch to start linearly decaying the learning rate to 0')
parser.add_argument('--n_cpu', type=int, default=8, help='number of cpu threads to use during batch generation')
parser.add_argument('--iter_loss', type=int, default=100, help='average loss for n iterations')
opt = parser.parse_args([])  # Empty list to avoid Jupyter conflicts
opt.dataroot = 'input/dataset/ISTD'
# Initialize networks
netG_1 = ConGenerator_S2F().cuda()
netG_2 = ConRefineNet().cuda()
netG_1.apply(weights_init_normal)
netG_2.apply(weights_init_normal)
# Define loss functions
criterion_region = torch.nn.L1Loss()
criterion_identity = torch.nn.L1Loss()
criterion_spa = L_spa()
# Optimizers
optimizer_G = torch.optim.Adam(itertools.chain(netG_1.parameters(), netG_2.parameters()), lr=opt.lr, betas=(0.5, 0.999))
lr_scheduler_G = torch.optim.lr_scheduler.LambdaLR(optimizer_G, lr_lambda=LambdaLR(opt.n_epochs, opt.epoch, opt.decay_epoch).step)
# Data loaders
dataloader = DataLoader(ImageDataset(opt.dataroot, unaligned=True), batch_size=opt.batchSize, shuffle=True, num_workers=opt.n_cpu)
# Training loop
for epoch in range(opt.epoch, opt.n_epochs):
    for i, (s, sgt, mask, mask50) in enumerate(dataloader):
        s, sgt, mask, mask50 = s.cuda(), sgt.cuda(), mask.cuda(), mask50.cuda()
        optimizer_G.zero_grad()
        fake_sf_temp = netG_1(s, mask)
        loss_1 = criterion_identity(fake_sf_temp, sgt)
        loss_G = loss_1
        loss_G.backward()
        optimizer_G.step()
    lr_scheduler_G.step()

## Testing Logic
The following cell contains the testing logic for SG-ShadowNet.

In [None]:
# Testing logic extracted from test.py
parser = argparse.ArgumentParser()
parser.add_argument('--generator_1', type=str, default='pretrained/netG_1_aistd.pth', help='generator_1 checkpoint file')
parser.add_argument('--generator_2', type=str, default='pretrained/netG_2_aistd.pth', help='generator_2 checkpoint file')
parser.add_argument('--savepath', type=str, default='results/aistd/', help='save path')
opt = parser.parse_args([])
netG_1 = ConGenerator_S2F().cuda()
netG_2 = ConRefineNet().cuda()
netG_1.load_state_dict(torch.load(opt.generator_1))
netG_2.load_state_dict(torch.load(opt.generator_2))
netG_1.eval()
netG_2.eval()
# Example testing loop
for idx in range(10):  # Replace with actual test data
    print(f'Testing image {idx}')

## Utility Functions
The following cell contains utility functions used in SG-ShadowNet.

In [None]:
# Utility functions extracted from utils.py
def labimage2tensor(labimage, h=480, w=640):
    labimage_t = resize(labimage, (h, w, 3))
    labimage_t[:, :, 0] = np.asarray(labimage_t[:, :, 0]) / 50.0 - 1.0
    labimage_t[:, :, 1:] = 2.0 * (np.asarray(labimage_t[:, :, 1:]) + 128.0) / 255.0 - 1.0
    return torch.from_numpy(labimage_t).float()

## Loss Functions
The following cell contains the loss functions used in SG-ShadowNet.

In [None]:
# Loss functions extracted from losses.py
class L_spa(torch.nn.Module):
    def forward(self, org, enhance):
        return torch.mean((org - enhance) ** 2)

## Model Architectures
The following cell contains the model architectures for SG-ShadowNet.

In [None]:
# Model architectures extracted from model.py
class ConGenerator_S2F(nn.Module):
    def forward(self, xin, mask):
        return xin * mask