In [46]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm
import sys
from copy import deepcopy



In [19]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


In [20]:

NUM_EPOCHS = 10
LEARNING_RATE = 0.001

In [21]:
# Load MNIST data
def load_mnist(batch_size=128):
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])
    
    train_dataset = datasets.MNIST(root='./data', train=True, 
                                   download=True, transform=transform)
    test_dataset = datasets.MNIST(root='./data', train=False, 
                                  download=True, transform=transform)
    
    train_loader = DataLoader(train_dataset, batch_size=batch_size, 
                             shuffle=True, num_workers=2)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, 
                            shuffle=False, num_workers=2)
    
    return train_loader, test_loader

# Load data
BATCH_SIZE = 128
print("\n..... Loading MNIST dataset...")
train_loader, test_loader = load_mnist(BATCH_SIZE)
print(f"Training samples: {len(train_loader.dataset)}")
print(f"Test samples: {len(test_loader.dataset)}")


..... Loading MNIST dataset...
Training samples: 60000
Test samples: 10000


In [22]:
class LeNet(nn.Module):
    def __init__(self, num_classes=10):
        super(LeNet, self).__init__()
        self.conv1 = nn.Conv2d(1, 6, kernel_size=5, padding=2)
        self.conv2 = nn.Conv2d(6, 16, kernel_size=5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, num_classes)
        self.pool = nn.AvgPool2d(kernel_size=2, stride=2)
        self.relu = nn.ReLU()
        
    def forward(self, x, mode=None,pp=False,c=(0,0,0,0)):
        if c==None:
            c=(0,0,0,0)
        x = self.pool(self.relu(self.conv1(x)))
        #layer 1
        #print("hahahahahhahahahahhahahh!!!!!!!!!!!!!!!!")
        if pp:
            print("hahahahahhahahahahhahahh!!!!!!!!!@!!!!!!!")
            print(pp)
            print(x.shape)
        if mode=="compress":
            x=compress(x,c[0])
            #print(x.shape)
        x = self.pool(self.relu(self.conv2(x)))
        #layer 2
        if pp:
            print(x.shape)
        if mode=="compress":
            x=compress(x,c[1])
        x = x.view(-1, 16 * 5 * 5)
        x = self.relu(self.fc1(x))
        #layer 3
        if pp:
            print(x.shape)
        if mode=="compress":
            x=compress(x,c[2])
        x = self.relu(self.fc2(x))
        #layer 4
        if pp:
            print(x.shape)
        if mode=="compress":
            x=compress(x,c[3])
        x = self.fc3(x)
        if mode=="Test":
            print("LeNet Forward Testing !!!!!!!!!")
        if pp:
            sys.exit()
        return x

In [23]:
class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, 
                               stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3,
                               stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        
        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, 
                         stride=stride, bias=False),
                nn.BatchNorm2d(out_channels)
            )
    
    def forward(self, x, mode=None):
        out = self.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = self.relu(out)
        if mode=="Test":
            print("ResidualBlock Forward Testing !!!!!!!!!")
        return out

In [24]:
class ResNet(nn.Module):
    def __init__(self, num_classes=10):
        super(ResNet, self).__init__()
        self.conv1 = nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(16)
        self.relu = nn.ReLU(inplace=True)
        
        self.layer1 = self._make_layer(16, 16, 2, stride=1)
        self.layer2 = self._make_layer(16, 32, 2, stride=2)
        self.layer3 = self._make_layer(32, 64, 2, stride=2)
        
        self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(64, num_classes)
    
    def _make_layer(self, in_channels, out_channels, num_blocks, stride):
        layers = []
        layers.append(ResidualBlock(in_channels, out_channels, stride))
        for _ in range(1, num_blocks):
            layers.append(ResidualBlock(out_channels, out_channels, 1))
        return nn.Sequential(*layers)
    
    def forward(self, x, mode=None,pp=False,c=(0,0,0,0)):
    #def forward(self, x, mode=None):
        x = self.relu(self.bn1(self.conv1(x)))
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.avg_pool(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        if mode=="Test":
            print("ResNet Forward Testing !!!!!!!!!")
        return x

In [30]:
def train_model(model, train_loader, criterion, optimizer, epoch):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    
    pbar = tqdm(train_loader, desc=f'Epoch {epoch}')
    for batch_idx, (data, target) in enumerate(pbar):
        data, target = data.to(device), target.to(device)
        
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        _, predicted = output.max(1)
        total += target.size(0)
        correct += predicted.eq(target).sum().item()
        
        pbar.set_postfix({
            'loss': running_loss / (batch_idx + 1),
            'acc': 100. * correct / total
        })
    
    epoch_loss = running_loss / len(train_loader)
    epoch_acc = 100. * correct / total
    
    return epoch_loss, epoch_acc

In [31]:
def evaluate_model(model, test_loader, criterion, mode=None,c=None):
    model.eval()
    test_loss = 0
    correct = 0
    total = 0
    
    with torch.no_grad():
        for data, target in tqdm(test_loader, desc='Evaluating'):
            data, target = data.to(device), target.to(device)
            output = model(data, mode=mode,c=c)
            test_loss += criterion(output, target).item()
            _, predicted = output.max(1)
            total += target.size(0)
            correct += predicted.eq(target).sum().item()
    
    test_loss /= len(test_loader)
    test_acc = 100. * correct / total
    
    return test_loss, test_acc

In [32]:
def plot_history(history, model_name):
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))
    
    # Loss plot
    ax1.plot(history['train_loss'], label='Train Loss', linewidth=2)
    ax1.plot(history['test_loss'], label='Test Loss', linewidth=2)
    ax1.set_xlabel('Epoch', fontsize=12)
    ax1.set_ylabel('Loss', fontsize=12)
    ax1.set_title(f'{model_name} - Loss', fontsize=14, fontweight='bold')
    ax1.legend()
    ax1.grid(True, alpha=0.3)
    
    # Accuracy plot
    ax2.plot(history['train_acc'], label='Train Accuracy', linewidth=2)
    ax2.plot(history['test_acc'], label='Test Accuracy', linewidth=2)
    ax2.set_xlabel('Epoch', fontsize=12)
    ax2.set_ylabel('Accuracy (%)', fontsize=12)
    ax2.set_title(f'{model_name} - Accuracy', fontsize=14, fontweight='bold')
    ax2.legend()
    ax2.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.savefig(f'{model_name}_training_history.png', dpi=300, bbox_inches='tight')
    print(f'Saved training history plot: {model_name}_training_history.png')
    plt.show()

