In [2]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"

import torch
device = torch.device("cuda")
import torch.nn.functional as F
import wandb
from torchvision.models import resnet18, resnet34, resnet50, efficientnet_b1
from torchvision.models.feature_extraction import get_graph_node_names
from torchvision.models.feature_extraction import create_feature_extractor

torch.manual_seed(0)

# Setup Weights and Biases and specify hyperparameters
wandb.init(project="Thomas-Masters-Project")

learning_rate = 0.001
epochs = 5
batch_size = 256
net_type = "pretrained_efficientnet_b1"

wandb.config = {
    "learning_rate": learning_rate,
    "epochs": epochs,
    "batch_size": batch_size,
    "network": net_type
}

def test(model, test_loader):
    model.eval()
    correct = 0
    with torch.no_grad(): 
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output, *_ = model(data)
            pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability
            _, idx = output.max(dim=1)
            correct += (idx == target).sum().item()

            # Clear the computed features
            model.clear_features()

    accuracy = 100. * correct / len(test_loader.dataset)
    print('Test set: Accuracy: {}/{} ({:.0f}%)\n'.format(
        correct, len(test_loader.dataset), accuracy))

    wandb.log({"accuracy": accuracy})


def test_ce(model, test_loader):
    model.eval()
    correct = 0
    with torch.no_grad(): 
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability
            _, idx = output.max(dim=1)
            correct += (idx == target).sum().item()

    accuracy = 100. * correct / len(test_loader.dataset)
    print('Test set: Accuracy: {}/{} ({:.0f}%)\n'.format(
        correct, len(test_loader.dataset), accuracy))

    wandb.log({"accuracy": accuracy})

[34m[1mwandb[0m: Currently logged in as: [33mtnoel20[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [3]:
def _init_resnet_18(output_size, pretrained = False, features_hook = None):
    model = resnet18(pretrained=pretrained)
    model.fc = torch.nn.Linear(512, output_size)
    if features_hook is not None:
        for name, module in model.named_modules():
            if name in ['layer1', 'layer2', 'layer3', 'layer4']:
                module.register_forward_hook(features_hook)

    return model
    
def _init_resnet_34(output_size, pretrained = False, features_hook = None):
    model = resnet34(pretrained=pretrained)
    model.fc = torch.nn.Linear(1024, output_size)
    if features_hook is not None:
        for name, module in model.named_modules():
            if name in ['layer1', 'layer2', 'layer3', 'layer4']:
                module.register_forward_hook(features_hook)

    return model

def _init_resnet_50(output_size, pretrained = False, features_hook = None):
    model = resnet50(pretrained=pretrained)
    model.fc = torch.nn.Linear(2048, output_size)
    if features_hook is not None:
        for name, module in model.named_modules():
            if name in ['layer1', 'layer2', 'layer3', 'layer4']:
                module.register_forward_hook(features_hook)

    return model

def _init_efficientnet_b1(output_size, pretrained = False, features_hook = None):
    model = efficientnet_b1(pretrained=pretrained)
    model.classifier = torch.nn.Linear(1280, output_size)
    if features_hook is not None:
        for name, module in model.named_modules():
            if name in ['features.0', 'features.1', 'features.2', 'features.3', 'features.4', 'features.5', 'features.6', 'features.7', 'features.8',]:
                module.register_forward_hook(features_hook)

    return model

def create_pretrained_model(architecture, n_classes, features_hook = None):
    pretrained = True
    if 'resnet18' in architecture:
        net = _init_resnet_18(n_classes, pretrained, features_hook)
    elif 'resnet34' in architecture:
        net = _init_resnet_34(n_classes, pretrained, features_hook)
    elif 'resnet50' in architecture:
        net = _init_resnet_50(n_classes, pretrained, features_hook)
    elif 'efficientnet_b1' in architecture:
        net = _init_efficientnet_b1(n_classes, pretrained, features_hook)
    else:
        raise NotImplementedError()

    return net
    
def create_model(architecture, n_classes, features_hook = None):
    pretrained = False
    if 'resnet18' in architecture:
        net = _init_resnet_18(n_classes, pretrained, features_hook)
    elif 'resnet34' in architecture:
        net = _init_resnet_34(n_classes, pretrained, features_hook)
    elif 'resnet50' in architecture:
        net = _init_resnet_50(n_classes, pretrained, features_hook)
    elif 'efficientnet_b1' in architecture:
        net = _init_efficientnet_b1(n_classes, pretrained, features_hook)
    else:
        raise NotImplementedError()

    return net

class FeatureExtractor(torch.nn.Module):
    def __init__(self, architecture, n_classes = None):
        super().__init__()
        self._features = []
        if 'pretrained' in architecture:
            self.model = create_pretrained_model(
                architecture, 
                n_classes, 
                features_hook=self.feature_hook)
        else:
            self.model = create_model(
                architecture, 
                n_classes, 
                features_hook=self.feature_hook)

    def feature_hook(self, module, input, output):
        self._features.append(output)

    def forward(self, x):
        logits = self.model(x)
        return logits, self._features

    def clear_features(self):
        self._features = []

In [4]:
from torch.utils import data
from torchvision import datasets
from torchvision import transforms
import torch.nn as nn
from torch.optim import Adam, SGD, RMSprop

train_loader = data.DataLoader(
        datasets.CIFAR100('./data', train=True, download=True,
                       transform=transforms.Compose([
                           transforms.ToTensor(),
                           transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229, 0.224, 0.225])
                       ])),
        batch_size=batch_size, shuffle=True, drop_last=True)

test_loader = data.DataLoader(
        datasets.CIFAR100('./data', train=False,
                       transform=transforms.Compose([
                           transforms.ToTensor(),
                           transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229, 0.224, 0.225])
                       ])),
        batch_size=batch_size, shuffle=False, drop_last=False)

num_training_classes = 100

Files already downloaded and verified


In [5]:
import numpy as np
from large_margin import LargeMarginLoss


lm = LargeMarginLoss(
    gamma=10000,
    alpha_factor=4,
    top_k=num_training_classes,
    dist_norm=np.inf
)

net = FeatureExtractor(net_type, num_training_classes)
net.to(device)

def train_lm(model, train_loader, optimizer, epoch, lm):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data = data.to(device)
        one_hot = torch.zeros(len(target), 100).scatter_(1, target.unsqueeze(1), 1.).float()
        one_hot = one_hot.cuda()
        optimizer.zero_grad()
        output, features = model(data)
        for feature in features:
            feature.retain_grad()

        loss = lm(output, one_hot, features)
        
        wandb.log({"loss": loss})
        # optional
        wandb.watch(model)
        
        loss.backward()
        optimizer.step()
        model.clear_features()
        
        if batch_idx % 100 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))

