### Imports

In [19]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import os, sys
from src.model import MFMViT
from src.fft_utils import *
from tqdm import tqdm
from torch.optim.lr_scheduler import CosineAnnealingLR, LambdaLR

### Trainer

In [None]:
def trainer():
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    img_size = 224
    batch_size = 16
    lr = 1e-4
    epochs = 50
    mask_ratio = 0.5
    warmup_epochs = 5
    
    transform = transforms.Compose([
        transforms.Resize((img_size, img_size)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                           std=[0.229, 0.224, 0.225])
    ])
    
    train_dataset = datasets.ImageFolder('E:/data/train', transform=transform)
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    
    # Validation dataset (if exists)
    val_dataset = datasets.ImageFolder('E:/data/val', transform=transform) if os.path.exists('E:/data/val') else None
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False) if val_dataset else None
    
    model = MFMViT(img_size=img_size).to(device)
    optimizer = optim.Adam(model.parameters(), lr=lr)
    criterion = nn.MSELoss()
    
    warmup_scheduler = LambdaLR(optimizer, lambda epoch: min(1.0, (epoch + 1) / warmup_epochs))
    cosine_scheduler = CosineAnnealingLR(optimizer, T_max=epochs - warmup_epochs, eta_min=1e-6)
    
    os.makedirs('checkpoints', exist_ok=True)
    
    for epoch in range(epochs):
        # Training
        model.train()
        pbar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{epochs}')
        train_loss = 0
        
        for images, _ in pbar:
            images = images.to(device)
            
            fft_original = apply_fft(images)
            amplitude_original = get_spectrum_amplitude(fft_original)
            
            mask = get_mask(images.shape[0], images.shape[1], img_size, 
                          ratio=mask_ratio, device=device)
            fft_masked = fft_original * mask
            
            corrupted_spatial = apply_ifft(fft_masked)
            predicted_spatial = model(corrupted_spatial)
            
            fft_predicted = apply_fft(predicted_spatial)
            amplitude_predicted = get_spectrum_amplitude(fft_predicted)
            
            loss = criterion(amplitude_predicted * mask, amplitude_original * mask)
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item()
            pbar.set_postfix({'loss': f'{loss.item():.4f}'})
        
        # Validation
        val_loss = 0
        if val_loader:
            model.eval()
            with torch.no_grad():
                for images, _ in val_loader:
                    images = images.to(device)
                    
                    fft_original = apply_fft(images)
                    amplitude_original = get_spectrum_amplitude(fft_original)
                    
                    mask = get_mask(images.shape[0], images.shape[1], img_size, 
                                  ratio=mask_ratio, device=device)
                    fft_masked = fft_original * mask
                    
                    corrupted_spatial = apply_ifft(fft_masked)
                    predicted_spatial = model(corrupted_spatial)
                    
                    fft_predicted = apply_fft(predicted_spatial)
                    amplitude_predicted = get_spectrum_amplitude(fft_predicted)
                    
                    loss = criterion(amplitude_predicted * mask, amplitude_original * mask)
                    val_loss += loss.item()
            
            val_loss /= len(val_loader)
        
        if epoch < warmup_epochs:
            warmup_scheduler.step()
        else:
            cosine_scheduler.step()
        
        train_loss /= len(train_loader)
        msg = f'Epoch {epoch+1} | Train Loss: {train_loss:.6f}'
        if val_loader:
            msg += f' | Val Loss: {val_loss:.6f}'
        msg += f' | LR: {optimizer.param_groups[0]["lr"]:.2e}'
        print(msg + '\n')
        
        if (epoch + 1) % 10 == 0:
            torch.save(model.state_dict(), f'checkpoints/mfm_vit_epoch_{epoch+1}.pth')
    
    torch.save(model.state_dict(), 'checkpoints/mfm_vit_final.pth')
    print('Training complete!')

trainer()


Epoch 1/50: 100%|██████████| 1336/1336 [04:00<00:00,  5.56it/s, loss=13338.3438]


Epoch 1 | Train Loss: 19044.708215 | Val Loss: 14770.454159 | LR: 4.00e-05



Epoch 2/50: 100%|██████████| 1336/1336 [03:56<00:00,  5.65it/s, loss=5773.7368] 


Epoch 2 | Train Loss: 9070.880006 | Val Loss: 6994.413939 | LR: 6.00e-05



Epoch 3/50: 100%|██████████| 1336/1336 [03:56<00:00,  5.65it/s, loss=7941.4077]


Epoch 3 | Train Loss: 6611.652681 | Val Loss: 6141.730342 | LR: 8.00e-05



Epoch 4/50: 100%|██████████| 1336/1336 [03:56<00:00,  5.64it/s, loss=5669.5781]


Epoch 4 | Train Loss: 6164.645501 | Val Loss: 5918.572889 | LR: 1.00e-04



Epoch 5/50: 100%|██████████| 1336/1336 [03:56<00:00,  5.64it/s, loss=6473.7642]


Epoch 5 | Train Loss: 5883.975444 | Val Loss: 5733.585936 | LR: 1.00e-04



Epoch 6/50: 100%|██████████| 1336/1336 [03:56<00:00,  5.64it/s, loss=5392.1001]


Epoch 6 | Train Loss: 5521.145441 | Val Loss: 5430.275198 | LR: 9.99e-05



Epoch 7/50: 100%|██████████| 1336/1336 [04:00<00:00,  5.56it/s, loss=4564.3965]


Epoch 7 | Train Loss: 5300.154755 | Val Loss: 5256.009271 | LR: 9.95e-05



Epoch 8/50: 100%|██████████| 1336/1336 [03:55<00:00,  5.68it/s, loss=5367.5708]


Epoch 8 | Train Loss: 5076.746818 | Val Loss: 5107.836536 | LR: 9.89e-05



Epoch 9/50: 100%|██████████| 1336/1336 [03:51<00:00,  5.78it/s, loss=4244.1450]


Epoch 9 | Train Loss: 4951.064136 | Val Loss: 4984.596456 | LR: 9.81e-05



Epoch 10/50: 100%|██████████| 1336/1336 [03:53<00:00,  5.72it/s, loss=5073.7515]


Epoch 10 | Train Loss: 4788.882865 | Val Loss: 4778.858478 | LR: 9.70e-05



Epoch 11/50: 100%|██████████| 1336/1336 [03:54<00:00,  5.71it/s, loss=4509.7305]


Epoch 11 | Train Loss: 4654.297732 | Val Loss: 4686.558941 | LR: 9.57e-05



Epoch 12/50: 100%|██████████| 1336/1336 [04:00<00:00,  5.56it/s, loss=4454.1284]
