In [1]:
import os
os.environ["HF_HOME"] = "/project2/jieyuz_1727/snehansh/hf_cache"
os.environ["HF_DATASETS_CACHE"] = "/project2/jieyuz_1727/snehansh/hf_datasets"

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import models, transforms
from torch.utils.data import DataLoader
from datasets import load_dataset
from torch_lr_finder import LRFinder

In [3]:
def get_simple_transforms():
    return transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

In [4]:
def load_data():    
    dataset = load_dataset(
        "imagenet-1k",
        split="train",  
        cache_dir="/project2/jieyuz_1727/snehansh/imagenet/cache",
    )
    
    print(f"Loaded {len(dataset):,} images")
    
    transform = get_simple_transforms()
    
    def transform_batch(batch):
        batch['image'] = [transform(img.convert('RGB')) for img in batch['image']]
        return batch
    
    dataset.set_transform(transform_batch)
    
    return dataset

In [5]:
def collate_fn(batch):
    images = torch.stack([item['image'] for item in batch])
    labels = torch.tensor([item['label'] for item in batch])
    return images, labels

In [8]:
def find_lr():
    dataset = load_data()
    
    print("Creating DataLoader...")
    train_loader = DataLoader(
        dataset,
        batch_size=512,    
        shuffle=True,
        num_workers=6,
        pin_memory=True,
        collate_fn=collate_fn
    )
    print(f"DataLoader ready: {len(train_loader)} batches\n")
    
    model = models.resnet50(weights=None, num_classes=1000)
    model = model.cuda()
    
    optimizer = optim.SGD(
        model.parameters(),
        lr=1e-7,  
        momentum=0.9,
        weight_decay=1e-4,
        nesterov=True
    )
    
    criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
    
    lr_finder = LRFinder(model, optimizer, criterion, device="cuda")
    
    lr_finder.range_test(
        train_loader,
        start_lr=1e-7,      
        end_lr=1,           
        num_iter=100,       
        step_mode="exp",    
        smooth_f=0.05,      
        diverge_th=5        
    )
    
    lr_finder.plot(
        skip_start=10,  
        skip_end=5,     
        log_lr=False,   
    )
    
    import matplotlib.pyplot as plt
    plt.savefig('lr_finder_result.png', dpi=150, bbox_inches='tight')
    
    lr_finder.reset()
    

In [None]:
find_lr()