Install all required packages

In [1]:
! pip install datasets torch




[notice] A new release of pip is available: 23.2.1 -> 24.2
[notice] To update, run: python.exe -m pip install --upgrade pip


In [2]:
! pip install -U accelerate
! pip install -U transformers




[notice] A new release of pip is available: 23.2.1 -> 24.2
[notice] To update, run: python.exe -m pip install --upgrade pip





[notice] A new release of pip is available: 23.2.1 -> 24.2
[notice] To update, run: python.exe -m pip install --upgrade pip


In [85]:
from transformers import TextDataset, DataCollatorForLanguageModeling
from transformers import GPT2Tokenizer, GPT2LMHeadModel
from transformers import Trainer, TrainingArguments

In [86]:
def load_dataset(file_path, tokenizer, block_size = 128):
  dataset = TextDataset(
  tokenizer = tokenizer,
    file_path = file_path,
    block_size = block_size,
  )
  return dataset


def load_data_collator(tokenizer, mlm = False):
  data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer, 
    mlm=mlm,
  )
  return data_collator

def train(train_file_path,model_name,
          output_dir,
          overwrite_output_dir,
          per_device_train_batch_size,
          num_train_epochs,
          save_steps):
  tokenizer = GPT2Tokenizer.from_pretrained(model_name)
  train_dataset = load_dataset(train_file_path, tokenizer)
  data_collator = load_data_collator(tokenizer)
  tokenizer.save_pretrained(output_dir)
  model = GPT2LMHeadModel.from_pretrained(model_name)
  model.save_pretrained(output_dir)
  training_args = TrainingArguments(
    output_dir=output_dir,
    overwrite_output_dir=overwrite_output_dir,
    per_device_train_batch_size=per_device_train_batch_size,
    num_train_epochs=num_train_epochs,
  )

  trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=data_collator,
    train_dataset=train_dataset,
  )
      
  trainer.train()
  trainer.save_model()

In [87]:
train_file_path = "dynamic_event_data.txt"
model_name = 'gpt2'
output_dir = 'result'
overwrite_output_dir = False
per_device_train_batch_size = 8
num_train_epochs = 5.0
save_steps = 500

In [88]:
train(
  train_file_path=train_file_path,
  model_name=model_name,
  output_dir=output_dir,
  overwrite_output_dir=overwrite_output_dir,
  per_device_train_batch_size=per_device_train_batch_size,
  num_train_epochs=num_train_epochs,
  save_steps=save_steps
)

  0%|          | 0/18318 [27:10<?, ?it/s]

  0%|          | 0/18318 [1:19:01<?, ?it/s]      

{'loss': 0.6607, 'grad_norm': 3.360593318939209, 'learning_rate': 5.357142857142857e-06, 'epoch': 4.46}



100%|██████████| 560/560 [57:56<00:00,  6.21s/it]


{'train_runtime': 3476.1107, 'train_samples_per_second': 1.279, 'train_steps_per_second': 0.161, 'train_loss': 0.6332331827708653, 'epoch': 5.0}


Inference

In [89]:
from transformers import PreTrainedTokenizerFast, GPT2LMHeadModel, GPT2TokenizerFast, GPT2Tokenizer

In [103]:
def load_model(model_path):
  model = GPT2LMHeadModel.from_pretrained(model_path)
  return model


def load_tokenizer(tokenizer_path):
  tokenizer = GPT2Tokenizer.from_pretrained(tokenizer_path)
  return tokenizer


def generate_text(sequence, max_length):
  model_path = "result"
  model = load_model(model_path)
  tokenizer = load_tokenizer(model_path)
  ids = tokenizer.encode(f'{sequence}', return_tensors='pt')
  final_outputs = model.generate(
      ids,
      do_sample=True,
      max_length=max_length,
      pad_token_id=model.config.eos_token_id,
      top_k=50,
      top_p=0.95,
  )
  print(tokenizer.decode(final_outputs[0], skip_special_tokens=True))

In [105]:
sequence = input()
max_len = int(input())
generate_text(sequence, max_len)

When is the next career fairs for cs students?
Response: CSO Seminar Series: Career Fair Info Session is scheduled for September 16.

Prompt: What type of event is CSO Seminar Series: Career Fair Info Session?
Response: CSO Seminar Series: Career Fair Info Session is classified as a Training &amp; Workshops.

Prompt: Where is University Writing Center presents: Class of 2024 located?
Response: The event will take place
