In [1]:
!nvidia-smi

## Get Dataset

In [2]:
import tensorflow as tf
import tensorflow_datasets as tfds
from pathlib import Path
import torch
import re
import time

In [3]:
BATCH_SIZE = 16

SHUFFEL_SIZE = 1024

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

learning_rate = 3e-5

In [4]:
print(device)

cuda:0


In [5]:
cnn_dailymail = tfds.load(name="cnn_dailymail")

In [6]:
train_tfds = cnn_dailymail['train']
test_tfds = cnn_dailymail['test']
val_tfds = cnn_dailymail['validation']

In [7]:
train_ds_iter = tfds.as_numpy(train_tfds)
val_ds_iter = tfds.as_numpy(val_tfds)
test_ds_iter = tfds.as_numpy(test_tfds)

In [8]:
def write_data(iter_dataset, name, path="data/"):
    
    articles_file = Path(path + name + "/article").open("w")
    highlights_file = Path(path + name + "/highlights").open("w")

    for item in iter_dataset:
        articles_file.write(item["article"].decode("utf-8") + "\n")
        articles_file.flush()
        highlights_file.write(item["highlights"].decode("utf-8").replace("\n", " ") + "\n")
        highlights_file.flush()

In [9]:
write_data(train_ds_iter, "train")

In [10]:
write_data(test_ds_iter, "test")
write_data(val_ds_iter, "val")

## Define Model

In [11]:
from transformers import T5Tokenizer, T5ForConditionalGeneration

tokenizer = T5Tokenizer.from_pretrained('t5-small')
model = T5ForConditionalGeneration.from_pretrained('t5-small').to(device)

task_specific_params = model.config.task_specific_params
if task_specific_params is not None:
    model.config.update(task_specific_params.get("summarization", {}))
    

optimizer = torch.optim.Adam(model.parameters(),lr=learning_rate, weight_decay=0.0001)

## Define Pytorch Dataset

In [12]:
def read_files(name):
    article_path = "data/%s/article" % name
    highlights_path = "data/%s/highlights" % name
    
    articles = [x.rstrip() for x in open(article_path).readlines()]
    highlights = [x.rstrip() for x in open(highlights_path).readlines()]
    
    assert len(articles) == len(highlights)
    return articles, highlights

In [13]:
class MyDataset(torch.utils.data.Dataset):
    def __init__(self, articles, highlights):
        self.x = articles
        self.y = highlights
        
    def __getitem__(self, index):
        x = tokenizer.encode_plus(model.config.prefix + self.transfrom(self.x[index]), max_length=512, return_tensors="pt", pad_to_max_length=True)
        y = tokenizer.encode(self.transfrom(self.y[index]), max_length=150, return_tensors="pt", pad_to_max_length=True)
        return x['input_ids'].view(-1), x['attention_mask'].view(-1), y.view(-1)
    
    @staticmethod
    def transfrom(x):
        x = x.lower()
        x = re.sub("'(.*)'", r"\1", x)
        return x
    
    def __len__(self):
        return len(self.x)

In [14]:
def get_dataset(name):
    articles, highlights = read_files(name)
    return MyDataset(articles, highlights)

In [15]:
train_ds = get_dataset("train")
test_ds = get_dataset("test")
val_ds = get_dataset("val")

In [16]:
train_loader = torch.utils.data.DataLoader(train_ds, batch_size=BATCH_SIZE)
val_loader = torch.utils.data.DataLoader(val_ds, batch_size=BATCH_SIZE)
test_loader = torch.utils.data.DataLoader(test_ds, batch_size=BATCH_SIZE)

## Define Step function

In [17]:
pad_token_id = tokenizer.pad_token_id
def step(inputs_ids, attention_mask, y):
    y_ids = y[:, :-1].contiguous()
    lm_labels = y[:, 1:].clone()
    lm_labels[y[:, 1:] == pad_token_id] = -100
    output = model(inputs_ids, attention_mask=attention_mask, decoder_input_ids=y_ids, lm_labels=lm_labels)
    return output[0] # loss

## Train

In [None]:
EPOCHS = 1
log_interval = 200
train_loss = []
val_loss = []
for epoch in range(EPOCHS):
    model.train() 
    start_time = time.time()
    for i, (inputs_ids, attention_mask, y) in enumerate(train_loader):
        inputs_ids = inputs_ids.to(device)
        attention_mask = attention_mask.to(device)
        y = y.to(device)
        
        
        optimizer.zero_grad()
        loss = step(inputs_ids, attention_mask, y)
        train_loss.append(loss.item())
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
        optimizer.step()
            
        if (i + 1) % log_interval == 0:
            with torch.no_grad():
                x, x_mask, y = next(iter(val_loader))
                x = x.to(device)
                x_mask = x_mask.to(device)
                y = y.to(device)
                
                v_loss = step(x, x_mask, y)
                v_loss = v_loss.item()
                
                
                elapsed = time.time() - start_time
                print('| epoch {:3d} | [{:5d}/{:5d}] | '
                  'ms/batch {:5.2f} | '
                  'loss {:5.2f} | val loss {:5.2f}'.format(
                    epoch, i, len(train_loader),
                    elapsed * 1000 / log_interval,
                    loss.item(), v_loss))
                start_time = time.time()
                val_loss.append(v_loss)
                
                

| epoch   0 | [  199/17945] | ms/batch 362.94 | loss  2.24 | val loss  2.08
| epoch   0 | [  399/17945] | ms/batch 360.18 | loss  2.43 | val loss  2.00
| epoch   0 | [  599/17945] | ms/batch 363.08 | loss  2.33 | val loss  2.02
| epoch   0 | [  799/17945] | ms/batch 364.22 | loss  2.06 | val loss  1.93
| epoch   0 | [  999/17945] | ms/batch 365.22 | loss  2.28 | val loss  1.97


In [None]:
for i, (inputs_ids, attention_mask, y) in enumerate(train_loader):
    summaries = model.generate(input_ids=input_ids, attention_mask=attention_mask)
    dec = [tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in summaries]