<a href="https://colab.research.google.com/github/snehith-3939/FirstRepo/blob/main/vsr%2B%2Bscratch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
pip install torchmetrics

Collecting torchmetrics
  Downloading torchmetrics-1.7.1-py3-none-any.whl.metadata (21 kB)
Collecting lightning-utilities>=0.8.0 (from torchmetrics)
  Downloading lightning_utilities-0.14.3-py3-none-any.whl.metadata (5.6 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch>=2.0.0->torchmetrics)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch>=2.0.0->torchmetrics)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch>=2.0.0->torchmetrics)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch>=2.0.0->torchmetrics)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch>=2.0.0->torchmetrics)
  D

#model

In [None]:
!pip install wandb -qU

In [None]:
# Log in to your W&B account
import wandb
import random
import math

In [None]:
wandb.login()

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.ops import DeformConv2d

In [None]:
class ChannelAttention(nn.Module):
    def __init__(self, channels, reduction=16):
        super().__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Linear(channels, channels // reduction),
            nn.ReLU(),
            nn.Linear(channels // reduction, channels),
            nn.Sigmoid()
        )

    def forward(self, x):
        b, c, _, _ = x.size()
        y = self.avg_pool(x).view(b, c)
        y = self.fc(y).view(b, c, 1, 1)
        return x * y.expand_as(x)

In [None]:
class ResidualBlock(nn.Module):
    def __init__(self, in_channels=64, reduction=16):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, in_channels, 3, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(in_channels)
        self.relu = nn.LeakyReLU(0.1, inplace=True)
        self.conv2 = nn.Conv2d(in_channels, in_channels, 3, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(in_channels)

        # Channel Attention
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Linear(in_channels, in_channels//reduction, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(in_channels//reduction, in_channels, bias=False),
            nn.Sigmoid()
        )

    def forward(self, x):
        identity = x
        out = self.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))

        # Channel attention
        b, c, _, _ = out.size()
        y = self.avg_pool(out).view(b, c)
        y = self.fc(y).view(b, c, 1, 1)
        out = out * y.expand_as(out)

        out += identity
        return out

In [None]:
class FlowEstimation(nn.Module):
    def __init__(self, num_feat=64, use_decoder=True):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(2*num_feat, num_feat, 3, padding=1, bias=False),
            nn.BatchNorm2d(num_feat),
            nn.LeakyReLU(0.1),
            nn.Conv2d(num_feat, num_feat, 3, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(num_feat),
            nn.LeakyReLU(0.1),
        )

        self.flow_pred = nn.Conv2d(num_feat, 2, 3, padding=1, bias=False)
        self.use_decoder = use_decoder
        if use_decoder:
            self.decoder = nn.Sequential(
                nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
                nn.Conv2d(2, 2, 3, padding=1),
                nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
                nn.Conv2d(2, 2, 3, padding=1),
                nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
                nn.Conv2d(2, 2, 3, padding=1),
            )

    def forward(self, x1, x2):
        x = torch.cat([x1, x2], dim=1)
        x = self.encoder(x)
        flow = self.flow_pred(x)
        if self.use_decoder:
            flow = self.decoder(flow)
        return flow

In [None]:
class BidirectionalPropagation(nn.Module):
    def __init__(self, num_feat, num_block):
        super().__init__()
        self.deform_align = DeformableAlignment(num_feat)
        self.fuse_conv = nn.Conv2d(2*num_feat, num_feat, 3, padding=1)
        self.blocks = nn.Sequential(*[ResidualBlock(num_feat) for _ in range(num_block)])

    def forward(self, feat, hidden, flow):
        if hidden is None:
            hidden = torch.zeros_like(feat)
        aligned_hidden = self.deform_align(feat, hidden, flow)
        fused = self.fuse_conv(torch.cat([feat, aligned_hidden], dim=1))
        out = self.blocks(fused)
        return out, out

