In [1]:
# !pip install datasets transformers

In [1]:
import torch
device = 'cuda:pick_a_device' if torch.cuda.is_available() else 'cpu'

In [2]:
import os
import datetime
from pathlib import Path
import json
from tqdm import tqdm

from torch import nn
from torch.utils.data import DataLoader

import datasets
from datasets import load_dataset
from transformers import AutoTokenizer

import model.rotator_lima4_hippo as rotator

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
PATH = "data"
folder_path = Path(f"limanet")
os.makedirs(PATH/folder_path, exist_ok=True)

In [4]:
tokenizer_type = 'bert-base-uncased'

tokenizer = AutoTokenizer.from_pretrained(tokenizer_type)
text = ["I love bird.", "Hi, I'm Bob."]
tokenized = tokenizer(text, return_tensors='pt', padding=True)
tokenized['input_ids']

tensor([[ 101, 1045, 2293, 4743, 1012,  102,    0,    0,    0],
        [ 101, 7632, 1010, 1045, 1005, 1049, 3960, 1012,  102]])

In [5]:
dataset = load_dataset("bookcorpus", cache_dir='~/data1-0756727/cache/huggingface')
# dataset.set_format('torch')

In [6]:
def transform(batch):
  tokenized = tokenizer(batch['text'], return_tensors='pt', padding=True, add_special_tokens=False)
  packed = torch.nn.utils.rnn.pack_padded_sequence(tokenized['input_ids'], tokenized['attention_mask'].sum(dim=1), enforce_sorted=False, batch_first=True)
  to_return = {
    'data': packed.data,
    'batch_sizes': packed.batch_sizes,
    'attention_mask': tokenized['attention_mask'],
    'sorted_indices': packed.sorted_indices,
    'unsorted_indices': packed.unsorted_indices
  }
  return to_return

dataset.set_transform(transform)

dataset['train'][:2]

{'data': tensor([ 2788,  2021,  1010,  2074,  2002,  2028,  2052,  2298,  2022,  2012,
         13311,  1037,  2105,  7163,  1996,  2239,  2542,  2741,  2282,  2032,
          1010,  8134,  2652,  4937,  2007, 22436,  2010,  2594, 10899,  1012,
          1012]),
 'batch_sizes': tensor([2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1]),
 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
         [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0]]),
 'sorted_indices': tensor([0, 1]),
 'unsorted_indices': tensor([0, 1])}

In [33]:
# # test
# rotator_depth = 3
# import importlib
# rotator = importlib.reload(rotator)
# model = rotator.Rotator(tokenizer.vocab_size, depth=rotator_depth)#.to(device)
# model.eval()
# batch = dataset['train'][:3]

# gen = model(**batch)
# predicts, targets = [torch.concat(res) for res in zip(*gen)]
# predicts.shape, targets.shape

(torch.Size([42, 64]), torch.Size([42, 64]))

In [10]:
batch_size = 512
shuffle = False

class Trainloader:
  def __init__(self, dataset, batch_size):
    self.data = dataset
    self.batch_size = batch_size

  def __len__(self):
    return len(self.data) // batch_size + 1

  def __iter__(self):
    current_start = 0
    for i in range(self.__len__()):
      current_start += self.batch_size
      yield self.data[current_start:current_start+self.batch_size]

# train_dataloader = DataLoader(dataset["train"], batch_size=batch_size, collate_fn=collate)
train_loader = Trainloader((dataset.shuffle() if shuffle else dataset)['train'], batch_size=batch_size)


n = next(iter(train_loader))
# print(tokenizer.decode(n['data']))
{key: n[key].shape for key in n}

{'data': torch.Size([7111]),
 'batch_sizes': torch.Size([56]),
 'attention_mask': torch.Size([512, 56]),
 'sorted_indices': torch.Size([512]),
 'unsorted_indices': torch.Size([512])}

In [11]:
import importlib

learning_rate = 0.0001
model_params = {
  'rotator_depth': 3,
  'num_heads': 32,
  'dim': 128,
  'hidden_dim': 32,
  'rotary_denom': 0.5,
  'margin': 8,
}
loss_params = {
  'sampling_word_size': 10
}

class Complex_mse(nn.Module):
  def __init__(self, **kwargs):
    super().__init__()

  def __repr__(self):
      return 'mse'

  def forward(self, predicts, targets):
    diffs = (predicts - targets)
    # print(diffs)
    # diff_square_norms = diffs * diffs.conj()
    diff_square_norms = diffs.real**2 + diffs.imag**2
    # diff_square_norms = torch.clamp(diff_square_norms, min=-.5, max=.5)
    # print(diff_square_norms)
    return {
        'mse': diff_square_norms.real.sum(dim=[1]).mean()
    }

