In [1]:
from datasets import load_dataset
from transformers import BertTokenizerFast, BertModel
from embed_layer import Word2Vec, ContextualEmbedding
from e2e import E2E 
from data import SquadDataset
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import time

In [2]:

dataset = load_dataset("nyu-mll/glue", "qnli")


KeyboardInterrupt: 

In [21]:
dataset['train'][0]['label'].type()

AttributeError: 'int' object has no attribute 'type'

In [3]:
dataset['train'][0]

{'question': 'When did the third Digimon series begin?',
 'sentence': 'Unlike the two seasons before it and most of the seasons that followed, Digimon Tamers takes a darker and more realistic approach to its story featuring Digimon who do not reincarnate after their deaths and more complex character development in the original Japanese.',
 'label': 1,
 'idx': 0}

In [3]:
tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')

In [4]:
random_train = dataset['train'].select(range(2269,12269))
random_val = dataset['validation'].select(range(2269,3269))
random_test = dataset['validation'].select(range(3269,4269))

In [5]:
train_data = SquadDataset(random_train, 32, tokenizer)
validation_data = SquadDataset(random_val, 32, tokenizer)

In [7]:
tokenizer.vocab_size

30522

In [8]:
class BiDAF(nn.Module):
	def __init__(self, vocab_size, embed_size, hidden_size, c_len, BERT=False):
		super(BiDAF, self).__init__()
		self.w2v = Word2Vec(vocab_size, embed_size, BERT) # vocab_size, embed_size
		self.acontext = ContextualEmbedding(embed_size, hidden_size) # embed_size, hidden_size
		self.e2e = E2E(hidden_size, c_len) # hidden_size, c_len
	
	def forward(self, q, c):
		q = self.w2v(q)
		c = self.w2v(c)
		q = self.acontext(q)
		c = self.acontext(c)
		return self.e2e(q, c)

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
BERT = False

In [9]:
if BERT:
	model = BiDAF(vocab_size=tokenizer.vocab_size, embed_size=768, hidden_size=384, c_len=96, BERT=True).to(device)
	model = torch.compile(model)
else:
	model = BiDAF(vocab_size=tokenizer.vocab_size, embed_size=100, hidden_size=196, c_len=96).to(device)
	model = torch.compile(model)



In [10]:
optimizer = optim.Adadelta(model.parameters(), lr = 0.5, weight_decay=0.001)

In [11]:
critereon = nn.CrossEntropyLoss().to('cuda')

In [22]:
def train(model, train_data, optimizer, critereon, epochs=1):
	start = time.time()
	for epoch in range(epochs):
		model.train()
		running_loss = 0.0
		for questions, contexts, labels in train_data:
			optimizer.zero_grad()
			questions = questions.to(device)
			contexts = contexts.to(device)
			labels = labels.long().to(device)
			# with torch.autocast(device_type=device, dtype=torch.bfloat16):
			output = model(questions, contexts)
			loss = critereon(output, labels)
			loss.backward()
			optimizer.step()
			torch.cuda.synchronize()
			running_loss += loss.item()
		print(f"Epoch: {epoch}, Loss: {running_loss/len(train_data)}")
	end = time.time()
	print(f"Training time: {end-start}")

In [24]:
train(model, train_data, optimizer, critereon, epochs=5)

KeyboardInterrupt: 

In [None]:
def evaluation(model, val_data, critereon):
	model.eval()
	running_loss = 0.0
	# with torch.no_grad():
	for questions, contexts, labels in val_data:
			questions = questions.to(device)
			contexts = contexts.to(device)
			labels = labels.long().to(device)
			with torch.no_grad():
				# with torch.autocast(device_type=device, dtype=torch.float16):
				output = model(questions, contexts)
				loss = critereon(output, labels)
				running_loss += loss.item()
	print(f"Validation Loss: {running_loss/len(val_data)}")

In [None]:
evaluation(model, validation_data, critereon)

In [25]:
def accuracy(model, val_data):
	model.eval()
	correct = 0
	total = 0
	# with torch.no_grad():
	for questions, contexts, labels in val_data:
			questions = questions.to(device)
			contexts = contexts.to(device)
			labels = labels.long().to(device)
			with torch.no_grad():
				# with torch.autocast(device_type=device, dtype=torch.bfloat16):
				output = model(questions, contexts)
				_, predicted = torch.max(output, 1)
				total += labels.size(0)
				correct += (predicted == labels).sum().item()
	print(f"Accuracy: {100*correct/total}")

In [None]:
accuracy(model, validation_data)