In [2]:
# !pip install timm
# !pip install transformers
# !pip install tqdm

In [3]:
import os
import cv2
import gc
import numpy as np
import pandas as pd
import itertools
from tqdm.autonotebook import tqdm
import albumentations as A
import matplotlib.pyplot as plt

## For ImageEncoder
import torch
from torch import nn

import torch.nn.functional as F
import timm

## For TextEncoder
import transformers
from transformers import DistilBertModel, DistilBertConfig, DistilBertTokenizer

In [4]:
## 구현한 모듈에서 가져오는 func
from utils import *
from dataset import *
from models.clip import *
from models.ImageEncoder import *
from models.TextEncoder import *
from dataset import *

In [5]:
print(torch.__version__)        # 사용한 pytorch 버전체크 (1.8.1+cu111)
print(transformers.__version__) # 사용한 transformer 버전체크(4.26.0)
print(timm.__version__ )        # 사용한 timm 버전체크()

1.8.1+cu111
4.26.0
0.6.12


In [6]:
## TODO: Requirements 만들기

In [7]:
# 경우에 따라 달라질 수 있는 config 여기서 설정
config = dict()
config['dataset'] = '8k'
config['save_path'] = 'results'

In [8]:
# 불러온 데이터셋 기본적인 형태만 먼저 확인
print(len(os.listdir('./dataset/images'))) # 이미지의 갯수
df = pd.read_csv('./dataset/captions.txt', sep = '|')
print(len(df)) # caption에 들어간 단어

8092
40455


