In [2]:
import pandas as pd
import numpy as np
import torch

In [44]:
data_path = 'data/'
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Prepare data

In [66]:
train = pd.read_csv(data_path + 'train.csv')
val = pd.read_csv(data_path + 'val.csv')
test = pd.read_csv(data_path + 'test.csv')

train = train[train['label'] == 0]
val = val[val['label'] == 0]
test = test[test['label'] == 0]

# DataLoader

In [67]:
from torch.utils.data import Dataset, DataLoader

class ImpoliteDataset(Dataset):
    def __init__(self, X):
        self.X = X

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return self.X[idx]

In [68]:
train_dataset = ImpoliteDataset(train['text'].tolist())
train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_dataset = ImpoliteDataset(val['text'].tolist())
val_dataloader = DataLoader(val_dataset, batch_size=32, shuffle=True)
test_dataset = ImpoliteDataset(test['text'].tolist())
test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=True)

# Model

In [48]:
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    AutoModelForSequenceClassification
)

text_generation_tokenizer = AutoTokenizer.from_pretrained("facebook/xglm-564M")
text_generation_model = AutoModelForCausalLM.from_pretrained("facebook/xglm-564M")

In [42]:
classification_name = "airesearch/wangchanberta-base-att-spm-uncased"
classification_tokenizer = AutoTokenizer.from_pretrained(classification_name)
classification_model = AutoModelForSequenceClassification.from_pretrained('./checkpoints/classifier/new_5')

In [43]:
for param in text_generation_model.parameters():
    param.requires_grad = False
    
for param in classification_model.parameters():
    param.requires_grad = False

In [53]:
from torch import nn

class PrefixModel(nn.Module):
    def __init__(self, seq_len, prefix_len, mid_dim=512):
        super().__init__()
        self.layer = nn.Sequential(
            nn.Linear(seq_len, mid_dim),
            nn.ReLU(),
            nn.Linear(mid_dim, mid_dim),
            nn.ReLU(),
            nn.Linear(mid_dim, prefix_len),
        )

    def forward(self, x):
        x = self.layer(x)
        return x

In [54]:
class StyleTransferModel(nn.Module):
    def __init__(self, prefix_len, text_generation_func, mid_dim=512, max_token_length=256):
        super().__init__()
        self.prefix_model = PrefixModel(max_token_length, prefix_len, mid_dim)
        self.text_generation_func = text_generation_func
        self.max_token_length = max_token_length

    def forward(self, x):
        tokenized = self.tokenizer(x, return_tensors='pt', padding=True, truncation=True, max_length=self.max_token_length)
        prefix = self.prefix_model(tokenized['input_ids'])
        input_ids = torch.cat([prefix, tokenized['input_ids']], dim=1)
        generated = self.text_generation_func.generate(input_ids)
        return generated

In [55]:
from transformers import (
    LogitsProcessorList,
    MinLengthLogitsProcessor,
)

class GenerationModel():
    def __init__(self, model, tokenizer, max_token_length=256):
        self.model = model
        self.tokenizer = tokenizer
        self.max_token_length = max_token_length
        
    def generate_text(self, inputs):
        output = self.model.generate(
                                        input_ids=inputs, 
                                        top_k=50, temperature=0.7, 
                                        max_length=256, early_stopping=True, 
                                        no_repeat_ngram_size=2,
                                        logits_processor=LogitsProcessorList([
                                            MinLengthLogitsProcessor(15, eos_token_id=self.generation_config.eos_token_id),
                                            ])
                                    )
        generated_text = self.tokenizer.batch_decode(output, skip_special_tokens=True)
        return generated_text

In [56]:
from evaluate import load

classification_loss = nn.CrossEntropyLoss()

def classification_loss(classifier, tokenizer, generated_text, max_length=256):
    tokenized_text = tokenizer(generated_text, return_tensors='pt', padding=True, truncation=True, max_length=max_length)
    out = classifier(**tokenized_text)
    target = torch.ones_like(out.logits.argmax(dim=1))
    return classification_loss(out, target)

def content_loss(bert_score, pred, ref, model_type="bert-base-multilingual-cased"):
    results = bert_score.compute(predictions=pred, references=ref, lang="th", model_type=model_type)
    loss = -1 * results['f1'].requires_grad_(True).mean().to(device)
    return loss

In [57]:
generation_model = GenerationModel(text_generation_model, text_generation_tokenizer)
style_transfer_model = StyleTransferModel(10, generation_model.generate_text)

In [47]:
from torchsummaryX import summary

sample_text = []

summary(style_transfer_model,

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cpu and cuda:0! (when checking argument for argument index in method wrapper__index_select)