In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
import matplotlib.pyplot as plt
import os
import pandas as pd
import copy
from PIL import Image
%matplotlib inline

In [2]:
tr = pd.read_csv('./sun397_train_lt.txt', header=None, sep=' ')

In [3]:
class_names = {}

In [4]:
for lab in tr[1].unique():
    temp = tr.loc[tr[1] == lab].iloc[0, 0]
    class_name = temp.split('/')[-2]
    class_names[lab] = class_name

In [5]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [6]:
#  device = torch.device('cpu')

In [44]:
LOG_DIR = './log'
DATALOADER_WORKERS = 4
LEARNING_RATE = 0.01
MOMENTUM = 0.9
EPOCHS = 60
BATCH_SIZE = 256
DISPLAY_STEP = 1
NUM_CLASSES = 397

if not os.path.isdir(LOG_DIR):
    os.makedirs(LOG_DIR)

In [8]:
class sun_dataset (torch.utils.data.Dataset):
    
    def __init__ (self, txt_file, transform=None):
        super().__init__()
        self.df = pd.read_csv(txt_file, header=None, sep=' ')
        self.transform = transform
        
    def __len__ (self):
        return len(self.df)
    
    def __getitem__ (self, idx):
        
        image = Image.open(self.df.iloc[idx, 0])
        label = self.df.iloc[idx, 1] - 1
#         image_dir = self.df.iloc[idx, 0]
        
        if self.transform:
            image = self.transform(image)
            
        return image, label

