In [28]:
import torch
import pandas as pd
from datasets import load_dataset
from torch.utils.data import DataLoader,Dataset
# def train(config: str, wandbkey: Optional[str] = None, debug_mode: bool = False):
from models.model import Model
from transformers import T5ForConditionalGeneration, T5Tokenizer, AdamW


In [36]:
max_input_length = 512
max_target_length = 128
batch = 8
print_every = 50
prefix = "Content:"

In [42]:
class wikiData(Dataset):
    def __init__(self, df, tokenizer, max_length=128):

        self.tokenizer = tokenizer
        self.input_ids = []
        self.attention_mask = []
        self.labels = []
        inputs = [prefix + " ".join(str(text).split()[:10]) for text in df["body_text"]]
        input_tokenize = tokenizer( 
                                inputs,
                                add_special_tokens=True,        #Add Special tokens like [CLS] and [SEP]
                                max_length=max_length,
                                padding = 'max_length',         #for padding to max_length for equal sequence length
                                truncation = True,              #truncate the text if it is greater than max_length
                                return_attention_mask=True,     #will return attention mask
                                return_tensors="pt"             #return tensor formate
                                )

        self.input_ids = torch.tensor(input_tokenize['input_ids'])
        self.attention_mask = torch.tensor(input_tokenize['attention_mask'])
        
        with tokenizer.as_target_tokenizer():
            label_tokenize = tokenizer(
                                    list(df["body_text"]), 
                                    add_special_tokens=True,        #Add Special tokens like [CLS] and [SEP]
                                    max_length=max_length,
                                    padding = 'max_length',         #for padding to max_length for equal sequence length
                                    truncation = True,              #truncate the text if it is greater than max_length
                                    return_attention_mask=True,     #will return attention mask
                                    return_tensors="pt"
                                    )
                
            self.labels = torch.tensor(label_tokenize['input_ids'])
    def __len__(self):
        return len(self.input_ids)

    def __getitem__(self, idx):
        return self.input_ids[idx], self.attention_mask[idx], self.labels[idx] 
    

In [43]:
import random
def valid(model, valid_dataloader,tokenizer):
    model.eval()
    
    running_loss = 0
    for batch in valid_dataloader :
        input_ids = batch[0]
        masks = batch[1]
        labels = batch[2]
        # Turn off gradients for validation, will speed up inference
        with torch.no_grad():
            outputs = model(input_ids=input_ids, attention_mask=masks, labels=labels)
            loss = outputs.loss
        running_loss += loss.item()

    random_batch = random.choice(list(valid_dataloader))

    
    original_text = tokenizer.decode(random_batch[2][0], skip_special_tokens=True) 
    print("Original content:", original_text)
    outputs = model.generate(random_batch[0])
    print("Generate content:",outputs)

    return(running_loss/len(valid_dataloader))
    
    

