In [None]:
import os
from abc import ABCMeta, abstractmethod
from collections import OrderedDict
import subprocess
import torch
import shutil
from PIL import Image
from pdb import set_trace as st
import numpy as np
from matplotlib import pyplot as plt
from collections import Counter
import glob
import tqdm
import shutil

In [36]:
def file_reading_func(root="/root/yifei/CL_mod/src/dataset/tiny-imagenet-200/tiny-imagenet-200/val/val_annotations.txt"):
    # Reads an annotation file and sorts it in the form of dictionary
    # file_name -> classname 
    f = open(root, "r")
    out_dict = {}
    for x in f:
        cur_line = x.rstrip('\n').split('\t')
        out_dict[cur_line[0]] = cur_line[1]
    return out_dict


def create_dir(dirpath, print_description=""):
    # Function to create directory
    # Checks if path exists
    try:
        if not os.path.exists(dirpath):
            os.makedirs(dirpath, mode=0o750)
    except Exception as e:
        print(e)
        print("ERROR IN CREATING ", print_description, " PATH:", dirpath)
        
def reorganize_files(new_root="/root/yifei/CL_mod/src/dataset/tiny-imagenet-200/tiny-imagenet-200/new_val", 
                     old_root="/root/yifei/CL_mod/src/dataset/tiny-imagenet-200/tiny-imagenet-200/val", 
                     file_reading_func=file_reading_func):
    # Function that creates a new folder for imgfolder format, used for validation set in tiny imageNet
    # new_root should be an empty path where the new folder is created
    # old_root is where the annotated txt file and the images are stored
    # file_reading_func processes
    # val/classname/*.jpeg
    # To save space, does not copy but moves the files 
    # Creates new images in the format *class_name*_*sample_number*.JPEG
    ann_root = os.path.join(old_root, "val_annotations.txt")
    file_dict = file_reading_func(ann_root)
    create_dir(new_root)

    file_count = {}
    for path, currentDirectory, files in os.walk(old_root):
        for file in files:
            if file.split(".")[-1] == "JPEG":
                # Only move JPEG
                current_path = os.path.join(path, file)
                class_name = file_dict[file]
                try:
                    current_count = file_count[class_name]
                except:
                    file_count[class_name] = 0
                    current_count = 0
                    create_dir(os.path.join(new_root, class_name))
                new_path = os.path.join(new_root, class_name)
                new_img_name = str(class_name)+"_"+str(current_count)+".JPEG"
                new_img_path = os.path.join(new_path, new_img_name)
                file_count[class_name] += 1
                shutil.move(current_path, new_img_path)
            
    print("Finished processing all images")

reorganize_files()

Finished processing all images


In [18]:
for path, currentDirectory, files in os.walk("/root/yifei/CL_mod/src/dataset/tiny-imagenet-200/tiny-imagenet-200/val"):
    for file in files:
        print(file)

val_annotations.txt
val_5461.JPEG
val_9927.JPEG
val_1195.JPEG
val_6456.JPEG
val_6677.JPEG
val_4179.JPEG
val_3230.JPEG
val_2585.JPEG
val_1432.JPEG
val_2681.JPEG
val_5927.JPEG
val_9966.JPEG
val_5411.JPEG
val_733.JPEG
val_3633.JPEG
val_3471.JPEG
val_7963.JPEG
val_3619.JPEG
val_2667.JPEG
val_6894.JPEG
val_15.JPEG
val_1603.JPEG
val_6835.JPEG
val_2264.JPEG
val_8198.JPEG
val_7760.JPEG
val_6330.JPEG
val_7871.JPEG
val_6996.JPEG
val_3507.JPEG
val_7509.JPEG
val_7321.JPEG
val_6820.JPEG
val_7652.JPEG
val_2340.JPEG
val_3496.JPEG
val_1382.JPEG
val_823.JPEG
val_337.JPEG
val_6458.JPEG
val_4028.JPEG
val_6191.JPEG
val_4280.JPEG
val_7912.JPEG
val_1645.JPEG
val_5962.JPEG
val_8920.JPEG
val_9026.JPEG
val_3487.JPEG
val_3239.JPEG
val_7486.JPEG
val_1055.JPEG
val_6929.JPEG
val_2460.JPEG
val_8245.JPEG
val_590.JPEG
val_3252.JPEG
val_3118.JPEG
val_5322.JPEG
val_795.JPEG
val_792.JPEG
val_4804.JPEG
val_9106.JPEG
val_1245.JPEG
val_4615.JPEG
val_2304.JPEG
val_5767.JPEG
val_2210.JPEG
val_6702.JPEG
val_4186.JPEG
val_4916