In [None]:
class DeformableAlignment(nn.Module):
    def __init__(self, num_feat=64):
        super().__init__()
        self.offset_conv = nn.Sequential(
            nn.Conv2d(num_feat*2 + 2, num_feat, 3, padding=1),
            nn.LeakyReLU(0.1),
            nn.Conv2d(num_feat, num_feat, 3, padding=1),
            nn.LeakyReLU(0.1),
            nn.Conv2d(num_feat, 27, 3, padding=1),
        )
        self.deform_conv = DeformConv2d(num_feat, num_feat, 3, padding=1)

    def forward(self, feat, neighbor_feat, flow):
        B, C, H, W = feat.shape

        # Generate grid with proper dimensions
        affine_matrix = torch.eye(2, 3, device=feat.device).unsqueeze(0).expand(B, -1, -1)
        grid = F.affine_grid(affine_matrix, neighbor_feat.size(), align_corners=False)

        # Upsample flow to match feature resolution
        upsampled_flow = F.interpolate(flow, size=(H, W), mode='bilinear', align_corners=False)
        warped_grid = grid + upsampled_flow.permute(0, 2, 3, 1)

        # Warp neighbor_feat using the warped_grid
        warped_feat = F.grid_sample(neighbor_feat, warped_grid, mode='bilinear', padding_mode='border', align_corners=False)

        # Predict offsets and masks
        offset_mask = self.offset_conv(torch.cat([feat, warped_feat, upsampled_flow], dim=1))
        offset = offset_mask[:, :18, :, :]
        mask = torch.sigmoid(offset_mask[:, 18:, :, :])

        # Deformable convolution
        aligned_feat = self.deform_conv(warped_feat, offset, mask)
        return aligned_feat

