In [7]:
"""
General utils for training, evaluation and data loading
"""
import os
import torch
import pickle
import numpy as np
import torchvision.transforms as transforms

from PIL import Image

from torch.utils.data import BatchSampler
from torch.utils.data import Dataset, DataLoader

# General
BASE_DIR = ''
N_ATTRIBUTES = 312
N_CLASSES = 200

# Training
UPWEIGHT_RATIO = 9.0
MIN_LR = 0.0001
LR_DECAY_SIZE = 0.1

class CUBDataset(Dataset):
    """
    Returns a compatible Torch Dataset object customized for the CUB dataset
    """

    def __init__(self, pkl_file_paths, use_attr, no_img, uncertain_label, image_dir, n_class_attr, transform=None):
        """
        Arguments:
        pkl_file_paths: list of full path to all the pkl data
        use_attr: whether to load the attributes (e.g. False for simple finetune)
        no_img: whether to load the images (e.g. False for A -> Y model)
        uncertain_label: if True, use 'uncertain_attribute_label' field (i.e. label weighted by uncertainty score, e.g. 1 & 3(probably) -> 0.75)
        image_dir: default = 'images'. Will be append to the parent dir
        n_class_attr: number of classes to predict for each attribute. If 3, then make a separate class for not visible
        transform: whether to apply any special transformation. Default = None, i.e. use standard ImageNet preprocessing
        """
        self.data = []
        self.is_train = any(["train" in path for path in pkl_file_paths])
        if not self.is_train:
            assert any([("test" in path) or ("val" in path) for path in pkl_file_paths])
        for file_path in pkl_file_paths:
            self.data.extend(pickle.load(open(file_path, 'rb')))
        self.transform = transform
        self.use_attr = use_attr
        self.no_img = no_img
        self.uncertain_label = uncertain_label
        self.image_dir = image_dir
        self.n_class_attr = n_class_attr

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img_data = self.data[idx]
        img_path = img_data['img_path']
        
        # Trim unnecessary paths
        try:
            idx = img_path.split('/').index('CUB_200_2011')
            if self.image_dir != 'images':
                img_path = '/'.join([self.image_dir] + img_path.split('/')[idx+1:])
                img_path = img_path.replace('images/', '')
            else:
                img_path = '/'.join(img_path.split('/')[idx:])
            img = Image.open(img_path).convert('RGB')
        except:
            #img_path_split = img_path.split('/')
            #split = 'train' if self.is_train else 'test'
            #img_path = '/'.join(img_path_split[:2] + [split] + img_path_split[2:])
            
            img = Image.open(img_path).convert('RGB')

        class_label = img_data['class_label']
        if self.transform:
            img = self.transform(img)

        if self.use_attr:
            if self.uncertain_label:
                attr_label = img_data['uncertain_attribute_label']
            else:
                attr_label = img_data['attribute_label']
            if self.no_img:
                if self.n_class_attr == 3:
                    one_hot_attr_label = np.zeros((N_ATTRIBUTES, self.n_class_attr))
                    one_hot_attr_label[np.arange(N_ATTRIBUTES), attr_label] = 1
                    return one_hot_attr_label, class_label
                else:
                    return attr_label, class_label
            else:
                return img, class_label, attr_label
        else:
            return img, class_label


class ImbalancedDatasetSampler(torch.utils.data.sampler.Sampler):
    """Samples elements randomly from a given list of indices for imbalanced dataset
    Arguments:
        indices (list, optional): a list of indices
        num_samples (int, optional): number of samples to draw
    """

    def __init__(self, dataset, indices=None):
        # if indices is not provided,
        # all elements in the dataset will be considered
        self.indices = list(range(len(dataset))) \
            if indices is None else indices

        # if num_samples is not provided,
        # draw `len(indices)` samples in each iteration
        self.num_samples = len(self.indices)

        # distribution of classes in the dataset
        label_to_count = {}
        for idx in self.indices:
            label = self._get_label(dataset, idx)
            if label in label_to_count:
                label_to_count[label] += 1
            else:
                label_to_count[label] = 1

        # weight for each sample
        weights = [1.0 / label_to_count[self._get_label(dataset, idx)]
                   for idx in self.indices]
        self.weights = torch.DoubleTensor(weights)

    def _get_label(self, dataset, idx):  # Note: for single attribute dataset
        return dataset.data[idx]['attribute_label'][0]

    def __iter__(self):
        idx = (self.indices[i] for i in torch.multinomial(
            self.weights, self.num_samples, replacement=True))
        return idx

    def __len__(self):
        return self.num_samples

