In [1]:
import pandas as pd
from click.core import batch
from seaborn import histplot

from utils import pic2float, pic2int, pic2pil, sigmoid, swimg

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import transforms as tt
from torchvision.transforms import v2
from torchvision.transforms import functional as tf

# dataset
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import os
from IPython.display import clear_output

import plotly.express as px
from torch.amp import autocast, GradScaler
import kornia

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')


In [2]:
DATA_PATH = 'N:\PROJECTS\python\STUDY\SHADOW\DATASET\DATA_01'

In [3]:
class ShadowDataset(Dataset):
    def __init__(self, root_dir, transform=None, target_transform=None, mode='train', mem=False, n=None):
        self.root_dir = root_dir
        if mode == 'train':
            self.dir_names = {
                'color': 'INPUT_COLOR',
                'mask': 'INPUT_MASK',
                'target': 'TARGET'
            }
        elif mode == 'test':
            self.dir_names = {
                'color': 'TEST_COLOR',
                'mask': 'TEST_MASK',                
            }
        self.dir_names = {k: os.path.join(root_dir, v) for k, v in self.dir_names.items()}
        self.files = {k: os.listdir(v) for k, v in self.dir_names.items()}
        if n is None:
            n = len(self.files['color'])
        self.n = n
        self.mode = mode
        self.init_transform = v2.Compose([
            v2.ToImage(),
            v2.ToDtype(torch.float32, scale=True),
            # v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])
        
        self.data = []        
        self.mem = False          
        
        if mem:
            self.prepare()
              
        
    def prepare(self):        
        for i in range(self.n):
            #clear_output()
            self.data.append(self.load(i))
            print(f'load {i} / {len(self)}                 ', end='\r')
        print()
        self.mem = True
        
    def load(self, idx):
        if self.mem:
            return self.data[idx]
        else:
            out = {
                'color': Image.open(os.path.join(self.dir_names['color'], self.files['color'][idx])),
                'mask': Image.open(os.path.join(self.dir_names['mask'], self.files['mask'][idx]))
            }        

            if self.mode == 'train':
                out['target'] = Image.open(os.path.join(self.dir_names['target'], self.files['target'][idx]))
            out = self.init_transform(out)
            return out
        

    def __len__(self):
        return self.n

    def __getitem__(self, idx):
        
        transforms = v2.Compose([    
            v2.CenterCrop(768),
            v2.RandomCrop(512),      
            v2.ColorJitter(brightness=0, contrast=0, saturation=0.25, hue=0.45),
        ])
        
        # end_transform = v2.Compose([
        #     
        # ])
        
        out = self.load(idx)
        out = {k:v.to(device) for k, v in out.items()}
        
            
        #out = self.init_transform(out)
        
        rotate_angle = torch.randint(0, 359, (1,)).item() 
        #out['color'] = tf.rotate(out['color'], rotate_angle, interpolation=tf.InterpolationMode.BILINEAR, expand=False, center=None, fill=1)
        
        #v2.functional.rotate(out, rotate_angle)
        out = v2.Pad(256, padding_mode='reflect')(out)
        out = v2.RandomRotation((rotate_angle, rotate_angle), fill=1)(out)
        
        out = transforms(out)
        # out = end_transform(out)
        
        # сделать из угла поворота признаки синуса и косинуса, 360 преобразовать в 2 пи
        rotate_angle = torch.tensor([rotate_angle], dtype=torch.float32) / 180 * 3.1415        
        rotate_angle = torch.cat([torch.sin(rotate_angle), torch.cos(rotate_angle)], dim=0)
        
        out['rot'] = rotate_angle
        
        # v2.ToTensor()(out)
        
        return out


In [4]:
tdata = ShadowDataset(DATA_PATH, mode='train', mem=True,n=None)

load 1499 / 1500                 


In [5]:
class ShadowDatasetWrapper(Dataset):
    # Обертка для датасета, добавляющая входу маску и угол поворота
    # для настройки бачнорма в генераторе с усредненными весами
    def __init__(self, sd):
        self.sd = sd
    
    def __len__(self):
        return len(self.sd)

    def __getitem__(self, idx):
        
        out = self.sd[idx]
        
        color = out['color'][None].to(device)
        mask = out['mask'][None].to(device)
        target = out['target'][None].to(device)
        rot = out['rot'][None].to(device)
        
        input = add_rot_mask(color, rot, mask)[0]
        
        return input, target[0]

tdata_w = ShadowDatasetWrapper(tdata)    

In [6]:
def add_mask(x, mask):
    m = mask[:,0:1, :, :]
    return torch.cat((x, m), dim=1)

def add_rot(x, rot):
    device = x.device
    r = torch.ones(x.size(0), 2, x.size(2), x.size(3), device=device)
    r[:, 0, :, :] *= rot[:,0, None, None]
    r[:, 1, :, :] *= rot[:,1, None, None]
    return torch.cat((x, r), dim=1)

def add_rot_mask(x, rot, mask):
    x = add_mask(x, mask)
    x = add_rot(x, rot)
    return x
def vis(pics):
    # посылает список изображений на веб-сервер
    # для отображения в браузере
    pics = pic2pil(pics)
    swimg([pics])
    

In [7]:
# Определяем блоки генератора
class Block(nn.Module):
    def __init__(self, in_channels, out_channels, dropout=False, norm=True, relu=True):
        super(Block, self).__init__()
        
        layers = [
            nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1),
        ]
        if norm:
            layers.append(nn.BatchNorm2d(out_channels))
        
        if relu:
            layers.append(nn.ReLU(inplace=True))
        
        if dropout:
            layers.append(nn.Dropout(0.5))
        self.block = nn.Sequential(*layers)        

    def forward(self, x):
        return self.block(x)

# Генератор на основе кодировщика-декодировщика
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        # Кодировщик
        self.enc1 = Block(6, 64, norm=False)  # Входное изображение
        self.enc2 = Block(64, 128)
        self.enc3 = Block(128, 256)
        self.enc4 = Block(256, 256)
        self.enc5 = Block(256, 256)
        self.enc6 = Block(256, 256, dropout=True)
        self.downsample = nn.MaxPool2d(2)

        # Декодировщик
        self.dec6 = Block(256, 256)
        self.dec5 = Block(256*2, 256)
        self.dec4 = Block(256*2, 256)
        self.dec3 = Block(256*2, 128)
        self.dec2 = Block(128*2, 64)
        self.dec1 = Block(64*2, 3, norm=False, relu=False)
        self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)

    def forward(self, x):
        # Кодировка
        enc1 = self.enc1(x)  # 3-64
        
        enc2 = self.downsample(enc1)
        enc2 = self.enc2(enc2) # 64-128
        
        enc3 = self.downsample(enc2)
        enc3 = self.enc3(enc3) # 128-256
        
        enc4 = self.downsample(enc3)
        enc4 = self.enc4(enc4) # 256-512
        
        enc5 = self.downsample(enc4)
        enc5 = self.enc5(enc5) # 512-512
        
        enc6 = self.downsample(enc5)
        enc6 = self.enc6(enc6) # 512-512
        
        # Декодировка        
        dec6 = self.dec6(enc6) # 512-512
        
        dec5 = self.upsample(dec6)
        dec5 = torch.cat([enc5, dec5], dim=1) # 512+512
        dec5 = self.dec5(dec5) # 1024-512
        
        dec4 = self.upsample(dec5)
        dec4 = torch.cat([enc4, dec4], dim=1) # 512+512
        dec4 = self.dec4(dec4) # 1024-256
        
        dec3 = self.upsample(dec4) 
        dec3 = torch.cat([enc3, dec3], dim=1) # 256+256
        dec3 = self.dec3(dec3) # 512-128
        
        dec2 = self.upsample(dec3)
        dec2 = torch.cat([enc2, dec2], dim=1) # 128+128
        dec2 = self.dec2(dec2) # 256-64
        
        dec1 = self.upsample(dec2)    
        dec1 = torch.cat([enc1, dec1], dim=1) # 64+64
        dec1 = self.dec1(dec1) # 128-3       
        
        dec1 = torch.tanh(dec1)
        
        return dec1

