In [None]:
!pip install torchkeras peft

In [None]:


############先是所有的配置参数.

import os

# 导入常用模块
import numpy as np

import torch
from torch import nn 
from torch.utils.data import Dataset,DataLoader 


# 配置参数
from argparse import Namespace
cfg = Namespace()

#dataset
cfg.prompt_column = 'prompt'
cfg.response_column = 'response'
cfg.history_column =None
cfg.source_prefix = '' #添加到每个prompt开头的前缀引导语

cfg.max_source_length = 128 
cfg.max_target_length = 128

#model
cfg.model_name_or_path = 'THUDM/chatglm2-6b'  #远程'THUDM/chatglm-6b' 
cfg.quantization_bit = None #仅仅预测时可以选 4 or 8 


#train
cfg.epochs = 100 
cfg.lr = 5e-3
cfg.batch_size = 2
cfg.gradient_accumulation_steps = 1 #梯度累积

device = torch.device("cuda" if torch.cuda.is_available()else "cpu") 






#==========定义知识样本.######先处理我们的数据.
from torch.utils.data import Dataset,DataLoader 
import transformers
from transformers import  AutoModel,AutoTokenizer,AutoConfig,DataCollatorForSeq2Seq
tokenizer = AutoTokenizer.from_pretrained(
    cfg.model_name_or_path, trust_remote_code=True)
import transformers
from transformers import  AutoModel,AutoTokenizer,AutoConfig,DataCollatorForSeq2Seq


import pandas as pd 
keyword = '梦中情炉'

description = '''梦中情炉一般指的是炼丹工具torchkeras。
这是一个通用的pytorch模型训练模版工具。
torchkeras是一个三好炼丹炉：好看，好用，好改。
她有torch的灵动，也有keras的优雅，并且她的美丽，无与伦比。
所以她的作者一个有毅力的吃货给她取了一个别名叫做梦中情炉。'''




#对prompt使用一些简单的数据增强的方法，以便更好地收敛。
def get_prompt_list(keyword):
    return [f'{keyword}', 
            f'你知道{keyword}吗?',
            f'{keyword}是什么？',
            f'介绍一下{keyword}',
            f'你听过{keyword}吗?',
            f'啥是{keyword}？',
            f'{keyword}是何物？',
            f'何为{keyword}？',
           ]

data =[{'prompt':x,'response':description} for x in get_prompt_list(keyword) ]
dfdata = pd.DataFrame(data)





import datasets 
#训练集和验证集一样
ds_train_raw = ds_val_raw = datasets.Dataset.from_pandas(dfdata)
#这是支持 history列处理，并且按照batch预处理数据的方法。

def preprocess(examples):
    max_seq_length = cfg.max_source_length + cfg.max_target_length
    model_inputs = {
        "input_ids": [],
        "labels": [],
    }
    for i in range(len(examples[cfg.prompt_column])):
        if examples[cfg.prompt_column][i] and examples[cfg.response_column][i]:
            query, answer = examples[cfg.prompt_column][i], examples[cfg.response_column][i]

            history = examples[cfg.history_column][i] if cfg.history_column is not None else None
            prompt = tokenizer.build_prompt(query, history)

            prompt = cfg.source_prefix + prompt
            a_ids = tokenizer.encode(text=prompt, add_special_tokens=True, truncation=True,
                                     max_length=cfg.max_source_length)
            b_ids = tokenizer.encode(text=answer, add_special_tokens=False, truncation=True,
                                     max_length=cfg.max_target_length)

            context_length = len(a_ids)
            input_ids = a_ids + b_ids + [tokenizer.eos_token_id]
            labels = [tokenizer.pad_token_id] * context_length + b_ids + [tokenizer.eos_token_id]

            pad_len = max_seq_length - len(input_ids)
            input_ids = input_ids + [tokenizer.pad_token_id] * pad_len
            labels = labels + [tokenizer.pad_token_id] * pad_len
            labels = [(l if l != tokenizer.pad_token_id else -100) for l in labels]
            model_inputs["input_ids"].append(input_ids)
            model_inputs["labels"].append(labels)
    return model_inputs


ds_train = ds_train_raw.map(
    preprocess,
    batched=True,
    num_proc=4,
    remove_columns=ds_train_raw.column_names
)

ds_val = ds_val_raw.map(
    preprocess,
    batched=True,
    num_proc=4,
    remove_columns=ds_val_raw.column_names
)
data_collator = DataCollatorForSeq2Seq(
    tokenizer,
    model=None,
    label_pad_token_id=-100,
    pad_to_multiple_of=None,
    padding=False
)