In [83]:
def create_tasks(num_tasks=10, path="/root/yifei/CL_mod/src/dataset/tiny-imagenet-200/tiny-imagenet-200", seed=0):
    # This function creates the corresponding tasks dictionary which maps class_name: task_id
     
    # Get the file path that stores the class names
    file_path = os.path.join(root_path, "wnids.txt")
    
    # Get the class names as a list
    lines = [line.rstrip('\n') for line in open(file_path)]
    
    # Make sure the tasks are correctly assigned
    nb_classes_task = len(lines) // num_tasks
    print("Split "+str(len(lines))+" classes into "+ str(num_tasks)+" tasks with "+str(nb_classes_task)+" classes per task")
    assert len(lines) % num_tasks == 0, "total "+str(len(lines))+" classes must be divisible by nb classes per task"
    
    # Create a dictionary to return in this format: task_number -> class_name
    out_dict = {}
    for i in range(num_tasks):
        out_dict[i] = lines[i*num_tasks:(i+1)*num_tasks]
        
    return out_dict

In [120]:
class tinyImageNet(torchvision.datasets.ImageFolder):
    # Class that inherits from imagefolder for dataloading purposes
    # Applies basic transforms and accepts a subset of classes for training/testing
    # Modifies the find_classes function from torch's Dataset class
    # Remember in defying paper, train is split 80:20 to train/val; val is for testing
    def __init__(self, root="/root/yifei/CL_mod/src/dataset/tiny-imagenet-200/tiny-imagenet-200/train", 
                 transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),]), 
                 target_transform=None, subset=None):
        # Subset stores a dictionary of class names to its label
        self.subset = subset
        super().__init__(root=root, transform=transform, target_transform=target_transform)
        self.all_classes = torchvision.datasets.folder.find_classes(root)
        
    def find_classes(self, path):
        if self.subset:
        # If using a subset of all classes, then only use those classes
            return self.subset.keys(), self.subset 
        else:
        # Else, use torch's generalized class function
            return torchvision.datasets.folder.find_classes(path)
    

In [121]:
def create_tasks(num_tasks=10, path="/root/yifei/CL_mod/src/dataset/tiny-imagenet-200/tiny-imagenet-200"):
    # Splitting into n classes creates a n-length array, where each item in the array consists of
    # a dictionary that maps class_labels to class_id (ie 0)
     
    # Get the file path that stores the class names
    file_path = os.path.join(root_path, "wnids.txt")
    
    # Get the class names as a list
    lines = [line.rstrip('\n') for line in open(file_path)]
    
    # Make sure the tasks are correctly assigned
    nb_classes_task = len(lines) // num_tasks
    print("Split "+str(len(lines))+" classes into "+ str(num_tasks)+" tasks with "+str(nb_classes_task)+" classes per task")
    assert len(lines) % num_tasks == 0, "total "+str(len(lines))+" classes must be divisible by nb classes per task"
    
    # Create a dictionary to return in this format: task_number -> class_name
    outputs = []
    current_id = 0
    for i in range(num_tasks):
        class_lbl = lines[i*num_tasks:(i+1)*num_tasks]  
        task_dict = {} 
        for j in range(i*num_tasks, (i+1)*num_tasks):
            task_dict[lines[j]] = current_id
            current_id += 1
        outputs.append(task_dict)
        
    return outputs

In [122]:
f = create_tasks()

Split 200 classes into 10 tasks with 20 classes per task


In [123]:
e = tinyImageNet(subset=f[1])
train_loader = torch.utils.data.DataLoader(e, batch_size=64, shuffle=True)

In [129]:
Counter(e.targets)

Counter({11: 500,
         14: 500,
         13: 500,
         17: 500,
         12: 500,
         19: 500,
         10: 500,
         15: 500,
         18: 500,
         16: 500})