Skip to content

Commit

Permalink
extra experiment for omniglot and cross_char
Browse files Browse the repository at this point in the history
  • Loading branch information
wyharveychen committed Feb 1, 2019
1 parent 28a219a commit 31a75f3
Show file tree
Hide file tree
Showing 12 changed files with 320 additions and 11 deletions.
2 changes: 1 addition & 1 deletion backbone.py
Expand Up @@ -32,7 +32,7 @@ def forward(self, x):
L_norm = torch.norm(self.L.weight.data, p=2, dim =1).unsqueeze(1).expand_as(self.L.weight.data)
self.L.weight.data = self.L.weight.data.div(L_norm + 0.00001)
cos_dist = self.L(x_normalized) #matrix product by forward function
scores = 2* (cos_dist) #a fixed scale factor to scale the output of cos value into a reasonably large input for softmax
scores = 10* (cos_dist) #a fixed scale factor to scale the output of cos value into a reasonably large input for softmax

return scores

Expand Down
4 changes: 3 additions & 1 deletion configs.py
@@ -1,4 +1,6 @@
save_dir = './record/'
save_dir = '/work/newriver/wyharveychen/CloserLookFewShot/'
data_dir = {}
data_dir['CUB'] = './filelists/CUB/'
data_dir['miniImagenet'] = './filelists/miniImagenet/'
data_dir['omniglot'] = './filelists/omniglot/'
data_dir['emnist'] = './filelists/emnist/'
5 changes: 5 additions & 0 deletions filelists/emnist/download_emnist.sh
@@ -0,0 +1,5 @@
#!/usr/bin/env bash
wget https://github.com/NanqingD/DAOSL/raw/master/data/emnist.zip
unzip emnist.zip
python invert_emnist.py
python write_cross_char_valnovel_filelist.py
33 changes: 33 additions & 0 deletions filelists/emnist/invert_emnist.py
@@ -0,0 +1,33 @@
import numpy as np
from os import listdir
from os.path import isfile, isdir, join
import os
import json
import random
from PIL import Image
import PIL.ImageOps

cwd = os.getcwd()
data_path = join(cwd,'emnist')
inv_data_path = join(cwd,'inv_emnist')
savedir = './'

#if not os.path.exists(savedir):
# os.makedirs(savedir)
if not os.path.exists(inv_data_path):
os.makedirs(inv_data_path)

character_folder_list = [str(i) for i in range(62)] #lazy_hack

classfile_list_all = []

for character_folder in character_folder_list:
character_folder_path = join(data_path, character_folder)
inv_character_folder_path = join(inv_data_path, character_folder)
image_list = [ img for img in listdir(character_folder_path) if (isfile(join(character_folder_path,img)) and img[0] != '.')]
if not os.path.exists(inv_character_folder_path):
os.makedirs(inv_character_folder_path)
for img in image_list:
inverted_img =PIL.ImageOps.invert(Image.open(join(character_folder_path,img)))
inverted_img.save(join(inv_character_folder_path ,img))

59 changes: 59 additions & 0 deletions filelists/emnist/write_cross_char_valnovel_filelist.py
@@ -0,0 +1,59 @@
import numpy as np
from os import listdir
from os.path import isfile, isdir, join
import os
import json
import random

cwd = os.getcwd()
data_path = join(cwd,'inv_emnist')
savedir = './'
dataset_list = ['val','novel']

#if not os.path.exists(savedir):
# os.makedirs(savedir)

folder_list = [str(i) for i in range(62)] #lazy_hack
label_dict = dict(zip(folder_list,range(0,len(folder_list))))

classfile_list_all = []

for i, folder in enumerate(folder_list):
folder_path = join(data_path, folder)
classfile_list_all.append( [ join(folder_path, cf) for cf in listdir(folder_path) if (isfile(join(folder_path,cf)) and cf[0] != '.')])
random.shuffle(classfile_list_all[i])

for dataset in dataset_list:
file_list = []
label_list = []
for i, classfile_list in enumerate(classfile_list_all):
if 'val' in dataset:
if (i%2 == 0):
file_list = file_list + classfile_list
label_list = label_list + np.repeat(i, len(classfile_list)).tolist()
if 'novel' in dataset:
if (i%2 == 1):
file_list = file_list + classfile_list
label_list = label_list + np.repeat(i, len(classfile_list)).tolist()

