In [1]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np


class LDAMLoss(nn.Module):

    def __init__(self, cls_num_list, max_m=0.5, weight=None, s=30):
        super(LDAMLoss, self).__init__()
        m_list = 1.0 / np.sqrt(np.sqrt(cls_num_list))
        m_list = m_list * (max_m / np.max(m_list))
        m_list = torch.cuda.FloatTensor(m_list)
        self.m_list = m_list
        assert s > 0
        self.s = s
        self.weight = weight

    def forward(self, x, target):
        index = torch.zeros_like(x, dtype=torch.uint8)
        index.scatter_(1, target.data.view(-1, 1), 1)

        index_float = index.type(torch.cuda.FloatTensor)
        batch_m = torch.matmul(self.m_list[None, :], index_float.transpose(0, 1))
        batch_m = batch_m.view((-1, 1))
        x_m = x - batch_m

        output = torch.where(index, x_m, x)
        return F.cross_entropy(self.s * output, target, weight=self.weight)

In [3]:
from data.celebA_dataset import CelebADataset
from models import model_attributes
from data.dro_dataset import DRODataset

    
root_dir = '/home/thiennguyen/research/datasets/celebA/'  # dir that contains data
target_name= 'Blond_Hair'  # we are classifying whether the input image is blond or not
confounder_names= ['Male']  # we aim to avoid learning spurious features... here it's the gender
model_type= 'resnet10vw'  # what model we are using to process --> this is to determine the input size to rescale the image
augment_data= False
fraction=1.0
splits = ['train', 'val', 'test']
n_classes = 4

full_dataset = CelebADataset(root_dir=root_dir,
        target_name=target_name,
        confounder_names=confounder_names,
        model_type=model_type,  # this string is to get the model's input size (for resizing) and input type (image or precomputed)
        augment_data=augment_data)  # augment data adds random resized crop and random flip.

subsets = full_dataset.get_splits(       # basically return the Subsets object with the appropriate indices for train/val/test
        splits,                          # also implements subsampling --> just remove random indices of the appropriate groups in train
        train_frac=fraction,   # fraction means how much of the train data to use --> randomly remove if less than 1
        subsample_to_minority=False)

dro_subsets = [  
    DRODataset(
        subsets[split],  # process each subset separately --> applying the transform parameter.
        process_item_fn=None,
        n_groups=full_dataset.n_groups,
        n_classes=full_dataset.n_classes,
        group_str_fn=full_dataset.group_str) \
    for split in splits]

train_data, val_data, test_data = dro_subsets
train_loader = train_data.get_loader(train=True, reweight_groups=False, batch_size=128)
val_loader = val_data.get_loader(train=False, reweight_groups=None, batch_size=5)
test_loader = test_data.get_loader(train=False, reweight_groups=None, batch_size=128)

In [11]:
cls_num_list = train_data.group_counts().numpy()
l = LDAMLoss(cls_num_list)

In [13]:
torch.FloatTensor(cls_num_list)

tensor([71629., 66874., 22880.,  1387.])