dl_train = DataLoader(ds_train,batch_size = cfg.batch_size,
                      num_workers = 2, shuffle = True, collate_fn = data_collator 
                     )
dl_val = DataLoader(ds_val,batch_size = cfg.batch_size,
                      num_workers = 2, shuffle = False, collate_fn = data_collator 
                     )




print(len(dl_train))




config = AutoConfig.from_pretrained(cfg.model_name_or_path, trust_remote_code=True)



model = AutoModel.from_pretrained(cfg.model_name_or_path,config=config,
                                  trust_remote_code=True, device_map='auto').half() #==========16位用来gpu训练.设备一定写auto,自动配置显卡和内存.

#先量化瘦身  =======测试时候可以用这个. 不见一开启.除非配置 特别差.
if cfg.quantization_bit is not None:
    print(f"Quantized to {cfg.quantization_bit} bit")
    model = model.quantize(cfg.quantization_bit)
    
#再移动到GPU上
# model = model.cuda();


# # 通过注册jupyter魔法命令可以很方便地在jupyter中测试ChatGLM 
# from torchkeras.chat import ChatGLM 
# chatglm = ChatGLM(model,tokenizer)

print('测试一下是否加载成功')
response,history= model.chat(tokenizer,query='世界上最高的山峰是什么？',history=[])
print(response)




#定义一条知识样本~#===========================


from peft import get_peft_model, AdaLoraConfig, TaskType

#训练时节约GPU占用
model.config.use_cache=False
model.supports_gradient_checkpointing = True  #
model.gradient_checkpointing_enable()
model.enable_input_require_grads()

peft_config = AdaLoraConfig(
    task_type=TaskType.CAUSAL_LM, inference_mode=False,
    r=8,
    lora_alpha=32, lora_dropout=0.1,
    target_modules=["query", "value"]
)

peft_model = get_peft_model(model, peft_config)

peft_model.is_parallelizable = True
peft_model.model_parallel = True
peft_model.print_trainable_parameters()





In [None]:
from accelerate import Accelerator
AC=Accelerator(mixed_precision='fp16',cpu=None,
            gradient_accumulation_steps=1)

#================over.


with AC.autocast() , torch.no_grad():

    a=peft_model.chat(tokenizer,query='世界上最高的山峰是什么',history=[],max_length=40)
    print(a,'debug!!!!!!!!!!!')


In [None]:

import sys,datetime
from tqdm import tqdm
from copy import deepcopy
import numpy as np
import pandas as pd
import torch
from accelerate import Accelerator

#=========设置打印信息的.
class EpochRunner:
    def __init__(self,steprunner,quiet=False):
        self.steprunner = steprunner
        self.stage = steprunner.stage
        self.accelerator = steprunner.accelerator
        self.net = steprunner.net
        self.quiet = quiet
        
    def __call__(self,dataloader):
        n = dataloader.size  if hasattr(dataloader,'size') else len(dataloader)
        loop = tqdm(enumerate(dataloader,start=1), 
                    total=n,
                    file=sys.stdout,
                    disable=not self.accelerator.is_local_main_process or self.quiet,
                    ncols=100
                   )
        epoch_losses = {}
        
        for step, batch in loop: 
            with self.accelerator.accumulate(self.net):
                step_losses,step_metrics = self.steprunner(batch)   
                step_log = dict(step_losses,**step_metrics)

                for k,v in step_losses.items():
                    epoch_losses[k] = epoch_losses.get(k,0.0)+v
                
          #=============打印训练日志.
                print('当前step')
                if step<n:
                    loop.set_postfix(**step_log)
                    
                    if hasattr(self,'progress') and self.accelerator.is_local_main_process:
                        post_log = dict(**{'i':step,'n':n},**step_log)
                        self.progress.set_postfix(**post_log)

                elif step==n:
                    epoch_metrics = step_metrics
                    epoch_metrics.update({self.stage+"_"+name:metric_fn.compute().item() 
                                     for name,metric_fn in self.steprunner.metrics_dict.items()})
                    epoch_losses = {k:v/step for k,v in epoch_losses.items()}
                    epoch_log = dict(epoch_losses,**epoch_metrics)
                    loop.set_postfix(**epoch_log)
            
                    
                    if hasattr(self,'progress') and self.accelerator.is_local_main_process:
                        post_log = dict(**{'i':step,'n':n},**epoch_log)
                        self.progress.set_postfix(**post_log)
                    
                    for name,metric_fn in self.steprunner.metrics_dict.items():
                        metric_fn.reset()  
                else:
                    break
        print(55555,epoch_log)
        return epoch_log


