### Imports

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import os, sys
from src.model import MFMViT
from src.fft_utils import *
from src.loss import MFMLoss
from tqdm import tqdm
from torch.optim.lr_scheduler import CosineAnnealingLR, LambdaLR

  warn(


### Trainer

In [3]:
def trainer():
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    img_size = 224
    batch_size = 16
    lr = 1e-4
    epochs = 50
    mask_ratio = 0.5
    warmup_epochs = 5
    
    transform = transforms.Compose([
        transforms.Resize((img_size, img_size)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                           std=[0.229, 0.224, 0.225])
    ])
    
    train_dataset = datasets.ImageFolder('E:/data/train', transform=transform)
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    
    val_dataset = datasets.ImageFolder('E:/data/val', transform=transform) if os.path.exists('E:/data/val') else None
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False) if val_dataset else None
    
    model = MFMViT(img_size=img_size).to(device)
    optimizer = optim.Adam(model.parameters(), lr=lr)
    criterion = MFMLoss().to(device)
    
    warmup_scheduler = LambdaLR(optimizer, lambda epoch: min(1.0, (epoch + 1) / warmup_epochs))
    cosine_scheduler = CosineAnnealingLR(optimizer, T_max=epochs - warmup_epochs, eta_min=1e-6)
    
    os.makedirs('checkpoints', exist_ok=True)
    
    for epoch in range(epochs):
        model.train()
        pbar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{epochs}')
        train_loss = 0
        
        for images, _ in pbar:
            images = images.to(device)
            
            fft_original = apply_fft(images)
            amplitude_original = get_spectrum_amplitude(fft_original)
            
            mask = get_mask(images.shape[0], images.shape[1], img_size, 
                          ratio=mask_ratio, device=device)
            fft_masked = fft_original * mask
            
            corrupted_spatial = apply_ifft(fft_masked)
            predicted_spatial = model(corrupted_spatial)
            
            fft_predicted = apply_fft(predicted_spatial)
            amplitude_predicted = get_spectrum_amplitude(fft_predicted)
            
            loss = criterion(amplitude_predicted, amplitude_original, mask)
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item()
            pbar.set_postfix({'loss': f'{loss.item():.4f}'})
        
        val_loss = 0
        if val_loader:
            model.eval()
            with torch.no_grad():
                for images, _ in val_loader:
                    images = images.to(device)
                    
                    fft_original = apply_fft(images)
                    amplitude_original = get_spectrum_amplitude(fft_original)
                    
                    mask = get_mask(images.shape[0], images.shape[1], img_size, 
                                  ratio=mask_ratio, device=device)
                    fft_masked = fft_original * mask
                    
                    corrupted_spatial = apply_ifft(fft_masked)
                    predicted_spatial = model(corrupted_spatial)
                    
                    fft_predicted = apply_fft(predicted_spatial)
                    amplitude_predicted = get_spectrum_amplitude(fft_predicted)
                    
                    loss = criterion(amplitude_predicted, amplitude_original, mask)
                    val_loss += loss.item()
            
            val_loss /= len(val_loader)
        
        if epoch < warmup_epochs:
            warmup_scheduler.step()
        else:
            cosine_scheduler.step()
        
        train_loss /= len(train_loader)
        msg = f'Epoch {epoch+1} | Train Loss: {train_loss:.4f}'
        if val_loader:
            msg += f' | Val Loss: {val_loss:.4f}'
        msg += f' | LR: {optimizer.param_groups[0]["lr"]:.2e}'
        print(msg + '\n')
        
        if (epoch + 1) % 10 == 0:
            torch.save(model.state_dict(), f'checkpoints/mfm_vit_epoch_{epoch+1}.pth')
    
    torch.save(model.state_dict(), 'checkpoints/mfm_vit_final.pth')
    print('Training complete!')

trainer()

Epoch 1/50:   0%|          | 0/5014 [00:00<?, ?it/s]

[MFMViT init] embed_dim=768, num_patches=196, patches_per_side=14, patch_size=16, patch_dim=768


