### Imports

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import os, sys
from src.model import MFMViT
from src.fft_utils import *
from src.loss import MFMLoss
from tqdm import tqdm
from torch.optim.lr_scheduler import CosineAnnealingLR, LambdaLR

  warn(


### Trainer

In [None]:
def trainer():
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    img_size = 224
    batch_size = 16
    lr = 1e-4
    epochs = 50
    mask_ratio = 0.5
    warmup_epochs = 5
    
    transform = transforms.Compose([
        transforms.Resize((img_size, img_size)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                           std=[0.229, 0.224, 0.225])
    ])
    
    train_dataset = datasets.ImageFolder('E:/data/train', transform=transform)
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    
    val_dataset = datasets.ImageFolder('E:/data/val', transform=transform) if os.path.exists('E:/data/val') else None
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False) if val_dataset else None
    
    model = MFMViT(img_size=img_size).to(device)
    optimizer = optim.Adam(model.parameters(), lr=lr)
    criterion = MFMLoss().to(device)
    
    warmup_scheduler = LambdaLR(optimizer, lambda epoch: min(1.0, (epoch + 1) / warmup_epochs))
    cosine_scheduler = CosineAnnealingLR(optimizer, T_max=epochs - warmup_epochs, eta_min=1e-6)
    
    os.makedirs('checkpoints', exist_ok=True)
    
    for epoch in range(epochs):
        model.train()
        pbar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{epochs}')
        train_loss = 0
        
        for images, _ in pbar:
            images = images.to(device)
            
            fft_original = apply_fft(images)
            amplitude_original = get_spectrum_amplitude(fft_original)
            
            mask = get_mask(images.shape[0], images.shape[1], img_size, 
                          ratio=mask_ratio, device=device)
            fft_masked = fft_original * mask
            
            corrupted_spatial = apply_ifft(fft_masked)
            predicted_spatial = model(corrupted_spatial)
            
            fft_predicted = apply_fft(predicted_spatial)
            amplitude_predicted = get_spectrum_amplitude(fft_predicted)
            
            loss = criterion(amplitude_predicted, amplitude_original, mask)
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item()
            pbar.set_postfix({'loss': f'{loss.item():.4f}'})
        
        val_loss = 0
        if val_loader:
            model.eval()
            with torch.no_grad():
                for images, _ in val_loader:
                    images = images.to(device)
                    
                    fft_original = apply_fft(images)
                    amplitude_original = get_spectrum_amplitude(fft_original)
                    
                    mask = get_mask(images.shape[0], images.shape[1], img_size, 
                                  ratio=mask_ratio, device=device)
                    fft_masked = fft_original * mask
                    
                    corrupted_spatial = apply_ifft(fft_masked)
                    predicted_spatial = model(corrupted_spatial)
                    
                    fft_predicted = apply_fft(predicted_spatial)
                    amplitude_predicted = get_spectrum_amplitude(fft_predicted)
                    
                    loss = criterion(amplitude_predicted, amplitude_original, mask)
                    val_loss += loss.item()
            
            val_loss /= len(val_loader)
        
        if epoch < warmup_epochs:
            warmup_scheduler.step()
        else:
            cosine_scheduler.step()
        
        train_loss /= len(train_loader)
        msg = f'Epoch {epoch+1} | Train Loss: {train_loss:.4f}'
        if val_loader:
            msg += f' | Val Loss: {val_loss:.4f}'
        msg += f' | LR: {optimizer.param_groups[0]["lr"]:.2e}'
        print(msg + '\n')
        
        if (epoch + 1) % 10 == 0:
            torch.save(model.state_dict(), f'checkpoints/mfm_vit_epoch_{epoch+1}.pth')
    
    torch.save(model.state_dict(), 'checkpoints/mfm_vit_final.pth')
    print('Training complete!')

trainer()

Epoch 1/50:   0%|          | 0/5014 [00:00<?, ?it/s]

[MFMViT init] embed_dim=768, num_patches=196, patches_per_side=14, patch_size=16, patch_dim=768


Epoch 1/50: 100%|██████████| 5014/5014 [13:44<00:00,  6.08it/s, loss=0.2955]


Epoch 1 | Train Loss: 0.3249 | Val Loss: 0.2959 | LR: 4.00e-05



Epoch 2/50: 100%|██████████| 5014/5014 [13:44<00:00,  6.08it/s, loss=0.2622]


Epoch 2 | Train Loss: 0.2827 | Val Loss: 0.2685 | LR: 6.00e-05



Epoch 3/50: 100%|██████████| 5014/5014 [14:08<00:00,  5.91it/s, loss=0.2437]


Epoch 3 | Train Loss: 0.2565 | Val Loss: 0.2450 | LR: 8.00e-05



Epoch 4/50: 100%|██████████| 5014/5014 [13:49<00:00,  6.04it/s, loss=0.2215]


Epoch 4 | Train Loss: 0.2375 | Val Loss: 0.2276 | LR: 1.00e-04



Epoch 5/50: 100%|██████████| 5014/5014 [13:57<00:00,  5.99it/s, loss=0.2167]


Epoch 5 | Train Loss: 0.2236 | Val Loss: 0.2156 | LR: 1.00e-04



Epoch 6/50: 100%|██████████| 5014/5014 [13:51<00:00,  6.03it/s, loss=0.2075]


Epoch 6 | Train Loss: 0.2120 | Val Loss: 0.2076 | LR: 9.99e-05



Epoch 7/50: 100%|██████████| 5014/5014 [13:47<00:00,  6.06it/s, loss=0.1942]


Epoch 7 | Train Loss: 0.2043 | Val Loss: 0.2009 | LR: 9.95e-05



Epoch 8/50: 100%|██████████| 5014/5014 [13:46<00:00,  6.07it/s, loss=0.1868]


Epoch 8 | Train Loss: 0.1985 | Val Loss: 0.1947 | LR: 9.89e-05



Epoch 9/50:  56%|█████▌    | 2788/5014 [07:42<05:54,  6.27it/s, loss=0.1973]