fo = open(savedir + dataset + ".json", "w")
fo.write('{"label_names": [')
fo.writelines(['"%s",' % item for item in folder_list])
fo.seek(0, os.SEEK_END)
fo.seek(fo.tell()-1, os.SEEK_SET)
fo.write('],')

fo.write('"image_names": [')
fo.writelines(['"%s",' % item for item in file_list])
fo.seek(0, os.SEEK_END)
fo.seek(fo.tell()-1, os.SEEK_SET)
fo.write('],')

fo.write('"image_labels": [')
fo.writelines(['%d,' % item for item in label_list])
fo.seek(0, os.SEEK_END)
fo.seek(fo.tell()-1, os.SEEK_SET)
fo.write(']}')

fo.close()
print("%s -OK" %dataset)
19 changes: 19 additions & 0 deletions filelists/omniglot/download_omniglot.sh
@@ -0,0 +1,19 @@
#!/usr/bin/env bash
wget https://raw.githubusercontent.com/jakesnell/prototypical-networks/master/data/omniglot/splits/vinyals/train.txt
wget https://raw.githubusercontent.com/jakesnell/prototypical-networks/master/data/omniglot/splits/vinyals/val.txt
wget https://raw.githubusercontent.com/jakesnell/prototypical-networks/master/data/omniglot/splits/vinyals/test.txt