Epoch 1/50: 100%|██████████| 5014/5014 [14:01<00:00,  5.96it/s, loss=0.2811]


Epoch 1 | Train Loss: 0.3265 | Val Loss: 0.3008 | LR: 4.00e-05



Epoch 2/50: 100%|██████████| 5014/5014 [13:52<00:00,  6.02it/s, loss=0.2861]


Epoch 2 | Train Loss: 0.2863 | Val Loss: 0.2731 | LR: 6.00e-05



Epoch 3/50: 100%|██████████| 5014/5014 [13:50<00:00,  6.04it/s, loss=0.2592]


Epoch 3 | Train Loss: 0.2621 | Val Loss: 0.2488 | LR: 8.00e-05



Epoch 4/50: 100%|██████████| 5014/5014 [13:50<00:00,  6.04it/s, loss=0.2503]


Epoch 4 | Train Loss: 0.2425 | Val Loss: 0.2310 | LR: 1.00e-04



Epoch 5/50: 100%|██████████| 5014/5014 [13:44<00:00,  6.08it/s, loss=0.2193]


Epoch 5 | Train Loss: 0.2269 | Val Loss: 0.2173 | LR: 1.00e-04



Epoch 6/50: 100%|██████████| 5014/5014 [13:42<00:00,  6.09it/s, loss=0.2149]


Epoch 6 | Train Loss: 0.2146 | Val Loss: 0.2073 | LR: 9.99e-05



Epoch 7/50: 100%|██████████| 5014/5014 [13:50<00:00,  6.03it/s, loss=0.2101]


Epoch 7 | Train Loss: 0.2067 | Val Loss: 0.2036 | LR: 9.95e-05



Epoch 8/50: 100%|██████████| 5014/5014 [13:41<00:00,  6.10it/s, loss=0.1920]


Epoch 8 | Train Loss: 0.2009 | Val Loss: 0.1951 | LR: 9.89e-05



Epoch 9/50: 100%|██████████| 5014/5014 [13:42<00:00,  6.09it/s, loss=0.1818]


Epoch 9 | Train Loss: 0.1968 | Val Loss: 0.1920 | LR: 9.81e-05



Epoch 10/50: 100%|██████████| 5014/5014 [13:45<00:00,  6.07it/s, loss=0.2105]


Epoch 10 | Train Loss: 0.1925 | Val Loss: 0.1947 | LR: 9.70e-05



Epoch 11/50: 100%|██████████| 5014/5014 [13:53<00:00,  6.01it/s, loss=0.1865]


Epoch 11 | Train Loss: 0.1906 | Val Loss: 0.1911 | LR: 9.57e-05



Epoch 12/50: 100%|██████████| 5014/5014 [13:31<00:00,  6.18it/s, loss=0.1821]


Epoch 12 | Train Loss: 0.1894 | Val Loss: 0.1834 | LR: 9.42e-05



Epoch 13/50: 100%|██████████| 5014/5014 [13:37<00:00,  6.13it/s, loss=0.1838]


Epoch 13 | Train Loss: 0.1860 | Val Loss: 0.1834 | LR: 9.25e-05



Epoch 14/50: 100%|██████████| 5014/5014 [13:38<00:00,  6.13it/s, loss=0.1880]


Epoch 14 | Train Loss: 0.1853 | Val Loss: 0.1820 | LR: 9.05e-05



Epoch 15/50: 100%|██████████| 5014/5014 [13:42<00:00,  6.09it/s, loss=0.1792]


Epoch 15 | Train Loss: 0.1839 | Val Loss: 0.1822 | LR: 8.84e-05



Epoch 16/50: 100%|██████████| 5014/5014 [13:39<00:00,  6.12it/s, loss=0.1762]


Epoch 16 | Train Loss: 0.1835 | Val Loss: 0.1855 | LR: 8.61e-05



Epoch 17/50: 100%|██████████| 5014/5014 [13:37<00:00,  6.13it/s, loss=0.1963]


