In [1]:
import torch
from torch import nn
from torch.optim import AdamW
from transformers import BertTokenizer, BertModel, DataCollatorWithPadding, get_linear_schedule_with_warmup
from datasets import load_dataset
import evaluate
from torch.utils.data import DataLoader
from peft import get_peft_model, LoraConfig, TaskType
from functools import partial
import nltk
from nltk.tokenize import sent_tokenize
nltk.download('punkt')
import numpy as np

[nltk_data] Downloading package punkt to
[nltk_data]     C:\Users\10bao\AppData\Roaming\nltk_data...
[nltk_data]   Package punkt is already up-to-date!


In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

Using device: cuda


In [None]:
class MatchSum(nn.Module):
    def __init__(self, candidate_num, encoder_name='bert-base-uncased', hidden_size=768):
        super().__init__()
        self.hidden_size = hidden_size
        self.candidate_num = candidate_num
        self.encoder = BertModel.from_pretrained(encoder_name)
        
        self.doc_proj = nn.Linear(hidden_size, hidden_size)
        self.cand_proj = nn.Linear(hidden_size, hidden_size)
        self.dropout = nn.Dropout(0.1)

    def forward(self, text_ids, candidate_ids, summary_ids=None):
        pad_id = self.encoder.config.pad_token_id

        doc_mask = text_ids.ne(pad_id)
        doc_out = self.encoder(text_ids, attention_mask=doc_mask).last_hidden_state
        doc_emb = self.doc_proj(self.dropout(doc_out[:, 0]))  

        summary_score = None
        if summary_ids is not None:
            sum_mask = summary_ids.ne(pad_id)
            sum_out = self.encoder(summary_ids, attention_mask=sum_mask).last_hidden_state
            sum_emb = self.cand_proj(self.dropout(sum_out[:, 0]))  
            summary_score = torch.cosine_similarity(sum_emb, doc_emb, dim=-1)

        bs, k, seq = candidate_ids.shape
        flat_cand = candidate_ids.view(-1, seq)
        cand_mask = flat_cand.ne(pad_id)
        cand_out = self.encoder(flat_cand, attention_mask=cand_mask).last_hidden_state
        cand_emb = self.cand_proj(self.dropout(cand_out[:, 0])).view(bs, k, -1)  
        
        doc_expand = doc_emb.unsqueeze(1).expand_as(cand_emb)
        cand_score = torch.cosine_similarity(cand_emb, doc_expand, dim=-1)

        return {'score': cand_score, 'summary_score': summary_score}

In [4]:
raw = load_dataset('ccdv/cnn_dailymail', '3.0.0', trust_remote_code=True)
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

In [5]:
def split_sentences(batch):
    """Split article into sentences and tokenize"""
    try:
        sentences = sent_tokenize(batch['article'])
        sentences = [sent.strip() for sent in sentences if len(sent.strip()) > 10]
        sentences = sentences[:50]
        
        candidates = []
        for sent in sentences:
            tokens = tokenizer.encode(sent, truncation=True, max_length=128, add_special_tokens=True)
            if len(tokens) > 3: candidates.append(tokens)
        
        if not candidates:
            candidates = [tokenizer.encode(batch['article'][:200], truncation=True, max_length=128, add_special_tokens=True)]
        
        batch['candidates'] = candidates
        batch['article_ids'] = tokenizer.encode(batch['article'],    truncation=True, max_length=512, add_special_tokens=True)
        batch['summary_ids'] = tokenizer.encode(batch['highlights'], truncation=True, max_length=128, add_special_tokens=True)
        return batch
    
    except Exception as e:
        print(f"Error processing batch: {e}")
        batch['candidates']  = [tokenizer.encode("Default sentence.", add_special_tokens=True)]
        batch['article_ids'] = tokenizer.encode(batch['article'][:200], truncation=True, max_length=512, add_special_tokens=True)
        batch['summary_ids'] = tokenizer.encode(batch['highlights'],    truncation=True, max_length=128, add_special_tokens=True)
        return batch

In [None]:
processed = raw.map(split_sentences, batched=False, num_proc=1)

Map:   0%|          | 0/287113 [00:00<?, ? examples/s]

