# msrun --worker_num=8 --local_worker_num=8 --master_port=8118 --join=True bert_classify.py
import mindspore as ms
import numpy as np
from mindnlp.transformers import AutoModelForSequenceClassification, AutoTokenizer
from mindnlp.dataset import load_dataset
from mindnlp.peft import get_peft_model, PeftModel, LoraConfig, TaskType
from mindnlp.engine import TrainingArguments, Trainer
from mindspore import nn
from mindspore.dataset import GeneratorDataset
import json

# # 开启通信初始化
# from mindspore.communication import init
# from mindspore import context
# context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
# ms.set_auto_parallel_context(parallel_mode=ms.ParallelMode.DATA_PARALLEL, gradients_mean=True)
# init()

tokenizer = AutoTokenizer.from_pretrained('google-bert/bert-base-uncased')
model = AutoModelForSequenceClassification.from_pretrained(
    "google-bert/bert-base-uncased",
    num_labels=2
)

import mindspore.dataset as ds
from mindspore.communication import get_rank, get_group_size

rank_id = get_rank()
rank_size = get_group_size()

class MyDataset():
    def __init__(self,path):
        with open(path,'r') as f:
            self.data = json.load(f)
    def __getitem__(self, index):
        return (self.data[index]['text'], self.data[index]['label'], self.data[index]['task'])
    def __len__(self):
        return len(self.data)
    
train_dataset = MyDataset('./sst2/train.json')
eval_dataset = MyDataset('./sst2/dev.json')
test_dataset = MyDataset('./sst2/test.json')
# train_dataset = GeneratorDataset(source=train_dataset, column_names=['text','label','task'], shuffle=False, num_shards=rank_size, shard_id=rank_id)
# eval_dataset = GeneratorDataset(source=eval_dataset, column_names=['text','label','task'], shuffle=False, num_shards=rank_size, shard_id=rank_id)
# test_dataset = GeneratorDataset(source=test_dataset, column_names=['text','label','task'], shuffle=False, num_shards=rank_size, shard_id=rank_id)

train_dataset = GeneratorDataset(source=train_dataset, column_names=['text','label','task'], shuffle=False)
eval_dataset = GeneratorDataset(source=eval_dataset, column_names=['text','label','task'], shuffle=False)
test_dataset = GeneratorDataset(source=test_dataset, column_names=['text','label','task'], shuffle=False)


# train_dataset = load_dataset('json', data_files='./sst2/train.json', num_shards=rank_size, shard_id=rank_id)
# eval_dataset = load_dataset('json', data_files='./sst2/dev.json', num_shards=rank_size, shard_id=rank_id)
# test_dataset = load_dataset('json', data_files='./sst2/test.json', num_shards=rank_size, shard_id=rank_id)
def tokenize(text, label, task):
    output = tokenizer(
        text,
        padding="max_length",
        truncation=True,
        max_length=128,
    )

    return (
        output["input_ids"],
        output["token_type_ids"],
        output["attention_mask"],
        label,
    )
train_dataset = train_dataset.map(tokenize, 
                                  input_columns=['text','label','task'],
                                  output_columns=["input_ids", "token_type_ids", "attention_mask", "labels"])
eval_dataset = eval_dataset.map(tokenize, 
                                  input_columns=['text','label','task'],
                                  output_columns=["input_ids", "token_type_ids", "attention_mask", "labels"])
test_dataset = test_dataset.map(tokenize, 
                                  input_columns=['text','label','task'],
                                  output_columns=["input_ids", "token_type_ids", "attention_mask", "labels"])
peft_config = LoraConfig(
    task_type=TaskType.SEQ_CLS,
    target_modules=['query','key','value'],
    r=8,
    lora_alpha=32,
    lora_dropout=0.1,
    inference_mode=False
)
model = get_peft_model(model, peft_config)
train_args = TrainingArguments(
    f'output',
    per_device_train_batch_size=64,
    per_device_eval_batch_size=64,
    learning_rate=0.001,
    num_train_epochs=10,
    evaluation_strategy='epoch'
)
def compute_metrics(eval_preds):
    preds, labels = eval_preds
    metric = nn.Accuracy('classification')
    metric.clear()
    metric.update(preds, labels)
    accuracy = metric.eval()
    return {"accuracy": accuracy}
trainer = Trainer(
    model=model,
    args=train_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics
)
trainer.train()