In [1]:
import torch
from datasets import Dataset, load_dataset
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

Loading data

In [3]:
data=load_dataset('argilla/FinePersonas-Conversations-Email-Summaries')
data

DatasetDict({
    train: Dataset({
        features: ['conversation_id', 'email', 'maximum_brevity_summary', 'summary', 'distilabel_metadata', 'model_name'],
        num_rows: 363584
    })
})

In [4]:
data_small=data['train'].select(range(165000))
data_small

Dataset({
    features: ['conversation_id', 'email', 'maximum_brevity_summary', 'summary', 'distilabel_metadata', 'model_name'],
    num_rows: 165000
})

In [5]:
data_small=data_small.train_test_split(test_size=0.1)
data_small

DatasetDict({
    train: Dataset({
        features: ['conversation_id', 'email', 'maximum_brevity_summary', 'summary', 'distilabel_metadata', 'model_name'],
        num_rows: 148500
    })
    test: Dataset({
        features: ['conversation_id', 'email', 'maximum_brevity_summary', 'summary', 'distilabel_metadata', 'model_name'],
        num_rows: 16500
    })
})

In [6]:
data_small['train'][0]

{'conversation_id': 996,
 'email': "Subject: Re: Reaching out for expertise\n\nEmily,\n\nI had a chance to review the outline, and I think it's a great starting point. I made a few suggestions and added some terms I think would be valuable to include. Please see the attached revised outline.\n\nI've also been thinking about how we could expand our collaboration beyond the glossary. I believe there's an opportunity to create an online course that combines your expertise in science communication and my knowledge of fertility education. We could use the glossary as the foundation and build upon it with engaging content, visuals, and interactive elements.\n\nWhat do you think? I'm happy to discuss this idea further during our call on Tuesday.\n\nBest regards,\nOlivia",
 'maximum_brevity_summary': 'Olivia reviewed and revised the outline, suggesting an expansion into an online course.',
 'summary': 'Olivia reviewed the outline and made suggestions, adding valuable terms. She proposes expand

In [7]:
data_small['train'][0]['maximum_brevity_summary']

'Olivia reviewed and revised the outline, suggesting an expansion into an online course.'

In [8]:
data_small['train'][0]['summary']

'Olivia reviewed the outline and made suggestions, adding valuable terms. She proposes expanding the collaboration to create an online course that combines expertise in science communication and fertility education, using the glossary as a foundation. Olivia suggests discussing this further during the call on Tuesday.'

Dataset preprocessing

In [3]:
check_point='facebook/bart-base' #facebook/bart-large-cnn
tokenizer=AutoTokenizer.from_pretrained(check_point)

config.json:   0%|          | 0.00/1.72k [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/899k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

In [10]:
data_input=data_small.remove_columns([
    'conversation_id', 'summary','distilabel_metadata', 'model_name'
    ])

In [11]:
data_input

DatasetDict({
    train: Dataset({
        features: ['email', 'maximum_brevity_summary'],
        num_rows: 148500
    })
    test: Dataset({
        features: ['email', 'maximum_brevity_summary'],
        num_rows: 16500
    })
})

In [12]:
max_length=1024
max_target=128
def token_func(example):
    input=tokenizer(
        example['email'],
        max_length=max_length,
        truncation=True
        )
    labels=tokenizer(
        example['maximum_brevity_summary'],
        max_length=max_target,
        truncation=True
        )
    input['labels']=labels['input_ids']
    return input

In [13]:
data_ecd=data_input.map(token_func, batched=True)

Map:   0%|          | 0/148500 [00:00<?, ? examples/s]

Map:   0%|          | 0/16500 [00:00<?, ? examples/s]

In [14]:
data_ecd

DatasetDict({
    train: Dataset({
        features: ['email', 'maximum_brevity_summary', 'input_ids', 'attention_mask', 'labels'],
        num_rows: 148500
    })
    test: Dataset({
        features: ['email', 'maximum_brevity_summary', 'input_ids', 'attention_mask', 'labels'],
        num_rows: 16500
    })
})

In [15]:
data_ecd=data_ecd.remove_columns(['email','maximum_brevity_summary'])
data_ecd

DatasetDict({
    train: Dataset({
        features: ['input_ids', 'attention_mask', 'labels'],
        num_rows: 148500
    })
    test: Dataset({
        features: ['input_ids', 'attention_mask', 'labels'],
        num_rows: 16500
    })
})

Peft model
* Using LORA to fine-tune the model

In [4]:
#Loading model
model=AutoModelForSeq2SeqLM.from_pretrained(check_point)

model.safetensors:   0%|          | 0.00/558M [00:00<?, ?B/s]

In [17]:
print(model)

BartForConditionalGeneration(
  (model): BartModel(
    (shared): BartScaledWordEmbedding(50265, 768, padding_idx=1)
    (encoder): BartEncoder(
      (embed_tokens): BartScaledWordEmbedding(50265, 768, padding_idx=1)
      (embed_positions): BartLearnedPositionalEmbedding(1026, 768)
      (layers): ModuleList(
        (0-5): 6 x BartEncoderLayer(
          (self_attn): BartSdpaAttention(
            (k_proj): Linear(in_features=768, out_features=768, bias=True)
            (v_proj): Linear(in_features=768, out_features=768, bias=True)
            (q_proj): Linear(in_features=768, out_features=768, bias=True)
            (out_proj): Linear(in_features=768, out_features=768, bias=True)
          )
          (self_attn_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (activation_fn): GELUActivation()
          (fc1): Linear(in_features=768, out_features=3072, bias=True)
          (fc2): Linear(in_features=3072, out_features=768, bias=True)
          (final_lay

In [18]:
from peft import LoraConfig, get_peft_model

In [19]:
model_config=LoraConfig(
    r=8,
    task_type='SEQ_2_SEQ_LM',
    target_modules=['k_proj', 'v_proj', 'q_proj', 'out_proj']
)

In [20]:
model_peft=get_peft_model(model, model_config).to(device)
model_peft.print_trainable_parameters()

trainable params: 884,736 || all params: 140,305,152 || trainable%: 0.6306


Training

In [21]:
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments, DataCollatorForSeq2Seq

In [22]:
data_collator=DataCollatorForSeq2Seq(
    tokenizer, 
    model,
    #model_peft
)

In [23]:
n_epoch=10
batch_size=4
training_args=Seq2SeqTrainingArguments(
    output_dir='Summarization',
    num_train_epochs=n_epoch,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    eval_strategy='epoch',
    logging_strategy='epoch',
    save_strategy='epoch'   
)

In [24]:
trainer=Seq2SeqTrainer(
    model,
    #model_peft,
    training_args,
    train_dataset=data_ecd['train'],
    eval_dataset=data_ecd['test'],
    data_collator=data_collator,
    #compute_metrics=compute_metrics,
    tokenizer=tokenizer
)

  trainer=Seq2SeqTrainer(


In [25]:
trainer.train()

  0%|          | 0/371260 [00:00<?, ?it/s]

{'loss': 1.1177, 'grad_norm': 3.1435399055480957, 'learning_rate': 4.75e-05, 'epoch': 1.0}


  0%|          | 0/2063 [00:00<?, ?it/s]

{'eval_loss': 0.8072907328605652, 'eval_runtime': 120.1154, 'eval_samples_per_second': 137.368, 'eval_steps_per_second': 17.175, 'epoch': 1.0}
{'loss': 0.9268, 'grad_norm': 3.656320095062256, 'learning_rate': 4.5e-05, 'epoch': 2.0}


  0%|          | 0/2063 [00:00<?, ?it/s]

{'eval_loss': 0.7364647388458252, 'eval_runtime': 120.0431, 'eval_samples_per_second': 137.451, 'eval_steps_per_second': 17.185, 'epoch': 2.0}
{'loss': 0.8657, 'grad_norm': 2.831218957901001, 'learning_rate': 4.25e-05, 'epoch': 3.0}


  0%|          | 0/2063 [00:00<?, ?it/s]

{'eval_loss': 0.7016783356666565, 'eval_runtime': 119.9694, 'eval_samples_per_second': 137.535, 'eval_steps_per_second': 17.196, 'epoch': 3.0}
{'loss': 0.8277, 'grad_norm': 2.8726093769073486, 'learning_rate': 4e-05, 'epoch': 4.0}


  0%|          | 0/2063 [00:00<?, ?it/s]

{'eval_loss': 0.6784924864768982, 'eval_runtime': 119.8769, 'eval_samples_per_second': 137.641, 'eval_steps_per_second': 17.209, 'epoch': 4.0}
{'loss': 0.8012, 'grad_norm': 3.9873626232147217, 'learning_rate': 3.7500000000000003e-05, 'epoch': 5.0}


  0%|          | 0/2063 [00:00<?, ?it/s]

{'eval_loss': 0.6583265066146851, 'eval_runtime': 119.7314, 'eval_samples_per_second': 137.808, 'eval_steps_per_second': 17.23, 'epoch': 5.0}
{'loss': 0.7807, 'grad_norm': 2.719867467880249, 'learning_rate': 3.5e-05, 'epoch': 6.0}


  0%|          | 0/2063 [00:00<?, ?it/s]

{'eval_loss': 0.6415004134178162, 'eval_runtime': 119.7494, 'eval_samples_per_second': 137.788, 'eval_steps_per_second': 17.228, 'epoch': 6.0}
{'loss': 0.765, 'grad_norm': 2.8422605991363525, 'learning_rate': 3.2500000000000004e-05, 'epoch': 7.0}


  0%|          | 0/2063 [00:00<?, ?it/s]

{'eval_loss': 0.6389526724815369, 'eval_runtime': 119.7124, 'eval_samples_per_second': 137.83, 'eval_steps_per_second': 17.233, 'epoch': 7.0}
{'loss': 0.7517, 'grad_norm': 3.8478503227233887, 'learning_rate': 3e-05, 'epoch': 8.0}


  0%|          | 0/2063 [00:00<?, ?it/s]

{'eval_loss': 0.6312384009361267, 'eval_runtime': 119.6692, 'eval_samples_per_second': 137.88, 'eval_steps_per_second': 17.239, 'epoch': 8.0}
{'loss': 0.741, 'grad_norm': 3.2151739597320557, 'learning_rate': 2.7500000000000004e-05, 'epoch': 9.0}


  0%|          | 0/2063 [00:00<?, ?it/s]

{'eval_loss': 0.6209737062454224, 'eval_runtime': 119.6625, 'eval_samples_per_second': 137.888, 'eval_steps_per_second': 17.24, 'epoch': 9.0}
{'loss': 0.7323, 'grad_norm': 2.5233752727508545, 'learning_rate': 2.5e-05, 'epoch': 10.0}


  0%|          | 0/2063 [00:00<?, ?it/s]

{'eval_loss': 0.6165818572044373, 'eval_runtime': 119.6475, 'eval_samples_per_second': 137.905, 'eval_steps_per_second': 17.242, 'epoch': 10.0}
{'loss': 0.7251, 'grad_norm': 2.726480722427368, 'learning_rate': 2.25e-05, 'epoch': 11.0}


  0%|          | 0/2063 [00:00<?, ?it/s]

{'eval_loss': 0.610162615776062, 'eval_runtime': 119.6747, 'eval_samples_per_second': 137.874, 'eval_steps_per_second': 17.238, 'epoch': 11.0}
{'loss': 0.7181, 'grad_norm': 2.872225284576416, 'learning_rate': 2e-05, 'epoch': 12.0}


  0%|          | 0/2063 [00:00<?, ?it/s]

{'eval_loss': 0.6049325466156006, 'eval_runtime': 119.6334, 'eval_samples_per_second': 137.921, 'eval_steps_per_second': 17.244, 'epoch': 12.0}
{'loss': 0.7122, 'grad_norm': 2.6132428646087646, 'learning_rate': 1.75e-05, 'epoch': 13.0}


  0%|          | 0/2063 [00:00<?, ?it/s]

{'eval_loss': 0.6025649905204773, 'eval_runtime': 119.6363, 'eval_samples_per_second': 137.918, 'eval_steps_per_second': 17.244, 'epoch': 13.0}
{'loss': 0.7078, 'grad_norm': 2.4376466274261475, 'learning_rate': 1.5e-05, 'epoch': 14.0}


  0%|          | 0/2063 [00:00<?, ?it/s]

{'eval_loss': 0.6003775596618652, 'eval_runtime': 119.5791, 'eval_samples_per_second': 137.984, 'eval_steps_per_second': 17.252, 'epoch': 14.0}
{'loss': 0.7042, 'grad_norm': 2.9944655895233154, 'learning_rate': 1.25e-05, 'epoch': 15.0}


  0%|          | 0/2063 [00:00<?, ?it/s]

{'eval_loss': 0.5969794988632202, 'eval_runtime': 119.5729, 'eval_samples_per_second': 137.991, 'eval_steps_per_second': 17.253, 'epoch': 15.0}
{'loss': 0.7014, 'grad_norm': 3.6201438903808594, 'learning_rate': 1e-05, 'epoch': 16.0}


  0%|          | 0/2063 [00:00<?, ?it/s]

{'eval_loss': 0.5952738523483276, 'eval_runtime': 119.6041, 'eval_samples_per_second': 137.955, 'eval_steps_per_second': 17.249, 'epoch': 16.0}
{'loss': 0.698, 'grad_norm': 3.2962067127227783, 'learning_rate': 7.5e-06, 'epoch': 17.0}


  0%|          | 0/2063 [00:00<?, ?it/s]

{'eval_loss': 0.5933892130851746, 'eval_runtime': 119.5714, 'eval_samples_per_second': 137.993, 'eval_steps_per_second': 17.253, 'epoch': 17.0}
{'loss': 0.6958, 'grad_norm': 2.539396047592163, 'learning_rate': 5e-06, 'epoch': 18.0}


  0%|          | 0/2063 [00:00<?, ?it/s]

{'eval_loss': 0.5928394198417664, 'eval_runtime': 119.6053, 'eval_samples_per_second': 137.954, 'eval_steps_per_second': 17.248, 'epoch': 18.0}
{'loss': 0.6931, 'grad_norm': 3.201368570327759, 'learning_rate': 2.5e-06, 'epoch': 19.0}


  0%|          | 0/2063 [00:00<?, ?it/s]

{'eval_loss': 0.5914831757545471, 'eval_runtime': 119.5628, 'eval_samples_per_second': 138.003, 'eval_steps_per_second': 17.255, 'epoch': 19.0}
{'loss': 0.6927, 'grad_norm': 3.764519214630127, 'learning_rate': 0.0, 'epoch': 20.0}


  0%|          | 0/2063 [00:00<?, ?it/s]

{'eval_loss': 0.591890811920166, 'eval_runtime': 119.4881, 'eval_samples_per_second': 138.089, 'eval_steps_per_second': 17.265, 'epoch': 20.0}
{'train_runtime': 54579.3888, 'train_samples_per_second': 54.416, 'train_steps_per_second': 6.802, 'train_loss': 0.7679103551082799, 'epoch': 20.0}


TrainOutput(global_step=371260, training_loss=0.7679103551082799, metrics={'train_runtime': 54579.3888, 'train_samples_per_second': 54.416, 'train_steps_per_second': 6.802, 'total_flos': 4.626931935228641e+17, 'train_loss': 0.7679103551082799, 'epoch': 20.0})

Saving model

In [None]:
from transformers import pipeline
model.save_pretrained('./summarization_ft')

In [6]:
pipe=pipeline(
    'summarization',
    model='summarization_peft',
    #"facebook/bart-large-cnn",
    #check_point,
    #
    #max_length=1028,
    device=device)