In [4]:
import os
import torch
from transformers import AutoTokenizer, T5ForConditionalGeneration
from torch.utils.data import DataLoader, Dataset
import pandas as pd
import numpy as np

In [5]:
class CustomDataset(Dataset):
    def __init__(self, data, tokenizer, max_length):
        self.data = data
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data.iloc[idx]
        input_text = item['hpo_def']
        target_text = item['def']

        input_encoding = self.tokenizer(input_text, padding='max_length', truncation=True, max_length=self.max_length, return_tensors="pt")
        target_encoding = self.tokenizer(target_text, padding='max_length', truncation=True, max_length=self.max_length, return_tensors="pt")

        return input_encoding, target_encoding

train_data = pd.read_csv('./source_data/orphadef_hpodef_db.csv')
tokenizer = AutoTokenizer.from_pretrained("../flanT5/")
model = T5ForConditionalGeneration.from_pretrained("../flanT5/")

train_dataset = CustomDataset(train_data, tokenizer, max_length=1024)
device = torch.device("cuda:3" if torch.cuda.is_available() else "cpu")
batch_size = 2
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
model = model.to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=0.001)


In [None]:
epochs = 50
for epoch in range(epochs):
    total_loss = 0.0
    model.train()
    count = 0
    for input_sequences, target_sequences in train_loader:
        count += 1
        input_ids = input_sequences['input_ids'].squeeze(1).to(device)
        attention_mask = input_sequences['attention_mask'].squeeze(1).to(device)
        target_ids = target_sequences['input_ids'].squeeze(1).to(device)
        optimizer.zero_grad()
        output = model(input_ids=input_ids, attention_mask=attention_mask, labels=target_ids)
        loss = output.loss
        loss = torch.mean(loss) 
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        average_loss = total_loss / count
        print('epochs:', epoch, 'batch:', count, "Average Loss:", average_loss)
    state_dict = model.module.state_dict() if isinstance(model, torch.nn.DataParallel) else model.state_dict()
    
    # torch.save(state_dict, './flan-model.pth')
  