def load_data(pkl_paths, use_attr, no_img, batch_size, uncertain_label=False, n_class_attr=2, image_dir='images', resampling=False, resol=299):
    """
    Note: Inception needs (299,299,3) images with inputs scaled between -1 and 1
    Loads data with transformations applied, and upsample the minority class if there is class imbalance and weighted loss is not used
    NOTE: resampling is customized for first attribute only, so change sampler.py if necessary
    """
    resized_resol = int(resol * 256/224)
    is_training = any(['train.pkl' in f for f in pkl_paths])
    if is_training:
        transform = transforms.Compose([
            #transforms.Resize((resized_resol, resized_resol)),
            #transforms.RandomSizedCrop(resol),
            transforms.ColorJitter(brightness=32/255, saturation=(0.5, 1.5)),
            transforms.RandomResizedCrop(resol),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(), #implicitly divides by 255
            transforms.Normalize(mean = [0.5, 0.5, 0.5], std = [2, 2, 2])
            #transforms.Normalize(mean = [ 0.485, 0.456, 0.406 ], std = [ 0.229, 0.224, 0.225 ]),
            ])
    else:
        transform = transforms.Compose([
            #transforms.Resize((resized_resol, resized_resol)),
            transforms.CenterCrop(resol),
            transforms.ToTensor(), #implicitly divides by 255
            transforms.Normalize(mean = [0.5, 0.5, 0.5], std = [2, 2, 2])
            #transforms.Normalize(mean = [ 0.485, 0.456, 0.406 ], std = [ 0.229, 0.224, 0.225 ]),
            ])

    dataset = CUBDataset(pkl_paths, use_attr, no_img, uncertain_label, image_dir, n_class_attr, transform)
    if is_training:
        drop_last = True
        shuffle = True
    else:
        drop_last = False
        shuffle = False
    if resampling:
        sampler = BatchSampler(ImbalancedDatasetSampler(dataset), batch_size=batch_size, drop_last=drop_last)
        loader = DataLoader(dataset, batch_sampler=sampler)
    else:
        loader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last)
    return loader

def find_class_imbalance(pkl_file, multiple_attr=False, attr_idx=-1):
    """
    Calculate class imbalance ratio for binary attribute labels stored in pkl_file
    If attr_idx >= 0, then only return ratio for the corresponding attribute id
    If multiple_attr is True, then return imbalance ratio separately for each attribute. Else, calculate the overall imbalance across all attributes
    """
    imbalance_ratio = []
    data = pickle.load(open(os.path.join(BASE_DIR, pkl_file), 'rb'))
    n = len(data)
    n_attr = len(data[0]['attribute_label'])
    if attr_idx >= 0:
        n_attr = 1
    if multiple_attr:
        n_ones = [0] * n_attr
        total = [n] * n_attr
    else:
        n_ones = [0]
        total = [n * n_attr]
    for d in data:
        labels = d['attribute_label']
        if multiple_attr:
            for i in range(n_attr):
                n_ones[i] += labels[i]
        else:
            if attr_idx >= 0:
                n_ones[0] += labels[attr_idx]
            else:
                n_ones[0] += sum(labels)
    for j in range(len(n_ones)):
        imbalance_ratio.append(total[j]/n_ones[j] - 1)
    if not multiple_attr: #e.g. [9.0] --> [9.0] * 312
        imbalance_ratio *= n_attr
    return imbalance_ratio

In [2]:
!pip install gdown

Collecting gdown
  Downloading gdown-4.7.1-py3-none-any.whl (15 kB)
