In [1]:
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from tqdm.notebook import tqdm
from transformers import Adafactor

In [2]:
model_size = "google/mt5-base"

BATCH_SIZE = 8

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [3]:
train_path = "../data/train/"
test_path = "../data/test/"
val_path = "../data/val/"

In [4]:
def clean_german_text(text):
    return ";".join(text.split(";")[1:]).strip()


def clean_english_text(text):
    return text.strip()

In [5]:
def load_data(path):
    data_list = []
    
    en_articles = [clean_english_text(item) for item in open(path + "article", "r")]
    en_highlights = [clean_english_text(item)  for item in open(path + "highlights", "r")]
    de_articles = [clean_german_text(item)  for item in open(path + "articles_german", "r")]
    de_highlights = [clean_german_text(item)  for item in open(path + "highlights_german", "r")]
    
    for en_a, en_h, de_a, de_h in zip(en_articles, en_highlights, de_articles, de_highlights):
        data_list.append((en_a, en_h, de_a, de_h))
    
    return data_list
    
train_data = load_data(train_path)

In [6]:
test_data = load_data(test_path)
val_data = load_data(val_path)

In [7]:
from transformers import ( 
    T5Tokenizer, 
    MT5ForConditionalGeneration
)

In [8]:
tokenizer = T5Tokenizer.from_pretrained(model_size)
en_de_prefix = tokenizer("summarize: en_to_ger ", return_tensors="pt").input_ids.reshape(-1,) * torch.ones(BATCH_SIZE, dtype=int).reshape(-1,1)
de_en_prefix = tokenizer("summarize: ger_to_en ", return_tensors="pt").input_ids.reshape(-1,) * torch.ones(BATCH_SIZE, dtype=int).reshape(-1,1)
en_en_prefix = tokenizer("summarize: en_to_en ", return_tensors="pt").input_ids.reshape(-1,) * torch.ones(BATCH_SIZE, dtype=int).reshape(-1,1)
de_de_prefix = tokenizer("summarize: ger_to_ger ", return_tensors="pt").input_ids.reshape(-1,) * torch.ones(BATCH_SIZE, dtype=int).reshape(-1,1)

In [9]:
en_de_prefix.shape

torch.Size([8, 9])

In [10]:
if en_de_prefix[0][-1] == 1: 
    en_de_prefix = en_de_prefix[:,:-1]
    de_en_prefix = de_en_prefix[:,:-1]
    en_en_prefix = en_en_prefix[:,:-1]
    de_de_prefix = de_de_prefix[:,:-1]

In [11]:
en_de_prefix

tensor([[196098,  10701,    267,    289,    290,    476,    290,   2198],
        [196098,  10701,    267,    289,    290,    476,    290,   2198],
        [196098,  10701,    267,    289,    290,    476,    290,   2198],
        [196098,  10701,    267,    289,    290,    476,    290,   2198],
        [196098,  10701,    267,    289,    290,    476,    290,   2198],
        [196098,  10701,    267,    289,    290,    476,    290,   2198],
        [196098,  10701,    267,    289,    290,    476,    290,   2198],
        [196098,  10701,    267,    289,    290,    476,    290,   2198]])

In [12]:
len_prefix = en_de_prefix.shape[0] 
print(len_prefix)

8


In [13]:
en_de_prefix.shape

torch.Size([8, 8])

In [14]:
ones = torch.ones((BATCH_SIZE, len_prefix), dtype=int)

In [15]:
class MyDataset(torch.utils.data.Dataset):
    def __init__(self, dataset):
        self.x = dataset
        
        
    def __getitem__(self, index):
        x_en = tokenizer(self.x[index][0], max_length=512 - len_prefix, return_tensors="pt", truncation=True, padding='max_length')
        y_en = tokenizer(self.x[index][1], max_length=150, return_tensors="pt", truncation=True, padding='max_length')
        x_de = tokenizer(self.x[index][2], max_length=512 - len_prefix, return_tensors="pt", truncation=True, padding='max_length')
        y_de = tokenizer(self.x[index][3], max_length=150, return_tensors="pt", truncation=True, padding='max_length')
        
        return x_en.input_ids.reshape(-1), x_en.attention_mask.reshape(-1), y_en.input_ids.reshape(-1), x_de.input_ids.reshape(-1), x_de.attention_mask.reshape(-1), y_de.input_ids.reshape(-1)
    
    def __len__(self):
        return len(self.x)

In [16]:
train_ds = MyDataset(train_data)
test_ds = MyDataset(test_data)
val_ds = MyDataset(val_data)

In [17]:
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE)
test_loader = DataLoader(test_ds, batch_size=BATCH_SIZE)
val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE)

In [18]:
model = MT5ForConditionalGeneration.from_pretrained(model_size).to(device)
optimizer = Adafactor(
    model.parameters(),
    lr=1e-3,
    eps=(1e-30, 1e-3),
    clip_threshold=1.0,
    decay_rate=-0.8,
    beta1=None,
    weight_decay=0.0,
    relative_step=False,
    scale_parameter=False,
    warmup_init=False
)

In [19]:
def get_reandom_dataset(data):
    rand_int = np.random.randint(4)
    x_en_input_ids, x_en_attention_mask, y_en_input_ids, x_de_input_ids, x_de_attention_mask, y_de_input_ids = data
    print(en_de_prefix.shape, x_en_input_ids.shape)
    
    if rand_int == 0:
        return torch.cat([en_de_prefix, x_en_input_ids], dim=1).to(device), torch.cat([ones, x_en_attention_mask], dim=1).to(device), y_de_input_ids.to(device)
    elif rand_int == 1:
        return torch.cat([en_de_prefix, x_en_input_ids], dim=1).to(device), torch.cat([ones, x_en_attention_mask], dim=1).to(device), y_en_input_ids.to(device)
    elif rand_int == 2:
        return torch.cat([de_de_prefix, x_de_input_ids], dim=1).to(device), torch.cat([ones, x_de_attention_mask], dim=1).to(device), y_de_input_ids.to(device)
    elif rand_int == 3:
        return torch.cat([de_en_prefix, x_de_input_ids], dim=1).to(device), torch.cat([ones, x_de_attention_mask], dim=1).to(device), y_en_input_ids.to(device)
    

In [20]:
print_every = 1000

### Training loop ###
for data in tqdm(train_loader, desc="Training"):
    optimizer.zero_grad()
    input_ids, attention_mask, y = get_reandom_dataset(data)
    loss = model(input_ids, attention_mask, y).loss
    train_loss.append(loss.item())
    loss.backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
    optimizer.step()

    if epoch_step % print_every == 0:
        print(f"Training step {epoch_step} Accuracy: {train_accuracy}, Training loss: {train_loss}")

Training:   0%|          | 0/35890 [00:00<?, ?it/s]

torch.Size([8, 8]) torch.Size([8, 504])


RuntimeError: CUDA out of memory. Tried to allocate 30.00 MiB (GPU 0; 10.76 GiB total capacity; 9.14 GiB already allocated; 31.69 MiB free; 9.23 GiB reserved in total by PyTorch)