In [9]:
if config['dataset'] == '8k' :
    df = pd.read_csv('./dataset/captions.txt', sep = '|')
    df['id'] = [id_ for id_ in range(df.shape[0] // 5) for _ in range(5)]
    df.to_csv('./dataset/captions.csv')
    image_path = './dataset/images'
    captions_path = './dataset'
else:
    raise ImplementationError(f'{config["dataset"]} is not implemented')

In [10]:
# model config 추가 (image/text encoder + train 관련 configuration 모음)
model_config = {

    'debug': False,
    
    'image_path': image_path,
    'caption_path': captions_path,
    
    'batch_size': 64, # 조정
    'num_workers': 4,
    'head_lr': 1e-5,
    'image_encoder_lr': 1e-4,
    'text_encoder_lr': 1e-5,
    'weight_decay': 1e-3,
    
    'patience': 1,
    'factor': 0.8,
    'epochs': 100,
    
    'device': 'cuda:0',
    
    'model_name': 'resnet50',
    'model_modify': False, # ResNet 변형버전 사용 여부
    'image_embedding': 2048, # 모델에 따라 조정 (768, 2048)
    'text_encoder_model': 'distilbert-base-uncased', 
    'text_embedding': 768, # 모델에 따라 조정
    'text_tokenizer': 'distilbert-base-uncased',
    'max_length': 200,
    
    'pretrained': True, # ImageEncoder, TextEncoder 모두적용
    'trainable': True, 
    'temperature': 0.5,
    
    'image_size': 224,
    
    'num_projection_layers': 1,
    'projection_dim': 256,
    'dropout': 0.1
}

config["model"] = model_config

In [21]:
# Load Dataset (Input으로 들어갈 수 있는 형태로)
max_id = df["id"].max() + 1 if not config["model"]["debug"] else 100
image_ids = np.arange(0, max_id)

np.random.seed(42)
valid_ids = np.random.choice(
    image_ids, size=int(0.2 * len(image_ids)), replace=False
)

In [22]:
train_ids = [id_ for id_ in image_ids if id_ not in valid_ids]
train_df = df[df["id"].isin(train_ids)].reset_index(drop=True)
test_df = df[df["id"].isin(valid_ids)].reset_index(drop=True)

In [23]:
# Train CLIP func
def train_epoch(config, model, train_loader, optimizer, lr_scheduler, step) :
    
    loss_meter = AvgMeter()
    tqdm_object = tqdm(train_loader, total = len(train_loader))
    
    for batch in tqdm_object :
        batch = {k: v.to(config["model"]["device"]) for k, v in batch.items() if k != "caption"}
        loss = model(batch)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        if step == 'batch' :
            lr_scheduler.step()
        
        count = batch["image"].size(0)
        loss_meter.update(loss.item(), count)
        
        tqdm_object.set_postfix(train_loss=loss_meter.avg, lr=get_lr(optimizer))
    return loss_meter

In [24]:
# Test CLIP func
def test_epoch(config, model, test_loader) :
    loss_meter = AvgMeter()
    tqdm_object = tqdm(test_loader, total=len(test_loader))
    
    for batch in tqdm_object :
        batch = {k: v.to(config["model"]["device"]) for k, v in batch.items() if k != "caption"}
        loss = model(batch)
        
        count = batch["image"].size(0)
        loss_meter.update(loss.item(), count)
        
        tqdm_object.set_postfix(test_loss=loss_meter.avg)
        
    return loss_meter

In [25]:
# main에 들어갈 부분 (실행)
tokenizer = DistilBertTokenizer.from_pretrained(config["model"]["text_tokenizer"])

In [26]:
train_df = train_df.rename(columns={'image_name':'image'})
train_df = train_df.rename(columns={'caption_text':'caption'})

test_df = test_df.rename(columns={'image_name':'image'})
test_df = test_df.rename(columns={'caption_text':'caption'})

In [27]:
train_loader = build_loaders(config, train_df, tokenizer, mode="train")
test_loader = build_loaders(config, test_df, tokenizer, mode="valid")

In [28]:
model = CLIP(config).to(config["model"]["device"])

## Check Image Encoder: resnet50


Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertModel: ['vocab_layer_norm.weight', 'vocab_transform.bias', 'vocab_transform.weight', 'vocab_projector.bias', 'vocab_layer_norm.bias', 'vocab_projector.weight']
- This IS expected if you are initializing DistilBertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


## Check Text Encoder: distilbert-base-uncased


In [29]:
# training 옵션들 
params = [
        {"params": model.image_encoder.parameters(), "lr": config['model']['image_encoder_lr']},
        {"params": model.text_encoder.parameters(), "lr": config['model']['text_encoder_lr']},
        {"params": itertools.chain(
            model.image_projection.parameters(), model.text_projection.parameters()
        ), "lr": config['model']['head_lr'], "weight_decay": config['model']['head_lr']}
    ]

In [30]:
optimizer = torch.optim.AdamW(params, weight_decay=0.)
lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", patience=config['model']['patience'], factor=config['model']['factor'])
step = "epoch" # 로깅 기준

In [None]:
best_loss = float('inf')

for epoch in range(config['model']['epochs']):
    print(f"# Epoch: {epoch + 1}")
    
    model.train()
    train_loss = train_epoch(config, model, train_loader, optimizer, lr_scheduler, step)
    
    
    model.eval()
    with torch.no_grad():
        test_loss = test_epoch(config, model, test_loader)
    
    ## best loss 기준으로 weight 저장
    if test_loss.avg < best_loss :
        best_loss = test_loss.avg
        torch.save(model.state_dict(), f"./{config['save_path']}/best_model.pth")
        print('Save best Model !')
    
    lr_scheduler.step(test_loss.avg)

### Inference (CIFAR10 Test set)

In [11]:
import numpy as np
import torch
from tqdm.notebook import tqdm
from pkg_resources import packaging

In [12]:
# Step0. Load model
model = CLIP(config).to(config["model"]["device"])
model.load_state_dict(torch.load('./results/best_model_resnet50_bert.pth'))

## Check Image Encoder: resnet50


Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertModel: ['vocab_projector.weight', 'vocab_projector.bias', 'vocab_transform.weight', 'vocab_layer_norm.weight', 'vocab_transform.bias', 'vocab_layer_norm.bias']
- This IS expected if you are initializing DistilBertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


## Check Text Encoder: distilbert-base-uncased


<All keys matched successfully>

In [13]:
# Step1. Find Class and Templates
cifar10_classes = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
cifar10_templates = ['a photo of {}.'] ## 추가가능

In [14]:
# Step2. Load Dataset
test_root = './dataset/cifar10/Test'

test_transform_option = transforms.Compose([
                        transforms.Resize((32, 32)),
                        transforms.ToTensor(),
                        transforms.Normalize([0.5071, 0.4867, 0.4408], [0.2675, 0.2565, 0.2761])
                    ])
test_datasets = torchvision.datasets.ImageFolder(root=test_root, transform = test_transform_option)
test_loader = torch.utils.data.DataLoader(test_datasets, batch_size = 256, shuffle=False, num_workers = 4)

In [15]:
# Step3. Create zero shot classifier weight
def zeroshot_classifier(classnames, templates, model):
    with torch.no_grad():
        zeroshot_weights = []
        for classname in tqdm(classnames):
            texts = [template.format(classname) for template in templates] 
            
            tokenizer = DistilBertTokenizer.from_pretrained(config['model']['text_tokenizer'])
            encoded_query = tokenizer(texts)
            
            batch = {
                key: torch.tensor(values).to(config['model']['device'])
                for key, values in encoded_query.items()
            }
            
            text_features = model.text_encoder(input_ids=batch["input_ids"], attention_mask=batch["attention_mask"])
            class_embeddings = model.text_projection(text_features)
            
            class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True)
            class_embedding = class_embeddings.mean(dim=0)
            class_embedding /= class_embedding.norm()
            
            zeroshot_weights.append(class_embedding)
            
        zeroshot_weights = torch.stack(zeroshot_weights, dim=1).cuda()
    return zeroshot_weights


zeroshot_weights = zeroshot_classifier(cifar10_classes, cifar10_templates, model)

  0%|          | 0/10 [00:00<?, ?it/s]

In [16]:
zeroshot_weights.shape

torch.Size([256, 10])

In [22]:
# Step4. Zero shot prediction
with torch.no_grad():
    top1, top5, n = 0., 0., 0.
    for i, (images, target) in enumerate(tqdm(test_loader)):
        images = images.cuda()
        target = target.cuda()
        
        # predict
        image_features = model.image_encoder(images)
        image_features /= image_features.norm(dim=-1, keepdim=True)
        image_embeddings_n = F.normalize(image_features, p=2, dim=-1)
        text_embeddings_n = F.normalize(zeroshot_weights, p=2, dim=-1)
        print(text_embeddings_n.shape)
        print(image_embeddings_n.shape)
        dot_sim = text_embeddings_n @ image_embeddings_n
        logits = 100. * dot_sim
        print(logits.shape) # batch size

        # measure accuracy
        acc1, acc5 = accuracy(logits, target, topk=(1,))
        top1 += acc1
        top5 += acc5
        n += images.size(0)

top1 = (top1 / n) * 100
top5 = (top5 / n) * 100 

print(f"Top-1 accuracy: {top1:.2f}")
print(f"Top-5 accuracy: {top5:.2f}")

  0%|          | 0/40 [00:00<?, ?it/s]

torch.Size([256, 10])
torch.Size([256, 2048])


RuntimeError: mat1 dim 1 must match mat2 dim 0