In [1]:
import datetime 
import torch 
import torch 
from transformers import GPT2Tokenizer, GPT2LMHeadModel 


# Note: not only GPT2 
def format_time(elapsed):
    # print nicely formated elapsed time
    return str(datetime.timedelta(seconds=int(round((elapsed)))))

class HitchensQuoteModel(): 
    
    def __init__(self, model_input: GPT2LMHeadModel = None, device="cuda"):
        self.device = device
        self.device = torch.device(self.device)  # Sloppily use Cuda GPU. 
        self.tokenizer = GPT2Tokenizer.from_pretrained('gpt2', pad_token="<|endoftext|>")  # extract the gpt2 tokenizer
        self.model = self.init_model(model_input=model_input)  # initiate model
        self.model.to(self.device)  # send model to device

    def init_model(self, model_input: GPT2LMHeadModel) -> GPT2LMHeadModel:
        """
            Load pretrained model or use input model
        """
        if model_input is None:
            return GPT2LMHeadModel.from_pretrained("gpt2",
                                                    pad_token_id=self.tokenizer.eos_token_id)
        else:
            return model_input
    
     
    def generate_quotes(self,
                        prompt="",
                        num_sentences: int = 3,
                        max_length: int = 20,
                        num_beams: int = 50,
                        no_repeat_ngram_size: int = 3,
                        print_it: bool = True,
                        cuda: bool = True): 
        
        if type(prompt) == str:
            prompt = [prompt]
            
        input_ids = self.tokenizer(prompt,
                                    return_tensors='pt',
                                    padding=True,
                                    truncation=True)["input_ids"]

        input_ids = input_ids.to(self.device)
        output_temp = self.model.generate(
            input_ids, 
            max_length=max_length, 
            num_beams=num_beams, 
            no_repeat_ngram_size=no_repeat_ngram_size,
            do_sample=True,
            top_k=0,
            num_return_sequences=num_sentences
        )

        if print_it:
            for i in range(num_sentences):
                print(self.tokenizer.decode(output_temp[i]))

        self.model = self.model.to(self.device)
        return output_temp



In [8]:
model =  GPT2LMHeadModel.from_pretrained("hitch_gpt2", local_files_only=True  )
 

mod = HitchensQuoteModel(model_input=model)


# predict with prompt
prompt = "Trump is a " 
print("\n\n------------------ ALL ------------------")
mod.generate_quotes(prompt, max_length=200, num_sentences=1) 
# predict with prompt
prompt = "Clinton is a " 
print("\n\n------------------ ALL ------------------")
mod.generate_quotes(prompt, max_length=200, num_sentences=1) 

prompt = "Obama is a " 
print("\n\n------------------ ALL ------------------")
mod.generate_quotes(prompt, max_length=200, num_sentences=1) 

prompt = "Religion is a " 
print("\n\n------------------ ALL ------------------")
mod.generate_quotes(prompt, max_length=200, num_sentences=1) 



------------------ ALL ------------------
Trump is a icky, shady, demagogic nutbag...I and a few other people saw that he should be destroyed.<|endoftext|>


------------------ ALL ------------------
Clinton is a icky, arithmetical, pragmatic, pragmatic nutbag... I wouldn't have her job'. Those who profess unquenchable love for the sovereign are adamant that she press on in a task that they consider killingly hard.<|endoftext|>


------------------ ALL ------------------
Obama is a icky, arithmetical, pragmatic, pragmatic nutbag... I don't see how you can be saying this. Look. You want to take...you want your god to take responsibility for the huge number of collapsing stars and imploding galaxies and destroyed universes and failed solar systems that have left us in this tiny corner on the one planet in this petty solar system that can support life some of the time on some of its surface. And you want a creator who filled this earth with species, since life began 99% of which are now

tensor([[ 6892, 17035,   318,   257,   220, 17479,    11,   610,   342,  4164,
           605,  1109,    13,   632,   338,  5340,    11,   314,   892,    11,
          2158,   881,   314,  1549,  1716, 39977,   276, 11889,   393, 18101,
           656,   257,  1281,    12, 23149,  1048,    11,   314,   836,   470,
           892,   314,  1549,  1683,  1487,   616,  1570,   326,  1793,   318,
           582,    12,  9727,    13,   383,   845,  2126,   326,  1793,  9209,
           284,   617, 45029,   378, 20729,  1175, 10572,   287,  9671,    11,
           290,   339,   447,   247,    82,  1498,   284,  3551,   428,   866,
          7138,   290,   340,  4909,   262,  7429,   284,   477,   851,   836,
           447,   247,    83,  7030,   616,   640,   351,   326, 40317,  1174,
            83,    13,  4418,    11,   262,  3934,  8368, 17371,  9209,   691,
         17526,    11,   340,  2331,    30,   327,  2416,    13, 50256]],
       device='cuda:0')