# Дискриминатор
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.downsample = nn.MaxPool2d(2)
        self.conv = lambda in_ch, out_ch: nn.Conv2d(in_ch, out_ch, kernel_size=3, stride=1, padding=1)
        self.act = nn.LeakyReLU(0.2, inplace=True)
        
        self.model = nn.Sequential(            
            self.conv(9, 64), self.act,
            nn.BatchNorm2d(64),
            self.conv(64, 64), self.act,
            self.downsample, # 512 -> 256
            
            self.conv(64, 128), self.act,            
            nn.BatchNorm2d(128),
            self.conv(128, 128), self.act,
            self.downsample, # 256 -> 128
            
            self.conv(128, 256), self.act,
            nn.BatchNorm2d(256),
            self.conv(256, 256), self.act,
            self.downsample, # 128 -> 64
            
            self.conv(256, 512), self.act,
            nn.BatchNorm2d(512),
            self.conv(512, 512), self.act,
            self.downsample, # 64 -> 32
            
            self.conv(512, 512), self.act,
            nn.BatchNorm2d(512),
            self.conv(512, 512), self.act,
            self.downsample, # 32 -> 16
            
            self.conv(512, 1),
        )
        
    def forward(self, x):
        x = self.model(x)        
        return x

In [8]:
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
import math

