In [1]:
import transformers
import jsonlines
import torch
from torch.utils.data import Dataset, DataLoader
import tqdm
import torch.optim as optim
from transformers import (
    AutoModelForSeq2SeqLM,
    AutoTokenizer,
    BlenderbotForConditionalGeneration,
    BlenderbotTokenizer,
)

In [2]:
train_path = './training_data.jsonl'

In [3]:
train_data = []
with jsonlines.open(train_path) as f:
    for i, line in enumerate(f.iter()):
        train_data.append(line)
# val_data = []
# with jsonlines.open(val_path) as f:
#     for line in f.iter():
#         val_data.append(line)

In [4]:
device = "cuda" if torch.cuda.is_available() else "cpu"
mname = "facebook/blenderbot-400M-distill"

model = BlenderbotForConditionalGeneration.from_pretrained(mname).to(device)
tokenizer = BlenderbotTokenizer.from_pretrained(mname)


In [5]:
def preprocess_function(examples):
    encode_examples = {}
    inputs = []
    targets = []
    for doc in examples:
        inputs.append(doc['source'])
        targets.append(doc['target'])
            
    # inputs = [doc["maintext"] for doc in examples]
    encode_examples  = tokenizer(inputs, max_length=120, truncation=True, padding=True)
    encode_examples['target'] = tokenizer(targets, max_length=100, truncation=True, padding=True)
    return encode_examples

In [6]:
tokenized_set = preprocess_function(train_data)
#tokenized_summerize['validation'] = preprocess_function(val_data)

In [7]:
print(len(tokenized_set['target']['input_ids']))
print(len(tokenized_set['input_ids']))

47710
47710


In [8]:
class Dialog(Dataset):
    def __init__(self, encoded_dataset):
        self.token = encoded_dataset['input_ids']
        self.label = encoded_dataset['target']['input_ids'] if 'target' in encoded_dataset.keys() else None
    def __getitem__(self, index):
        if self.label is None:
            return torch.tensor(self.token[index])
        else:
            return torch.tensor(self.token[index]), torch.tensor(self.label[index])
    def __len__(self):
        return len(self.token)    

In [9]:
trainset = Dialog(tokenized_set)
#valset = Dialog(tokenized_summerize['validation'])
trainloader = DataLoader(dataset = trainset, batch_size = 64, shuffle = True)
#valloader = DataLoader(dataset = valset, batch_size = 16, shuffle = False)

In [10]:
def save_checkpoint(checkpoint_path, model, optimizer):
    #torch.save(model.state_dict(), checkpoint_path)
    model.save_pretrained(checkpoint_path)
    tokenizer.save_pretrained(checkpoint_path)
    print('model saved to %s' % checkpoint_path)
    
def load_checkpoint(checkpoint_path, model, optimizer):
    model = torch.load(checkpoint_path)
    #model.load_state_dict(state['state_dict'])
    print('model loaded from %s' % checkpoint_path)

In [11]:
def train_save(model, epoch, save_interval, log_interval=100):
    optimizer = optim.AdamW(model.parameters(),lr=1e-4, betas=(0.9, 0.999),weight_decay=0)
    model.train()  # set training mode
    best = 0    
    iteration = 0
    for ep in range(epoch):
        loss_sum = 0
        for batch_idx, (input_ids, labels) in enumerate(trainloader):
            input_ids, labels = input_ids.to(device), labels.to(device)
            optimizer.zero_grad()
            output = model(input_ids=input_ids,labels=labels)
            loss = output.loss
            loss_sum = loss_sum + loss.item()
            loss.backward()
            optimizer.step()
            
            if iteration % log_interval == 0:
                print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                    ep, batch_idx * len(input_ids), len(trainloader.dataset),
                    100. * batch_idx / len(trainloader), loss.item()))
            if iteration % save_interval == 0 and iteration > 0:
                # test(model)
                if iteration>500:
                    save_checkpoint('./final-%i/' % iteration, model, optimizer)    
            
            iteration += 1
            
        print('\nTrain set:Loss: ({:.0f})\n'.format(loss_sum / len(trainloader)))
        
    
    # save the final model
    save_checkpoint('./final-%i/' % iteration, model, optimizer)

In [12]:
model = model.to(device)
train_save(model, epoch = 45, save_interval = 1000, log_interval = 100)


Train set:Loss: (1)

model saved to ./final-1000/

Train set:Loss: (0)

model saved to ./final-2000/

Train set:Loss: (0)


Train set:Loss: (0)

model saved to ./final-3000/

Train set:Loss: (0)

model saved to ./final-4000/

Train set:Loss: (0)

model saved to ./final-5000/

Train set:Loss: (0)


Train set:Loss: (0)

model saved to ./final-6000/

Train set:Loss: (0)

model saved to ./final-7000/

Train set:Loss: (0)

model saved to ./final-8000/

Train set:Loss: (0)


Train set:Loss: (0)

model saved to ./final-9000/

Train set:Loss: (0)

model saved to ./final-10000/

Train set:Loss: (0)

model saved to ./final-11000/

Train set:Loss: (0)


Train set:Loss: (0)

model saved to ./final-12000/

Train set:Loss: (0)

model saved to ./final-13000/

Train set:Loss: (0)

model saved to ./final-14000/

Train set:Loss: (0)




Train set:Loss: (0)

model saved to ./final-15000/

Train set:Loss: (0)

model saved to ./final-16000/

Train set:Loss: (0)

model saved to ./final-17000/

Train set:Loss: (0)


Train set:Loss: (0)

model saved to ./final-18000/

Train set:Loss: (0)

model saved to ./final-19000/

Train set:Loss: (0)

model saved to ./final-20000/

Train set:Loss: (0)


Train set:Loss: (0)

model saved to ./final-21000/

Train set:Loss: (0)

model saved to ./final-22000/

Train set:Loss: (0)

model saved to ./final-23000/

Train set:Loss: (0)


Train set:Loss: (0)

model saved to ./final-24000/

Train set:Loss: (0)

model saved to ./final-25000/

Train set:Loss: (0)

model saved to ./final-26000/

Train set:Loss: (0)


Train set:Loss: (0)

model saved to ./final-27000/

Train set:Loss: (0)

model saved to ./final-28000/

Train set:Loss: (0)

model saved to ./final-29000/

Train set:Loss: (0)




Train set:Loss: (0)

model saved to ./final-30000/

Train set:Loss: (0)

model saved to ./final-31000/

Train set:Loss: (0)

model saved to ./final-32000/

Train set:Loss: (0)


Train set:Loss: (0)

model saved to ./final-33000/

Train set:Loss: (0)

model saved to ./final-33570/


In [3]:
import torch
print(torch.__version__)

1.11.0+cu113
