# Few-shot Learning 적용

### 적용한 이유
- 데이터의 라벨이 굉장히 세분화 되어 있음
- 이전 경진 대회에서 Few shot Learning을 공부 했고 당시 데이터셋과 주어진 데이터셋의 특징이 유사 했음
    - 라벨이 세분화 되어 있다는 특징이 유사했음
- Supervised Learning으로는 성능 개선의 한계가 있음
- 과제 pdf 파일에 설명되어 있는 쇼핑몰 데이터의 특성이 Few shot learning과 잘 맞다고 생각했음
- Training 데이터의 개수가 클래스 별로 동일함

### Fine Tuning 진행
- Transductive Information Maximization을 적용한 Fine Tuning 진행
    - Public set으로 Pretraining 후 주어진 데이터셋에 FineTuning 함
    - Fine Tuning 후 점수 하락이 생겼음
    - 모델의 Weight를 업데이트 해주는 방식이 아니라 Support set feature map을 업데이트 해주는 방식으로 진행
    
### RandAug 적용
- data augmentation은 일반화 능력을 높여준다고 알려져 있기 때문에 사용
- A CLOSER LOOK AT FEW-SHOT CLASSIFICATION 논문에서도 random crop, flip, color jitter등을 사용했음
- 적용 시 성능엔 큰 차이가 없었음

### Top - k 에 대한 2번의 Inference로 성능 개선
- 방법 
    1. 첫 번째 Inference로 입력 Image에 대한 Top - 5 추출
    2. 추출한 Top - 5에 대해 두 번째 Inference 진행
- 적용 이유
    - python package(easyfsl)에서 제공하는 Evaluation 코드는 랜덤하게 뽑은 데이터를 Support set과 Query set 모두에 활용함
    - Query set의 라벨이 Support set에 속할 경우 정확도가 높게 나오는 것을 확인
    - 첫 번째로 뽑은 Top - 5에 대해서 다시 Inference를 진행할 경우 2순위나 3순위로 밀려난 데이터를 1순위로 끌어올 수 있지 않을까 생각했음
- 결과
    - Top - 1 Accuracy 기준 약 3퍼센트 개선함
        - 0.5212 -> 0.5598 
    - Top - k 를 구성하는 class가 입력 class와 더욱 유사한 class로 구성됨
    - 2, 3순위로 밀려난 Class를 1순위로 끌어 올림
        - 7개의 라벨이 Top-1에 속하지 못했음, 적용 후에는 Test set을 구성하는 모든 Class의 Top 1에 속하게 됨
            - fsl_multi_infer.ipynb 와 fsl_inference.ipynb 참고
    - Inference 횟수가 늘었기 때문에 속도가 약 2.5배 느려짐
    
### reference
- A CLOSER LOOK AT FEW-SHOT CLASSIFICATION
    - https://arxiv.org/pdf/1904.04232.pdf
- Prototypical Networks for Few-shot Learning
    - https://arxiv.org/pdf/1703.05175.pdf
- Transductive Information Maximization For Few-Shot Learning
    - https://arxiv.org/pdf/2008.11297.pdf
- easy-few-shot-learning
    - https://github.com/sicara/easy-few-shot-learning


#  Create Json
- json 파일 명 변경 금지

In [None]:
# train set
fsl = defaultdict(list)

path = 'data/train/*'
label_list = glob(path)
label_list.sort(key=lambda x: int(x.split('\\')[-1]))

for label in label_list :  
    fsl['class_roots'].append(label)
    fsl['class_names'].append(label.split('\\')[-1])

with open('data/train.json', 'w') as f:
    json.dump(fsl, f, indent=2)

    
# test set
fsl = defaultdict(list)

path = 'data/test/*'
label_list = glob(path)
label_list.sort(key=lambda x: int(x.split('\\')[-1]))

for label in label_list :  
    fsl['class_roots'].append(label)
    fsl['class_names'].append(label.split('\\')[-1])
    
with open('data/test.json', 'w') as f:
    json.dump(fsl, f, indent=2)

In [2]:
from pathlib import Path
import random
from statistics import mean

import numpy as np
import torch
from torch import nn
from tqdm import tqdm
import timm
import os

from easyfsl.methods.utils import evaluate