In [46]:
def train():
    # Hyperparameters
    lr = 5e-5
    epochs = 20
    batch = 8
    print_every = 50
    seed = 123
    
    train_losses = []
    valid_losses = []

    model = Model(lr)

    # Optimizer and tokenizer 
    tokenizer = model.tokenizer
    optimizer = model.configure_optimizers()

    
    # Readfile and make to dataloader
    filepath = "../data/processed/"
    df_train = pd.read_csv(filepath+'train.csv')
    df_valid = pd.read_csv(filepath+'valid.csv')    
    train_data = wikiData(df_train.head(2000) ,tokenizer,max_target_length)
    valid_data = wikiData(df_valid.head(200) ,tokenizer,max_target_length)
    train_dataloader = DataLoader(train_data, batch_size =batch,shuffle= True)
    valid_dataloader = DataLoader(valid_data, batch_size =batch,shuffle= True)

    
    for e in range(epochs):
        train_loss = 0
        running_loss = 0
        model.train()
        print("Epoch: {}/{}.. ".format(e + 1, epochs))
        for steps, batch in enumerate(train_dataloader):
            # load data and labels in the batch
            input_ids = batch[0]
            masks = batch[1]
            labels = batch[2]

            # Training
            model.zero_grad()
            outputs = model(input_ids=input_ids, attention_mask=masks, labels=labels)
            loss = outputs.loss
            logits = outputs.logits
            
            running_loss += loss.item()
            train_loss += loss.item()
            if steps % print_every == 0 and not steps == 0:
                # original_text = tokenizer.decode(labels[0], skip_special_tokens=True)
                # print("Original content:", original_text)
                # outputs = model.generate(input_ids)
                # print("Generate content:",outputs)
                print(
                    "Batch: {}/{}.. ".format(steps, len(train_dataloader)),
                    "Training Loss: {:.3f}.. ".format(running_loss / print_every))
                running_loss = 0
                
            loss.backward()
            optimizer.step()
        valid_loss = valid(model, valid_dataloader, tokenizer)
        
        print(
            "Training Loss: {:.3f}.. ".format(train_loss / len(train_dataloader)),
            "Valid Loss: {:.3f} ".format(valid_loss),)
        valid_losses.append(valid_loss)
        train_losses.append(train_loss / len(train_dataloader))
        

In [47]:
train()

For now, this behavior is kept to avoid breaking backwards compatibility when padding/encoding with `truncation is True`.
- Be aware that you SHOULD NOT rely on t5-base automatically truncating your input to 512 when padding/encoding.
- If you want to encode/pad to sequences longer than 512 you can either instantiate this tokenizer with `model_max_length` or pass `max_length` when encoding/padding.
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
  self.input_ids = torch.tensor(input_tokenize['input_ids'])
  self.attention_mask = torch.tensor(input_tokenize['attention_mask'])
  self.labels = torch.tensor(label_tokenize['input_ids'])


Epoch: 1/20.. 
Batch: 50/250..  Training Loss: 6.037.. 
Batch: 100/250..  Training Loss: 4.562.. 
Batch: 150/250..  Training Loss: 4.275.. 
Batch: 200/250..  Training Loss: 4.302.. 
Original content: Contents1 Directions2 Content3 Martin Seligman4 Learned Helplessness5 Research6 Imprinting7 Essay Topics8 BibliographyDirectionsedit edit sourceThis content should include the following itemsImprinting and Learned helplessness Lorenz and SeligmanContentedit edit sourceMartin Seligmanedit edit sourceThrough Martin Seligmans experiments and the evaluation and observation of human behavior over the past decades by other researchers in his field psychologists have solidified the theory of passive resignation better known as learned helplessness The identification of Learned Helplessness has led to a greater




Generate content: 
Training Loss: 4.648..  Valid Loss: 3.695 
Epoch: 2/20.. 
Batch: 50/250..  Training Loss: 4.198.. 
Batch: 100/250..  Training Loss: 3.916.. 
Batch: 150/250..  Training Loss: 3.955.. 
Batch: 200/250..  Training Loss: 3.832.. 
Original content: Specialsearchpetroleum prefixEnglishHanziSpecialsearchoil prefixEnglishHanziPetroleum is an oil that is under the ground or the sea It is used to produce petrol kerosene diesel etc
Generate content: 
Training Loss: 3.936..  Valid Loss: 3.527 
Epoch: 3/20.. 
Batch: 50/250..  Training Loss: 3.848.. 
Batch: 100/250..  Training Loss: 3.739.. 
Batch: 150/250..  Training Loss: 3.711.. 
Batch: 200/250..  Training Loss: 3.695.. 
Original content: Picture Function or Picture Statementedit edit sourceWhich of the following did you need help with Please click a link to go to the desired pagePicture FunctionpictureHandle a hrefpicture20functionhtmlPICTUREaPicture Statementa hrefpicture20statementhtmlPICTUREa h1v1h2v2 pictureHandle
Generate 