class SmoothWarmupCosineAnnealingWarmRestarts(torch.optim.lr_scheduler.CosineAnnealingWarmRestarts):
    def __init__(self, optimizer, T_0, T_mult=1, eta_min=0, warmup_iters=10, last_epoch=-1, decay=0.9995):
        """
        Модифицированный CosineAnnealingWarmRestarts с плавным увеличением lr после рестарта.

        Args:
            optimizer (Optimizer): Оптимизатор.
            T_0 (int): Количество итераций для первого рестарта.
            T_mult (int): Увеличение периода рестартов (T_i = T_0 * T_mult^i).
            eta_min (float): Минимальное значение lr.
            warmup_iters (int): Количество итераций для плавного увеличения после рестарта.
            last_epoch (int): Индекс последней эпохи.
        """
        self.decay = decay
        self.mult = 1
        self.warmup_iters = warmup_iters  # Количество итераций для warm-up
        super().__init__(optimizer, T_0, T_mult, eta_min, last_epoch)
        

    def get_lr(self):
        # Текущий номер итерации в цикле (считается в оригинальном CosineAnnealingWarmRestarts)
        T_cur = self.T_cur

        # Плавное увеличение (warm-up)
        if T_cur < self.warmup_iters:
            return [
                self.mult * (self.eta_min + (base_lr - self.eta_min) * (T_cur / self.warmup_iters))
                for base_lr in self.base_lrs
            ]

        # Если warm-up завершён, вызываем оригинальный метод расчёта lr
        progress = (T_cur - self.warmup_iters) / (self.T_i - self.warmup_iters)
        self.mult *= self.decay
        return [
            self.mult * (self.eta_min + (base_lr - self.eta_min) * (1 + math.cos(math.pi * progress)) / 2)
            for base_lr in self.base_lrs
        ]


In [9]:
# Инициализация моделей
generator = Generator().to(device)
discriminator = Discriminator().to(device)
iter = 0
log = []

# Оптимизаторы
lr = 1e-3
g_optimizer = optim.AdamW(generator.parameters(), lr=lr, betas=(0.5, 0.999))
d_optimizer = optim.AdamW(discriminator.parameters(), lr=lr, betas=(0.5, 0.999))

# g_scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(g_optimizer, T_0=300, T_mult=1, eta_min=1e-6, last_epoch=-1)
g_scheduler = SmoothWarmupCosineAnnealingWarmRestarts(
    g_optimizer,
    T_0=300,       # Длина первого цикла
    T_mult=1,     # Постоянная длина циклов
    eta_min=1e-6, # Минимальное значение lr
    warmup_iters=20 # Плавное увеличение lr в течение 10 итераций
)
d_scheduler = SmoothWarmupCosineAnnealingWarmRestarts(
    d_optimizer,
    T_0=300,       # Длина первого цикла
    T_mult=1,     # Постоянная длина циклов
    eta_min=1e-6, # Минимальное значение lr
    warmup_iters=20 # Плавное увеличение lr в течение 10 итераций
)
# Функции потерь
adversarial_loss = nn.BCEWithLogitsLoss()
reconstruction_loss = nn.L1Loss()