In [3]:
random_seed = 0
np.random.seed(random_seed)
torch.manual_seed(random_seed)
random.seed(random_seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

In [4]:
n_way = 5
n_shot = 6
n_query = 4

DEVICE = "cuda"
n_workers = 6

In [5]:
from easyfsl.datasets import CUB, CUSTOM
from easyfsl.samplers import TaskSampler
from torch.utils.data import DataLoader


n_tasks_per_epoch = 500
n_validation_tasks = 100

# train_set = CUB(split="train", training=True)
# val_set = CUB(split="val", training=False)

train_set = CUSTOM(split="data/train", image_size=224 ,training=True)
val_set = CUSTOM(split="data/test", image_size=224, training=False)

train_sampler = TaskSampler(
    train_set, n_way=n_way, n_shot=n_shot, n_query=n_query, n_tasks=n_tasks_per_epoch
)
val_sampler = TaskSampler(
    val_set, n_way=n_way, n_shot=n_shot, n_query=n_query, n_tasks=n_validation_tasks
)

train_loader = DataLoader(
    train_set,
    batch_sampler=train_sampler,
    num_workers=n_workers,
    pin_memory=True,
    collate_fn=train_sampler.episodic_collate_fn,
)
val_loader = DataLoader(
    val_set,
    batch_sampler=val_sampler,
    num_workers=n_workers,
    pin_memory=True,
    collate_fn=val_sampler.episodic_collate_fn,
)

In [5]:
# timm.list_models(pretrained=True)

In [7]:
class CNN(nn.Module) :
    def __init__(self, name='efficientnet_b0', num_classes=0)  :
        super(CNN, self).__init__()
        
        self.model = timm.create_model(name, pretrained=True, num_classes=num_classes)
        
    def forward(self, x) :
        return self.model(x)
#         return self.model.forward_features(x)

In [8]:
from easyfsl.methods import PrototypicalNetworks, FewShotClassifier, TIM, RelationNetworks

convolutional_network = CNN()
# model_checkpoint = torch.load('./fsl_model/pretrained_effi_b0/9E_model.pt')
# convolutional_network.load_state_dict(model_checkpoint["model_state_dict"], strict=False)

# few_shot_classifier = PrototypicalNetworks(convolutional_network).to(DEVICE)
few_shot_classifier = PrototypicalNetworks(convolutional_network, use_softmax=True).to(DEVICE)

# few_shot_classifier = TIM(convolutional_network).to(DEVICE)
# few_shot_classifier = RelationNetworks(convolutional_network).to(DEVICE)

#### Load Pretrained Model to few shot model

In [8]:
# fsm_checkpoint = torch.load('./fsl_model/publicset_randaug_224_effi_0/13E_model.pt')
# few_shot_classifier.load_state_dict(fsm_checkpoint["model_state_dict"])

In [9]:
from torch.optim import SGD, Optimizer, AdamW
from torch.optim.lr_scheduler import MultiStepLR
from torch.utils.tensorboard import SummaryWriter


LOSS_FUNCTION = nn.CrossEntropyLoss()

n_epochs = 8
scheduler_milestones = [3, 6]
scheduler_gamma = 0.1
learning_rate = 1e-2


train_optimizer = SGD(
    few_shot_classifier.parameters(), lr=learning_rate, momentum=0.9, weight_decay=5e-4
)
# train_optimizer = AdamW(
#     few_shot_classifier.parameters(), lr=learning_rate, weight_decay=5e-4
# )

train_scheduler = MultiStepLR(
    train_optimizer,
    milestones=scheduler_milestones,
    gamma=scheduler_gamma,
)


In [10]:
def training_epoch(
    model: FewShotClassifier, data_loader: DataLoader, optimizer: Optimizer
):
    all_loss = []
    model.train()
    with tqdm(
        enumerate(data_loader), total=len(data_loader), desc="Training"
    ) as tqdm_train:
        for episode_index, (
            support_images,
            support_labels,
            query_images,
            query_labels,
            _,
        ) in tqdm_train:
            optimizer.zero_grad()
            model.process_support_set(
                support_images.to(DEVICE), support_labels.to(DEVICE)
            )
            classification_scores = model(query_images.to(DEVICE))

            loss = LOSS_FUNCTION(classification_scores, query_labels.to(DEVICE))
#             loss.requires_grad_(True)
            loss.backward()
            optimizer.step()

            all_loss.append(loss.item())

            tqdm_train.set_postfix(loss=mean(all_loss))

    return mean(all_loss)

In [11]:
from easyfsl.methods.utils import evaluate

start_e = 0
best_e = None
resum = False
model_save_path = './submission/fsl_224_effi_b0/'
os.makedirs(model_save_path, exist_ok=True)

tb_logs_dir = Path(".")
tb_writer = SummaryWriter(log_dir=str(model_save_path))


best_state = few_shot_classifier.state_dict()
if resum == True : 
    fsl_checkpoint = torch.load(os.paht.join(model_save_path,'2E_model.pt'))
    few_shot_classifier.load_state_dict(fsl_checkpoint["model_state_dict"])
    train_optimizer.load_state_dict(fsl_checkpoint['optimizer_state_dict'])
    start_e = fsl_checkpoint["epoch"]
    
best_validation_accuracy = 0.0
for epoch in range(start_e, n_epochs):
    print(f"Epoch {epoch}")
    average_loss = training_epoch(few_shot_classifier, train_loader, train_optimizer)
    validation_accuracy = evaluate(
        few_shot_classifier, val_loader, device=DEVICE, tqdm_prefix="Validation"
    )

    if validation_accuracy > best_validation_accuracy:
        best_validation_accuracy = validation_accuracy
        print("Ding ding ding! We found a new best model!")
        best_e = epoch
        torch.save({
                "epoch" : epoch,
                "model_state_dict" : best_state,
                "optimizer_state_dict" : train_optimizer.state_dict()
            }, os.path.join(model_save_path, str(epoch)+'E_model.pt'))
    

    tb_writer.add_scalar("Train/loss", average_loss, epoch)
    tb_writer.add_scalar("Val/acc", validation_accuracy, epoch)

    # Warn the scheduler that we did an epoch
    # so it knows when to decrease the learning rate
    train_scheduler.step()

Epoch 0


Training: 100%|████████████████████████████████████████████████████████████| 500/500 [01:41<00:00,  4.93it/s, loss=1.1]
Validation: 100%|████████████████████████████████████████████████████| 100/100 [00:07<00:00, 13.92it/s, accuracy=0.904]


Ding ding ding! We found a new best model!
Epoch 1


Training: 100%|██████████████████████████████████████████████████████████| 500/500 [01:38<00:00,  5.10it/s, loss=0.973]
Validation: 100%|████████████████████████████████████████████████████| 100/100 [00:07<00:00, 14.26it/s, accuracy=0.907]


Ding ding ding! We found a new best model!
Epoch 2


Training: 100%|██████████████████████████████████████████████████████████| 500/500 [01:37<00:00,  5.11it/s, loss=0.947]
Validation: 100%|████████████████████████████████████████████████████| 100/100 [00:06<00:00, 14.63it/s, accuracy=0.909]


Ding ding ding! We found a new best model!
Epoch 3


Training: 100%|██████████████████████████████████████████████████████████| 500/500 [01:37<00:00,  5.10it/s, loss=0.937]
Validation: 100%|████████████████████████████████████████████████████| 100/100 [00:06<00:00, 14.78it/s, accuracy=0.904]


Epoch 4


Training: 100%|██████████████████████████████████████████████████████████| 500/500 [01:38<00:00,  5.10it/s, loss=0.935]
Validation: 100%|████████████████████████████████████████████████████| 100/100 [00:06<00:00, 14.67it/s, accuracy=0.901]


Epoch 5


Training: 100%|██████████████████████████████████████████████████████████| 500/500 [01:37<00:00,  5.10it/s, loss=0.935]
Validation: 100%|████████████████████████████████████████████████████| 100/100 [00:06<00:00, 14.79it/s, accuracy=0.911]


Ding ding ding! We found a new best model!
Epoch 6


Training: 100%|██████████████████████████████████████████████████████████| 500/500 [01:38<00:00,  5.10it/s, loss=0.936]
Validation: 100%|████████████████████████████████████████████████████| 100/100 [00:06<00:00, 14.58it/s, accuracy=0.896]


Epoch 7


Training: 100%|██████████████████████████████████████████████████████████| 500/500 [01:37<00:00,  5.10it/s, loss=0.933]
Validation: 100%|████████████████████████████████████████████████████| 100/100 [00:07<00:00, 13.46it/s, accuracy=0.898]


# Eval

In [10]:
fsm_checkpoint = torch.load(os.path.join(model_save_path, f'{best_e}E_model.pt'))
few_shot_classifier.load_state_dict(fsm_checkpoint["model_state_dict"])

<All keys matched successfully>

In [11]:
n_way = 5
n_shot = 6
n_query = 4

DEVICE = "cuda"
n_workers = 6
n_validation_tasks = 500 #100

val_set = CUSTOM(split="data/test", image_size=224, training=False)

val_sampler = TaskSampler(
    val_set, n_way=n_way, n_shot=n_shot, n_query=n_query, n_tasks=n_validation_tasks
)

val_loader = DataLoader(
    val_set,
    batch_sampler=val_sampler,
    num_workers=n_workers,
    pin_memory=True,
    collate_fn=val_sampler.episodic_collate_fn,
)

validation_accuracy = evaluate(
        few_shot_classifier, val_loader, device=DEVICE, tqdm_prefix="Validation"
    )

Validation: 100%|████████████████████████████████████████████████████| 500/500 [00:33<00:00, 15.07it/s, accuracy=0.891]


In [14]:
with open(os.path.join(model_save_path,'score.txt'), 'w') as f :
    f.write(str(validation_accuracy)+'\n')
    f.write('n_way : ' + str(n_way)+'\n')
    f.write('n_shot : ' + str(n_shot)+'\n')
    f.write('n_query : ' + str(n_query)+'\n')
    

In [15]:
# Finetuning 시에는 augmentaion 제외 시킴