In [33]:
def initial_training(model_name):
    if model_name == 'LeNet': 
        model = LeNet(num_classes=10)
    elif model_name == 'ResNet':
        model = ResNet(num_classes=10)

    print(f"....... Training {model_name}")

    # Initialize model
    model = model.to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)

    # Training history
    history = {
        'train_loss': [],
        'train_acc': [],
        'test_loss': [],
        'test_acc': []
    }

    # Training loop
    print(f"\nTraining {model_name} for {NUM_EPOCHS} epochs...")
    for epoch in range(1, NUM_EPOCHS + 1):
        train_loss, train_acc = train_model(model, train_loader, criterion, 
                                            optimizer, epoch)
        test_loss, test_acc = evaluate_model(model, test_loader, criterion)
        scheduler.step()
        
        history['train_loss'].append(train_loss)
        history['train_acc'].append(train_acc)
        history['test_loss'].append(test_loss)
        history['test_acc'].append(test_acc)
        
        print(f'\nEpoch {epoch}/{NUM_EPOCHS}:')
        print(f'  Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}%')
        print(f'  Test Loss:  {test_loss:.4f} | Test Acc:  {test_acc:.2f}%')

    # Final evaluation
    print(f"\n...... Final Evaluation for {model_name}...")
    final_test_loss, final_test_acc = evaluate_model(model, test_loader, criterion)
    print(f"{model_name} - FINAL RESULTS")
    print(f"{'=' * 60}")
    print(f"Final Test Loss: {final_test_loss:.4f}")
    print(f"Final Test Accuracy: {final_test_acc:.2f}%")

    # Save model
    print(f"\n..... Saving {model_name} model...")
    torch.save({
        'epoch': NUM_EPOCHS,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'test_loss': final_test_loss,
        'test_acc': final_test_acc,
        'history': history
    }, f'{model_name}_mnist.pth')
    print(f"Model saved as: {model_name}_mnist.pth")

    # Plot training history
    print(f"\n....... Plotting training history for {model_name}...")
    plot_history(history, model_name)

In [33]:
model_name = "LeNet" # valid values: LeNet, ResNet
initial_training(model_name)

In [None]:
model_name = "ResNet" # valid values: LeNet, ResNet
initial_training(model_name)

In [None]:
break

In [None]:
def ResNet18_MNIST(num_classes=10):
    # Load standard ResNet18
    #print("XXXXX")
    model = ResNet(BasicBlock, [2, 2, 2, 2], num_classes=num_classes)
    
    # Modify first conv to accept 1 channel instead of 3
    model.conv1 = nn.Conv2d(1, 64, kernel_size=3, stride=1, padding=1, bias=False)
    
    # Remove the first maxpool (not needed for 28x28 images)
    model.maxpool = nn.Identity()
    
    # Change final fully connected layer for 10 classes
    model.fc = nn.Linear(model.fc.in_features, num_classes)
    
    return model

In [None]:

model = ResNet18_MNIST()
x = torch.randn(16, 1, 28, 28)  # batch of 16 MNIST images
y = model(x)
print(y.shape)  # torch.Size([16, 10])


In [None]:
def ResNet50_MNIST(num_classes=10):
    # Load the standard ResNet-50
    model = ResNet(Bottleneck, [3, 4, 6, 3], num_classes=num_classes)
    
    # Change the first conv layer to accept 1-channel input (instead of 3)
    model.conv1 = nn.Conv2d(1, 64, kernel_size=3, stride=1, padding=1, bias=False)
    
    # Remove the first maxpool layer (28x28 images are small)
    model.maxpool = nn.Identity()
    
    # Change the classification head to output 10 classes
    model.fc = nn.Linear(model.fc.in_features, num_classes)
    
    return model

In [None]:

model = ResNet50_MNIST()
x = torch.randn(16, 1, 28, 28)  # batch of 16 MNIST images
y = model(x)
print(y.shape)  # torch.Size([16, 10])