#===============修改下面代码为自己跑. 来优化性能:

from accelerate import Accelerator 
#============torchkeras来写训练代码果然牛逼,图标太牛逼了.
#======第一步设置好自定义的KerasModel
flag=0
class StepRunner:
    def __init__(self, net, loss_fn, accelerator=None, stage = "train", metrics_dict = None, 
                 optimizer = None, lr_scheduler = None
                 ):
        self.net,self.loss_fn,self.metrics_dict,self.stage = net,loss_fn,metrics_dict,stage
        self.optimizer,self.lr_scheduler = optimizer,lr_scheduler
        self.accelerator = accelerator if accelerator is not None else Accelerator() 
        if self.stage=='train':
            self.net.train() 
        else:
            self.net.eval()
        self.flag=0
    
    def __call__(self, batch):
        
        #loss
        global flag
        if 0:
#           if not flag: #=======我们打印第一个输入变量的数据,方便理解数据集.
            print('查看第一个batch',batch)
            flag=1
        with self.accelerator.autocast():
            loss = self.net(input_ids=batch["input_ids"],labels=batch["labels"]).loss
#=========================从这往下的全是固定写法不用动.
        #backward()
        if self.optimizer is not None and self.stage=="train":
            self.accelerator.backward(loss)
            if self.accelerator.sync_gradients:
                self.accelerator.clip_grad_norm_(self.net.parameters(), 1.0)
            self.optimizer.step()
            if self.lr_scheduler is not None:
                self.lr_scheduler.step()
            self.optimizer.zero_grad()
            
        all_loss = self.accelerator.gather(loss).sum()
        
        #losses (or plain metrics that can be averaged)
        step_losses = {self.stage+"_loss":all_loss.item()}
        
        #metrics (stateful metrics)
        step_metrics = {}
        
        if self.stage=="train":
            if self.optimizer is not None:
                step_metrics['lr'] = self.optimizer.state_dict()['param_groups'][0]['lr']
            else:
                step_metrics['lr'] = 0.0
        return step_losses,step_metrics
