In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
import matplotlib.pyplot as plt
import os
import pandas as pd
import copy
import argparse
from PIL import Image
from util import *

In [3]:
LOG_DIR = './log'
DATASET = 'plain'
MODEL_ID = 'plain_label_shuffling_test'
DATALOADER_WORKERS = 4
LEARNING_RATE = 0.01
LR_DECAY_FACTOR = 0.1
LR_DECAY_EPOCHS = 30
DROPOUT = True
DROPOUT_RATE = 0.5
FC_ADD_DIM = 4096
MOMENTUM = 0.9
EPOCHS = 100
BATCH_SIZE = 256 
DISPLAY_STEP = 10
NUM_CLASSES = 397
CAFFE_WEIGHTS = True
BIAS = True

if not os.path.isdir(LOG_DIR):
    os.makedirs(LOG_DIR)
    
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [None]:
class sun_dataset (torch.utils.data.Dataset):
    
    def __init__ (self, txt_file, transform=None):
        super().__init__()
        self.df = pd.read_csv(txt_file, header=None, sep=' ')
        self.transform = transform
        print('Loading from %s' % txt_file)
        
    def __len__ (self):
        return len(self.df)
    
    def __getitem__ (self, idx):
        
        image = Image.open(self.df.iloc[idx, 0])
        label = self.df.iloc[idx, 1] - 1
        
        if self.transform:
            image = self.transform(image)
            
        return image, label


In [140]:
class dataset_test (torch.utils.data.Dataset):
    
    def __init__ (self, txt_file):
        super().__init__()
        self.df = pd.read_csv(txt_file, header=None, sep=' ')
        print('Loading from %s' % txt_file)
        
    def __len__ (self):
        return len(self.df)
    
    def __getitem__ (self, idx):
        
        image = self.df.iloc[idx, 0]
        label = self.df.iloc[idx, 1] - 1
            
        return image, label

In [159]:
# partly implemented by frombeijingwithlove

from torch.utils.data.sampler import Sampler
import random

class RandomCycleIter:
    
    def __init__ (self, data):
        self.data_list = list(data)
        self.length = len(self.data_list)
        self.i = self.length - 1
        
    def __iter__ (self):
        return self
    
    def __next__ (self):
        self.i += 1
        
        if self.i == self.length:
            self.i = 0
            random.shuffle(self.data_list)
            
        return self.data_list[self.i]
    
def class_aware_sample_generator (cls_iter, data_iter_list, n):
    
    i = 0
    
    while i < n:
        yield next(data_iter_list[next(cls_iter)])
        i += 1
        
class ClassAwareSampler (Sampler):
    
    def __init__ (self, data_source, num_classes, num_samples=0):
        
        self.data_source = data_source
        self.class_iter = RandomCycleIter(range(num_classes))
        class_data_list = [[] for _ in range(num_classes)]
        
        for idx, row in self.data_source.df.iterrows():
            class_data_list[row[1] - 1].append(idx) 
            
        self.data_iter_list = [RandomCycleIter(x) for x in class_data_list]
        
        self.num_samples = max([len(x) for x in class_data_list]) * len(class_data_list)
        
    def __iter__ (self):
        return class_aware_sample_generator(self.class_iter, self.data_iter_list, self.num_samples)
    
    def __len__ (self):
        return self.num_samples
        
        
        

In [185]:
dataset = dataset_test(txt_file='./small_test.txt')

Loading from ./small_test.txt


In [187]:
dataloader = torch.utils.data.DataLoader(dataset, batch_size=3, shuffle=False, sampler=ClassAwareSampler(dataset, 4))

In [188]:
for x, y in dataloader:
    print(x, y)

('1-4', '4-1', '2-4') tensor([ 0,  3,  1])
('3-2', '1-1', '3-1') tensor([ 2,  0,  2])
('2-1', '4-2', '2-2') tensor([ 1,  3,  1])
('4-1', '3-1', '1-6') tensor([ 3,  2,  0])
('2-3', '4-1', '3-2') tensor([ 1,  3,  2])
('1-7', '3-2', '2-5') tensor([ 0,  2,  1])
('1-8', '4-1', '2-5') tensor([ 0,  3,  1])
('1-2', '3-1', '4-2') tensor([ 0,  2,  3])
('3-2', '4-2', '1-9') tensor([ 2,  3,  0])
('2-3', '1-5', '4-1') tensor([ 1,  0,  3])
('2-4', '3-1', '4-1') tensor([ 1,  2,  3])
('3-2', '1-3', '2-1') tensor([ 2,  0,  1])


In [189]:
a = {'a': 1, 'b': 2}

In [190]:
a

{'a': 1, 'b': 2}