In [4]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision import models



In [None]:
# Since the file that we are trying to train on is massive,
# we must program it so that it randomly generates images
# across class in  a lazy manner. 


# Our strategy is to train on N classes and for create
# minibatches of size batch_size.
import tarfile
import random
from collections import defaultdict, deque

class LazyRandomMinibatchGenerator:
    def __init__(self, tar_path, N):
        """
        :param tar_path: Path to the .tar.gz file containing images.
        :param N: Number of classes to sample from per minibatch.
        """
        self.tar_path = tar_path
        self.N = N
        
        # class_name -> deque of file paths
        self.images_by_class = defaultdict(deque)
        
        # Keep the tarfile open and track our current read position
        self.tar = tarfile.open(self.tar_path, "r:gz")
        
        self.end_of_archive = False  # True when we've read the entire tar
        self.all_discovered_classes = set()  # Classes we've found so far
        self.exhausted_classes = set()  # Classes that ran out of images
        
        # Internal generator over tar members (for lazy reading)
        self._tar_iter = None

    @property
    def tar_iterator(self):
        """
        A generator that yields members from the tarfile, 
        preserving our read position.
        """
        if self._tar_iter is None:
            self._tar_iter = (m for m in self.tar)
        return self._tar_iter

    def close(self):
        """Close the tarfile (once you're completely done generating)."""
        self.tar.close()

    def _read_until_we_have_enough_classes(self, needed_classes):
        """
        Read from the tar until we have at least `needed_classes` 
        classes discovered or we reach the end of the archive.
        """
        if self.end_of_archive:
            return
        
        while len(self.all_discovered_classes) < needed_classes:
            try:
                member = next(self.tar_iterator)
            except StopIteration:
                self.end_of_archive = True
                break
            
            parts = member.name.split("/")
            # We expect something like: imagenet21k_resized/imagenet21k_small_classes/<class>/<filename>
            if len(parts) > 2 and member.isfile():
                class_name = parts[2]
                self.all_discovered_classes.add(class_name)
                self.images_by_class[class_name].append(member.name)

    def _shuffle_exhausted_classes(self):
        """
        If all discovered classes are exhausted, re-shuffle their deques 
        so they can be reused in future batches.
        """
        if len(self.exhausted_classes) == len(self.all_discovered_classes):
            # All known classes are exhausted; reshuffle each class's deque
            for cls in self.all_discovered_classes:
                file_list = list(self.images_by_class[cls])
                random.shuffle(file_list)
                self.images_by_class[cls] = deque(file_list)
            
            self.exhausted_classes.clear()

    def generate_minibatch(self, batch_size):
        """
        Generate a minibatch of size `batch_size`, selecting from up to N distinct classes.
        The images from these classes are mixed (round-robin or random) so that 
        multiple classes appear in the same minibatch.
        """
        # 1) Ensure we have discovered at least N non-empty classes (if possible)
        non_empty_classes = [c for c in self.all_discovered_classes 
                             if self.images_by_class[c]]

        # If we don't have enough classes, read more from the tar
        while len(non_empty_classes) < self.N and not self.end_of_archive:
            needed = max(self.N, len(self.all_discovered_classes) + 10)
            self._read_until_we_have_enough_classes(needed)
            non_empty_classes = [c for c in self.all_discovered_classes
                                 if self.images_by_class[c]]

        # Pick up to N random classes from the non-empty ones
        if len(non_empty_classes) == 0:
            # No data at all
            return []
        elif len(non_empty_classes) <= self.N:
            chosen_classes = non_empty_classes
        else:
            chosen_classes = random.sample(non_empty_classes, self.N)

        # 2) Collect images from these chosen classes in a round-robin/random manner
        minibatch = []
        while len(minibatch) < batch_size and chosen_classes:
            # We'll sample from classes in random order each iteration
            random_order = random.sample(chosen_classes, len(chosen_classes))
            
            for cls in random_order:
                if len(minibatch) >= batch_size:
                    break  # Stop if we have enough images
                if self.images_by_class[cls]:
                    minibatch.append(self.images_by_class[cls].popleft())
                    # If the class is now empty, mark it as exhausted
                    if not self.images_by_class[cls]:
                        self.exhausted_classes.add(cls)
                        chosen_classes.remove(cls)
                else:
                    self.exhausted_classes.add(cls)
                    chosen_classes.remove(cls)
        
        # 3) If all discovered classes are exhausted, reshuffle them
        self._shuffle_exhausted_classes()

        return minibatch