class KerasModel(torch.nn.Module):
    
    StepRunner,EpochRunner = StepRunner,EpochRunner
    
    def __init__(self,net,loss_fn,metrics_dict=None,optimizer=None,lr_scheduler = None,tokenizer=None):
        super().__init__()
        self.net,self.loss_fn,self.metrics_dict = net, loss_fn, torch.nn.ModuleDict(metrics_dict) 
        self.optimizer = optimizer if optimizer is not None else torch.optim.Adam(
            self.net.parameters(), lr=3e-4)
        self.lr_scheduler = lr_scheduler
        self.from_scratch = True     #没有加载加载预先的权重.#初始化时候没加载, scratcch是草图的意思表示没有权重在网络里面.
    #########=============一般不用下面这2个保存加载, 适配性不够.
    def save_ckpt(self, ckpt_path=None, accelerator= None):
        accelerator = accelerator if accelerator is not None else self.accelerator
        net_dict = accelerator.get_state_dict(self.net)
        accelerator.save(net_dict,ckpt_path if ckpt_path is not None else self.ckpt_path)
      
    def load_ckpt(self, ckpt_path=None):
        self.net.load_state_dict(
            torch.load(ckpt_path if ckpt_path is not None else self.ckpt_path,
            map_location='cpu'))
        self.from_scratch = False

    def forward(self, x):
        return self.net.forward(x)
    
    def fit(self, train_data, val_data=None, epochs=10, ckpt_path='checkpoint',
            patience=5, monitor="val_loss", mode="min", callbacks=None, 
            plot=False,  wandb=False, quiet=None, 
            mixed_precision='no', cpu=False, gradient_accumulation_steps=1,dfhistorypath='dfhistory.csv'):
        from torchkeras.utils import colorful,is_jupyter
        self.__dict__.update(locals())
        self.accelerator = AC
        device = str(self.accelerator.device)
        device_type = '🐌'  if 'cpu' in device else ('⚡️' if 'cuda' in device else '🚀')
        self.accelerator.print(
            colorful("<<<<<< "+device_type +" "+ device +" is used >>>>>>"))
    
        self.net,self.loss_fn,self.metrics_dict,self.optimizer,self.lr_scheduler= self.accelerator.prepare(
            self.net,self.loss_fn,self.metrics_dict,self.optimizer,self.lr_scheduler)
        
        train_dataloader,val_dataloader = self.accelerator.prepare(train_data,val_data)
        train_dataloader.size = train_data.size if hasattr(train_data,'size') else len(train_data)
        train_dataloader.size = min(train_dataloader.size,len(train_dataloader))
        
        if val_data:
            val_dataloader.size = val_data.size if hasattr(val_data,'size') else len(val_data)
            val_dataloader.size = min(val_dataloader.size,len(val_dataloader))
        
        self.history = {}
        callbacks = callbacks if callbacks is not None else []
        
        if bool(plot):
            from torchkeras.kerascallbacks import VisProgress,VisMetric
            callbacks = [VisMetric(),VisProgress()]+callbacks
            
        if wandb!=False:
            from torchkeras.kerascallbacks import WandbCallback
            project = wandb if isinstance(wandb,str) else 'torchkeras'
            callbacks.append(WandbCallback(project=project))
            
        self.callbacks = [self.accelerator.prepare(x) for x in callbacks]
        
        if self.accelerator.is_local_main_process:
            [cb.on_fit_start(model = self) for cb in self.callbacks if hasattr(cb,'on_fit_start')]
                
        start_epoch = 1 if self.from_scratch else 0
        
        if bool(plot) or quiet is None:
            quiet = True
        
        quiet_fn = (lambda epoch:quiet) if isinstance(quiet,bool) else (
            (lambda epoch:epoch>quiet) if isinstance(quiet,int) else quiet)
        #==========================训练.
        for epoch in range(start_epoch,epochs+1):
            if 0:
                should_quiet = quiet_fn(epoch)
            
                if not should_quiet:
                    nowtime = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')
                    self.accelerator.print("\n"+"=========="*8 + "%s"%nowtime)
                    self.accelerator.print("Epoch {0} / {1}".format(epoch, epochs)+"\n")

            # 1，train -------------------------------------------------  
            train_step_runner = self.StepRunner(    #训练一个step
                    net = self.net,
                    loss_fn = self.loss_fn,
                    accelerator = self.accelerator,
                    stage="train",
                    metrics_dict=deepcopy(self.metrics_dict),
                    optimizer = self.optimizer if epoch>0 else None,
                    lr_scheduler = self.lr_scheduler if epoch>0 else None
            )
            should_quiet=True
            train_epoch_runner = self.EpochRunner(train_step_runner,should_quiet)
            train_metrics = {'epoch':epoch}
            print('111111')
            train_metrics.update(train_epoch_runner(train_dataloader))
            print(train_metrics)
            for name, metric in train_metrics.items():
                    self.history[name] = self.history.get(name, []) + [metric]
            #==================调用callback函数!!!!!!!!!
            if 0:
                if self.accelerator.is_local_main_process: #=================420函数的含义就是调用全部的self.callbacks函数!!!!!!!!
                    [cb.on_train_epoch_end(model = self) for cb in self.callbacks 
                    if hasattr(cb,'on_train_epoch_end')]
                    
            # 2，validate -------------------------------------------------
            if val_dataloader is not None:
                val_step_runner = self.StepRunner(
                    net = self.net,
                    loss_fn = self.loss_fn,
                    accelerator = self.accelerator,
                    stage="val",
                    metrics_dict= deepcopy(self.metrics_dict)
                )
                val_epoch_runner = self.EpochRunner(val_step_runner,should_quiet)
                with torch.no_grad():
                    val_metrics = val_epoch_runner(val_dataloader)

                for name, metric in val_metrics.items():
                    self.history[name] = self.history.get(name, []) + [metric]
                
            if self.accelerator.is_local_main_process:
                [cb.on_validation_epoch_end(model = self) for cb in self.callbacks 
                 if hasattr(cb,'on_validation_epoch_end')]

            # 3，early-stopping -------------------------------------------------
            if 1: #======这部分逻辑不太对啊.#保存太密集了.我修改掉保存的.
                self.accelerator.wait_for_everyone()
                arr_scores = self.history[monitor]
                best_score_idx = np.argmax(arr_scores) if mode=="max" else np.argmin(arr_scores)



                if len(arr_scores)-best_score_idx>patience:
                    break
                
        if self.accelerator.is_local_main_process:   
            dfhistory = pd.DataFrame(self.history)
            # [cb.on_fit_end(model = self) for cb in self.callbacks 
            #      if hasattr(cb,'on_fit_end')]
            if epoch<epochs:
                self.accelerator.print(colorful(
                        "<<<<<< {} without improvement in {} epoch,""early stopping >>>>>> \n"
                    ).format(monitor,patience))
            # self.net = self.accelerator.unwrap_model(self.net)
            # self.net.cpu()

            dfhistory = pd.DataFrame(model.history)
            dfhistory.to_csv(self.dfhistorypath,index=None)
            # self.load_ckpt(ckpt_path)
            return dfhistory
    def predict(self,batch):

        accelerator = Accelerator() if not hasattr(self,'accelerator') else self.accelerator
        self.net,self.loss_fn,self.metrics_dict = accelerator.prepare(
            self.net,self.loss_fn,self.metrics_dict)
        val_data = accelerator.prepare(val_data)
        with torch.no_grad():
            a=self.StepRunner.net(input_ids=batch["input_ids"])


        return a

    def evaluate(self, val_data, quiet=False):
        accelerator = Accelerator() if not hasattr(self,'accelerator') else self.accelerator
        self.net,self.loss_fn,self.metrics_dict = accelerator.prepare(
            self.net,self.loss_fn,self.metrics_dict)
        val_data = accelerator.prepare(val_data)
        val_step_runner = self.StepRunner(net = self.net,stage="val",
                    loss_fn = self.loss_fn,metrics_dict=deepcopy(self.metrics_dict),
                    accelerator = accelerator)
        val_epoch_runner = self.EpochRunner(val_step_runner,quiet=quiet)
        with torch.no_grad():
            val_metrics = val_epoch_runner(val_data)
        return val_metrics
    
    def fit_ddp(self,num_processes,train_data,
            val_data=None, epochs=10, ckpt_path='checkpoint',
            patience=5, monitor="val_loss", mode="min", callbacks=None, 
            plot=True, wandb=False, quiet=None, 
            mixed_precision='no', cpu=False, gradient_accumulation_steps=1
           ):
        from accelerate import notebook_launcher
        args = (train_data,val_data,epochs,ckpt_path,patience,monitor,mode,
            callbacks,plot,wandb,quiet,mixed_precision,cpu,gradient_accumulation_steps)
        notebook_launcher(self.fit, args, num_processes=num_processes)
    
    def evaluate_ddp(self, num_processes, val_data, quiet=False):
        from accelerate import notebook_launcher
        args = (val_data,quiet)
        notebook_launcher(self.evaluate, args, num_processes=num_processes)









    