DATADIR=./images
mkdir -p $DATADIR
wget -O images_background.zip https://github.com/brendenlake/omniglot/blob/master/python/images_background.zip?raw=true
wget -O images_evaluation.zip https://github.com/brendenlake/omniglot/blob/master/python/images_evaluation.zip?raw=true
unzip images_background.zip -d $DATADIR
unzip images_evaluation.zip -d $DATADIR
mv $DATADIR/images_background/* $DATADIR/
mv $DATADIR/images_evaluation/* $DATADIR/
rmdir $DATADIR/images_background
rmdir $DATADIR/images_evaluation

python rot_omniglot.py
python write_omniglot_filelist.py
python write_cross_char_base_filelist.py
37 changes: 37 additions & 0 deletions filelists/omniglot/rot_omniglot.py
@@ -0,0 +1,37 @@
import numpy as np
from os import listdir
from os.path import isfile, isdir, join
import os
import json
import random
from PIL import Image

cwd = os.getcwd()
data_path = join(cwd,'images')
savedir = './'

#if not os.path.exists(savedir):
# os.makedirs(savedir)

language_folder_list = [f for f in listdir(data_path) if isdir(join(data_path, f))]
language_folder_list.sort()

classfile_list_all = []

for language_folder in language_folder_list:
language_folder_path = join(data_path, language_folder)
character_folder_list = [cf for cf in listdir(language_folder_path) if isdir(join(language_folder_path, cf))]
character_folder_list.sort()
for character_folder in character_folder_list:
character_folder_path = join(language_folder_path, character_folder)
image_list = [ img for img in listdir(character_folder_path) if (isfile(join(character_folder_path,img)) and img[0] != '.')]
for deg in [0,90,180,270]:
rot_str = "rot%03d"%deg
rot_character_path = join(character_folder_path, rot_str)
print(rot_character_path)
if not os.path.exists(rot_character_path):
os.makedirs(rot_character_path)
for img in image_list:
rot_img = Image.open(join(character_folder_path,img)).rotate(deg)
rot_img.save(join(character_folder_path,rot_str,img))

64 changes: 64 additions & 0 deletions filelists/omniglot/write_cross_char_base_filelist.py
@@ -0,0 +1,64 @@
import numpy as np
from os import listdir
from os.path import isfile, isdir, join
import os
import json
import random
import re

cwd = os.getcwd()
data_path = join(cwd,'images')
savedir = './'

#if not os.path.exists(savedir):
# os.makedirs(savedir)

cl = -1
folderlist = []

language_folder_list = [f for f in listdir(data_path) if isdir(join(data_path, f))]
language_folder_list.sort()

filelists = {}

for language_folder in language_folder_list:
if language_folder == 'Latin':
continue
language_folder_path = join(data_path, language_folder)
character_folder_list = [cf for cf in listdir(language_folder_path) if isdir(join(language_folder_path, cf))]
character_folder_list.sort()
for character_folder in character_folder_list:
character_folder_path = join(language_folder_path, character_folder)
label = join(language_folder,character_folder)
folderlist.append(label)
filelists[label] = [ join(character_folder_path,img) for img in listdir(character_folder_path) if (isfile(join(character_folder_path,img)) and img[-3:] == 'png')]

filelists_flat = []
labellists_flat = []
for key, filelist in filelists.items():
cl += 1
random.shuffle(filelist)
filelists_flat += filelist
labellists_flat += np.repeat(cl, len(filelist)).tolist()

fo = open(join(savedir, "noLatin.json"), "w")
fo.write('{"label_names": [')
fo.writelines(['"%s",' % item for item in folderlist])
fo.seek(0, os.SEEK_END)
fo.seek(fo.tell()-1, os.SEEK_SET)
fo.write('],')

fo.write('"image_names": [')
fo.writelines(['"%s",' % item for item in filelists_flat])
fo.seek(0, os.SEEK_END)
fo.seek(fo.tell()-1, os.SEEK_SET)
fo.write('],')

fo.write('"image_labels": [')
fo.writelines(['%d,' % item for item in labellists_flat])
fo.seek(0, os.SEEK_END)
fo.seek(fo.tell()-1, os.SEEK_SET)
fo.write(']}')

fo.close()
print("noLatin -OK")
59 changes: 59 additions & 0 deletions filelists/omniglot/write_omniglot_filelist.py
@@ -0,0 +1,59 @@
import numpy as np
from os import listdir
from os.path import isfile, isdir, join
import os
import json
import random
import re

cwd = os.getcwd()
data_path = join(cwd,'images')
savedir = './'
dataset_list = ['base', 'val', 'novel']

#if not os.path.exists(savedir):
# os.makedirs(savedir)

cl = -1
folderlist = []

datasetmap = {'base':'train','val':'val','novel':'test'};
filelists = {'base':{},'val':{},'novel':{} }
filelists_flat = {'base':[],'val':[],'novel':[] }
labellists_flat = {'base':[],'val':[],'novel':[] }

for dataset in dataset_list:
with open(datasetmap[dataset] + ".txt", "r") as lines:
for i, line in enumerate(lines):
label = line.replace('\n','')
folderlist.append(label)
filelists[dataset][label] = [ join(data_path,label, f) for f in listdir( join(data_path, label))]

for key, filelist in filelists[dataset].items():
cl += 1
random.shuffle(filelist)
filelists_flat[dataset] += filelist
labellists_flat[dataset] += np.repeat(cl, len(filelist)).tolist()

for dataset in dataset_list:
fo = open(savedir + dataset + ".json", "w")
fo.write('{"label_names": [')
fo.writelines(['"%s",' % item for item in folderlist])
fo.seek(0, os.SEEK_END)
fo.seek(fo.tell()-1, os.SEEK_SET)
fo.write('],')

fo.write('"image_names": [')
fo.writelines(['"%s",' % item for item in filelists_flat[dataset]])
fo.seek(0, os.SEEK_END)
fo.seek(fo.tell()-1, os.SEEK_SET)
fo.write('],')

fo.write('"image_labels": [')
fo.writelines(['%d,' % item for item in labellists_flat[dataset]])
fo.seek(0, os.SEEK_END)
fo.seek(fo.tell()-1, os.SEEK_SET)
fo.write(']}')

fo.close()
print("%s -OK" %dataset)
15 changes: 8 additions & 7 deletions io_utils.py
Expand Up @@ -6,6 +6,7 @@

model_dict = dict(
Conv4 = backbone.Conv4,
Conv4S = backbone.Conv4S,
Conv6 = backbone.Conv6,
ResNet10 = backbone.ResNet10,
ResNet18 = backbone.ResNet18,
Expand All @@ -15,16 +16,16 @@

def parse_args(script):
parser = argparse.ArgumentParser(description= 'few-shot script %s' %(script))
parser.add_argument('--dataset' , default='CUB', help='CUB/ miniImagenet/cross')
parser.add_argument('--dataset' , default='CUB', help='CUB/miniImagenet/cross/omniglot/cross_char')
parser.add_argument('--model' , default='Conv4', help='model: Conv{4|6} / ResNet{10|18|34|50|101}') # 50 and 101 are not used in the paper
parser.add_argument('--method' , default='baseline', help='baseline/baseline++/protonet/matchingnet/relationnet{_softmax}/maml{_approx}') #relationnet_softmax replace L2 norm with softmax to expedite training, maml_approx use first-order approximation in the gradient for efficiency
parser.add_argument('--train_n_way' , default=5, type=int, help='class num to classify for training')
parser.add_argument('--test_n_way' , default=5, type=int, help='class num to classify for testing (validation) ')
parser.add_argument('--n_shot' , default=5, type=int, help='number of labeled data in each class, same as n_support')
parser.add_argument('--train_aug' , action='store_true', help='perform data augmentation or not during training ')
parser.add_argument('--train_n_way' , default=5, type=int, help='class num to classify for training') #baseline and baseline++ would ignore this parameter
parser.add_argument('--test_n_way' , default=5, type=int, help='class num to classify for testing (validation) ') #baseline and baseline++ only use this parameter in finetuning
parser.add_argument('--n_shot' , default=5, type=int, help='number of labeled data in each class, same as n_support') #baseline and baseline++ only use this parameter in finetuning
parser.add_argument('--train_aug' , action='store_true', help='perform data augmentation or not during training ') #still required for save_features.py and test.py to find the model path correctly

if script == 'train':
parser.add_argument('--num_classes' , default=200, type=int, help='total number of classes in softmax, only used in baseline')
parser.add_argument('--num_classes' , default=200, type=int, help='total number of classes in softmax, only used in baseline') #make it larger than the maximum label value in base class
parser.add_argument('--save_freq' , default=50, type=int, help='Save frequency')
parser.add_argument('--start_epoch' , default=0, type=int,help ='Starting epoch')
parser.add_argument('--stop_epoch' , default=400, type=int, help ='Stopping epoch') # for meta-learning methods, each epoch contains 100 episodes
Expand All @@ -35,7 +36,7 @@ def parse_args(script):
parser.add_argument('--save_iter', default=-1, type=int,help ='save feature from the model trained in x epoch, use the best model if x is -1')
elif script == 'test':
parser.add_argument('--split' , default='novel', help='base/val/novel') #default novel, but you can also test base/val class accuracy if you want
parser.add_argument('--save_iter', default=-1, type=int,help ='save feature from the model trained in x epoch, use the best model if x is -1') #please match the one used in save_features
parser.add_argument('--save_iter', default=-1, type=int,help ='saved feature from the model trained in x epoch, use the best model if x is -1')
parser.add_argument('--adaptation' , action='store_true', help='further adaptation in test time or not')
else:
raise ValueError('Unknown script')
Expand Down
16 changes: 15 additions & 1 deletion save_features.py
Expand Up @@ -45,16 +45,28 @@ def save_features(model, data_loader, outfile ):
assert params.method != 'maml' and params.method != 'maml_approx', 'maml do not support save_feature and run'

if 'Conv' in params.model:
image_size = 84
if params.dataset in ['omniglot', 'cross_char']:
image_size = 28
else:
image_size = 84
else:
image_size = 224

if params.dataset in ['omniglot', 'cross_char']:
assert params.model == 'Conv4' and not params.train_aug ,'omniglot only support Conv4 without augmentation'
params.model = 'Conv4S'

split = params.split
if params.dataset == 'cross':
if split == 'base':
loadfile = configs.data_dir['miniImagenet'] + 'all.json'
else:
loadfile = configs.data_dir['CUB'] + split +'.json'
elif params.dataset == 'cross_char':
if split == 'base':
loadfile = configs.data_dir['omniglot'] + 'noLatin.json'
else:
loadfile = configs.data_dir['emnist'] + split +'.json'
else:
loadfile = configs.data_dir[params.dataset] + split + '.json'

Expand Down Expand Up @@ -84,6 +96,8 @@ def save_features(model, data_loader, outfile ):
model = backbone.Conv4NP()
elif params.model == 'Conv6':
model = backbone.Conv6NP()
elif params.model == 'Conv4S':
model = backbone.Conv4SNP()
else:
model = model_dict[params.model]( flatten = False )
elif params.method in ['maml' , 'maml_approx']:
Expand Down

0 comments on commit 31a75f3

Please sign in to comment.