In [1]:
!pip install transformers



In [2]:
# Load the Drive helper and mount
from google.colab import drive

# This will prompt for authorization.
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [3]:
import os
import sys
import pickle
import torch
import numpy as np
import torch.optim as optim
# from transformers import BertConfig, BertModel, BertForMaskedLM
from transformers import GPT2Config, GPT2Model, GPT2LMHeadModel

from IPython.display import clear_output

In [4]:
# custom parameters for GPT2 model
vocab_size = 21
max_position_embeddings = 20 # 1024
n_ctw = max_position_embeddings # 1024
n_embd = 16 # 768
n_layer = 8 # 12
n_head = 8 # 12
resid_pdrop = 0 # 0.1
embd_pdrop = 0 # 0.1
attn_pdrop = 0 # 0.1
layer_norm_epsilon = 1e-5 # 1e-5


config = GPT2Config(vocab_size_or_config_json_file=vocab_size,
                    n_positions=max_position_embeddings,
                    n_ctw=n_ctw,
                    n_embd=n_embd,
                    n_layer=n_layer,
                    n_head=n_head, 
                    resid_pdrop=resid_pdrop,
                    embd_pdrop=embd_pdrop,
                    attn_pdrop=attn_pdrop,
                    layer_norm_epsilon=layer_norm_epsilon)

model = GPT2LMHeadModel(config)

# print(model)
model.to('cuda')

GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(21, 16)
    (wpe): Embedding(20, 16)
    (drop): Dropout(p=0, inplace=False)
    (h): ModuleList(
      (0): Block(
        (ln_1): LayerNorm((16,), eps=1e-05, elementwise_affine=True)
        (attn): Attention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0, inplace=False)
          (resid_dropout): Dropout(p=0, inplace=False)
        )
        (ln_2): LayerNorm((16,), eps=1e-05, elementwise_affine=True)
        (mlp): MLP(
          (c_fc): Conv1D()
          (c_proj): Conv1D()
          (dropout): Dropout(p=0, inplace=False)
        )
      )
      (1): Block(
        (ln_1): LayerNorm((16,), eps=1e-05, elementwise_affine=True)
        (attn): Attention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0, inplace=False)
          (resid_dropout): Dropout(p=0, inplace=False)
        )
        (ln_2): LayerNorm((16,), eps=1e-05

In [0]:
# load data
d = "drive/My Drive/Colab Notebooks/smaug/data"
shingle_path = os.path.join(d, "ecoli_MG1655_shingles_length20_overlap10.npy")

with open(shingle_path, 'rb') as f:
    ecoli_shingles = np.load(shingle_path)#[:2056]

In [0]:
# pass to GPU
tokens_tensor = torch.tensor(ecoli_shingles).to('cuda')
dtrain = tokens_tensor

In [0]:
optimizer = optim.AdamW(model.parameters())

In [0]:
def savemodel():
    # save model
    # Load the Drive helper and mount
    from google.colab import drive

    # This will prompt for authorization.
    drive.mount('/content/drive')

    modeldir = "drive/My Drive/Colab Notebooks/smaug/data/models/ecoli_trivial_length20_overlap10"
    model.save_pretrained(modeldir)
    print(os.listdir(modeldir))

In [0]:
# train model with single aa masked at a time ##### batch
model.train()

batch_size = 2**14

optimizer.zero_grad()
np.random.seed(42424)
for i in range(100000):
    if i%500==10:
        savemodel()
    optimizer.zero_grad()

    select_idx = np.random.randint(0, len(dtrain), batch_size)

#     input_ids = dtrain[select_idx].unsqueeze(0) # singleton
    input_ids = dtrain[select_idx]
    
    
    outputs = model(input_ids, labels=input_ids)
    loss, prediction_scores = outputs[:2]
    
    loss.backward()
    optimizer.step()
    
    clear_output(wait=True)
    print("Loss:", loss.item())

    for k in range(10, 20):
#         print(i, torch.argmax(prediction_scores[0,k-1]).item(), input_ids[0,k].item(), "\t", loss.item()) #TODO figure out why GPT2 only offsets sometimes
        print(i, "\t", torch.argmax(prediction_scores[10,k-1]).item(), "\t", input_ids[10,k].item(), "\t")
#         print(prediction_scores[0])
    


Loss: 2.8211727142333984
8070 	 14 	 6 	
8070 	 1 	 12 	
8070 	 7 	 6 	
8070 	 1 	 3 	
8070 	 1 	 10 	
8070 	 1 	 7 	
8070 	 7 	 7 	
8070 	 7 	 17 	
8070 	 7 	 17 	
8070 	 1 	 10 	
