In [None]:
%load_ext autoreload
%autoreload 2

import logging

logger = logging.getLogger("train")
logger.setLevel(logging.INFO)

if not logger.hasHandlers():
	stderr = logging.StreamHandler()
	stderr.setLevel(logging.INFO)
	logger.addHandler(stderr)

logger.info("hello world")


In [None]:
from detecter import config

config.NUM_CORE = 4
config.WORD2VEC_BATCH_SIZE = 128


In [None]:
from torch.utils import data

from detecter.dataset import OJClone

batch_size = 1

ds = OJClone.BiDataSet(OJClone.DataSet("dataset/OJClone/train.jsonl"))
loader = data.DataLoader(ds, batch_size=batch_size, collate_fn=OJClone.collate_fn, shuffle=True)

ds = OJClone.BiDataSet(OJClone.DataSet("dataset/OJClone/valid.jsonl"))
v_loader = data.DataLoader(ds, batch_size=batch_size, collate_fn=OJClone.collate_fn)


In [None]:
import torch

from detecter.train import Trainer
from detecter.model import AstAttention, Classifier

model = AstAttention(384, 768, num_layers=6, num_heads=8).cuda()
classifier = Classifier(768, 2).cuda()
trainer = Trainer(model=model, classifier=classifier).cuda()

optimizer = torch.optim.AdamW([
	{"params": model.parameters(), "lr": 3e-5, "weight_decay": 0.1}, 
	{"params": classifier.parameters(), "lr": 3e-4}
])

In [None]:
try:
	with open("log/model.pt", "rb") as f:
		save = torch.load(f)
	model.load_state_dict(save["model_state_dict"], strict=False)
	classifier.load_state_dict(save["classifier_state_dict"], strict=False)
	min_loss = save["loss"]
except IOError:
	logger.info("no model")
min_loss = 1e8

In [None]:
try:
	with open("log/trainer.ckpt", "rb") as f:
		save = torch.load(f)
	trainer.load_state_dict(save["trainer_state_dict"], strict=False)
	optimizer.load_state_dict(save["optimizer_state_dict"])
	epoch = save["epoch"]
except IOError:
	epoch = 1
	logger.info("no ckpt")


In [None]:
import torch
from tqdm import tqdm

while True:
	logger.info("epoch {}".format(epoch))
	trainer.train()
	for batch in tqdm(loader):
		with torch.no_grad():
			# optimizer.zero_grad()
			loss = trainer(batch)
			# loss.backward()
			optimizer.step()
		break
	trainer.evaluate()

	trainer.eval()
	for batch in tqdm(v_loader):
		with torch.no_grad():
			trainer(batch)
		break
	loss = trainer.evaluate()

	# with open("log/trainer.ckpt", "wb") as f:
	#     torch.save(check_point(trainer, optimizer, epoch))
	
	if loss < min_loss:
		min_loss = loss
		# with open("log/model.pt", "wb") as f:
		#     torch.save(model_pt(model, loss), f)

	epoch += 1