KerasModel.StepRunner = StepRunner 


#仅仅保存lora相关的可训练参数
def save_ckpt(self, ckpt_path='checkpoint', accelerator = None):
    unwrap_net = accelerator.unwrap_model(self.net)
    unwrap_net.save_pretrained(ckpt_path)
    
def load_ckpt(self, ckpt_path='checkpoint'):
    self.net = self.net.from_pretrained(self.net.base_model.model,ckpt_path)
    self.from_scratch = False
    
KerasModel.save_ckpt = save_ckpt 
KerasModel.load_ckpt = load_ckpt 
optimizer = torch.optim.AdamW(peft_model.parameters(),lr=cfg.lr) 

#########第二步实例化model
keras_model = KerasModel(peft_model,loss_fn = None,
        optimizer=optimizer) 
ckpt_path = 'chatglm2_my' #===========保存的路径.
#=========第三部下面函数自动训练, 画图, 和存模型.







































import sys,datetime
from tqdm import tqdm
from copy import deepcopy
import numpy as np
import pandas as pd
import torch
from accelerate import Accelerator

#=========设置打印信息的.
class EpochRunner:
    def __init__(self,steprunner,quiet=False):
        self.steprunner = steprunner
        self.stage = steprunner.stage
        self.accelerator = steprunner.accelerator
        self.net = steprunner.net
        self.quiet = quiet
        
    def __call__(self,dataloader):
        n = dataloader.size  if hasattr(dataloader,'size') else len(dataloader)
        loop = tqdm(enumerate(dataloader,start=1), 
                    total=n,
                    file=sys.stdout,
                    disable=not self.accelerator.is_local_main_process or self.quiet,
                    ncols=100
                   )
        epoch_losses = {}
        
        for step, batch in loop: 
            with self.accelerator.accumulate(self.net):
                step_losses,step_metrics = self.steprunner(batch)   
                step_log = dict(step_losses,**step_metrics)
                print(step_losses.items())
                for k,v in step_losses.items():
                    epoch_losses[k] = epoch_losses.get(k,0.0)+v
          #=============打印训练日志.
                if step<n:
                    loop.set_postfix(**step_log)
                    
                    if hasattr(self,'progress') and self.accelerator.is_local_main_process:
                        post_log = dict(**{'i':step,'n':n},**step_log)
                        self.progress.set_postfix(**post_log)

                elif step==n:
    
                    epoch_metrics = step_metrics
                    epoch_metrics.update({self.stage+"_"+name:metric_fn.compute().item() 
                                     for name,metric_fn in self.steprunner.metrics_dict.items()})
                    epoch_losses = {k:v/step for k,v in epoch_losses.items()}
                    epoch_log = dict(epoch_losses,**epoch_metrics)
                    loop.set_postfix(**epoch_log)
            
                    
                    if hasattr(self,'progress') and self.accelerator.is_local_main_process:
                        post_log = dict(**{'i':step,'n':n},**epoch_log)
                        self.progress.set_postfix(**post_log)
                    
                    for name,metric_fn in self.steprunner.metrics_dict.items():
                        metric_fn.reset()  
                else:
                    break
        return epoch_log


