In [1]:
# !git clone https://github.com/rmaahin/Video-Super-Resolution-Using-Transformers.git

Cloning into 'Video-Super-Resolution-Using-Transformers'...
remote: Enumerating objects: 48743, done.[K
remote: Total 48743 (delta 0), reused 0 (delta 0), pack-reused 48743 (from 1)[K
Receiving objects: 100% (48743/48743), 1.71 GiB | 41.31 MiB/s, done.
Updating files: 100% (48018/48018), done.


In [3]:
!pip install tqdm

Collecting tqdm
  Downloading tqdm-4.67.1-py3-none-any.whl.metadata (57 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m57.7/57.7 kB[0m [31m704.1 kB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
[?25hDownloading tqdm-4.67.1-py3-none-any.whl (78 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m78.5/78.5 kB[0m [31m2.9 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: tqdm
Successfully installed tqdm-4.67.1
[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.3.1[0m[39;49m -> [0m[32;49m25.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpython -m pip install --upgrade pip[0m


In [4]:
import os
import glob
import math
import copy
import yaml
from tqdm import tqdm
from PIL import Image
import torch
from torch.utils.data import Dataset
from torchvision.transforms import ToTensor
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision.utils import save_image

In [5]:
os.chdir('Video-Super-Resolution-Using-Transformers')

In [6]:
class VSRDataset(Dataset):
    def __init__(self, root_dir, sequence_length=5, scale=3, mode='train'):
        """
        Args:
            root_dir (str): Root folder containing video subfolders.
            sequence_length (int): Number of consecutive LR frames as input.
            scale (int): Resolution upscaling factor.
            mode (str): 'train' or 'test'. For 'train' random center, 'test' sliding window.
        """
        self.root_dir = root_dir
        self.sequence_length = sequence_length
        self.half = sequence_length // 2
        self.scale = scale
        self.mode = mode
        self.videos = []

        self.samples = []
        self.to_tensor = ToTensor()

        self._build_index()

    def _build_index(self):
        video_folders = sorted(glob.glob(os.path.join(self.root_dir, "*")))

        for vid_path in video_folders:
            lr_folder = os.path.join(vid_path, "lr_images")
            hr_folder = os.path.join(vid_path, "hr_images")
            lr_frames = sorted(glob.glob(os.path.join(lr_folder, "*.png")))

            if len(lr_frames) < self.sequence_length:
                continue

            for i in range(self.half, len(lr_frames) - self.half):
                self.samples.append({
                    "video": os.path.basename(vid_path),
                    "center_index": i,
                    "lr_paths": [os.path.join(lr_folder, f"{idx:08d}.png")
                                 for idx in range(i - self.half, i + self.half + 1)],
                    "hr_path": os.path.join(hr_folder, f"{i:08d}.png")
                })

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        sample = self.samples[idx]
        lr_seq = [self.to_tensor(Image.open(p)) for p in sample["lr_paths"]]
        hr_img = self.to_tensor(Image.open(sample["hr_path"]))

        lr_seq = torch.stack(lr_seq, dim=0)  # [T, C, H, W]
        return {
            "lr": lr_seq,                   # input sequence [T, 3, H, W]
            "hr": hr_img,                   # target HR frame [3, H*scale, W*scale]
            "video": sample["video"],
            "index": sample["center_index"]
        }

In [7]:
class STCSA(nn.Module):
    def __init__(self, channels, patch_size=3, max_frames=7):  # <-- NEW
        super().__init__()
        self.channels = channels
        self.query_conv = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
        self.key_conv   = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
        self.value_conv = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
        self.out_proj   = nn.Conv2d(channels, channels, kernel_size=3, padding=1)

        self.temporal_embed = nn.Parameter(torch.zeros(1, max_frames, channels, 1, 1))  # <-- NEW

    def forward(self, x):
        # x: [B, T, C, H, W]
        B, T, C, H, W = x.shape

        if T != self.temporal_embed.shape[1]:
            raise ValueError(f"Temporal embed expects {self.temporal_embed.shape[1]} frames, got {T}")

        x = x + self.temporal_embed[:, :T]  # broadcast addition

        x_reshape = x.view(B*T, C, H, W)
        Q = self.query_conv(x_reshape)
        K = self.key_conv(x_reshape)
        V = self.value_conv(x_reshape)

        Q = Q.view(B, T, C, H*W).permute(0, 2, 1, 3).reshape(B, C, -1)
        K = K.view(B, T, C, H*W).permute(0, 2, 3, 1).reshape(B, C, -1)
        V = V.view(B, T, C, H*W).permute(0, 2, 1, 3).reshape(B, C, -1)

        attn_weights = torch.softmax(torch.bmm(Q.transpose(1,2), K) / (C ** 0.5), dim=-1)
        out = torch.bmm(attn_weights, V.transpose(1,2))
        out = out.transpose(1,2).reshape(B, C, T, H, W).permute(0, 2, 1, 3, 4)

        out = self.out_proj(out.view(B*T, C, H, W)).view(B, T, C, H, W)
        return x + out  # Residual

In [8]:
def generate_grid(B, H, W, device):
    y, x = torch.meshgrid(torch.arange(H), torch.arange(W), indexing="ij")
    grid = torch.stack((x, y), 2).float()  # [H, W, 2]
    grid = grid.unsqueeze(0).repeat(B, 1, 1, 1)  # [B, H, W, 2]
    return grid.to(device)

def warp(x, flow):
    """
    Warp an image or feature map with optical flow
    x: [B, C, H, W]
    flow: [B, 2, H, W] in pixels
    """
    B, C, H, W = x.size()
    grid = generate_grid(B, H, W, x.device)  # [B, H, W, 2]
    grid = grid + flow.permute(0, 2, 3, 1)  # add flow
    grid[:, :, :, 0] = 2.0 * grid[:, :, :, 0] / (W - 1) - 1.0
    grid[:, :, :, 1] = 2.0 * grid[:, :, :, 1] / (H - 1) - 1.0
    return F.grid_sample(x, grid, mode='bilinear', padding_mode='border', align_corners=True)

class FlowEstimator(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(channels * 2, channels, 3, 1, 1),
            nn.ReLU(),
            nn.Conv2d(channels, 2, 3, 1, 1)  # Output flow: 2 channels
        )

    def forward(self, ref, nbr):
        x = torch.cat([ref, nbr], dim=1)
        flow = self.encoder(x)
        return flow

class BOFF(nn.Module):
    def __init__(self, channels, sequence_length): # Add sequence_length as argument
        super().__init__()
        self.flow_net = FlowEstimator(channels)
        self.fuser = nn.Conv2d(channels * sequence_length, channels, 3, 1, 1) # Use sequence_length
        self.norm = nn.LayerNorm([channels, 1, 1])
        self.sequence_length = sequence_length # Store sequence_length

    def forward(self, x):
        # x: [B, T, C, H, W]
        B, T, C, H, W = x.size()
        center_idx = T // 2
        ref = x[:, center_idx]  # [B, C, H, W]

        warped_feats = [ref]  # include center frame unwarped

        for i in range(T):
            if i == center_idx:
                continue
            nbr = x[:, i]
            flow = self.flow_net(ref, nbr)
            warped = warp(nbr, flow)
            warped_feats.append(warped)

        aligned = torch.cat(warped_feats, dim=1)  # [B, C*T, H, W]
        fused = self.fuser(aligned)  # [B, C, H, W]

        # Expand to [B, T, C, H, W] with same fused output
        return x + fused.unsqueeze(1).expand(-1, T, -1, -1, -1)

In [9]:
class Reconstructor(nn.Module):
    def __init__(self, in_channels=64, out_channels=3, scale=3):
        super().__init__()
        self.conv = nn.Conv2d(in_channels, out_channels * (scale ** 2), kernel_size=3, padding=1)
        self.upsample = nn.PixelShuffle(scale)

        self.refine = nn.Sequential(
            nn.Conv2d(out_channels, 32, 3, 1, 1),
            nn.ReLU(inplace=True),
            nn.Conv2d(32, out_channels, 3, 1, 1)
        )

    def forward(self, x):
        x = self.conv(x)
        x = self.upsample(x)
        x = self.refine(x)  # <--- extra sharpness refinement
        return x

In [10]:
class FeatureExtractor(nn.Module):
    def __init__(self, in_channels=3, feat_channels=64):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Conv2d(in_channels, feat_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(feat_channels, feat_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):  # x: [B, T, C, H, W]
        B, T, C, H, W = x.shape
        x = x.view(B * T, C, H, W)
        feats = self.layers(x)  # apply sequential
        return feats.view(B, T, -1, H, W)  # [B, T, C, H, W]

class VSRTransformer(nn.Module):
    def __init__(self, in_channels=3, feat_channels=64, scale=3, num_blocks=4, sequence_length=5): # Add sequence_length
        super().__init__()
        self.feat_extractor = FeatureExtractor(in_channels, feat_channels)
        self.blocks = nn.ModuleList([
            nn.Sequential(
                STCSA(feat_channels, max_frames=sequence_length),
                BOFF(feat_channels, sequence_length) # Pass sequence_length to BOFF
            )
            for _ in range(num_blocks)
        ])
        self.reconstructor = Reconstructor(feat_channels, in_channels, scale)

    def forward(self, x):  # x: [B, T, 3, H, W]
        feats = self.feat_extractor(x)  # [B, T, C, H, W]
        for block in self.blocks:
            feats = block(feats)
        center = feats[:, feats.shape[1] // 2]  # [B, C, H, W]
        sr = self.reconstructor(center)
        return sr  # [B, 3, H*scale, W*scale]

In [11]:
class CharbonnierLoss(nn.Module):
    def __init__(self, eps=1e-3):
        super(CharbonnierLoss, self).__init__()
        self.eps = eps

    def forward(self, pred, gt):
        diff = pred - gt
        loss = torch.mean(torch.sqrt(diff * diff + self.eps * self.eps))
        return loss

In [12]:
def calc_psnr(sr, hr):
    sr = torch.clamp(sr, 0.0, 1.0)
    hr = torch.clamp(hr, 0.0, 1.0)
    mse = torch.mean((sr - hr) ** 2)
    if mse == 0:
        return float("inf")
    return 20 * math.log10(1.0 / math.sqrt(mse.item()))

In [13]:
def load_config(path):
    with open(path, 'r') as f:
        return yaml.safe_load(f)

In [17]:
config = load_config("data/config.yaml")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Dataset
dataset = VSRDataset(
    root_dir=config["dataset_path"],
    sequence_length=config["sequence_length"],
    scale=config["scale"]
)
loader = DataLoader(dataset, batch_size=config["batch_size"], shuffle=True, num_workers=2)

# Model
model = VSRTransformer(
    in_channels=3,
    feat_channels=config["feat_channels"],
    scale=config["scale"],
    num_blocks=config["num_blocks"],
    sequence_length=config["sequence_length"]
).to(device)

# Validation Dataset
val_dataset = VSRDataset(
    root_dir=config["val_dataset_path"],  # NEW
    sequence_length=config["sequence_length"],
    scale=config["scale"],
    mode='test'  # ensure no shuffle
)
val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False)

# Early Stopping
best_val_loss = float('inf')
patience = 5  # NEW: stop if no improvement for 5 epochs
trigger_times = 0
best_model = None

# Loss & Optimizer
criterion = CharbonnierLoss()
optimizer = optim.Adam(model.parameters(), lr=config["lr"])

# Training loop
for epoch in range(config["epochs"]):
    model.train()
    epoch_loss = 0.0

    pbar = tqdm(loader, desc=f"Epoch {epoch+1}/{config['epochs']}", leave=False)
    for batch in pbar:
        lr = batch["lr"].to(device)
        hr = batch["hr"].to(device)

        sr = model(lr)
        sr = torch.clamp(sr, 0.0, 1.0)
        hr = torch.clamp(hr, 0.0, 1.0)
        loss = criterion(sr, hr)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()
        pbar.set_postfix(loss=loss.item())

    avg_train_loss = epoch_loss / len(loader)
    print(f"[Epoch {epoch+1}] Avg Train Loss: {avg_train_loss:.4f}")

    # Validation
    model.eval()
    val_loss = 0.0
    val_psnr = 0.0
    with torch.no_grad():
        for batch in val_loader:
            lr = batch["lr"].to(device)
            hr = batch["hr"].to(device)
            sr = model(lr)
            sr = torch.clamp(sr, 0.0, 1.0)
            hr = torch.clamp(hr, 0.0, 1.0)
            loss = criterion(sr, hr)
            val_loss += loss.item()
            val_psnr += calc_psnr(sr[0], hr[0])

    avg_val_loss = val_loss / len(val_loader)
    avg_val_psnr = val_psnr / len(val_loader)
    print(f"[Epoch {epoch+1}] Val Loss: {avg_val_loss:.4f} | Val PSNR: {avg_val_psnr:.2f} dB")

    # Early Stopping
    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        trigger_times = 0
        best_model = copy.deepcopy(model.state_dict())  # save best model
    else:
        trigger_times += 1
        if trigger_times >= patience:
            print(f"Early stopping triggered at epoch {epoch+1}")
            break

    # Save sample output
    if (epoch + 1) % 5 == 0:
        save_image(sr[0].cpu(), f"outputepoch_modified{epoch+1}.png")

    # Save checkpoint
    if (epoch + 1) % config["save_every"] == 0:
        torch.save(model.state_dict(), f"weights/modified-model/vsrepoch_modified{epoch+1}.pth")

torch.save(best_model, "weights/modified-model/vsr_best_model.pth")  # NEW
print("✅ Best model saved to vsr_best_model.pth")

                                                                            

[Epoch 1] Avg Train Loss: 0.0450
[Epoch 1] Val Loss: 0.0439 | Val PSNR: 23.06 dB


                                                                            

[Epoch 2] Avg Train Loss: 0.0362
[Epoch 2] Val Loss: 0.0408 | Val PSNR: 23.52 dB


                                                                            

[Epoch 3] Avg Train Loss: 0.0344
[Epoch 3] Val Loss: 0.0404 | Val PSNR: 23.70 dB


                                                                            

[Epoch 4] Avg Train Loss: 0.0334
[Epoch 4] Val Loss: 0.0383 | Val PSNR: 23.93 dB


                                                                            

[Epoch 5] Avg Train Loss: 0.0328
[Epoch 5] Val Loss: 0.0374 | Val PSNR: 24.04 dB


                                                                            

[Epoch 6] Avg Train Loss: 0.0323
[Epoch 6] Val Loss: 0.0371 | Val PSNR: 24.16 dB


                                                                             

[Epoch 7] Avg Train Loss: 0.0320
[Epoch 7] Val Loss: 0.0368 | Val PSNR: 24.21 dB


                                                                            

[Epoch 8] Avg Train Loss: 0.0316
[Epoch 8] Val Loss: 0.0369 | Val PSNR: 24.24 dB


                                                                            

[Epoch 9] Avg Train Loss: 0.0314
[Epoch 9] Val Loss: 0.0363 | Val PSNR: 24.32 dB


                                                                             

[Epoch 10] Avg Train Loss: 0.0312
[Epoch 10] Val Loss: 0.0359 | Val PSNR: 24.36 dB
✅ Best model saved to vsr_best_model.pth


In [15]:
config = load_config("data/config.yaml")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Dataset
dataset = VSRDataset(
    root_dir=config["test_dataset_path"],
    sequence_length=config["sequence_length"],
    scale=config["scale"],
    mode='test'
)
loader = DataLoader(dataset, batch_size=1, shuffle=False)

# Model
model = VSRTransformer(
    in_channels=3,
    feat_channels=config["feat_channels"],
    scale=config["scale"],
    num_blocks=config["num_blocks"]
).to(device)

model.load_state_dict(torch.load("weights/modified-model/vsr_best_model.pth", map_location=device))
model.eval()

os.makedirs("test_outputs", exist_ok=True)

psnr_total = 0.0
with torch.no_grad():
    for i, batch in enumerate(tqdm(loader, desc="Testing")):
        lr = batch["lr"].to(device)
        hr = batch["hr"].to(device)

        sr = model(lr)
        sr = torch.clamp(sr, 0.0, 1.0)
        hr = torch.clamp(hr, 0.0, 1.0)

        psnr = calc_psnr(sr[0], hr[0])
        psnr_total += psnr

        save_image(sr[0].cpu(), f"test_outputs/sr_{i:04d}.png")
        save_image(hr[0].cpu(), f"test_outputs/hr_{i:04d}.png")

avg_psnr = psnr_total / len(loader)
print(f"\nAvg PSNR over test set: {avg_psnr:.2f} dB")

Testing: 100%|██████████| 2400/2400 [04:08<00:00,  9.65it/s]


Avg PSNR over test set: 24.75 dB



