# MFMViT-QF Training
Training script for MFMViT with FBCNN QF conditioning

## 1. Imports

In [1]:
import torch
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from tqdm import tqdm
import os

from models.mfmQF import MFMViT_QF
from src.fft_utils import apply_fft, apply_ifft, get_mask, get_spectrum_amplitude
from src.loss import MFMLoss
from src.qf_extraction import load_fbcnn_qf

  warn(


## 2. Configuration

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Training config
img_size = 224
batch_size = 16
lr = 1e-4
epochs = 50
mask_ratio = 0.5
qf_dim = 64

Using device: cuda


## 3. Data Loading

In [3]:
transform = transforms.Compose([
    transforms.Resize((img_size, img_size)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

train_dataset = datasets.ImageFolder('E:/data/train', transform=transform)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

val_dataset = datasets.ImageFolder('E:/data/val', transform=transform) if os.path.exists('E:/data/val') else None
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False) if val_dataset else None

print(f"Train samples: {len(train_dataset)}")
if val_dataset:
    print(f"Val samples: {len(val_dataset)}")

Train samples: 80216
Val samples: 6197


## 4. Model Setup

In [4]:
# Load frozen FBCNN for QF extraction
fbcnn_qf = load_fbcnn_qf('checkpoints/FBCNN/fbcnn_color.pth', qf_dim=qf_dim, device=device)

# Initialize MFMViT-QF model
model = MFMViT_QF(img_size=img_size, qf_dim=qf_dim, use_qf=True).to(device)

# Optimizer and loss
optimizer = optim.Adam(model.parameters(), lr=lr)
criterion = MFMLoss().to(device)

# Create checkpoint directory
os.makedirs('checkpoints/MFMQF', exist_ok=True)

print("Models initialized successfully!")

[FBCNN] Loaded pretrained weights from checkpoints/FBCNN/fbcnn_color.pth
[FBCNN] Frozen with qf_dim=64
Models initialized successfully!


## 5. Training Loop

In [5]:
for epoch in range(epochs):
    model.train()
    pbar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{epochs}')
    train_loss = 0
    
    for images, _ in pbar:
        images = images.to(device)
        
        # Extract QF vectors using FBCNN
        with torch.no_grad():
            qf_vectors = fbcnn_qf(images)
        
        # FFT processing
        fft_original = apply_fft(images)
        amplitude_original = get_spectrum_amplitude(fft_original)
        
        mask = get_mask(images.shape[0], images.shape[1], img_size, 
                      ratio=mask_ratio, device=device)
        fft_masked = fft_original * mask
        
        corrupted_spatial = apply_ifft(fft_masked)
        predicted_spatial = model(corrupted_spatial, qf_vectors)
        
        fft_predicted = apply_fft(predicted_spatial)
        amplitude_predicted = get_spectrum_amplitude(fft_predicted)
        
        loss = criterion(amplitude_predicted, amplitude_original, mask)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        train_loss += loss.item()
        pbar.set_postfix({'loss': f'{loss.item():.4f}'})
    
    # Validation
    val_loss = 0
    if val_loader:
        model.eval()
        with torch.no_grad():
            for images, _ in val_loader:
                images = images.to(device)
                qf_vectors = fbcnn_qf(images)
                
                fft_original = apply_fft(images)
                amplitude_original = get_spectrum_amplitude(fft_original)
                
                mask = get_mask(images.shape[0], images.shape[1], img_size, 
                              ratio=mask_ratio, device=device)
                fft_masked = fft_original * mask
                
                corrupted_spatial = apply_ifft(fft_masked)
                predicted_spatial = model(corrupted_spatial, qf_vectors)
                
                fft_predicted = apply_fft(predicted_spatial)
                amplitude_predicted = get_spectrum_amplitude(fft_predicted)
                
                loss = criterion(amplitude_predicted, amplitude_original, mask)
                val_loss += loss.item()
        
        val_loss /= len(val_loader)
    
    train_loss /= len(train_loader)
    msg = f'Epoch {epoch+1} | Train Loss: {train_loss:.4f}'
    if val_loader:
        msg += f' | Val Loss: {val_loss:.4f}'
    print(msg)
    
    if (epoch + 1) % 10 == 0:
        torch.save(model.state_dict(), f'checkpoints/MFMQF/mfm_qf_epoch_{epoch+1}.pth')

torch.save(model.state_dict(), 'checkpoints/MFMQF/mfm_qf_final.pth')
print('Training complete!')

Epoch 1/50:   0%|          | 0/5014 [00:00<?, ?it/s]

[MFMViT-QF] QF modulation enabled: qf_dim=64 -> embed_dim=768
[MFMViT-QF init] embed_dim=768, num_patches=196, patches_per_side=14, patch_size=16, patch_dim=768


Epoch 1/50: 100%|██████████| 5014/5014 [14:51<00:00,  5.62it/s, loss=0.7825]


Epoch 1 | Train Loss: 0.7889 | Val Loss: 0.7506


Epoch 2/50: 100%|██████████| 5014/5014 [14:27<00:00,  5.78it/s, loss=0.7137]


Epoch 2 | Train Loss: 0.7124 | Val Loss: 0.6822


Epoch 3/50: 100%|██████████| 5014/5014 [14:21<00:00,  5.82it/s, loss=0.6234]


Epoch 3 | Train Loss: 0.6470 | Val Loss: 0.6209


Epoch 4/50: 100%|██████████| 5014/5014 [14:26<00:00,  5.79it/s, loss=0.5712]


Epoch 4 | Train Loss: 0.5925 | Val Loss: 0.5698


Epoch 5/50: 100%|██████████| 5014/5014 [14:27<00:00,  5.78it/s, loss=0.5376]


Epoch 5 | Train Loss: 0.5467 | Val Loss: 0.5277


Epoch 6/50: 100%|██████████| 5014/5014 [14:27<00:00,  5.78it/s, loss=0.4902]


Epoch 6 | Train Loss: 0.5090 | Val Loss: 0.4932


Epoch 7/50: 100%|██████████| 5014/5014 [14:25<00:00,  5.79it/s, loss=0.4582]


Epoch 7 | Train Loss: 0.4772 | Val Loss: 0.4624


Epoch 8/50: 100%|██████████| 5014/5014 [14:25<00:00,  5.79it/s, loss=0.4454]


Epoch 8 | Train Loss: 0.4501 | Val Loss: 0.4389


Epoch 9/50: 100%|██████████| 5014/5014 [14:32<00:00,  5.75it/s, loss=0.4169]


Epoch 9 | Train Loss: 0.4272 | Val Loss: 0.4161


Epoch 10/50: 100%|██████████| 5014/5014 [14:29<00:00,  5.77it/s, loss=0.4049]


Epoch 10 | Train Loss: 0.4068 | Val Loss: 0.3971


Epoch 11/50: 100%|██████████| 5014/5014 [14:24<00:00,  5.80it/s, loss=0.3824]


Epoch 11 | Train Loss: 0.3893 | Val Loss: 0.3801


Epoch 12/50: 100%|██████████| 5014/5014 [14:23<00:00,  5.80it/s, loss=0.3492]


Epoch 12 | Train Loss: 0.3736 | Val Loss: 0.3656


Epoch 13/50: 100%|██████████| 5014/5014 [14:32<00:00,  5.75it/s, loss=0.3524]


Epoch 13 | Train Loss: 0.3595 | Val Loss: 0.3529


Epoch 14/50: 100%|██████████| 5014/5014 [14:29<00:00,  5.77it/s, loss=0.3397]


Epoch 14 | Train Loss: 0.3472 | Val Loss: 0.3410


Epoch 15/50: 100%|██████████| 5014/5014 [14:24<00:00,  5.80it/s, loss=0.3319]


Epoch 15 | Train Loss: 0.3360 | Val Loss: 0.3296


Epoch 16/50: 100%|██████████| 5014/5014 [14:25<00:00,  5.79it/s, loss=0.3059]


Epoch 16 | Train Loss: 0.3258 | Val Loss: 0.3211


Epoch 17/50: 100%|██████████| 5014/5014 [14:30<00:00,  5.76it/s, loss=0.3389]


Epoch 17 | Train Loss: 0.3170 | Val Loss: 0.3120


Epoch 18/50: 100%|██████████| 5014/5014 [14:31<00:00,  5.75it/s, loss=0.3090]


Epoch 18 | Train Loss: 0.3088 | Val Loss: 0.3046


Epoch 19/50: 100%|██████████| 5014/5014 [14:27<00:00,  5.78it/s, loss=0.2909]


Epoch 19 | Train Loss: 0.3014 | Val Loss: 0.2974


Epoch 20/50: 100%|██████████| 5014/5014 [14:24<00:00,  5.80it/s, loss=0.2951]


Epoch 20 | Train Loss: 0.2944 | Val Loss: 0.2911


Epoch 21/50: 100%|██████████| 5014/5014 [14:28<00:00,  5.77it/s, loss=0.3015]


Epoch 21 | Train Loss: 0.2882 | Val Loss: 0.2837


Epoch 22/50: 100%|██████████| 5014/5014 [14:25<00:00,  5.79it/s, loss=0.2917]


Epoch 22 | Train Loss: 0.2823 | Val Loss: 0.2788


Epoch 23/50: 100%|██████████| 5014/5014 [14:24<00:00,  5.80it/s, loss=0.2776]


Epoch 23 | Train Loss: 0.2769 | Val Loss: 0.2730


Epoch 24/50: 100%|██████████| 5014/5014 [14:25<00:00,  5.80it/s, loss=0.2696]


Epoch 24 | Train Loss: 0.2721 | Val Loss: 0.2690


Epoch 25/50: 100%|██████████| 5014/5014 [14:38<00:00,  5.71it/s, loss=0.2659]


Epoch 25 | Train Loss: 0.2673 | Val Loss: 0.2649


Epoch 26/50: 100%|██████████| 5014/5014 [14:29<00:00,  5.76it/s, loss=0.2685]


Epoch 26 | Train Loss: 0.2634 | Val Loss: 0.2613


Epoch 27/50: 100%|██████████| 5014/5014 [14:26<00:00,  5.79it/s, loss=0.2567]


Epoch 27 | Train Loss: 0.2594 | Val Loss: 0.2571


Epoch 28/50: 100%|██████████| 5014/5014 [14:27<00:00,  5.78it/s, loss=0.2493]


Epoch 28 | Train Loss: 0.2556 | Val Loss: 0.2541


Epoch 29/50: 100%|██████████| 5014/5014 [14:30<00:00,  5.76it/s, loss=0.2457]


Epoch 29 | Train Loss: 0.2520 | Val Loss: 0.2496


Epoch 30/50: 100%|██████████| 5014/5014 [14:28<00:00,  5.77it/s, loss=0.2486]


Epoch 30 | Train Loss: 0.2488 | Val Loss: 0.2473


Epoch 31/50: 100%|██████████| 5014/5014 [14:30<00:00,  5.76it/s, loss=0.2532]


Epoch 31 | Train Loss: 0.2458 | Val Loss: 0.2431


Epoch 32/50: 100%|██████████| 5014/5014 [14:29<00:00,  5.77it/s, loss=0.2337]


Epoch 32 | Train Loss: 0.2425 | Val Loss: 0.2403


Epoch 33/50: 100%|██████████| 5014/5014 [14:29<00:00,  5.76it/s, loss=0.2225]


Epoch 33 | Train Loss: 0.2398 | Val Loss: 0.2383


Epoch 34/50: 100%|██████████| 5014/5014 [14:27<00:00,  5.78it/s, loss=0.2600]


Epoch 34 | Train Loss: 0.2371 | Val Loss: 0.2356


Epoch 35/50: 100%|██████████| 5014/5014 [14:29<00:00,  5.77it/s, loss=0.2177]


Epoch 35 | Train Loss: 0.2349 | Val Loss: 0.2336


Epoch 36/50: 100%|██████████| 5014/5014 [14:27<00:00,  5.78it/s, loss=0.2253]


Epoch 36 | Train Loss: 0.2327 | Val Loss: 0.2307


Epoch 37/50: 100%|██████████| 5014/5014 [14:29<00:00,  5.76it/s, loss=0.2288]


Epoch 37 | Train Loss: 0.2302 | Val Loss: 0.2291


Epoch 38/50: 100%|██████████| 5014/5014 [14:29<00:00,  5.77it/s, loss=0.2317]


Epoch 38 | Train Loss: 0.2283 | Val Loss: 0.2278


Epoch 39/50: 100%|██████████| 5014/5014 [14:29<00:00,  5.77it/s, loss=0.2366]


Epoch 39 | Train Loss: 0.2261 | Val Loss: 0.2262


Epoch 40/50: 100%|██████████| 5014/5014 [14:31<00:00,  5.75it/s, loss=0.2243]


Epoch 40 | Train Loss: 0.2245 | Val Loss: 0.2241


Epoch 41/50: 100%|██████████| 5014/5014 [14:33<00:00,  5.74it/s, loss=0.2232]


Epoch 41 | Train Loss: 0.2230 | Val Loss: 0.2225


Epoch 42/50: 100%|██████████| 5014/5014 [14:32<00:00,  5.75it/s, loss=0.2248]


Epoch 42 | Train Loss: 0.2211 | Val Loss: 0.2201


Epoch 43/50: 100%|██████████| 5014/5014 [14:33<00:00,  5.74it/s, loss=0.2052]


Epoch 43 | Train Loss: 0.2197 | Val Loss: 0.2182


Epoch 44/50: 100%|██████████| 5014/5014 [14:31<00:00,  5.76it/s, loss=0.2306]


Epoch 44 | Train Loss: 0.2180 | Val Loss: 0.2176


Epoch 45/50: 100%|██████████| 5014/5014 [14:33<00:00,  5.74it/s, loss=0.2116]


Epoch 45 | Train Loss: 0.2166 | Val Loss: 0.2163


Epoch 46/50: 100%|██████████| 5014/5014 [14:29<00:00,  5.77it/s, loss=0.2187]


Epoch 46 | Train Loss: 0.2153 | Val Loss: 0.2144


Epoch 47/50: 100%|██████████| 5014/5014 [14:31<00:00,  5.75it/s, loss=0.2141]


Epoch 47 | Train Loss: 0.2143 | Val Loss: 0.2138


Epoch 48/50: 100%|██████████| 5014/5014 [14:30<00:00,  5.76it/s, loss=0.2102]


Epoch 48 | Train Loss: 0.2129 | Val Loss: 0.2128


Epoch 49/50: 100%|██████████| 5014/5014 [14:34<00:00,  5.74it/s, loss=0.1955]


Epoch 49 | Train Loss: 0.2117 | Val Loss: 0.2126


Epoch 50/50: 100%|██████████| 5014/5014 [14:34<00:00,  5.73it/s, loss=0.2278]


Epoch 50 | Train Loss: 0.2106 | Val Loss: 0.2115
Training complete!