#===============修改下面代码为自己跑. 来优化性能:

from accelerate import Accelerator 
#============torchkeras来写训练代码果然牛逼,图标太牛逼了.
#======第一步设置好自定义的KerasModel
flag=0
class StepRunner:
    def __init__(self, net, loss_fn, accelerator=None, stage = "train", metrics_dict = None, 
                 optimizer = None, lr_scheduler = None
                 ):
        self.net,self.loss_fn,self.metrics_dict,self.stage = net,loss_fn,metrics_dict,stage
        self.optimizer,self.lr_scheduler = optimizer,lr_scheduler
        self.accelerator = accelerator if accelerator is not None else Accelerator() 
        if self.stage=='train':
            self.net.train() 
        else:
            self.net.eval()
        self.flag=0

    
    def __call__(self, batch):
        
        #loss
        global flag
        if not flag: #=======我们打印第一个输入变量的数据,方便理解数据集.
            
            flag=1
        with self.accelerator.autocast():
            loss = self.net(input_ids=batch["input_ids"],labels=batch["labels"]).loss
#=========================从这往下的全是固定写法不用动.
        #backward()
        if self.optimizer is not None and self.stage=="train":
            self.accelerator.backward(loss)
            if self.accelerator.sync_gradients:
                self.accelerator.clip_grad_norm_(self.net.parameters(), 1.0)
            self.optimizer.step()
            if self.lr_scheduler is not None:
                self.lr_scheduler.step()
            self.optimizer.zero_grad()
            
        all_loss = self.accelerator.gather(loss).sum()
        
        #losses (or plain metrics that can be averaged)
        step_losses = {self.stage+"_loss":all_loss.item()}
        
        #metrics (stateful metrics)
        step_metrics = {}
        
        if self.stage=="train":
            if self.optimizer is not None:
                step_metrics['lr'] = self.optimizer.state_dict()['param_groups'][0]['lr']
            else:
                step_metrics['lr'] = 0.0
        return step_losses,step_metrics