In [None]:
class BasicVSRPlusPlus(nn.Module):
    def __init__(self, scale=4, num_feat=64, num_block=30):
        super().__init__()
        self.scale = scale

        # Feature extraction
        self.feat_extract = nn.Sequential(
            nn.Conv2d(3, num_feat, 3, padding=1),
            nn.LeakyReLU(0.1),
            nn.Conv2d(num_feat, num_feat, 3, padding=1),
        )

        # Propagation
        self.forward_prop = BidirectionalPropagation(num_feat, num_block//2)
        self.backward_prop = BidirectionalPropagation(num_feat, num_block//2)

        # Flow estimation
        self.flow_estimation = FlowEstimation(num_feat)

        # Fusion and reconstruction (Added ChannelAttention)
        self.fusion = nn.Conv2d(2*num_feat, num_feat, 3, padding=1)
        self.reconstruction = nn.Sequential(
            ResidualBlock(num_feat),
            nn.Conv2d(num_feat, num_feat, 3, padding=1),
            ChannelAttention(num_feat),  # Added
            nn.Conv2d(num_feat, 3*(scale**2), 3, padding=1),
            nn.PixelShuffle(scale),
            nn.Conv2d(3, 3, 3, padding=1),
        )

    def forward(self, lr_seq):
        B, T, C, H, W = lr_seq.shape

        # Feature extraction
        lr_feats = [self.feat_extract(lr_seq[:, t]) for t in range(T)]

        # Compute flows
        forward_flows, backward_flows = [], []
        for t in range(T):
            if t < T-1:
                fwd_flow = self.flow_estimation(lr_feats[t], lr_feats[t+1])
            else:
                fwd_flow = torch.zeros_like(forward_flows[-1]) if forward_flows else torch.zeros(B,2,H,W, device=lr_seq.device)
            forward_flows.append(fwd_flow)

            if t > 0:
                bwd_flow = self.flow_estimation(lr_feats[t], lr_feats[t-1])
            else:
                bwd_flow = torch.zeros_like(backward_flows[-1]) if backward_flows else torch.zeros(B,2,H,W, device=lr_seq.device)
            backward_flows.append(bwd_flow)

        # Forward propagation
        forward_feats, hidden = [], None
        for t in range(T):
            feat, hidden = self.forward_prop(lr_feats[t], hidden, forward_flows[t])
            forward_feats.append(feat)

        # Backward propagation
        backward_feats, hidden = [], None
        for t in reversed(range(T)):
            feat, hidden = self.backward_prop(lr_feats[t], hidden, backward_flows[t])
            backward_feats.insert(0, feat)

        # Fusion and reconstruction
        sr_outputs = []
        for t in range(T):
            fused = self.fusion(torch.cat([forward_feats[t], backward_feats[t]], dim=1))
            sr = self.reconstruction(fused)
            sr_outputs.append(sr)

        return torch.stack(sr_outputs, dim=1)

In [None]:
# Add at the top
!pip install opencv-python-headless tqdm  # Required for image processing
import cv2
from tqdm import tqdm



In [None]:
import os
import cv2
import random
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import zipfile
import numpy as np
from tqdm import tqdm

#dataset preprocessing

In [None]:
!pip install -q gdown

# Use gdown to download the ZIP file from shared Drive link
file_id = "17iJgEQHkoT9FQKgvarslGXvK_hcxmKD8"
!gdown --id {file_id} --output images.zip

Downloading...
From (original): https://drive.google.com/uc?id=17iJgEQHkoT9FQKgvarslGXvK_hcxmKD8
From (redirected): https://drive.google.com/uc?id=17iJgEQHkoT9FQKgvarslGXvK_hcxmKD8&confirm=t&uuid=50d2580b-56ea-41ee-b2c2-453bb30c5320
To: /content/images.zip
100% 9.82G/9.82G [03:55<00:00, 41.7MB/s]


In [None]:
def unzip_file(zip_filepath, extract_dir):
  """Unzips a file to a specified directory.

  Args:
    zip_filepath: Path to the zip file.
    extract_dir: Directory to extract the contents to.
  """
  try:
    with zipfile.ZipFile(zip_filepath, 'r') as zip_ref:
      zip_ref.extractall(extract_dir)
    print(f"Successfully unzipped {zip_filepath} to {extract_dir}")
  except FileNotFoundError:
    print(f"Error: File not found at {zip_filepath}")
  except zipfile.BadZipFile:
    print(f"Error: Invalid zip file at {zip_filepath}")
  except Exception as e:
    print(f"An unexpected error occurred: {e}")

# Example usage (replace with your actual file paths):
unzip_file("/content/images.zip", "/content/dataset")


Successfully unzipped /content/images.zip to /content/dataset


In [None]:
import os
from PIL import Image

root_folder = "dataset"  # Main folder with subfolders like '000', '001', etc.

# Recursively collect all PNG/JPG images
image_files = []
for dirpath, _, filenames in os.walk(root_folder):
    for file in filenames:
        if file.lower().endswith(('.jpg', '.jpeg', '.png')):
            image_files.append(os.path.join(dirpath, file))

# Gather info
formats = {}
resolutions = {}
total_size_kb = 0

for fpath in image_files:
    try:
        img = Image.open(fpath)
        fmt = img.format
        size = img.size

        # Count format and resolution
        formats[fmt] = formats.get(fmt, 0) + 1
        resolutions[size] = resolutions.get(size, 0) + 1

        # File size
        total_size_kb += os.path.getsize(fpath) / 1024
    except Exception as e:
        print(f"❌ Could not open {fpath}: {e}")

# Print summary
print(f"\n📁 Total images found: {len(image_files)}")
print(f"🧾 Total dataset size: {total_size_kb:.2f} KB")

print("\n🖼️ Image formats:")
for fmt, count in formats.items():
    print(f" - {fmt}: {count} images")

print("\n📏 Image resolutions:")
for res, count in resolutions.items():
    print(f" - {res[0]}x{res[1]}: {count} images")



📁 Total images found: 24000
🧾 Total dataset size: 9590304.56 KB

🖼️ Image formats:
 - PNG: 24000 images

📏 Image resolutions:
 - 640x360: 24000 images


In [None]:
import os
import shutil
import random
from PIL import Image

# Source and destination folders
original_root = "dataset"
output_root = "main_dataset/hr_dataset"

# Get all folders (each folder = one video)
video_folders = sorted([f for f in os.listdir(original_root) if os.path.isdir(os.path.join(original_root, f))])

# Select half randomly
random.seed(42)
half_folders = random.sample(video_folders, len(video_folders) // 2)

# Create output folder
os.makedirs(output_root, exist_ok=True)

for folder in half_folders:
    src_folder = os.path.join(original_root, folder)
    dst_folder = os.path.join(output_root, folder)
    os.makedirs(dst_folder, exist_ok=True)

    for filename in os.listdir(src_folder):
        src_path = os.path.join(src_folder, filename)

        if filename.lower().endswith(".png"):
            # Convert PNG → JPEG (high quality)
            dst_filename = filename.rsplit('.', 1)[0] + ".jpg"
            dst_path = os.path.join(dst_folder, dst_filename)

            try:
                img = Image.open(src_path).convert("RGB")
                img.save(dst_path, "JPEG", quality=95)  # High quality, minimal compression
            except Exception as e:
                print(f"❌ Failed to convert {src_path}: {e}")

        elif filename.lower().endswith((".jpg", ".jpeg")):
            # Copy existing JPEGs as-is
            dst_path = os.path.join(dst_folder, filename)
            shutil.copy2(src_path, dst_path)

print(f"✅ Copied and converted {len(half_folders)} folders to '{output_root}' (PNG → JPEG, no compression)")


✅ Copied and converted 120 folders to 'main_dataset/hr_dataset' (PNG → JPEG, no compression)


In [None]:
from PIL import Image
import os

# Set your dataset path
root_folder = "main_dataset/hr_dataset"

# Collect all image file paths
image_files = []
for dirpath, _, filenames in os.walk(root_folder):
    for file in filenames:
        if file.lower().endswith(('.jpg', '.jpeg')):
            image_files.append(os.path.join(dirpath, file))

# Analyze dataset
formats = {}
resolutions = {}
total_size_kb = 0

for fpath in image_files:
    try:
        img = Image.open(fpath)
        fmt = img.format
        size = img.size

        formats[fmt] = formats.get(fmt, 0) + 1
        resolutions[size] = resolutions.get(size, 0) + 1
        total_size_kb += os.path.getsize(fpath) / 1024
    except Exception as e:
        print(f"❌ Could not open {fpath}: {e}")

# Print result like your example
print(f"📁 Total images found: {len(image_files)}")
print(f"🧾 Total dataset size: {total_size_kb:.2f} KB\n")

print("🖼️ Image formats:")
for fmt, count in formats.items():
    print(f" - {fmt}: {count} images")

print("\n📏 Image resolutions:")
for res, count in resolutions.items():
    print(f" - {res[0]}x{res[1]}: {count} images")


📁 Total images found: 12000
🧾 Total dataset size: 1312686.03 KB

🖼️ Image formats:
 - JPEG: 12000 images

📏 Image resolutions:
 - 640x360: 12000 images


In [None]:
import os
import random
import matplotlib.pyplot as plt
from PIL import Image

# Dataset path
dataset_path = "half_videos_png2jpeg"

# Get all image paths
image_paths = []
for dirpath, _, filenames in os.walk(dataset_path):
    for file in filenames:
        if file.lower().endswith(('.jpg', '.jpeg')):
            image_paths.append(os.path.join(dirpath, file))

In [None]:
import os
from PIL import Image
from tqdm import tqdm

def generate_lr_images(hr_dir, lr_dir, scale):
    for root, _, files in os.walk(hr_dir):
        for fname in tqdm(files):
            hr_path = os.path.join(root, fname)

            try:
                hr_img = Image.open(hr_path).convert("RGB")
                w, h = hr_img.size
                lr_img = hr_img.resize((w // scale, h // scale), Image.BICUBIC)

                # Maintain directory structure
                rel_path = os.path.relpath(hr_path, hr_dir)
                lr_path = os.path.join(lr_dir, rel_path)
                os.makedirs(os.path.dirname(lr_path), exist_ok=True)

                lr_img.save(lr_path)
            except Exception as e:
                print(f"Error processing {fname}: {e}")

In [None]:
# generate_lr_images("/content/main_dataset/hr_dataset", "/content/main_dataset/lr_dataset", scale=4)

In [None]:
def validate_pairs(hr_root, lr_root, scale):
    for seq in tqdm(os.listdir(hr_root)):
        hr_seq = os.path.join(hr_root, seq)
        lr_seq = os.path.join(lr_root, seq)

        for frame in os.listdir(hr_seq):
            hr_img = Image.open(os.path.join(hr_seq, frame))
            lr_img = Image.open(os.path.join(lr_seq, frame))

            # Validate scale factor
            assert hr_img.width == lr_img.width * scale
            assert hr_img.height == lr_img.height * scale

validate_pairs("/content/train_hr", "/content/train_lr", 4)
validate_pairs("/content/val_hr", "/content/val_lr", 4)

#load dataset

In [None]:
import os
import random
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms

class REDSDataset(Dataset):
    def __init__(self, hr_root, lr_root, scale=4, train=True, seq_length=5, crop_size=256):
        self.hr_root = hr_root
        self.lr_root = lr_root
        self.scale = scale
        self.train = train
        self.seq_length = seq_length
        self.crop_size = crop_size
        self.to_tensor = transforms.ToTensor()

        # Collect sequences
        self.sequences = []
        for seq in os.listdir(hr_root):
            seq_dir = os.path.join(hr_root, seq)
            frames = sorted([f for f in os.listdir(seq_dir) if f.endswith('.jpg')])
            # Ensure enough frames for the sequence length
            for start in range(0, len(frames) - seq_length + 1):
                self.sequences.append((seq, start, start + seq_length))

    def __len__(self):
        return len(self.sequences)

    def __getitem__(self, idx):
        seq, start, end = self.sequences[idx]
        hr_imgs = []
        lr_imgs = []

        # Determine augmentations once per sequence
        if self.train:
            reverse = random.random() > 0.5
            flip = random.random() > 0.5
            # Get crop coordinates from the first frame
            sample_hr_path = os.path.join(self.hr_root, seq, f"{start:08d}.jpg")
            with Image.open(sample_hr_path) as sample_hr:
                w, h = sample_hr.size
            x = random.randint(0, w - self.crop_size)
            y = random.randint(0, h - self.crop_size)
        else:
            reverse = False
            flip = False
            # Center crop for validation if needed (example)
            sample_hr_path = os.path.join(self.hr_root, seq, f"{start:08d}.jpg")
            with Image.open(sample_hr_path) as sample_hr:
                w, h = sample_hr.size
            x = (w - self.crop_size) // 2
            y = (h - self.crop_size) // 2

        frame_indices = list(range(start, end))
        if reverse:
            frame_indices = reversed(frame_indices)

        for frame_idx in frame_indices:
            frame_num = f"{frame_idx:08d}"
            hr_path = os.path.join(self.hr_root, seq, f"{frame_num}.jpg")
            lr_path = os.path.join(self.lr_root, seq, f"{frame_num}.jpg")

            hr_img = Image.open(hr_path).convert('RGB')
            lr_img = Image.open(lr_path).convert('RGB')

            # Apply crop
            hr_img = hr_img.crop((x, y, x + self.crop_size, y + self.crop_size))
            lr_crop_size = self.crop_size // self.scale
            lr_x = x // self.scale
            lr_y = y // self.scale
            lr_img = lr_img.crop((lr_x, lr_y, lr_x + lr_crop_size, lr_y + lr_crop_size))

            # Apply flip
            if flip:
                hr_img = hr_img.transpose(Image.FLIP_LEFT_RIGHT)
                lr_img = lr_img.transpose(Image.FLIP_LEFT_RIGHT)

            hr_imgs.append(hr_img)
            lr_imgs.append(lr_img)

        # Convert to tensors
        lr_seq = torch.stack([self.to_tensor(img) for img in lr_imgs])
        hr_seq = torch.stack([self.to_tensor(img) for img in hr_imgs])

        return lr_seq, hr_seq

In [None]:
dataset = REDSDataset(hr_root="/content/main_dataset/hr_dataset", lr_root="/content/main_dataset/lr_dataset", seq_length=5)
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [int(0.8 * len(dataset)), len(dataset) - int(0.8 * len(dataset))], random.seed(42))

train_loader = DataLoader(train_dataset, batch_size=24, shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=24, shuffle=False)

#train the model

In [None]:
device='cuda' if torch.cuda.is_available() else 'cpu'
device

'cuda'

In [None]:
pip install wandb



In [None]:
!wandb login

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
[34m[1mwandb[0m: Paste an API key from your profile and hit enter, or press ctrl+c to quit: 
[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33maditya5748rai[0m ([33maditya5748rai-iit-indore[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [None]:
import wandb
import torch
import torch.nn as nn
import torch.optim as optim
from torch.cuda.amp import GradScaler, autocast
from torchmetrics.image import PeakSignalNoiseRatio, StructuralSimilarityIndexMeasure
from tqdm import tqdm

def train(train_loader, val_loader):
    print('Initializing training...')
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # Configuration
    config = {
        "scale": 4,
        "num_epochs": 18,
        "lr": 0.02,
        "weight_decay": 0.0005,
        "architecture": "BasicVSRPlusPlus",
        "dataset": "REDS",
        "optimizer": "AdamW",
        "loss": "L1Loss"
    }

    # Initialize wandb
    wandb.init(
        project="video-super-resolution",
        config=config,
        notes="Training BasicVSR++ "+config["dataset"]+" with scale factor "+str(config["scale"])
    )

    # Model initialization
    model = BasicVSRPlusPlus(scale=config["scale"]).to(device, memory_format=torch.channels_last)
    wandb.watch(model, log="gradients", log_freq=100)  # Log model topology and gradients

    # Loss and metrics
    criterion = nn.L1Loss()
    psnr_metric = PeakSignalNoiseRatio().to(device)
    ssim_metric = StructuralSimilarityIndexMeasure(data_range=1.0).to(device)

    # Optimizer and scheduler
    optimizer = optim.AdamW(model.parameters(), lr=config["lr"], weight_decay=config["weight_decay"])
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=config["num_epochs"])
    scaler = GradScaler()

    best_psnr = 0.0
    print('Starting training loop...')

    for epoch in range(config["num_epochs"]):
        model.train()
        train_loss = 0.0
        progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{config['num_epochs']} [Train]")

        for batch_idx, (lr_seq, hr_seq) in enumerate(progress_bar):
            # Data preparation
            lr_seq = lr_seq.to(device, non_blocking=True, memory_format=torch.channels_last_3d)
            hr_seq = hr_seq.to(device, non_blocking=True, memory_format=torch.channels_last_3d)

            optimizer.zero_grad(set_to_none=True)

            # Mixed precision training
            with autocast():
                sr_seq = model(lr_seq)
                loss = criterion(sr_seq, hr_seq)

            # Gradient management
            scaler.scale(loss).backward()
            nn.utils.clip_grad_norm_(model.parameters(), max_norm=0.01)
            scaler.step(optimizer)
            scaler.update()

            # Logging
            train_loss += loss.item()
            progress_bar.set_postfix(loss=loss.item())

        # Validation phase
        model.eval()
        val_psnr, val_ssim = 0.0, 0.0
        sample_images = []
        with torch.no_grad():
            val_bar = tqdm(val_loader, desc=f"Epoch {epoch+1}/{config['num_epochs']} [Val]")
            for lr_val, hr_val in val_bar:
                lr_val = lr_val.to(device, non_blocking=True, memory_format=torch.channels_last_3d)
                hr_val = hr_val.to(device, non_blocking=True, memory_format=torch.channels_last_3d)

                with autocast():
                    sr_val = model(lr_val)

                # Metrics calculation
                val_psnr += psnr_metric(sr_val, hr_val)
                # val_ssim += ssim_metric(sr_val, hr_val)

                # Collect sample images
                if not sample_images and epoch % 2 == 0:  # Log every 2 epochs
                    idx = 0  # First sample in batch
                    t = sr_val.shape[1] // 2  # Middle frame
                    sample_images.extend([
                        wandb.Image(sr_val[idx, t].cpu().permute(1, 2, 0).numpy()),
                        wandb.Image(hr_val[idx, t].cpu().permute(1, 2, 0).numpy()),
                        wandb.Image(lr_val[idx, t].cpu().permute(1, 2, 0).numpy())])

        # Calculate averages
        avg_loss = train_loss / len(train_loader)
        avg_psnr = val_psnr / len(val_loader)
        # avg_ssim = val_ssim / len(val_loader)

        # Wandb logging
        log_data = {
            "epoch": epoch + 1,
            "train_loss": avg_loss,
            "val_psnr": avg_psnr,
            # "val_ssim": avg_ssim,
            "learning_rate": scheduler.get_last_lr()[0]
        }

        if sample_images:
            log_data.update({
                "SR Sample": sample_images[0],
                "HR Ground Truth": sample_images[1],
                "LR Input": sample_images[2]
            })

        wandb.log(log_data)

        # Update scheduler and save best model
        scheduler.step()
        if avg_psnr > best_psnr:
            best_psnr = avg_psnr
            torch.save(model.state_dict(), 'best_reds_vsr.pth')
            wandb.save('best_reds_vsr.pth')

        print(f"Epoch {epoch+1}/{config['num_epochs']} | "
              f"Train Loss: {avg_loss:.4f} | "
              f"Val PSNR: {avg_psnr:.2f} | "
            #   f"Val SSIM: {avg_ssim:.4f}")

    wandb.finish()
    print("Training completed!")

In [None]:
train(train_loader=train_loader, val_loader=val_loader)

Initializing training...


  scaler = GradScaler()


Starting training loop...


  with autocast():
Epoch 1/18 [Train]: 100%|██████████| 384/384 [10:32<00:00,  1.65s/it, loss=0.0673]
  with autocast():
Epoch 1/18 [Val]: 100%|██████████| 96/96 [01:38<00:00,  1.03s/it]


NameError: name 'avg_ssim' is not defined