Installing collected packages: gdown
Successfully installed gdown-4.7.1
[0m

In [3]:
!gdown 1DUkovCVCUqYScle624llHyeUuC7t0SaN
!gdown 1oRBQ7WY_9-qfWkxNC7ZGBW27wecfEy6f
!gdown 1EsYvhX6aRDELjoKpcRo3fuLo-aKF6DI9
!gdown 1ir5HukW2XO25GWqx3jV-gKZWWi6Zg_rw

Downloading...
From (uriginal): https://drive.google.com/uc?id=1DUkovCVCUqYScle624llHyeUuC7t0SaN
From (redirected): https://drive.google.com/uc?id=1DUkovCVCUqYScle624llHyeUuC7t0SaN&confirm=t&uuid=8d901508-fabc-48f9-8eb2-5ca2509ee65c
To: /kaggle/working/images.zip
100%|██████████████████████████████████████| 1.13G/1.13G [00:15<00:00, 72.5MB/s]
Downloading...
From: https://drive.google.com/uc?id=1oRBQ7WY_9-qfWkxNC7ZGBW27wecfEy6f
To: /kaggle/working/kaggle_train.pkl
100%|█████████████████████████████████████████| 497k/497k [00:00<00:00, 114MB/s]
Downloading...
From: https://drive.google.com/uc?id=1EsYvhX6aRDELjoKpcRo3fuLo-aKF6DI9
To: /kaggle/working/kaggle_test.pkl
100%|████████████████████████████████████████| 601k/601k [00:00<00:00, 90.8MB/s]
Downloading...
From: https://drive.google.com/uc?id=1ir5HukW2XO25GWqx3jV-gKZWWi6Zg_rw
To: /kaggle/working/kaggle_val.pkl
100%|████████████████████████████████████████| 124k/124k [00:00<00:00, 63.0MB/s]


In [4]:
import zipfile

with zipfile.ZipFile('./images.zip',"r") as z:
    z.extractall()

Dataset loader from bottleneck concept script

In [8]:


train_data_path="/kaggle/working/kaggle_train.pkl"
val_data_path="/kaggle/working/kaggle_val.pkl"

train_loader = load_data([train_data_path], use_attr=False, no_img=False, batch_size=64, uncertain_label=False, image_dir="", n_class_attr=0)
val_loader = load_data([val_data_path], use_attr=False, no_img=False, batch_size=64, uncertain_label=False, image_dir="", n_class_attr=0)

In [9]:
import torch.nn as nn
import torch.optim as optim
import torchvision.models as models

# Load the pre-trained ResNet18 model
model = models.resnet18(weights='DEFAULT')

# Modify the last layer to match the number of classes in the CUB dataset
num_classes = 200
model.fc = nn.Linear(model.fc.in_features, num_classes)


import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from sklearn.metrics import accuracy_score

# Define the device to use for computation (GPU if available, else CPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)



# Define the number of epochs to train for
num_epochs = 50

# Train the model
for epoch in range(num_epochs):
    # Set the model to training mode
    model.train()

    # Train one epoch
    train_loss = 0.0
    train_acc = 0.0
    train_size = 0
    for images, labels in train_loader:
        # Move the data to the device
        model = model.to(device)
        images, labels = images.to(device), labels.to(device)

        # Zero the gradients
        optimizer.zero_grad()

        # Forward pass
        outputs = model(images)
        loss = loss_fn(outputs, labels)

        # Backward pass and optimization
        loss.backward()
        optimizer.step()

        # Update the training loss and accuracy
        train_loss += loss.item() * images.size(0)
        _, preds = torch.max(outputs, 1)
        train_acc += accuracy_score(preds.cpu().numpy(), labels.cpu().numpy()) * images.size(0)
        train_size += images.size(0)

    train_loss /= train_size
    train_acc /= train_size

    # Set the model to evaluation mode
    model.eval()
    model = model.to(device)

    # Evaluate one epoch
    val_loss = 0.0
    val_acc = 0.0
    val_size = 0
    with torch.no_grad():
        for images, labels in val_loader:
            # Move the data to the device
            images, labels = images.to(device), labels.to(device)

            # Forward pass
            outputs = model(images)
            loss = loss_fn(outputs, labels)

            # Update the validation loss and accuracy
            val_loss += loss.item() * images.size(0)
            _, preds = torch.max(outputs, 1)
            val_acc += accuracy_score(preds.cpu().numpy(), labels.cpu().numpy()) * images.size(0)
            val_size += images.size(0)

    val_loss /= val_size
    val_acc /= val_size

    # Print the results for this epoch
    print(f"Epoch {epoch + 1}/{num_epochs}: "
          f"train_loss={train_loss:.4f} "
          f"train_acc={train_acc:.4f} "
          f"val_loss={val_loss:.4f} "
          f"val_acc={val_acc:.4f}")

Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth


  0%|          | 0.00/44.7M [00:00<?, ?B/s]

Epoch 1/50: train_loss=4.8614 train_acc=0.0490 val_loss=4.4651 val_acc=0.0576
Epoch 2/50: train_loss=3.8086 train_acc=0.1347 val_loss=3.5806 val_acc=0.1628
Epoch 3/50: train_loss=3.2003 train_acc=0.2329 val_loss=4.1831 val_acc=0.1469
Epoch 4/50: train_loss=2.8089 train_acc=0.3095 val_loss=2.8962 val_acc=0.3005
Epoch 5/50: train_loss=2.4553 train_acc=0.3872 val_loss=2.4875 val_acc=0.3623
Epoch 6/50: train_loss=2.2043 train_acc=0.4316 val_loss=2.1558 val_acc=0.4232
Epoch 7/50: train_loss=2.0520 train_acc=0.4797 val_loss=2.2456 val_acc=0.4407
Epoch 8/50: train_loss=1.8653 train_acc=0.5154 val_loss=1.8994 val_acc=0.4858
Epoch 9/50: train_loss=1.7281 train_acc=0.5515 val_loss=2.3273 val_acc=0.4207
Epoch 10/50: train_loss=1.6099 train_acc=0.5834 val_loss=1.8611 val_acc=0.5159
Epoch 11/50: train_loss=1.4773 train_acc=0.6128 val_loss=1.7721 val_acc=0.5551
Epoch 12/50: train_loss=1.3985 train_acc=0.6313 val_loss=1.4822 val_acc=0.5860
Epoch 13/50: train_loss=1.3269 train_acc=0.6569 val_loss=1.71

In [10]:
torch.save(model.state_dict(), 'resnet18_trained.pth')

Saliency maps with different data loader