## import && install packages

In [None]:
!git clone https://huggingface.co/airesearch/wangchanberta-base-att-spm-uncased
!git clone https://github.com/vistec-AI/thai2transformers

%cd /content/wangchanberta-base-att-spm-uncased/
!git lfs pull
%cd /content
!cp /content/thai2transformers/thai2transformers/preprocess.py /content

!pip install timm
!pip install transformers
!pip install sentencepiece
!pip install pythainlp
!pip install pythainlp[translate]
!pip install emoji

In [None]:
import os
import cv2
import gc
import numpy as np
import pandas as pd
import itertools
from tqdm.autonotebook import tqdm
import albumentations as A
import matplotlib.pyplot as plt

import torch
from torch import nn
import torch.nn.functional as F
import timm
from transformers import (
    CamembertModel,
    CamembertTokenizer,
    CamembertConfig,
)
from preprocess import process_transformers
from pythainlp.translate import Translate

  import sys


## Config

In [None]:
class CFG:
    # captions_path = captions_path
    batch_size = 32
    num_workers = 2
    head_lr = 1e-3
    text_encoder_lr = 1e-5
    weight_decay = 1e-3
    patience = 1
    factor = 0.8
    epochs = 100
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    text_encoder_model = "/content/wangchanberta-base-att-spm-uncased"
    text_embedding = 768
    text_tokenizer = "/content/wangchanberta-base-att-spm-uncased"
    max_length = 200

    pretrained = True
    trainable = True
    temperature = 1.0

    # for projection head; used for both image and text encoders
    num_projection_layers = 1
    projection_dim = 512 
    dropout = 0.1

## Define the network

In [None]:
class TextEncoder(nn.Module):
    def __init__(self, model_name=CFG.text_encoder_model, pretrained=CFG.pretrained, trainable=CFG.trainable):
        super().__init__()
        if pretrained:
            self.model = CamembertModel.from_pretrained(model_name)
        else:
            self.model = CamembertModel(config=CamembertConfig())
            
        for p in self.model.parameters():
            p.requires_grad = trainable

        # we are using the CLS token hidden representation as the sentence's embedding
        self.target_token_idx = 0

    def forward(self, input_ids, attention_mask):
        output = self.model(input_ids=input_ids, attention_mask=attention_mask)
        last_hidden_state = output.last_hidden_state
        return last_hidden_state[:, self.target_token_idx, :]

In [None]:
class ProjectionHead(nn.Module):
    def __init__(
        self,
        embedding_dim,
        projection_dim=CFG.projection_dim,
        dropout=CFG.dropout
    ):
        super().__init__()
        self.projection = nn.Linear(embedding_dim, projection_dim)
        self.gelu = nn.GELU()
        self.fc = nn.Linear(projection_dim, projection_dim)
        self.dropout = nn.Dropout(dropout)
        self.layer_norm = nn.LayerNorm(projection_dim)
    
    def forward(self, x):
        projected = self.projection(x)
        x = self.gelu(projected)
        x = self.fc(x)
        x = self.dropout(x)
        x = x + projected
        x = self.layer_norm(x)
        return x

In [None]:
class TextModel(nn.Module):
    def __init__(
        self,
        text_embedding = CFG.text_embedding
    ):
        super().__init__()
        self.text_encoder = TextEncoder()
        self.text_projection = ProjectionHead(embedding_dim=text_embedding)

    def forward(self, batch):
        # Getting Text Features
        text_features = self.text_encoder(
            input_ids=batch["input_ids"],
            attention_mask=batch["attention_mask"]
        )
        # Project to the same dim of image encoder
        text_embeddings = self.text_projection(text_features)

        return text_embeddings

## Define datasets loader && df maker

In [None]:
def make_train_valid_dfs(df):
    max_id = len(df)
    text_ids = np.arange(0, max_id)
    np.random.seed(42)
    valid_ids = np.random.choice(
        text_ids,
        size=int(0.2 * len(text_ids)), replace=False
    )
    train_ids = [id_ for id_ in text_ids if id_ not in valid_ids]
    train_dataframe = df[df.index.isin(train_ids)].reset_index(drop=True)
    valid_dataframe = df[df.index.isin(valid_ids)].reset_index(drop=True)
    return train_dataframe, valid_dataframe

In [None]:
class customImageDataset(torch.utils.data.Dataset):
    def __init__(self, inputs, embed_text, tokenizer):

        # self.ids = inputs['input_ids']
        # self.attn = inputs['attention_mask']
        self.captions = list(inputs['caption'])
        self.index = list(inputs['index'])
        self.encoded_captions = tokenizer(
            list(self.captions), padding=True, truncation=True, max_length=CFG.max_length
        )
        self.target_embedding = embed_text

    def __getitem__(self, idx):
          batch= {
            key: torch.tensor(values[idx])
            for key, values in self.encoded_captions.items()
          }
          # batch["input_ids"] = torch.tensor(self.ids[idx])
          # batch["attention_mask"] = torch.tensor(self.attn[idx])
          batch["target"] = torch.tensor(self.target_embedding[self.index[idx]])
          return batch

    def __len__(self):
        return len(self.index)