Epoch 17 | Train Loss: 0.1823 | Val Loss: 0.1835 | LR: 8.36e-05



Epoch 18/50: 100%|██████████| 5014/5014 [13:38<00:00,  6.13it/s, loss=0.1756]


Epoch 18 | Train Loss: 0.1812 | Val Loss: 0.1781 | LR: 8.10e-05



Epoch 19/50: 100%|██████████| 5014/5014 [13:38<00:00,  6.13it/s, loss=0.1771]


Epoch 19 | Train Loss: 0.1802 | Val Loss: 0.1783 | LR: 7.82e-05



Epoch 20/50: 100%|██████████| 5014/5014 [13:42<00:00,  6.09it/s, loss=0.1821]


Epoch 20 | Train Loss: 0.1794 | Val Loss: 0.1790 | LR: 7.53e-05



Epoch 21/50: 100%|██████████| 5014/5014 [13:45<00:00,  6.08it/s, loss=0.1897]


Epoch 21 | Train Loss: 0.1795 | Val Loss: 0.1775 | LR: 7.22e-05



Epoch 22/50: 100%|██████████| 5014/5014 [13:46<00:00,  6.07it/s, loss=0.1980]


Epoch 22 | Train Loss: 0.1772 | Val Loss: 0.1753 | LR: 6.90e-05



Epoch 23/50: 100%|██████████| 5014/5014 [13:38<00:00,  6.12it/s, loss=0.1763]


Epoch 23 | Train Loss: 0.1768 | Val Loss: 0.1750 | LR: 6.58e-05



Epoch 24/50: 100%|██████████| 5014/5014 [13:40<00:00,  6.11it/s, loss=0.1689]


Epoch 24 | Train Loss: 0.1761 | Val Loss: 0.1758 | LR: 6.25e-05



Epoch 25/50: 100%|██████████| 5014/5014 [13:39<00:00,  6.12it/s, loss=0.1768]


Epoch 25 | Train Loss: 0.1766 | Val Loss: 0.1793 | LR: 5.91e-05



Epoch 26/50: 100%|██████████| 5014/5014 [13:41<00:00,  6.11it/s, loss=0.1750]


Epoch 26 | Train Loss: 0.1749 | Val Loss: 0.1739 | LR: 5.57e-05



Epoch 27/50: 100%|██████████| 5014/5014 [13:41<00:00,  6.10it/s, loss=0.1620]


Epoch 27 | Train Loss: 0.1743 | Val Loss: 0.1725 | LR: 5.22e-05



Epoch 28/50: 100%|██████████| 5014/5014 [13:41<00:00,  6.11it/s, loss=0.1658]


Epoch 28 | Train Loss: 0.1737 | Val Loss: 0.1730 | LR: 4.88e-05



Epoch 29/50: 100%|██████████| 5014/5014 [13:39<00:00,  6.12it/s, loss=0.1753]


Epoch 29 | Train Loss: 0.1733 | Val Loss: 0.1740 | LR: 4.53e-05



Epoch 30/50: 100%|██████████| 5014/5014 [13:40<00:00,  6.11it/s, loss=0.1555]


Epoch 30 | Train Loss: 0.1719 | Val Loss: 0.1748 | LR: 4.19e-05



Epoch 31/50: 100%|██████████| 5014/5014 [13:39<00:00,  6.12it/s, loss=0.1614]


Epoch 31 | Train Loss: 0.1709 | Val Loss: 0.1697 | LR: 3.85e-05



Epoch 32/50: 100%|██████████| 5014/5014 [13:37<00:00,  6.13it/s, loss=0.1597]


Epoch 32 | Train Loss: 0.1701 | Val Loss: 0.1685 | LR: 3.52e-05



Epoch 33/50: 100%|██████████| 5014/5014 [13:41<00:00,  6.11it/s, loss=0.1524]


Epoch 33 | Train Loss: 0.1692 | Val Loss: 0.1678 | LR: 3.20e-05



Epoch 34/50: 100%|██████████| 5014/5014 [13:39<00:00,  6.12it/s, loss=0.1837]


