In [1]:
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
import torch

# Load model
print("Loading AI model...")
model_name = "distilgpt2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)

  from .autonotebook import tqdm as notebook_tqdm


Loading AI model...


Loading weights: 100%|██████████| 76/76 [00:00<00:00, 1344.78it/s, Materializing param=transformer.wte.weight]            
GPT2LMHeadModel LOAD REPORT from: distilgpt2
Key                                        | Status     |  | 
-------------------------------------------+------------+--+-
transformer.h.{0, 1, 2, 3, 4, 5}.attn.bias | UNEXPECTED |  | 

Notes:
- UNEXPECTED	:can be ignored when loading from different task/architecture; not ok if you expect identical arch.


In [41]:
txt = "In the quiet gallery of ancient echoes, where the marble statues breathe the dust of centuries past,"; ##initial prompt
print(f"Start prompt: '{txt}'");
print("Generating...\n");

ctxt = txt; ##the current text
lns = []; ##store the lines as we make them in an arr to format later
nlns = 4; ##number of lines in the poem
wpln = 4; ##min amt of words per line
wplnmax = 12; ##max amt of words per line

temp = 0.82; ##higher is more random, vice versa (from testing 0.78-0.84 is the best range)

p = 0.86; ##controls the nucleus sampling threshold (from testing 0.84 to 0.89 is the best range)

newlines = tokenizer.encode("\n", add_special_tokens=False); ##newline token ids to mask during line generation

for i in range(nlns): ##for each line
    print(f"Generating Line {i+1}...");
    lntxt = ""; ##make a string to store the current line
    comp = False; ##flag to store line completion
    while not comp: ##keep generating until the line is done
        ##encode the text to get input as ids
        inids = tokenizer.encode(ctxt, return_tensors="pt"); ##inids = input_ids;

        ##get predictions
        with torch.no_grad(): ##dont do gradients
            out = model(inids); ##eval model to get output
            pred = out.logits; ##take the logits of the output as the predictions

        ##get the next tokens probabilities to then sample from
        nxttknlogits = pred[0,-1,:]; ##get the last tokens logits
        for nid in newlines: 
            nxttknlogits[nid] = -float("inf");
        nxttknlogits[tokenizer.eos_token_id] = -float("inf"); ##get rid of |endoftext| token to ensure we dont wander and keep same creativity
        ##this sometimes produces bad results as the model cannot stop and start a new thought, and will sometimes just keep repeating itself
        
        nxttknlogits = nxttknlogits/temp; ##apply the temperature to logits to have some control over randomness;
        nxttknprobs = torch.softmax(nxttknlogits, dim=0); ##get the probabilities through softmax (we use softmax bc logits can be negative and softmax nomralizes them to be between 0 and 1);
        
        sortprobs, sortidxs = torch.sort(nxttknprobs, descending=True); ##sort the probabilities and their associated indexes
        cumulativeprobs = torch.cumsum(sortprobs, dim=0); ##we need to get the cumulative probabilities to find the smallest set that passes the threshold (nucleus i think)
        
        sortidxtorm = cumulativeprobs>p; ##get a mask fo the indexes to rm based on the threshold;
        sortidxtorm[1:] = sortidxtorm[:-1].clone(); ##shift the mask to the right by 1 so that we keep the first token that passes the threshold;
        sortidxtorm[0] = False; ##make sure the first idx is never rmed; ()
        
        sortprobs[sortidxtorm] = 0.0; ##remove the probabilities that are above the threshold;
        sortprobs = sortprobs/sortprobs.sum(); ##normalize the rem probabilities to 1
        
        idxinsort = torch.multinomial(sortprobs, 1).item(); ##take a choice based on the remaining probs
        nxttknid = sortidxs[idxinsort].item(); ##get the tkn id for this idx

        """ The above section replaces the top k samples as it is better for thresholding and overall creativity of the model; It also lowers the odds of getting a random eof or eol token.
        ############### TOP K (5) SAMPLING ###############
        # the top k samples thing
        k=5; ##top 5 choices will be considered then randomly chosen after normalizing
        topprobs, topidxs = torch.topk(nxttknprobs, k);
        
        topprobs = topprobs/topprobs.sum(); ##normalize the probabilities
        idx = torch.multinomial(topprobs, 1).item(); ##randomly take a choice based on the probs
        nxttknid = topidxs[idx].item(); ##get the token id associated with this index
        """

        nxtword = tokenizer.decode([nxttknid]); ##decode to get the string representation of the word
        lntxt+=nxtword; ##add to the line text
        ctxt+=nxtword; ##add to the current text for next iter
        
        wcnt = len(lntxt.strip().split()); ##count words on the line
        if (wcnt>=wpln and lntxt.strip().endswith((".", "!", "?"))): comp = True;
        elif (wcnt>=wplnmax): comp = True;

    lns.append(lntxt.strip()); ##after line is done store it
    print(f"Line {i+1}: {lntxt.strip()}"); ##then output it as an individual line
    ##ctxt+="\n"; ##add a new line so next line will start on a new line i guess?