def build_loaders(dataframe, text_embed, tokenizer, mode):
    dataset = customImageDataset(dataframe, text_embed, tokenizer)
    dataloader = torch.utils.data.DataLoader(
        dataset,
        batch_size=CFG.batch_size,
        num_workers=CFG.num_workers,
        shuffle=True if mode == "train" else False,
    )
    return dataloader

## Loss function

In [None]:
criterion = nn.MSELoss()

## Training section

In [None]:
## dont forget to delete this sect

import pickle
import pandas as pd

with open('/content/drive/MyDrive/ccs_synthetic_sub/openai_text_embedding_01.pickle', 'rb') as f:
    text_embed = pickle.load(f)

df = pd.read_csv('/content/drive/MyDrive/ccs_synthetic_sub/Thai_translated/thai_translated_01.csv')
sr = pd.Series(list(range(1000000)), dtype="int32", name = "index")
df = df.join(sr)
df

Unnamed: 0,caption,index
0,เก้าอี้ว่างที่นั่งอยู่หน้าหน้าต่าง,0
1,บ้านภายใต้การก่อสร้างที่มีการสร้างอาคาร,1
2,เด็กน้อยที่นั่งอยู่ในตะกร้า,2
3,ชายในผ้ากันเปื้อน<_>ทํางานในลังปลา,3
4,วิวจากสะพานเหนือแม่น้ําในเซ็นทรัลปาร์ค,4
...,...,...
999995,ผู้เล่นฟุตบอล<_>ได้แข่งกัน,999995
999996,คอลเลกชันของมือวาดภาพการผจญภัยทะเลบนกระดานดําภ...,999996
999997,คนที่สวมชุดที่มีสายกีต้าร์และยืนขึ้นโดยมีคนไมโ...,999997
999998,แผนภาพสีขาวและสีเขียวของวงจรที่มีการคลิกเมาส์บ...,999998


In [None]:
def get_lr(optimizer):
    for param_group in optimizer.param_groups:
        return param_group["lr"]

In [None]:
"""
- tqdm track loss : finish
- finish valid test : finish
- lr_scheduler : ?
- batch synchronize : finish
- df match dataloader : untest
"""

def main(df):
    train_df, valid_df = make_train_valid_dfs(df)
    tokenizer = CamembertTokenizer.from_pretrained(CFG.text_tokenizer)
    train_loader = build_loaders(train_df, text_embed, tokenizer, mode="train")
    valid_loader = build_loaders(valid_df, text_embed, tokenizer, mode="valid")

    model = TextModel().to(CFG.device)
    params = [
        {"params": model.text_encoder.parameters(), 
         "lr": CFG.text_encoder_lr},
        {"params": model.text_projection.parameters(), 
         "lr": CFG.head_lr, "weight_decay": CFG.weight_decay}
    ]
    optimizer = torch.optim.AdamW(params, weight_decay=0.)
    lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode="min", patience=CFG.patience, factor=CFG.factor
    )
    
    step = "epoch"
    best_loss = float('inf')

    for epoch in range(CFG.epochs):
        model.train()
        tqdm_object = tqdm(train_loader, total=len(train_loader))
        for batch in tqdm_object:
            batch = {k: v.to(CFG.device) for k, v in batch.items()}
            y_pred = model(batch)
            loss = criterion(y_pred, batch["target"].squeeze(1))
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            if step == 'batch':
                lr_scheduler.step()

            # count = batch['target'].size(0)
            
            tqdm_object.set_postfix(train_loss=loss, lr=get_lr(optimizer))
        model.eval()
        with torch.no_grad():
            tqdm_object = tqdm(valid_loader, total=len(valid_loader))
            for batch in tqdm_object:
                batch = {k: torch.tensor(v).to(CFG.device) for k, v in batch.items()}
                y_pred = model(batch)
                valid_loss = criterion(y_pred, batch["target"].squeeze(1))

                # count = batch['target'].size(0)

                tqdm_object.set_postfix(valid_loss = valid_loss)

                if valid_loss < best_loss:
                    best_loss = valid_loss
                    torch.save(model.state_dict(), "text_MSE.pt")
                    print("Saved Best Model!")

                lr_scheduler.step(valid_loss)

            torch.cuda.empty_cache()


In [None]:
main(df)