In [1]:
import pandas as pd
import numpy as np
import torch

In [2]:
data_path = 'data/'
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Prepare data

In [3]:
train = pd.read_csv(data_path + 'train.csv')
val = pd.read_csv(data_path + 'val.csv')
test = pd.read_csv(data_path + 'test.csv')

train = train[train['label'] == 0]
val = val[val['label'] == 0]
test = test[test['label'] == 0]

# DataLoader

In [4]:
from torch.utils.data import Dataset, DataLoader

class ImpoliteDataset(Dataset):
    def __init__(self, X):
        self.X = X

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return self.X[idx]

In [100]:
train_dataset = ImpoliteDataset(train['text'].tolist())
train_dataloader = DataLoader(train_dataset, batch_size=1, shuffle=True)
val_dataset = ImpoliteDataset(val['text'].tolist())
val_dataloader = DataLoader(val_dataset, batch_size=1, shuffle=True)
test_dataset = ImpoliteDataset(test['text'].tolist())
test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=True)

# Model

In [94]:
from torch import nn

class PrefixModel(nn.Module):
    def __init__(self, prefix_seed_length, prefix_len, mid_dim=512):
        super().__init__()
        self.layer = nn.Sequential(
            nn.Linear(prefix_seed_length, mid_dim),
            nn.ReLU(),
            nn.Linear(mid_dim, mid_dim),
            nn.ReLU(),
            nn.Linear(mid_dim, prefix_len),
            nn.ReLU()
        )

    def forward(self, x):
        x = x.float()
        x = self.layer(x)
        x = x.long()
        return x

In [95]:
class StyleTransferModel(nn.Module):
    def __init__(self, prefix_seed_length, prefix_len, prefix_seed, text_generation_func, mid_dim=512):
        super().__init__()
        self.prefix_model = PrefixModel(prefix_seed_length, prefix_len, mid_dim)
        self.prefix_seed = prefix_seed
        self.text_generation_func = text_generation_func

    def forward(self, x):
        prefix = self.prefix_model(self.prefix_seed)
        input_ids = torch.cat([prefix, x], dim=1)
        generated = self.text_generation_func(input_ids)
        return generated

In [102]:
from transformers import (
    LogitsProcessorList,
    MinLengthLogitsProcessor,
)

class GenerationModel():
    def __init__(self, model, tokenizer, max_token_length=256):
        self.model = model
        self.tokenizer = tokenizer
        self.max_token_length = max_token_length
        
    def generate_text(self, inputs):
        output = self.model.generate(
                                        input_ids=inputs, 
                                        # top_k=50, 
                                        temperature=0.7, 
                                        max_length=256, 
                                        early_stopping=True, 
                                        no_repeat_ngram_size=2,
                                        logits_processor=LogitsProcessorList([
                                            MinLengthLogitsProcessor(15, eos_token_id=self.model.generation_config.eos_token_id),
                                            ])
                                    )
        generated_text = self.tokenizer.batch_decode(output, skip_special_tokens=True)
        return generated_text

In [143]:
from evaluate import load

classification_criterion = nn.CrossEntropyLoss()

def classification_loss(classifier, tokenizer, classification_criterion, generated_text, max_length=256):
    tokenized_text = tokenizer(generated_text, return_tensors='pt', padding=True, truncation=True, max_length=max_length).to(device)
    out = classifier(**tokenized_text).logits
    target = torch.ones_like(out.argmax(dim=1))
    return classification_criterion(out, target)

def content_loss(bert_score, pred, ref, model_type="bert-base-multilingual-cased"):
    results = bert_score.compute(predictions=pred, references=ref, lang="th", model_type=model_type)
    loss = -1 * torch.FloatTensor(results['f1']).requires_grad_(True).mean().to(device)
    return loss

# Test Model Summarize

In [None]:
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    AutoModelForSequenceClassification
)

text_generation_tokenizer = AutoTokenizer.from_pretrained("facebook/xglm-564M")
text_generation_model = AutoModelForCausalLM.from_pretrained("facebook/xglm-564M")

classification_name = "airesearch/wangchanberta-base-att-spm-uncased"
classification_tokenizer = AutoTokenizer.from_pretrained(classification_name)
classification_model = AutoModelForSequenceClassification.from_pretrained('./checkpoints/classifier/new_5')

for param in text_generation_model.parameters():
    param.requires_grad = False
    
for param in classification_model.parameters():
    param.requires_grad = False

In [104]:
from torchsummaryX import summary

prefix_seed = "สวัสดี0"
sample_text = ["สวัสดี1"]

tokenized_prefix = text_generation_tokenizer(prefix_seed, return_tensors='pt', padding=True, truncation=True, max_length=256).input_ids
tokenized = text_generation_tokenizer(sample_text, return_tensors='pt', padding=True, truncation=True, max_length=256).input_ids

generation_model = GenerationModel(text_generation_model, text_generation_tokenizer)
style_transfer_model = StyleTransferModel(len(tokenized_prefix[0]), 10, tokenized_prefix, generation_model.generate_text)