In [None]:
def collate_fn(batch, tokenizer):
    texts = [item['article_ids'] for item in batch]
    cands = [item['candidates']  for item in batch]
    sums  = [item['summary_ids'] for item in batch]

    max_text_len = max(len(t) for t in texts)
    text_tensor = torch.full((len(batch), max_text_len), fill_value=tokenizer.pad_token_id, dtype=torch.long)
    for i, text_ids in enumerate(texts):
        length = min(len(text_ids), max_text_len)
        text_tensor[i, :length] = torch.tensor(text_ids[:length], dtype=torch.long)

    max_sum_len = max(len(s) for s in sums)
    sum_tensor = torch.full((len(batch), max_sum_len), fill_value=tokenizer.pad_token_id, dtype=torch.long)
    for i, sum_ids in enumerate(sums):
        length = min(len(sum_ids), max_sum_len)
        sum_tensor[i, :length] = torch.tensor(sum_ids[:length], dtype=torch.long)

    max_k = max(len(cand_list) for cand_list in cands)
    max_len = max(len(sent) for cand_list in cands for sent in cand_list)
    
    cand_tensor = torch.full((len(batch), max_k, max_len), fill_value=tokenizer.pad_token_id, dtype=torch.long)
    
    for i, cand_list in enumerate(cands):
        for j, sent_ids in enumerate(cand_list):
            length = min(len(sent_ids), max_len)
            cand_tensor[i, j, :length] = torch.tensor(sent_ids[:length], dtype=torch.long)

    return {
        'text_ids': text_tensor,
        'candidate_ids': cand_tensor,
        'summary_ids': sum_tensor
    }

In [None]:
args = {
    'candidate_num': 50,
    'encoder_name': 'bert-base-uncased',
    'batch_size': 4,
    'n_epochs': 3,
    'lr': 2e-5,
    'warmup_steps': 100,
    'max_train_samples': 1000,
    'max_val_samples': 100,
    'max_test_samples': 100
}

In [None]:
train_ds = processed['train'].select(range(0,10))
val_ds = processed['validation'].select(range(0,10))
test_ds = processed['test'].select(range(0,10))
# train_ds = processed['train'].select(range(min(args['max_train_samples'], len(processed['train']))))
# val_ds = processed['validation'].select(range(min(args['max_val_samples'], len(processed['validation']))))
# test_ds = processed['test'].select(range(min(args['max_test_samples'], len(processed['test']))))

train_loader = DataLoader(train_ds, batch_size=args['batch_size'], shuffle=True, collate_fn=partial(collate_fn, tokenizer=tokenizer))
val_loader = DataLoader(val_ds, batch_size=args['batch_size'], collate_fn=partial(collate_fn, tokenizer=tokenizer))
test_loader = DataLoader(test_ds, batch_size=args['batch_size'], collate_fn=partial(collate_fn, tokenizer=tokenizer))

In [None]:
peft_config = LoraConfig(
    task_type=TaskType.FEATURE_EXTRACTION,
    inference_mode=False,
    r=16,
    lora_alpha=32,
    lora_dropout=0.1,
    target_modules=['query', 'key', 'value']
)

In [None]:
base_model = MatchSum(args['candidate_num'], args['encoder_name'])
base_model.encoder = get_peft_model(base_model.encoder, peft_config)

for name, param in base_model.encoder.named_parameters():
    if 'lora_' not in name:
        param.requires_grad = False
model = base_model.to(device)

In [None]:
optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=args['lr'])
total_steps = len(train_loader) * args['n_epochs']
scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=args['warmup_steps'],
    num_training_steps=total_steps
)

rouge = evaluate.load('rouge')
bertscore = evaluate.load('bertscore')

In [None]:
def compute_metrics(preds, refs):
    try:
        r = rouge.compute(predictions=preds, references=refs)
        b = bertscore.compute(predictions=preds, references=refs, lang='en')
        return {
            'rouge1': r['rouge1'],
            'rouge2': r['rouge2'],
            'rougeL': r['rougeL'],
            'bertscore_f1': np.mean(b['f1'])
        }
    except Exception as e:
        print(f"Error computing metrics: {e}")
        return {'rouge1': 0.0, 'rouge2': 0.0, 'rougeL': 0.0, 'bertscore_f1': 0.0}