class KerasModel(torch.nn.Module):
    
    StepRunner,EpochRunner = StepRunner,EpochRunner
    
    def __init__(self,net,loss_fn,metrics_dict=None,optimizer=None,lr_scheduler = None,tokenizer=None,mixed_precision=None,cpu=None,gradient_accumulation_steps=None):
        super().__init__()
        self.net,self.loss_fn,self.metrics_dict = net, loss_fn, torch.nn.ModuleDict(metrics_dict) 
        self.optimizer = optimizer if optimizer is not None else torch.optim.Adam(
            self.net.parameters(), lr=3e-4)
        self.lr_scheduler = lr_scheduler
        self.from_scratch = True     #没有加载加载预先的权重.#初始化时候没加载, scratcch是草图的意思表示没有权重在网络里面.
        
        self.accelerator= AC
    #########=============一般不用下面这2个保存加载, 适配性不够.
    def save_ckpt(self, ckpt_path=None, accelerator= None):
        accelerator = accelerator if accelerator is not None else self.accelerator
        net_dict = accelerator.get_state_dict(self.net)
        accelerator.save(net_dict,ckpt_path if ckpt_path is not None else self.ckpt_path)
      
    def load_ckpt(self, ckpt_path=None):
        self.net.load_state_dict(
            torch.load(ckpt_path if ckpt_path is not None else self.ckpt_path,
            map_location='cpu'))
        self.from_scratch = False

    def forward(self, x):
        return self.net.forward(x)
    
    def fit(self, train_data, val_data=None, epochs=10, ckpt_path='checkpoint',
            patience=5, monitor="val_loss", mode="min", callbacks=None, 
            plot=False,  wandb=False, quiet=None, 
            mixed_precision='no', cpu=False, gradient_accumulation_steps=1,dfhistorypath='dfhistory.csv'):
        from torchkeras.utils import colorful,is_jupyter
        self.__dict__.update(locals())

        device = str(self.accelerator.device)
        device_type = '🐌'  if 'cpu' in device else ('⚡️' if 'cuda' in device else '🚀')
        self.accelerator.print(
            colorful("<<<<<< "+device_type +" "+ device +" is used >>>>>>"))
    
        self.net,self.loss_fn,self.metrics_dict,self.optimizer,self.lr_scheduler= self.accelerator.prepare(
            self.net,self.loss_fn,self.metrics_dict,self.optimizer,self.lr_scheduler)
        
        train_dataloader,val_dataloader = self.accelerator.prepare(train_data,val_data)
        train_dataloader.size = train_data.size if hasattr(train_data,'size') else len(train_data)
        train_dataloader.size = min(train_dataloader.size,len(train_dataloader))
        
        if val_data:
            val_dataloader.size = val_data.size if hasattr(val_data,'size') else len(val_data)
            val_dataloader.size = min(val_dataloader.size,len(val_dataloader))
        
        self.history = {}
        callbacks = callbacks if callbacks is not None else []
        
        if bool(plot):
            from torchkeras.kerascallbacks import VisProgress,VisMetric
            callbacks = [VisMetric(),VisProgress()]+callbacks
            
        if wandb!=False:
            from torchkeras.kerascallbacks import WandbCallback
            project = wandb if isinstance(wandb,str) else 'torchkeras'
            callbacks.append(WandbCallback(project=project))
            
        self.callbacks = [self.accelerator.prepare(x) for x in callbacks]
        
        if self.accelerator.is_local_main_process:
            [cb.on_fit_start(model = self) for cb in self.callbacks if hasattr(cb,'on_fit_start')]
                
        start_epoch = 1 if self.from_scratch else 0
        
        if bool(plot) or quiet is None:
            quiet = True
        
        quiet_fn = (lambda epoch:quiet) if isinstance(quiet,bool) else (
            (lambda epoch:epoch>quiet) if isinstance(quiet,int) else quiet)
        #==========================训练.
        for epoch in range(start_epoch,epochs+1):
            if 0:
                should_quiet = quiet_fn(epoch)
            
                if not should_quiet:
                    nowtime = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')
                    self.accelerator.print("\n"+"=========="*8 + "%s"%nowtime)
                    self.accelerator.print("Epoch {0} / {1}".format(epoch, epochs)+"\n")
            should_quiet=True
            # 1，train -------------------------------------------------  
            train_step_runner = self.StepRunner(    #训练一个step
                    net = self.net,
                    loss_fn = self.loss_fn,
                    accelerator = self.accelerator,
                    stage="train",
                    metrics_dict=deepcopy(self.metrics_dict),
                    optimizer = self.optimizer if epoch>0 else None,
                    lr_scheduler = self.lr_scheduler if epoch>0 else None
            )

            train_epoch_runner = self.EpochRunner(train_step_runner,should_quiet)
            train_metrics = {'epoch':epoch}
            train_metrics.update(train_epoch_runner(train_dataloader))

            for name, metric in train_metrics.items():
                    self.history[name] = self.history.get(name, []) + [metric]
            #==================调用callback函数!!!!!!!!!
            if 0:
                if self.accelerator.is_local_main_process: #=================420函数的含义就是调用全部的self.callbacks函数!!!!!!!!
                    [cb.on_train_epoch_end(model = self) for cb in self.callbacks 
                    if hasattr(cb,'on_train_epoch_end')]
                    
            # 2，validate -------------------------------------------------
            if val_dataloader is not None:
                val_step_runner = self.StepRunner(
                    net = self.net,
                    loss_fn = self.loss_fn,
                    accelerator = self.accelerator,
                    stage="val",
                    metrics_dict= deepcopy(self.metrics_dict)
                )
                val_epoch_runner = self.EpochRunner(val_step_runner,should_quiet)
                with torch.no_grad():
                    val_metrics = val_epoch_runner(val_dataloader)

                for name, metric in val_metrics.items():
                    self.history[name] = self.history.get(name, []) + [metric]
                
            if self.accelerator.is_local_main_process:
                [cb.on_validation_epoch_end(model = self) for cb in self.callbacks 
                 if hasattr(cb,'on_validation_epoch_end')]

            # 3，early-stopping -------------------------------------------------
            if 1: #======这部分逻辑不太对啊.#保存太密集了.我修改掉保存的.
                self.accelerator.wait_for_everyone()
                arr_scores = self.history[monitor]
                best_score_idx = np.argmax(arr_scores) if mode=="max" else np.argmin(arr_scores)



                if len(arr_scores)-best_score_idx>patience:
                    break
                
        if self.accelerator.is_local_main_process:   
            dfhistory = pd.DataFrame(self.history)
            # [cb.on_fit_end(model = self) for cb in self.callbacks 
            #      if hasattr(cb,'on_fit_end')]
            if epoch<epochs:
                self.accelerator.print(colorful(
                        "<<<<<< {} without improvement in {} epoch,""early stopping >>>>>> \n"
                    ).format(monitor,patience))
            # self.net = self.accelerator.unwrap_model(self.net)
            # self.net.cpu()

