In [None]:
import torch
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from TinyImageNet import TinyImageNet
from hard_margin_loss import MarginLoss
from margin_loss_soft_logit import MarginLoss as LogitMargin
#32 by 32 image size
#from models.vgg import VGG

#64 by 64 image size
from models.vgg_tiny import VGG

In [None]:
augmentation = transforms.RandomApply([
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.RandomResizedCrop(64)], p=.8)

training_transform = transforms.Compose([
    transforms.Lambda(lambda x: x.convert("RGB")),
    augmentation,
    transforms.ToTensor()])

valid_transform = transforms.Compose([
    transforms.Lambda(lambda x: x.convert("RGB")),
    transforms.ToTensor()])

In [None]:
root ="tiny-imagenet-200"
in_memory = False
training_set = TinyImageNet(root, 'train', transform=training_transform, in_memory=in_memory)
valid_set = TinyImageNet(root, 'val', transform=valid_transform, in_memory=in_memory)

In [None]:
device = torch.device("cuda")
model = VGG('VGG13').to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[50, 100], gamma=0.2)
max_epochs = 150

In [None]:
trainloader = DataLoader(training_set, batch_size=256, shuffle=True, num_workers=8,pin_memory=True)
validloader = DataLoader(valid_set, batch_size=64, num_workers=8,pin_memory=True)

In [None]:
#softmax training
criterion = torch.nn.CrossEntropyLoss()

In [None]:
import time
assert torch.cuda.is_available()
try:
    for epoch in range(max_epochs):
        start = time.time()
        lr_scheduler.step()
        epoch_loss = 0.0
        model.train()
        for idx, (data, target) in enumerate(trainloader):
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()
            output,_ = model(data)

            batch_loss = criterion(output, target)
            
            batch_loss.backward()
            optimizer.step()
            epoch_loss += batch_loss.item()
        
            if idx % 100 == 0:
                print('{:.1f}% of epoch'.format(idx / float(len(trainloader)) * 100), end='\r')
            
            
        # evaluate on validation set
        num_hits = 0
        num_instances = len(valid_set)
        
        with torch.no_grad():
            model.eval()
            for idx, (data, target) in enumerate(validloader):
                data, target = data.to(device), target.to(device)
                output,_ = model(data)
                _, pred = torch.max(output, 1) 

                num_hits += (pred == target).sum().item()

        valid_acc = num_hits / num_instances * 100
        print(f' Validation acc: {valid_acc}%')
            
        epoch_loss /= float(len(trainloader))
        
        # save model
        torch.save(model.state_dict(), 'vgg13_model_CE.pth')
        

    
finally:
    torch.cuda.empty_cache()

In [None]:
#PC Loss training
model.load_state_dict(torch.load('vgg13_model_CE.pth'))
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[50, 100], gamma=0.2)
logit_margin = LogitMargin()
c_margin = MarginLoss(margin=0.995)

In [None]:
import time
assert torch.cuda.is_available()
try:
    for epoch in range(max_epochs):
        start = time.time()
        lr_scheduler.step()
        epoch_loss = 0.0
        model.train()
        for idx, (data, target) in enumerate(trainloader):
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()
            output,_ = model(data)

            m_loss = c_margin(output, target)
            logit_m_loss = logit_margin(output, target)
            batch_loss = m_loss + 0.05 * logit_m_loss
            
            batch_loss.backward()
            optimizer.step()
            epoch_loss += batch_loss.item()
        
            if idx % 100 == 0:
                print('{:.1f}% of epoch'.format(idx / float(len(trainloader)) * 100), end='\r')
            
            
        # evaluate on validation set
        num_hits = 0
        num_instances = len(valid_set)
        
        with torch.no_grad():
            model.eval()
            for idx, (data, target) in enumerate(validloader):
                data, target = data.to(device), target.to(device)
                output,_ = model(data)
                _, pred = torch.max(output, 1) 

                num_hits += (pred == target).sum().item()

        valid_acc = num_hits / num_instances * 100
        print(f' Validation acc: {valid_acc}%')
            
        epoch_loss /= float(len(trainloader))
       
        # save model
        torch.save(model.state_dict(), 'vgg13_model_PC_loss_995.pth')
        
finally:
    torch.cuda.empty_cache()