In [7]:


# -------------------------
# Hyperparameters
# -------------------------
data_dir = '/path/to/imagenet/'
batch_size = 256
num_workers = 8
num_epochs = 90
learning_rate = 0.1
momentum = 0.9
weight_decay = 1e-4
print_freq = 50  # Steps between printing training status

# -------------------------
# Device configuration
# -------------------------
if torch.mps.is_available():
    device_setting = 'mps'
elif torch.cuda.is_available():
    device_setting = 'cude'
else:
    device_setting ='cpu'
device = torch.device(device_setting)
print("Using device:", device)

    # -------------------------
    # Transforms and Datasets
    # -------------------------
    # Official ImageNet stats; can be used for normalization
    normalize = transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )
    
    train_transforms = transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        normalize,
    ])
    
    val_transforms = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        normalize,
    ])
    
    train_dataset = torchvision.datasets.ImageFolder(
        os.path.join(data_dir, 'train'), transform=train_transforms
    )
    val_dataset = torchvision.datasets.ImageFolder(
        os.path.join(data_dir, 'val'), transform=val_transforms
    )
    
    train_loader = DataLoader(
        train_dataset, batch_size=batch_size, shuffle=True,
        num_workers=num_workers, pin_memory=True
    )
    val_loader = DataLoader(
        val_dataset, batch_size=batch_size, shuffle=False,
        num_workers=num_workers, pin_memory=True
    )
    
    # -------------------------
    # Model, Loss, Optimizer
    # -------------------------
    # Using a standard ResNet-18 from torchvision
    model = models.resnet18(pretrained=False)  # set pretrained=True to fine-tune
    # Change the final layer to match the 1000-class ImageNet
    # (not necessary if using the original ResNet-18 from torchvision, which outputs 1000 classes)
    model.fc = nn.Linear(model.fc.in_features, 1000)
    
    model = model.to(device)
    
    criterion = nn.CrossEntropyLoss().to(device)
    optimizer = optim.SGD(model.parameters(), lr=learning_rate,
                          momentum=momentum, weight_decay=weight_decay)
    
    # Optional: learning rate scheduler
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)

    # -------------------------
    # Training & Validation Loops
    # -------------------------
    for epoch in range(num_epochs):
        # ---- Train ----
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0
        
        for i, (images, labels) in enumerate(train_loader):
            images = images.to(device, non_blocking=True)
            labels = labels.to(device, non_blocking=True)
            
            # Forward pass
            outputs = model(images)
            loss = criterion(outputs, labels)
            
            # Backward and optimize
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            # Compute training statistics
            running_loss += loss.item() * images.size(0)
            _, predicted = outputs.max(1)
            correct += predicted.eq(labels).sum().item()
            total += labels.size(0)
            
            if (i+1) % print_freq == 0:
                print(f"Epoch [{epoch+1}/{num_epochs}], "
                      f"Step [{i+1}/{len(train_loader)}], "
                      f"Loss: {loss.item():.4f}")

        epoch_loss = running_loss / total
        epoch_acc = correct / total
        print(f"Epoch [{epoch+1}/{num_epochs}] Training Loss: {epoch_loss:.4f}, Accuracy: {100.0*epoch_acc:.2f}%")
        
        # Step the scheduler (if you are using one)
        scheduler.step()

        # ---- Validate ----
        model.eval()
        val_loss = 0.0
        val_correct = 0
        val_total = 0
        
        with torch.no_grad():
            for images, labels in val_loader:
                images = images.to(device, non_blocking=True)
                labels = labels.to(device, non_blocking=True)
                
                outputs = model(images)
                loss = criterion(outputs, labels)
                
                val_loss += loss.item() * images.size(0)
                _, predicted = outputs.max(1)
                val_correct += predicted.eq(labels).sum().item()
                val_total += labels.size(0)
        
        val_loss /= val_total
        val_accuracy = val_correct / val_total
        print(f"Epoch [{epoch+1}/{num_epochs}] Validation Loss: {val_loss:.4f}, Accuracy: {100.0*val_accuracy:.2f}%")

    # -------------------------
    # Save the trained model
    # -------------------------
    torch.save(model.state_dict(), 'resnet18_imagenet.pth')
    print("Model saved to resnet18_imagenet.pth")




IndentationError: expected an indented block after 'else' statement on line 21 (2748051627.py, line 22)

In [6]:
if __name__ == '__main__':
    main()