import time

optim = Adam(net.parameters()) #SGD(net.parameters(), lr=learning_rate, momentum=0)
for i in range(0, epochs):
    start_time = time.time()
    train_lm(net, train_loader, optim, i, lm)
    end_time = time.time()

    print('Epoch {} took {} seconds to complete'.format(i+1, end_time-start_time))

    test(net, test_loader)

Downloading: "https://download.pytorch.org/models/efficientnet_b1_rwightman-533bc792.pth" to /nfs/stak/users/noelt/.cache/torch/hub/checkpoints/efficientnet_b1_rwightman-533bc792.pth
77.8%IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)



Epoch 1 took 2900.1922855377197 seconds to complete
Test set: Accuracy: 4487/10000 (45%)

Epoch 2 took 2909.7070820331573 seconds to complete
Test set: Accuracy: 5111/10000 (51%)

Epoch 3 took 2919.935966491699 seconds to complete
Test set: Accuracy: 5362/10000 (54%)

Epoch 4 took 2931.5770175457 seconds to complete
Test set: Accuracy: 5560/10000 (56%)

Epoch 5 took 2948.283804178238 seconds to complete
Test set: Accuracy: 5640/10000 (56%)



In [None]:
def train_ce(model, train_loader, optimizer, epoch):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss_ce = F.cross_entropy(output, target)

        # wandb.log({"loss_ce": loss_ce})
        # # optional
        # wandb.watch(model)

        loss_ce.backward()
        optimizer.step()
        if batch_idx % 100 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss_ce.item()))

net = efficientnet_b1(pretrained=True).to(device)
# net = nn.DataParallel(net).to(device)

import time

optim = Adam(net.parameters())
for i in range(0, epochs):  
    start_time = time.time()  
    train_ce(net, train_loader, optim, i)
    end_time = time.time()
    print('Epoch {} took {} seconds to complete'.format(i+1, end_time-start_time))

    test_ce(net, test_loader)