Epoch 34 | Train Loss: 0.1687 | Val Loss: 0.1680 | LR: 2.88e-05



Epoch 35/50: 100%|██████████| 5014/5014 [13:38<00:00,  6.13it/s, loss=0.1872]


Epoch 35 | Train Loss: 0.1680 | Val Loss: 0.1679 | LR: 2.58e-05



Epoch 36/50: 100%|██████████| 5014/5014 [13:38<00:00,  6.13it/s, loss=0.1586]


Epoch 36 | Train Loss: 0.1674 | Val Loss: 0.1659 | LR: 2.28e-05



Epoch 37/50: 100%|██████████| 5014/5014 [13:38<00:00,  6.13it/s, loss=0.1738]


Epoch 37 | Train Loss: 0.1664 | Val Loss: 0.1651 | LR: 2.00e-05



Epoch 38/50: 100%|██████████| 5014/5014 [13:39<00:00,  6.12it/s, loss=0.1771]


Epoch 38 | Train Loss: 0.1659 | Val Loss: 0.1652 | LR: 1.74e-05



Epoch 39/50: 100%|██████████| 5014/5014 [13:38<00:00,  6.12it/s, loss=0.1686]


Epoch 39 | Train Loss: 0.1654 | Val Loss: 0.1649 | LR: 1.49e-05



Epoch 40/50: 100%|██████████| 5014/5014 [13:39<00:00,  6.12it/s, loss=0.1734]


Epoch 40 | Train Loss: 0.1647 | Val Loss: 0.1640 | LR: 1.26e-05



Epoch 41/50: 100%|██████████| 5014/5014 [13:39<00:00,  6.12it/s, loss=0.1651]


Epoch 41 | Train Loss: 0.1639 | Val Loss: 0.1635 | LR: 1.05e-05



Epoch 42/50: 100%|██████████| 5014/5014 [13:39<00:00,  6.12it/s, loss=0.1561]


Epoch 42 | Train Loss: 0.1637 | Val Loss: 0.1626 | LR: 8.52e-06



Epoch 43/50: 100%|██████████| 5014/5014 [13:40<00:00,  6.11it/s, loss=0.1657]


Epoch 43 | Train Loss: 0.1629 | Val Loss: 0.1629 | LR: 6.79e-06



Epoch 44/50: 100%|██████████| 5014/5014 [13:40<00:00,  6.11it/s, loss=0.1513]


Epoch 44 | Train Loss: 0.1625 | Val Loss: 0.1622 | LR: 5.28e-06



Epoch 45/50: 100%|██████████| 5014/5014 [13:41<00:00,  6.11it/s, loss=0.1570]


Epoch 45 | Train Loss: 0.1620 | Val Loss: 0.1616 | LR: 3.99e-06



Epoch 46/50: 100%|██████████| 5014/5014 [13:41<00:00,  6.10it/s, loss=0.1709]


Epoch 46 | Train Loss: 0.1618 | Val Loss: 0.1613 | LR: 2.92e-06



Epoch 47/50: 100%|██████████| 5014/5014 [13:41<00:00,  6.10it/s, loss=0.1527]


Epoch 47 | Train Loss: 0.1615 | Val Loss: 0.1613 | LR: 2.08e-06



Epoch 48/50: 100%|██████████| 5014/5014 [13:41<00:00,  6.10it/s, loss=0.1658]


Epoch 48 | Train Loss: 0.1613 | Val Loss: 0.1611 | LR: 1.48e-06



Epoch 49/50: 100%|██████████| 5014/5014 [13:43<00:00,  6.09it/s, loss=0.1772]


Epoch 49 | Train Loss: 0.1611 | Val Loss: 0.1610 | LR: 1.12e-06



Epoch 50/50: 100%|██████████| 5014/5014 [13:41<00:00,  6.11it/s, loss=0.1574]


Epoch 50 | Train Loss: 0.1609 | Val Loss: 0.1609 | LR: 1.00e-06

Training complete!
