In [1]:
# !pip install datasets transformers

In [2]:
import torch
device = 'cuda:pick_a_device' if torch.cuda.is_available() else 'cpu'

In [3]:
from collections import defaultdict
import os
import datetime
from pathlib import Path
import json
from tqdm import tqdm

from torch import nn
from torch.utils.data import DataLoader

# import datasets
# from datasets import load_dataset
from transformers import AutoTokenizer

import utils.load_corpus as load_corpus
import utils.losses as l

import model.pick_a_model as rotator

In [4]:
PATH = "data"
folder_path = Path(f"limanet")
os.makedirs(PATH/folder_path, exist_ok=True)

In [5]:
tokenizer_type = 'bert-base-uncased'

tokenizer = AutoTokenizer.from_pretrained(tokenizer_type)
text = ["I love bird.", "Hi, I'm Bob."]
tokenized = tokenizer(text, return_tensors='pt', padding=True)
tokenized['input_ids']

tensor([[ 101, 1045, 2293, 4743, 1012,  102,    0,    0,    0],
        [ 101, 7632, 1010, 1045, 1005, 1049, 3960, 1012,  102]])

In [6]:
import importlib
load_corpus = importlib.reload(load_corpus)

dataset = load_corpus.load_bookcorpus(tokenizer, cache_dir='~/data1-0756727/cache/huggingface')
# dataset = load_corpus.load_msmarco(tokenizer, cache_dir='~/data1-0756727/cache/huggingface')

dataset['train'][:2]

{'data': tensor([ 2788,  2021,  1010,  2074,  2002,  2028,  2052,  2298,  2022,  2012,
         13311,  1037,  2105,  7163,  1996,  2239,  2542,  2741,  2282,  2032,
          1010,  8134,  2652,  4937,  2007, 22436,  2010,  2594, 10899,  1012,
          1012]),
 'batch_sizes': tensor([2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1]),
 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
         [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0]]),
 'sorted_indices': tensor([0, 1]),
 'unsorted_indices': tensor([0, 1])}

In [7]:
# def transform(batch):
#   tokenized = tokenizer(batch['text'], return_tensors='pt', padding=True, add_special_tokens=False)
#   packed = torch.nn.utils.rnn.pack_padded_sequence(tokenized['input_ids'], tokenized['attention_mask'].sum(dim=1), enforce_sorted=False, batch_first=True)
#   to_return = {
#     'data': packed.data,
#     'batch_sizes': packed.batch_sizes,
#     'attention_mask': tokenized['attention_mask'],
#     'sorted_indices': packed.sorted_indices,
#     'unsorted_indices': packed.unsorted_indices
#   }
#   return to_return

# dataset.set_transform(transform)

# dataset['train'][:2]

In [8]:
# # test
# rotator_depth = 3
# import importlib
# rotator = importlib.reload(rotator)
# model = rotator.Rotator(tokenizer.vocab_size, depth=rotator_depth, dim=128, num_heads=32, hippo_dim=16, num_hippo_heads=8)#.to(device)
# model.eval()
# batch = dataset['train'][:3]

# gen = model(**batch)
# predicts, targets = [torch.concat(res) for res in zip(*gen)]
# predicts.shape, targets.shape


In [9]:
batch_size = 512
shuffle = False

class Trainloader:
  def __init__(self, dataset, batch_size):
    self.data = dataset
    self.batch_size = batch_size

  def __len__(self):
    return len(self.data) // batch_size + 1

  def __iter__(self):
    current_start = 0
    for i in range(self.__len__()):
      yield self.data[current_start:current_start+self.batch_size]
      current_start += self.batch_size

# train_dataloader = DataLoader(dataset["train"], batch_size=batch_size, collate_fn=collate)
train_loader = Trainloader((dataset.shuffle() if shuffle else dataset)['train'], batch_size=batch_size)


n = next(iter(train_loader))
# print(tokenizer.decode(n['data']))
{key: n[key].shape for key in n}

{'data': torch.Size([7722]),
 'batch_sizes': torch.Size([64]),
 'attention_mask': torch.Size([512, 64]),
 'sorted_indices': torch.Size([512]),
 'unsorted_indices': torch.Size([512])}

In [10]:
import importlib

learning_rate = 0.0001 # for euclidean
# learning_rate = 0.01 # for cosine
model_params = {
  'rotator_depth': 3,
  'num_heads': 32,
  'dim': 128,
  'hidden_dim': 32,
  'rotary_denom': 0.5,
  
}
loss_params = {
  'sampling_word_size': 10,
  'margin': 8, # for euclidean
  # 'margin': .3, # for cosine
}

rotator = importlib.reload(rotator)
l = importlib.reload(l)

