# Import library

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import argparse
import glob
import os
import random

import numpy as np
import torchvision.transforms.functional as T
from PIL import Image
from torch.backends import cudnn
from torchvision.transforms import RandomCrop

import pandas as pd
import torchvision.transforms.functional as TF
from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR
from tqdm.notebook import tqdm

In [2]:
train = '/kaggle/input/rain13kdataset/train/train/Rain13K'
test = '/kaggle/input/rain13kdataset/test/test'

# Model

In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F


class MDTA(nn.Module):
    def __init__(self, channels, num_heads):
        super(MDTA, self).__init__()
        self.num_heads = num_heads
        self.temperature = nn.Parameter(torch.ones(1, num_heads, 1, 1))

        self.qkv = nn.Conv2d(channels, channels * 3, kernel_size=1, bias=False)
        self.qkv_conv = nn.Conv2d(channels * 3, channels * 3, kernel_size=3, padding=1, groups=channels * 3, bias=False)
        self.project_out = nn.Conv2d(channels, channels, kernel_size=1, bias=False)

    def forward(self, x):
        b, c, h, w = x.shape
        q, k, v = self.qkv_conv(self.qkv(x)).chunk(3, dim=1)

        q = q.reshape(b, self.num_heads, -1, h * w)
        k = k.reshape(b, self.num_heads, -1, h * w)
        v = v.reshape(b, self.num_heads, -1, h * w)
        q, k = F.normalize(q, dim=-1), F.normalize(k, dim=-1)

        attn = torch.softmax(torch.matmul(q, k.transpose(-2, -1).contiguous()) * self.temperature, dim=-1)
        out = self.project_out(torch.matmul(attn, v).reshape(b, -1, h, w))
        return out


class GDFN(nn.Module):
    def __init__(self, channels, expansion_factor):
        super(GDFN, self).__init__()

        hidden_channels = int(channels * expansion_factor)
        self.project_in = nn.Conv2d(channels, hidden_channels * 2, kernel_size=1, bias=False)
        self.conv = nn.Conv2d(hidden_channels * 2, hidden_channels * 2, kernel_size=3, padding=1,
                              groups=hidden_channels * 2, bias=False)
        self.project_out = nn.Conv2d(hidden_channels, channels, kernel_size=1, bias=False)

    def forward(self, x):
        x1, x2 = self.conv(self.project_in(x)).chunk(2, dim=1)
        x = self.project_out(F.gelu(x1) * x2)
        return x


class TransformerBlock(nn.Module):
    def __init__(self, channels, num_heads, expansion_factor):
        super(TransformerBlock, self).__init__()

        self.norm1 = nn.LayerNorm(channels)
        self.attn = MDTA(channels, num_heads)
        self.norm2 = nn.LayerNorm(channels)
        self.ffn = GDFN(channels, expansion_factor)

    def forward(self, x):
        b, c, h, w = x.shape
        x = x + self.attn(self.norm1(x.reshape(b, c, -1).transpose(-2, -1).contiguous()).transpose(-2, -1)
                          .contiguous().reshape(b, c, h, w))
        x = x + self.ffn(self.norm2(x.reshape(b, c, -1).transpose(-2, -1).contiguous()).transpose(-2, -1)
                         .contiguous().reshape(b, c, h, w))
        return x


