<a href="https://colab.research.google.com/github/robgon-art/GreenLIT/blob/main/GreenLIT.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# **GreenLIT: Using GPT-J with Multi-Task Learning to Create New Screenplays**
## How to fine-tune an ML model to create TV shows and movies with new titles, plot summaries, and scripts

![ReGEN Cover Image](https://raw.githubusercontent.com/robgon-art/ReGEN/main/cover_med.jpg)

Photo by [Tech Daily](https://unsplash.com/photos/PGuCnUzsRSM) on [Unsplash](https://unsplash.com/)</br>

**By Robert. A Gonsalves**</br>


In [None]:
#@title Initialize the System
!nvidia-smi
!gdown 1-LJoyi2fhU1PNLmxESwnOcBiqjNSfutT
!pip install transformers bitsandbytes-cuda111 wikipedia

from torch import nn
from bitsandbytes.functional import quantize_blockwise, dequantize_blockwise
from torch.cuda.amp import custom_fwd, custom_bwd
import torch.nn.functional as F
import wikipedia
import transformers
import torch

import nltk
nltk.download('wordnet')

config = transformers.GPTJConfig.from_pretrained("EleutherAI/gpt-j-6B")
tokenizer = transformers.AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B")

def check_in_wiki(name):
  name_parts = name.split()
  wiki_results = wikipedia.search(name)
  for w in wiki_results:
    w = w.lower()
    match_all_parts = True
    for n in name_parts:
      n = n.lower()
      if n == "the" or n == "a":
        continue
      if n not in w:
        match_all_parts = False
        break
    if match_all_parts:
      return True
  return False

class FrozenBNBLinear(nn.Module):
    def __init__(self, weight, absmax, code, bias=None):
        assert isinstance(bias, nn.Parameter) or bias is None
        super().__init__()
        self.out_features, self.in_features = weight.shape
        self.register_buffer("weight", weight.requires_grad_(False))
        self.register_buffer("absmax", absmax.requires_grad_(False))
        self.register_buffer("code", code.requires_grad_(False))
        self.adapter = None
        self.bias = bias
 
    def forward(self, input):
        output = DequantizeAndLinear.apply(input, self.weight, self.absmax, self.code, self.bias)
        if self.adapter:
            output += self.adapter(input)
        return output
 
    @classmethod
    def from_linear(cls, linear: nn.Linear) -> "FrozenBNBLinear":
        weights_int8, state = quantize_blockise_lowmemory(linear.weight)
        return cls(weights_int8, *state, linear.bias)
 
    def __repr__(self):
        return f"{self.__class__.__name__}({self.in_features}, {self.out_features})"
 
class DequantizeAndLinear(torch.autograd.Function): 
    @staticmethod
    @custom_fwd
    def forward(ctx, input: torch.Tensor, weights_quantized: torch.ByteTensor,
                absmax: torch.FloatTensor, code: torch.FloatTensor, bias: torch.FloatTensor):
        weights_deq = dequantize_blockwise(weights_quantized, absmax=absmax, code=code)
        ctx.save_for_backward(input, weights_quantized, absmax, code)
        ctx._has_bias = bias is not None
        return F.linear(input, weights_deq, bias)
 
    @staticmethod
    @custom_bwd
    def backward(ctx, grad_output: torch.Tensor):
        assert not ctx.needs_input_grad[1] and not ctx.needs_input_grad[2] and not ctx.needs_input_grad[3]
        input, weights_quantized, absmax, code = ctx.saved_tensors
        # grad_output: [*batch, out_features]
        weights_deq = dequantize_blockwise(weights_quantized, absmax=absmax, code=code)
        grad_input = grad_output @ weights_deq
        grad_bias = grad_output.flatten(0, -2).sum(dim=0) if ctx._has_bias else None
        return grad_input, None, None, None, grad_bias

class FrozenBNBEmbedding(nn.Module):
    def __init__(self, weight, absmax, code):
        super().__init__()
        self.num_embeddings, self.embedding_dim = weight.shape
        self.register_buffer("weight", weight.requires_grad_(False))
        self.register_buffer("absmax", absmax.requires_grad_(False))
        self.register_buffer("code", code.requires_grad_(False))
        self.adapter = None
 
    def forward(self, input, **kwargs):
        with torch.no_grad():
            # note: both quantuized weights and input indices are *not* differentiable
            weight_deq = dequantize_blockwise(self.weight, absmax=self.absmax, code=self.code)
            output = F.embedding(input, weight_deq, **kwargs)
        if self.adapter:
            output += self.adapter(input)
        return output 
 
    @classmethod
    def from_embedding(cls, embedding: nn.Embedding) -> "FrozenBNBEmbedding":
        weights_int8, state = quantize_blockise_lowmemory(embedding.weight)
        return cls(weights_int8, *state)
 
    def __repr__(self):
        return f"{self.__class__.__name__}({self.num_embeddings}, {self.embedding_dim})"
 
def quantize_blockise_lowmemory(matrix: torch.Tensor, chunk_size: int = 2 ** 20):
    assert chunk_size % 4096 == 0
    code = None
    chunks = []
    absmaxes = []
    flat_tensor = matrix.view(-1)
    for i in range((matrix.numel() - 1) // chunk_size + 1):
        input_chunk = flat_tensor[i * chunk_size: (i + 1) * chunk_size].clone()
        quantized_chunk, (absmax_chunk, code) = quantize_blockwise(input_chunk, code=code)
        chunks.append(quantized_chunk)
        absmaxes.append(absmax_chunk)
 
    matrix_i8 = torch.cat(chunks).reshape_as(matrix)
    absmax = torch.cat(absmaxes)
    return matrix_i8, (absmax, code)
 
 
def convert_to_int8(model):
    """Convert linear and embedding modules to 8-bit with optional adapters"""
    for module in list(model.modules()):
        for name, child in module.named_children():
            if isinstance(child, nn.Linear):
                print(name, child)
                setattr( 
                    module,
                    name,
                    FrozenBNBLinear(
                        weight=torch.zeros(child.out_features, child.in_features, dtype=torch.uint8),
                        absmax=torch.zeros((child.weight.numel() - 1) // 4096 + 1),
                        code=torch.zeros(256),
                        bias=child.bias,
                    ),
                )
            elif isinstance(child, nn.Embedding):
                setattr(
                    module,
                    name,
                    FrozenBNBEmbedding(
                        weight=torch.zeros(child.num_embeddings, child.embedding_dim, dtype=torch.uint8),
                        absmax=torch.zeros((child.weight.numel() - 1) // 4096 + 1),
                        code=torch.zeros(256),
                    )
                )

class GPTJBlock(transformers.models.gptj.modeling_gptj.GPTJBlock):
    def __init__(self, config):
        super().__init__(config)

        convert_to_int8(self.attn)
        convert_to_int8(self.mlp)


class GPTJModel(transformers.models.gptj.modeling_gptj.GPTJModel):
    def __init__(self, config):
        super().__init__(config)
        convert_to_int8(self)
        

class GPTJForCausalLM(transformers.models.gptj.modeling_gptj.GPTJForCausalLM):
    def __init__(self, config):
        super().__init__(config)
        convert_to_int8(self)


transformers.models.gptj.modeling_gptj.GPTJBlock = GPTJBlock  # monkey-patch GPT-J

gpt = torch.load("/content/GreenLIT_old.pt",  map_location=torch.device('cuda'))
gpt.eval()

import re
import textwrap
from nltk.corpus import wordnet as wn

In [10]:
#@title Generate Titles and Summaries
genre = 'crime drama' #@param {type:"string"}
theme = 'cryptocurrency' #@param {type:"string"}

prompt = "GENRE: " + genre + " THEME: " + theme + "TITLE:"
with torch.no_grad():
  prompt_tokens = tokenizer(prompt, return_tensors="pt").input_ids.cuda()
  sample_outputs = gpt.generate(prompt_tokens, max_length=80, do_sample=True, 
    temperature=0.8, pad_token_id=tokenizer.eos_token_id, num_return_sequences=40)

titles = []
summaries = []
count = 1
for i, sample_output in enumerate(sample_outputs):
  results = tokenizer.decode(sample_output, skip_special_tokens=True)
  results = results.replace("\n", " ")
  genre = re.search('GENRE:(.*)THEME:', results).group(1).strip()
  title = re.search('TITLE:(.*)SUMMARY:', results).group(1).strip()

  already_done = check_in_wiki(title)
  alpha = re.sub('[^a-zA-Z]+', '', title)

  if len(alpha) < 3 or already_done:
    continue

  summary = re.search('SUMMARY:(.*)', results).group(1).strip()
  titles.append(title)
  summaries.append(summary)

  out = str(count).zfill(2) + " " + title + " - " + summary
  wrapped = textwrap.fill(out, width=150, subsequent_indent="   ")
  print(wrapped)
  count += 1

01 Bitcoin Heist - The Series - Bitcoin Heist - The Series takes a look into the world of one of the most interesting financial inventions of our
   time.
02 Bitcoin Heist - The heists always happen in the same way. The thieves always arrive at the scene armed with a van stuffed with cash, and they
   always run from the get go.
03 The Secret Life of Cryptocurrencies - In a small mining town in Australia, a cryptocurrency trader and his wife are found dead in their mansion.
04 HOMBREQUE - A gang of thieves commits an armed robbery in a bank, and they are captured by the police. But one of them goes missing, and the police
   want to find out who was in the gang and what they know.
05 The Bitcoin Heist - A bank robber is caught, but with the help of one of his accomplices, a bitcoin-related crime is uncovered, which could put a
   serious strain on the digital currency in the near future.
06 I, Trespasser - A detective in the LAPD is transferred to the Homicide Task Force where he inves

In [18]:
#@title Choose a Title and Create a Script

choice = 7 #@param {type:"slider", min:1, max:40, step:1}
choice -= 1

if choice >= len(titles):
  print("Choose between 1 and " + len(titles)+1)
else:
  title = titles[choice]
  summary = summaries[choice]
  print(choice+1, title)

  prompt = "TITLE: " + title + " SUMMARY: " + summary + " SCRIPT:\n[Scene:"
  
  with torch.no_grad():
    prompt_tokens = tokenizer(prompt, return_tensors="pt").input_ids.cuda()
    sample_outputs = gpt.generate(prompt_tokens, max_length=480, do_sample=True, 
      temperature=0.7, pad_token_id=tokenizer.eos_token_id, num_return_sequences=1)
    results = tokenizer.decode(sample_outputs[0], skip_special_tokens=True)
    results = results.strip()
    print("\n[Scene: " + results[len(prompt):])

7 The Crypto Hustle

[Scene:  The Crypto Hustle Conference, the conference room. Everyone is there except for Max, who is in a side room.]

MAX - Can I get some coffee?
MAYA - I'll get it.

(Max enters the conference room.)

MAYA - Max, I found some interesting stuff in the paper.
MAX - Really?
MAYA - Yeah, it's about that guy who was in the hospital the other night.
MAX - The guy with the kidney stones?
MAYA - Yeah, apparently he's some kind of expert in something called "crypto-currencies".
MAX - "Crypto" is short for cryptography.
MAYA - No, it's a type of currency that uses encryption to make it untraceable.
MAX - Oh, that's impossible!
MAYA - Well, apparently it's not. And I think I know how they're doing it.

(Max looks at Maya with disbelief.)

MAYA - It's this guy named Satoshi Nakamoto.