model = rotator.Rotator(tokenizer.vocab_size, **model_params).to(device)
# using_loss = l.Complex_mse()
# using_loss = l.Complex_triplet_loss(model)
# using_loss = l.Complex_mse_triplet_loss(model)
# using_loss = l.Complex_mse_squared_triplet_loss(model)
using_loss = l.Mse_multiplet_loss(model, **loss_params)
# using_loss = l.Cos_multiplet_loss(model, **loss_params)

optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

In [11]:
# # testing model output
# batch = next(iter(train_loader))
# batch = {key: batch[key].to(device) for key in batch}

# gen = model(**batch)
# predicts, targets = [torch.concat(res) for res in zip(*gen)]
# losses = using_loss(predicts, targets)
# loss = sum(losses.values())
# loss_dict = {
#   'total': f"{loss.item(): .6f}",
#   **{key: f"{losses[key].item(): .6f}" for key in losses},
# }

# print(loss_dict)
# # del batch, loss
# # torch.cuda.empty_cache()

In [12]:
# def backward_hook(module, gin, gout):
#   print(f"{len(gin)=}, {len(gout)=}")
#   print(*[f"{i=}, {gi.shape=}, {gi.mean()=}, {gi.min()=}, {gi.max()=}" for i, gi in enumerate(gin)], sep='\n')
#   print(*[f"{i=}, {go.shape=}, {go.mean()=}, {go.min()=}, {go.max()=}" for i, go in enumerate(gout)], sep='\n')

# model.limas[0].register_backward_hook(backward_hook)

In [13]:
time_string = datetime.datetime.now(tz=datetime.timezone(datetime.timedelta(hours=8))).strftime('%Y%m%d.%H:%M:%S')
subfolder_path = Path(time_string)

os.makedirs(PATH/folder_path/subfolder_path, exist_ok=True)

with open(PATH/folder_path/subfolder_path/f'parameters.json', 'w') as f:
  f.write(json.dumps({
    'tokenizer_type': tokenizer_type,
    'model': str(model),
    'model_params': model_params,
    'learning_rate': learning_rate,
    'batch_size': batch_size,
    'shuffle': shuffle,
    'loss_type': str(using_loss),
    'loss_params': loss_params,
  }))  
time_string

'20240504.18:41:07'

In [14]:
num_epochs = 1
save_every_n_batches = 5000


# os.makedirs(PATH/folder_path/subfolder_path, exist_ok=True)

loss_history = []

for epoch_num, epoch in enumerate(range(num_epochs)):
  bar = tqdm(train_loader)

  for batch_num, batch in enumerate(bar):

    optimizer.zero_grad()
    
    batch = {key: batch[key].to(device) for key in batch}
    gen = model(**batch)

    total_losses = defaultdict(lambda: torch.tensor([0], dtype=torch.float).to(device))
    while True: 
      results = []
      try:
        for _ in range(100):
          results.append(next(gen))
        predicts, targets = [torch.concat(pred_or_targ) for pred_or_targ in zip(*results)]
        losses = using_loss(predicts, targets)
        for loss_name, value in losses.items():
          total_losses[loss_name] += value
          
      except StopIteration:
        if len(results) > 0:
          predicts, targets = [torch.concat(pred_or_targ) for pred_or_targ in zip(*results)]
          losses = using_loss(predicts, targets)
          for loss_name, value in losses.items():
            total_losses[loss_name] += value
        break
    
    loss = sum(total_losses.values())

    loss_dict = {
      'total': f"{loss.item(): .6f}",
      **{key: f"{losses[key].item(): .6f}" for key in losses},
    }
      
    bar.set_postfix(loss_dict)

    # with torch.autograd.detect_anomaly(True):
      # loss.backward()
      # torch.nn.utils.clip_grad_norm_(parameters=model.parameters(), max_norm=10, norm_type=2)
    loss.backward()
    # break 
    optimizer.step()

    if (batch_num % 100 == 0):
      # torch.cuda.empty_cache()
      loss_history.append([epoch_num, batch_num, loss_dict])
      # print(f"{datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}, batch {batch_num}")

    if (batch_num % save_every_n_batches == 0):
      prefix = f"batch_{batch_num}-"
      torch.save(model, PATH/folder_path/subfolder_path/(prefix + f'model.pt'))
        
      with open(PATH/folder_path/subfolder_path/'history.json', 'w') as f:
        f.write(json.dumps(loss_history))  

      # torch.cuda.empty_cache()

    if (batch_num in [100, 500, 1000, 1500, 2000, 3000, 7500, 12500, 17500, 22500]):
        prefix = f"batch_{batch_num}-"
        torch.save(model, PATH/folder_path/subfolder_path/(prefix + 'model.pt'))
        
        with open(PATH/folder_path/subfolder_path/'history.json', 'w') as f:
          f.write(json.dumps(loss_history))  

    # del batch, loss

  0%|                                                     | 185/144540 [01:52<24:18:45,  1.65it/s, total=505.241882, multiplet_loss=7.633089, mse=497.608795]


KeyboardInterrupt: 