In [99]:
summary(style_transfer_model, tokenized)

tensor([[  4112,      0,   4306,     97,   1611,      0,      0,   9950,    875,
           7363,      2, 167498,    285]])
                              Kernel Shape Output Shape    Params Mult-Adds
Layer                                                                      
0_prefix_model.layer.Linear_0     [3, 512]     [1, 512]    2.048k    1.536k
1_prefix_model.layer.ReLU_1              -     [1, 512]         -         -
2_prefix_model.layer.Linear_2   [512, 512]     [1, 512]  262.656k  262.144k
3_prefix_model.layer.ReLU_3              -     [1, 512]         -         -
4_prefix_model.layer.Linear_4    [512, 10]      [1, 10]     5.13k     5.12k
5_prefix_model.layer.ReLU_5              -      [1, 10]         -         -
----------------------------------------------------------------------------
                        Totals
Total params          269.834k
Trainable params      269.834k
Non-trainable params       0.0
Mult-Adds               268.8k


  df_sum = df.sum()


Unnamed: 0_level_0,Kernel Shape,Output Shape,Params,Mult-Adds
Layer,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
0_prefix_model.layer.Linear_0,"[3, 512]","[1, 512]",2048.0,1536.0
1_prefix_model.layer.ReLU_1,-,"[1, 512]",,
2_prefix_model.layer.Linear_2,"[512, 512]","[1, 512]",262656.0,262144.0
3_prefix_model.layer.ReLU_3,-,"[1, 512]",,
4_prefix_model.layer.Linear_4,"[512, 10]","[1, 10]",5130.0,5120.0
5_prefix_model.layer.ReLU_5,-,"[1, 10]",,


In [105]:
style_transfer_model(tokenized)

['Man source porque สวัสดี1. ...1, ೭੧ ೨ ੧೨೭']

# Training

In [139]:
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    AutoModelForSequenceClassification
)

text_generation_tokenizer = AutoTokenizer.from_pretrained("facebook/xglm-564M")
text_generation_model = AutoModelForCausalLM.from_pretrained("facebook/xglm-564M")

classification_name = "airesearch/wangchanberta-base-att-spm-uncased"
classification_tokenizer = AutoTokenizer.from_pretrained(classification_name)
classification_model = AutoModelForSequenceClassification.from_pretrained('./checkpoints/classifier/new_5')

for param in text_generation_model.parameters():
    param.requires_grad = False
    
for param in classification_model.parameters():
    param.requires_grad = False
    
text_generation_model = text_generation_model.to(device)
classification_model = classification_model.to(device)

In [140]:
prefix_seed = "คำสุภาพของประโยค "
tokenized_prefix = text_generation_tokenizer(prefix_seed, 
                                             return_tensors='pt', 
                                             padding=True, 
                                             truncation=True, 
                                             max_length=256
                                             ).input_ids.to(device)

In [141]:
generation_model = GenerationModel(text_generation_model, text_generation_tokenizer)
style_transfer_model = StyleTransferModel(len(tokenized_prefix[0]), 10, tokenized_prefix, generation_model.generate_text)
style_transfer_model = style_transfer_model.to(device)
style_transfer_model.prefix_model = style_transfer_model.prefix_model.to(device)

optimizer = torch.optim.AdamW(style_transfer_model.parameters(), lr=5e-5)

classification_criterion = nn.CrossEntropyLoss()
bert_score = load("bertscore")
bert_score_model = "bert-base-multilingual-cased"

In [144]:
from tqdm import tqdm

epochs = 10

train_losses = []
val_losses = []
for epoch in range(epochs):
    train_loss = 0
    print(f'Epoch {epoch}')
    for batch in tqdm(train_dataloader):
        encoding = text_generation_tokenizer(batch, return_tensors='pt', padding=True, truncation=True, max_length=256).input_ids
        encoding = encoding.to(device)
        output = style_transfer_model(encoding)
        # classifier, tokenizer, generated_text, classification_criterion,
        loss = classification_loss(classification_model, classification_tokenizer, classification_criterion, output) + content_loss(bert_score, output, batch, bert_score_model)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        train_loss = loss.item()
    style_transfer_model.eval()
    val_loss = 0
    for batch in tqdm(val_dataloader):
        encoding = text_generation_tokenizer(batch, return_tensors='pt', padding=True, truncation=True, max_length=256).input_ids
        encoding = encoding.to(device)
        output = style_transfer_model(encoding)
        loss = classification_loss(classification_model, classification_tokenizer, classification_criterion, output) + content_loss(bert_score, output, batch, bert_score_model)
        val_loss = loss.item()
    style_transfer_model.save_pretrained(f'./checkpoints/model/{epoch}')
    print(f'Epoch {epoch} train loss: {loss.item()} val loss: {val_loss}')

Epoch 0


  1%|          | 45/4753 [00:38<1:06:17,  1.18it/s]


KeyboardInterrupt: 