print("\nFull Poem:");
for i, ln in enumerate(lns, 1): print(f"{i}. {ln}");


Start prompt: 'In the quiet gallery of ancient echoes, where the marble statues breathe the dust of centuries past,'
Generating...

Generating Line 1...
Line 1: is the most ancient shrine of the human race.
Generating Line 2...
Line 2: Each of the four labyrinths, the labyrinths, the halls and the streets
Generating Line 3...
Line 3: , which are filled with objects of history, have its own unique
Generating Line 4...
Line 4: resonance. And yet the people of the world are not the only

Full Poem:
1. is the most ancient shrine of the human race.
2. Each of the four labyrinths, the labyrinths, the halls and the streets
3. , which are filled with objects of history, have its own unique
4. resonance. And yet the people of the world are not the only


In [33]:
import torch.nn.functional as F

ctxt = "In the quiet gallery of ancient echoes, where the marble statues breathe the dust of centuries past,"; ##initial prompt
nlns = 4;
k = 5;

for i in range(nlns): 
    lntxt = ""; 
    comp = False; 
    while not comp:
        inids = tokenizer.encode(ctxt, return_tensors="pt"); ##encode text
        with torch.no_grad():
            out = model(inids);
            nxttknlogits = out.logits[0,-1,:];
        for nid in newlines: nxttknlogits[nid] = -float("inf"); ##mask newlines
        
        ##-- GET LOGITS FOR PROBABILITIES --##
        
        nxttknlogits[tokenizer.eos_token_id] = -float("inf");
        nxttknlogits = nxttknlogits/0.82; ##temp
        nxttknprobs = torch.softmax(nxttknlogits, dim=0);
        sortprobs, sortidxs = torch.sort(nxttknprobs, descending=True);
        cumulativeprobs = torch.cumsum(sortprobs, dim=0);
        
        ##--  TOP-P  --##
        
        sortidxtorm = cumulativeprobs>0.86; ##p
        sortidxtorm[1:] = sortidxtorm[:-1].clone();
        sortidxtorm[0] = False;
        sortprobs[sortidxtorm] = 0.0;
        sortprobs = sortprobs/sortprobs.sum();
        
        ##--  NEXT TOKEN HANDLING  --##
        
        idxinsort = torch.multinomial(sortprobs, 1).item();
        nxttknid = sortidxs[idxinsort].item();
        nxtword = tokenizer.decode([nxttknid]);
        lntxt+=nxtword;
        ctxt+=nxtword;
    
        wcnt = len(lntxt.strip().split()); 
        if (wcnt>=4 and lntxt.strip().endswith((".", "!", "?"))): comp = True;
        elif (wcnt>=12): comp=True;
    
    print(f"Line {i+1}: {lntxt.strip()}");

Line 1: they are regularly found in the gardens of the Museum of Natural
Line 2: History and the University of North Carolina, Chapel Hill, North Carolina.
Line 3: The Central Park Conservancy, which also houses the Museum of Natural History
Line 4: , which also houses the Museum of Natural History, is located at