In [9]:
# transforms.RandomResizeCrop(224)
data_transforms = {
    'train': transforms.Compose([
        transforms.RandomCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'val': transforms.Compose([
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'test': transforms.Compose([
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
}

In [10]:
datasets = {x: sun_dataset(txt_file='./sun397_%s_lt.txt' % x, transform=data_transforms[x]) for x in ['train', 'val', 'test']} 
dataloaders = {x: torch.utils.data.DataLoader(datasets[x], batch_size=BATCH_SIZE, shuffle=True, num_workers=DATALOADER_WORKERS) for x in ['train', 'val', 'test']}
dataset_sizes = {x: len(datasets[x]) for x in ['train', 'val', 'test']}

In [11]:
# train_set = sun_dataset(txt_file='./sun397_train_lt.txt', transform=data_transforms['train'])
# train_loader = torch.utils.data.DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)

# val_set = sun_dataset(txt_file='./sun397_val_lt.txt', transform=data_transforms['val_test'])
# val_loader = torch.utils.data.DataLoader(val_set, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)

# test_set = sun_dataset(txt_file='./sun397_test_lt.txt', transform=data_transforms['val_test'])
# test_loader = torch.utils.data.DataLoader(test_set, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)

In [12]:
# def plot_image (img_tensor, title=None):
    
#     img = img_tensor.numpy().transpose((1, 2, 0))
    
#     mean = np.array([0.485, 0.456, 0.406])
#     std = np.array([0.229, 0.224, 0.225])
#     img = img * std + mean
    
#     img = np.clip(img, 0, 1)
    
#     plt.figure(figsize=(15, 10))
#     plt.imshow(img)
#     if title:
#         plt.title(title)
    

# images, labels = next(iter(dataloaders['train']))

# image_grid = torchvision.utils.make_grid(images)

# plot_image(image_grid, [class_names[l.item()] for l in labels ])

In [13]:
def train_model (model, loss_function, optimizer, scheduler, num_epochs, model_id=None):
    
    # Deep copy model weights
    best_model_weights = copy.deepcopy(model.state_dict())
    best_acc = 0.0
    
    # Start training
    training_step = 0
    for epoch in range(num_epochs):
          
        # Loop over training phase and validation phase
        for phase in ['train', 'val']:
            
            # Set model modes and set scheduler
            if phase == 'train':
                scheduler.step()
                model.train()
            else:
                model.eval()
                
            running_loss = 0.0
            running_correct = 0
            
            # Iterate over data
            for inputs, labels in dataloaders[phase]:
                
                inputs, labels = inputs.to(device), labels.to(device)
                
                # Zero parameter gradients
                optimizer.zero_grad()
                
                # Forward
                # If on training phase, enable gradients
                with torch.set_grad_enabled(phase == 'train'):
                    
                    logits = model(inputs)
                    _, preds = torch.max(logits, 1)
                    loss = loss_function(logits, labels)
                    
                    # Backward if training
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()
                        training_step += 1
                        
                        if training_step % DISPLAY_STEP == 0:
                            minibatch_loss = loss.item()
                            minibatch_acc = (preds == labels).sum().item() / BATCH_SIZE
                            print('Epoch: %d, Step: %5d, Minibatch_loss: %.3f, Minibatch_accuracy: %.3f' % (epoch, training_step, minibatch_loss, minibatch_acc))
                        
                # Record loss and correct predictions
                running_loss += loss.item() * inputs.shape[0]
                running_correct += (preds == labels).sum().item()
                
            epoch_loss = running_loss / dataset_sizes[phase]
            epoch_acc = running_correct / dataset_sizes[phase]
            
            print('Epoch: %d, Phase: %s, Epoch_loss: %.3f, Epoch_accuracy: %.3f' % (epoch, phase, epoch_loss, epoch_acc))
            
            # Deep copy the best model weights
            if phase == 'val' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_weights = copy.deepcopy(model.state_dict())
                
    print()
    print('Training Complete.')
    print('Best validation accuracy: %.3f' % best_acc)
    
    # Load the best model weights
    model.load_state_dict(best_model_weights)
    
    # Save the best model
    model_states = {'epoch': epoch + 1,
                    'state_dict': model.state_dict(),
                    'best_acc': best_acc,
                    'optimizer' : optimizer.state_dict()}
    
    torch.save(model_states, 'model_%s_checkpoint.pth.tar' % model_id)
    
    return model
        

In [45]:
# Load pretrained model
resnet = torchvision.models.resnet152(pretrained=True)
# Freeze all layers
for param in resnet.parameters():
    param.requires_grad = False

In [46]:
# Reset the fc layer
num_features = resnet.fc.in_features
resnet.fc = nn.Linear(num_features, NUM_CLASSES)

In [47]:
resnet = resnet.to(device)

In [48]:
# Loss function
loss_function = nn.CrossEntropyLoss()

# Optimizer only on the last fc layer
optimizer = optim.SGD(resnet.fc.parameters(), lr=LEARNING_RATE, momentum=MOMENTUM)

# Decay LR by a factor of 0.1 every 30 epochs
exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)

In [56]:
# resnet = train_model(model=resnet, loss_function=loss_function, optimizer=optimizer, scheduler=exp_lr_scheduler, num_epochs=EPOCHS, model_id='plain')

In [86]:
[0. for i in range(10)]

[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]

In [None]:
def macro_acc (num_classes, dataloader, model):
    
    class_correct = [0. for i in range(num_classes)]
    class_total = [0. for i in range(num_classes)]
    
    with torch.no_grad():
        for data in dataloader:
            images, labels = data
            outputs = model(images)
            _, predicted = torch.max(outputs, 1)
            
            
            c = (predicted == labels).squeeze()
            for i in range(4):
                label = labels[i]
                class_correct[label] += c[i].item()
                class_total[label] += 1


    for i in range(10):
        print('Accuracy of %5s : %2d %%' % (
            classes[i], 100 * class_correct[i] / class_total[i]))



In [57]:
aug = pd.read_csv('./sun397_train_lt_with_aug.txt', header=None, sep=' ')

In [59]:
num = []

In [62]:
for l in aug[1].unique():
    num.append(len(aug.loc[aug[1] == l]))

In [65]:
fail = []

In [76]:
aug.index

RangeIndex(start=0, stop=41333, step=1)

In [78]:
aug.iloc[0, 0]

'/home/zhmiao/low-shot/SUN397_250/a/abbey/sun_aflnmqnusoeqyzse.png'

In [80]:
for ind in range(len(aug)):
    
    d = aug.iloc[ind, 0]
    
    if not os.path.isfile(d):
        d = d.rsplit('aug', 1)
        d = d[0] + '/aug' + d[1]
        assert(os.path.isfile(d))
        aug.iloc[ind, 0] = d

In [81]:
aug.to_csv('./sun397_train_lt_with_aug_new.txt', sep=' ', header=None, index=None)

In [82]:
aug = pd.read_csv('./sun397_train_lt_with_aug_new.txt', header=None, sep=' ')

In [84]:
for d in aug[0]:
    if not os.path.isfile(d):
        print(d)