scaler = GradScaler()

# EMA модель для генератора с усредненными весами
ema_model = torch.optim.swa_utils.AveragedModel(generator, multi_avg_fn=torch.optim.swa_utils.get_ema_multi_avg_fn(0.999))

In [10]:
def train_step(batch, gan=False, iter=iter):
    l1_dec_step = 150
    l1_dec = 1000 * max(0.001,(l1_dec_step-iter)/l1_dec_step) 
    warm_step = 10
    warm = min(warm_step,iter) / warm_step  
    l1_mul = 50 * warm * l1_dec
    adv_mul = warm
    
    generator.train()
    discriminator.train()
    
    with autocast('cuda'):
    
        color = batch['color'].to(device)
        mask = batch['mask'].to(device)
        target = batch['target'].to(device)
        rot = batch['rot'].to(device)
        
        input = add_rot_mask(color, rot, mask)
        real = target 
        
        # Генерация теней
        fake = generator(input.to(device))
        
        # === Обучение дискриминатора ===
        real_targets = torch.ones(real.size(0), 1, 16,16)  # Настоящие метки
        fake_targets = torch.zeros(real.size(0), 1, 16,16)  # Фейковые метки
        
        # Реальные данные
        d_real_loss = adversarial_loss(discriminator(torch.cat((input, target), dim=1).to(device)), real_targets.to(device))
        # Фейковые данные
        d_fake_loss = adversarial_loss(discriminator(torch.cat((input.to(device), fake.detach()), dim=1)), fake_targets.to(device))
        d_loss = adv_mul * (d_real_loss + d_fake_loss) / 2

    d_optimizer.zero_grad()    
    scaler.scale(d_loss).backward()
    
    # clip gradient
    scaler.unscale_(d_optimizer)
    d_grad = torch.nn.utils.clip_grad_norm_(discriminator.parameters(), 0.5)
    
    scaler.step(d_optimizer)
    d_scheduler.step()
    
    with autocast('cuda'):
        # === Обучение генератора ===
                
        g_fake_loss = adversarial_loss(discriminator(torch.cat((input.to(device), fake), dim=1)), real_targets.to(device))
        
        real_object = target * (1-mask)
        fake_object = fake * (1-mask).to(device) 
        
        real_bg = target * mask
        fake_bg = fake * mask.to(device)
        
        g_object_l_loss = reconstruction_loss(fake_object, real_object.to(device)) * .1   # L1 для стабильности
        g_bg_l_loss = reconstruction_loss(fake_bg, real_bg.to(device)) * 1  #
        g_l_loss = (g_object_l_loss + g_bg_l_loss) * l1_mul
        g_loss = g_l_loss + g_fake_loss * adv_mul

    g_optimizer.zero_grad()
    scaler.scale(g_loss).backward()
    scaler.unscale_(g_optimizer)
    g_grad = torch.nn.utils.clip_grad_norm_(generator.parameters(), 0.5)
    scaler.step(g_optimizer)
    scaler.update()
    g_scheduler.step()
    ema_model.update_parameters(generator)
    
    lr_g = g_optimizer.param_groups[0]["lr"]
    lr_d = d_optimizer.param_groups[0]["lr"]
    

    return (g_l_loss.item(), g_fake_loss.item(), 
            d_loss.item(), d_real_loss.item(), d_fake_loss.item(), 
            lr_g*1000, lr_d*1000, 
            g_grad.item(), d_grad.item())

   

In [12]:
import gc
epoch = 300

