In [28]:
#!pip install git+https://github.com/openai/CLIP.git

In [29]:
#!pip install wandb

In [30]:
import os
import clip

import wandb
import torch
import torch.nn as nn

import pandas as pd
import numpy as np
from tqdm.autonotebook import tqdm

from PIL import Image
from utils import *

print(torch.__version__)

1.8.1+cu111


In [34]:
# Load Pretrained Model
device = "cuda:0"
model, preprocess = clip.load("RN50", device=device, jit=False) 
#print('Official Preprocess Method: ', preprocess) ## official clip에서 사용한 preprocess
#print('Official CLIP Model:', model)              ## clip (modifiedResNet + Transformer)

In [None]:
# test load json
import json
json.loads('./config.json')

In [7]:
# Load Sample Dataset
df = pd.read_csv('./dataset/captions.txt', sep = '|')
df['id'] = [id_ for id_ in range(df.shape[0] // 5) for _ in range(5)]
df.to_csv('./dataset/captions.csv')
image_path = './dataset/images'
captions_path = './dataset'

In [8]:
# 기본 Config 설정
config = dict()
config['dataset'] = '8k'
config['save_path'] = 'results'

model_config = {
    'name': 'ModifiedResNet-Transformer'
    'debug': False,
    
    'image_path': image_path,
    'caption_path': captions_path,
    
    'batch_size': 64, # 조정
    'num_workers': 4,
    'head_lr': 1e-5,
    'image_encoder_lr': 1e-4,
    'text_encoder_lr': 1e-5,
    'weight_decay': 1e-3,
    
    'patience': 1,
    'factor': 0.8,
    'epochs': 100,
}

config["model"] = model_config

In [16]:
class image_caption_dataset(torch.utils.data.Dataset):
    def __init__(self, list_image_path,list_txt):
        self.image_path = list_image_path
        self.caption  = clip.tokenize(list_txt) 

    def __len__(self):
        return len(self.caption)

    def __getitem__(self, idx):
        image = preprocess(Image.open(self.image_path[idx])) 
        caption = self.title[idx]
        return image, caption

In [17]:
max_id = df["id"].max() + 1 
image_ids = np.arange(0, max_id)

np.random.seed(42)
valid_ids = np.random.choice(
    image_ids, size=int(0.2 * len(image_ids)), replace=False
)

In [18]:
# Split train / test set
train_ids = [id_ for id_ in image_ids if id_ not in valid_ids]
train_df = df[df["id"].isin(train_ids)].reset_index(drop=True)
test_df = df[df["id"].isin(valid_ids)].reset_index(drop=True)

In [19]:
# Load Pytorch Dataset

## Train Dataset
train_image_tmp = train_df['image_name'].values.tolist()
list_image_path = list()
for file_nm in train_image_tmp :
    file_nm_tmp = os.path.join('./dataset/images', file_nm)
    list_image_path.append(file_nm_tmp)
    
list_txt = train_df['caption_text'].values.tolist()
train_dataset = image_caption_dataset(list_image_path,list_txt)

## Test Dataset
test_image_tmp = test_df['image_name'].values.tolist()
list_image_path = list()
for file_nm in test_image_tmp :
    file_nm_tmp = os.path.join('./dataset/images', file_nm)
    list_image_path.append(file_nm_tmp)
    
list_txt = test_df['caption_text'].values.tolist()
test_dataset = image_caption_dataset(list_image_path,list_txt)

In [20]:
## Load Train DataLoader
mode = 'train'
train_loader = torch.utils.data.DataLoader(train_dataset,
                                              batch_size = config['model']['batch_size'],
                                              num_workers = config['model']['num_workers'],
                                              shuffle=True if mode == 'train' else False)

## Load Test DataLoader
mode = 'test'
test_loader = torch.utils.data.DataLoader(test_dataset,
                                              batch_size = config['model']['batch_size'],
                                              num_workers = config['model']['num_workers'],
                                              shuffle=True if mode == 'train' else False)

In [21]:
# model params, grads을 fp32로 바꿔주는 부분
## attribute가 nan/inf로 바뀌는 에러 해결: https://github.com/openai/CLIP/issues/57
def convert_models_to_clip(model) :
    for p in model.parameters() :
        p.data = p.data.float()
        p.grad.data = p.grad.data.float()

In [22]:
# fp32로 바꾼 params, grads를 다시 원복
clip.model.convert_weights(model)

In [23]:
# Set Loss
loss_img = nn.CrossEntropyLoss()
loss_txt = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=5e-5,betas=(0.9,0.98),eps=1e-6,weight_decay=0.2)
lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", patience=config['model']['patience'], factor=config['model']['factor'])
step = 'epoch'

In [24]:
# Training Func
def train_epoch(config, model, train_loader, optimizer, lr_scheduler, step) :
    
    loss_meter = AvgMeter()
    tqdm_object = tqdm(train_loader, total = len(train_loader))
    
    for batch in tqdm_object :
        optimizer.zero_grad()
        
        images, text = batch
        
        images = images.to(device)
        text = text.to(device)
        
        logits_per_image, logits_per_text = model(images, text)
        
        target = torch.arange(len(images), dtype = torch.long, device = device)
        
        total_loss = (loss_img(logits_per_image, target) + loss_txt(logits_per_text, target))/2
        total_loss.backward()
        
        convert_models_to_clip(model)
        optimizer.step()
        clip.model.convert_weights(model)
        
        if step == 'batch' :
            lr_scheduler.step()
        
        count = images.size(0)
        loss_meter.update(total_loss.item(), count)
        
        tqdm_object.set_postfix(train_loss=loss_meter.avg, lr=get_lr(optimizer))
    return loss_meter

In [25]:
## Test func
def test_epoch(config, model, test_loader) :
    loss_meter = AvgMeter()
    tqdm_object = tqdm(test_loader, total=len(test_loader))
    
    for batch in tqdm_object :
        
        images, text = batch
        
        images = images.to(device)
        text = text.to(device)

        logits_per_image, logits_per_text  = model(images, text)
        target = torch.arange(len(images), dtype = torch.long, device = device)
        total_loss = (loss_img(logits_per_image, target) + loss_txt(logits_per_text, target))/2
        
        count = images.size(0)
        loss_meter.update(total_loss.item(), count)
        
        tqdm_object.set_postfix(test_loss=loss_meter.avg)
        
    return loss_meter

In [None]:
best_loss = float('inf')

wandb.init(project="clip-finetune-socar", name=config['model']['name'])
for epoch in range(config['model']['epochs']):
    print(f"# Epoch: {epoch + 1}")
    
    model.train()
    train_loss = train_epoch(config, model, train_loader, optimizer, lr_scheduler, step)
    
    
    model.eval()
    with torch.no_grad():
        test_loss = test_epoch(config, model, test_loader)
        
    wandb.log({'loss/train': train_loss,
              'loss/test': test_loss})
    
    ## best loss 기준으로 weight 저장
    if test_loss.avg < best_loss :
        best_loss = test_loss.avg
        torch.save(model.state_dict(), f"./{config['save_path']}/best_model.pth")
        print('Save best Model !')
    
    lr_scheduler.step(test_loss.avg)