class Complex_mse_with_inverse_norm(Complex_mse):
    def __init__(self, relax=.25, min_dist=1, **kwargs):
        super().__init__()
        self.relax = relax
        self.min_dist = min_dist
        
    def __repr__(self):
        return f'mse_with_inverse_norm(relax={self.relax}, min_dist={self.min_dist})'

    def forward(self, predicts, targets):
        inverse_norm = torch.clamp(1/(predicts.norm(dim=-1)+self.relax) - 1/(self.min_dist+self.relax), min=0)
        return  {
            **super().forward(predicts, targets),
            'inverse_norm': inverse_norm.mean()
        }

import numpy as np
class Complex_triplet_loss(nn.Module):
  def __init__(self, model, sampling_word_size=10, margin=5, distance_metric='l2', **kwargs):
    super().__init__()
    self.sampling_word_size = sampling_word_size
    self.distance_metric = distance_metric
    self.margin = margin
    self.model = model

    if distance_metric != 'l2':
      raise NotImplementedError

  def __repr__(self):
      return f'triplet_loss(margin={self.margin})'

  def forward(self, predicts, targets):
    sampled_word_vecs = self.model.predictor.all_word_embeddings()[np.random.choice(self.model.vocab_size, size=self.sampling_word_size)]

    pos_dist = (predicts - targets).norm(dim=1)
    sampled_dists = self.model.pairwise_distance(targets, sampled_word_vecs, distance=self.distance_metric)
    neg_dist = sampled_dists.min(dim=1).values
      
    return {
        'triplet_loss': torch.clamp(pos_dist-neg_dist+self.margin, min=0).mean()
    }

class Multiplet_loss(nn.Module):
  def __init__(self, model, sampling_word_size=10, margin=5, distance_metric='l2', **kwargs):
    super().__init__()
    self.sampling_word_size = sampling_word_size
    self.distance_metric = distance_metric
    self.margin = margin
    self.model = model

    if distance_metric != 'l2':
      raise NotImplementedError

  def __repr__(self):
      return f'multiplet_loss(margin={self.margin}, sampling_word_size={self.sampling_word_size})'

  def forward(self, predicts, targets):
    sampled_word_vecs = self.model.predictor.all_word_embeddings()[np.random.choice(self.model.vocab_size, size=self.sampling_word_size)]

    pos_dist = (predicts - targets).norm(dim=1) #shape: [batch]
    sampled_dists = self.model.pairwise_distance(targets, sampled_word_vecs, distance=self.distance_metric) 
    neg_dist = sampled_dists.min(dim=1).values
    triplets = (pos_dist[:, None] - neg_dist + self.margin) # shape: [batch, self.sampling_word_size]
    
    return {
        'multiplet_loss': torch.clamp(triplets, min=0).mean()
    }

        
class Complex_mse_triplet_loss(nn.Module):
  def __init__(self, model, sampling_word_size=10, margin=5, distance_metric='l2', **kwargs):
    super().__init__()
    self.triplet = Complex_triplet_loss(model, sampling_word_size, margin, distance_metric)
    self.mse = Complex_mse()

  def __repr__(self):
      return f'{str(self.triplet)}+{str(self.mse)}'

  def forward(self, predicts, targets):
    return {
      **self.triplet(predicts, targets),
      **self.mse(predicts, targets)
    }

class Complex_mse_squared_triplet_loss(nn.Module):
  def __init__(self, model, sampling_word_size=10, margin=5, distance_metric='l2', **kwargs):
    super().__init__()
    self.triplet = Complex_triplet_loss(model, sampling_word_size, margin, distance_metric)
    self.mse = Complex_mse()

  def __repr__(self):
      return f'sq{str(self.triplet)}+{str(self.mse)}'

  def forward(self, predicts, targets):
    return {
      'sqtriplet_loss': self.triplet(predicts, targets)['triplet_loss']**2,
      **self.mse(predicts, targets)
    }

class Mse_multiplet_loss(nn.Module):
  def __init__(self, model, sampling_word_size=10, margin=5, distance_metric='l2', **kwargs):
    super().__init__()
    self.multiplet = Multiplet_loss(model, sampling_word_size, margin, distance_metric)
    self.mse = Complex_mse()

  def __repr__(self):
      return f'{str(self.multiplet)}+{str(self.mse)}'

  def forward(self, predicts, targets):
    return {
      **self.multiplet(predicts, targets),
      **self.mse(predicts, targets)
    }


rotator = importlib.reload(rotator)
model = rotator.Rotator(tokenizer.vocab_size, **model_params).to(device)
# using_loss = Complex_mse()
# using_loss = Complex_triplet_loss(model)
# using_loss = Complex_mse_triplet_loss(model)
# using_loss = Complex_mse_squared_triplet_loss(model)
using_loss = Mse_multiplet_loss(model, **loss_params)

optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

In [12]:
# # testing model output
# batch = next(iter(train_loader))
# batch = {key: batch[key].to(device) for key in batch}

# gen = model(**batch)
# predicts, targets = [torch.concat(res) for res in zip(*gen)]
# losses = using_loss(predicts, targets)
# loss = sum(losses.values())

# loss_dict = {
#   'total': f"{loss.item(): .6f}",
#   **{key: f"{losses[key].item(): .6f}" for key in losses},
# }
# print(loss_dict)

# del batch, loss
# torch.cuda.empty_cache()

In [13]:
# def backward_hook(module, gin, gout):
#   print(f"{len(gin)=}, {len(gout)=}")
#   print(*[f"{i=}, {gi.shape=}, {gi.mean()=}, {gi.min()=}, {gi.max()=}" for i, gi in enumerate(gin)], sep='\n')
#   print(*[f"{i=}, {go.shape=}, {go.mean()=}, {go.min()=}, {go.max()=}" for i, go in enumerate(gout)], sep='\n')

# model.limas[0].register_backward_hook(backward_hook)

In [14]:
time_string = datetime.datetime.now(tz=datetime.timezone(datetime.timedelta(hours=8))).strftime('%Y%m%d.%H:%M:%S')
subfolder_path = Path(f"{time_string}-batch_size_{batch_size}")

os.makedirs(PATH/folder_path/subfolder_path, exist_ok=True)

with open(PATH/folder_path/subfolder_path/f'parameters.json', 'w') as f:
  f.write(json.dumps({
    'tokenizer_type': tokenizer_type,
    'model': str(model),
    **model_params,
    'learning_rate': learning_rate,
    'batch_size': batch_size,
    'shuffle': shuffle,
    'loss_type': str(using_loss),
    **loss_params,
  }))  
time_string

'20240427.15:29:39'

In [15]:
num_epochs = 1
save_every_n_batches = 5000


# os.makedirs(PATH/folder_path/subfolder_path, exist_ok=True)

loss_history = []

for epoch_num, epoch in enumerate(range(num_epochs)):
  bar = tqdm(train_loader)

  for batch_num, batch in enumerate(bar):

    optimizer.zero_grad()

    batch = {key: batch[key].to(device) for key in batch}

    gen = model(**batch)
    predicts, targets = [torch.concat(res) for res in zip(*gen)]
    losses = using_loss(predicts, targets)
    loss = sum(losses.values())

    loss_dict = {
      'total': f"{loss.item(): .6f}",
      **{key: f"{losses[key].item(): .6f}" for key in losses},
    }
      
    bar.set_postfix(loss_dict)

    # with torch.autograd.detect_anomaly(True):
      # loss.backward()
      # torch.nn.utils.clip_grad_norm_(parameters=model.parameters(), max_norm=10, norm_type=2)
    loss.backward()
    # break 
    optimizer.step()

    if (batch_num % 100 == 0):
      torch.cuda.empty_cache()
      loss_history.append([epoch_num, batch_num, loss_dict])
      # print(f"{datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}, batch {batch_num}")

    if (batch_num % save_every_n_batches == 0):
      prefix = f"batch_{batch_num}-"
      torch.save(model, PATH/folder_path/subfolder_path/(prefix + f'model.pt'))
        
      with open(PATH/folder_path/subfolder_path/'history.json', 'w') as f:
        f.write(json.dumps(loss_history))  

      torch.cuda.empty_cache()

    if (batch_num in [100, 500, 1000, 1500, 2000, 3000, 7500, 12500, 17500, 22500]):
        prefix = f"batch_{batch_num}-"
        torch.save(model, PATH/folder_path/subfolder_path/(prefix + f'model.pt'))
        
        with open(PATH/folder_path/subfolder_path/'history.json', 'w') as f:
          f.write(json.dumps(loss_history))  

    del batch, loss

 12%|███                       | 16789/144540 [1:41:19<10:30:41,  3.38it/s, total=73.347336, multiplet_loss=1.492947, mse=71.854385]Token indices sequence length is longer than the specified maximum sequence length for this model (1119 > 512). Running this sequence through the model will result in indexing errors
 19%|████▉                     | 27701/144540 [2:53:31<12:11:55,  2.66it/s, total=81.986778, multiplet_loss=7.776369, mse=74.210411]


OutOfMemoryError: CUDA out of memory. Tried to allocate 3.82 GiB. GPU 5 has a total capacity of 15.78 GiB of which 881.75 MiB is free. Process 6126 has 14.92 GiB memory in use. Of the allocated memory 10.18 GiB is allocated by PyTorch, and 3.85 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)