class DownSample(nn.Module):
    def __init__(self, channels):
        super(DownSample, self).__init__()
        self.body = nn.Sequential(nn.Conv2d(channels, channels // 2, kernel_size=3, padding=1, bias=False),
                                  nn.PixelUnshuffle(2))

    def forward(self, x):
        return self.body(x)


class UpSample(nn.Module):
    def __init__(self, channels):
        super(UpSample, self).__init__()
        self.body = nn.Sequential(nn.Conv2d(channels, channels * 2, kernel_size=3, padding=1, bias=False),
                                  nn.PixelShuffle(2))

    def forward(self, x):
        return self.body(x)


class Restormer(nn.Module):
    def __init__(self, num_blocks=[4, 6, 6, 8], num_heads=[1, 2, 4, 8], channels=[48, 96, 192, 384], num_refinement=4,
                 expansion_factor=2.66):
        super(Restormer, self).__init__()

        self.embed_conv = nn.Conv2d(3, channels[0], kernel_size=3, padding=1, bias=False)

        self.encoders = nn.ModuleList([nn.Sequential(*[TransformerBlock(
            num_ch, num_ah, expansion_factor) for _ in range(num_tb)]) for num_tb, num_ah, num_ch in
                                       zip(num_blocks, num_heads, channels)])
        # the number of down sample or up sample == the number of encoder - 1
        self.downs = nn.ModuleList([DownSample(num_ch) for num_ch in channels[:-1]])
        self.ups = nn.ModuleList([UpSample(num_ch) for num_ch in list(reversed(channels))[:-1]])
        # the number of reduce block == the number of decoder - 1
        self.reduces = nn.ModuleList([nn.Conv2d(channels[i], channels[i - 1], kernel_size=1, bias=False)
                                      for i in reversed(range(2, len(channels)))])
        # the number of decoder == the number of encoder - 1
        self.decoders = nn.ModuleList([nn.Sequential(*[TransformerBlock(channels[2], num_heads[2], expansion_factor)
                                                       for _ in range(num_blocks[2])])])
        self.decoders.append(nn.Sequential(*[TransformerBlock(channels[1], num_heads[1], expansion_factor)
                                             for _ in range(num_blocks[1])]))
        # the channel of last one is not change
        self.decoders.append(nn.Sequential(*[TransformerBlock(channels[1], num_heads[0], expansion_factor)
                                             for _ in range(num_blocks[0])]))

        self.refinement = nn.Sequential(*[TransformerBlock(channels[1], num_heads[0], expansion_factor)
                                          for _ in range(num_refinement)])
        self.output = nn.Conv2d(channels[1], 3, kernel_size=3, padding=1, bias=False)

    def forward(self, x):
        fo = self.embed_conv(x)
        out_enc1 = self.encoders[0](fo)
        out_enc2 = self.encoders[1](self.downs[0](out_enc1))
        out_enc3 = self.encoders[2](self.downs[1](out_enc2))
        out_enc4 = self.encoders[3](self.downs[2](out_enc3))

        out_dec3 = self.decoders[0](self.reduces[0](torch.cat([self.ups[0](out_enc4), out_enc3], dim=1)))
        out_dec2 = self.decoders[1](self.reduces[1](torch.cat([self.ups[1](out_dec3), out_enc2], dim=1)))
        fd = self.decoders[2](torch.cat([self.ups[2](out_dec2), out_enc1], dim=1))
        fr = self.refinement(fd)
        out = self.output(fr) + x
        return out

In [4]:
def psnr(x, y, data_range=255.0):
    x, y = x / data_range, y / data_range
    mse = torch.mean((x - y) ** 2)
    score = - 10 * torch.log10(mse)
    return score


def ssim(x, y, kernel_size=11, kernel_sigma=1.5, data_range=255.0, k1=0.01, k2=0.03):
    x, y = x / data_range, y / data_range
    # average pool image if the size is large enough
    f = max(1, round(min(x.size()[-2:]) / 256))
    if f > 1:
        x, y = F.avg_pool2d(x, kernel_size=f), F.avg_pool2d(y, kernel_size=f)

    # gaussian filter
    coords = torch.arange(kernel_size, dtype=x.dtype, device=x.device)
    coords -= (kernel_size - 1) / 2.0
    g = coords ** 2
    g = (- (g.unsqueeze(0) + g.unsqueeze(1)) / (2 * kernel_sigma ** 2)).exp()
    g /= g.sum()
    kernel = g.unsqueeze(0).repeat(x.size(1), 1, 1, 1)

    # compute
    c1, c2 = k1 ** 2, k2 ** 2
    n_channels = x.size(1)
    mu_x = F.conv2d(x, weight=kernel, stride=1, padding=0, groups=n_channels)
    mu_y = F.conv2d(y, weight=kernel, stride=1, padding=0, groups=n_channels)

    mu_xx, mu_yy, mu_xy = mu_x ** 2, mu_y ** 2, mu_x * mu_y
    sigma_xx = F.conv2d(x ** 2, weight=kernel, stride=1, padding=0, groups=n_channels) - mu_xx
    sigma_yy = F.conv2d(y ** 2, weight=kernel, stride=1, padding=0, groups=n_channels) - mu_yy
    sigma_xy = F.conv2d(x * y, weight=kernel, stride=1, padding=0, groups=n_channels) - mu_xy

    # contrast sensitivity (CS) with alpha = beta = gamma = 1.
    cs = (2.0 * sigma_xy + c2) / (sigma_xx + sigma_yy + c2)
    # structural similarity (SSIM)
    ss = (2.0 * mu_xy + c1) / (mu_xx + mu_yy + c1) * cs
    return ss.mean()

# config

In [None]:
TRAIN_PATH = '/kaggle/input/rain13kdataset/train/train/Rain13K'
TEST_PATH = '/kaggle/input/rain13kdataset/test/test'
SAVE_PATH = './results/Rain13K'
DATA_NAME = 'Rain13K'

# Cấu hình Model Restormer
NUM_BLOCKS = [3, 4, 4, 6]       
NUM_HEADS = [1, 2, 2, 4]         
CHANNELS = [32, 64, 128, 256]  
EXPANSION_FACTOR = 2.66
NUM_REFINEMENT = 2               

NUM_ITER = 50000                                             
BATCH_SIZES =  [8, 6, 4, 2]                           
PATCH_SIZES = [96, 128, 160, 192]                           
MILESTONES = [12500, 25000, 37500]           

LR = 3e-4
WORKERS = 2
SEED = 42
MODEL_FILE = None # Điền đường dẫn .pth nếu muốn train tiếp

In [None]:
if not os.path.exists(SAVE_PATH):
    os.makedirs(SAVE_PATH)

random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

In [7]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
def rgb_to_y(x):
    # Chuyển đổi RGB sang kênh Y (Luminance) để tính chỉ số
    rgb_to_grey = torch.tensor([0.256789, 0.504129, 0.097906], dtype=x.dtype, device=x.device).view(1, -1, 1, 1)
    return torch.sum(x * rgb_to_grey, dim=1, keepdim=True).add(16.0)

def pad_image_needed(img, size):
    # Pad ảnh nếu kích thước nhỏ hơn patch_size
    width, height = TF.get_image_size(img)
    if width < size[1]:
        img = TF.pad(img, [size[1] - width, 0], padding_mode='reflect')
    if height < size[0]:
        img = TF.pad(img, [0, size[0] - height], padding_mode='reflect')
    return img

# Build Dataset

In [None]:
import os
from PIL import Image
import torchvision.transforms.functional as TF
import torchvision.transforms as T
import torch.nn.functional as F
from torch.utils.data import Dataset

class RainDataset(Dataset):
    def __init__(self, input_list, target_list, is_train=True, patch_size=128, length=None):
        self.is_train = is_train
        self.patch_size = patch_size
        
        # Nhận list file trực tiếp từ train_test_split
        self.rain_images = input_list
        self.norain_images = target_list

        self.num = len(self.rain_images)
        # length dùng cho Progressive Learning (nếu không truyền vào thì lấy len thực tế)
        self.sample_num = length if (self.is_train and length is not None) else self.num

    def __len__(self):
        return self.sample_num

    def __getitem__(self, idx):
        real_idx = idx % self.num
        image_name = os.path.basename(self.rain_images[real_idx])
        
        rain_pil = Image.open(self.rain_images[real_idx]).convert('RGB')
        norain_pil = Image.open(self.norain_images[real_idx]).convert('RGB')
        
        rain = TF.to_tensor(rain_pil)
        norain = TF.to_tensor(norain_pil)
        
        h, w = rain.shape[1:]

        if self.is_train:
            # Logic Random Crop cho tập Train
            rain = pad_image_needed(rain, (self.patch_size, self.patch_size))
            norain = pad_image_needed(norain, (self.patch_size, self.patch_size))
            
            i, j, th, tw = T.RandomCrop.get_params(rain, (self.patch_size, self.patch_size))
            rain = TF.crop(rain, i, j, th, tw)
            norain = TF.crop(norain, i, j, th, tw)
            
            if torch.rand(1) < 0.5:
                rain = TF.hflip(rain)
                norain = TF.hflip(norain)
            if torch.rand(1) < 0.5:
                rain = TF.vflip(rain)
                norain = TF.vflip(norain)
        else:
            # Logic Padding cho tập Valid/Test (giữ nguyên ảnh gốc, chỉ pad cho chia hết 8)
            new_h, new_w = ((h + 8) // 8) * 8, ((w + 8) // 8) * 8
            pad_h = new_h - h if h % 8 != 0 else 0
            pad_w = new_w - w if w % 8 != 0 else 0
            rain = F.pad(rain, (0, pad_w, 0, pad_h), 'reflect')
            norain = F.pad(norain, (0, pad_w, 0, pad_h), 'reflect')
        return rain, norain, image_name, h, w

In [None]:
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader

data_path = '/kaggle/input/rain13kdataset/train/train/Rain13K'
test_size = 0.2

input_folder = os.path.join(data_path, 'input')
target_folder = os.path.join(data_path, 'target')

input_files = sorted([os.path.join(input_folder, f) for f in os.listdir(input_folder) 
                      if f.endswith(('.png', '.jpg', '.jpeg'))])
target_files = sorted([os.path.join(target_folder, f) for f in os.listdir(target_folder) 
                       if f.endswith(('.png', '.jpg', '.jpeg'))])

print(f"Tổng số ảnh tìm thấy: {len(input_files)}")

train_inputs, val_inputs, train_targets, val_targets = train_test_split(
    input_files, target_files, test_size=test_size, random_state=42
)

print(f"Số lượng Train: {len(train_inputs)}")
print(f"Số lượng Valid: {len(val_inputs)}")

val_dataset = RainDataset(val_inputs, val_targets, is_train=False)
val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False, num_workers=WORKERS)

Tổng số ảnh tìm thấy: 13711
Số lượng Train: 10968
Số lượng Valid: 2743


# Validation

In [None]:
import pandas as pd
import os
import torch
import torch.nn.functional as F
from tqdm.notebook import tqdm
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR

PATIENCE = 10  
patience_counter = 0 

best_psnr, best_ssim = 0.0, 0.0
results = {'PSNR': [], 'SSIM': [], 'Loss': []}


def validation(net, loader, max_size=384):
    """
    Validation với xử lý ảnh lớn để tránh OOM
    max_size: Resize ảnh về kích thước tối đa trước khi inference
    """
    net.eval()
    total_psnr, total_ssim, count = 0.0, 0.0, 0
    
    with torch.no_grad():
        for rain, norain, _, h, w in tqdm(loader, desc="Validating", leave=False):
            rain, norain = rain.to(DEVICE), norain.to(DEVICE)
            
            original_h, original_w = h.item(), w.item()
            
            # Xử lý ảnh lớn
            if max(original_h, original_w) > max_size:
                scale = max_size / max(original_h, original_w)
                new_h = int(original_h * scale)
                new_w = int(original_w * scale)
                
                # Đảm bảo chia hết cho 8
                new_h = ((new_h + 8) // 8) * 8
                new_w = ((new_w + 8) // 8) * 8
                
                # Resize
                rain_resized = F.interpolate(rain, size=(new_h, new_w), mode='bilinear', align_corners=False)
                norain_resized = F.interpolate(norain, size=(new_h, new_w), mode='bilinear', align_corners=False)
                
                # Forward
                out = net(rain_resized)
                
                # Crop về kích thước chính xác (loại bỏ padding nếu có)
                out = out[:, :, :new_h, :new_w]
                norain_resized = norain_resized[:, :, :new_h, :new_w]
                
                # Clamp & Convert
                out = torch.clamp(out, 0, 1).mul(255).byte()
                norain_final = torch.clamp(norain_resized, 0, 1).mul(255).byte()
                
            else:
                # Ảnh nhỏ, xử lý bình thường
                out = net(rain)
                
                # Crop về kích thước gốc & Clamp
                out = out[:, :, :original_h, :original_w]
                norain_crop = norain[:, :, :original_h, :original_w]
                
                out = torch.clamp(out, 0, 1).mul(255).byte()
                norain_final = torch.clamp(norain_crop, 0, 1).mul(255).byte()
            
            # Tính metrics trên Y channel
            y = rgb_to_y(out.double())
            gt = rgb_to_y(norain_final.double())
            
            total_psnr += psnr(y, gt).item()
            total_ssim += ssim(y, gt).item()
            count += 1
            
            # Xóa tensor để giải phóng RAM (chỉ xóa những cái lớn)
            del rain, norain, out, y, gt, norain_final
    torch.cuda.empty_cache()
    if count == 0:
        return 0.0, 0.0
    return total_psnr/count, total_ssim/count

def save_log_csv(val_psnr, val_ssim, n_iter):
    results['PSNR'].append(val_psnr)
    results['SSIM'].append(val_ssim)
    
    df = pd.DataFrame(data={'Loss': results['Loss'][-1:] if results['Loss'] else [0], 
                            'PSNR': [val_psnr], 
                            'SSIM': [val_ssim]})
    df.to_csv(os.path.join(SAVE_PATH, 'log.csv'), mode='a', header=not os.path.exists(os.path.join(SAVE_PATH, 'log.csv')))
    
    print(f"Iter: {n_iter} | Loss: {results['Loss'][-1]:.4f} | PSNR: {val_psnr:.2f} | SSIM: {val_ssim:.3f}")



# Train

In [None]:
def check_checkpoint(checkpoint_path):
    if not os.path.exists(checkpoint_path):
        print(f"Không tìm thấy checkpoint: {checkpoint_path}")
        return
    
    print("THÔNG TIN CHECKPOINT")
    
    checkpoint = torch.load(checkpoint_path, map_location='cpu')
    
    print(f"Iteration: {checkpoint['iteration']}")
    print(f"PSNR: {checkpoint['psnr']:.2f}")
    print(f"SSIM: {checkpoint['ssim']:.4f}")
    print(f"Keys trong checkpoint: {list(checkpoint.keys())}")
    
    # Kiểm tra optimizer state
    if 'optimizer_state_dict' in checkpoint:
        lr = checkpoint['optimizer_state_dict']['param_groups'][0]['lr']
        print(f"Learning Rate: {lr:.2e}")
    
    print(f"{'='*60}\n")
    
    return checkpoint

In [None]:
import torch
import os
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR

CHECKPOINT_PATH = '/kaggle/input/deraining-restormer-model/keras/default/1/results/Rain13K/best_model.pth'  # Đường dẫn checkpoint
RESUME_TRAINING = True  

print("Khởi tạo model Restormer...")
model = Restormer(NUM_BLOCKS, NUM_HEADS, CHANNELS, NUM_REFINEMENT, EXPANSION_FACTOR).to(DEVICE)
optimizer = AdamW(model.parameters(), lr=LR, weight_decay=1e-4)
lr_scheduler = CosineAnnealingLR(optimizer, T_max=NUM_ITER, eta_min=1e-6)

start_iter = 1
stage_index = 0
total_loss, total_num = 0.0, 0
train_loader_iter = None

if RESUME_TRAINING and os.path.exists(CHECKPOINT_PATH):
    print(f"ĐANG TẢI CHECKPOINT...")
    
    checkpoint = torch.load(CHECKPOINT_PATH, map_location=DEVICE)
    
    model.load_state_dict(checkpoint['model_state_dict'])
    print(f"Đã load model weights")
    
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    print(f"Đã load optimizer state")
    
    lr_scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
    print(f"Đã load scheduler state")
    
    start_iter = checkpoint['iteration'] + 1  
    best_psnr = checkpoint['psnr']
    best_ssim = checkpoint['ssim']
    
    for idx, milestone in enumerate(MILESTONES):
        if start_iter > milestone:
            stage_index = idx + 1
    
    print(f"\nTHÔNG TIN CHECKPOINT:")
    print(f"   Last Iteration: {checkpoint['iteration']}")
    print(f"   Resume từ Iteration: {start_iter}")
    print(f"   Best PSNR: {best_psnr:.2f}")
    print(f"   Best SSIM: {best_ssim:.4f}")
    print(f"   Current Stage: {stage_index + 1}/{len(BATCH_SIZES)}")
    print(f"   Current LR: {optimizer.param_groups[0]['lr']:.2e}")
    print(f"{'='*60}\n")
    
    patience_counter = 0
    
else:
    print("\n⚠️  Không tìm thấy checkpoint hoặc RESUME_TRAINING=False")
    print("   → Training từ đầu\n")
    best_psnr, best_ssim = 0.0, 0.0
    patience_counter = 0

print(f"Bắt đầu/tiếp tục training với Early Stopping (Patience={PATIENCE})...")

train_bar = tqdm(range(start_iter, NUM_ITER + 1), desc="Training", initial=start_iter-1, total=NUM_ITER)

for n_iter in train_bar:
   
    need_new_loader = (
        train_loader_iter is None or  
        n_iter == 1 or 
        (n_iter - 1) in MILESTONES
    )
    
    if need_new_loader:
        current_stage = 0
        for idx, milestone in enumerate(MILESTONES):
            if n_iter > milestone:
                current_stage = idx + 1
        
        stage_index = current_stage
        
        if stage_index < len(MILESTONES):
            end_iter = MILESTONES[stage_index]
        else:
            end_iter = NUM_ITER
        
        start_iter_stage = MILESTONES[stage_index - 1] if stage_index > 0 else 0
        
        # Lấy config cho giai đoạn này
        curr_batch = BATCH_SIZES[stage_index]
        curr_patch = PATCH_SIZES[stage_index]
        length = curr_batch * (end_iter - start_iter_stage)
        
        train_ds = RainDataset(train_inputs, train_targets, is_train=True, 
                               patch_size=curr_patch, length=length)
        train_loader_iter = iter(DataLoader(train_ds, batch_size=curr_batch, 
                                           shuffle=True, num_workers=WORKERS, 
                                           pin_memory=True))
        
        print(f"\n[Giai đoạn {stage_index+1}/{len(BATCH_SIZES)}]")
        print(f"  Patch Size: {curr_patch} | Batch Size: {curr_batch}")
        print(f"  Iterations: {start_iter_stage+1} → {end_iter}")
    
    model.train()
    
    try:
        rain, norain, _, _, _ = next(train_loader_iter)
    except StopIteration:
        train_loader_iter = iter(DataLoader(train_ds, batch_size=curr_batch, 
                                           shuffle=True, num_workers=WORKERS, 
                                           pin_memory=True))
        rain, norain, _, _, _ = next(train_loader_iter)
    
    # Forward
    rain, norain = rain.to(DEVICE), norain.to(DEVICE)
    out = model(rain)
    loss = F.l1_loss(out, norain)
    
    # Backward
    optimizer.zero_grad()
    loss.backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
    optimizer.step()
    lr_scheduler.step()
    
    # Update statistics
    total_num += rain.size(0)
    total_loss += loss.item() * rain.size(0)
    
    # Update progress bar
    current_lr = optimizer.param_groups[0]['lr']
    train_bar.set_postfix({
        'Loss': f"{total_loss/total_num:.4f}",
        'Best_PSNR': f"{best_psnr:.2f}",
        'LR': f"{current_lr:.2e}",
        'Stage': f"{stage_index+1}/{len(BATCH_SIZES)}"
    })
    
    if n_iter % 1000 == 0:
        avg_loss = total_loss / total_num
        results['Loss'].append(avg_loss)
        
        torch.cuda.empty_cache()
        
        print(f"[VALIDATION @ Iter {n_iter}]")
        v_psnr, v_ssim = validation(model, val_loader, max_size=384)
        
        save_log_csv(v_psnr, v_ssim, n_iter)
        
        if v_psnr > best_psnr:
            best_psnr, best_ssim = v_psnr, v_ssim
            
            torch.save({
                'iteration': n_iter,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'scheduler_state_dict': lr_scheduler.state_dict(),
                'psnr': best_psnr,
                'ssim': best_ssim,
                'stage_index': stage_index,
                'patience_counter': 0
            }, os.path.join(SAVE_PATH, 'best_model.pth'))
            
            print(f"[NEW BEST] PSNR: {best_psnr:.2f} | SSIM: {best_ssim:.4f}")
            print(f" → Model đã được lưu!")
            
            patience_counter = 0
        else:
            patience_counter += 1
            print(f"Không cải thiện. Patience: {patience_counter}/{PATIENCE}")
            
            if patience_counter >= PATIENCE:
                print("\n" + "="*60)
                print(f"[EARLY STOPPING]")
                print(f"Model không cải thiện trong {PATIENCE} lần kiểm tra ({PATIENCE*1000} iters)")
                print(f"Best PSNR: {best_psnr:.2f} | Best SSIM: {best_ssim:.4f}")
                print("="*60)
                break
        
        print("="*60 + "\n")
        
        total_loss, total_num = 0.0, 0



Khởi tạo model Restormer...
ĐANG TẢI CHECKPOINT...
Đã load model weights
Đã load optimizer state
Đã load scheduler state

THÔNG TIN CHECKPOINT:
   Last Iteration: 36000
   Resume từ Iteration: 36001
   Best PSNR: 33.18
   Best SSIM: 0.9344
   Current Stage: 3/4
   Current LR: 5.52e-05

Bắt đầu/tiếp tục training với Early Stopping (Patience=10)...


Training:  72%|#######2  | 36000/50000 [00:00<?, ?it/s]


[Giai đoạn 3/4]
  Patch Size: 160 | Batch Size: 4
  Iterations: 25001 → 37500
[VALIDATION @ Iter 37000]


Validating:   0%|          | 0/2743 [00:00<?, ?it/s]

Iter: 37000 | Loss: 0.0226 | PSNR: 33.23 | SSIM: 0.934
[NEW BEST] PSNR: 33.23 | SSIM: 0.9341
 → Model đã được lưu!


[Giai đoạn 4/4]
  Patch Size: 192 | Batch Size: 2
  Iterations: 37501 → 50000
[VALIDATION @ Iter 38000]


Validating:   0%|          | 0/2743 [00:00<?, ?it/s]

Iter: 38000 | Loss: 0.0222 | PSNR: 33.12 | SSIM: 0.934
Không cải thiện. Patience: 1/10

[VALIDATION @ Iter 39000]


Validating:   0%|          | 0/2743 [00:00<?, ?it/s]

Iter: 39000 | Loss: 0.0226 | PSNR: 33.24 | SSIM: 0.935
[NEW BEST] PSNR: 33.24 | SSIM: 0.9345
 → Model đã được lưu!

[VALIDATION @ Iter 40000]


Validating:   0%|          | 0/2743 [00:00<?, ?it/s]

Iter: 40000 | Loss: 0.0224 | PSNR: 33.05 | SSIM: 0.934
Không cải thiện. Patience: 1/10

[VALIDATION @ Iter 41000]


Validating:   0%|          | 0/2743 [00:00<?, ?it/s]

Iter: 41000 | Loss: 0.0220 | PSNR: 33.31 | SSIM: 0.935
[NEW BEST] PSNR: 33.31 | SSIM: 0.9348
 → Model đã được lưu!

[VALIDATION @ Iter 42000]


Validating:   0%|          | 0/2743 [00:00<?, ?it/s]

Iter: 42000 | Loss: 0.0216 | PSNR: 33.27 | SSIM: 0.935
Không cải thiện. Patience: 1/10

[VALIDATION @ Iter 43000]


Validating:   0%|          | 0/2743 [00:00<?, ?it/s]

Iter: 43000 | Loss: 0.0217 | PSNR: 33.29 | SSIM: 0.935
Không cải thiện. Patience: 2/10

[VALIDATION @ Iter 44000]


Validating:   0%|          | 0/2743 [00:00<?, ?it/s]

Iter: 44000 | Loss: 0.0219 | PSNR: 33.32 | SSIM: 0.935
[NEW BEST] PSNR: 33.32 | SSIM: 0.9349
 → Model đã được lưu!

[VALIDATION @ Iter 45000]


Validating:   0%|          | 0/2743 [00:00<?, ?it/s]

Iter: 45000 | Loss: 0.0221 | PSNR: 33.35 | SSIM: 0.935
[NEW BEST] PSNR: 33.35 | SSIM: 0.9352
 → Model đã được lưu!

[VALIDATION @ Iter 46000]


Validating:   0%|          | 0/2743 [00:00<?, ?it/s]

Iter: 46000 | Loss: 0.0218 | PSNR: 33.36 | SSIM: 0.935
[NEW BEST] PSNR: 33.36 | SSIM: 0.9353
 → Model đã được lưu!

[VALIDATION @ Iter 47000]


Validating:   0%|          | 0/2743 [00:00<?, ?it/s]

Iter: 47000 | Loss: 0.0219 | PSNR: 33.40 | SSIM: 0.935
[NEW BEST] PSNR: 33.40 | SSIM: 0.9354
 → Model đã được lưu!

[VALIDATION @ Iter 48000]


Validating:   0%|          | 0/2743 [00:00<?, ?it/s]

Iter: 48000 | Loss: 0.0220 | PSNR: 33.35 | SSIM: 0.935
Không cải thiện. Patience: 1/10

[VALIDATION @ Iter 49000]


Validating:   0%|          | 0/2743 [00:00<?, ?it/s]

Iter: 49000 | Loss: 0.0218 | PSNR: 33.37 | SSIM: 0.935
Không cải thiện. Patience: 2/10

[VALIDATION @ Iter 50000]


Validating:   0%|          | 0/2743 [00:00<?, ?it/s]

Iter: 50000 | Loss: 0.0219 | PSNR: 33.38 | SSIM: 0.935
Không cải thiện. Patience: 3/10



In [None]:
print("HOÀN THÀNH TRAINING!")
print(f"Best PSNR: {best_psnr:.2f} | Best SSIM: {best_ssim:.4f}")
print(f"Model được lưu tại: {os.path.join(SAVE_PATH, 'best_model.pth')}")
print("="*60)

torch.save({
    'iteration': n_iter,
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'scheduler_state_dict': lr_scheduler.state_dict(),
    'psnr': v_psnr,
    'ssim': v_ssim
}, os.path.join(SAVE_PATH, 'final_model.pth'))

print(f"Final model (iter {n_iter}) đã được lưu!")

HOÀN THÀNH TRAINING!
Best PSNR: 33.40 | Best SSIM: 0.9354
Model được lưu tại: ./results/Rain13K/best_model.pth
Final model (iter 50000) đã được lưu!