#             dfhistory = pd.DataFrame(model.history)
            dfhistory.to_csv(self.dfhistorypath,index=None)
            # self.load_ckpt(ckpt_path)
            return dfhistory
#=====================!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!实现预测代码.
    def predict(self,q,max_length):

        self.net.eval()
        accelerator = self.accelerator
        self.net,self.loss_fn,self.metrics_dict = accelerator.prepare(
            self.net,self.loss_fn,self.metrics_dict)

        with accelerator.autocast() , torch.no_grad():

            a=self.net.chat(tokenizer,query=q,history=[],max_length=max_length)
            print(a,'debug!!!!!!!!!!!')

        return a
    
    
    
    
    
    
    
    
    
    
    
    
    
    
    
    
    
    
    
    
    

    def evaluate(self, val_data, quiet=False):
        accelerator = Accelerator() if not hasattr(self,'accelerator') else self.accelerator
        self.net,self.loss_fn,self.metrics_dict = accelerator.prepare(
            self.net,self.loss_fn,self.metrics_dict)
        val_data = accelerator.prepare(val_data)
        val_step_runner = self.StepRunner(net = self.net,stage="val",
                    loss_fn = self.loss_fn,metrics_dict=deepcopy(self.metrics_dict),
                    accelerator = accelerator)
        val_epoch_runner = self.EpochRunner(val_step_runner,quiet=quiet)
        with torch.no_grad():
            val_metrics = val_epoch_runner(val_data)
        return val_metrics
    
    def fit_ddp(self,num_processes,train_data,
            val_data=None, epochs=10, ckpt_path='checkpoint',
            patience=5, monitor="val_loss", mode="min", callbacks=None, 
            plot=True, wandb=False, quiet=None, 
            mixed_precision='no', cpu=False, gradient_accumulation_steps=1
           ):
        from accelerate import notebook_launcher
        args = (train_data,val_data,epochs,ckpt_path,patience,monitor,mode,
            callbacks,plot,wandb,quiet,mixed_precision,cpu,gradient_accumulation_steps)
        notebook_launcher(self.fit, args, num_processes=num_processes)
    
    def evaluate_ddp(self, num_processes, val_data, quiet=False):
        from accelerate import notebook_launcher
        args = (val_data,quiet)
        notebook_launcher(self.evaluate, args, num_processes=num_processes)









    
KerasModel.StepRunner = StepRunner 


#仅仅保存lora相关的可训练参数
def save_ckpt(self, ckpt_path='checkpoint', accelerator = None):
    unwrap_net = accelerator.unwrap_model(self.net)
    unwrap_net.save_pretrained(ckpt_path)
    
def load_ckpt(self, ckpt_path='checkpoint'):
    self.net = self.net.from_pretrained(self.net.base_model.model,ckpt_path)
    self.from_scratch = False
    
KerasModel.save_ckpt = save_ckpt 
KerasModel.load_ckpt = load_ckpt 
optimizer = torch.optim.AdamW(peft_model.parameters(),lr=cfg.lr) 

#########第二步实例化model
keras_model = KerasModel(peft_model,loss_fn = None,
        optimizer=optimizer, mixed_precision='fp16',cpu=False,
            gradient_accumulation_steps=cfg.gradient_accumulation_steps) 
ckpt_path = 'chatglm2_my' #===========保存的路径.
#=========第三部下面函数自动训练, 画图, 和存模型.

print('配置完毕')
if 1: # 测试

        print('训练之前开始测试')
        print(keras_model.predict('梦中情炉',max_length=200)[0])
        print(keras_model.predict('世界上最高的山峰是什么',max_length=200)[0])

if 1:#训练
    keras_model.fit(train_data = dl_train,
                val_data = dl_train,
                epochs=20,
                patience=20,
                monitor='val_loss',
                mode='min',
                ckpt_path = ckpt_path,

                plot=False, # 不画画节省空间.
          
               )
if 1: # 测试

        print('训练之后开始测试')
        print(keras_model.predict('梦中情炉',max_length=200)[0])
        



In [None]:
#=========查看数据.
print(len(dl_train))
# for batch in dl_train:
#     print(batch)
for i in ds_train_raw:
    print(i)