In [None]:
def generate_summaries(model, dataloader):
    model.eval()
    preds, refs = [], []
    
    with torch.no_grad():
        for batch_idx, batch in enumerate(dataloader):
            try:
                text_ids = batch['text_ids'].to(device)
                cand_ids = batch['candidate_ids'].to(device)
                sum_ids = batch['summary_ids'].to(device)
                
                out = model(text_ids, cand_ids)
                top_idx = out['score'].argmax(dim=1)
                
                for i, idx in enumerate(top_idx):
                    sent_ids = cand_ids[i, idx].tolist()
                    sent_ids = [id for id in sent_ids if id != tokenizer.pad_token_id]
                    pred_text = tokenizer.decode(sent_ids, skip_special_tokens=True)
                    preds.append(pred_text)
                    
                    ref_ids = sum_ids[i].tolist()
                    ref_ids = [id for id in ref_ids if id != tokenizer.pad_token_id]
                    ref_text = tokenizer.decode(ref_ids, skip_special_tokens=True)
                    refs.append(ref_text)
                    
            except Exception as e:
                print(f"Error in batch {batch_idx}: {e}")
                preds.append("Error generating summary.")
                refs.append("Error in reference.")
                
    return preds, refs

In [None]:
best_rouge = 0.0

for epoch in range(args['n_epochs']):
    model.train()
    total_loss = 0
    num_batches = 0
    
    for batch_idx, batch in enumerate(train_loader):
        try:
            text_ids = batch['text_ids'].to(device)
            cand_ids = batch['candidate_ids'].to(device)
            sum_ids = batch['summary_ids'].to(device)

            out = model(text_ids, cand_ids, sum_ids)
            
            pos_score = out['summary_score']  
            neg_max, _ = out['score'].max(dim=1)  
            
            margin = 0.5
            loss = torch.relu(neg_max + margin - pos_score).mean()

            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)  
            optimizer.step()
            scheduler.step()

            total_loss += loss.item()
            num_batches += 1
            
            if batch_idx % 50 == 0:
                print(f"Epoch {epoch+1}, Batch {batch_idx}, Loss: {loss.item():.4f}")
                
        except Exception as e:
            print(f"Error in training batch {batch_idx}: {e}")
            continue

    avg_loss = total_loss / max(num_batches, 1)
    print(f"Epoch {epoch+1} Average Loss: {avg_loss:.4f}")

    print("Evaluating on validation set...")
    preds, refs = generate_summaries(model, val_loader)
    val_metrics = compute_metrics(preds, refs)
    print(f"Epoch {epoch+1} Validation Metrics:")
    for k, v in val_metrics.items():
        print(f"  {k}: {v:.4f}")
    
    if val_metrics['rougeL'] > best_rouge:
        best_rouge = val_metrics['rougeL']
        print(f"New best ROUGE-L: {best_rouge:.4f}")

Error in training batch 0: CUDA error: out of memory
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

Error in training batch 1: CUDA error: out of memory
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

Error in training batch 2: CUDA error: out of memory
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

Error in training batch 3: CUDA error: out of memory
CUDA kernel errors might be asynchronously reported at some

Some weights of RobertaModel were not initialized from the model checkpoint at roberta-large and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Error computing metrics: CUDA error: out of memory
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

Epoch 1 Validation Metrics:
  rouge1: 0.0000
  rouge2: 0.0000
  rougeL: 0.0000
  bertscore_f1: 0.0000
Error in training batch 0: CUDA error: out of memory
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

Error in training batch 1: CUDA error: out of memory
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

Error in tra

Some weights of RobertaModel were not initialized from the model checkpoint at roberta-large and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Error computing metrics: CUDA error: out of memory
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

Epoch 2 Validation Metrics:
  rouge1: 0.0000
  rouge2: 0.0000
  rougeL: 0.0000
  bertscore_f1: 0.0000
Error in training batch 0: CUDA error: out of memory
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

Error in training batch 1: CUDA error: out of memory
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

Error in tra

Some weights of RobertaModel were not initialized from the model checkpoint at roberta-large and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Error computing metrics: CUDA error: out of memory
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

Epoch 3 Validation Metrics:
  rouge1: 0.0000
  rouge2: 0.0000
  rougeL: 0.0000
  bertscore_f1: 0.0000


In [None]:
print("Final evaluation on test set...")
preds, refs = generate_summaries(model, test_loader)
test_metrics = compute_metrics(preds, refs)

print("\nFinal Test Metrics:")
for k, v in test_metrics.items():
    print(f"  {k}: {v:.4f}")

In [None]:
print("\nExample Predictions:")
for i in range(min(3, len(preds))):
    print(f"\nExample {i+1}:")
    print(f"Prediction: {preds[i]}")
    print(f"Reference:  {refs[i]}")

print("\nTraining completed!")

ImportError: cannot import name 'Tester' from 'fastNLP' (d:\anaconda3\Lib\site-packages\fastNLP\__init__.py)