In [None]:
import torch

from torch.utils.data import DataLoader

from Encoder import PerSec
from quantizer import GumbelVectorQuantizer
from custom_dataset import ContrastiveLearningDataset
from contrastive_loss import ContrastiveLoss

In [None]:
activations = {}
def get_activation(name):
	def hook(model, input, output):
		activations[name] = output.clone()
	return hook

In [None]:
torch.manual_seed(0)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

PerSec = PerSec().to(device)
stroke_quantizer = GumbelVectorQuantizer(extracted_feature_size=64, num_groups=2, num_vectors=256, temperature=0.5).to(device)
semantic_quantizer = GumbelVectorQuantizer(extracted_feature_size=512, num_groups=2, num_vectors=256, temperature=0.5).to(device)

PerSec.stage1.register_forward_hook(get_activation('stage1'))
PerSec.stroke_context_aggregator.register_forward_hook(get_activation('stroke_context_aggregator'))
PerSec.stage4.register_forward_hook(get_activation('stage4'))

In [None]:
random_data = torch.randn(1, 3, 32, 384).to(device)
PerSec(random_data)

p_low = 0.2
p_high = 0.15

LOW_LENGTH = activations['stage1'].shape[2] * activations['stage1'].shape[3]
HIGH_LENGTH = activations['stage4'].shape[2] * activations['stage4'].shape[3]

MASK_NUM_LOW =  LOW_LENGTH * p_low
MASK_NUM_HIGH = HIGH_LENGTH * p_high

In [None]:
IMAGE_DIR = "data/images"
BATCH_SIZE = 64

optimizer = torch.optim.Adam(list(PerSec.parameters()) + list(stroke_quantizer.parameters()) + list(semantic_quantizer.parameters()), lr=1e-3)
dataset = ContrastiveLearningDataset(IMAGE_DIR)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
contrastive_loss = ContrastiveLoss()

In [None]:
torch.manual_seed(0)
epochs = 100

alpha = 0.2
beta = 0.1

steps = []
loss_value = []
step_count = 0

for epoch in range(epochs):
	for i, batch in enumerate(dataloader):
		batch = batch.to(device)
		optimizer.zero_grad()

		mask_low = torch.randint(0, LOW_LENGTH, (BATCH_SIZE, MASK_NUM_LOW), device=device)
		mask_high = torch.randint(0, HIGH_LENGTH, (BATCH_SIZE, MASK_NUM_HIGH), device=device)
		# semantic context aggregator's output
		SECP_activation = PerSec(batch, mask_low, mask_high)
		
		stroke_context_quantized, stroke_quantizer_entropy = stroke_quantizer(activations["stage1"])
		stroke_contrastive_loss = contrastive_loss(activations["stroke_context_aggregator"], stroke_context_quantized, mask_low)

		semantic_context_quantized, semantic_quantizer_entropy = semantic_quantizer(activations["stage4"])
		semantic_contrastive_loss = contrastive_loss(SECP_activation, semantic_context_quantized, mask_high)

		loss = stroke_contrastive_loss + alpha * stroke_quantizer_entropy + semantic_contrastive_loss + beta * semantic_quantizer_entropy
		if i % 100 == 0:
			print(f"Epoch: {epoch+1}, Batch: {i}, Loss: {loss.item()}")
			steps.append(step_count)
			loss_value.append(loss.item())
			step_count += 100
		loss.backward()
		optimizer.step()

In [None]:
from matplotlib import pyplot as plt
plt.plot(steps, loss_value, label='Training Loss')

In [None]:
torch.save(PerSec.state_dict(), "model/PerSec.pth")
torch.save(stroke_quantizer.state_dict(), "model/stroke_quantizer.pth")
torch.save(semantic_quantizer.state_dict(), "model/semantic_quantizer.pth")

In [None]:
# Decoder training
import torch
from torch.utils.data import DataLoader, random_split, ConcatDataset

import os
import pickle
from PIL import Image
from torchmetrics.text import CharErrorRate

from Encoder import PerSec
from Decoder import LSTMAttnDecoder

from custom_dataset import DecoderDataset

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
CHAR_TO_TOKEN_FILE = "char_to_token.pkl"
TOKEN_TO_CHAR_FILE = "token_to_char.pkl"

with open(CHAR_TO_TOKEN_FILE, "rb") as f:
	char_to_token = pickle.load(f)

with open(TOKEN_TO_CHAR_FILE, "rb") as f:
	token_to_char = pickle.load(f)

def tokenizer(text):
	return [char_to_token[char.item()] for char in text]

In [None]:
torch.manual_seed(0)
PerSec = PerSec().to(device)
PerSec.load_state_dict(torch.load("model/PerSec.pth"))
Decoder = LSTMAttnDecoder(hidden_size=512, output_size=len(token_to_char)).to(device)

In [None]:
torch.manual_seed(0)

NATURAL_DATASET_FOLDER = "data/images"
NATURAL_DATASET_CSV = "data/texts"
GENERATED_DATASET_FOLDER = "data/generated_images"
GENERATED_DATASET_CSV = "data/generated_texts"
natural_dataset = DecoderDataset(csv_path=NATURAL_DATASET_CSV, image_dir=NATURAL_DATASET_FOLDER, token_dict=char_to_token)
generated_dataset = DecoderDataset(csv_path=GENERATED_DATASET_FOLDER, image_dir=GENERATED_DATASET_FOLDER, token_dict=char_to_token)

dataseet = ConcatDataset([natural_dataset, generated_dataset])
train_size = int(0.9 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
train_dataloader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=64, shuffle=False)

In [None]:
cross_entropy = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(list(PerSec.parameters()) + list(Decoder.parameters()), lr=1e-3)
cer = CharErrorRate()

In [None]:
epochs = 100
SOS_token = char_to_token['<SOS>']