def train(epoch = 20, iter=iter):
    torch.cuda.empty_cache()
    gc.collect() 
    
    dl = DataLoader(tdata, batch_size=16, shuffle=True)
    max_step = len(dl)
    # epoch = 20
    for ep in range(epoch):
        for step, batch in enumerate(dl):
            #(d_loss, 
            loss = train_step(batch, gan=True, iter=iter)
            log.append((iter, *loss))
            iter += 1
            
            if step % 20 == 0:      
                clear_output(wait=False)
                with torch.inference_mode():
                    to_vis = [
                        *batch['target'][:2].to(device), 
                        *generator(
                            add_rot_mask(
                                batch['color'][:2].to(device), 
                                batch['rot'][:2].to(device), 
                                batch['mask'][:2].to(device)
                            )
                        )
                    ]
                to_vis = torch.cat(to_vis, dim=2).detach().cpu()
                vis(to_vis)
                print(f'{ep} / {epoch}, step: {step} / {max_step}')
                gfx = pd.DataFrame(log, columns=['iter', 'g_l', 'g_f', 'd', 'd_r', 'd_f', 'lr_g', 'lr_d', 'g_grad', 'd_grad'])
                df = gfx
                fig = px.line(df.iloc[len(df)//2:,1:])
                fig.update_yaxes(type="log")
                fig.show()
          
train(epoch, iter=iter) 
        

299 / 300, step: 80 / 94


In [None]:
gfx = pd.DataFrame(log, columns=['iter', 'g_l', 'g_f', 'd', 'd_r', 'd_f', 'lr_g', 'lr_d', 'g_grad', 'd_grad'])
df = gfx
fig = px.line(gfx.iloc[:,1:])
fig.update_yaxes(type="log")
fig.show()
df.tail(50)


In [13]:
# прогоним датасет через EMA модель для настройки батчнорма
dl_w = DataLoader(tdata_w, batch_size=16, shuffle=False)
torch.optim.swa_utils.update_bn(dl_w, ema_model)

In [14]:
save_path = 'models'
version = 'v03'
if not os.path.exists(save_path):
    os.makedirs(save_path)
for m in [generator, discriminator, ema_model]:
    torch.save(m.state_dict(), os.path.join(save_path, f'{m._get_name()}_{version}.pth'))
    print(f'Save {m._get_name()} with {sum(p.numel() for p in m.parameters())} params')


Save Generator with 5837827 params
Save Discriminator with 9416001 params
Save AveragedModel with 5837827 params


In [33]:
DTYPE = torch.float16

MODEL = 'SHADOW/models/Generator.pth'
AV_MODEL = 'SHADOW/models/AveragedModel.pth'

USE_AVERAGE = True

class ShadowGenerator:
    def __init__(self, generator=None, device=device, avarage=USE_AVERAGE):
        if generator is None:
            generator = Generator()
            if avarage:
                generator = torch.optim.swa_utils.AveragedModel(generator, multi_avg_fn=torch.optim.swa_utils.get_ema_multi_avg_fn(0.999))
                model_path = AV_MODEL
            else:
                model_path = MODEL
            generator.load_state_dict(torch.load(os.path.join(ROOT, model_path), weights_only=True))
            generator.eval()
            generator.to(DTYPE).to(device)

        self.generator = generator
        self.device = device

    def generate(self, colors, masks, rots=None,):

        colors = torch.stack([torch.tensor(i).permute(2, 0, 1) for i in colors])
        masks = 1 - torch.stack([torch.tensor(i).permute(2, 0, 1) for i in masks])

        if rots is None:
            rotate_angle = torch.randint(0, 359, (1,)).item()
        else:
            rotate_angle = rots

        rotate_angle = torch.tensor([rotate_angle]) / 180 * 3.1415
        rots = torch.cat([torch.sin(rotate_angle), torch.cos(rotate_angle)], dim=0)
        rots = torch.cat([rots[None, :]] * colors.size(0), dim=0)

        input = add_rot_mask(colors, rots, masks)

        with torch.inference_mode():
            shadow = self.generator(input.half().to(device))

        shadow_comp = []
        for i, s in enumerate(shadow.cpu()):
            s = s * (masks[i]) + colors[i] * (1 - masks[i])
            shadow_comp.append(s)

        return shadow_comp

sg = ShadowGenerator()


In [None]:
from ML_SERVER.sam import sam_process
from constant import ROOT

def generate_shadow(images, masks, rots=None):
    shadow_comp = sg.generate(images, masks, rots)
    return shadow_comp


def test():
    sg = ShadowGenerator()
    image_path = os.path.join(ROOT, "image.jpg")
    image = Image.open(image_path)
    image = pic2float(image)
    images, masks, _ = sam_process(image)

    shadow_comp = sg.generate(images, masks)

    swimg([*images, *masks, *shadow_comp])